Skip to content

Revert dsv3_dev to runnable version #10907

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: dsv3_dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 24 additions & 130 deletions paddlenlp/transformers/deepseek_v2/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@


DSV3_USE_FP8_GEMM = os.getenv("DSV3_USE_FP8_GEMM", "False").lower() == "true"
DSV3_USE_FP8_DISPATCH = os.getenv("DSV3_USE_FP8_DISPATCH", "False").lower() == "true"


def parse_args(args):
Expand Down Expand Up @@ -158,43 +157,6 @@ def __init__(
assert self.shared_experts is not None
assert self.shared_experts.norm_weight is not None and self.shared_experts.norm_eps is not None

def forward_without_residual(self, inputs):

if isinstance(inputs, list):
inputs = tuple(inputs)

if self.send_mtp_embed:
(inputs_embeds_mtp, hidden_states, residual, l_aux, final_hidden_states) = inputs
else:
(hidden_states, residual, l_aux, final_hidden_states) = inputs

with paddle.no_grad():
if self.shared_experts is not None:
if self.using_post_norm_recompute:
shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd_norm_rc(
hidden_states,
self.shared_experts.norm_weight,
self.shared_experts.norm_eps,
self.shared_experts.w1,
self.shared_experts.w2,
)
else:
_, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd(
hidden_states, self.shared_experts.w1, self.shared_experts.w2
)
residual = residual + shared_expert_output

self.x = hidden_states
self.l_aux = l_aux

hidden_states = residual
hidden_states.stop_gradient = False

if self.send_mtp_embed:
hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1)

return return_args(hidden_states)

def forward(self, inputs):

if isinstance(inputs, list):
Expand Down Expand Up @@ -469,17 +431,9 @@ def __init__(self, forward_nodes, backward_nodes, use_fuion=True):
for f, b in zip(forward_nodes, backward_nodes):
self.nodes.append(schedule_node_class(f, b, f"OverlapedNode_{len(self.nodes)}"))

def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, pp_stream=None):
# print(" fwd pp stream", pp_stream)
event_to_wait = combine_bw_event_to_wait
for i, n in enumerate(self.nodes):
pp_stream_t = pp_stream
if i + 1 != len(self.nodes):
pp_stream_t = None

inputs, output_grad, event_to_wait = n.forward_backward(
inputs, output_grad, combine_bw_event_to_wait=event_to_wait, pp_stream=pp_stream_t
)
def forward_backward(self, inputs, output_grad, event_to_wait=None):
for n in self.nodes:
inputs, output_grad, event_to_wait = n.forward_backward(inputs, output_grad, event_to_wait)
return inputs, output_grad, None


Expand Down Expand Up @@ -632,7 +586,7 @@ def combine_forward(self, inputs, async_finish=False, previous_event=None, alloc
ret = (inputs_embeds_mtp, *ret) if self.send_mtp_embed else ret
return ret

def post_process_forward(self, inputs, with_residual=True):
def post_process_forward(self, inputs):
if self.send_mtp_embed:
(inputs_embeds_mtp, hidden_states, residual, l_aux, output_combine) = inputs
else:
Expand All @@ -642,10 +596,7 @@ def post_process_forward(self, inputs, with_residual=True):
inputs = (hidden_states, residual, l_aux, final_hidden_states)
inputs = (inputs_embeds_mtp, *inputs) if self.send_mtp_embed else inputs

if with_residual:
inputs = self.post_process_node.forward(inputs)
else:
inputs = self.post_process_node.forward_without_residual(inputs)
inputs = self.post_process_node.forward(inputs)
return inputs

def post_process_backward(self, output_grad, event_to_wait=None):
Expand All @@ -664,7 +615,7 @@ def post_process_backward(self, output_grad, event_to_wait=None):
ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret
return ret

def combine_backward(self, output_grad, previous_event=None, async_finish=False, allocate_on_comm_stream=False):
def combine_backward(self, output_grad, async_finish=False, allocate_on_comm_stream=False):
if self.send_mtp_embed:
(
inputs_embeds_mtp_grad,
Expand All @@ -675,22 +626,12 @@ def combine_backward(self, output_grad, previous_event=None, async_finish=False,
quant_event,
) = output_grad
else:
(
hidden_states_grad,
residual_grad,
l_aux_grad,
output_combine_grad,
quant_event,
) = output_grad
hidden_states_grad, residual_grad, l_aux_grad, output_combine_grad, quant_event = output_grad

if DSV3_USE_FP8_DISPATCH and quant_event is not None:
combine_backward_wait_event = quant_event
else:
combine_backward_wait_event = previous_event
hidden_states_out_grad = self.fp8_fusion_moe_node.combine_node.backward(
output_combine_grad,
async_finish=async_finish,
previous_event=combine_backward_wait_event,
previous_event=quant_event,
allocate_on_comm_stream=allocate_on_comm_stream and quant_event is not None,
)

Expand Down Expand Up @@ -797,34 +738,25 @@ def __init__(self, forward_node, backward_node, name=""):
self.backward_node = backward_node
self.name = name

def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, pp_stream=None):
def forward_backward(self, inputs, output_grad, event_to_wait=None):
paddle.base.core.nvprof_nvtx_push("forward_backward")

combine_bwd_event = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id)

paddle.base.core.nvprof_nvtx_push("attn_forward")
inputs = self.forward_node.attn_forward(inputs)
paddle.base.core.nvprof_nvtx_pop()
attn_compute_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id)

paddle.base.core.nvprof_nvtx_push("post_process_backward")
output_grad = self.backward_node.post_process_backward(output_grad, combine_bw_event_to_wait)
output_grad = self.backward_node.post_process_backward(output_grad, event_to_wait)
paddle.base.core.nvprof_nvtx_pop()

paddle.base.core.nvprof_nvtx_push("combine_backward")
if combine_bw_event_to_wait is not None:
# print(" event", combine_bw_event_to_wait)
output_grad = self.backward_node.combine_backward(
output_grad, previous_event=combine_bw_event_to_wait, async_finish=True, allocate_on_comm_stream=True
)
else:
output_grad = self.backward_node.combine_backward(
output_grad, previous_event=combine_bwd_event, async_finish=True, allocate_on_comm_stream=True
)
output_grad = self.backward_node.combine_backward(output_grad, async_finish=True, allocate_on_comm_stream=True)
# get combine event
combine_backward_event = deep_ep.get_event_from_comm_stream(self.backward_node.moe_group.id)
paddle.base.core.nvprof_nvtx_pop()

paddle.base.core.nvprof_nvtx_push("attn_forward")
inputs = self.forward_node.attn_forward(inputs)
paddle.base.core.nvprof_nvtx_pop()

attn_compute_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id)

combine_backward_event.calc_stream_wait(self.backward_node.moe_group.id)
paddle.base.core.nvprof_nvtx_push("mlp_backward_dx")
output_grad = self.backward_node.mlp_backward(output_grad)
Expand Down Expand Up @@ -855,61 +787,26 @@ def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, p
paddle.base.core.nvprof_nvtx_push("mlp_forward")
inputs = self.forward_node.mlp_forward(inputs)
paddle.base.core.nvprof_nvtx_pop()
mlp_fwd_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id)

if pp_stream is not None:
final_out = self.forward_node.post_process_node.forward_without_residual(inputs)

final_out_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id)
inputs_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id)

paddle.base.core.nvprof_nvtx_push("combine_forward")
inputs = self.forward_node.combine_forward(
inputs, previous_event=mlp_fwd_event, async_finish=True, allocate_on_comm_stream=True
inputs, async_finish=True, previous_event=inputs_event, allocate_on_comm_stream=True
)
paddle.base.core.nvprof_nvtx_pop()

combine_forward_event = deep_ep.get_event_from_comm_stream(self.forward_node.moe_group.id)

combine_fwd_out = inputs[-1]

if pp_stream is not None:
send_recv_stream = paddle.device.Stream(stream_base=pp_stream)

# combine_forward_event.custom_stream_wait( pp_stream)
# final_out_event.custom_stream_wait(pp_stream)

paddle.base.core.nvprof_nvtx_push("pp stream add")

with paddle.device.stream_guard(send_recv_stream):
combine_forward_event.current_stream_wait()
final_out_event.current_stream_wait()

inputs = final_out + combine_fwd_out

final_out._record_stream()
combine_fwd_out._record_stream()

paddle.base.core.nvprof_nvtx_pop()

dispatch_backward_event.calc_stream_wait(self.backward_node.moe_group.id)
paddle.base.core.nvprof_nvtx_push("post_process_forward")

paddle.base.core.nvprof_nvtx_pop()
paddle.base.core.nvprof_nvtx_push("attn_backward")
output_grad = self.backward_node.attn_backward(output_grad)
event_to_wait = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id)

paddle.base.core.nvprof_nvtx_pop()

# residual add
if pp_stream is None:
combine_forward_event.calc_stream_wait(self.forward_node.moe_group.id)

final_out = self.forward_node.post_process_node.forward_without_residual(inputs)
inputs = final_out + combine_fwd_out

combine_fwd_out._record_stream()

combine_forward_event.calc_stream_wait(self.forward_node.moe_group.id)
paddle.base.core.nvprof_nvtx_push("post_process_forward")
inputs = self.forward_node.post_process_forward(inputs)
paddle.base.core.nvprof_nvtx_pop()
paddle.base.core.nvprof_nvtx_pop()
return inputs, output_grad, event_to_wait

Expand Down Expand Up @@ -1683,10 +1580,7 @@ def overlapped_forward_backward(
forward_inputs = forward_pre_node.forward(forward_inputs)
backward_input_grads = backward_pre_node.backward(backward_input_grads)
forward_inputs, backward_input_grads, _ = overlap_node.forward_backward(
forward_inputs,
backward_input_grads,
combine_bw_event_to_wait=combine_bw_event_to_wait,
pp_stream=pp_stream,
forward_inputs, backward_input_grads, combine_bw_event_to_wait
)
forward_inputs = forward_post_node.forward(forward_inputs)
backward_input_grads = backward_post_node.backward(backward_input_grads)
Expand Down
20 changes: 3 additions & 17 deletions paddlenlp/transformers/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,17 +917,18 @@ def bwd_gate_up_weight(self, do1, input_x, expert_w1, clear_input=False):
@paddle.no_grad()
def forward(self, hs_out, unzipped_probs, tokens_per_expert, origin_token_per_experts, output=None):
self.origin_token_per_experts = origin_token_per_experts
# deal 0 size
dtype = paddle.bfloat16
if hs_out is None:
assert self.input_fp8 is not None
assert self.input_scale is not None
shape = self.input_fp8.shape
dtype = paddle.bfloat16
else:
if isinstance(hs_out, tuple):
shape = hs_out[0].shape
dtype = hs_out[0].dtype
else:
shape = hs_out.shape
dtype = hs_out.dtype

if shape[0] == 0:
o3 = paddle.zeros(shape, dtype=dtype)
Expand Down Expand Up @@ -957,12 +958,6 @@ def forward(self, hs_out, unzipped_probs, tokens_per_expert, origin_token_per_ex

@paddle.no_grad()
def backward(self, out_grad):
# deal 0 size
dtype = paddle.bfloat16
shape = out_grad[0].shape if isinstance(out_grad, tuple) else out_grad.shape
if shape[0] == 0:
return paddle.zeros_like(out_grad, dtype=dtype), paddle.zeros_like(self.unzipped_probs, dtype=dtype)

# recompute expert_w2 and expert_w1
expert_w1 = [x.w1 for x in self.experts if x is not None]
expert_w2 = [x.w2 for x in self.experts if x is not None]
Expand Down Expand Up @@ -1000,12 +995,6 @@ def backward(self, out_grad):

@paddle.no_grad()
def backward_dx(self, out_grad):
# deal 0 size
dtype = paddle.bfloat16
shape = out_grad[0].shape if isinstance(out_grad, tuple) else out_grad.shape
if shape[0] == 0:
return paddle.zeros_like(out_grad, dtype=dtype), paddle.zeros_like(self.unzipped_probs, dtype=dtype)

# recompute expert_w2 and expert_w1
expert_w1 = [x.w1 for x in self.experts if x is not None]
expert_w2 = [x.w2 for x in self.experts if x is not None]
Expand Down Expand Up @@ -1038,9 +1027,6 @@ def backward_dx(self, out_grad):

@paddle.no_grad()
def backward_dw(self):
# deal 0 size
if self.input_fp8 is None or self.input_fp8.shape[0] == 0:
return
# recompute expert_w2 and expert_w1
expert_w1 = [x.w1 for x in self.experts if x is not None]
expert_w2 = [x.w2 for x in self.experts if x is not None]
Expand Down
Loading