Skip to content

Commit 44eff1f

Browse files
authored
Add auto_parallel context_parallel strategy (#10722)
* add auto_parallel context_parallel and sep * fix cp;add cp intermediate api demo * fix * fix * fix * fix * fix modeling_auto * fix modeling_auto * fix pre-commit error * fix pre-commit error * fix ci error;add cp ut case * fix ci error;add cp ut case * fix ci error * fix ci case loss diff error * fix ci case loss diff error * fix PaddleNLP-CI-gpt-3 ci error;cp not support non-a100 * fix PaddleNLP-CI-gpt-3 ci error;cp not support non-a100 * fix useless code; * edit ci loss;changed cause paddle whl;0708
1 parent 78741f9 commit 44eff1f

File tree

8 files changed

+597
-130
lines changed

8 files changed

+597
-130
lines changed

llm/auto_parallel/llama/run_pretrain_auto.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,12 +410,20 @@ def init_seed(seed: int = 1234, args=None):
410410
order = ["pp", "dp", "sharding", "mp", "sep"]
411411
elif args.hybrid_parallel_topo_order == "sharding_first":
412412
order = ["dp", "sharding", "pp", "mp", "sep"]
413+
if args.context_parallel_degree is not None and args.context_parallel_degree > 1:
414+
sep_degree = args.context_parallel_degree
415+
elif args.sep_parallel_degree is not None and args.sep_parallel_degree > 1:
416+
sep_degree = args.sep_parallel_degree
417+
else:
418+
sep_degree = 1
419+
sep_degree=args.sep_parallel_degree if args.sep_parallel_degree > 1 else args.context_parallel_degree
413420
topo = Topology(
414421
dist.get_rank(),
415422
dist.get_world_size(),
416423
dp_degree=args.dataset_world_size,
417424
pp_degree=args.pipeline_parallel_degree,
418425
mp_degree=args.tensor_parallel_degree,
426+
sep_degree=sep_degree,
419427
sharding_degree=1, # auto_parallel's sharding is not orthogonal with dp, mp and pp
420428
order=order,
421429
)
@@ -555,6 +563,8 @@ def main():
555563
config.tensor_parallel_rank = training_args.tensor_parallel_rank
556564
config.sharding_parallel_degree = training_args.sharding_parallel_degree
557565
config.to_static = training_args.to_static
566+
config.sep_parallel_degree = training_args.sep_parallel_degree
567+
config.context_parallel_degree = training_args.context_parallel_degree
558568

559569
if training_args.strategy.pipeline.enable and config.virtual_pp_degree > 1:
560570
pipeline = training_args.strategy.pipeline

paddlenlp/trainer/auto_trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929

3030
from paddlenlp.trainer import Trainer
3131

32+
# from ..transformers.segment_parallel_utils import split_inputs_sequence_dim
33+
from ..transformers.context_parallel_utils import split_sequence_dim_load_balance
3234
from ..transformers.model_utils import clean_model_class_name, unwrap_model
3335
from ..utils.batch_sampler import DistributedBatchSampler as NlpDistributedBatchSampler
3436
from ..utils.env import (
@@ -141,6 +143,7 @@ def parallel_model(cls, model, training_args: AutoTrainingArguments):
141143
"data_sharding_parallel": training_args.dataset_world_size > 1,
142144
"sharding": training_args.sharding,
143145
"sharding_mesh_dim": training_args.sharding_parallel_mesh_dimension,
146+
"context_parallel": training_args.context_parallel_degree > 1 or training_args.sep_parallel_degree > 1,
144147
}
145148
auto_dist_config = model._generate_auto_dist_config(auto_dist_degree)
146149
model = parallelize.parallelize_model(
@@ -567,7 +570,8 @@ def _inner_training_loop(
567570
if step_control % args.gradient_accumulation_steps == 0:
568571
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
569572
self.timers and self.timers("forward-backward").start()
570-
573+
if self.args.context_parallel_degree > 1 and self.args.split_inputs_sequence_dim:
574+
inputs = split_sequence_dim_load_balance(inputs)
571575
tr_loss_step = self.training_step(model, inputs)
572576

573577
with _exec_mode_guard("dynamic"):

paddlenlp/trainer/training_args.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1743,12 +1743,25 @@ def is_segment_parallel_supported():
17431743
amp.custom_white_list = self.amp_custom_white_list if self.amp_custom_white_list is not None else []
17441744

17451745
self.strategy = strategy
1746+
sep_degree = self.sep_parallel_degree if self.sep_parallel_degree > 1 else self.context_parallel_degree
17461747
if self.hybrid_parallel_topo_order == "pp_first":
17471748
order = ["pp", "dp", "mp"]
1748-
degree = [self.pipeline_parallel_degree, self.dataset_world_size, self.tensor_parallel_degree]
1749+
1750+
degree = [
1751+
self.pipeline_parallel_degree,
1752+
self.dataset_world_size,
1753+
self.tensor_parallel_degree,
1754+
]
17491755
elif self.hybrid_parallel_topo_order == "sharding_first":
17501756
order = ["dp", "pp", "mp"]
1751-
degree = [self.dataset_world_size, self.pipeline_parallel_degree, self.tensor_parallel_degree]
1757+
degree = [
1758+
self.dataset_world_size,
1759+
self.pipeline_parallel_degree,
1760+
self.tensor_parallel_degree,
1761+
]
1762+
if sep_degree > 1:
1763+
order.insert(-1, "sep")
1764+
degree.insert(-1, sep_degree)
17521765
mesh_dims = list(zip(order, degree))
17531766
fleet.auto.create_mesh(mesh_dims)
17541767

@@ -1767,6 +1780,7 @@ def is_segment_parallel_supported():
17671780
"dp_degree": self.dataset_world_size,
17681781
"mp_degree": self.tensor_parallel_degree,
17691782
"pp_degree": self.pipeline_parallel_degree,
1783+
"sep_degree": sep_degree,
17701784
"order": order,
17711785
}
17721786
fleet.init(is_collective=True, strategy=strategy)

paddlenlp/transformers/context_parallel_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727

2828
import paddle
29+
from paddle.distributed.auto_parallel.ring_attention import shard_seq_load_balance
2930
from paddle.distributed.fleet import fleet
3031

3132

@@ -62,3 +63,22 @@ def do_split_sequence_dim_load_balance(data, rank, degree):
6263
else:
6364
raise ValueError(f"the inputs should be a list or a dict, but is type: {type(inputs)}")
6465
return res
66+
67+
68+
def split_sequence_dim_load_balance(inputs):
69+
"""
70+
for auto_parallel mode
71+
"""
72+
if isinstance(inputs, paddle.Tensor):
73+
return shard_seq_load_balance(inputs, 1)
74+
elif isinstance(inputs, dict):
75+
res = {}
76+
for k, tensor in inputs.items():
77+
res[k] = shard_seq_load_balance(tensor, 1)
78+
elif isinstance(inputs, list):
79+
res = []
80+
for tensor in inputs:
81+
res.append(shard_seq_load_balance(tensor, 1))
82+
else:
83+
raise ValueError(f"the inputs should be a list or a dict, but is type: {type(inputs)}")
84+
return res

0 commit comments

Comments
 (0)