Skip to content

Commit f4846c1

Browse files
authored
【AutoParallel】 add sep strategy for llama (#10837)
* autoparallel add sep strategy for llama * fix ci * fix ci error; * fix ci error;
1 parent 58bed42 commit f4846c1

File tree

5 files changed

+190
-40
lines changed

5 files changed

+190
-40
lines changed

paddlenlp/trainer/auto_trainer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929

3030
from paddlenlp.trainer import Trainer
3131

32-
# from ..transformers.segment_parallel_utils import split_inputs_sequence_dim
33-
from ..transformers.context_parallel_utils import split_sequence_dim_load_balance
32+
from ..transformers.context_parallel_utils import auto_split_sequence_dim_load_balance
3433
from ..transformers.model_utils import clean_model_class_name, unwrap_model
34+
from ..transformers.segment_parallel_utils import auto_split_inputs_sequence_dim
3535
from ..utils.batch_sampler import DistributedBatchSampler as NlpDistributedBatchSampler
3636
from ..utils.env import (
3737
PREFIX_CHECKPOINT_DIR,
@@ -570,8 +570,10 @@ def _inner_training_loop(
570570
if step_control % args.gradient_accumulation_steps == 0:
571571
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
572572
self.timers and self.timers("forward-backward").start()
573+
if self.args.sep_parallel_degree > 1 and self.args.split_inputs_sequence_dim:
574+
inputs = auto_split_inputs_sequence_dim(inputs)
573575
if self.args.context_parallel_degree > 1 and self.args.split_inputs_sequence_dim:
574-
inputs = split_sequence_dim_load_balance(inputs)
576+
inputs = auto_split_sequence_dim_load_balance(inputs)
575577
tr_loss_step = self.training_step(model, inputs)
576578

577579
with _exec_mode_guard("dynamic"):

paddlenlp/transformers/context_parallel_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def do_split_sequence_dim_load_balance(data, rank, degree):
6565
return res
6666

6767

68-
def split_sequence_dim_load_balance(inputs):
68+
def auto_split_sequence_dim_load_balance(inputs):
6969
"""
7070
for auto_parallel mode
7171
"""

paddlenlp/transformers/llama/modeling_auto.py

Lines changed: 120 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from paddle.distributed import fleet
2929
from paddle.distributed.fleet.utils import recompute
3030

31+
from ..segment_parallel_utils import sep_reshard_layer
32+
3133
try:
3234
from paddle.incubate.nn.functional import fused_rotary_position_embedding
3335
except ImportError:
@@ -200,12 +202,24 @@ def scaled_dot_product_attention(
200202
return (attn_output, attn_weights) if output_attentions else attn_output
201203

202204

203-
def get_colwise_placement(has_seq_mesh):
204-
return [dist.Replicate(), dist.Replicate(), dist.Shard(1)] if has_seq_mesh else [dist.Replicate(), dist.Shard(1)]
205+
def get_colwise_placement(has_seq_mesh, has_seq_parallel):
206+
if has_seq_mesh:
207+
if has_seq_parallel: # not support mp+sep now
208+
return [dist.Replicate(), dist.Replicate(), dist.Replicate()]
209+
else:
210+
return [dist.Replicate(), dist.Replicate(), dist.Shard(1)]
211+
else:
212+
return [dist.Replicate(), dist.Shard(1)]
205213

206214

207-
def get_rowwise_placement(has_seq_mesh):
208-
return [dist.Replicate(), dist.Replicate(), dist.Shard(0)] if has_seq_mesh else [dist.Replicate(), dist.Shard(0)]
215+
def get_rowwise_placement(has_seq_mesh, has_seq_parallel):
216+
if has_seq_mesh:
217+
if has_seq_parallel: # not support mp+sep now
218+
return [dist.Replicate(), dist.Replicate(), dist.Replicate()]
219+
else:
220+
return [dist.Replicate(), dist.Replicate(), dist.Shard(0)]
221+
else:
222+
return [dist.Replicate(), dist.Shard(0)]
209223

210224

211225
def get_replicate_placement(has_seq_mesh):
@@ -266,28 +280,28 @@ def __init__(self, config, ipp: Optional[int] = None):
266280
self.gate_up_fused_proj.weight = dist.shard_tensor(
267281
self.gate_up_fused_proj.weight,
268282
get_mesh(self.ipp),
269-
get_colwise_placement(has_seq_mesh),
283+
get_colwise_placement(has_seq_mesh, self.config.sep_parallel_degree > 1),
270284
)
271285
else:
272286
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
273287
self.gate_proj.weight = dist.shard_tensor(
274288
self.gate_proj.weight,
275289
get_mesh(self.ipp),
276-
get_colwise_placement(has_seq_mesh),
290+
get_colwise_placement(has_seq_mesh, self.config.sep_parallel_degree > 1),
277291
)
278292

279293
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
280294
self.up_proj.weight = dist.shard_tensor(
281295
self.up_proj.weight,
282296
get_mesh(self.ipp),
283-
get_colwise_placement(has_seq_mesh),
297+
get_colwise_placement(has_seq_mesh, self.config.sep_parallel_degree > 1),
284298
)
285299

286300
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False)
287301
self.down_proj.weight = dist.shard_tensor(
288302
self.down_proj.weight,
289303
get_mesh(self.ipp),
290-
get_rowwise_placement(has_seq_mesh),
304+
get_rowwise_placement(has_seq_mesh, self.config.sep_parallel_degree > 1),
291305
)
292306

293307
def forward(self, x):
@@ -348,7 +362,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
348362
self.qkv_proj.weight = dist.shard_tensor(
349363
self.qkv_proj.weight,
350364
get_mesh(self.ipp),
351-
get_colwise_placement(self.has_seq_mesh),
365+
get_colwise_placement(self.has_seq_mesh, self.config.sep_parallel_degree > 1),
352366
)
353367

354368
else:
@@ -360,7 +374,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
360374
self.q_proj.weight = dist.shard_tensor(
361375
self.q_proj.weight,
362376
get_mesh(self.ipp),
363-
get_colwise_placement(self.has_seq_mesh),
377+
get_colwise_placement(self.has_seq_mesh, self.config.sep_parallel_degree > 1),
364378
)
365379

366380
self.k_proj = nn.Linear(
@@ -371,7 +385,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
371385
self.k_proj.weight = dist.shard_tensor(
372386
self.k_proj.weight,
373387
get_mesh(self.ipp),
374-
get_colwise_placement(self.has_seq_mesh),
388+
get_colwise_placement(self.has_seq_mesh, self.config.sep_parallel_degree > 1),
375389
)
376390

377391
self.v_proj = nn.Linear(
@@ -382,7 +396,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
382396
self.v_proj.weight = dist.shard_tensor(
383397
self.v_proj.weight,
384398
get_mesh(self.ipp),
385-
get_colwise_placement(self.has_seq_mesh),
399+
get_colwise_placement(self.has_seq_mesh, self.config.sep_parallel_degree > 1),
386400
)
387401

388402
self.o_proj = nn.Linear(
@@ -393,13 +407,16 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
393407
self.o_proj.weight = dist.shard_tensor(
394408
self.o_proj.weight,
395409
get_mesh(self.ipp),
396-
get_rowwise_placement(self.has_seq_mesh),
410+
get_rowwise_placement(self.has_seq_mesh, self.config.sep_parallel_degree > 1),
397411
)
398412

399413
if config.rope:
400414
self._init_rope()
401415

402416
self.config = config
417+
if config.sep_parallel_degree > 1:
418+
assert self.num_key_value_heads % config.sep_parallel_degree == 0
419+
assert self.num_heads % config.sep_parallel_degree == 0
403420

404421
def _init_rope(self):
405422
if self.config.rope_scaling_type is None:
@@ -456,37 +473,108 @@ def forward(
456473
)
457474

458475
if self.fuse_attention_qkv and not enable_fuse_ffn_qkv_pass():
459-
target_shape = [0, 0, self.num_key_value_heads, (self.num_key_value_groups + 2) * self.head_dim]
460476
mix_layer = self.qkv_proj(hidden_states)
461-
mix_layer = paddle.reshape_(mix_layer, target_shape)
477+
# NOTE for GQA attention fusion (compatible with MHA and MQA):
478+
# The weight for qkv_proj is in shape like [hidden_size, hidden_size + 2 * num_kv_heads * head_dim].
479+
# After the projection, the mix_layer is in shape like [b, s, hidden_size + 2 * num_kv_heads * head_dim].
480+
# Reshape the mix_layer into a shape like [b, s, num_kv_heads, (num_groups + 2) * head_dim],
481+
# where num_groups = num_q_heads // num_kv_heads.
482+
# Split the mix_layer on the last axis into three sections [num_groups * head_dim, head_dim, head_dim]
483+
# to represent the q, k and v respectively.
484+
# The q is in the shape like [b, s, num_kv_heads, num_groups * head_dim].
485+
# The k and v are in the shape like [b, s, num_kv_heads, head_dim].
486+
# Under MHA, the q is ready for the following calculation since num_kv_heads == num_q_heads,
487+
# But for the GQA or MQA, q should be reshaped into [b, s, num_q_heads, head_dim].
488+
if self.config.sep_parallel_degree > 1:
489+
if self.config.sequence_parallel:
490+
raise ValueError(
491+
"Sep parallel cannot be used with sequence parallel, "
492+
"because paddle auto parallel does not support "
493+
"reshard one dim twice."
494+
)
495+
496+
# [bs, seq_len / sep, num_head, head_dim] -> [bs, seq_len, num_head / sep, head_dim]
497+
mix_layer = sep_reshard_layer(
498+
mix_layer,
499+
split_axis=2,
500+
concat_axis=1,
501+
)
502+
mix_layer = paddle.reshape_(
503+
mix_layer, [0, self.seq_length, -1, (self.num_key_value_groups + 2) * self.head_dim]
504+
) # [bs, seq_len, num_head/k, 3*head_dim], k is sep degree
505+
else:
506+
target_shape = [0, 0, self.num_key_value_heads, (self.num_key_value_groups + 2) * self.head_dim]
507+
mix_layer = paddle.reshape_(mix_layer, target_shape)
508+
462509
query_states, key_states, value_states = paddle.split(
463510
mix_layer,
464511
num_or_sections=[self.num_key_value_groups * self.head_dim, self.head_dim, self.head_dim],
465512
axis=-1,
466513
)
467514
if self.gqa_or_mqa:
468515
query_states = paddle.reshape(query_states, [0, 0, self.num_heads, self.head_dim])
516+
if self.config.sequence_parallel and self.config.sep_parallel_degree <= 1:
517+
# [seq_len, bs, num_head * head_dim] -> [bs, seq_len, num_head * head_dim] (if sequence_parallel)
518+
# FA and rope not support sequence first
519+
query_states = paddle.transpose(query_states, [1, 0, 2, 3])
520+
key_states = paddle.transpose(key_states, [1, 0, 2, 3])
521+
value_states = paddle.transpose(value_states, [1, 0, 2, 3])
469522
else:
470-
target_query_shape = [0, 0, self.num_heads, self.head_dim]
471-
target_key_value_shape = [0, 0, self.num_key_value_heads, self.head_dim]
472-
473-
query_states = self.q_proj(hidden_states).reshape(shape=target_query_shape)
474-
key_states = self.k_proj(hidden_states).reshape(shape=target_key_value_shape)
475-
value_states = self.v_proj(hidden_states).reshape(shape=target_key_value_shape)
476-
477-
if self.config.sequence_parallel:
478-
# [seq_len, bs, num_head * head_dim] -> [bs, seq_len, num_head * head_dim] (if sequence_parallel)
479-
# FA and rope not support sequence first
480-
query_states = paddle.transpose(query_states, [1, 0, 2, 3])
481-
key_states = paddle.transpose(key_states, [1, 0, 2, 3])
482-
value_states = paddle.transpose(value_states, [1, 0, 2, 3])
523+
if self.config.sep_parallel_degree > 1:
524+
query_states = self.q_proj(hidden_states)
525+
key_states = self.k_proj(hidden_states)
526+
value_states = self.v_proj(hidden_states)
527+
if self.config.sequence_parallel:
528+
raise ValueError(
529+
"Sep parallel cannot be used with sequence parallel, "
530+
"because paddle auto parallel does not support "
531+
"reshard one dim twice."
532+
)
483533

534+
query_states = sep_reshard_layer(
535+
query_states,
536+
split_axis=2,
537+
concat_axis=1,
538+
)
539+
key_states = sep_reshard_layer(
540+
key_states,
541+
split_axis=2,
542+
concat_axis=1,
543+
)
544+
value_states = sep_reshard_layer(
545+
value_states,
546+
split_axis=2,
547+
concat_axis=1,
548+
)
549+
query_states = paddle.reshape(
550+
query_states, shape=[0, self.seq_length, -1, self.head_dim]
551+
) # [bs, seq_len, num_head/k, head_dim], k is sep degree
552+
key_states = paddle.reshape(query_states, shape=[0, self.seq_length, -1, self.head_dim])
553+
value_states = paddle.reshape(value_states, shape=[0, self.seq_length, -1, self.head_dim])
554+
else:
555+
target_query_shape = [0, 0, self.num_heads, self.head_dim]
556+
target_key_value_shape = [0, 0, self.num_key_value_heads, self.head_dim]
557+
558+
query_states = self.q_proj(hidden_states).reshape(shape=target_query_shape)
559+
key_states = self.k_proj(hidden_states).reshape(shape=target_key_value_shape)
560+
value_states = self.v_proj(hidden_states).reshape(shape=target_key_value_shape)
561+
562+
if self.config.sequence_parallel:
563+
# [seq_len, bs, num_head * head_dim] -> [bs, seq_len, num_head * head_dim] (if sequence_parallel)
564+
# FA and rope not support sequence first
565+
query_states = paddle.transpose(query_states, [1, 0, 2, 3])
566+
key_states = paddle.transpose(key_states, [1, 0, 2, 3])
567+
value_states = paddle.transpose(value_states, [1, 0, 2, 3])
484568
kv_seq_len = key_states.shape[-3]
485569

486570
if past_key_value is not None:
487571
kv_seq_len += past_key_value[0].shape[-3]
488572

489573
if self.config.rope:
574+
query_placement = query_states.placements
575+
if self.config.sep_parallel_degree > 1:
576+
batch_size, seq_length, _, _ = query_states.shape
577+
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))
490578
if self.config.context_parallel_degree > 1:
491579
mesh = dist.auto_parallel.get_mesh()
492580
group = mesh._get_group("sep")
@@ -516,16 +604,16 @@ def forward(
516604
self.rotary_emb,
517605
self.config.context_parallel_degree,
518606
)
519-
if self.has_seq_mesh:
607+
if self.config.context_parallel_degree > 1:
520608
query_states = dist.reshard(
521609
query_states,
522610
get_mesh(self.ipp),
523-
[dist.Shard(0), dist.Shard(1), dist.Shard(2)],
611+
query_placement, # [dist.Shard(0), dist.Shard(1), dist.Shard(2)],
524612
)
525613
key_states = dist.reshard(
526614
key_states,
527615
get_mesh(self.ipp),
528-
[dist.Shard(0), dist.Shard(1), dist.Shard(2)],
616+
query_placement, # [dist.Shard(0), dist.Shard(1), dist.Shard(2)],
529617
)
530618
else:
531619
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
@@ -1282,7 +1370,7 @@ def __init__(self, config: LlamaConfig):
12821370
self.weight = dist.shard_tensor(
12831371
self.weight,
12841372
get_mesh(-1),
1285-
get_colwise_placement(has_seq_mesh),
1373+
get_colwise_placement(has_seq_mesh, self.config.sep_parallel_degree > 1),
12861374
)
12871375

12881376
def forward(self, hidden_states, tensor_parallel_output=None):

paddlenlp/transformers/llama/modeling_network.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,16 +1369,28 @@ def auto_dist_config(self, prefix=""):
13691369
"pp_config": {"split_spec": f"{prefix}llama.layers", "global_spec": f"{prefix}llama.global_layer"},
13701370
"cp_config": {
13711371
"parallelize_plan": {
1372-
f"{prefix}llama.layers.*.self_attn.rope_func": [
1373-
PrepareLayerInput(layer_input_rope_hook),
1374-
PrepareLayerOutput(layer_output_rope_hook),
1375-
],
13761372
f"{prefix}llama.layers.*.self_attn.sdpa": dist.ContextParallel(
13771373
backend="p2p" if self.config.context_parallel_degree > 1 else "all2all"
13781374
),
13791375
}
13801376
},
13811377
}
1378+
if self.config.context_parallel_degree > 1:
1379+
config["cp_config"]["parallelize_plan"].update(
1380+
{
1381+
f"{prefix}llama.layers.*.self_attn.rope_func": [
1382+
PrepareLayerInput(layer_input_rope_hook),
1383+
PrepareLayerOutput(layer_output_rope_hook),
1384+
]
1385+
}
1386+
)
1387+
elif self.config.sep_parallel_degree > 1:
1388+
# fuse_rope is not support dtensor spmd yet,thus need to extraly reshard sequence dim
1389+
config["cp_config"]["parallelize_plan"].update(
1390+
{
1391+
f"{prefix}llama.layers.*.self_attn.rope_func": PrepareLayerOutput(layer_output_rope_hook),
1392+
}
1393+
)
13821394

13831395
return config
13841396

paddlenlp/transformers/segment_parallel_utils.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,51 @@ def forward(
135135
)
136136
reshard_tensor.reshape_(shape)
137137
return reshard_tensor
138+
139+
140+
def sep_reshard_layer(input, split_axis, concat_axis):
141+
# [auto_parallel] do alltoall operation to reshard input from [Shard(concat_axis)] to [Shard[split_axis]]
142+
sep_axis = input.process_mesh.dim_names.index("sep")
143+
mp_axis = input.process_mesh.dim_names.index("mp")
144+
145+
input_placements = input.placements
146+
if input_placements[sep_axis] != dist.Shard(concat_axis):
147+
raise ValueError(
148+
f"Input placements for 'sep' axis should be Shard({concat_axis}), but got {input_placements[sep_axis]}"
149+
)
150+
151+
input_placements[sep_axis] = dist.Shard(split_axis)
152+
153+
if input_placements[sep_axis] == input_placements[mp_axis]:
154+
input_placements[sep_axis] = dist.Shard(split_axis, shard_order=0)
155+
input_placements[mp_axis] = dist.Shard(split_axis, shard_order=1)
156+
out = dist.reshard(input, input.process_mesh, input_placements)
157+
return out
158+
159+
160+
def auto_split_inputs_sequence_dim(inputs):
161+
def do_split_sequence_dim(data):
162+
if data is None:
163+
return None
164+
165+
data_mesh = data.process_mesh
166+
data_placements = data.placements
167+
sep_axis = data_mesh.dim_names.index("sep")
168+
# shard along sep axis
169+
data_placements[sep_axis] = dist.Shard(1)
170+
data = dist.reshard(data, data_mesh, data_placements)
171+
return data
172+
173+
if isinstance(inputs, paddle.Tensor):
174+
return do_split_sequence_dim(inputs)
175+
elif isinstance(inputs, dict):
176+
res = {}
177+
for k, tensor in inputs.items():
178+
res[k] = do_split_sequence_dim(tensor)
179+
elif isinstance(inputs, list):
180+
res = []
181+
for tensor in inputs:
182+
res.append(do_split_sequence_dim(tensor))
183+
else:
184+
raise ValueError(f"the inputs should be a tensor, list or dict, but is type: {type(inputs)}")
185+
return res

0 commit comments

Comments
 (0)