diff --git a/comfy/ldm/genmo/joint_model/asymm_models_joint.py b/comfy/ldm/genmo/joint_model/asymm_models_joint.py index 2c46c24bf11..366a8b7133c 100644 --- a/comfy/ldm/genmo/joint_model/asymm_models_joint.py +++ b/comfy/ldm/genmo/joint_model/asymm_models_joint.py @@ -13,7 +13,6 @@ from .layers import ( FeedForward, PatchEmbed, - RMSNorm, TimestepEmbedder, ) @@ -90,10 +89,10 @@ def __init__( # Query and key normalization for stability. assert qk_norm - self.q_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype) - self.k_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype) - self.q_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype) - self.k_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype) + self.q_norm_x = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype) + self.k_norm_x = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype) + self.q_norm_y = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype) + self.k_norm_y = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype) # Output layers. y features go back down from dim_x -> dim_y. self.proj_x = operations.Linear(dim_x, dim_x, bias=out_bias, device=device, dtype=dtype) diff --git a/comfy/ldm/genmo/joint_model/layers.py b/comfy/ldm/genmo/joint_model/layers.py index 51d979559ed..e310bd71783 100644 --- a/comfy/ldm/genmo/joint_model/layers.py +++ b/comfy/ldm/genmo/joint_model/layers.py @@ -151,14 +151,3 @@ def forward(self, x): x = self.norm(x) return x - - -class RMSNorm(torch.nn.Module): - def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): - super().__init__() - self.eps = eps - self.weight = torch.nn.Parameter(torch.empty(hidden_size, device=device, dtype=dtype)) - self.register_parameter("bias", None) - - def forward(self, x): - return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 66bee748016..fc5ff40c5c9 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -9,7 +9,6 @@ from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.math import apply_rope -from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm import comfy.ldm.common_dit import comfy.model_management @@ -49,8 +48,8 @@ def __init__(self, self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) - self.norm_q = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() - self.norm_k = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() + self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() + self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() def forward(self, x, freqs): r""" @@ -114,7 +113,7 @@ def __init__(self, self.k_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.v_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) # self.alpha = nn.Parameter(torch.zeros((1, ))) - self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() + self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() def forward(self, x, context, context_img_len): r"""