Skip to content

Commit baaa080

Browse files
authored
fix_qwen3_moe (#10801)
1 parent 2828bc7 commit baaa080

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

paddlenlp/transformers/moe_gate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ def topkgating(
496496
top_gate = top_gate * self.routed_scaling_factor
497497

498498
# get topk mask
499-
mask = paddle.zeros_like(gates).put_along_axis(top_idx, paddle.to_tensor(1.0, dtype="float32"), axis=1)
499+
mask = paddle.zeros_like(gates).put_along_axis(top_idx, paddle.to_tensor(1.0, dtype=gates.dtype), axis=1)
500500
if hasattr(self.config, "seq_aux") and self.config.seq_aux:
501501
l_aux = self._cal_seq_aux_loss(gates_ori, self.top_k, top_idx)
502502
else:

paddlenlp/transformers/qwen2/modeling.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,9 @@ def __init__(self, config: Qwen2Config, layerwise_recompute: bool = True, skip_r
590590
)
591591
self.k_proj = ColumnParallelLinear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=self.has_bias, gather_output=False) # fmt:skip
592592
self.v_proj = ColumnParallelLinear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=self.has_bias, gather_output=False) # fmt:skip
593-
self.o_proj = RowParallelLinear(self.hidden_size, self.hidden_size, has_bias=False, input_is_parallel=True)
593+
self.o_proj = RowParallelLinear(
594+
self.num_attention_heads * self.head_dim, self.hidden_size, has_bias=False, input_is_parallel=True
595+
)
594596
else:
595597
if self.fuse_attention_qkv:
596598
self.qkv_proj = Linear(

0 commit comments

Comments
 (0)