Skip to content

Commit 8086045

Browse files
committed
Implement TCD, avoid repeated allocation in DDIM, implement eta parameter for DDIM and TCD, minor comment clarification
1 parent bf353f1 commit 8086045

File tree

4 files changed

+263
-44
lines changed

4 files changed

+263
-44
lines changed

denoiser.hpp

Lines changed: 236 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,8 @@ static void sample_k_diffusion(sample_method_t method,
474474
ggml_context* work_ctx,
475475
ggml_tensor* x,
476476
std::vector<float> sigmas,
477-
std::shared_ptr<RNG> rng) {
477+
std::shared_ptr<RNG> rng,
478+
float eta) {
478479
size_t steps = sigmas.size() - 1;
479480
// sample_euler_ancestral
480481
switch (method) {
@@ -1014,6 +1015,8 @@ static void sample_k_diffusion(sample_method_t method,
10141015
// structure hides from the denoiser), and the sigmas are
10151016
// also needed to invert the behavior of CompVisDenoiser
10161017
// (k-diffusion's LMSDiscreteScheduler)
1018+
float beta_start = 0.00085f;
1019+
float beta_end = 0.0120f;
10171020
std::vector<double> alphas_cumprod;
10181021
std::vector<double> compvis_sigmas;
10191022

@@ -1023,21 +1026,41 @@ static void sample_k_diffusion(sample_method_t method,
10231026
alphas_cumprod[i] =
10241027
(i == 0 ? 1.0f : alphas_cumprod[i - 1]) *
10251028
(1.0f -
1026-
std::pow(sqrtf(0.00085f) +
1027-
(sqrtf(0.0120f) - sqrtf(0.00085f)) *
1029+
std::pow(sqrtf(beta_start) +
1030+
(sqrtf(beta_end) - sqrtf(beta_start)) *
10281031
((float)i / (TIMESTEPS - 1)), 2));
10291032
compvis_sigmas[i] =
10301033
std::sqrt((1 - alphas_cumprod[i]) /
10311034
alphas_cumprod[i]);
10321035
}
1036+
1037+
struct ggml_tensor* pred_original_sample =
1038+
ggml_dup_tensor(work_ctx, x);
1039+
struct ggml_tensor* variance_noise =
1040+
ggml_dup_tensor(work_ctx, x);
1041+
10331042
for (int i = 0; i < steps; i++) {
10341043
// The "trailing" DDIM timestep, see S. Lin et al.,
10351044
// "Common Diffusion Noise Schedules and Sample Steps
10361045
// are Flawed", arXiv:2305.08891 [cs], p. 4, Table
1037-
// 2. Most variables below follow Diffusers naming.
1046+
// 2. Most variables below follow Diffusers naming
1047+
//
1048+
// Diffuser naming vs. J. Song et al., "Denoising
1049+
// Diffusion Implicit Models", arXiv:2010.02502, p. 5,
1050+
// (12) and p. 16, (16) (<variable name> -> <name in
1051+
// paper>):
1052+
//
1053+
// - pred_noise_t -> epsilon_theta^(t)(x_t)
1054+
// - pred_original_sample -> f_theta^(t)(x_t) or x_0
1055+
// - std_dev_t -> sigma_t (not the LMS sigma)
1056+
// - eta -> eta (set to 0 at the moment)
1057+
// - pred_sample_direction -> "direction pointing to
1058+
// x_t"
1059+
// - pred_prev_sample -> "x_t-1"
10381060
int timestep =
10391061
roundf(TIMESTEPS -
10401062
i * ((float)TIMESTEPS / steps)) - 1;
1063+
// 1. get previous step value (=t-1)
10411064
int prev_timestep = timestep - TIMESTEPS / steps;
10421065
// The sigma here is chosen to cause the
10431066
// CompVisDenoiser to produce t = timestep
@@ -1066,51 +1089,53 @@ static void sample_k_diffusion(sample_method_t method,
10661089
}
10671090
else {
10681091
// For the subsequent steps after the first one,
1069-
// at this point x = latents (pipeline) or x =
1070-
// sample (scheduler), and needs to be prescaled
1071-
// with x <- latents / c_in to compensate for
1072-
// model() applying the scale c_in before the
1073-
// U-net F_theta
1092+
// at this point x = latents or x = sample, and
1093+
// needs to be prescaled with x <- sample / c_in
1094+
// to compensate for model() applying the scale
1095+
// c_in before the U-net F_theta
10741096
float* vec_x = (float*)x->data;
10751097
for (int j = 0; j < ggml_nelements(x); j++) {
10761098
vec_x[j] *= std::sqrt(sigma * sigma + 1);
10771099
}
10781100
}
1079-
// Note model() is the D(x, sigma) as defined in
1080-
// T. Karras et al., arXiv:2206.00364, p. 3, Table 1
1081-
// and p. 8 (7)
1082-
struct ggml_tensor* noise_pred =
1101+
// Note (also noise_pred in Diffuser's pipeline)
1102+
// model_output = model() is the D(x, sigma) as
1103+
// defined in T. Karras et al., arXiv:2206.00364,
1104+
// p. 3, Table 1 and p. 8 (7), compare also p. 38
1105+
// (226) therein.
1106+
struct ggml_tensor* model_output =
10831107
model(x, sigma, i + 1);
1084-
// Here noise_pred is still the k-diffusion denoiser
1108+
// Here model_output is still the k-diffusion denoiser
10851109
// output, not the U-net output F_theta(c_in(sigma) x;
10861110
// ...) in Karras et al. (2022), whereas Diffusers'
1087-
// noise_pred is F_theta(...). Recover the actual
1088-
// noise_pred, which is also referred to as the
1111+
// model_output is F_theta(...). Recover the actual
1112+
// model_output, which is also referred to as the
10891113
// "Karras ODE derivative" d or d_cur in several
10901114
// samplers above.
10911115
{
10921116
float* vec_x = (float*)x->data;
1093-
float* vec_noise_pred = (float*)noise_pred->data;
1117+
float* vec_model_output =
1118+
(float*)model_output->data;
10941119
for (int j = 0; j < ggml_nelements(x); j++) {
1095-
vec_noise_pred[j] =
1096-
(vec_x[j] - vec_noise_pred[j]) *
1120+
vec_model_output[j] =
1121+
(vec_x[j] - vec_model_output[j]) *
10971122
(1 / sigma);
10981123
}
10991124
}
11001125
// 2. compute alphas, betas
11011126
float alpha_prod_t = alphas_cumprod[timestep];
1102-
// Note final_alpha_cumprod = alphas_cumprod[0]
1127+
// Note final_alpha_cumprod = alphas_cumprod[0] due to
1128+
// trailing timestep spacing
11031129
float alpha_prod_t_prev = prev_timestep >= 0 ?
11041130
alphas_cumprod[prev_timestep] : alphas_cumprod[0];
11051131
float beta_prod_t = 1 - alpha_prod_t;
11061132
// 3. compute predicted original sample from predicted
11071133
// noise also called "predicted x_0" of formula (12)
11081134
// from https://arxiv.org/pdf/2010.02502.pdf
1109-
struct ggml_tensor* pred_original_sample =
1110-
ggml_dup_tensor(work_ctx, x);
11111135
{
11121136
float* vec_x = (float*)x->data;
1113-
float* vec_noise_pred = (float*)noise_pred->data;
1137+
float* vec_model_output =
1138+
(float*)model_output->data;
11141139
float* vec_pred_original_sample =
11151140
(float*)pred_original_sample->data;
11161141
// Note the substitution of latents or sample = x
@@ -1119,12 +1144,12 @@ static void sample_k_diffusion(sample_method_t method,
11191144
vec_pred_original_sample[j] =
11201145
(vec_x[j] / std::sqrt(sigma * sigma + 1) -
11211146
std::sqrt(beta_prod_t) *
1122-
vec_noise_pred[j]) *
1147+
vec_model_output[j]) *
11231148
(1 / std::sqrt(alpha_prod_t));
11241149
}
11251150
}
11261151
// Assuming the "epsilon" prediction type, where below
1127-
// pred_epsilon = noise_pred is inserted, and is not
1152+
// pred_epsilon = model_output is inserted, and is not
11281153
// defined/copied explicitly.
11291154
//
11301155
// 5. compute variance: "sigma_t(eta)" -> see formula
@@ -1135,34 +1160,35 @@ static void sample_k_diffusion(sample_method_t method,
11351160
float beta_prod_t_prev = 1 - alpha_prod_t_prev;
11361161
float variance = (beta_prod_t_prev / beta_prod_t) *
11371162
(1 - alpha_prod_t / alpha_prod_t_prev);
1138-
float std_dev_t = 0 * std::sqrt(variance);
1163+
float std_dev_t = eta * std::sqrt(variance);
11391164
// 6. compute "direction pointing to x_t" of formula
11401165
// (12) from https://arxiv.org/pdf/2010.02502.pdf
1141-
struct ggml_tensor* pred_sample_direction =
1142-
ggml_dup_tensor(work_ctx, noise_pred);
1143-
{
1144-
float* vec_noise_pred = (float*)noise_pred->data;
1145-
float* vec_pred_sample_direction =
1146-
(float*)pred_sample_direction->data;
1147-
for (int j = 0; j < ggml_nelements(x); j++) {
1148-
vec_pred_sample_direction[j] =
1149-
std::sqrt(1 - alpha_prod_t_prev -
1150-
std::pow(std_dev_t, 2)) *
1151-
vec_noise_pred[j];
1152-
}
1153-
}
11541166
// 7. compute x_t without "random noise" of formula
11551167
// (12) from https://arxiv.org/pdf/2010.02502.pdf
11561168
{
1169+
float* vec_model_output = (float*)model_output->data;
11571170
float* vec_pred_original_sample =
11581171
(float*)pred_original_sample->data;
1159-
float* vec_pred_sample_direction =
1160-
(float*)pred_sample_direction->data;
11611172
float* vec_x = (float*)x->data;
11621173
for (int j = 0; j < ggml_nelements(x); j++) {
1174+
// Two step inner loop without an explicit
1175+
// tensor
1176+
float pred_sample_direction =
1177+
std::sqrt(1 - alpha_prod_t_prev -
1178+
std::pow(std_dev_t, 2)) *
1179+
vec_model_output[j];
11631180
vec_x[j] = std::sqrt(alpha_prod_t_prev) *
11641181
vec_pred_original_sample[j] +
1165-
vec_pred_sample_direction[j];
1182+
pred_sample_direction;
1183+
}
1184+
}
1185+
if (eta > 0) {
1186+
ggml_tensor_set_f32_randn(variance_noise, rng);
1187+
float* vec_variance_noise =
1188+
(float*)variance_noise->data;
1189+
float* vec_x = (float*)x->data;
1190+
for (int j = 0; j < ggml_nelements(x); j++) {
1191+
vec_x[j] += std_dev_t * vec_variance_noise[j];
11661192
}
11671193
}
11681194
// See the note above: x = latents or sample here, and
@@ -1173,6 +1199,174 @@ static void sample_k_diffusion(sample_method_t method,
11731199
// factor c_in.
11741200
}
11751201
} break;
1202+
case TCD: // Strategic Stochastic Sampling (Algorithm 4) in
1203+
// Trajectory Consistency Distillation
1204+
{
1205+
float beta_start = 0.00085f;
1206+
float beta_end = 0.0120f;
1207+
std::vector<double> alphas_cumprod;
1208+
std::vector<double> compvis_sigmas;
1209+
1210+
alphas_cumprod.reserve(TIMESTEPS);
1211+
compvis_sigmas.reserve(TIMESTEPS);
1212+
for (int i = 0; i < TIMESTEPS; i++) {
1213+
alphas_cumprod[i] =
1214+
(i == 0 ? 1.0f : alphas_cumprod[i - 1]) *
1215+
(1.0f -
1216+
std::pow(sqrtf(beta_start) +
1217+
(sqrtf(beta_end) - sqrtf(beta_start)) *
1218+
((float)i / (TIMESTEPS - 1)), 2));
1219+
compvis_sigmas[i] =
1220+
std::sqrt((1 - alphas_cumprod[i]) /
1221+
alphas_cumprod[i]);
1222+
}
1223+
int original_steps = 50;
1224+
1225+
struct ggml_tensor* pred_original_sample =
1226+
ggml_dup_tensor(work_ctx, x);
1227+
struct ggml_tensor* noise =
1228+
ggml_dup_tensor(work_ctx, x);
1229+
1230+
for (int i = 0; i < steps; i++) {
1231+
// Analytic form for TCD timesteps
1232+
int timestep = TIMESTEPS - 1 -
1233+
(TIMESTEPS / original_steps) *
1234+
(int)floor(i * ((float)original_steps / steps));
1235+
// 1. get previous step value
1236+
int prev_timestep = i >= steps - 1 ? 0 :
1237+
TIMESTEPS - 1 - (TIMESTEPS / original_steps) *
1238+
(int)floor((i + 1) *
1239+
((float)original_steps / steps));
1240+
// Here timestep_s is tau_n' in Algorithm 4. The _s
1241+
// notation appears to be that from DPM-Solver, C. Lu,
1242+
// arXiv:2206.00927 [cs.LG], but this notation is not
1243+
// continued in Algorithm 4, where _n' is used.
1244+
int timestep_s =
1245+
(int)floor((1 - eta) * prev_timestep);
1246+
// Begin k-diffusion specific workaround for
1247+
// evaluating F_theta(x; ...) from D(x, sigma), same
1248+
// as in DDIM (and see there for detailed comments)
1249+
float sigma = compvis_sigmas[timestep];
1250+
if (i == 0) {
1251+
float* vec_x = (float*)x->data;
1252+
for (int j = 0; j < ggml_nelements(x); j++) {
1253+
vec_x[j] *= std::sqrt(sigma * sigma + 1) /
1254+
sigma;
1255+
}
1256+
}
1257+
else {
1258+
float* vec_x = (float*)x->data;
1259+
for (int j = 0; j < ggml_nelements(x); j++) {
1260+
vec_x[j] *= std::sqrt(sigma * sigma + 1);
1261+
}
1262+
}
1263+
struct ggml_tensor* model_output =
1264+
model(x, sigma, i + 1);
1265+
{
1266+
float* vec_x = (float*)x->data;
1267+
float* vec_model_output =
1268+
(float*)model_output->data;
1269+
for (int j = 0; j < ggml_nelements(x); j++) {
1270+
vec_model_output[j] =
1271+
(vec_x[j] - vec_model_output[j]) *
1272+
(1 / sigma);
1273+
}
1274+
}
1275+
// 2. compute alphas, betas
1276+
//
1277+
// When comparing TCD with DDPM/DDIM note that Zheng
1278+
// et al. (2024) follows the DPM-Solver notation for
1279+
// alpha. One can find the following comment in the
1280+
// original DPM-Solver code
1281+
// (https://github.com/LuChengTHU/dpm-solver/):
1282+
// "**Important**: Please pay special attention for
1283+
// the args for `alphas_cumprod`: The `alphas_cumprod`
1284+
// is the \hat{alpha_n} arrays in the notations of
1285+
// DDPM. [...] Therefore, the notation \hat{alpha_n}
1286+
// is different from the notation alpha_t in
1287+
// DPM-Solver. In fact, we have alpha_{t_n} =
1288+
// \sqrt{\hat{alpha_n}}, [...]"
1289+
float alpha_prod_t = alphas_cumprod[timestep];
1290+
float beta_prod_t = 1 - alpha_prod_t;
1291+
// Note final_alpha_cumprod = alphas_cumprod[0] since
1292+
// TCD is always "trailing"
1293+
float alpha_prod_t_prev = prev_timestep >= 0 ?
1294+
alphas_cumprod[prev_timestep] : alphas_cumprod[0];
1295+
// The subscript _s are the only portion in this
1296+
// section (2) unique to TCD
1297+
float alpha_prod_s = alphas_cumprod[timestep_s];
1298+
float beta_prod_s = 1 - alpha_prod_s;
1299+
// 3. Compute the predicted noised sample x_s based on
1300+
// the model parameterization
1301+
//
1302+
// This section is also exactly the same as DDIM
1303+
{
1304+
float* vec_x = (float*)x->data;
1305+
float* vec_model_output =
1306+
(float*)model_output->data;
1307+
float* vec_pred_original_sample =
1308+
(float*)pred_original_sample->data;
1309+
for (int j = 0; j < ggml_nelements(x); j++) {
1310+
vec_pred_original_sample[j] =
1311+
(vec_x[j] / std::sqrt(sigma * sigma + 1) -
1312+
std::sqrt(beta_prod_t) *
1313+
vec_model_output[j]) *
1314+
(1 / std::sqrt(alpha_prod_t));
1315+
}
1316+
}
1317+
// This consistency function step can be difficult to
1318+
// decipher from Algorithm 4, as it involves a
1319+
// difficult notation ("|->"). In Diffusers it is
1320+
// borrowed verbatim (with the same comments below for
1321+
// step (4)) from LCMScheduler's noise injection step,
1322+
// compare in S. Luo et al., arXiv:2310.04378 p. 14,
1323+
// Algorithm 3.
1324+
{
1325+
float* vec_pred_original_sample =
1326+
(float*)pred_original_sample->data;
1327+
float* vec_model_output =
1328+
(float*)model_output->data;
1329+
float* vec_x = (float*)x->data;
1330+
for (int j = 0; j < ggml_nelements(x); j++) {
1331+
// Substituting x = pred_noised_sample and
1332+
// pred_epsilon = model_output
1333+
vec_x[j] =
1334+
std::sqrt(alpha_prod_s) *
1335+
vec_pred_original_sample[j] +
1336+
std::sqrt(beta_prod_s) *
1337+
vec_model_output[j];
1338+
}
1339+
}
1340+
// 4. Sample and inject noise z ~ N(0, I) for
1341+
// MultiStep Inference Noise is not used on the final
1342+
// timestep of the timestep schedule. This also means
1343+
// that noise is not used for one-step sampling. Eta
1344+
// (referred to as "gamma" in the paper) was
1345+
// introduced to control the stochasticity in every
1346+
// step. When eta = 0, it represents deterministic
1347+
// sampling, whereas eta = 1 indicates full stochastic
1348+
// sampling.
1349+
if (eta > 0 && i != steps - 1) {
1350+
// In this case, x is still pred_noised_sample,
1351+
// continue in-place
1352+
ggml_tensor_set_f32_randn(noise, rng);
1353+
float* vec_x = (float*)x->data;
1354+
float* vec_noise = (float*)noise->data;
1355+
for (int j = 0; j < ggml_nelements(x); j++) {
1356+
// Corresponding to (35) in Zheng et
1357+
// al. (2024), substituting x =
1358+
// pred_noised_sample
1359+
vec_x[j] =
1360+
std::sqrt(alpha_prod_t_prev /
1361+
alpha_prod_s) *
1362+
vec_x[j] +
1363+
std::sqrt(1 - alpha_prod_t_prev /
1364+
alpha_prod_s) *
1365+
vec_noise[j];
1366+
}
1367+
}
1368+
}
1369+
} break;
11761370

11771371
default:
11781372
LOG_ERROR("Attempting to sample with nonexisting sample method %i", method);

0 commit comments

Comments
 (0)