28
28
from paddle .distributed import fleet
29
29
from paddle .distributed .fleet .utils import recompute
30
30
31
+ from ..segment_parallel_utils import sep_reshard_layer
32
+
31
33
try :
32
34
from paddle .incubate .nn .functional import fused_rotary_position_embedding
33
35
except ImportError :
@@ -200,12 +202,24 @@ def scaled_dot_product_attention(
200
202
return (attn_output , attn_weights ) if output_attentions else attn_output
201
203
202
204
203
- def get_colwise_placement (has_seq_mesh ):
204
- return [dist .Replicate (), dist .Replicate (), dist .Shard (1 )] if has_seq_mesh else [dist .Replicate (), dist .Shard (1 )]
205
+ def get_colwise_placement (has_seq_mesh , has_seq_parallel ):
206
+ if has_seq_mesh :
207
+ if has_seq_parallel : # not support mp+sep now
208
+ return [dist .Replicate (), dist .Replicate (), dist .Replicate ()]
209
+ else :
210
+ return [dist .Replicate (), dist .Replicate (), dist .Shard (1 )]
211
+ else :
212
+ return [dist .Replicate (), dist .Shard (1 )]
205
213
206
214
207
- def get_rowwise_placement (has_seq_mesh ):
208
- return [dist .Replicate (), dist .Replicate (), dist .Shard (0 )] if has_seq_mesh else [dist .Replicate (), dist .Shard (0 )]
215
+ def get_rowwise_placement (has_seq_mesh , has_seq_parallel ):
216
+ if has_seq_mesh :
217
+ if has_seq_parallel : # not support mp+sep now
218
+ return [dist .Replicate (), dist .Replicate (), dist .Replicate ()]
219
+ else :
220
+ return [dist .Replicate (), dist .Replicate (), dist .Shard (0 )]
221
+ else :
222
+ return [dist .Replicate (), dist .Shard (0 )]
209
223
210
224
211
225
def get_replicate_placement (has_seq_mesh ):
@@ -266,28 +280,28 @@ def __init__(self, config, ipp: Optional[int] = None):
266
280
self .gate_up_fused_proj .weight = dist .shard_tensor (
267
281
self .gate_up_fused_proj .weight ,
268
282
get_mesh (self .ipp ),
269
- get_colwise_placement (has_seq_mesh ),
283
+ get_colwise_placement (has_seq_mesh , self . config . sep_parallel_degree > 1 ),
270
284
)
271
285
else :
272
286
self .gate_proj = nn .Linear (self .hidden_size , self .intermediate_size , bias_attr = False )
273
287
self .gate_proj .weight = dist .shard_tensor (
274
288
self .gate_proj .weight ,
275
289
get_mesh (self .ipp ),
276
- get_colwise_placement (has_seq_mesh ),
290
+ get_colwise_placement (has_seq_mesh , self . config . sep_parallel_degree > 1 ),
277
291
)
278
292
279
293
self .up_proj = nn .Linear (self .hidden_size , self .intermediate_size , bias_attr = False )
280
294
self .up_proj .weight = dist .shard_tensor (
281
295
self .up_proj .weight ,
282
296
get_mesh (self .ipp ),
283
- get_colwise_placement (has_seq_mesh ),
297
+ get_colwise_placement (has_seq_mesh , self . config . sep_parallel_degree > 1 ),
284
298
)
285
299
286
300
self .down_proj = nn .Linear (self .intermediate_size , self .hidden_size , bias_attr = False )
287
301
self .down_proj .weight = dist .shard_tensor (
288
302
self .down_proj .weight ,
289
303
get_mesh (self .ipp ),
290
- get_rowwise_placement (has_seq_mesh ),
304
+ get_rowwise_placement (has_seq_mesh , self . config . sep_parallel_degree > 1 ),
291
305
)
292
306
293
307
def forward (self , x ):
@@ -348,7 +362,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
348
362
self .qkv_proj .weight = dist .shard_tensor (
349
363
self .qkv_proj .weight ,
350
364
get_mesh (self .ipp ),
351
- get_colwise_placement (self .has_seq_mesh ),
365
+ get_colwise_placement (self .has_seq_mesh , self . config . sep_parallel_degree > 1 ),
352
366
)
353
367
354
368
else :
@@ -360,7 +374,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
360
374
self .q_proj .weight = dist .shard_tensor (
361
375
self .q_proj .weight ,
362
376
get_mesh (self .ipp ),
363
- get_colwise_placement (self .has_seq_mesh ),
377
+ get_colwise_placement (self .has_seq_mesh , self . config . sep_parallel_degree > 1 ),
364
378
)
365
379
366
380
self .k_proj = nn .Linear (
@@ -371,7 +385,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
371
385
self .k_proj .weight = dist .shard_tensor (
372
386
self .k_proj .weight ,
373
387
get_mesh (self .ipp ),
374
- get_colwise_placement (self .has_seq_mesh ),
388
+ get_colwise_placement (self .has_seq_mesh , self . config . sep_parallel_degree > 1 ),
375
389
)
376
390
377
391
self .v_proj = nn .Linear (
@@ -382,7 +396,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
382
396
self .v_proj .weight = dist .shard_tensor (
383
397
self .v_proj .weight ,
384
398
get_mesh (self .ipp ),
385
- get_colwise_placement (self .has_seq_mesh ),
399
+ get_colwise_placement (self .has_seq_mesh , self . config . sep_parallel_degree > 1 ),
386
400
)
387
401
388
402
self .o_proj = nn .Linear (
@@ -393,13 +407,16 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
393
407
self .o_proj .weight = dist .shard_tensor (
394
408
self .o_proj .weight ,
395
409
get_mesh (self .ipp ),
396
- get_rowwise_placement (self .has_seq_mesh ),
410
+ get_rowwise_placement (self .has_seq_mesh , self . config . sep_parallel_degree > 1 ),
397
411
)
398
412
399
413
if config .rope :
400
414
self ._init_rope ()
401
415
402
416
self .config = config
417
+ if config .sep_parallel_degree > 1 :
418
+ assert self .num_key_value_heads % config .sep_parallel_degree == 0
419
+ assert self .num_heads % config .sep_parallel_degree == 0
403
420
404
421
def _init_rope (self ):
405
422
if self .config .rope_scaling_type is None :
@@ -456,37 +473,108 @@ def forward(
456
473
)
457
474
458
475
if self .fuse_attention_qkv and not enable_fuse_ffn_qkv_pass ():
459
- target_shape = [0 , 0 , self .num_key_value_heads , (self .num_key_value_groups + 2 ) * self .head_dim ]
460
476
mix_layer = self .qkv_proj (hidden_states )
461
- mix_layer = paddle .reshape_ (mix_layer , target_shape )
477
+ # NOTE for GQA attention fusion (compatible with MHA and MQA):
478
+ # The weight for qkv_proj is in shape like [hidden_size, hidden_size + 2 * num_kv_heads * head_dim].
479
+ # After the projection, the mix_layer is in shape like [b, s, hidden_size + 2 * num_kv_heads * head_dim].
480
+ # Reshape the mix_layer into a shape like [b, s, num_kv_heads, (num_groups + 2) * head_dim],
481
+ # where num_groups = num_q_heads // num_kv_heads.
482
+ # Split the mix_layer on the last axis into three sections [num_groups * head_dim, head_dim, head_dim]
483
+ # to represent the q, k and v respectively.
484
+ # The q is in the shape like [b, s, num_kv_heads, num_groups * head_dim].
485
+ # The k and v are in the shape like [b, s, num_kv_heads, head_dim].
486
+ # Under MHA, the q is ready for the following calculation since num_kv_heads == num_q_heads,
487
+ # But for the GQA or MQA, q should be reshaped into [b, s, num_q_heads, head_dim].
488
+ if self .config .sep_parallel_degree > 1 :
489
+ if self .config .sequence_parallel :
490
+ raise ValueError (
491
+ "Sep parallel cannot be used with sequence parallel, "
492
+ "because paddle auto parallel does not support "
493
+ "reshard one dim twice."
494
+ )
495
+
496
+ # [bs, seq_len / sep, num_head, head_dim] -> [bs, seq_len, num_head / sep, head_dim]
497
+ mix_layer = sep_reshard_layer (
498
+ mix_layer ,
499
+ split_axis = 2 ,
500
+ concat_axis = 1 ,
501
+ )
502
+ mix_layer = paddle .reshape_ (
503
+ mix_layer , [0 , self .seq_length , - 1 , (self .num_key_value_groups + 2 ) * self .head_dim ]
504
+ ) # [bs, seq_len, num_head/k, 3*head_dim], k is sep degree
505
+ else :
506
+ target_shape = [0 , 0 , self .num_key_value_heads , (self .num_key_value_groups + 2 ) * self .head_dim ]
507
+ mix_layer = paddle .reshape_ (mix_layer , target_shape )
508
+
462
509
query_states , key_states , value_states = paddle .split (
463
510
mix_layer ,
464
511
num_or_sections = [self .num_key_value_groups * self .head_dim , self .head_dim , self .head_dim ],
465
512
axis = - 1 ,
466
513
)
467
514
if self .gqa_or_mqa :
468
515
query_states = paddle .reshape (query_states , [0 , 0 , self .num_heads , self .head_dim ])
516
+ if self .config .sequence_parallel and self .config .sep_parallel_degree <= 1 :
517
+ # [seq_len, bs, num_head * head_dim] -> [bs, seq_len, num_head * head_dim] (if sequence_parallel)
518
+ # FA and rope not support sequence first
519
+ query_states = paddle .transpose (query_states , [1 , 0 , 2 , 3 ])
520
+ key_states = paddle .transpose (key_states , [1 , 0 , 2 , 3 ])
521
+ value_states = paddle .transpose (value_states , [1 , 0 , 2 , 3 ])
469
522
else :
470
- target_query_shape = [0 , 0 , self .num_heads , self .head_dim ]
471
- target_key_value_shape = [0 , 0 , self .num_key_value_heads , self .head_dim ]
472
-
473
- query_states = self .q_proj (hidden_states ).reshape (shape = target_query_shape )
474
- key_states = self .k_proj (hidden_states ).reshape (shape = target_key_value_shape )
475
- value_states = self .v_proj (hidden_states ).reshape (shape = target_key_value_shape )
476
-
477
- if self .config .sequence_parallel :
478
- # [seq_len, bs, num_head * head_dim] -> [bs, seq_len, num_head * head_dim] (if sequence_parallel)
479
- # FA and rope not support sequence first
480
- query_states = paddle .transpose (query_states , [1 , 0 , 2 , 3 ])
481
- key_states = paddle .transpose (key_states , [1 , 0 , 2 , 3 ])
482
- value_states = paddle .transpose (value_states , [1 , 0 , 2 , 3 ])
523
+ if self .config .sep_parallel_degree > 1 :
524
+ query_states = self .q_proj (hidden_states )
525
+ key_states = self .k_proj (hidden_states )
526
+ value_states = self .v_proj (hidden_states )
527
+ if self .config .sequence_parallel :
528
+ raise ValueError (
529
+ "Sep parallel cannot be used with sequence parallel, "
530
+ "because paddle auto parallel does not support "
531
+ "reshard one dim twice."
532
+ )
483
533
534
+ query_states = sep_reshard_layer (
535
+ query_states ,
536
+ split_axis = 2 ,
537
+ concat_axis = 1 ,
538
+ )
539
+ key_states = sep_reshard_layer (
540
+ key_states ,
541
+ split_axis = 2 ,
542
+ concat_axis = 1 ,
543
+ )
544
+ value_states = sep_reshard_layer (
545
+ value_states ,
546
+ split_axis = 2 ,
547
+ concat_axis = 1 ,
548
+ )
549
+ query_states = paddle .reshape (
550
+ query_states , shape = [0 , self .seq_length , - 1 , self .head_dim ]
551
+ ) # [bs, seq_len, num_head/k, head_dim], k is sep degree
552
+ key_states = paddle .reshape (query_states , shape = [0 , self .seq_length , - 1 , self .head_dim ])
553
+ value_states = paddle .reshape (value_states , shape = [0 , self .seq_length , - 1 , self .head_dim ])
554
+ else :
555
+ target_query_shape = [0 , 0 , self .num_heads , self .head_dim ]
556
+ target_key_value_shape = [0 , 0 , self .num_key_value_heads , self .head_dim ]
557
+
558
+ query_states = self .q_proj (hidden_states ).reshape (shape = target_query_shape )
559
+ key_states = self .k_proj (hidden_states ).reshape (shape = target_key_value_shape )
560
+ value_states = self .v_proj (hidden_states ).reshape (shape = target_key_value_shape )
561
+
562
+ if self .config .sequence_parallel :
563
+ # [seq_len, bs, num_head * head_dim] -> [bs, seq_len, num_head * head_dim] (if sequence_parallel)
564
+ # FA and rope not support sequence first
565
+ query_states = paddle .transpose (query_states , [1 , 0 , 2 , 3 ])
566
+ key_states = paddle .transpose (key_states , [1 , 0 , 2 , 3 ])
567
+ value_states = paddle .transpose (value_states , [1 , 0 , 2 , 3 ])
484
568
kv_seq_len = key_states .shape [- 3 ]
485
569
486
570
if past_key_value is not None :
487
571
kv_seq_len += past_key_value [0 ].shape [- 3 ]
488
572
489
573
if self .config .rope :
574
+ query_placement = query_states .placements
575
+ if self .config .sep_parallel_degree > 1 :
576
+ batch_size , seq_length , _ , _ = query_states .shape
577
+ position_ids = paddle .arange (seq_length , dtype = "int64" ).expand ((batch_size , seq_length ))
490
578
if self .config .context_parallel_degree > 1 :
491
579
mesh = dist .auto_parallel .get_mesh ()
492
580
group = mesh ._get_group ("sep" )
@@ -516,16 +604,16 @@ def forward(
516
604
self .rotary_emb ,
517
605
self .config .context_parallel_degree ,
518
606
)
519
- if self .has_seq_mesh :
607
+ if self .config . context_parallel_degree > 1 :
520
608
query_states = dist .reshard (
521
609
query_states ,
522
610
get_mesh (self .ipp ),
523
- [dist .Shard (0 ), dist .Shard (1 ), dist .Shard (2 )],
611
+ query_placement , # [dist.Shard(0), dist.Shard(1), dist.Shard(2)],
524
612
)
525
613
key_states = dist .reshard (
526
614
key_states ,
527
615
get_mesh (self .ipp ),
528
- [dist .Shard (0 ), dist .Shard (1 ), dist .Shard (2 )],
616
+ query_placement , # [dist.Shard(0), dist.Shard(1), dist.Shard(2)],
529
617
)
530
618
else :
531
619
cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
@@ -1282,7 +1370,7 @@ def __init__(self, config: LlamaConfig):
1282
1370
self .weight = dist .shard_tensor (
1283
1371
self .weight ,
1284
1372
get_mesh (- 1 ),
1285
- get_colwise_placement (has_seq_mesh ),
1373
+ get_colwise_placement (has_seq_mesh , self . config . sep_parallel_degree > 1 ),
1286
1374
)
1287
1375
1288
1376
def forward (self , hidden_states , tensor_parallel_output = None ):
0 commit comments