Skip to content

[pull] master from comfyanonymous:master #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions comfy/ldm/aura/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from comfy.ldm.modules.attention import optimized_attention
import comfy.ops
import comfy.ldm.common_dit

def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
Expand Down Expand Up @@ -407,10 +408,7 @@ def unpatchify(self, x, h, w):

def patchify(self, x):
B, C, H, W = x.size()
pad_h = (self.patch_size - H % self.patch_size) % self.patch_size
pad_w = (self.patch_size - W % self.patch_size) % self.patch_size

x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='circular')
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
x = x.view(
B,
C,
Expand Down
8 changes: 8 additions & 0 deletions comfy/ldm/common_dit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import torch

def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
padding_mode = "reflect"
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
8 changes: 3 additions & 5 deletions comfy/ldm/flux/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)

from einops import rearrange, repeat
import comfy.ldm.common_dit

@dataclass
class FluxParams:
Expand Down Expand Up @@ -42,7 +43,7 @@ def __init__(self, image_model=None, dtype=None, device=None, operations=None, *
self.dtype = dtype
params = FluxParams(**kwargs)
self.params = params
self.in_channels = params.in_channels
self.in_channels = params.in_channels * 2 * 2
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
Expand Down Expand Up @@ -125,10 +126,7 @@ def forward_orig(
def forward(self, x, timestep, context, y, guidance, **kwargs):
bs, c, h, w = x.shape
patch_size = 2
pad_h = (patch_size - h % 2) % patch_size
pad_w = (patch_size - w % 2) % patch_size

x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='circular')
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))

img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)

Expand Down
5 changes: 2 additions & 3 deletions comfy/ldm/modules/diffusionmodules/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from einops import rearrange, repeat
from .util import timestep_embedding
import comfy.ops
import comfy.ldm.common_dit

def default(x, y):
if x is not None:
Expand Down Expand Up @@ -111,9 +112,7 @@ def forward(self, x):
# f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
# )
if self.dynamic_img_pad:
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode=self.padding_mode)
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode=self.padding_mode)
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
Expand Down
2 changes: 1 addition & 1 deletion comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod
if comfy.model_management.force_channels_last():
self.diffusion_model.to(memory_format=torch.channels_last)
logging.debug("using channels last mode for diffusion model")
logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type)

Expand All @@ -94,7 +95,6 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod
self.concat_keys = ()
logging.info("model_type {}".format(model_type.name))
logging.debug("adm {}".format(self.adm_channels))
logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
self.memory_usage_factor = model_config.memory_usage_factor

def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion comfy/model_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def detect_unet_config(state_dict, key_prefix):
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: #Flux
dit_config = {}
dit_config["image_model"] = "flux"
dit_config["in_channels"] = 64
dit_config["in_channels"] = 16
dit_config["vec_in_dim"] = 768
dit_config["context_in_dim"] = 4096
dit_config["hidden_size"] = 3072
Expand Down
4 changes: 2 additions & 2 deletions comfy/text_encoders/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def reset_clip_options(self):

def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_l = token_weight_pairs["l"]
token_weight_pars_t5 = token_weight_pairs["t5xxl"]
token_weight_pairs_t5 = token_weight_pairs["t5xxl"]

t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5)
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
return t5_out, l_pooled

Expand Down
4 changes: 2 additions & 2 deletions comfy/text_encoders/sd3_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def reset_clip_options(self):
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_l = token_weight_pairs["l"]
token_weight_pairs_g = token_weight_pairs["g"]
token_weight_pars_t5 = token_weight_pairs["t5xxl"]
token_weight_pairs_t5 = token_weight_pairs["t5xxl"]
lg_out = None
pooled = None
out = None
Expand All @@ -108,7 +108,7 @@ def encode_token_weights(self, token_weight_pairs):
pooled = torch.cat((l_pooled, g_pooled), dim=-1)

if self.t5xxl is not None:
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5)
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
if lg_out is not None:
out = torch.cat([lg_out, t5_out], dim=-2)
else:
Expand Down
14 changes: 12 additions & 2 deletions comfy_extras/nodes_model_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import comfy.sd
import comfy.model_sampling
import comfy.latent_formats
import nodes
import torch

class LCM(comfy.model_sampling.EPS):
Expand Down Expand Up @@ -174,17 +175,26 @@ class ModelSamplingFlux:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"shift": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 100.0, "step":0.01}),
"max_shift": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 100.0, "step":0.01}),
"base_shift": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step":0.01}),
"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
}}

RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"

CATEGORY = "advanced/model"

def patch(self, model, shift):
def patch(self, model, max_shift, base_shift, width, height):
m = model.clone()

x1 = 256
x2 = 4096
mm = (max_shift - base_shift) / (x2 - x1)
b = base_shift - mm * x1
shift = (width * height / (8 * 8 * 2 * 2)) * mm + b

sampling_base = comfy.model_sampling.ModelSamplingFlux
sampling_type = comfy.model_sampling.CONST

Expand Down
4 changes: 2 additions & 2 deletions comfy_extras/nodes_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def __init__(self):

@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
return {"required": { "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
Expand Down