Skip to content

Commit 30159a7

Browse files
authored
Save v pred zsnr metadata (comfyanonymous#7840)
1 parent cb9ac3d commit 30159a7

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

comfy/model_sampling.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,14 @@ def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps
111111
self.num_timesteps = int(timesteps)
112112
self.linear_start = linear_start
113113
self.linear_end = linear_end
114+
self.zsnr = zsnr
114115

115116
# self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
116117
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
117118
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
118119

119120
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
120-
if zsnr:
121+
if self.zsnr:
121122
sigmas = rescale_zero_terminal_snr_sigmas(sigmas)
122123

123124
self.set_sigmas(sigmas)

comfy_extras/nodes_model_merging.py

+3
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,9 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
209209
metadata["modelspec.predict_key"] = "epsilon"
210210
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
211211
metadata["modelspec.predict_key"] = "v"
212+
extra_keys["v_pred"] = torch.tensor([])
213+
if getattr(model_sampling, "zsnr", False):
214+
extra_keys["ztsnr"] = torch.tensor([])
212215

213216
if not args.disable_metadata:
214217
metadata["prompt"] = prompt_info

0 commit comments

Comments
 (0)