Skip to content

Commit efa5a71

Browse files
Reduce memory usage when applying DORA: comfyanonymous#3557
1 parent 58c9838 commit efa5a71

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

comfy/model_patcher.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from comfy.types import UnetWrapperFunction
1010

1111

12-
def apply_weight_decompose(dora_scale, weight):
12+
def weight_decompose_scale(dora_scale, weight):
1313
weight_norm = (
1414
weight.transpose(0, 1)
1515
.reshape(weight.shape[1], -1)
@@ -18,7 +18,7 @@ def apply_weight_decompose(dora_scale, weight):
1818
.transpose(0, 1)
1919
)
2020

21-
return weight * (dora_scale / weight_norm).type(weight.dtype)
21+
return (dora_scale / weight_norm).type(weight.dtype)
2222

2323
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
2424
to = model_options["transformer_options"].copy()
@@ -365,7 +365,7 @@ def calculate_weight(self, patches, weight, key):
365365
try:
366366
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
367367
if dora_scale is not None:
368-
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
368+
weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
369369
except Exception as e:
370370
logging.error("ERROR {} {} {}".format(patch_type, key, e))
371371
elif patch_type == "lokr":
@@ -407,7 +407,7 @@ def calculate_weight(self, patches, weight, key):
407407
try:
408408
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
409409
if dora_scale is not None:
410-
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
410+
weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
411411
except Exception as e:
412412
logging.error("ERROR {} {} {}".format(patch_type, key, e))
413413
elif patch_type == "loha":
@@ -439,7 +439,7 @@ def calculate_weight(self, patches, weight, key):
439439
try:
440440
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
441441
if dora_scale is not None:
442-
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
442+
weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
443443
except Exception as e:
444444
logging.error("ERROR {} {} {}".format(patch_type, key, e))
445445
elif patch_type == "glora":
@@ -456,7 +456,7 @@ def calculate_weight(self, patches, weight, key):
456456
try:
457457
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
458458
if dora_scale is not None:
459-
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
459+
weight *= weight_decompose_scale(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
460460
except Exception as e:
461461
logging.error("ERROR {} {} {}".format(patch_type, key, e))
462462
else:

0 commit comments

Comments
 (0)