Skip to content

Commit a96e65d

Browse files
Disable omnigen2 fp16 on older pytorch versions. (comfyanonymous#8672)
1 parent 93a49a4 commit a96e65d

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

comfy/model_management.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,6 +1290,13 @@ def supports_fp8_compute(device=None):
12901290

12911291
return True
12921292

1293+
def extended_fp16_support():
1294+
# TODO: check why some models work with fp16 on newer torch versions but not on older
1295+
if torch_version_numeric < (2, 7):
1296+
return False
1297+
1298+
return True
1299+
12931300
def soft_empty_cache(force=False):
12941301
global cpu_state
12951302
if cpu_state == CPUState.MPS:

comfy/supported_models.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1197,11 +1197,16 @@ class Omnigen2(supported_models_base.BASE):
11971197
unet_extra_config = {}
11981198
latent_format = latent_formats.Flux
11991199

1200-
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
1200+
supported_inference_dtypes = [torch.bfloat16, torch.float32]
12011201

12021202
vae_key_prefix = ["vae."]
12031203
text_encoder_key_prefix = ["text_encoders."]
12041204

1205+
def __init__(self, unet_config):
1206+
super().__init__(unet_config)
1207+
if comfy.model_management.extended_fp16_support():
1208+
self.supported_inference_dtypes = [torch.float16] + self.supported_inference_dtypes
1209+
12051210
def get_model(self, state_dict, prefix="", device=None):
12061211
out = model_base.Omnigen2(self, device=device)
12071212
return out

0 commit comments

Comments
 (0)