@@ -52,6 +52,7 @@ def __init__(
52
52
adm_in_channels = None ,
53
53
transformer_depth_middle = None ,
54
54
transformer_depth_output = None ,
55
+ attn_precision = None ,
55
56
device = None ,
56
57
operations = comfy .ops .disable_weight_init ,
57
58
** kwargs ,
@@ -202,7 +203,7 @@ def __init__(
202
203
SpatialTransformer (
203
204
ch , num_heads , dim_head , depth = num_transformers , context_dim = context_dim ,
204
205
disable_self_attn = disabled_sa , use_linear = use_linear_in_transformer ,
205
- use_checkpoint = use_checkpoint , dtype = self .dtype , device = device , operations = operations
206
+ use_checkpoint = use_checkpoint , attn_precision = attn_precision , dtype = self .dtype , device = device , operations = operations
206
207
)
207
208
)
208
209
self .input_blocks .append (TimestepEmbedSequential (* layers ))
@@ -262,7 +263,7 @@ def __init__(
262
263
mid_block += [SpatialTransformer ( # always uses a self-attn
263
264
ch , num_heads , dim_head , depth = transformer_depth_middle , context_dim = context_dim ,
264
265
disable_self_attn = disable_middle_self_attn , use_linear = use_linear_in_transformer ,
265
- use_checkpoint = use_checkpoint , dtype = self .dtype , device = device , operations = operations
266
+ use_checkpoint = use_checkpoint , attn_precision = attn_precision , dtype = self .dtype , device = device , operations = operations
266
267
),
267
268
ResBlock (
268
269
ch ,
0 commit comments