@@ -192,6 +192,7 @@ def forward_without_residual(self, inputs):
192
192
193
193
if self .send_mtp_embed :
194
194
hidden_states = paddle .concat ([hidden_states , inputs_embeds_mtp ], axis = - 1 )
195
+ self .mtp_embed_shape = inputs_embeds_mtp .shape # 保存mtp_embed的shape用于反向传播
195
196
196
197
return return_args (hidden_states )
197
198
@@ -227,37 +228,45 @@ def forward(self, inputs):
227
228
228
229
if self .send_mtp_embed :
229
230
hidden_states = paddle .concat ([hidden_states , inputs_embeds_mtp ], axis = - 1 )
231
+ self .mtp_embed_shape = inputs_embeds_mtp .shape # 保存mtp_embed的shape用于反向传播
230
232
231
233
return return_args (hidden_states )
232
234
233
235
@paddle .no_grad ()
234
236
def backward (self , output_grad ):
235
237
(do3 ,) = output_grad
236
238
237
- assert not self .send_mtp_embed , "not support have mtp have yet"
239
+ if self .send_mtp_embed :
240
+ # 分割梯度:do3的前部分对应hidden_states,后部分对应inputs_embeds_mtp
241
+ hidden_size = do3 .shape [- 1 ] - self .mtp_embed_shape [- 1 ]
242
+ hidden_states_grad = do3 [..., :hidden_size ]
243
+ inputs_embeds_mtp_grad = do3 [..., hidden_size :]
244
+ else :
245
+ hidden_states_grad = do3
246
+ inputs_embeds_mtp_grad = None
247
+
238
248
if self .using_post_norm_recompute :
239
249
dx = FP8LinearFunctionBase .fp8_mlp_bwd_norm_rc (
240
- do3 ,
250
+ hidden_states_grad ,
241
251
self .x ,
242
252
self .shared_experts .norm_weight ,
243
253
self .shared_experts .norm_eps ,
244
254
self .shared_experts .w1 ,
245
255
self .shared_experts .w2 ,
246
256
)
247
257
else :
248
- dx = FP8LinearFunctionBase .fp8_mlp_bwd (do3 , self .x , self .shared_experts .w1 , self .shared_experts .w2 )
258
+ dx = FP8LinearFunctionBase .fp8_mlp_bwd (hidden_states_grad , self .x , self .shared_experts .w1 , self .shared_experts .w2 )
249
259
250
260
self .x = None
251
261
252
- residual_grad = do3
253
-
254
- hidden_states_grad = dx
255
-
262
+ residual_grad = hidden_states_grad
256
263
l_aux_grad = paddle .ones (1 , dtype = self .l_aux .dtype ) * self .alpha
264
+ final_hidden_states_grad = hidden_states_grad
257
265
258
- final_hidden_states_grad = do3
259
-
260
- return (hidden_states_grad , residual_grad , l_aux_grad , final_hidden_states_grad )
266
+ if self .send_mtp_embed :
267
+ return (inputs_embeds_mtp_grad , dx , residual_grad , l_aux_grad , final_hidden_states_grad )
268
+ else :
269
+ return (dx , residual_grad , l_aux_grad , final_hidden_states_grad )
261
270
262
271
263
272
class DecoderLayerNode (ScheduleNode ):
@@ -749,6 +758,9 @@ def attn_backward(self, output_grad):
749
758
hs_grad ,
750
759
token_probs_grad ,
751
760
) = output_grad
761
+ inputs_embeds_mtp_grad_shape = hidden_states_grad .shape
762
+ inputs_embeds_mtp_grad_shape [- 1 ] = - 1
763
+ inputs_embeds_mtp_grad = inputs_embeds_mtp_grad .view (inputs_embeds_mtp_grad_shape )
752
764
else :
753
765
hidden_states_grad , residual_grad , l_aux_grad , hs_grad , token_probs_grad = output_grad
754
766
@@ -906,8 +918,8 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
906
918
combine_forward_event .calc_stream_wait (self .forward_node .moe_group .id )
907
919
908
920
final_out = self .forward_node .post_process_node .forward_without_residual (inputs )
909
- inputs = final_out + combine_fwd_out
910
-
921
+ final_out [:, :, : combine_fwd_out . shape [ - 1 ]] += combine_fwd_out # 直接广播并相加
922
+ inputs = final_out
911
923
combine_fwd_out ._record_stream ()
912
924
913
925
paddle .base .core .nvprof_nvtx_pop ()
@@ -1072,7 +1084,7 @@ def forward(self, args):
1072
1084
if self .config .send_mtp_embed :
1073
1085
batch_size , _ , hidden_size = hidden_states .shape
1074
1086
batch_size_mtp = hidden_size // (self .config .num_nextn_predict_layers + 1 )
1075
- inputs_embeds_mtp = hidden_states [..., - batch_size_mtp :]
1087
+ inputs_embeds_mtp = hidden_states [..., batch_size_mtp :]
1076
1088
hidden_states = hidden_states [..., :batch_size_mtp ]
1077
1089
1078
1090
has_gradient = not hidden_states .stop_gradient
@@ -1129,7 +1141,7 @@ def attn_compute(self, args):
1129
1141
1130
1142
batch_size , _ , hidden_size = hidden_states .shape
1131
1143
batch_size_mtp = hidden_size // (self .config .num_nextn_predict_layers + 1 )
1132
- inputs_embeds_mtp = hidden_states [..., - batch_size_mtp :]
1144
+ inputs_embeds_mtp = hidden_states [..., batch_size_mtp :]
1133
1145
hidden_states = hidden_states [..., :batch_size_mtp ]
1134
1146
1135
1147
def attn_compute_func (hidden_states ):
@@ -1162,7 +1174,7 @@ def attn_compute_for_fusion(self, args):
1162
1174
# slice from holy tensor
1163
1175
batch_size , _ , hidden_size = hidden_states .shape
1164
1176
batch_size_mtp = hidden_size // (self .config .num_nextn_predict_layers + 1 )
1165
- inputs_embeds_mtp = hidden_states [..., - batch_size_mtp :]
1177
+ inputs_embeds_mtp = hidden_states [..., batch_size_mtp :]
1166
1178
hidden_states = hidden_states [..., :batch_size_mtp ]
1167
1179
1168
1180
hidden_states , residual = self .self_attn_compute (hidden_states )
0 commit comments