Skip to content

Commit b93fda9

Browse files
phlrainphlrain
andauthored
fp8_quant_cache_and_using_fp8_gemm (#10923)
* update * update modeling and modeling_pp * add shared export fp8 quant cache * polish code --------- Co-authored-by: phlrain <phiuhongyu@126.com>
1 parent 9d1130b commit b93fda9

File tree

3 files changed

+273
-87
lines changed

3 files changed

+273
-87
lines changed

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 142 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@
8282

8383
from paddle.distributed.fleet.meta_parallel.zero_bubble_utils import WeightGradStore
8484

85-
from ..fp8_utils import FP8KeepXLinear, FP8Linear, FP8Mlp
85+
from ..fp8_utils import FP8KeepXLinear, FP8Linear, FP8Mlp, FP8LinearFunctionBase, cache_fp8_weight
8686
from .fp8_linear import Linear
8787

8888
DSV3_USE_FP8_GEMM = os.getenv("DSV3_USE_FP8_GEMM", "False").lower() == "true"
@@ -961,9 +961,9 @@ def __init__(self, config: DeepseekV2Config, norm_weight=None, norm_eps=None):
961961
using_post_norm_recompute=self.using_post_norm_recompute,
962962
)
963963

964-
moe_grad_group = fleet.get_hybrid_communicate_group().expert_grad_comm_group
965-
for p in self.experts.parameters():
966-
setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group})
964+
# moe_grad_group = fleet.get_hybrid_communicate_group().expert_grad_comm_group
965+
# for p in self.experts.parameters():
966+
# setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group})
967967

968968
self.alpha = config.aux_loss_alpha
969969
if config.n_shared_experts is not None:
@@ -995,6 +995,8 @@ def quantize_weights(weight_list, weight_obj=None):
995995
"""Helper function to quantize a list of weights."""
996996
if weight_obj is None:
997997
weight_obj = weight_list[0]
998+
if hasattr( weight_obj, "fp8_weight_stacked"):
999+
return
9981000

9991001
# Quantize without transpose
10001002
fp8_weight, fp8_scale = paddle.incubate.nn.functional.fused_stack_transpose_quant(
@@ -1025,6 +1027,9 @@ def quantize_weights(weight_list, weight_obj=None):
10251027
if expert is not None:
10261028
quantize_weights([expert.w1])
10271029
quantize_weights([expert.w1])
1030+
1031+
if self.config.n_shared_experts is not None:
1032+
self.shared_experts.fp8_quant_weight()
10281033

10291034
def forward(self, hidden_states):
10301035
if self.using_post_norm_recompute:
@@ -1189,12 +1194,18 @@ def forward(
11891194

11901195
bsz = q_init.shape[0]
11911196
q_ln_t, q_ln_invar = fused_ln.fused_rms_norm(q_init, q_ln_weight, eps)
1192-
q = paddle.matmul(q_ln_t, q_up_weight)
1197+
#q = paddle.matmul(q_ln_t, q_up_weight)
1198+
q_orig_shape = q_ln_t.shape
1199+
q = FP8LinearFunctionBase.compute_fp8_linear(q_ln_t.reshape([-1, q_orig_shape[-1]]), q_up_weight, weight_transpose=True, return_transpose_only=True)
1200+
q = q.reshape( q_orig_shape[:-1] + [q_up_weight.shape[-1]])
11931201

11941202
compressed_kv, k_pe = paddle.split(kv_init, [kv_lora_rank, qk_rope_head_dim], axis=-1)
11951203

11961204
kv_ln_t, kv_ln_invar = fused_ln.fused_rms_norm(compressed_kv, kv_ln_weight, eps)
1197-
kv = paddle.matmul(kv_ln_t, kv_up_weight)
1205+
#kv = paddle.matmul(kv_ln_t, kv_up_weight)
1206+
kv_orig_shape = kv_ln_t.shape
1207+
kv = FP8LinearFunctionBase.compute_fp8_linear(kv_ln_t.reshape([-1, kv_orig_shape[-1]]), kv_up_weight, weight_transpose=True, return_transpose_only=True)
1208+
kv = kv.reshape( kv_orig_shape[:-1] + [kv_up_weight.shape[-1]])
11981209

11991210
query_states, key_states, value_states = qkv_pre_process(
12001211
q,
@@ -1354,12 +1365,26 @@ def backward(ctx, dout):
13541365
assert False, f"invalid {FA_VERSION=}"
13551366

13561367
q_ln_t, q_ln_invar = fused_ln.fused_rms_norm(q_init, q_ln_weight, eps)
1357-
q = paddle.matmul(q_ln_t, q_up_weight)
1368+
1369+
1370+
q_ln_fp8, q_ln_scale, q_ln_trans_fp8, q_ln_trans_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
1371+
q_ln_t.reshape([-1, q_ln_t.shape[-1]]), output_scale_transpose=True,
1372+
quant_method="1x128", input_transpose=True )
1373+
1374+
q_orig_shape = q_ln_t.shape
1375+
q = FP8LinearFunctionBase.compute_fp8_linear((q_ln_fp8, q_ln_scale), q_up_weight, weight_transpose=True, return_transpose_only=True)
1376+
q = q.reshape( q_orig_shape[:-1] + [q_up_weight.shape[-1]])
13581377

13591378
compressed_kv, k_pe = paddle.split(kv_init, [kv_lora_rank, qk_rope_head_dim], axis=-1)
13601379

13611380
kv_ln_t, kv_ln_invar = fused_ln.fused_rms_norm(compressed_kv, kv_ln_weight, eps)
1362-
kv = paddle.matmul(kv_ln_t, kv_up_weight)
1381+
1382+
kv_ln_fp8, kv_ln_scale, kv_ln_trans_fp8, kv_ln_trans_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
1383+
kv_ln_t.reshape([-1, kv_ln_t.shape[-1]]), output_scale_transpose=True,
1384+
quant_method="1x128", input_transpose=True )
1385+
kv_orig_shape = kv_ln_t.shape
1386+
kv = FP8LinearFunctionBase.compute_fp8_linear((kv_ln_fp8, kv_ln_scale), kv_up_weight, weight_transpose=True, return_transpose_only=True)
1387+
kv = kv.reshape( kv_orig_shape[:-1] + [kv_up_weight.shape[-1]])
13631388

13641389
paddle.base.core._set_has_grad(True)
13651390
q.stop_gradient = False
@@ -1439,20 +1464,29 @@ def backward(ctx, dout):
14391464

14401465
# call up proj
14411466
if hasattr(kv_up_weight, "main_grad"):
1442-
d_kv_ln_t = paddle.matmul(d_kv, kv_up_weight, transpose_y=True)
1443-
1444-
def kv_up_weight_grad(kv_ln_t, d_kv, kv_up_weight):
1445-
1446-
with paddle.no_grad():
1447-
w_grad_t = paddle.matmul( kv_ln_t.reshape([-1, kv_ln_t.shape[-1]]), d_kv.reshape([-1, d_kv.shape[-1]]), transpose_x=True)
1448-
1449-
kv_up_weight.main_grad.add_( w_grad_t )
1450-
1467+
d_kv_fp8, d_kv_scale, d_kv_t_fp8, d_kv_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
1468+
d_kv.reshape([-1, d_kv.shape[-1]]), output_scale_transpose=True,
1469+
quant_method="1x128", input_transpose=True )
1470+
1471+
d_kv_ln_t = FP8LinearFunctionBase.compute_fp8_linear((d_kv_fp8, d_kv_scale), kv_up_weight, weight_transpose=False)
1472+
d_kv_ln_t = d_kv_ln_t.reshape( d_kv.shape[:-1] + [kv_up_weight.shape[0]])
1473+
1474+
def kv_up_weight_grad(kv_ln_trans_fp8, kv_ln_trans_scale, d_kv_t_fp8, d_kv_t_scale, kv_up_weight):
1475+
FP8LinearFunctionBase.kitchen_gemm(
1476+
kv_ln_trans_fp8,
1477+
kv_ln_trans_scale,
1478+
d_kv_t_fp8,
1479+
d_kv_t_scale,
1480+
True,
1481+
True,
1482+
kv_up_weight.main_grad,
1483+
paddle.float32 )
1484+
14511485
if WeightGradStore.enabled:
14521486

1453-
WeightGradStore.put(partial(kv_up_weight_grad, kv_ln_t, d_kv, kv_up_weight))
1487+
WeightGradStore.put(partial(kv_up_weight_grad, kv_ln_trans_fp8, kv_ln_trans_scale, d_kv_t_fp8, d_kv_t_scale, kv_up_weight))
14541488
else:
1455-
kv_up_weight_grad(kv_ln_t, d_kv, kv_up_weight)
1489+
kv_up_weight_grad(kv_ln_trans_fp8, kv_ln_trans_scale, d_kv_t_fp8, d_kv_t_scale, kv_up_weight)
14561490

14571491
d_kv_up_weight = None
14581492

@@ -1467,18 +1501,32 @@ def kv_up_weight_grad(kv_ln_t, d_kv, kv_up_weight):
14671501
d_kv_init = paddle.concat([d_compressed_kv, d_k_pe], axis=-1)
14681502

14691503
if hasattr(q_up_weight, "main_grad"):
1470-
d_q_ln_t = paddle.matmul(d_q, q_up_weight, transpose_y=True)
14711504

1472-
def q_up_weight_grad(q_ln_t, d_q, q_up_weight):
1505+
d_q_fp8, d_q_scale, d_q_t_fp8, d_q_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
1506+
d_q.reshape([-1, d_q.shape[-1]]), output_scale_transpose=True,
1507+
quant_method="1x128", input_transpose=True )
1508+
#d_q_ln_t = paddle.matmul(d_q, q_up_weight, transpose_y=True)
14731509

1474-
with paddle.no_grad():
1475-
w_grad_t = paddle.matmul( q_ln_t.reshape([-1, q_ln_t.shape[-1]]), d_q.reshape([-1, d_q.shape[-1]]), transpose_x=True)
1476-
q_up_weight.main_grad.add_( w_grad_t )
1510+
d_q_ln_t = FP8LinearFunctionBase.compute_fp8_linear((d_q_fp8, d_q_scale), q_up_weight, weight_transpose=False)
1511+
d_q_ln_t = d_q_ln_t.reshape( d_q.shape[:-1] + [q_up_weight.shape[0]])
1512+
1513+
1514+
def q_up_weight_grad(q_ln_trans_fp8, q_ln_trans_scale, d_q_t_fp8, d_q_t_scale, q_up_weight):
1515+
FP8LinearFunctionBase.kitchen_gemm(
1516+
q_ln_trans_fp8,
1517+
q_ln_trans_scale,
1518+
d_q_t_fp8,
1519+
d_q_t_scale,
1520+
True,
1521+
True,
1522+
q_up_weight.main_grad,
1523+
paddle.float32 )
1524+
14771525

14781526
if WeightGradStore.enabled:
1479-
WeightGradStore.put(partial(q_up_weight_grad, q_ln_t, d_q, q_up_weight))
1527+
WeightGradStore.put(partial(q_up_weight_grad, q_ln_trans_fp8, q_ln_trans_scale, d_q_t_fp8, d_q_t_scale, q_up_weight))
14801528
else:
1481-
q_up_weight_grad(q_ln_t, d_q, q_up_weight)
1529+
q_up_weight_grad(q_ln_trans_fp8, q_ln_trans_scale, d_q_t_fp8, d_q_t_scale, q_up_weight)
14821530

14831531
d_q_up_weight = None
14841532

@@ -1556,7 +1604,17 @@ def __init__(
15561604
softmax_scale,
15571605
)
15581606

1607+
def fp8_quant_weight(self):
1608+
cache_fp8_weight( self.q_up_weight)
1609+
cache_fp8_weight( self.kv_up_weight)
1610+
15591611
def forward(self, q_init, kv_init, position_ids):
1612+
1613+
seq_len = q_init.shape[1]
1614+
1615+
if self.rotary_emb.max_seq_len_cached is None or seq_len > self.rotary_emb.max_seq_len_cached:
1616+
self.rotary_emb._set_cos_sin_cache(seq_len)
1617+
15601618

15611619
return MemroyRecomputeAttnFunc.apply(
15621620
q_init,
@@ -1583,10 +1641,18 @@ class FusedRMSLinearFunc(paddle.autograd.PyLayer):
15831641
def forward(ctx, x, rms_norm_weight, q_down_weight, kv_down_weight, eps):
15841642

15851643
hidden_states, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps)
1586-
q = paddle.matmul(hidden_states, q_down_weight)
1644+
1645+
h_fp8, h_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
1646+
hidden_states.reshape([-1, hidden_states.shape[-1]]), output_scale_transpose=True,
1647+
quant_method="1x128" )
15871648

1588-
kv = paddle.matmul(hidden_states, kv_down_weight)
1649+
h_orig_shape = hidden_states.shape
1650+
q = FP8LinearFunctionBase.compute_fp8_linear((h_fp8, h_scale), q_down_weight, weight_transpose=True, return_transpose_only=True)
1651+
q = q.reshape( h_orig_shape[:-1] + [q_down_weight.shape[-1]])
15891652

1653+
1654+
kv = paddle.matmul(hidden_states, kv_down_weight)
1655+
15901656
ctx.save_for_backward(x, rms_norm_weight, q_down_weight, kv_down_weight)
15911657
ctx.eps = eps
15921658
return q, kv
@@ -1596,11 +1662,41 @@ def backward(ctx, d_q, d_kv):
15961662
x, rms_norm_weight, q_down_weight, kv_down_weight = ctx.saved_tensor()
15971663
eps = ctx.eps
15981664
hidden_states, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps)
1665+
1666+
h_t_fp8, h_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
1667+
hidden_states.reshape([-1, hidden_states.shape[-1]]), output_scale_transpose=True,
1668+
quant_method="1x128", input_transpose=True, return_transpose_only=True )
1669+
1670+
h_grad, d_kv_down_weight = _C_ops.matmul_grad(hidden_states, kv_down_weight, d_kv, False, False)
1671+
1672+
if hasattr(q_down_weight, "main_grad"):
1673+
d_q_fp8, d_q_scale, d_q_t_fp8, d_q_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise(
1674+
d_q.reshape([-1, d_q.shape[-1]]), output_scale_transpose=True,
1675+
quant_method="1x128", input_transpose=True )
1676+
FP8LinearFunctionBase.compute_fp8_linear((d_q_fp8, d_q_scale), q_down_weight, weight_transpose=False, out=h_grad.view( [-1, h_grad.shape[-1]]))
1677+
1678+
1679+
def q_down_weight_grad(h_t_fp8, h_t_scale, d_q_t_fp8, d_q_t_scale, q_down_weight):
1680+
FP8LinearFunctionBase.kitchen_gemm(
1681+
h_t_fp8,
1682+
h_t_scale,
1683+
d_q_t_fp8,
1684+
d_q_t_scale,
1685+
True,
1686+
True,
1687+
q_down_weight.main_grad,
1688+
paddle.float32 )
1689+
1690+
if WeightGradStore.enabled:
1691+
WeightGradStore.put(partial(q_down_weight_grad, h_t_fp8, h_t_scale, d_q_t_fp8, d_q_t_scale, q_down_weight))
1692+
else:
1693+
q_down_weight_grad( h_t_fp8, h_t_scale, d_q_t_fp8, d_q_t_scale, q_down_weight)
15991694

1600-
h_grad_0, d_q_down_weight = _C_ops.matmul_grad(hidden_states, q_down_weight, d_q, False, False)
1601-
h_grad_1, d_kv_down_weight = _C_ops.matmul_grad(hidden_states, kv_down_weight, d_kv, False, False)
1695+
d_q_down_weight = None
16021696

1603-
h_grad = h_grad_0 + h_grad_1
1697+
else:
1698+
h_grad_0, d_q_down_weight = _C_ops.matmul_grad(hidden_states, q_down_weight, d_q, False, False)
1699+
h_grad = h_grad + h_grad_0
16041700

16051701
dx, d_rms_norm_weight = fused_ln.fused_rms_norm_grad_func(x, rms_norm_weight, invar, h_grad, eps)
16061702

@@ -1630,6 +1726,10 @@ def __init__(self, hidden_size, q_out_dim, kv_outdim, eps=1e-6) -> None:
16301726
is_bias=False,
16311727
)
16321728
self.eps = eps
1729+
1730+
def fp8_quant_weight(self):
1731+
cache_fp8_weight( self.q_down_weight)
1732+
16331733

16341734
def forward(self, x):
16351735

@@ -1791,6 +1891,15 @@ def linear_dtype_gaurd():
17911891

17921892
self.attn_func = scaled_dot_product_attention
17931893

1894+
def fp8_quant_weight(self):
1895+
1896+
if DSV3_USE_ATTEN_RECOMPUTE:
1897+
self.o_proj.fp8_quant_weight()
1898+
self.memory_recompute_att.fp8_quant_weight()
1899+
self.fused_rms_norm_linear.fp8_quant_weight()
1900+
1901+
1902+
17941903
def _init_rope(self):
17951904
if self.config.rope_scaling is None:
17961905
self.rotary_emb = DeepseekV2RotaryEmbedding(
@@ -1884,7 +1993,6 @@ def forward(
18841993
target_key_value_shape = [0, 0, self.num_heads, self.qk_nope_head_dim + self.v_head_dim]
18851994

18861995
q = q.reshape(shape=target_query_shape)
1887-
# q.register_hook( print_grad)
18881996
q_nope, q_pe = paddle.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
18891997

18901998
# DeepSeekV2 kv_lora_rank+qk_rope_head_dim=512+64
@@ -2016,8 +2124,9 @@ def __init__(self, config: DeepseekV2Config, layer_idx: int, layerwise_recompute
20162124
def fp8_quant_weight(self, batch_mode=False):
20172125
"""fp8_quant_weight"""
20182126
if isinstance(self.mlp, DeepseekV2MoE):
2019-
logger.info(f"fp8 quant weight for mlp {type(self.mlp)}")
2127+
# logger.info(f"fp8 quant weight for mlp {type(self.mlp)}")
20202128
self.mlp.fp8_quant_weight(batch_mode)
2129+
self.self_attn.fp8_quant_weight()
20212130

20222131
def forward(
20232132
self,

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -887,7 +887,11 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
887887
mlp_fwd_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id)
888888

889889
if pp_stream is not None:
890+
paddle.base.core.nvprof_nvtx_push("post_process_forward")
891+
890892
final_out = self.forward_node.post_process_node.forward_without_residual(inputs)
893+
paddle.base.core.nvprof_nvtx_pop()
894+
891895

892896
final_out_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id)
893897

@@ -921,9 +925,7 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
921925
paddle.base.core.nvprof_nvtx_pop()
922926

923927
dispatch_backward_event.calc_stream_wait(self.backward_node.moe_group.id)
924-
paddle.base.core.nvprof_nvtx_push("post_process_forward")
925-
926-
paddle.base.core.nvprof_nvtx_pop()
928+
927929
paddle.base.core.nvprof_nvtx_push("attn_backward")
928930
assert WeightGradStore.funcs_queue.empty()
929931
WeightGradStore.enabled = True
@@ -938,12 +940,16 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
938940
WeightGradStore.pop()
939941
assert WeightGradStore.funcs_queue.empty()
940942

943+
WeightGradStore.enabled = False
944+
WeightGradStore.flush()
945+
WeightGradStore.pop()
946+
assert WeightGradStore.funcs_queue.empty()
941947
paddle.base.core.nvprof_nvtx_pop()
942948

943949
# residual add
944950
if pp_stream is None:
945951
combine_forward_event.calc_stream_wait(self.forward_node.moe_group.id)
946-
952+
947953
final_out = self.forward_node.post_process_node.forward_without_residual(inputs)
948954
if final_out.shape[-1] != combine_fwd_out.shape[-1]:
949955
final_out[:, :, : combine_fwd_out.shape[-1]] += combine_fwd_out # 直接广播并相加

0 commit comments

Comments
 (0)