Skip to content

Commit 53545aa

Browse files
authored
[Auto-Parallel] optimize the perfermance of GPT-3 (#10780)
1 parent 24a8887 commit 53545aa

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

llm/auto_parallel/gpt-3/run_pretrain_auto.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,7 @@ def main():
545545
config.use_recompute = training_args.recompute
546546
config.tensor_parallel_degree = training_args.tensor_parallel_degree
547547
config.tensor_parallel_rank = training_args.tensor_parallel_rank
548+
config.to_static = training_args.to_static
548549

549550
if training_args.strategy.pipeline.enable and config.virtual_pp_degree > 1:
550551
pipeline = training_args.strategy.pipeline
@@ -554,6 +555,14 @@ def main():
554555
config.hidden_dropout_prob = model_args.hidden_dropout_prob
555556
config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob
556557
print("Final pre-training config:", config)
558+
if (
559+
"replace_with_parallel_cross_entropy" in training_args.tensor_parallel_config
560+
and config.tensor_parallel_degree > 1
561+
and config.to_static is False
562+
):
563+
from llm.utils.replace_ops import replace_cross_entropy
564+
565+
replace_cross_entropy()
557566

558567
# Set the dtype for loading model
559568
dtype = "float32"

llm/utils/replace_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def parallel_cross_entropy(
106106
"1. soft_label=False is set for parallel computation (current value: {}) \n"
107107
"2. Input tensor is properly sharded (current sharding status: {}) \n".format(
108108
soft_label,
109-
input_placement,
109+
input.placements,
110110
)
111111
)
112112

0 commit comments

Comments
 (0)