Skip to content

Commit bd62138

Browse files
committed
feat: adapt to more weight formats
1 parent 3a25179 commit bd62138

File tree

2 files changed

+28
-11
lines changed

2 files changed

+28
-11
lines changed

models/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
*.bin
22
*.ckpt
33
*.safetensor
4+
*.safetensors
45
*.log

models/convert.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,9 @@ def preprocess(state_dict):
179179
state_dict["alphas_cumprod"] = alphas_cumprod
180180

181181
new_state_dict = {}
182-
for name in state_dict.keys():
182+
for name, w in state_dict.items():
183183
# ignore unused tensors
184-
if not isinstance(state_dict[name], torch.Tensor):
184+
if not isinstance(w, torch.Tensor):
185185
continue
186186
skip = False
187187
for unused_tensor in unused_tensors:
@@ -190,13 +190,25 @@ def preprocess(state_dict):
190190
break
191191
if skip:
192192
continue
193-
193+
194+
# # convert BF16 to FP16
195+
if w.dtype == torch.bfloat16:
196+
w = w.to(torch.float16)
197+
194198
# convert open_clip to hf CLIPTextModel (for SD2.x)
195199
open_clip_to_hf_clip_model = {
196200
"cond_stage_model.model.ln_final.bias": "cond_stage_model.transformer.text_model.final_layer_norm.bias",
197201
"cond_stage_model.model.ln_final.weight": "cond_stage_model.transformer.text_model.final_layer_norm.weight",
198202
"cond_stage_model.model.positional_embedding": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
199203
"cond_stage_model.model.token_embedding.weight": "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight",
204+
"first_stage_model.decoder.mid.attn_1.to_k.bias": "first_stage_model.decoder.mid.attn_1.k.bias",
205+
"first_stage_model.decoder.mid.attn_1.to_k.weight": "first_stage_model.decoder.mid.attn_1.k.weight",
206+
"first_stage_model.decoder.mid.attn_1.to_out.0.bias": "first_stage_model.decoder.mid.attn_1.proj_out.bias",
207+
"first_stage_model.decoder.mid.attn_1.to_out.0.weight": "first_stage_model.decoder.mid.attn_1.proj_out.weight",
208+
"first_stage_model.decoder.mid.attn_1.to_q.bias": "first_stage_model.decoder.mid.attn_1.q.bias",
209+
"first_stage_model.decoder.mid.attn_1.to_q.weight": "first_stage_model.decoder.mid.attn_1.q.weight",
210+
"first_stage_model.decoder.mid.attn_1.to_v.bias": "first_stage_model.decoder.mid.attn_1.v.bias",
211+
"first_stage_model.decoder.mid.attn_1.to_v.weight": "first_stage_model.decoder.mid.attn_1.v.weight",
200212
}
201213
open_clip_to_hk_clip_resblock = {
202214
"attn.out_proj.bias": "self_attn.out_proj.bias",
@@ -214,22 +226,19 @@ def preprocess(state_dict):
214226
hf_clip_resblock_prefix = "cond_stage_model.transformer.text_model.encoder.layers."
215227
if name in open_clip_to_hf_clip_model:
216228
new_name = open_clip_to_hf_clip_model[name]
217-
new_state_dict[new_name] = state_dict[name]
218229
print(f"preprocess {name} => {new_name}")
219-
continue
230+
name = new_name
220231
if name.startswith(open_clip_resblock_prefix):
221232
remain = name[len(open_clip_resblock_prefix):]
222233
idx = remain.split(".")[0]
223234
suffix = remain[len(idx)+1:]
224235
if suffix == "attn.in_proj_weight":
225-
w = state_dict[name]
226236
w_q, w_k, w_v = w.chunk(3)
227237
for new_suffix, new_w in zip(["self_attn.q_proj.weight", "self_attn.k_proj.weight", "self_attn.v_proj.weight"], [w_q, w_k, w_v]):
228238
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix
229239
new_state_dict[new_name] = new_w
230240
print(f"preprocess {name}{w.size()} => {new_name}{new_w.size()}")
231241
elif suffix == "attn.in_proj_bias":
232-
w = state_dict[name]
233242
w_q, w_k, w_v = w.chunk(3)
234243
for new_suffix, new_w in zip(["self_attn.q_proj.bias", "self_attn.k_proj.bias", "self_attn.v_proj.bias"], [w_q, w_k, w_v]):
235244
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix
@@ -238,20 +247,27 @@ def preprocess(state_dict):
238247
else:
239248
new_suffix = open_clip_to_hk_clip_resblock[suffix]
240249
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix
241-
new_state_dict[new_name] = state_dict[name]
250+
new_state_dict[new_name] = w
242251
print(f"preprocess {name} => {new_name}")
243252
continue
244253

245254
# convert unet transformer linear to conv2d 1x1
246255
if name.startswith("model.diffusion_model.") and (name.endswith("proj_in.weight") or name.endswith("proj_out.weight")):
247-
w = state_dict[name]
248-
if len(state_dict[name].shape) == 2:
256+
if len(w.shape) == 2:
257+
new_w = w.unsqueeze(2).unsqueeze(3)
258+
new_state_dict[name] = new_w
259+
print(f"preprocess {name} {w.size()} => {name} {new_w.size()}")
260+
continue
261+
262+
# convert vae attn block linear to conv2d 1x1
263+
if name.startswith("first_stage_model.") and "attn_1" in name:
264+
if len(w.shape) == 2:
249265
new_w = w.unsqueeze(2).unsqueeze(3)
250266
new_state_dict[name] = new_w
251267
print(f"preprocess {name} {w.size()} => {name} {new_w.size()}")
252268
continue
253269

254-
new_state_dict[name] = state_dict[name]
270+
new_state_dict[name] = w
255271
return new_state_dict
256272

257273
def convert(model_path, out_type = None, out_file=None):

0 commit comments

Comments
 (0)