@@ -179,9 +179,9 @@ def preprocess(state_dict):
179
179
state_dict ["alphas_cumprod" ] = alphas_cumprod
180
180
181
181
new_state_dict = {}
182
- for name in state_dict .keys ():
182
+ for name , w in state_dict .items ():
183
183
# ignore unused tensors
184
- if not isinstance (state_dict [ name ] , torch .Tensor ):
184
+ if not isinstance (w , torch .Tensor ):
185
185
continue
186
186
skip = False
187
187
for unused_tensor in unused_tensors :
@@ -190,13 +190,25 @@ def preprocess(state_dict):
190
190
break
191
191
if skip :
192
192
continue
193
-
193
+
194
+ # # convert BF16 to FP16
195
+ if w .dtype == torch .bfloat16 :
196
+ w = w .to (torch .float16 )
197
+
194
198
# convert open_clip to hf CLIPTextModel (for SD2.x)
195
199
open_clip_to_hf_clip_model = {
196
200
"cond_stage_model.model.ln_final.bias" : "cond_stage_model.transformer.text_model.final_layer_norm.bias" ,
197
201
"cond_stage_model.model.ln_final.weight" : "cond_stage_model.transformer.text_model.final_layer_norm.weight" ,
198
202
"cond_stage_model.model.positional_embedding" : "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight" ,
199
203
"cond_stage_model.model.token_embedding.weight" : "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ,
204
+ "first_stage_model.decoder.mid.attn_1.to_k.bias" : "first_stage_model.decoder.mid.attn_1.k.bias" ,
205
+ "first_stage_model.decoder.mid.attn_1.to_k.weight" : "first_stage_model.decoder.mid.attn_1.k.weight" ,
206
+ "first_stage_model.decoder.mid.attn_1.to_out.0.bias" : "first_stage_model.decoder.mid.attn_1.proj_out.bias" ,
207
+ "first_stage_model.decoder.mid.attn_1.to_out.0.weight" : "first_stage_model.decoder.mid.attn_1.proj_out.weight" ,
208
+ "first_stage_model.decoder.mid.attn_1.to_q.bias" : "first_stage_model.decoder.mid.attn_1.q.bias" ,
209
+ "first_stage_model.decoder.mid.attn_1.to_q.weight" : "first_stage_model.decoder.mid.attn_1.q.weight" ,
210
+ "first_stage_model.decoder.mid.attn_1.to_v.bias" : "first_stage_model.decoder.mid.attn_1.v.bias" ,
211
+ "first_stage_model.decoder.mid.attn_1.to_v.weight" : "first_stage_model.decoder.mid.attn_1.v.weight" ,
200
212
}
201
213
open_clip_to_hk_clip_resblock = {
202
214
"attn.out_proj.bias" : "self_attn.out_proj.bias" ,
@@ -214,22 +226,19 @@ def preprocess(state_dict):
214
226
hf_clip_resblock_prefix = "cond_stage_model.transformer.text_model.encoder.layers."
215
227
if name in open_clip_to_hf_clip_model :
216
228
new_name = open_clip_to_hf_clip_model [name ]
217
- new_state_dict [new_name ] = state_dict [name ]
218
229
print (f"preprocess { name } => { new_name } " )
219
- continue
230
+ name = new_name
220
231
if name .startswith (open_clip_resblock_prefix ):
221
232
remain = name [len (open_clip_resblock_prefix ):]
222
233
idx = remain .split ("." )[0 ]
223
234
suffix = remain [len (idx )+ 1 :]
224
235
if suffix == "attn.in_proj_weight" :
225
- w = state_dict [name ]
226
236
w_q , w_k , w_v = w .chunk (3 )
227
237
for new_suffix , new_w in zip (["self_attn.q_proj.weight" , "self_attn.k_proj.weight" , "self_attn.v_proj.weight" ], [w_q , w_k , w_v ]):
228
238
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix
229
239
new_state_dict [new_name ] = new_w
230
240
print (f"preprocess { name } { w .size ()} => { new_name } { new_w .size ()} " )
231
241
elif suffix == "attn.in_proj_bias" :
232
- w = state_dict [name ]
233
242
w_q , w_k , w_v = w .chunk (3 )
234
243
for new_suffix , new_w in zip (["self_attn.q_proj.bias" , "self_attn.k_proj.bias" , "self_attn.v_proj.bias" ], [w_q , w_k , w_v ]):
235
244
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix
@@ -238,20 +247,27 @@ def preprocess(state_dict):
238
247
else :
239
248
new_suffix = open_clip_to_hk_clip_resblock [suffix ]
240
249
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix
241
- new_state_dict [new_name ] = state_dict [ name ]
250
+ new_state_dict [new_name ] = w
242
251
print (f"preprocess { name } => { new_name } " )
243
252
continue
244
253
245
254
# convert unet transformer linear to conv2d 1x1
246
255
if name .startswith ("model.diffusion_model." ) and (name .endswith ("proj_in.weight" ) or name .endswith ("proj_out.weight" )):
247
- w = state_dict [name ]
248
- if len (state_dict [name ].shape ) == 2 :
256
+ if len (w .shape ) == 2 :
257
+ new_w = w .unsqueeze (2 ).unsqueeze (3 )
258
+ new_state_dict [name ] = new_w
259
+ print (f"preprocess { name } { w .size ()} => { name } { new_w .size ()} " )
260
+ continue
261
+
262
+ # convert vae attn block linear to conv2d 1x1
263
+ if name .startswith ("first_stage_model." ) and "attn_1" in name :
264
+ if len (w .shape ) == 2 :
249
265
new_w = w .unsqueeze (2 ).unsqueeze (3 )
250
266
new_state_dict [name ] = new_w
251
267
print (f"preprocess { name } { w .size ()} => { name } { new_w .size ()} " )
252
268
continue
253
269
254
- new_state_dict [name ] = state_dict [ name ]
270
+ new_state_dict [name ] = w
255
271
return new_state_dict
256
272
257
273
def convert (model_path , out_type = None , out_file = None ):
0 commit comments