Skip to content

Commit 2ada171

Browse files
committed
Unify Qwen2RMSNorm definitions and use RMSNorm from PyTorch
Signed-off-by: cyy <cyyever@outlook.com>
1 parent f4d57f2 commit 2ada171

File tree

8 files changed

+87
-67
lines changed

8 files changed

+87
-67
lines changed

docs/source/en/model_doc/qwen2.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,11 @@ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
159159

160160
[[autodoc]] Qwen2TokenizerFast
161161

162+
## Qwen2RMSNorm
163+
164+
[[autodoc]] Qwen2RMSNorm
165+
- forward
166+
162167
## Qwen2Model
163168

164169
[[autodoc]] Qwen2Model

src/transformers/models/dots1/modeling_dots1.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import torch
2424
import torch.nn.functional as F
25+
from packaging import version
2526
from torch import nn
2627

2728
from ...activations import ACT2FN
@@ -38,20 +39,30 @@
3839
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
3940
from ...utils.deprecation import deprecate_kwarg
4041
from ...utils.generic import check_model_inputs
42+
from ...utils.import_utils import get_torch_version
4143
from .configuration_dots1 import Dots1Config
4244

4345

4446
@use_kernel_forward_from_hub("RMSNorm")
4547
class Dots1RMSNorm(nn.Module):
46-
def __init__(self, hidden_size, eps=1e-6):
48+
def __init__(self, hidden_size, eps: float = 1e-6) -> None:
4749
"""
4850
Dots1RMSNorm is equivalent to T5LayerNorm
4951
"""
5052
super().__init__()
5153
self.weight = nn.Parameter(torch.ones(hidden_size))
5254
self.variance_epsilon = eps
55+
self.has_rms_norm = version.parse(get_torch_version()) >= version.parse("2.3.0")
56+
57+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
58+
if self.has_rms_norm:
59+
return F.rms_norm(
60+
input=hidden_states,
61+
normalized_shape=[hidden_states.shape[-1]],
62+
weight=self.weight,
63+
eps=self.variance_epsilon,
64+
)
5365

54-
def forward(self, hidden_states):
5566
input_dtype = hidden_states.dtype
5667
hidden_states = hidden_states.to(torch.float32)
5768
variance = hidden_states.pow(2).mean(-1, keepdim=True)

src/transformers/models/qwen2/modeling_qwen2.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from typing import Callable, Optional, Union
88

99
import torch
10+
import torch.nn.functional as F
11+
from packaging import version
1012
from torch import nn
1113

1214
from ...activations import ACT2FN
@@ -28,6 +30,7 @@
2830
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
2931
from ...utils.deprecation import deprecate_kwarg
3032
from ...utils.generic import check_model_inputs
33+
from ...utils.import_utils import get_torch_version
3134
from .configuration_qwen2 import Qwen2Config
3235

3336

@@ -185,15 +188,24 @@ def forward(
185188

186189
@use_kernel_forward_from_hub("RMSNorm")
187190
class Qwen2RMSNorm(nn.Module):
188-
def __init__(self, hidden_size, eps=1e-6):
191+
def __init__(self, hidden_size, eps: float = 1e-6) -> None:
189192
"""
190193
Qwen2RMSNorm is equivalent to T5LayerNorm
191194
"""
192195
super().__init__()
193196
self.weight = nn.Parameter(torch.ones(hidden_size))
194197
self.variance_epsilon = eps
198+
self.has_rms_norm = version.parse(get_torch_version()) >= version.parse("2.3.0")
199+
200+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
201+
if self.has_rms_norm:
202+
return F.rms_norm(
203+
input=hidden_states,
204+
normalized_shape=[hidden_states.shape[-1]],
205+
weight=self.weight,
206+
eps=self.variance_epsilon,
207+
)
195208

196-
def forward(self, hidden_states):
197209
input_dtype = hidden_states.dtype
198210
hidden_states = hidden_states.to(torch.float32)
199211
variance = hidden_states.pow(2).mean(-1, keepdim=True)
@@ -497,6 +509,7 @@ class Qwen2ForQuestionAnswering(GenericForQuestionAnswering, Qwen2PreTrainedMode
497509
"Qwen2PreTrainedModel",
498510
"Qwen2Model",
499511
"Qwen2ForCausalLM",
512+
"Qwen2RMSNorm",
500513
"Qwen2ForSequenceClassification",
501514
"Qwen2ForTokenClassification",
502515
"Qwen2ForQuestionAnswering",

src/transformers/models/qwen2/modular_qwen2.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from typing import Callable, Optional
22

33
import torch
4+
import torch.nn.functional as F
45
import torch.utils.checkpoint
6+
from packaging import version
57
from torch import nn
68

79
from ...cache_utils import Cache, DynamicCache
10+
from ...integrations import use_kernel_forward_from_hub
811
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
912
from ...modeling_flash_attention_utils import FlashAttentionKwargs
1013
from ...modeling_outputs import (
@@ -15,6 +18,7 @@
1518
from ...utils import TransformersKwargs, auto_docstring, logging
1619
from ...utils.deprecation import deprecate_kwarg
1720
from ...utils.generic import check_model_inputs
21+
from ...utils.import_utils import get_torch_version
1822
from ..llama.modeling_llama import (
1923
LlamaAttention,
2024
LlamaDecoderLayer,
@@ -97,6 +101,36 @@ def forward(
97101
return attn_output, attn_weights
98102

99103

104+
@use_kernel_forward_from_hub("RMSNorm")
105+
class Qwen2RMSNorm(nn.Module):
106+
def __init__(self, hidden_size, eps: float = 1e-6) -> None:
107+
"""
108+
Qwen2RMSNorm is equivalent to T5LayerNorm
109+
"""
110+
super().__init__()
111+
self.weight = nn.Parameter(torch.ones(hidden_size))
112+
self.variance_epsilon = eps
113+
self.has_rms_norm = version.parse(get_torch_version()) >= version.parse("2.3.0")
114+
115+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
116+
if self.has_rms_norm:
117+
return F.rms_norm(
118+
input=hidden_states,
119+
normalized_shape=[hidden_states.shape[-1]],
120+
weight=self.weight,
121+
eps=self.variance_epsilon,
122+
)
123+
124+
input_dtype = hidden_states.dtype
125+
hidden_states = hidden_states.to(torch.float32)
126+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
127+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
128+
return self.weight * hidden_states.to(input_dtype)
129+
130+
def extra_repr(self):
131+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
132+
133+
100134
class Qwen2DecoderLayer(LlamaDecoderLayer):
101135
def __init__(self, config: Qwen2Config, layer_idx: int):
102136
super().__init__()
@@ -206,6 +240,7 @@ class Qwen2ForQuestionAnswering(LlamaForQuestionAnswering):
206240
"Qwen2PreTrainedModel",
207241
"Qwen2Model",
208242
"Qwen2ForCausalLM",
243+
"Qwen2RMSNorm",
209244
"Qwen2ForSequenceClassification",
210245
"Qwen2ForTokenClassification",
211246
"Qwen2ForQuestionAnswering",

src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from ...utils import TransformersKwargs, auto_docstring, check_torch_load_is_safe, logging
4444
from ...utils.deprecation import deprecate_kwarg
4545
from ...utils.hub import cached_file
46+
from ..qwen2.modeling_qwen2 import Qwen2RMSNorm
4647
from .configuration_qwen2_5_omni import (
4748
Qwen2_5OmniAudioEncoderConfig,
4849
Qwen2_5OmniBigVGANConfig,
@@ -986,26 +987,6 @@ def forward(self, hidden_state):
986987
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
987988

988989

989-
class Qwen2RMSNorm(nn.Module):
990-
def __init__(self, hidden_size, eps=1e-6):
991-
"""
992-
Qwen2RMSNorm is equivalent to T5LayerNorm
993-
"""
994-
super().__init__()
995-
self.weight = nn.Parameter(torch.ones(hidden_size))
996-
self.variance_epsilon = eps
997-
998-
def forward(self, hidden_states):
999-
input_dtype = hidden_states.dtype
1000-
hidden_states = hidden_states.to(torch.float32)
1001-
variance = hidden_states.pow(2).mean(-1, keepdim=True)
1002-
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
1003-
return self.weight * hidden_states.to(input_dtype)
1004-
1005-
def extra_repr(self):
1006-
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
1007-
1008-
1009990
class Qwen2_5OmniVisionBlock(GradientCheckpointingLayer):
1010991
def __init__(self, config: Qwen2_5OmniVisionEncoderConfig) -> None:
1011992
super().__init__()

src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from ...processing_utils import Unpack
4444
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
4545
from ...utils.deprecation import deprecate_kwarg
46+
from ..qwen2.modeling_qwen2 import Qwen2RMSNorm
4647
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig
4748

4849

@@ -103,26 +104,6 @@ def forward(self, seqlen: int) -> torch.Tensor:
103104
return freqs
104105

105106

106-
class Qwen2RMSNorm(nn.Module):
107-
def __init__(self, hidden_size, eps=1e-6):
108-
"""
109-
Qwen2RMSNorm is equivalent to T5LayerNorm
110-
"""
111-
super().__init__()
112-
self.weight = nn.Parameter(torch.ones(hidden_size))
113-
self.variance_epsilon = eps
114-
115-
def forward(self, hidden_states):
116-
input_dtype = hidden_states.dtype
117-
hidden_states = hidden_states.to(torch.float32)
118-
variance = hidden_states.pow(2).mean(-1, keepdim=True)
119-
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
120-
return self.weight * hidden_states.to(input_dtype)
121-
122-
def extra_repr(self):
123-
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
124-
125-
126107
class Qwen2_5_VLPatchMerger(nn.Module):
127108
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
128109
super().__init__()

src/transformers/models/qwen2_vl/modeling_qwen2_vl.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@
4646
logging,
4747
)
4848
from ...utils.deprecation import deprecate_kwarg
49+
from ..qwen2.modeling_qwen2 import (
50+
Qwen2RMSNorm,
51+
)
4952
from .configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig, Qwen2VLVisionConfig
5053

5154

@@ -441,27 +444,6 @@ def forward(
441444
return hidden_states
442445

443446

444-
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm
445-
class Qwen2RMSNorm(nn.Module):
446-
def __init__(self, hidden_size, eps=1e-6):
447-
"""
448-
Qwen2RMSNorm is equivalent to T5LayerNorm
449-
"""
450-
super().__init__()
451-
self.weight = nn.Parameter(torch.ones(hidden_size))
452-
self.variance_epsilon = eps
453-
454-
def forward(self, hidden_states):
455-
input_dtype = hidden_states.dtype
456-
hidden_states = hidden_states.to(torch.float32)
457-
variance = hidden_states.pow(2).mean(-1, keepdim=True)
458-
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
459-
return self.weight * hidden_states.to(input_dtype)
460-
461-
def extra_repr(self):
462-
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
463-
464-
465447
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2MLP
466448
class Qwen2MLP(nn.Module):
467449
def __init__(self, config):

src/transformers/models/qwen3/modeling_qwen3.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from typing import Callable, Optional, Union
2323

2424
import torch
25+
import torch.nn.functional as F
26+
from packaging import version
2527
from torch import nn
2628

2729
from ...activations import ACT2FN
@@ -43,20 +45,30 @@
4345
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
4446
from ...utils.deprecation import deprecate_kwarg
4547
from ...utils.generic import check_model_inputs
48+
from ...utils.import_utils import get_torch_version
4649
from .configuration_qwen3 import Qwen3Config
4750

4851

4952
@use_kernel_forward_from_hub("RMSNorm")
5053
class Qwen3RMSNorm(nn.Module):
51-
def __init__(self, hidden_size, eps=1e-6):
54+
def __init__(self, hidden_size, eps: float = 1e-6) -> None:
5255
"""
5356
Qwen3RMSNorm is equivalent to T5LayerNorm
5457
"""
5558
super().__init__()
5659
self.weight = nn.Parameter(torch.ones(hidden_size))
5760
self.variance_epsilon = eps
61+
self.has_rms_norm = version.parse(get_torch_version()) >= version.parse("2.3.0")
62+
63+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
64+
if self.has_rms_norm:
65+
return F.rms_norm(
66+
input=hidden_states,
67+
normalized_shape=[hidden_states.shape[-1]],
68+
weight=self.weight,
69+
eps=self.variance_epsilon,
70+
)
5871

59-
def forward(self, hidden_states):
6072
input_dtype = hidden_states.dtype
6173
hidden_states = hidden_states.to(torch.float32)
6274
variance = hidden_states.pow(2).mean(-1, keepdim=True)

0 commit comments

Comments
 (0)