Skip to content

Commit f41f323

Browse files
authored
Add the denoising step to several samplers (comfyanonymous#8780)
1 parent f74fc4d commit f41f323

File tree

1 file changed

+24
-12
lines changed

1 file changed

+24
-12
lines changed

comfy/k_diffusion/sampling.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -412,9 +412,13 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o
412412
ds.pop(0)
413413
if callback is not None:
414414
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
415-
cur_order = min(i + 1, order)
416-
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
417-
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
415+
if sigmas[i + 1] == 0:
416+
# Denoising step
417+
x = denoised
418+
else:
419+
cur_order = min(i + 1, order)
420+
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
421+
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
418422
return x
419423

420424

@@ -1067,7 +1071,9 @@ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None,
10671071
d_cur = (x_cur - denoised) / t_cur
10681072

10691073
order = min(max_order, i+1)
1070-
if order == 1: # First Euler step.
1074+
if t_next == 0: # Denoising step
1075+
x_next = denoised
1076+
elif order == 1: # First Euler step.
10711077
x_next = x_cur + (t_next - t_cur) * d_cur
10721078
elif order == 2: # Use one history point.
10731079
x_next = x_cur + (t_next - t_cur) * (3 * d_cur - buffer_model[-1]) / 2
@@ -1085,6 +1091,7 @@ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None,
10851091

10861092
return x_next
10871093

1094+
10881095
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
10891096
#under Apache 2 license
10901097
def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
@@ -1108,7 +1115,9 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
11081115
d_cur = (x_cur - denoised) / t_cur
11091116

11101117
order = min(max_order, i+1)
1111-
if order == 1: # First Euler step.
1118+
if t_next == 0: # Denoising step
1119+
x_next = denoised
1120+
elif order == 1: # First Euler step.
11121121
x_next = x_cur + (t_next - t_cur) * d_cur
11131122
elif order == 2: # Use one history point.
11141123
h_n = (t_next - t_cur)
@@ -1148,6 +1157,7 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
11481157

11491158
return x_next
11501159

1160+
11511161
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
11521162
#under Apache 2 license
11531163
@torch.no_grad()
@@ -1198,6 +1208,7 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None,
11981208

11991209
return x_next
12001210

1211+
12011212
@torch.no_grad()
12021213
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
12031214
extra_args = {} if extra_args is None else extra_args
@@ -1404,6 +1415,7 @@ def sample_res_multistep_ancestral(model, x, sigmas, extra_args=None, callback=N
14041415
def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
14051416
return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True)
14061417

1418+
14071419
@torch.no_grad()
14081420
def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2., cfg_pp=False):
14091421
"""Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
@@ -1430,19 +1442,19 @@ def post_cfg_function(args):
14301442
if callback is not None:
14311443
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
14321444
dt = sigmas[i + 1] - sigmas[i]
1433-
if i == 0:
1445+
if sigmas[i + 1] == 0:
1446+
# Denoising step
1447+
x = denoised
1448+
else:
14341449
# Euler method
14351450
if cfg_pp:
14361451
x = denoised + d * sigmas[i + 1]
14371452
else:
14381453
x = x + d * dt
1439-
else:
1440-
# Gradient estimation
1441-
if cfg_pp:
1454+
1455+
if i >= 1:
1456+
# Gradient estimation
14421457
d_bar = (ge_gamma - 1) * (d - old_d)
1443-
x = denoised + d * sigmas[i + 1] + d_bar * dt
1444-
else:
1445-
d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
14461458
x = x + d_bar * dt
14471459
old_d = d
14481460
return x

0 commit comments

Comments
 (0)