Skip to content

[pull] master from comfyanonymous:master #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Aug 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions comfy/latent_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,24 @@ class Flux(SD3):
def __init__(self):
self.scale_factor = 0.3611
self.shift_factor = 0.1159
self.latent_rgb_factors =[
[-0.0404, 0.0159, 0.0609],
[ 0.0043, 0.0298, 0.0850],
[ 0.0328, -0.0749, -0.0503],
[-0.0245, 0.0085, 0.0549],
[ 0.0966, 0.0894, 0.0530],
[ 0.0035, 0.0399, 0.0123],
[ 0.0583, 0.1184, 0.1262],
[-0.0191, -0.0206, -0.0306],
[-0.0324, 0.0055, 0.1001],
[ 0.0955, 0.0659, -0.0545],
[-0.0504, 0.0231, -0.0013],
[ 0.0500, -0.0008, -0.0088],
[ 0.0982, 0.0941, 0.0976],
[-0.1233, -0.0280, -0.0897],
[-0.0005, -0.0530, -0.0020],
[-0.1273, -0.0932, -0.0680]
]

def process_in(self, latent):
return (latent - self.shift_factor) * self.scale_factor
Expand Down
14 changes: 10 additions & 4 deletions comfy/ldm/flux/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,21 @@ def forward_orig(

def forward(self, x, timestep, context, y, guidance, **kwargs):
bs, c, h, w = x.shape
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
patch_size = 2
pad_h = (patch_size - h % 2) % patch_size
pad_w = (patch_size - w % 2) % patch_size

h_len = (h // 2)
w_len = (w // 2)
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='circular')

img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)

h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)

txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance)
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
17 changes: 4 additions & 13 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod
self.concat_keys = ()
logging.info("model_type {}".format(model_type.name))
logging.debug("adm {}".format(self.adm_channels))
self.memory_usage_factor = model_config.memory_usage_factor

def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t
Expand Down Expand Up @@ -252,11 +253,11 @@ def memory_required(self, input_shape):
dtype = self.manual_cast_dtype
#TODO: this needs to be tweaked
area = input_shape[0] * math.prod(input_shape[2:])
return (area * comfy.model_management.dtype_size(dtype) / 50) * (1024 * 1024)
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
else:
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
area = input_shape[0] * math.prod(input_shape[2:])
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024)


def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None):
Expand Down Expand Up @@ -354,6 +355,7 @@ def encode_adm(self, **kwargs):
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)


class SVD_img2vid(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
super().__init__(model_config, model_type, device=device)
Expand Down Expand Up @@ -594,17 +596,6 @@ def extra_conds(self, **kwargs):
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out

def memory_required(self, input_shape):
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
#TODO: this probably needs to be tweaked
area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * comfy.model_management.dtype_size(dtype) * 0.012) * (1024 * 1024)
else:
area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * 0.3) * (1024 * 1024)

class AuraFlow(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
Expand Down
15 changes: 11 additions & 4 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,11 +379,15 @@ def free_memory(memory_required, device, keep_loaded=[]):
if mem_free_torch > mem_free_total * 0.25:
soft_empty_cache()

def load_models_gpu(models, memory_required=0, force_patch_weights=False):
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None):
global vram_state

inference_memory = minimum_inference_memory()
extra_mem = max(inference_memory, memory_required)
if minimum_memory_required is None:
minimum_memory_required = extra_mem
else:
minimum_memory_required = max(inference_memory, minimum_memory_required)

models = set(models)

Expand Down Expand Up @@ -446,8 +450,8 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False):
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - extra_mem)))
if model_size <= (current_free_mem - inference_memory): #only switch to lowvram if really necessary
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required)))
if model_size <= lowvram_model_memory: #only switch to lowvram if really necessary
lowvram_model_memory = 0

if vram_set_state == VRAMState.NO_VRAM:
Expand Down Expand Up @@ -897,7 +901,10 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if directml_enabled:
return False

if cpu_mode() or mps_mode():
if mps_mode():
return True

if cpu_mode():
return False

if is_intel_xpu():
Expand Down
4 changes: 3 additions & 1 deletion comfy/sampler_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def prepare_sampling(model, noise_shape, conds):
device = model.load_device
real_model = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required)
real_model = model.model

return real_model, conds, models
Expand Down
11 changes: 11 additions & 0 deletions comfy/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class SD15(supported_models_base.BASE):
}

latent_format = latent_formats.SD15
memory_usage_factor = 1.0

def process_clip_state_dict(self, state_dict):
k = list(state_dict.keys())
Expand Down Expand Up @@ -77,6 +78,7 @@ class SD20(supported_models_base.BASE):
}

latent_format = latent_formats.SD15
memory_usage_factor = 1.0

def model_type(self, state_dict, prefix=""):
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
Expand Down Expand Up @@ -140,6 +142,7 @@ class SDXLRefiner(supported_models_base.BASE):
}

latent_format = latent_formats.SDXL
memory_usage_factor = 1.0

def get_model(self, state_dict, prefix="", device=None):
return model_base.SDXLRefiner(self, device=device)
Expand Down Expand Up @@ -178,6 +181,8 @@ class SDXL(supported_models_base.BASE):

latent_format = latent_formats.SDXL

memory_usage_factor = 0.7

def model_type(self, state_dict, prefix=""):
if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5
self.latent_format = latent_formats.SDXL_Playground_2_5()
Expand Down Expand Up @@ -505,6 +510,9 @@ class SD3(supported_models_base.BASE):

unet_extra_config = {}
latent_format = latent_formats.SD3

memory_usage_factor = 1.2

text_encoder_key_prefix = ["text_encoders."]

def get_model(self, state_dict, prefix="", device=None):
Expand Down Expand Up @@ -631,6 +639,9 @@ class Flux(supported_models_base.BASE):

unet_extra_config = {}
latent_format = latent_formats.Flux

memory_usage_factor = 2.6

supported_inference_dtypes = [torch.bfloat16, torch.float32]

vae_key_prefix = ["vae."]
Expand Down
2 changes: 2 additions & 0 deletions comfy/supported_models_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class BASE:
text_encoder_key_prefix = ["cond_stage_model."]
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]

memory_usage_factor = 2.0

manual_cast_dtype = None

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion comfy/text_encoders/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def tokenize_with_weights(self, text:str, return_word_ids=False):
return out

def untokenize(self, token_weight_pair):
return self.clip_g.untokenize(token_weight_pair)
return self.clip_l.untokenize(token_weight_pair)

def state_dict(self):
return {}
Expand Down
1 change: 1 addition & 0 deletions comfy_extras/nodes_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import json
import struct
import random
import hashlib
from comfy.cli_args import args

class EmptyLatentAudio:
Expand Down
47 changes: 47 additions & 0 deletions comfy_extras/nodes_flux.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import node_helpers

class CLIPTextEncodeFlux:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"

CATEGORY = "advanced/conditioning/flux"

def encode(self, clip, clip_l, t5xxl, guidance):
tokens = clip.tokenize(clip_l)
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]

output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
cond = output.pop("cond")
output["guidance"] = guidance
return ([[cond, output]], )

class FluxGuidance:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"conditioning": ("CONDITIONING", ),
"guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}),
}}

RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"

CATEGORY = "advanced/conditioning/flux"

def append(self, conditioning, guidance):
c = node_helpers.conditioning_set_values(conditioning, {"guidance": guidance})
return (c, )


NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
"FluxGuidance": FluxGuidance,
}
28 changes: 28 additions & 0 deletions comfy_extras/nodes_model_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,33 @@ def INPUT_TYPES(s):
def patch_aura(self, model, shift):
return self.patch(model, shift, multiplier=1.0)

class ModelSamplingFlux:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"shift": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 100.0, "step":0.01}),
}}

RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"

CATEGORY = "advanced/model"

def patch(self, model, shift):
m = model.clone()

sampling_base = comfy.model_sampling.ModelSamplingFlux
sampling_type = comfy.model_sampling.CONST

class ModelSamplingAdvanced(sampling_base, sampling_type):
pass

model_sampling = ModelSamplingAdvanced(model.model.model_config)
model_sampling.set_parameters(shift=shift)
m.add_object_patch("model_sampling", model_sampling)
return (m, )


class ModelSamplingContinuousEDM:
@classmethod
def INPUT_TYPES(s):
Expand Down Expand Up @@ -284,5 +311,6 @@ def rescale_cfg(args):
"ModelSamplingStableCascade": ModelSamplingStableCascade,
"ModelSamplingSD3": ModelSamplingSD3,
"ModelSamplingAuraFlow": ModelSamplingAuraFlow,
"ModelSamplingFlux": ModelSamplingFlux,
"RescaleCFG": RescaleCFG,
}
27 changes: 27 additions & 0 deletions comfy_extras/nodes_model_merging_model_specific.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,36 @@ def INPUT_TYPES(s):

return {"required": arg_dict}

class ModelMergeFlux1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"

@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}

argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})

arg_dict["img_in."] = argument
arg_dict["time_in."] = argument
arg_dict["guidance_in"] = argument
arg_dict["vector_in."] = argument
arg_dict["txt_in."] = argument

for i in range(19):
arg_dict["double_blocks.{}.".format(i)] = argument

for i in range(38):
arg_dict["single_blocks.{}.".format(i)] = argument

arg_dict["final_layer."] = argument

return {"required": arg_dict}

NODE_CLASS_MAPPINGS = {
"ModelMergeSD1": ModelMergeSD1,
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
"ModelMergeSDXL": ModelMergeSDXL,
"ModelMergeSD3_2B": ModelMergeSD3_2B,
"ModelMergeFlux1": ModelMergeFlux1,
}
2 changes: 1 addition & 1 deletion comfy_extras/nodes_pag.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}),
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01}),
}
}

Expand Down
2 changes: 1 addition & 1 deletion comfy_extras/nodes_sag.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class SelfAttentionGuidance:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.1}),
"scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.01}),
"blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}),
}}
RETURN_TYPES = ("MODEL",)
Expand Down
10 changes: 8 additions & 2 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,9 +826,14 @@ def INPUT_TYPES(s):
CATEGORY = "advanced/loaders"

def load_unet(self, unet_name, weight_dtype):
weight_dtype = {"default":None, "fp8_e4m3fn":torch.float8_e4m3fn, "fp8_e5m2":torch.float8_e4m3fn}[weight_dtype]
dtype = None
if weight_dtype == "fp8_e4m3fn":
dtype = torch.float8_e4m3fn
elif weight_dtype == "fp8_e5m2":
dtype = torch.float8_e5m2

unet_path = folder_paths.get_full_path("unet", unet_name)
model = comfy.sd.load_unet(unet_path, dtype=weight_dtype)
model = comfy.sd.load_unet(unet_path, dtype=dtype)
return (model,)

class CLIPLoader:
Expand Down Expand Up @@ -2043,6 +2048,7 @@ def init_builtin_extra_nodes():
"nodes_gits.py",
"nodes_controlnet.py",
"nodes_hunyuan.py",
"nodes_flux.py",
]

import_failed = []
Expand Down
Loading