Skip to content

Commit 13cca92

Browse files
committed
Unify Qwen2RMSNorm definitions and use RMSNorm from PyTorch
Signed-off-by: cyy <cyyever@outlook.com>
1 parent f4d57f2 commit 13cca92

File tree

7 files changed

+76
-61
lines changed

7 files changed

+76
-61
lines changed

src/transformers/models/dots1/modeling_dots1.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import torch
2424
import torch.nn.functional as F
25+
from packaging import version
2526
from torch import nn
2627

2728
from ...activations import ACT2FN
@@ -38,6 +39,7 @@
3839
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
3940
from ...utils.deprecation import deprecate_kwarg
4041
from ...utils.generic import check_model_inputs
42+
from ...utils.import_utils import get_torch_version
4143
from .configuration_dots1 import Dots1Config
4244

4345

@@ -50,8 +52,17 @@ def __init__(self, hidden_size, eps=1e-6):
5052
super().__init__()
5153
self.weight = nn.Parameter(torch.ones(hidden_size))
5254
self.variance_epsilon = eps
55+
self.has_rms_norm = version.parse(get_torch_version()) >= version.parse("2.3.0")
5356

5457
def forward(self, hidden_states):
58+
if self.has_rms_norm:
59+
return F.rms_norm(
60+
hidden_states.to(torch.float32, non_blocking=True),
61+
[hidden_states.shape[-1]],
62+
self.weight,
63+
self.variance_epsilon,
64+
)
65+
5566
input_dtype = hidden_states.dtype
5667
hidden_states = hidden_states.to(torch.float32)
5768
variance = hidden_states.pow(2).mean(-1, keepdim=True)

src/transformers/models/qwen2/modeling_qwen2.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from typing import Callable, Optional, Union
88

99
import torch
10+
import torch.nn.functional as F
11+
from packaging import version
1012
from torch import nn
1113

1214
from ...activations import ACT2FN
@@ -28,6 +30,7 @@
2830
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
2931
from ...utils.deprecation import deprecate_kwarg
3032
from ...utils.generic import check_model_inputs
33+
from ...utils.import_utils import get_torch_version
3134
from .configuration_qwen2 import Qwen2Config
3235

3336

@@ -192,8 +195,17 @@ def __init__(self, hidden_size, eps=1e-6):
192195
super().__init__()
193196
self.weight = nn.Parameter(torch.ones(hidden_size))
194197
self.variance_epsilon = eps
198+
self.has_rms_norm = version.parse(get_torch_version()) >= version.parse("2.3.0")
195199

196200
def forward(self, hidden_states):
201+
if self.has_rms_norm:
202+
return F.rms_norm(
203+
hidden_states.to(torch.float32, non_blocking=True),
204+
[hidden_states.shape[-1]],
205+
self.weight,
206+
self.variance_epsilon,
207+
)
208+
197209
input_dtype = hidden_states.dtype
198210
hidden_states = hidden_states.to(torch.float32)
199211
variance = hidden_states.pow(2).mean(-1, keepdim=True)
@@ -497,6 +509,7 @@ class Qwen2ForQuestionAnswering(GenericForQuestionAnswering, Qwen2PreTrainedMode
497509
"Qwen2PreTrainedModel",
498510
"Qwen2Model",
499511
"Qwen2ForCausalLM",
512+
"Qwen2RMSNorm",
500513
"Qwen2ForSequenceClassification",
501514
"Qwen2ForTokenClassification",
502515
"Qwen2ForQuestionAnswering",

src/transformers/models/qwen2/modular_qwen2.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from typing import Callable, Optional
22

33
import torch
4+
import torch.nn.functional as F
45
import torch.utils.checkpoint
6+
from packaging import version
57
from torch import nn
68

79
from ...cache_utils import Cache, DynamicCache
10+
from ...integrations import use_kernel_forward_from_hub
811
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
912
from ...modeling_flash_attention_utils import FlashAttentionKwargs
1013
from ...modeling_outputs import (
@@ -15,6 +18,7 @@
1518
from ...utils import TransformersKwargs, auto_docstring, logging
1619
from ...utils.deprecation import deprecate_kwarg
1720
from ...utils.generic import check_model_inputs
21+
from ...utils.import_utils import get_torch_version
1822
from ..llama.modeling_llama import (
1923
LlamaAttention,
2024
LlamaDecoderLayer,
@@ -96,6 +100,36 @@ def forward(
96100
attn_output = self.o_proj(attn_output)
97101
return attn_output, attn_weights
98102

103+
@use_kernel_forward_from_hub("RMSNorm")
104+
class Qwen2RMSNorm(nn.Module):
105+
def __init__(self, hidden_size, eps=1e-6):
106+
"""
107+
Qwen2RMSNorm is equivalent to T5LayerNorm
108+
"""
109+
super().__init__()
110+
self.weight = nn.Parameter(torch.ones(hidden_size))
111+
self.variance_epsilon = eps
112+
self.has_rms_norm = version.parse(get_torch_version()) >= version.parse("2.3.0")
113+
114+
def forward(self, hidden_states):
115+
if self.has_rms_norm:
116+
return F.rms_norm(
117+
hidden_states.to(torch.float32, non_blocking=True),
118+
[hidden_states.shape[-1]],
119+
self.weight,
120+
self.variance_epsilon,
121+
)
122+
123+
input_dtype = hidden_states.dtype
124+
hidden_states = hidden_states.to(torch.float32)
125+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
126+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
127+
return self.weight * hidden_states.to(input_dtype)
128+
129+
def extra_repr(self):
130+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
131+
132+
99133

100134
class Qwen2DecoderLayer(LlamaDecoderLayer):
101135
def __init__(self, config: Qwen2Config, layer_idx: int):
@@ -206,6 +240,7 @@ class Qwen2ForQuestionAnswering(LlamaForQuestionAnswering):
206240
"Qwen2PreTrainedModel",
207241
"Qwen2Model",
208242
"Qwen2ForCausalLM",
243+
"Qwen2RMSNorm",
209244
"Qwen2ForSequenceClassification",
210245
"Qwen2ForTokenClassification",
211246
"Qwen2ForQuestionAnswering",

src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from ...utils import TransformersKwargs, auto_docstring, check_torch_load_is_safe, logging
4444
from ...utils.deprecation import deprecate_kwarg
4545
from ...utils.hub import cached_file
46+
from ..qwen2.modeling_qwen2 import Qwen2RMSNorm
4647
from .configuration_qwen2_5_omni import (
4748
Qwen2_5OmniAudioEncoderConfig,
4849
Qwen2_5OmniBigVGANConfig,
@@ -986,26 +987,6 @@ def forward(self, hidden_state):
986987
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
987988

988989

989-
class Qwen2RMSNorm(nn.Module):
990-
def __init__(self, hidden_size, eps=1e-6):
991-
"""
992-
Qwen2RMSNorm is equivalent to T5LayerNorm
993-
"""
994-
super().__init__()
995-
self.weight = nn.Parameter(torch.ones(hidden_size))
996-
self.variance_epsilon = eps
997-
998-
def forward(self, hidden_states):
999-
input_dtype = hidden_states.dtype
1000-
hidden_states = hidden_states.to(torch.float32)
1001-
variance = hidden_states.pow(2).mean(-1, keepdim=True)
1002-
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
1003-
return self.weight * hidden_states.to(input_dtype)
1004-
1005-
def extra_repr(self):
1006-
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
1007-
1008-
1009990
class Qwen2_5OmniVisionBlock(GradientCheckpointingLayer):
1010991
def __init__(self, config: Qwen2_5OmniVisionEncoderConfig) -> None:
1011992
super().__init__()

src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from ...processing_utils import Unpack
4444
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
4545
from ...utils.deprecation import deprecate_kwarg
46+
from ..qwen2.modeling_qwen2 import Qwen2RMSNorm
4647
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig
4748

4849

@@ -103,26 +104,6 @@ def forward(self, seqlen: int) -> torch.Tensor:
103104
return freqs
104105

105106

106-
class Qwen2RMSNorm(nn.Module):
107-
def __init__(self, hidden_size, eps=1e-6):
108-
"""
109-
Qwen2RMSNorm is equivalent to T5LayerNorm
110-
"""
111-
super().__init__()
112-
self.weight = nn.Parameter(torch.ones(hidden_size))
113-
self.variance_epsilon = eps
114-
115-
def forward(self, hidden_states):
116-
input_dtype = hidden_states.dtype
117-
hidden_states = hidden_states.to(torch.float32)
118-
variance = hidden_states.pow(2).mean(-1, keepdim=True)
119-
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
120-
return self.weight * hidden_states.to(input_dtype)
121-
122-
def extra_repr(self):
123-
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
124-
125-
126107
class Qwen2_5_VLPatchMerger(nn.Module):
127108
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
128109
super().__init__()

src/transformers/models/qwen2_vl/modeling_qwen2_vl.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@
4646
logging,
4747
)
4848
from ...utils.deprecation import deprecate_kwarg
49+
from ..qwen2.modeling_qwen2 import (
50+
Qwen2RMSNorm,
51+
)
4952
from .configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig, Qwen2VLVisionConfig
5053

5154

@@ -441,27 +444,6 @@ def forward(
441444
return hidden_states
442445

443446

444-
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm
445-
class Qwen2RMSNorm(nn.Module):
446-
def __init__(self, hidden_size, eps=1e-6):
447-
"""
448-
Qwen2RMSNorm is equivalent to T5LayerNorm
449-
"""
450-
super().__init__()
451-
self.weight = nn.Parameter(torch.ones(hidden_size))
452-
self.variance_epsilon = eps
453-
454-
def forward(self, hidden_states):
455-
input_dtype = hidden_states.dtype
456-
hidden_states = hidden_states.to(torch.float32)
457-
variance = hidden_states.pow(2).mean(-1, keepdim=True)
458-
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
459-
return self.weight * hidden_states.to(input_dtype)
460-
461-
def extra_repr(self):
462-
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
463-
464-
465447
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2MLP
466448
class Qwen2MLP(nn.Module):
467449
def __init__(self, config):

src/transformers/models/qwen3/modeling_qwen3.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from typing import Callable, Optional, Union
2323

2424
import torch
25+
import torch.nn.functional as F
26+
from packaging import version
2527
from torch import nn
2628

2729
from ...activations import ACT2FN
@@ -43,6 +45,7 @@
4345
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
4446
from ...utils.deprecation import deprecate_kwarg
4547
from ...utils.generic import check_model_inputs
48+
from ...utils.import_utils import get_torch_version
4649
from .configuration_qwen3 import Qwen3Config
4750

4851

@@ -55,8 +58,17 @@ def __init__(self, hidden_size, eps=1e-6):
5558
super().__init__()
5659
self.weight = nn.Parameter(torch.ones(hidden_size))
5760
self.variance_epsilon = eps
61+
self.has_rms_norm = version.parse(get_torch_version()) >= version.parse("2.3.0")
5862

5963
def forward(self, hidden_states):
64+
if self.has_rms_norm:
65+
return F.rms_norm(
66+
hidden_states.to(torch.float32, non_blocking=True),
67+
[hidden_states.shape[-1]],
68+
self.weight,
69+
self.variance_epsilon,
70+
)
71+
6072
input_dtype = hidden_states.dtype
6173
hidden_states = hidden_states.to(torch.float32)
6274
variance = hidden_states.pow(2).mean(-1, keepdim=True)

0 commit comments

Comments
 (0)