Skip to content

Commit 2ae914e

Browse files
RevL147luyuxiang
andauthored
fix ordered_save func (#10896)
Co-authored-by: luyuxiang <luyuxiang@YuxiangdeMacBook-Pro.local>
1 parent 5b9e0b3 commit 2ae914e

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2657,7 +2657,7 @@ def _filter_moe_no_sync_optimizer_params(self):
26572657
filter_optimzier_state_dict[op_k] = op_v
26582658
return filter_optimzier_state_dict
26592659

2660-
def _ordered_save(self, state_dict, save_path):
2660+
def _ordered_save(self, state_dict, save_path, signal_path=None):
26612661
group_size = self.args.ordered_save_group_size
26622662
hcg = fleet.get_hybrid_communicate_group()
26632663
if hcg.get_sharding_parallel_world_size() > 1 or hcg.get_model_parallel_world_size() <= 1:
@@ -2677,6 +2677,10 @@ def _ordered_save(self, state_dict, save_path):
26772677
paddle.save(state_dict, save_path)
26782678
dist.barrier(mp_group)
26792679

2680+
if signal_path is not None:
2681+
with open(signal_path, mode="w+") as f:
2682+
f.write("1")
2683+
26802684
def _save_checkpoint(self, model, metrics=None):
26812685
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
26822686
if self.args.enable_zero_cost_checkpoint:

0 commit comments

Comments
 (0)