Skip to content

Change Qwen2RMSNorm to RMSNorm from PyTorch #40066

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/qwen2.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,11 @@ print(tokenizer.decode(outputs[0], skip_special_tokens=True))

[[autodoc]] Qwen2TokenizerFast

## Qwen2RMSNorm

[[autodoc]] Qwen2RMSNorm
- forward

## Qwen2Model

[[autodoc]] Qwen2Model
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/dots1/modeling_dots1.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@

@use_kernel_forward_from_hub("RMSNorm")
class Dots1RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
def __init__(self, hidden_size, eps: float = 1e-6) -> None:
"""
Dots1RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,15 +185,15 @@ def forward(

@use_kernel_forward_from_hub("RMSNorm")
class Qwen2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
def __init__(self, hidden_size, eps: float = 1e-6) -> None:
"""
Qwen2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
Expand Down Expand Up @@ -497,6 +497,7 @@ class Qwen2ForQuestionAnswering(GenericForQuestionAnswering, Qwen2PreTrainedMode
"Qwen2PreTrainedModel",
"Qwen2Model",
"Qwen2ForCausalLM",
"Qwen2RMSNorm",
"Qwen2ForSequenceClassification",
"Qwen2ForTokenClassification",
"Qwen2ForQuestionAnswering",
Expand Down
33 changes: 33 additions & 0 deletions src/transformers/models/qwen2/modular_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import torch
import torch.utils.checkpoint
from packaging import version
from torch import nn

from ...cache_utils import Cache, DynamicCache
from ...integrations import use_kernel_forward_from_hub
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
Expand All @@ -15,6 +17,7 @@
from ...utils import TransformersKwargs, auto_docstring, logging
from ...utils.deprecation import deprecate_kwarg
from ...utils.generic import check_model_inputs
from ...utils.import_utils import get_torch_version
from ..llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
Expand Down Expand Up @@ -97,6 +100,35 @@ def forward(
return attn_output, attn_weights


if version.parse(get_torch_version()) >= version.parse("2.3.0"):

class Qwen2RMSNorm(nn.RMSNorm):
def __init__(self, hidden_size, eps: float = 1e-6) -> None:
super().__init__(normalized_shape=hidden_size, eps=eps, elementwise_affine=True)

else:

@use_kernel_forward_from_hub("RMSNorm")
class Qwen2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps: float = 1e-6) -> None:
"""
Qwen2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)

def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


class Qwen2DecoderLayer(LlamaDecoderLayer):
def __init__(self, config: Qwen2Config, layer_idx: int):
super().__init__()
Expand Down Expand Up @@ -206,6 +238,7 @@ class Qwen2ForQuestionAnswering(LlamaForQuestionAnswering):
"Qwen2PreTrainedModel",
"Qwen2Model",
"Qwen2ForCausalLM",
"Qwen2RMSNorm",
"Qwen2ForSequenceClassification",
"Qwen2ForTokenClassification",
"Qwen2ForQuestionAnswering",
Expand Down
21 changes: 1 addition & 20 deletions src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from ...utils import TransformersKwargs, auto_docstring, check_torch_load_is_safe, logging
from ...utils.deprecation import deprecate_kwarg
from ...utils.hub import cached_file
from ..qwen2.modeling_qwen2 import Qwen2RMSNorm
from .configuration_qwen2_5_omni import (
Qwen2_5OmniAudioEncoderConfig,
Qwen2_5OmniBigVGANConfig,
Expand Down Expand Up @@ -986,26 +987,6 @@ def forward(self, hidden_state):
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))


class Qwen2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Qwen2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)

def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


class Qwen2_5OmniVisionBlock(GradientCheckpointingLayer):
def __init__(self, config: Qwen2_5OmniVisionEncoderConfig) -> None:
super().__init__()
Expand Down
21 changes: 1 addition & 20 deletions src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
from ...utils.deprecation import deprecate_kwarg
from ..qwen2.modeling_qwen2 import Qwen2RMSNorm
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig


Expand Down Expand Up @@ -103,26 +104,6 @@ def forward(self, seqlen: int) -> torch.Tensor:
return freqs


class Qwen2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Qwen2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)

def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


class Qwen2_5_VLPatchMerger(nn.Module):
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
super().__init__()
Expand Down
24 changes: 3 additions & 21 deletions src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
logging,
)
from ...utils.deprecation import deprecate_kwarg
from ..qwen2.modeling_qwen2 import (
Qwen2RMSNorm,
)
from .configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig, Qwen2VLVisionConfig


Expand Down Expand Up @@ -441,27 +444,6 @@ def forward(
return hidden_states


# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm
class Qwen2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Qwen2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)

def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2MLP
class Qwen2MLP(nn.Module):
def __init__(self, config):
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/qwen3/modeling_qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@

@use_kernel_forward_from_hub("RMSNorm")
class Qwen3RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
def __init__(self, hidden_size, eps: float = 1e-6) -> None:
"""
Qwen3RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
Expand Down