Skip to content

Commit 1938c9e

Browse files
Run llama dyanmic pp perf (#10753)
* llama_with_auto_pp * support change n_microbatches in json * fix init model * keep layers which are not in curr rank * formal llama_with_auto_pp * formal llama_with_auto_pp * style * formal * fix * fix * fix * fix * add condition for hybrid pp * fix * add json * remove old * update json --------- Co-authored-by: Waynezee <wangxiangzhe@baidu.com>
1 parent e7420b1 commit 1938c9e

File tree

7 files changed

+736
-1
lines changed

7 files changed

+736
-1
lines changed

llm/auto_parallel/llama/run_pretrain_auto.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
LinearAnnealingWithWarmupDecay,
4242
LlamaConfig,
4343
LlamaForCausalLM3DAuto,
44+
LlamaForCausalLM3DAutoPP,
4445
LlamaForCausalLMNet,
4546
LlamaPretrainingCriterion3DAuto,
4647
LlamaPretrainingCriterionNet,
@@ -49,6 +50,7 @@
4950

5051
MODEL_CLASSES = {
5152
"llama": (LlamaConfig, LlamaForCausalLM3DAuto, LlamaPretrainingCriterion3DAuto),
53+
"llama_pp": (LlamaConfig, LlamaForCausalLM3DAutoPP, LlamaPretrainingCriterion3DAuto),
5254
"llama_network": (LlamaConfig, LlamaForCausalLMNet, LlamaPretrainingCriterionNet),
5355
}
5456

@@ -94,6 +96,10 @@ class PreTrainingArguments(AutoTrainingArguments):
9496
default=False,
9597
metadata={"help": "Weather to run benchmark by autotuner. True for from_scratch and pad_max_length."},
9698
)
99+
n_microbatches: int = field(
100+
default=1,
101+
metadata={"help": "Control the num of microbatches in one pp step."},
102+
)
97103

98104
def __post_init__(self):
99105
super().__post_init__()
@@ -637,6 +643,7 @@ def fn(layer):
637643
)
638644
trainer = PretrainingTrainer(
639645
model=model,
646+
model_type=model_args.model_type,
640647
criterion=criterion,
641648
args=training_args,
642649
data_collator=data_collator,

paddlenlp/trainer/auto_trainer.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
from paddlenlp.trainer import Trainer
3131

32+
from ..transformers import get_pp_schedule
3233
from ..transformers.model_utils import clean_model_class_name, unwrap_model
3334
from ..utils.batch_sampler import DistributedBatchSampler as NlpDistributedBatchSampler
3435
from ..utils.env import (
@@ -46,6 +47,7 @@
4647
ShardingOption,
4748
TrainOutput,
4849
_exec_mode_guard,
50+
check_auto_parallel_pipeline_support,
4951
get_last_checkpoint,
5052
has_length,
5153
speed_metrics,
@@ -77,6 +79,7 @@ def loss_func(loss, outputs):
7779
kwargs.update({"criterion": loss_func})
7880
self.auto_dist_config = kwargs.pop("auto_dist_config", None)
7981
model = kwargs.get("model", None)
82+
self.model_type = kwargs.pop("model_type", None)
8083
assert model is not None
8184
if kwargs.get("args", None) is not None and kwargs["args"].use_intermediate_api:
8285
if not parallelize.has_parallelized_model:
@@ -93,12 +96,20 @@ def loss_func(loss, outputs):
9396
if not param._is_initialized() and param._init_func is not None:
9497
param.initialize()
9598
kwargs["model"] = model
96-
9799
super().__init__(*args, **kwargs)
98100
assert self.args.enable_auto_parallel
99101

100102
self.global_mesh = fleet.auto.get_mesh()
101103
self.comm_group_in_pp = fleet.get_hybrid_communicate_group().get_pipe_parallel_group()
104+
if self.args.pipeline_parallel_degree > 1 and check_auto_parallel_pipeline_support(self.model_type):
105+
self.pp_schedule = get_pp_schedule(
106+
model,
107+
self.args.n_microbatches,
108+
self.criterion,
109+
self.args.pipeline_schedule_mode,
110+
self.args.pipeline_parallel_degree,
111+
self.comm_group_in_pp,
112+
)
102113
self._in_pir_mode = paddle.base.framework.get_flags("FLAGS_enable_pir_api")["FLAGS_enable_pir_api"]
103114

104115
@classmethod
@@ -703,7 +714,56 @@ def to_list(value):
703714

704715
return (loss, outputs) if return_outputs else loss
705716

717+
def compute_pipeline_loss(self, model, inputs, return_outputs=False):
718+
"""
719+
How the loss is computed by Trainer. By default, all models return the loss in the first element.
720+
Subclass and override for custom behavior.
721+
"""
722+
if self.criterion is not None:
723+
if "labels" in inputs:
724+
labels = inputs.pop("labels")
725+
elif "start_positions" in inputs and "end_positions" in inputs:
726+
labels = (inputs.pop("start_positions"), inputs.pop("end_positions"))
727+
elif self.args.label_names is not None:
728+
labels = []
729+
for label in self.label_names:
730+
labels.append(inputs.pop(label))
731+
labels = tuple(labels)
732+
elif "generator_labels" in inputs:
733+
labels = inputs["generator_labels"]
734+
else:
735+
labels = None
736+
737+
pp_rank = self.comm_group_in_pp.rank
738+
losses = []
739+
if pp_rank == 0: # 第一个pp_stage,参数传入数据流
740+
self.pp_schedule.step(**inputs) # 最后的pp_stage,参数传入label, 并输出loss
741+
elif pp_rank == self.args.pipeline_parallel_degree - 1:
742+
self.pp_schedule.step(target=labels, losses=losses)
743+
else:
744+
self.pp_schedule.step()
745+
746+
final_loss = None
747+
if len(losses) != 0:
748+
final_loss = paddle.stack(losses).mean()
749+
750+
return final_loss
751+
752+
def dynamic_auto_parallel_pipeline_training(
753+
self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]
754+
) -> paddle.Tensor:
755+
assert self.args.pipeline_parallel_degree > 1, "pipeline_parallel_degree must be greater than 1."
756+
assert check_auto_parallel_pipeline_support(
757+
self.model_type
758+
), "dynamic auto_parallel pipeline only supports special models"
759+
with self.autocast_smart_context_manager():
760+
loss = self.compute_pipeline_loss(model, inputs)
761+
762+
return loss
763+
706764
def dynamic_training(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor:
765+
if self.args.pipeline_parallel_degree > 1 and check_auto_parallel_pipeline_support(self.model_type):
766+
return self.dynamic_auto_parallel_pipeline_training(model, inputs)
707767
with self.autocast_smart_context_manager():
708768
loss = self.compute_loss(model, inputs)
709769

paddlenlp/trainer/trainer_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,3 +1253,8 @@ def download_recovery_ckpt_from_pdc(recovery_checkpoint_path, timeout):
12531253
raise RuntimeError(
12541254
f"{PDC_DOWNLOAD_ERROR}; Error occurred when trying to download checkpoint from PDC, recovery_checkpoint_path: {recovery_checkpoint_path}, timeout: {timeout}; error details: {PDCErrorMessageMap[result]}"
12551255
)
1256+
1257+
1258+
def check_auto_parallel_pipeline_support(model_type=None):
1259+
support_types = ["llama_pp"]
1260+
return model_type in support_types

paddlenlp/transformers/llama/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .configuration import *
1616
from .modeling import *
1717
from .modeling_auto import *
18+
from .modeling_auto_pp import *
1819
from .modeling_network import *
1920
from .modeling_pp import *
2021
from .tokenizer import *

0 commit comments

Comments
 (0)