Skip to content

Commit 45c339d

Browse files
committed
did it in preprocess_lora
1 parent 9e14d55 commit 45c339d

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

models/convert.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def preprocess(state_dict):
192192
if skip:
193193
continue
194194

195-
# # convert BF16 to FP16
195+
# convert BF16 to FP16
196196
if w.dtype == torch.bfloat16:
197197
w = w.to(torch.float16)
198198

@@ -342,6 +342,11 @@ def preprocess_lora(state_dict):
342342
for name, w in state_dict.items():
343343
if not isinstance(w, torch.Tensor):
344344
continue
345+
346+
# convert BF16 to FP16
347+
if w.dtype == torch.bfloat16:
348+
w = w.to(torch.float16)
349+
345350
name_without_network_parts, network_part = name.split(".", 1)
346351
new_name_without_network_parts = convert_diffusers_name_to_compvis(name_without_network_parts)
347352
if new_name_without_network_parts == None:
@@ -422,12 +427,7 @@ def convert(model_path, out_type = None, out_file=None, lora=False):
422427
if name in unused_tensors:
423428
continue
424429

425-
data_tmp = state_dict[name]
426-
if data_tmp.dtype == torch.bfloat16:
427-
# numpy does not support bf16, so we conservatively upcast to f32
428-
data = data_tmp.float().numpy()
429-
else:
430-
data = data_tmp.numpy()
430+
data = state_dict[name].numpy()
431431

432432
n_dims = len(data.shape)
433433
shape = data.shape

0 commit comments

Comments
 (0)