29
29
30
30
from paddlenlp .trainer import Trainer
31
31
32
+ from ..transformers import get_pp_schedule
32
33
from ..transformers .model_utils import clean_model_class_name , unwrap_model
33
34
from ..utils .batch_sampler import DistributedBatchSampler as NlpDistributedBatchSampler
34
35
from ..utils .env import (
46
47
ShardingOption ,
47
48
TrainOutput ,
48
49
_exec_mode_guard ,
50
+ check_auto_parallel_pipeline_support ,
49
51
get_last_checkpoint ,
50
52
has_length ,
51
53
speed_metrics ,
@@ -77,6 +79,7 @@ def loss_func(loss, outputs):
77
79
kwargs .update ({"criterion" : loss_func })
78
80
self .auto_dist_config = kwargs .pop ("auto_dist_config" , None )
79
81
model = kwargs .get ("model" , None )
82
+ self .model_type = kwargs .pop ("model_type" , None )
80
83
assert model is not None
81
84
if kwargs .get ("args" , None ) is not None and kwargs ["args" ].use_intermediate_api :
82
85
if not parallelize .has_parallelized_model :
@@ -93,12 +96,20 @@ def loss_func(loss, outputs):
93
96
if not param ._is_initialized () and param ._init_func is not None :
94
97
param .initialize ()
95
98
kwargs ["model" ] = model
96
-
97
99
super ().__init__ (* args , ** kwargs )
98
100
assert self .args .enable_auto_parallel
99
101
100
102
self .global_mesh = fleet .auto .get_mesh ()
101
103
self .comm_group_in_pp = fleet .get_hybrid_communicate_group ().get_pipe_parallel_group ()
104
+ if self .args .pipeline_parallel_degree > 1 and check_auto_parallel_pipeline_support (self .model_type ):
105
+ self .pp_schedule = get_pp_schedule (
106
+ model ,
107
+ self .args .n_microbatches ,
108
+ self .criterion ,
109
+ self .args .pipeline_schedule_mode ,
110
+ self .args .pipeline_parallel_degree ,
111
+ self .comm_group_in_pp ,
112
+ )
102
113
self ._in_pir_mode = paddle .base .framework .get_flags ("FLAGS_enable_pir_api" )["FLAGS_enable_pir_api" ]
103
114
104
115
@classmethod
@@ -703,7 +714,56 @@ def to_list(value):
703
714
704
715
return (loss , outputs ) if return_outputs else loss
705
716
717
+ def compute_pipeline_loss (self , model , inputs , return_outputs = False ):
718
+ """
719
+ How the loss is computed by Trainer. By default, all models return the loss in the first element.
720
+ Subclass and override for custom behavior.
721
+ """
722
+ if self .criterion is not None :
723
+ if "labels" in inputs :
724
+ labels = inputs .pop ("labels" )
725
+ elif "start_positions" in inputs and "end_positions" in inputs :
726
+ labels = (inputs .pop ("start_positions" ), inputs .pop ("end_positions" ))
727
+ elif self .args .label_names is not None :
728
+ labels = []
729
+ for label in self .label_names :
730
+ labels .append (inputs .pop (label ))
731
+ labels = tuple (labels )
732
+ elif "generator_labels" in inputs :
733
+ labels = inputs ["generator_labels" ]
734
+ else :
735
+ labels = None
736
+
737
+ pp_rank = self .comm_group_in_pp .rank
738
+ losses = []
739
+ if pp_rank == 0 : # 第一个pp_stage,参数传入数据流
740
+ self .pp_schedule .step (** inputs ) # 最后的pp_stage,参数传入label, 并输出loss
741
+ elif pp_rank == self .args .pipeline_parallel_degree - 1 :
742
+ self .pp_schedule .step (target = labels , losses = losses )
743
+ else :
744
+ self .pp_schedule .step ()
745
+
746
+ final_loss = None
747
+ if len (losses ) != 0 :
748
+ final_loss = paddle .stack (losses ).mean ()
749
+
750
+ return final_loss
751
+
752
+ def dynamic_auto_parallel_pipeline_training (
753
+ self , model : nn .Layer , inputs : Dict [str , Union [paddle .Tensor , Any ]]
754
+ ) -> paddle .Tensor :
755
+ assert self .args .pipeline_parallel_degree > 1 , "pipeline_parallel_degree must be greater than 1."
756
+ assert check_auto_parallel_pipeline_support (
757
+ self .model_type
758
+ ), "dynamic auto_parallel pipeline only supports special models"
759
+ with self .autocast_smart_context_manager ():
760
+ loss = self .compute_pipeline_loss (model , inputs )
761
+
762
+ return loss
763
+
706
764
def dynamic_training (self , model : nn .Layer , inputs : Dict [str , Union [paddle .Tensor , Any ]]) -> paddle .Tensor :
765
+ if self .args .pipeline_parallel_degree > 1 and check_auto_parallel_pipeline_support (self .model_type ):
766
+ return self .dynamic_auto_parallel_pipeline_training (model , inputs )
707
767
with self .autocast_smart_context_manager ():
708
768
loss = self .compute_loss (model , inputs )
709
769
0 commit comments