82
82
83
83
from paddle .distributed .fleet .meta_parallel .zero_bubble_utils import WeightGradStore
84
84
85
- from ..fp8_utils import FP8KeepXLinear , FP8Linear , FP8Mlp
85
+ from ..fp8_utils import FP8KeepXLinear , FP8Linear , FP8Mlp , FP8LinearFunctionBase , cache_fp8_weight
86
86
from .fp8_linear import Linear
87
87
88
88
DSV3_USE_FP8_GEMM = os .getenv ("DSV3_USE_FP8_GEMM" , "False" ).lower () == "true"
@@ -961,9 +961,9 @@ def __init__(self, config: DeepseekV2Config, norm_weight=None, norm_eps=None):
961
961
using_post_norm_recompute = self .using_post_norm_recompute ,
962
962
)
963
963
964
- moe_grad_group = fleet .get_hybrid_communicate_group ().expert_grad_comm_group
965
- for p in self .experts .parameters ():
966
- setattr (p , "color" , {"color" : "moe_expert" , "group" : moe_grad_group })
964
+ # moe_grad_group = fleet.get_hybrid_communicate_group().expert_grad_comm_group
965
+ # for p in self.experts.parameters():
966
+ # setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group})
967
967
968
968
self .alpha = config .aux_loss_alpha
969
969
if config .n_shared_experts is not None :
@@ -995,6 +995,8 @@ def quantize_weights(weight_list, weight_obj=None):
995
995
"""Helper function to quantize a list of weights."""
996
996
if weight_obj is None :
997
997
weight_obj = weight_list [0 ]
998
+ if hasattr ( weight_obj , "fp8_weight_stacked" ):
999
+ return
998
1000
999
1001
# Quantize without transpose
1000
1002
fp8_weight , fp8_scale = paddle .incubate .nn .functional .fused_stack_transpose_quant (
@@ -1025,6 +1027,9 @@ def quantize_weights(weight_list, weight_obj=None):
1025
1027
if expert is not None :
1026
1028
quantize_weights ([expert .w1 ])
1027
1029
quantize_weights ([expert .w1 ])
1030
+
1031
+ if self .config .n_shared_experts is not None :
1032
+ self .shared_experts .fp8_quant_weight ()
1028
1033
1029
1034
def forward (self , hidden_states ):
1030
1035
if self .using_post_norm_recompute :
@@ -1189,12 +1194,18 @@ def forward(
1189
1194
1190
1195
bsz = q_init .shape [0 ]
1191
1196
q_ln_t , q_ln_invar = fused_ln .fused_rms_norm (q_init , q_ln_weight , eps )
1192
- q = paddle .matmul (q_ln_t , q_up_weight )
1197
+ #q = paddle.matmul(q_ln_t, q_up_weight)
1198
+ q_orig_shape = q_ln_t .shape
1199
+ q = FP8LinearFunctionBase .compute_fp8_linear (q_ln_t .reshape ([- 1 , q_orig_shape [- 1 ]]), q_up_weight , weight_transpose = True , return_transpose_only = True )
1200
+ q = q .reshape ( q_orig_shape [:- 1 ] + [q_up_weight .shape [- 1 ]])
1193
1201
1194
1202
compressed_kv , k_pe = paddle .split (kv_init , [kv_lora_rank , qk_rope_head_dim ], axis = - 1 )
1195
1203
1196
1204
kv_ln_t , kv_ln_invar = fused_ln .fused_rms_norm (compressed_kv , kv_ln_weight , eps )
1197
- kv = paddle .matmul (kv_ln_t , kv_up_weight )
1205
+ #kv = paddle.matmul(kv_ln_t, kv_up_weight)
1206
+ kv_orig_shape = kv_ln_t .shape
1207
+ kv = FP8LinearFunctionBase .compute_fp8_linear (kv_ln_t .reshape ([- 1 , kv_orig_shape [- 1 ]]), kv_up_weight , weight_transpose = True , return_transpose_only = True )
1208
+ kv = kv .reshape ( kv_orig_shape [:- 1 ] + [kv_up_weight .shape [- 1 ]])
1198
1209
1199
1210
query_states , key_states , value_states = qkv_pre_process (
1200
1211
q ,
@@ -1354,12 +1365,26 @@ def backward(ctx, dout):
1354
1365
assert False , f"invalid { FA_VERSION = } "
1355
1366
1356
1367
q_ln_t , q_ln_invar = fused_ln .fused_rms_norm (q_init , q_ln_weight , eps )
1357
- q = paddle .matmul (q_ln_t , q_up_weight )
1368
+
1369
+
1370
+ q_ln_fp8 , q_ln_scale , q_ln_trans_fp8 , q_ln_trans_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
1371
+ q_ln_t .reshape ([- 1 , q_ln_t .shape [- 1 ]]), output_scale_transpose = True ,
1372
+ quant_method = "1x128" , input_transpose = True )
1373
+
1374
+ q_orig_shape = q_ln_t .shape
1375
+ q = FP8LinearFunctionBase .compute_fp8_linear ((q_ln_fp8 , q_ln_scale ), q_up_weight , weight_transpose = True , return_transpose_only = True )
1376
+ q = q .reshape ( q_orig_shape [:- 1 ] + [q_up_weight .shape [- 1 ]])
1358
1377
1359
1378
compressed_kv , k_pe = paddle .split (kv_init , [kv_lora_rank , qk_rope_head_dim ], axis = - 1 )
1360
1379
1361
1380
kv_ln_t , kv_ln_invar = fused_ln .fused_rms_norm (compressed_kv , kv_ln_weight , eps )
1362
- kv = paddle .matmul (kv_ln_t , kv_up_weight )
1381
+
1382
+ kv_ln_fp8 , kv_ln_scale , kv_ln_trans_fp8 , kv_ln_trans_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
1383
+ kv_ln_t .reshape ([- 1 , kv_ln_t .shape [- 1 ]]), output_scale_transpose = True ,
1384
+ quant_method = "1x128" , input_transpose = True )
1385
+ kv_orig_shape = kv_ln_t .shape
1386
+ kv = FP8LinearFunctionBase .compute_fp8_linear ((kv_ln_fp8 , kv_ln_scale ), kv_up_weight , weight_transpose = True , return_transpose_only = True )
1387
+ kv = kv .reshape ( kv_orig_shape [:- 1 ] + [kv_up_weight .shape [- 1 ]])
1363
1388
1364
1389
paddle .base .core ._set_has_grad (True )
1365
1390
q .stop_gradient = False
@@ -1439,20 +1464,29 @@ def backward(ctx, dout):
1439
1464
1440
1465
# call up proj
1441
1466
if hasattr (kv_up_weight , "main_grad" ):
1442
- d_kv_ln_t = paddle .matmul (d_kv , kv_up_weight , transpose_y = True )
1443
-
1444
- def kv_up_weight_grad (kv_ln_t , d_kv , kv_up_weight ):
1445
-
1446
- with paddle .no_grad ():
1447
- w_grad_t = paddle .matmul ( kv_ln_t .reshape ([- 1 , kv_ln_t .shape [- 1 ]]), d_kv .reshape ([- 1 , d_kv .shape [- 1 ]]), transpose_x = True )
1448
-
1449
- kv_up_weight .main_grad .add_ ( w_grad_t )
1450
-
1467
+ d_kv_fp8 , d_kv_scale , d_kv_t_fp8 , d_kv_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
1468
+ d_kv .reshape ([- 1 , d_kv .shape [- 1 ]]), output_scale_transpose = True ,
1469
+ quant_method = "1x128" , input_transpose = True )
1470
+
1471
+ d_kv_ln_t = FP8LinearFunctionBase .compute_fp8_linear ((d_kv_fp8 , d_kv_scale ), kv_up_weight , weight_transpose = False )
1472
+ d_kv_ln_t = d_kv_ln_t .reshape ( d_kv .shape [:- 1 ] + [kv_up_weight .shape [0 ]])
1473
+
1474
+ def kv_up_weight_grad (kv_ln_trans_fp8 , kv_ln_trans_scale , d_kv_t_fp8 , d_kv_t_scale , kv_up_weight ):
1475
+ FP8LinearFunctionBase .kitchen_gemm (
1476
+ kv_ln_trans_fp8 ,
1477
+ kv_ln_trans_scale ,
1478
+ d_kv_t_fp8 ,
1479
+ d_kv_t_scale ,
1480
+ True ,
1481
+ True ,
1482
+ kv_up_weight .main_grad ,
1483
+ paddle .float32 )
1484
+
1451
1485
if WeightGradStore .enabled :
1452
1486
1453
- WeightGradStore .put (partial (kv_up_weight_grad , kv_ln_t , d_kv , kv_up_weight ))
1487
+ WeightGradStore .put (partial (kv_up_weight_grad , kv_ln_trans_fp8 , kv_ln_trans_scale , d_kv_t_fp8 , d_kv_t_scale , kv_up_weight ))
1454
1488
else :
1455
- kv_up_weight_grad (kv_ln_t , d_kv , kv_up_weight )
1489
+ kv_up_weight_grad (kv_ln_trans_fp8 , kv_ln_trans_scale , d_kv_t_fp8 , d_kv_t_scale , kv_up_weight )
1456
1490
1457
1491
d_kv_up_weight = None
1458
1492
@@ -1467,18 +1501,32 @@ def kv_up_weight_grad(kv_ln_t, d_kv, kv_up_weight):
1467
1501
d_kv_init = paddle .concat ([d_compressed_kv , d_k_pe ], axis = - 1 )
1468
1502
1469
1503
if hasattr (q_up_weight , "main_grad" ):
1470
- d_q_ln_t = paddle .matmul (d_q , q_up_weight , transpose_y = True )
1471
1504
1472
- def q_up_weight_grad (q_ln_t , d_q , q_up_weight ):
1505
+ d_q_fp8 , d_q_scale , d_q_t_fp8 , d_q_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
1506
+ d_q .reshape ([- 1 , d_q .shape [- 1 ]]), output_scale_transpose = True ,
1507
+ quant_method = "1x128" , input_transpose = True )
1508
+ #d_q_ln_t = paddle.matmul(d_q, q_up_weight, transpose_y=True)
1473
1509
1474
- with paddle .no_grad ():
1475
- w_grad_t = paddle .matmul ( q_ln_t .reshape ([- 1 , q_ln_t .shape [- 1 ]]), d_q .reshape ([- 1 , d_q .shape [- 1 ]]), transpose_x = True )
1476
- q_up_weight .main_grad .add_ ( w_grad_t )
1510
+ d_q_ln_t = FP8LinearFunctionBase .compute_fp8_linear ((d_q_fp8 , d_q_scale ), q_up_weight , weight_transpose = False )
1511
+ d_q_ln_t = d_q_ln_t .reshape ( d_q .shape [:- 1 ] + [q_up_weight .shape [0 ]])
1512
+
1513
+
1514
+ def q_up_weight_grad (q_ln_trans_fp8 , q_ln_trans_scale , d_q_t_fp8 , d_q_t_scale , q_up_weight ):
1515
+ FP8LinearFunctionBase .kitchen_gemm (
1516
+ q_ln_trans_fp8 ,
1517
+ q_ln_trans_scale ,
1518
+ d_q_t_fp8 ,
1519
+ d_q_t_scale ,
1520
+ True ,
1521
+ True ,
1522
+ q_up_weight .main_grad ,
1523
+ paddle .float32 )
1524
+
1477
1525
1478
1526
if WeightGradStore .enabled :
1479
- WeightGradStore .put (partial (q_up_weight_grad , q_ln_t , d_q , q_up_weight ))
1527
+ WeightGradStore .put (partial (q_up_weight_grad , q_ln_trans_fp8 , q_ln_trans_scale , d_q_t_fp8 , d_q_t_scale , q_up_weight ))
1480
1528
else :
1481
- q_up_weight_grad (q_ln_t , d_q , q_up_weight )
1529
+ q_up_weight_grad (q_ln_trans_fp8 , q_ln_trans_scale , d_q_t_fp8 , d_q_t_scale , q_up_weight )
1482
1530
1483
1531
d_q_up_weight = None
1484
1532
@@ -1556,7 +1604,17 @@ def __init__(
1556
1604
softmax_scale ,
1557
1605
)
1558
1606
1607
+ def fp8_quant_weight (self ):
1608
+ cache_fp8_weight ( self .q_up_weight )
1609
+ cache_fp8_weight ( self .kv_up_weight )
1610
+
1559
1611
def forward (self , q_init , kv_init , position_ids ):
1612
+
1613
+ seq_len = q_init .shape [1 ]
1614
+
1615
+ if self .rotary_emb .max_seq_len_cached is None or seq_len > self .rotary_emb .max_seq_len_cached :
1616
+ self .rotary_emb ._set_cos_sin_cache (seq_len )
1617
+
1560
1618
1561
1619
return MemroyRecomputeAttnFunc .apply (
1562
1620
q_init ,
@@ -1583,10 +1641,18 @@ class FusedRMSLinearFunc(paddle.autograd.PyLayer):
1583
1641
def forward (ctx , x , rms_norm_weight , q_down_weight , kv_down_weight , eps ):
1584
1642
1585
1643
hidden_states , invar = fused_ln .fused_rms_norm (x , rms_norm_weight , eps )
1586
- q = paddle .matmul (hidden_states , q_down_weight )
1644
+
1645
+ h_fp8 , h_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
1646
+ hidden_states .reshape ([- 1 , hidden_states .shape [- 1 ]]), output_scale_transpose = True ,
1647
+ quant_method = "1x128" )
1587
1648
1588
- kv = paddle .matmul (hidden_states , kv_down_weight )
1649
+ h_orig_shape = hidden_states .shape
1650
+ q = FP8LinearFunctionBase .compute_fp8_linear ((h_fp8 , h_scale ), q_down_weight , weight_transpose = True , return_transpose_only = True )
1651
+ q = q .reshape ( h_orig_shape [:- 1 ] + [q_down_weight .shape [- 1 ]])
1589
1652
1653
+
1654
+ kv = paddle .matmul (hidden_states , kv_down_weight )
1655
+
1590
1656
ctx .save_for_backward (x , rms_norm_weight , q_down_weight , kv_down_weight )
1591
1657
ctx .eps = eps
1592
1658
return q , kv
@@ -1596,11 +1662,41 @@ def backward(ctx, d_q, d_kv):
1596
1662
x , rms_norm_weight , q_down_weight , kv_down_weight = ctx .saved_tensor ()
1597
1663
eps = ctx .eps
1598
1664
hidden_states , invar = fused_ln .fused_rms_norm (x , rms_norm_weight , eps )
1665
+
1666
+ h_t_fp8 , h_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
1667
+ hidden_states .reshape ([- 1 , hidden_states .shape [- 1 ]]), output_scale_transpose = True ,
1668
+ quant_method = "1x128" , input_transpose = True , return_transpose_only = True )
1669
+
1670
+ h_grad , d_kv_down_weight = _C_ops .matmul_grad (hidden_states , kv_down_weight , d_kv , False , False )
1671
+
1672
+ if hasattr (q_down_weight , "main_grad" ):
1673
+ d_q_fp8 , d_q_scale , d_q_t_fp8 , d_q_t_scale = paddle .incubate .nn .functional .fp8_quant_blockwise (
1674
+ d_q .reshape ([- 1 , d_q .shape [- 1 ]]), output_scale_transpose = True ,
1675
+ quant_method = "1x128" , input_transpose = True )
1676
+ FP8LinearFunctionBase .compute_fp8_linear ((d_q_fp8 , d_q_scale ), q_down_weight , weight_transpose = False , out = h_grad .view ( [- 1 , h_grad .shape [- 1 ]]))
1677
+
1678
+
1679
+ def q_down_weight_grad (h_t_fp8 , h_t_scale , d_q_t_fp8 , d_q_t_scale , q_down_weight ):
1680
+ FP8LinearFunctionBase .kitchen_gemm (
1681
+ h_t_fp8 ,
1682
+ h_t_scale ,
1683
+ d_q_t_fp8 ,
1684
+ d_q_t_scale ,
1685
+ True ,
1686
+ True ,
1687
+ q_down_weight .main_grad ,
1688
+ paddle .float32 )
1689
+
1690
+ if WeightGradStore .enabled :
1691
+ WeightGradStore .put (partial (q_down_weight_grad , h_t_fp8 , h_t_scale , d_q_t_fp8 , d_q_t_scale , q_down_weight ))
1692
+ else :
1693
+ q_down_weight_grad ( h_t_fp8 , h_t_scale , d_q_t_fp8 , d_q_t_scale , q_down_weight )
1599
1694
1600
- h_grad_0 , d_q_down_weight = _C_ops .matmul_grad (hidden_states , q_down_weight , d_q , False , False )
1601
- h_grad_1 , d_kv_down_weight = _C_ops .matmul_grad (hidden_states , kv_down_weight , d_kv , False , False )
1695
+ d_q_down_weight = None
1602
1696
1603
- h_grad = h_grad_0 + h_grad_1
1697
+ else :
1698
+ h_grad_0 , d_q_down_weight = _C_ops .matmul_grad (hidden_states , q_down_weight , d_q , False , False )
1699
+ h_grad = h_grad + h_grad_0
1604
1700
1605
1701
dx , d_rms_norm_weight = fused_ln .fused_rms_norm_grad_func (x , rms_norm_weight , invar , h_grad , eps )
1606
1702
@@ -1630,6 +1726,10 @@ def __init__(self, hidden_size, q_out_dim, kv_outdim, eps=1e-6) -> None:
1630
1726
is_bias = False ,
1631
1727
)
1632
1728
self .eps = eps
1729
+
1730
+ def fp8_quant_weight (self ):
1731
+ cache_fp8_weight ( self .q_down_weight )
1732
+
1633
1733
1634
1734
def forward (self , x ):
1635
1735
@@ -1791,6 +1891,15 @@ def linear_dtype_gaurd():
1791
1891
1792
1892
self .attn_func = scaled_dot_product_attention
1793
1893
1894
+ def fp8_quant_weight (self ):
1895
+
1896
+ if DSV3_USE_ATTEN_RECOMPUTE :
1897
+ self .o_proj .fp8_quant_weight ()
1898
+ self .memory_recompute_att .fp8_quant_weight ()
1899
+ self .fused_rms_norm_linear .fp8_quant_weight ()
1900
+
1901
+
1902
+
1794
1903
def _init_rope (self ):
1795
1904
if self .config .rope_scaling is None :
1796
1905
self .rotary_emb = DeepseekV2RotaryEmbedding (
@@ -1884,7 +1993,6 @@ def forward(
1884
1993
target_key_value_shape = [0 , 0 , self .num_heads , self .qk_nope_head_dim + self .v_head_dim ]
1885
1994
1886
1995
q = q .reshape (shape = target_query_shape )
1887
- # q.register_hook( print_grad)
1888
1996
q_nope , q_pe = paddle .split (q , [self .qk_nope_head_dim , self .qk_rope_head_dim ], axis = - 1 )
1889
1997
1890
1998
# DeepSeekV2 kv_lora_rank+qk_rope_head_dim=512+64
@@ -2016,8 +2124,9 @@ def __init__(self, config: DeepseekV2Config, layer_idx: int, layerwise_recompute
2016
2124
def fp8_quant_weight (self , batch_mode = False ):
2017
2125
"""fp8_quant_weight"""
2018
2126
if isinstance (self .mlp , DeepseekV2MoE ):
2019
- logger .info (f"fp8 quant weight for mlp { type (self .mlp )} " )
2127
+ # logger.info(f"fp8 quant weight for mlp {type(self.mlp)}")
2020
2128
self .mlp .fp8_quant_weight (batch_mode )
2129
+ self .self_attn .fp8_quant_weight ()
2021
2130
2022
2131
def forward (
2023
2132
self ,
0 commit comments