@@ -412,9 +412,13 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o
412
412
ds .pop (0 )
413
413
if callback is not None :
414
414
callback ({'x' : x , 'i' : i , 'sigma' : sigmas [i ], 'sigma_hat' : sigmas [i ], 'denoised' : denoised })
415
- cur_order = min (i + 1 , order )
416
- coeffs = [linear_multistep_coeff (cur_order , sigmas_cpu , i , j ) for j in range (cur_order )]
417
- x = x + sum (coeff * d for coeff , d in zip (coeffs , reversed (ds )))
415
+ if sigmas [i + 1 ] == 0 :
416
+ # Denoising step
417
+ x = denoised
418
+ else :
419
+ cur_order = min (i + 1 , order )
420
+ coeffs = [linear_multistep_coeff (cur_order , sigmas_cpu , i , j ) for j in range (cur_order )]
421
+ x = x + sum (coeff * d for coeff , d in zip (coeffs , reversed (ds )))
418
422
return x
419
423
420
424
@@ -1067,7 +1071,9 @@ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None,
1067
1071
d_cur = (x_cur - denoised ) / t_cur
1068
1072
1069
1073
order = min (max_order , i + 1 )
1070
- if order == 1 : # First Euler step.
1074
+ if t_next == 0 : # Denoising step
1075
+ x_next = denoised
1076
+ elif order == 1 : # First Euler step.
1071
1077
x_next = x_cur + (t_next - t_cur ) * d_cur
1072
1078
elif order == 2 : # Use one history point.
1073
1079
x_next = x_cur + (t_next - t_cur ) * (3 * d_cur - buffer_model [- 1 ]) / 2
@@ -1085,6 +1091,7 @@ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None,
1085
1091
1086
1092
return x_next
1087
1093
1094
+
1088
1095
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
1089
1096
#under Apache 2 license
1090
1097
def sample_ipndm_v (model , x , sigmas , extra_args = None , callback = None , disable = None , max_order = 4 ):
@@ -1108,7 +1115,9 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
1108
1115
d_cur = (x_cur - denoised ) / t_cur
1109
1116
1110
1117
order = min (max_order , i + 1 )
1111
- if order == 1 : # First Euler step.
1118
+ if t_next == 0 : # Denoising step
1119
+ x_next = denoised
1120
+ elif order == 1 : # First Euler step.
1112
1121
x_next = x_cur + (t_next - t_cur ) * d_cur
1113
1122
elif order == 2 : # Use one history point.
1114
1123
h_n = (t_next - t_cur )
@@ -1148,6 +1157,7 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non
1148
1157
1149
1158
return x_next
1150
1159
1160
+
1151
1161
#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
1152
1162
#under Apache 2 license
1153
1163
@torch .no_grad ()
@@ -1198,6 +1208,7 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None,
1198
1208
1199
1209
return x_next
1200
1210
1211
+
1201
1212
@torch .no_grad ()
1202
1213
def sample_euler_cfg_pp (model , x , sigmas , extra_args = None , callback = None , disable = None ):
1203
1214
extra_args = {} if extra_args is None else extra_args
@@ -1404,6 +1415,7 @@ def sample_res_multistep_ancestral(model, x, sigmas, extra_args=None, callback=N
1404
1415
def sample_res_multistep_ancestral_cfg_pp (model , x , sigmas , extra_args = None , callback = None , disable = None , eta = 1. , s_noise = 1. , noise_sampler = None ):
1405
1416
return res_multistep (model , x , sigmas , extra_args = extra_args , callback = callback , disable = disable , s_noise = s_noise , noise_sampler = noise_sampler , eta = eta , cfg_pp = True )
1406
1417
1418
+
1407
1419
@torch .no_grad ()
1408
1420
def sample_gradient_estimation (model , x , sigmas , extra_args = None , callback = None , disable = None , ge_gamma = 2. , cfg_pp = False ):
1409
1421
"""Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
@@ -1430,19 +1442,19 @@ def post_cfg_function(args):
1430
1442
if callback is not None :
1431
1443
callback ({'x' : x , 'i' : i , 'sigma' : sigmas [i ], 'sigma_hat' : sigmas [i ], 'denoised' : denoised })
1432
1444
dt = sigmas [i + 1 ] - sigmas [i ]
1433
- if i == 0 :
1445
+ if sigmas [i + 1 ] == 0 :
1446
+ # Denoising step
1447
+ x = denoised
1448
+ else :
1434
1449
# Euler method
1435
1450
if cfg_pp :
1436
1451
x = denoised + d * sigmas [i + 1 ]
1437
1452
else :
1438
1453
x = x + d * dt
1439
- else :
1440
- # Gradient estimation
1441
- if cfg_pp :
1454
+
1455
+ if i >= 1 :
1456
+ # Gradient estimation
1442
1457
d_bar = (ge_gamma - 1 ) * (d - old_d )
1443
- x = denoised + d * sigmas [i + 1 ] + d_bar * dt
1444
- else :
1445
- d_bar = ge_gamma * d + (1 - ge_gamma ) * old_d
1446
1458
x = x + d_bar * dt
1447
1459
old_d = d
1448
1460
return x
0 commit comments