File tree Expand file tree Collapse file tree 2 files changed +10
-1
lines changed Expand file tree Collapse file tree 2 files changed +10
-1
lines changed Original file line number Diff line number Diff line change @@ -545,6 +545,7 @@ def main():
545
545
config .use_recompute = training_args .recompute
546
546
config .tensor_parallel_degree = training_args .tensor_parallel_degree
547
547
config .tensor_parallel_rank = training_args .tensor_parallel_rank
548
+ config .to_static = training_args .to_static
548
549
549
550
if training_args .strategy .pipeline .enable and config .virtual_pp_degree > 1 :
550
551
pipeline = training_args .strategy .pipeline
@@ -554,6 +555,14 @@ def main():
554
555
config .hidden_dropout_prob = model_args .hidden_dropout_prob
555
556
config .attention_probs_dropout_prob = model_args .attention_probs_dropout_prob
556
557
print ("Final pre-training config:" , config )
558
+ if (
559
+ "replace_with_parallel_cross_entropy" in training_args .tensor_parallel_config
560
+ and config .tensor_parallel_degree > 1
561
+ and config .to_static is False
562
+ ):
563
+ from llm .utils .replace_ops import replace_cross_entropy
564
+
565
+ replace_cross_entropy ()
557
566
558
567
# Set the dtype for loading model
559
568
dtype = "float32"
Original file line number Diff line number Diff line change @@ -106,7 +106,7 @@ def parallel_cross_entropy(
106
106
"1. soft_label=False is set for parallel computation (current value: {}) \n "
107
107
"2. Input tensor is properly sharded (current sharding status: {}) \n " .format (
108
108
soft_label ,
109
- input_placement ,
109
+ input . placements ,
110
110
)
111
111
)
112
112
You can’t perform that action at this time.
0 commit comments