Skip to content

Commit f7ca489

Browse files
author
root
committed
Fix mtp bug
1 parent 2ff16f8 commit f7ca489

File tree

1 file changed

+27
-15
lines changed

1 file changed

+27
-15
lines changed

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def forward_without_residual(self, inputs):
192192

193193
if self.send_mtp_embed:
194194
hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1)
195+
self.mtp_embed_shape = inputs_embeds_mtp.shape # 保存mtp_embed的shape用于反向传播
195196

196197
return return_args(hidden_states)
197198

@@ -227,37 +228,45 @@ def forward(self, inputs):
227228

228229
if self.send_mtp_embed:
229230
hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1)
231+
self.mtp_embed_shape = inputs_embeds_mtp.shape # 保存mtp_embed的shape用于反向传播
230232

231233
return return_args(hidden_states)
232234

233235
@paddle.no_grad()
234236
def backward(self, output_grad):
235237
(do3,) = output_grad
236238

237-
assert not self.send_mtp_embed, "not support have mtp have yet"
239+
if self.send_mtp_embed:
240+
# 分割梯度:do3的前部分对应hidden_states,后部分对应inputs_embeds_mtp
241+
hidden_size = do3.shape[-1] - self.mtp_embed_shape[-1]
242+
hidden_states_grad = do3[..., :hidden_size]
243+
inputs_embeds_mtp_grad = do3[..., hidden_size:]
244+
else:
245+
hidden_states_grad = do3
246+
inputs_embeds_mtp_grad = None
247+
238248
if self.using_post_norm_recompute:
239249
dx = FP8LinearFunctionBase.fp8_mlp_bwd_norm_rc(
240-
do3,
250+
hidden_states_grad,
241251
self.x,
242252
self.shared_experts.norm_weight,
243253
self.shared_experts.norm_eps,
244254
self.shared_experts.w1,
245255
self.shared_experts.w2,
246256
)
247257
else:
248-
dx = FP8LinearFunctionBase.fp8_mlp_bwd(do3, self.x, self.shared_experts.w1, self.shared_experts.w2)
258+
dx = FP8LinearFunctionBase.fp8_mlp_bwd(hidden_states_grad, self.x, self.shared_experts.w1, self.shared_experts.w2)
249259

250260
self.x = None
251261

252-
residual_grad = do3
253-
254-
hidden_states_grad = dx
255-
262+
residual_grad = hidden_states_grad
256263
l_aux_grad = paddle.ones(1, dtype=self.l_aux.dtype) * self.alpha
264+
final_hidden_states_grad = hidden_states_grad
257265

258-
final_hidden_states_grad = do3
259-
260-
return (hidden_states_grad, residual_grad, l_aux_grad, final_hidden_states_grad)
266+
if self.send_mtp_embed:
267+
return (inputs_embeds_mtp_grad, dx, residual_grad, l_aux_grad, final_hidden_states_grad)
268+
else:
269+
return (dx, residual_grad, l_aux_grad, final_hidden_states_grad)
261270

262271

263272
class DecoderLayerNode(ScheduleNode):
@@ -749,6 +758,9 @@ def attn_backward(self, output_grad):
749758
hs_grad,
750759
token_probs_grad,
751760
) = output_grad
761+
inputs_embeds_mtp_grad_shape = hidden_states_grad.shape
762+
inputs_embeds_mtp_grad_shape[-1] = -1
763+
inputs_embeds_mtp_grad = inputs_embeds_mtp_grad.view(inputs_embeds_mtp_grad_shape)
752764
else:
753765
hidden_states_grad, residual_grad, l_aux_grad, hs_grad, token_probs_grad = output_grad
754766

@@ -906,8 +918,8 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
906918
combine_forward_event.calc_stream_wait(self.forward_node.moe_group.id)
907919

908920
final_out = self.forward_node.post_process_node.forward_without_residual(inputs)
909-
inputs = final_out + combine_fwd_out
910-
921+
final_out[:, :, :combine_fwd_out.shape[-1]] += combine_fwd_out # 直接广播并相加
922+
inputs = final_out
911923
combine_fwd_out._record_stream()
912924

913925
paddle.base.core.nvprof_nvtx_pop()
@@ -1072,7 +1084,7 @@ def forward(self, args):
10721084
if self.config.send_mtp_embed:
10731085
batch_size, _, hidden_size = hidden_states.shape
10741086
batch_size_mtp = hidden_size // (self.config.num_nextn_predict_layers + 1)
1075-
inputs_embeds_mtp = hidden_states[..., -batch_size_mtp:]
1087+
inputs_embeds_mtp = hidden_states[..., batch_size_mtp:]
10761088
hidden_states = hidden_states[..., :batch_size_mtp]
10771089

10781090
has_gradient = not hidden_states.stop_gradient
@@ -1129,7 +1141,7 @@ def attn_compute(self, args):
11291141

11301142
batch_size, _, hidden_size = hidden_states.shape
11311143
batch_size_mtp = hidden_size // (self.config.num_nextn_predict_layers + 1)
1132-
inputs_embeds_mtp = hidden_states[..., -batch_size_mtp:]
1144+
inputs_embeds_mtp = hidden_states[..., batch_size_mtp:]
11331145
hidden_states = hidden_states[..., :batch_size_mtp]
11341146

11351147
def attn_compute_func(hidden_states):
@@ -1162,7 +1174,7 @@ def attn_compute_for_fusion(self, args):
11621174
# slice from holy tensor
11631175
batch_size, _, hidden_size = hidden_states.shape
11641176
batch_size_mtp = hidden_size // (self.config.num_nextn_predict_layers + 1)
1165-
inputs_embeds_mtp = hidden_states[..., -batch_size_mtp:]
1177+
inputs_embeds_mtp = hidden_states[..., batch_size_mtp:]
11661178
hidden_states = hidden_states[..., :batch_size_mtp]
11671179

11681180
hidden_states, residual = self.self_attn_compute(hidden_states)

0 commit comments

Comments
 (0)