Skip to content

Commit 520eb77

Browse files
LoRA Trainer: LoRA training node in weight adapter scheme (comfyanonymous#8446)
1 parent 5bf69bd commit 520eb77

File tree

12 files changed

+948
-23
lines changed

12 files changed

+948
-23
lines changed

comfy/comfy_types/node_typing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class IO(StrEnum):
3737
CONTROL_NET = "CONTROL_NET"
3838
VAE = "VAE"
3939
MODEL = "MODEL"
40+
LORA_MODEL = "LORA_MODEL"
41+
LOSS_MAP = "LOSS_MAP"
4042
CLIP_VISION = "CLIP_VISION"
4143
CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT"
4244
STYLE_MODEL = "STYLE_MODEL"

comfy/ldm/modules/attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,7 @@ def forward(self, x, context=None, transformer_options={}):
753753
for p in patch:
754754
n = p(n, extra_options)
755755

756-
x += n
756+
x = n + x
757757
if "middle_patch" in transformer_patches:
758758
patch = transformer_patches["middle_patch"]
759759
for p in patch:
@@ -793,12 +793,12 @@ def forward(self, x, context=None, transformer_options={}):
793793
for p in patch:
794794
n = p(n, extra_options)
795795

796-
x += n
796+
x = n + x
797797
if self.is_res:
798798
x_skip = x
799799
x = self.ff(self.norm3(x))
800800
if self.is_res:
801-
x += x_skip
801+
x = x_skip + x
802802

803803
return x
804804

comfy/model_patcher.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,26 @@
1717
"""
1818

1919
from __future__ import annotations
20-
from typing import Optional, Callable
21-
import torch
20+
21+
import collections
2222
import copy
2323
import inspect
2424
import logging
25-
import uuid
26-
import collections
2725
import math
26+
import uuid
27+
from typing import Callable, Optional
28+
29+
import torch
2830

29-
import comfy.utils
3031
import comfy.float
31-
import comfy.model_management
32-
import comfy.lora
3332
import comfy.hooks
33+
import comfy.lora
34+
import comfy.model_management
3435
import comfy.patcher_extension
35-
from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
36+
import comfy.utils
3637
from comfy.comfy_types import UnetWrapperFunction
38+
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
39+
3740

3841
def string_to_seed(data):
3942
crc = 0xFFFFFFFF

comfy/sd.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1081,7 +1081,28 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
10811081
return (model_patcher, clip, vae, clipvision)
10821082

10831083

1084-
def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format
1084+
def load_diffusion_model_state_dict(sd, model_options={}):
1085+
"""
1086+
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
1087+
1088+
Args:
1089+
sd (dict): State dictionary containing model weights and configuration
1090+
model_options (dict, optional): Additional options for model loading. Supports:
1091+
- dtype: Override model data type
1092+
- custom_operations: Custom model operations
1093+
- fp8_optimizations: Enable FP8 optimizations
1094+
1095+
Returns:
1096+
ModelPatcher: A wrapped model instance that handles device management and weight loading.
1097+
Returns None if the model configuration cannot be detected.
1098+
1099+
The function:
1100+
1. Detects and handles different model formats (regular, diffusers, mmdit)
1101+
2. Configures model dtype based on parameters and device capabilities
1102+
3. Handles weight conversion and device placement
1103+
4. Manages model optimization settings
1104+
5. Loads weights and returns a device-managed model instance
1105+
"""
10851106
dtype = model_options.get("dtype", None)
10861107

10871108
#Allow loading unets from checkpoint files

comfy/weight_adapter/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .base import WeightAdapterBase
1+
from .base import WeightAdapterBase, WeightAdapterTrainBase
22
from .lora import LoRAAdapter
33
from .loha import LoHaAdapter
44
from .lokr import LoKrAdapter
@@ -15,3 +15,9 @@
1515
OFTAdapter,
1616
BOFTAdapter,
1717
]
18+
19+
__all__ = [
20+
"WeightAdapterBase",
21+
"WeightAdapterTrainBase",
22+
"adapters"
23+
] + [a.__name__ for a in adapters]

comfy/weight_adapter/base.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,20 @@ class WeightAdapterBase:
1212
weights: list[torch.Tensor]
1313

1414
@classmethod
15-
def load(cls, x: str, lora: dict[str, torch.Tensor]) -> Optional["WeightAdapterBase"]:
15+
def load(cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor) -> Optional["WeightAdapterBase"]:
1616
raise NotImplementedError
1717

1818
def to_train(self) -> "WeightAdapterTrainBase":
1919
raise NotImplementedError
2020

21+
@classmethod
22+
def create_train(cls, weight, *args) -> "WeightAdapterTrainBase":
23+
"""
24+
weight: The original weight tensor to be modified.
25+
*args: Additional arguments for configuration, such as rank, alpha etc.
26+
"""
27+
raise NotImplementedError
28+
2129
def calculate_weight(
2230
self,
2331
weight,
@@ -33,10 +41,22 @@ def calculate_weight(
3341

3442

3543
class WeightAdapterTrainBase(nn.Module):
44+
# We follow the scheme of PR #7032
3645
def __init__(self):
3746
super().__init__()
3847

39-
# [TODO] Collaborate with LoRA training PR #7032
48+
def __call__(self, w):
49+
"""
50+
w: The original weight tensor to be modified.
51+
"""
52+
raise NotImplementedError
53+
54+
def passive_memory_usage(self):
55+
raise NotImplementedError("passive_memory_usage is not implemented")
56+
57+
def move_to(self, device):
58+
self.to(device)
59+
return self.passive_memory_usage()
4060

4161

4262
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
@@ -102,3 +122,14 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
102122
padded_tensor[new_slices] = tensor[orig_slices]
103123

104124
return padded_tensor
125+
126+
127+
def tucker_weight_from_conv(up, down, mid):
128+
up = up.reshape(up.size(0), up.size(1))
129+
down = down.reshape(down.size(0), down.size(1))
130+
return torch.einsum("m n ..., i m, n j -> i j ...", mid, up, down)
131+
132+
133+
def tucker_weight(wa, wb, t):
134+
temp = torch.einsum("i j ..., j r -> i r ...", t, wb)
135+
return torch.einsum("i j ..., i r -> r j ...", temp, wa)

comfy/weight_adapter/lora.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,56 @@
33

44
import torch
55
import comfy.model_management
6-
from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape
6+
from .base import (
7+
WeightAdapterBase,
8+
WeightAdapterTrainBase,
9+
weight_decompose,
10+
pad_tensor_to_shape,
11+
tucker_weight_from_conv,
12+
)
13+
14+
15+
class LoraDiff(WeightAdapterTrainBase):
16+
def __init__(self, weights):
17+
super().__init__()
18+
mat1, mat2, alpha, mid, dora_scale, reshape = weights
19+
out_dim, rank = mat1.shape[0], mat1.shape[1]
20+
rank, in_dim = mat2.shape[0], mat2.shape[1]
21+
if mid is not None:
22+
convdim = mid.ndim - 2
23+
layer = (
24+
torch.nn.Conv1d,
25+
torch.nn.Conv2d,
26+
torch.nn.Conv3d
27+
)[convdim]
28+
else:
29+
layer = torch.nn.Linear
30+
self.lora_up = layer(rank, out_dim, bias=False)
31+
self.lora_down = layer(in_dim, rank, bias=False)
32+
self.lora_up.weight.data.copy_(mat1)
33+
self.lora_down.weight.data.copy_(mat2)
34+
if mid is not None:
35+
self.lora_mid = layer(mid, rank, bias=False)
36+
self.lora_mid.weight.data.copy_(mid)
37+
else:
38+
self.lora_mid = None
39+
self.rank = rank
40+
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
41+
42+
def __call__(self, w):
43+
org_dtype = w.dtype
44+
if self.lora_mid is None:
45+
diff = self.lora_up.weight @ self.lora_down.weight
46+
else:
47+
diff = tucker_weight_from_conv(
48+
self.lora_up.weight, self.lora_down.weight, self.lora_mid.weight
49+
)
50+
scale = self.alpha / self.rank
51+
weight = w + scale * diff.reshape(w.shape)
52+
return weight.to(org_dtype)
53+
54+
def passive_memory_usage(self):
55+
return sum(param.numel() * param.element_size() for param in self.parameters())
756

857

958
class LoRAAdapter(WeightAdapterBase):
@@ -13,6 +62,21 @@ def __init__(self, loaded_keys, weights):
1362
self.loaded_keys = loaded_keys
1463
self.weights = weights
1564

65+
@classmethod
66+
def create_train(cls, weight, rank=1, alpha=1.0):
67+
out_dim = weight.shape[0]
68+
in_dim = weight.shape[1:].numel()
69+
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
70+
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
71+
torch.nn.init.kaiming_uniform_(mat1, a=5**0.5)
72+
torch.nn.init.constant_(mat2, 0.0)
73+
return LoraDiff(
74+
(mat1, mat2, alpha, None, None, None)
75+
)
76+
77+
def to_train(self):
78+
return LoraDiff(self.weights)
79+
1680
@classmethod
1781
def load(
1882
cls,

0 commit comments

Comments
 (0)