9
9
from comfy .types import UnetWrapperFunction
10
10
11
11
12
- def apply_weight_decompose (dora_scale , weight ):
12
+ def weight_decompose_scale (dora_scale , weight ):
13
13
weight_norm = (
14
14
weight .transpose (0 , 1 )
15
15
.reshape (weight .shape [1 ], - 1 )
@@ -18,7 +18,7 @@ def apply_weight_decompose(dora_scale, weight):
18
18
.transpose (0 , 1 )
19
19
)
20
20
21
- return weight * (dora_scale / weight_norm ).type (weight .dtype )
21
+ return (dora_scale / weight_norm ).type (weight .dtype )
22
22
23
23
def set_model_options_patch_replace (model_options , patch , name , block_name , number , transformer_index = None ):
24
24
to = model_options ["transformer_options" ].copy ()
@@ -365,7 +365,7 @@ def calculate_weight(self, patches, weight, key):
365
365
try :
366
366
weight += (alpha * torch .mm (mat1 .flatten (start_dim = 1 ), mat2 .flatten (start_dim = 1 ))).reshape (weight .shape ).type (weight .dtype )
367
367
if dora_scale is not None :
368
- weight = apply_weight_decompose (comfy .model_management .cast_to_device (dora_scale , weight .device , torch .float32 ), weight )
368
+ weight *= weight_decompose_scale (comfy .model_management .cast_to_device (dora_scale , weight .device , torch .float32 ), weight )
369
369
except Exception as e :
370
370
logging .error ("ERROR {} {} {}" .format (patch_type , key , e ))
371
371
elif patch_type == "lokr" :
@@ -407,7 +407,7 @@ def calculate_weight(self, patches, weight, key):
407
407
try :
408
408
weight += alpha * torch .kron (w1 , w2 ).reshape (weight .shape ).type (weight .dtype )
409
409
if dora_scale is not None :
410
- weight = apply_weight_decompose (comfy .model_management .cast_to_device (dora_scale , weight .device , torch .float32 ), weight )
410
+ weight *= weight_decompose_scale (comfy .model_management .cast_to_device (dora_scale , weight .device , torch .float32 ), weight )
411
411
except Exception as e :
412
412
logging .error ("ERROR {} {} {}" .format (patch_type , key , e ))
413
413
elif patch_type == "loha" :
@@ -439,7 +439,7 @@ def calculate_weight(self, patches, weight, key):
439
439
try :
440
440
weight += (alpha * m1 * m2 ).reshape (weight .shape ).type (weight .dtype )
441
441
if dora_scale is not None :
442
- weight = apply_weight_decompose (comfy .model_management .cast_to_device (dora_scale , weight .device , torch .float32 ), weight )
442
+ weight *= weight_decompose_scale (comfy .model_management .cast_to_device (dora_scale , weight .device , torch .float32 ), weight )
443
443
except Exception as e :
444
444
logging .error ("ERROR {} {} {}" .format (patch_type , key , e ))
445
445
elif patch_type == "glora" :
@@ -456,7 +456,7 @@ def calculate_weight(self, patches, weight, key):
456
456
try :
457
457
weight += ((torch .mm (b2 , b1 ) + torch .mm (torch .mm (weight .flatten (start_dim = 1 ), a2 ), a1 )) * alpha ).reshape (weight .shape ).type (weight .dtype )
458
458
if dora_scale is not None :
459
- weight = apply_weight_decompose (comfy .model_management .cast_to_device (dora_scale , weight .device , torch .float32 ), weight )
459
+ weight *= weight_decompose_scale (comfy .model_management .cast_to_device (dora_scale , weight .device , torch .float32 ), weight )
460
460
except Exception as e :
461
461
logging .error ("ERROR {} {} {}" .format (patch_type , key , e ))
462
462
else :
0 commit comments