From 63a7e8edba76b30e3c01190345126ae75c94777d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Aug 2024 11:53:30 -0400 Subject: [PATCH 1/8] More aggressive batch splitting. --- comfy/samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 3f763381412..ce4371d5073 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -171,7 +171,7 @@ def calc_cond_batch(model, conds, x_in, timestep, model_options): for i in range(1, len(to_batch_temp) + 1): batch_amount = to_batch_temp[:len(to_batch_temp)//i] input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] - if model.memory_required(input_shape) < free_memory: + if model.memory_required(input_shape) * 1.5 < free_memory: to_batch = batch_amount break From f123328b826dcd122d307b75288f89ea301fa25b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Aug 2024 12:39:33 -0400 Subject: [PATCH 2/8] Load T5 in fp8 if it's in fp8 in the Flux checkpoint. --- comfy/supported_models.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 681ef95c9e0..94fdcc0d2ad 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -652,7 +652,11 @@ def get_model(self, state_dict, prefix="", device=None): return out def clip_target(self, state_dict={}): - return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.FluxClipModel) + pref = self.text_encoder_key_prefix[0] + t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref) + if t5_key in state_dict: + dtype_t5 = state_dict[t5_key].dtype + return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5)) class FluxSchnell(Flux): unet_config = { From ba9095e5bd7914c2456b2dfe939c06180e97b1ad Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Aug 2024 13:45:19 -0400 Subject: [PATCH 3/8] Automatically use fp8 for diffusion model weights if: Checkpoint contains weights in fp8. There isn't enough memory to load the diffusion model in GPU vram. --- comfy/model_base.py | 1 + comfy/model_management.py | 22 ++++++++++++++++++++-- comfy/sd.py | 3 ++- comfy/utils.py | 12 +++++++++++- 4 files changed, 34 insertions(+), 4 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index ec15e9fcf5f..94f4d333c1c 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -94,6 +94,7 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod self.concat_keys = () logging.info("model_type {}".format(model_type.name)) logging.debug("adm {}".format(self.adm_channels)) + logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype)) self.memory_usage_factor = model_config.memory_usage_factor def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): diff --git a/comfy/model_management.py b/comfy/model_management.py index da0b989a853..c0fb1509567 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -527,6 +527,9 @@ def unet_inital_load_device(parameters, dtype): else: return cpu_dev +def maximum_vram_for_weights(device=None): + return (get_total_memory(device) * 0.8 - minimum_inference_memory()) + def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): if args.bf16_unet: return torch.bfloat16 @@ -536,6 +539,21 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor return torch.float8_e4m3fn if args.fp8_e5m2_unet: return torch.float8_e5m2 + + fp8_dtype = None + try: + for dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + if dtype in supported_dtypes: + fp8_dtype = dtype + break + except: + pass + + if fp8_dtype is not None: + free_model_memory = maximum_vram_for_weights(device) + if model_params * 2 > free_model_memory: + return fp8_dtype + if should_use_fp16(device=device, model_params=model_params, manual_cast=True): if torch.float16 in supported_dtypes: return torch.float16 @@ -871,7 +889,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma fp16_works = True if fp16_works or manual_cast: - free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) + free_model_memory = maximum_vram_for_weights(device) if (not prioritize_performance) or model_params * 4 > free_model_memory: return True @@ -920,7 +938,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma bf16_works = torch.cuda.is_bf16_supported() if bf16_works or manual_cast: - free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) + free_model_memory = maximum_vram_for_weights(device) if (not prioritize_performance) or model_params * 4 > free_model_memory: return True diff --git a/comfy/sd.py b/comfy/sd.py index 41ce18c803f..bf336c8590d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -510,13 +510,14 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix) + weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix) load_device = model_management.get_torch_device() model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix) if model_config is None: raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) - unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes) + unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=[weight_dtype] + model_config.supported_inference_dtypes) manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) diff --git a/comfy/utils.py b/comfy/utils.py index 0db9fbb6267..d9fe36f91b2 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -40,9 +40,19 @@ def calculate_parameters(sd, prefix=""): params = 0 for k in sd.keys(): if k.startswith(prefix): - params += sd[k].nelement() + w = sd[k] + params += w.nelement() return params +def weight_dtype(sd, prefix=""): + dtypes = {} + for k in sd.keys(): + if k.startswith(prefix): + w = sd[k] + dtypes[w.dtype] = dtypes.get(w.dtype, 0) + 1 + + return max(dtypes, key=dtypes.get) + def state_dict_key_replace(state_dict, keys_to_replace): for x in keys_to_replace: if x in state_dict: From 1e68002b87a3fb70afc7030c1b4dc6a31fea965e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Aug 2024 14:50:20 -0400 Subject: [PATCH 4/8] Cap lowvram to half of free memory. --- comfy/model_management.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index c0fb1509567..2008229f2d0 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -450,7 +450,8 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): model_size = loaded_model.model_memory_required(torch_dev) current_free_mem = get_free_memory(torch_dev) - lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required))) + lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required)) + lowvram_model_memory = int(min(current_free_mem * 0.5, lowvram_model_memory)) if model_size <= lowvram_model_memory: #only switch to lowvram if really necessary lowvram_model_memory = 0 From 2ba5cc8b867bc1aabe59fdaf0a8489e65012d603 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Aug 2024 15:06:40 -0400 Subject: [PATCH 5/8] Fix some issues. --- comfy/model_management.py | 3 +-- comfy/sd.py | 6 +++++- comfy/utils.py | 3 +++ 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 2008229f2d0..bb4bcbb2196 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -450,8 +450,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): model_size = loaded_model.model_memory_required(torch_dev) current_free_mem = get_free_memory(torch_dev) - lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required)) - lowvram_model_memory = int(min(current_free_mem * 0.5, lowvram_model_memory)) + lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), current_free_mem * 0.5) if model_size <= lowvram_model_memory: #only switch to lowvram if really necessary lowvram_model_memory = 0 diff --git a/comfy/sd.py b/comfy/sd.py index bf336c8590d..fac1a487fe3 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -517,7 +517,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o if model_config is None: raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) - unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=[weight_dtype] + model_config.supported_inference_dtypes) + unet_weight_dtype = list(model_config.supported_inference_dtypes) + if weight_dtype is not None: + unet_weight_dtype.append(weight_dtype) + + unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype) manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) diff --git a/comfy/utils.py b/comfy/utils.py index d9fe36f91b2..06e09170ac2 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -51,6 +51,9 @@ def weight_dtype(sd, prefix=""): w = sd[k] dtypes[w.dtype] = dtypes.get(w.dtype, 0) + 1 + if len(dtypes) == 0: + return None + return max(dtypes, key=dtypes.get) def state_dict_key_replace(state_dict, keys_to_replace): From 03c5018c98b9dd2654dc4942a0978ac53e755900 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Aug 2024 15:14:07 -0400 Subject: [PATCH 6/8] Lower lowvram memory to 1/3 of free memory. --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index bb4bcbb2196..c4402a8a7b7 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -450,7 +450,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): model_size = loaded_model.model_memory_required(torch_dev) current_free_mem = get_free_memory(torch_dev) - lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), current_free_mem * 0.5) + lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), current_free_mem * 0.33) if model_size <= lowvram_model_memory: #only switch to lowvram if really necessary lowvram_model_memory = 0 From 91be9c2867ef9ae5b255f038665649536c1e1b8b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Aug 2024 16:34:27 -0400 Subject: [PATCH 7/8] Tweak lowvram memory formula. --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index c4402a8a7b7..b280b149d07 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -450,7 +450,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): model_size = loaded_model.model_memory_required(torch_dev) current_free_mem = get_free_memory(torch_dev) - lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), current_free_mem * 0.33) + lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory())) if model_size <= lowvram_model_memory: #only switch to lowvram if really necessary lowvram_model_memory = 0 From f7a5107784cded39f92a4bb7553507575e78edbe Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Aug 2024 16:55:38 -0400 Subject: [PATCH 8/8] Fix crash. --- comfy/model_management.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index b280b149d07..fb27470152c 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -928,10 +928,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma if is_intel_xpu(): return True - if device is None: - device = torch.device("cuda") - - props = torch.cuda.get_device_properties(device) + props = torch.cuda.get_device_properties("cuda") if props.major >= 8: return True