Skip to content

Commit bfba64b

Browse files
authored
[Auto Parallel] add llama with auto pp (#10751)
* llama_with_auto_pp * support change n_microbatches in json * fix init model * keep layers which are not in curr rank * formal llama_with_auto_pp * formal llama_with_auto_pp * style * formal * fix * fix * fix * fix * add condition for hybrid pp * fix * add test case * add test case * add test case * add llama type pp_schedule * fix
1 parent 1938c9e commit bfba64b

File tree

4 files changed

+104
-6
lines changed

4 files changed

+104
-6
lines changed

paddlenlp/trainer/auto_trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929

3030
from paddlenlp.trainer import Trainer
3131

32-
from ..transformers import get_pp_schedule
3332
from ..transformers.model_utils import clean_model_class_name, unwrap_model
3433
from ..utils.batch_sampler import DistributedBatchSampler as NlpDistributedBatchSampler
3534
from ..utils.env import (
@@ -49,6 +48,7 @@
4948
_exec_mode_guard,
5049
check_auto_parallel_pipeline_support,
5150
get_last_checkpoint,
51+
get_pp_schedule,
5252
has_length,
5353
speed_metrics,
5454
)
@@ -104,6 +104,7 @@ def loss_func(loss, outputs):
104104
if self.args.pipeline_parallel_degree > 1 and check_auto_parallel_pipeline_support(self.model_type):
105105
self.pp_schedule = get_pp_schedule(
106106
model,
107+
self.model_type,
107108
self.args.n_microbatches,
108109
self.criterion,
109110
self.args.pipeline_schedule_mode,

paddlenlp/trainer/trainer_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from paddlenlp.ops import Topology
4444

4545
from ..trainer.argparser import strtobool
46+
from ..transformers import get_llama_pp_schedule
4647
from ..transformers.tokenizer_utils_base import BatchEncoding
4748
from ..utils.env import PREFIX_CHECKPOINT_DIR, _re_checkpoint # noqa for compatibility
4849
from ..utils.fault_tolerance import PDC_DOWNLOAD_ERROR
@@ -1258,3 +1259,9 @@ def download_recovery_ckpt_from_pdc(recovery_checkpoint_path, timeout):
12581259
def check_auto_parallel_pipeline_support(model_type=None):
12591260
support_types = ["llama_pp"]
12601261
return model_type in support_types
1262+
1263+
1264+
def get_pp_schedule(model, model_type, n_microbatches, loss_fn, mode, pp_degree, group):
1265+
assert check_auto_parallel_pipeline_support(model_type)
1266+
if model_type == "llama_pp":
1267+
return get_llama_pp_schedule(model, n_microbatches, loss_fn, mode, pp_degree, group)

paddlenlp/transformers/llama/modeling_auto_pp.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def swiglu(x, y=None):
5959
flash_attention = None
6060

6161
__all__ = [
62-
"get_pp_schedule",
62+
"get_llama_pp_schedule",
6363
"LlamaForCausalLM3DAutoPP",
6464
]
6565

@@ -146,10 +146,12 @@ def return_args(hidden_states, attention_mask=None, position_ids=None, alibi=Non
146146

147147

148148
class LlamaChunk(nn.Layer):
149-
def __init__(self, layers=None, is_first=False):
149+
def __init__(self, layers=None, is_first=False, is_last=False):
150150
super(LlamaChunk, self).__init__()
151+
assert not (is_first and is_last)
151152
self.layers = layers
152153
self.is_first = is_first
154+
self.is_last = is_last
153155

154156
def forward(self, *args, **kwargs):
155157
if self.is_first:
@@ -161,6 +163,13 @@ def forward(self, *args, **kwargs):
161163
for idx, (decoder_layer) in enumerate(self.layers):
162164
outputs = decoder_layer(outputs)
163165
return outputs
166+
elif self.is_last:
167+
outputs = args
168+
# decoder layers
169+
for idx, (decoder_layer) in enumerate(self.layers):
170+
outputs = decoder_layer(outputs)
171+
if isinstance(outputs, tuple):
172+
outputs = outputs[0]
164173
else:
165174
outputs = args
166175
# decoder layers
@@ -182,9 +191,15 @@ def manual_model_split(model, stage_idx, group, mode, pp_degree):
182191
def _build_stage(model, stage_idx, group):
183192
new_model = None
184193
if stage_idx == 0: # 第一个model_chunk输入特殊处理
185-
new_model = LlamaChunk(layer_lists[:chunk_size], is_first=True)
194+
new_model = LlamaChunk(layer_lists[:chunk_size], is_first=True, is_last=False)
195+
elif stage_idx == chunk_num - 1: # 最后一个一个model_chunk输出特殊处理
196+
new_model = LlamaChunk(
197+
layer_lists[stage_idx * chunk_size : (stage_idx + 1) * chunk_size], is_first=False, is_last=True
198+
)
186199
else:
187-
new_model = LlamaChunk(layer_lists[stage_idx * chunk_size : (stage_idx + 1) * chunk_size], is_first=False)
200+
new_model = LlamaChunk(
201+
layer_lists[stage_idx * chunk_size : (stage_idx + 1) * chunk_size], is_first=False, is_last=False
202+
)
188203
stage = PipelineStage(new_model, stage_idx, chunk_num, group=group)
189204
return stage
190205

@@ -195,7 +210,7 @@ def _build_stage(model, stage_idx, group):
195210
return stages
196211

197212

198-
def get_pp_schedule(model, n_microbatches, loss_fn, mode, pp_degree, group):
213+
def get_llama_pp_schedule(model, n_microbatches, loss_fn, mode, pp_degree, group):
199214
assert mode in ["VPP", "1F1B", "FThenB"]
200215
stages = manual_model_split(model, group.rank, group, mode, pp_degree)
201216
if mode == "VPP":

scripts/distribute/ci_case_auto.sh

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ function llama_case_list_auto() {
102102
# llama_dygraph_auto_bs8_fp32_DP2-MP2
103103
llama_dygraph_auto_bs8_fp32_DP2-MP2-PP2
104104
llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2
105+
llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2_hybrid_pp
105106
# llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2_intermediate
106107
llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2-VPP3_split_bw
107108
llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2
@@ -695,6 +696,80 @@ function llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2_intermediate() {
695696
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
696697
echo "=========== $FUNCNAME run end ==========="
697698
}
699+
700+
function llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2_hybrid_pp() {
701+
echo IS_A100 is $IS_A100
702+
if [ $IS_A100 -ne 0 ]; then
703+
echo "=========== $FUNCNAME run begin ==========="
704+
export PYTHONPATH=$root_path/:$PYTHONPATH
705+
export FLAGS_call_stack_level=3
706+
export NVIDIA_TF32_OVERRIDE=0
707+
708+
task_name="llama_auto_bs8_fp16_dp2mp2pp2_hybrid_pp"
709+
case_out_dir="output/$task_name"
710+
case_log_dir="output/$task_name""_log"
711+
rm -rf $case_out_dir
712+
rm -rf $case_log_dir
713+
714+
python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" --log_dir $case_log_dir run_pretrain_auto.py \
715+
--model_type "llama_pp" \
716+
--model_name_or_path "facebook/llama-7b" \
717+
--tokenizer_name_or_path "facebook/llama-7b" \
718+
--input_dir "./data" \
719+
--output_dir $case_out_dir \
720+
--split 949,50,1 \
721+
--max_seq_length 2048 \
722+
--hidden_size 1024 \
723+
--intermediate_size 3072 \
724+
--num_hidden_layers 8 \
725+
--num_attention_heads 32 \
726+
--per_device_train_batch_size 4 \
727+
--per_device_eval_batch_size 4 \
728+
--n_microbatch 4 \
729+
--gradient_accumulation_steps 1 \
730+
--use_flash_attention 1 \
731+
--use_fused_rms_norm 0 \
732+
--fp16 1 \
733+
--fp16_opt_level "O2" \
734+
--amp_master_grad 1 \
735+
--scale_loss 1024 \
736+
--pipeline_parallel_degree 2 \
737+
--pipeline_schedule_mode "FThenB" \
738+
--tensor_parallel_degree 2 \
739+
--sharding_parallel_degree 1 \
740+
--learning_rate 0.0001 \
741+
--min_learning_rate 0.00001 \
742+
--max_steps 10 \
743+
--save_steps 5000 \
744+
--weight_decay 0.01 \
745+
--warmup_ratio 0.01 \
746+
--logging_steps 1 \
747+
--dataloader_num_workers 1 \
748+
--sharding "" \
749+
--eval_steps 1000000 \
750+
--disable_tqdm true \
751+
--continue_training 0 \
752+
--recompute 0 \
753+
--do_train \
754+
--do_eval \
755+
--device "gpu" \
756+
--data_impl "mmap" \
757+
--enable_auto_parallel 1 \
758+
--to_static 0 \
759+
--max_grad_norm 0.0 \
760+
>>${log_path}/$FUNCNAME 2>&1
761+
loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
762+
ips=-1
763+
mem=-1
764+
echo "result: loss=$loss ips=$ips mem=$mem"
765+
loss_base=9.57178879
766+
ips_base=-1
767+
mem_base=-1
768+
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
769+
echo "=========== $FUNCNAME run end ==========="
770+
fi
771+
}
772+
698773
function llama_dy2st_auto_bs4_bf16_DP1-MP1-PP4-SD2() {
699774
# Only A100 support this case.
700775
echo IS_A100 is $IS_A100

0 commit comments

Comments
 (0)