From 80a44b97f5cbcb890896e2b9e65d177f1ac6a588 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 4 May 2025 03:39:23 -0700 Subject: [PATCH] Change lumina to native RMSNorm. (#7935) --- comfy/ldm/lumina/model.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index ccd5d2c0ea0..f8dc4d7db6f 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -8,7 +8,7 @@ import torch.nn.functional as F import comfy.ldm.common_dit -from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, RMSNorm +from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.flux.layers import EmbedND @@ -64,8 +64,8 @@ def __init__( ) if qk_norm: - self.q_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings) - self.k_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings) + self.q_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.k_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) else: self.q_norm = self.k_norm = nn.Identity() @@ -242,11 +242,11 @@ def __init__( operation_settings=operation_settings, ) self.layer_id = layer_id - self.attention_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) - self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) + self.attention_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.ffn_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) - self.attention_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) - self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) + self.attention_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.ffn_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.modulation = modulation if modulation: @@ -431,7 +431,7 @@ def __init__( self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings) self.cap_embedder = nn.Sequential( - RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, **operation_settings), + operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), operation_settings.get("operations").Linear( cap_feat_dim, dim, @@ -457,7 +457,7 @@ def __init__( for layer_id in range(n_layers) ] ) - self.norm_final = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) + self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings) assert (dim // n_heads) == sum(axes_dims)