Skip to content

Commit 05fff7b

Browse files
Update lora converter (Snowflake-Labs#27)
1 parent 6528a7f commit 05fff7b

File tree

1 file changed

+45
-26
lines changed

1 file changed

+45
-26
lines changed

training/ds_to_hf_converter.py

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,15 @@ def convert_moe_model(
2626
ds_dir: str,
2727
output_path: str,
2828
node_rank: int = 8,
29+
has_lora: bool = True,
2930
) -> None:
3031
ds_dir = os.path.normpath(ds_dir)
3132
print(ds_dir)
3233
parent_directory = os.path.dirname(ds_dir) # assuming the ds_dir points to a global_step directory located inside a checkpoint directory.
3334
print(parent_directory)
3435
config = AutoConfig.from_pretrained(parent_directory)
35-
lora_scaling_factor = config.ds_lora.lora_alpha / math.sqrt(config.ds_lora.lora_r)
36+
if has_lora:
37+
lora_scaling_factor = config.ds_lora.lora_alpha / config.ds_lora.lora_r
3638
# No need for lora and quantization params now.
3739
config.ds_lora = None
3840
config.ds_quantization = None
@@ -62,26 +64,31 @@ def convert_moe_model(
6264
sd_hf["model.norm.weight"] = sd_m["model.norm.weight"].clone().data
6365
sd_hf["lm_head.weight"] = sd_m["lm_head.weight"].clone().data
6466

65-
# Read all the sharded baseweights
66-
sd_of_base_weights = [None] * node_rank
67-
for rank in range(node_rank):
68-
sd_of_base_weights[rank] = torch.load(os.path.join(ds_dir, f"lora_optimized_linear_sharding_rank_{rank}.pt"), map_location="cpu")
69-
70-
# Confirm all shards have the sames keys of base weights.
71-
combined_base_weight = sd_of_base_weights[0].keys()
72-
for i in range(1, node_rank):
73-
assert sd_of_base_weights[i].keys() == combined_base_weight
74-
75-
# Concatena base weights and merge the lora weights in them as well.
76-
for weight in combined_base_weight:
77-
base_weight = torch.cat([sd_of_base_weights[rank][weight].to('cuda') for rank in range(node_rank)], dim=1).to('cpu')
78-
# now you have a weight like model.layers.5.self_attn.o_proj.weight and you want to create names like
79-
# model.layers.5.self_attn.o_proj.lora_weight_2.weight, and model.layers.5.self_attn.o_proj.lora_weight_1.weight
80-
prefix, suffix = weight.rsplit(".", 1)
81-
lora_weight1 = sd_m[f"{prefix}.lora_weight_1.{suffix}"]
82-
lora_weight2 = sd_m[f"{prefix}.lora_weight_2.{suffix}"]
83-
sd_hf[weight] = merge_lora_weights(base_weight, lora_weight1, lora_weight2, lora_scaling_factor)
84-
67+
if has_lora:
68+
# Read all the sharded baseweights
69+
sd_of_base_weights = [None] * node_rank
70+
for rank in range(node_rank):
71+
sd_of_base_weights[rank] = torch.load(os.path.join(ds_dir, f"lora_optimized_linear_sharding_rank_{rank}.pt"), map_location="cpu")
72+
73+
# Confirm all shards have the sames keys of base weights.
74+
combined_base_weight = sd_of_base_weights[0].keys()
75+
for i in range(1, node_rank):
76+
assert sd_of_base_weights[i].keys() == combined_base_weight
77+
78+
# Concatena base weights and merge the lora weights in them as well.
79+
for weight in combined_base_weight:
80+
base_weight = torch.cat([sd_of_base_weights[rank][weight].to('cuda') for rank in range(node_rank)], dim=1).to('cpu')
81+
# now you have a weight like model.layers.5.self_attn.o_proj.weight and you want to create names like
82+
# model.layers.5.self_attn.o_proj.lora_weight_2.weight, and model.layers.5.self_attn.o_proj.lora_weight_1.weight
83+
prefix, suffix = weight.rsplit(".", 1)
84+
lora_weight1 = sd_m[f"{prefix}.lora_weight_1.{suffix}"]
85+
lora_weight2 = sd_m[f"{prefix}.lora_weight_2.{suffix}"]
86+
sd_hf[weight] = merge_lora_weights(base_weight, lora_weight1, lora_weight2, lora_scaling_factor)
87+
else:
88+
for k in sd_m:
89+
if "deepspeed" not in k:
90+
sd_hf[k] = sd_m[k].clone().data
91+
8592
# Now go over each layer and add weights.
8693
for layer_i in range(n_layers):
8794
print(f"Convert Layer {layer_i + 1} / {n_layers}")
@@ -114,11 +121,15 @@ def convert_moe_model(
114121
lora_weight_param1 = f"{prefix}.lora_weight_1.{suffix}"
115122
lora_weight_param2 = f"{prefix}.lora_weight_2.{suffix}"
116123
new_name = base_weight_param.replace(f"block_sparse_moe.mlp.deepspeed_moe.experts.deepspeed_experts",
117-
f"block_sparse_moe.experts")
118-
sd_hf[new_name] = merge_lora_weights(sd_expert[base_weight_param],
119-
sd_expert[lora_weight_param1],
120-
sd_expert[lora_weight_param2],
121-
lora_scaling_factor)
124+
f"block_sparse_moe.experts")
125+
if has_lora:
126+
sd_hf[new_name] = merge_lora_weights(sd_expert[base_weight_param],
127+
sd_expert[lora_weight_param1],
128+
sd_expert[lora_weight_param2],
129+
lora_scaling_factor)
130+
else:
131+
sd_hf[new_name] = sd_expert[base_weight_param]
132+
122133

123134

124135
with torch.device("meta"):
@@ -163,11 +174,19 @@ def main():
163174
required=True,
164175
help="Output path for the huggingface coverted model.",
165176
)
177+
parser.add_argument(
178+
"--no-lora-weights",
179+
required=False,
180+
action="store_true",
181+
help="Output path for the huggingface coverted model.",
182+
)
166183

167184
args = parser.parse_args()
168185
convert_moe_model(
169186
args.ds_model_path,
170187
args.output_path,
188+
node_rank=8,
189+
has_lora=not args.no_lora_weights,
171190
)
172191

173192
if __name__ == "__main__":

0 commit comments

Comments
 (0)