Skip to content

Commit 8aff010

Browse files
committed
DDIM and TCD
commit 8086045 Author: Yue Shi Lai <yslai@users.noreply.github.com> Date: Tue Jan 14 01:50:56 2025 -0800 Implement TCD, avoid repeated allocation in DDIM, implement eta parameter for DDIM and TCD, minor comment clarification commit bf353f1 Author: Yue Shi Lai <yslai@users.noreply.github.com> Date: Sat Jan 11 23:30:12 2025 -0800 Implement DDIM with the "trailing" timestep spacing
1 parent 62a5d08 commit 8aff010

File tree

5 files changed

+394
-4
lines changed

5 files changed

+394
-4
lines changed

denoiser.hpp

Lines changed: 363 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,8 @@ static void sample_k_diffusion(sample_method_t method,
479479
ggml_context* work_ctx,
480480
ggml_tensor* x,
481481
std::vector<float> sigmas,
482-
std::shared_ptr<RNG> rng) {
482+
std::shared_ptr<RNG> rng,
483+
float eta) {
483484
size_t steps = sigmas.size() - 1;
484485
// sample_euler_ancestral
485486
switch (method) {
@@ -1010,6 +1011,367 @@ static void sample_k_diffusion(sample_method_t method,
10101011
}
10111012
}
10121013
} break;
1014+
case DDIM_TRAILING: // Denoising Diffusion Implicit Models
1015+
// with the "trailing" timestep spacing
1016+
{
1017+
// DDIM itself needs alphas_cumprod (DDPM, Ho et al.,
1018+
// arXiv:2006.11239 [cs.LG] with k-diffusion's start and
1019+
// end beta) (which unfortunately k-diffusion's data
1020+
// structure hides from the denoiser), and the sigmas are
1021+
// also needed to invert the behavior of CompVisDenoiser
1022+
// (k-diffusion's LMSDiscreteScheduler)
1023+
float beta_start = 0.00085f;
1024+
float beta_end = 0.0120f;
1025+
std::vector<double> alphas_cumprod;
1026+
std::vector<double> compvis_sigmas;
1027+
1028+
alphas_cumprod.reserve(TIMESTEPS);
1029+
compvis_sigmas.reserve(TIMESTEPS);
1030+
for (int i = 0; i < TIMESTEPS; i++) {
1031+
alphas_cumprod[i] =
1032+
(i == 0 ? 1.0f : alphas_cumprod[i - 1]) *
1033+
(1.0f -
1034+
std::pow(sqrtf(beta_start) +
1035+
(sqrtf(beta_end) - sqrtf(beta_start)) *
1036+
((float)i / (TIMESTEPS - 1)), 2));
1037+
compvis_sigmas[i] =
1038+
std::sqrt((1 - alphas_cumprod[i]) /
1039+
alphas_cumprod[i]);
1040+
}
1041+
1042+
struct ggml_tensor* pred_original_sample =
1043+
ggml_dup_tensor(work_ctx, x);
1044+
struct ggml_tensor* variance_noise =
1045+
ggml_dup_tensor(work_ctx, x);
1046+
1047+
for (int i = 0; i < steps; i++) {
1048+
// The "trailing" DDIM timestep, see S. Lin et al.,
1049+
// "Common Diffusion Noise Schedules and Sample Steps
1050+
// are Flawed", arXiv:2305.08891 [cs], p. 4, Table
1051+
// 2. Most variables below follow Diffusers naming
1052+
//
1053+
// Diffuser naming vs. J. Song et al., "Denoising
1054+
// Diffusion Implicit Models", arXiv:2010.02502, p. 5,
1055+
// (12) and p. 16, (16) (<variable name> -> <name in
1056+
// paper>):
1057+
//
1058+
// - pred_noise_t -> epsilon_theta^(t)(x_t)
1059+
// - pred_original_sample -> f_theta^(t)(x_t) or x_0
1060+
// - std_dev_t -> sigma_t (not the LMS sigma)
1061+
// - eta -> eta (set to 0 at the moment)
1062+
// - pred_sample_direction -> "direction pointing to
1063+
// x_t"
1064+
// - pred_prev_sample -> "x_t-1"
1065+
int timestep =
1066+
roundf(TIMESTEPS -
1067+
i * ((float)TIMESTEPS / steps)) - 1;
1068+
// 1. get previous step value (=t-1)
1069+
int prev_timestep = timestep - TIMESTEPS / steps;
1070+
// The sigma here is chosen to cause the
1071+
// CompVisDenoiser to produce t = timestep
1072+
float sigma = compvis_sigmas[timestep];
1073+
if (i == 0) {
1074+
// The function add_noise intializes x to
1075+
// Diffusers' latents * sigma (as in Diffusers'
1076+
// pipeline) or sample * sigma (Diffusers'
1077+
// scheduler), where this sigma = init_noise_sigma
1078+
// in Diffusers. For DDPM and DDIM however,
1079+
// init_noise_sigma = 1. But the k-diffusion
1080+
// model() also evaluates F_theta(c_in(sigma) x;
1081+
// ...) instead of the bare U-net F_theta, with
1082+
// c_in = 1 / sqrt(sigma^2 + 1), as defined in
1083+
// T. Karras et al., "Elucidating the Design Space
1084+
// of Diffusion-Based Generative Models",
1085+
// arXiv:2206.00364 [cs.CV], p. 3, Table 1. Hence
1086+
// the first call has to be prescaled as x <- x /
1087+
// (c_in * sigma) with the k-diffusion pipeline
1088+
// and CompVisDenoiser.
1089+
float* vec_x = (float*)x->data;
1090+
for (int j = 0; j < ggml_nelements(x); j++) {
1091+
vec_x[j] *= std::sqrt(sigma * sigma + 1) /
1092+
sigma;
1093+
}
1094+
}
1095+
else {
1096+
// For the subsequent steps after the first one,
1097+
// at this point x = latents or x = sample, and
1098+
// needs to be prescaled with x <- sample / c_in
1099+
// to compensate for model() applying the scale
1100+
// c_in before the U-net F_theta
1101+
float* vec_x = (float*)x->data;
1102+
for (int j = 0; j < ggml_nelements(x); j++) {
1103+
vec_x[j] *= std::sqrt(sigma * sigma + 1);
1104+
}
1105+
}
1106+
// Note (also noise_pred in Diffuser's pipeline)
1107+
// model_output = model() is the D(x, sigma) as
1108+
// defined in T. Karras et al., arXiv:2206.00364,
1109+
// p. 3, Table 1 and p. 8 (7), compare also p. 38
1110+
// (226) therein.
1111+
struct ggml_tensor* model_output =
1112+
model(x, sigma, i + 1);
1113+
// Here model_output is still the k-diffusion denoiser
1114+
// output, not the U-net output F_theta(c_in(sigma) x;
1115+
// ...) in Karras et al. (2022), whereas Diffusers'
1116+
// model_output is F_theta(...). Recover the actual
1117+
// model_output, which is also referred to as the
1118+
// "Karras ODE derivative" d or d_cur in several
1119+
// samplers above.
1120+
{
1121+
float* vec_x = (float*)x->data;
1122+
float* vec_model_output =
1123+
(float*)model_output->data;
1124+
for (int j = 0; j < ggml_nelements(x); j++) {
1125+
vec_model_output[j] =
1126+
(vec_x[j] - vec_model_output[j]) *
1127+
(1 / sigma);
1128+
}
1129+
}
1130+
// 2. compute alphas, betas
1131+
float alpha_prod_t = alphas_cumprod[timestep];
1132+
// Note final_alpha_cumprod = alphas_cumprod[0] due to
1133+
// trailing timestep spacing
1134+
float alpha_prod_t_prev = prev_timestep >= 0 ?
1135+
alphas_cumprod[prev_timestep] : alphas_cumprod[0];
1136+
float beta_prod_t = 1 - alpha_prod_t;
1137+
// 3. compute predicted original sample from predicted
1138+
// noise also called "predicted x_0" of formula (12)
1139+
// from https://arxiv.org/pdf/2010.02502.pdf
1140+
{
1141+
float* vec_x = (float*)x->data;
1142+
float* vec_model_output =
1143+
(float*)model_output->data;
1144+
float* vec_pred_original_sample =
1145+
(float*)pred_original_sample->data;
1146+
// Note the substitution of latents or sample = x
1147+
// * c_in = x / sqrt(sigma^2 + 1)
1148+
for (int j = 0; j < ggml_nelements(x); j++) {
1149+
vec_pred_original_sample[j] =
1150+
(vec_x[j] / std::sqrt(sigma * sigma + 1) -
1151+
std::sqrt(beta_prod_t) *
1152+
vec_model_output[j]) *
1153+
(1 / std::sqrt(alpha_prod_t));
1154+
}
1155+
}
1156+
// Assuming the "epsilon" prediction type, where below
1157+
// pred_epsilon = model_output is inserted, and is not
1158+
// defined/copied explicitly.
1159+
//
1160+
// 5. compute variance: "sigma_t(eta)" -> see formula
1161+
// (16)
1162+
//
1163+
// sigma_t = sqrt((1 - alpha_t-1)/(1 - alpha_t)) *
1164+
// sqrt(1 - alpha_t/alpha_t-1)
1165+
float beta_prod_t_prev = 1 - alpha_prod_t_prev;
1166+
float variance = (beta_prod_t_prev / beta_prod_t) *
1167+
(1 - alpha_prod_t / alpha_prod_t_prev);
1168+
float std_dev_t = eta * std::sqrt(variance);
1169+
// 6. compute "direction pointing to x_t" of formula
1170+
// (12) from https://arxiv.org/pdf/2010.02502.pdf
1171+
// 7. compute x_t without "random noise" of formula
1172+
// (12) from https://arxiv.org/pdf/2010.02502.pdf
1173+
{
1174+
float* vec_model_output = (float*)model_output->data;
1175+
float* vec_pred_original_sample =
1176+
(float*)pred_original_sample->data;
1177+
float* vec_x = (float*)x->data;
1178+
for (int j = 0; j < ggml_nelements(x); j++) {
1179+
// Two step inner loop without an explicit
1180+
// tensor
1181+
float pred_sample_direction =
1182+
std::sqrt(1 - alpha_prod_t_prev -
1183+
std::pow(std_dev_t, 2)) *
1184+
vec_model_output[j];
1185+
vec_x[j] = std::sqrt(alpha_prod_t_prev) *
1186+
vec_pred_original_sample[j] +
1187+
pred_sample_direction;
1188+
}
1189+
}
1190+
if (eta > 0) {
1191+
ggml_tensor_set_f32_randn(variance_noise, rng);
1192+
float* vec_variance_noise =
1193+
(float*)variance_noise->data;
1194+
float* vec_x = (float*)x->data;
1195+
for (int j = 0; j < ggml_nelements(x); j++) {
1196+
vec_x[j] += std_dev_t * vec_variance_noise[j];
1197+
}
1198+
}
1199+
// See the note above: x = latents or sample here, and
1200+
// is not scaled by the c_in. For the final output
1201+
// this is correct, but for subsequent iterations, x
1202+
// needs to be prescaled again, since k-diffusion's
1203+
// model() differes from the bare U-net F_theta by the
1204+
// factor c_in.
1205+
}
1206+
} break;
1207+
case TCD: // Strategic Stochastic Sampling (Algorithm 4) in
1208+
// Trajectory Consistency Distillation
1209+
{
1210+
float beta_start = 0.00085f;
1211+
float beta_end = 0.0120f;
1212+
std::vector<double> alphas_cumprod;
1213+
std::vector<double> compvis_sigmas;
1214+
1215+
alphas_cumprod.reserve(TIMESTEPS);
1216+
compvis_sigmas.reserve(TIMESTEPS);
1217+
for (int i = 0; i < TIMESTEPS; i++) {
1218+
alphas_cumprod[i] =
1219+
(i == 0 ? 1.0f : alphas_cumprod[i - 1]) *
1220+
(1.0f -
1221+
std::pow(sqrtf(beta_start) +
1222+
(sqrtf(beta_end) - sqrtf(beta_start)) *
1223+
((float)i / (TIMESTEPS - 1)), 2));
1224+
compvis_sigmas[i] =
1225+
std::sqrt((1 - alphas_cumprod[i]) /
1226+
alphas_cumprod[i]);
1227+
}
1228+
int original_steps = 50;
1229+
1230+
struct ggml_tensor* pred_original_sample =
1231+
ggml_dup_tensor(work_ctx, x);
1232+
struct ggml_tensor* noise =
1233+
ggml_dup_tensor(work_ctx, x);
1234+
1235+
for (int i = 0; i < steps; i++) {
1236+
// Analytic form for TCD timesteps
1237+
int timestep = TIMESTEPS - 1 -
1238+
(TIMESTEPS / original_steps) *
1239+
(int)floor(i * ((float)original_steps / steps));
1240+
// 1. get previous step value
1241+
int prev_timestep = i >= steps - 1 ? 0 :
1242+
TIMESTEPS - 1 - (TIMESTEPS / original_steps) *
1243+
(int)floor((i + 1) *
1244+
((float)original_steps / steps));
1245+
// Here timestep_s is tau_n' in Algorithm 4. The _s
1246+
// notation appears to be that from DPM-Solver, C. Lu,
1247+
// arXiv:2206.00927 [cs.LG], but this notation is not
1248+
// continued in Algorithm 4, where _n' is used.
1249+
int timestep_s =
1250+
(int)floor((1 - eta) * prev_timestep);
1251+
// Begin k-diffusion specific workaround for
1252+
// evaluating F_theta(x; ...) from D(x, sigma), same
1253+
// as in DDIM (and see there for detailed comments)
1254+
float sigma = compvis_sigmas[timestep];
1255+
if (i == 0) {
1256+
float* vec_x = (float*)x->data;
1257+
for (int j = 0; j < ggml_nelements(x); j++) {
1258+
vec_x[j] *= std::sqrt(sigma * sigma + 1) /
1259+
sigma;
1260+
}
1261+
}
1262+
else {
1263+
float* vec_x = (float*)x->data;
1264+
for (int j = 0; j < ggml_nelements(x); j++) {
1265+
vec_x[j] *= std::sqrt(sigma * sigma + 1);
1266+
}
1267+
}
1268+
struct ggml_tensor* model_output =
1269+
model(x, sigma, i + 1);
1270+
{
1271+
float* vec_x = (float*)x->data;
1272+
float* vec_model_output =
1273+
(float*)model_output->data;
1274+
for (int j = 0; j < ggml_nelements(x); j++) {
1275+
vec_model_output[j] =
1276+
(vec_x[j] - vec_model_output[j]) *
1277+
(1 / sigma);
1278+
}
1279+
}
1280+
// 2. compute alphas, betas
1281+
//
1282+
// When comparing TCD with DDPM/DDIM note that Zheng
1283+
// et al. (2024) follows the DPM-Solver notation for
1284+
// alpha. One can find the following comment in the
1285+
// original DPM-Solver code
1286+
// (https://github.com/LuChengTHU/dpm-solver/):
1287+
// "**Important**: Please pay special attention for
1288+
// the args for `alphas_cumprod`: The `alphas_cumprod`
1289+
// is the \hat{alpha_n} arrays in the notations of
1290+
// DDPM. [...] Therefore, the notation \hat{alpha_n}
1291+
// is different from the notation alpha_t in
1292+
// DPM-Solver. In fact, we have alpha_{t_n} =
1293+
// \sqrt{\hat{alpha_n}}, [...]"
1294+
float alpha_prod_t = alphas_cumprod[timestep];
1295+
float beta_prod_t = 1 - alpha_prod_t;
1296+
// Note final_alpha_cumprod = alphas_cumprod[0] since
1297+
// TCD is always "trailing"
1298+
float alpha_prod_t_prev = prev_timestep >= 0 ?
1299+
alphas_cumprod[prev_timestep] : alphas_cumprod[0];
1300+
// The subscript _s are the only portion in this
1301+
// section (2) unique to TCD
1302+
float alpha_prod_s = alphas_cumprod[timestep_s];
1303+
float beta_prod_s = 1 - alpha_prod_s;
1304+
// 3. Compute the predicted noised sample x_s based on
1305+
// the model parameterization
1306+
//
1307+
// This section is also exactly the same as DDIM
1308+
{
1309+
float* vec_x = (float*)x->data;
1310+
float* vec_model_output =
1311+
(float*)model_output->data;
1312+
float* vec_pred_original_sample =
1313+
(float*)pred_original_sample->data;
1314+
for (int j = 0; j < ggml_nelements(x); j++) {
1315+
vec_pred_original_sample[j] =
1316+
(vec_x[j] / std::sqrt(sigma * sigma + 1) -
1317+
std::sqrt(beta_prod_t) *
1318+
vec_model_output[j]) *
1319+
(1 / std::sqrt(alpha_prod_t));
1320+
}
1321+
}
1322+
// This consistency function step can be difficult to
1323+
// decipher from Algorithm 4, as it involves a
1324+
// difficult notation ("|->"). In Diffusers it is
1325+
// borrowed verbatim (with the same comments below for
1326+
// step (4)) from LCMScheduler's noise injection step,
1327+
// compare in S. Luo et al., arXiv:2310.04378 p. 14,
1328+
// Algorithm 3.
1329+
{
1330+
float* vec_pred_original_sample =
1331+
(float*)pred_original_sample->data;
1332+
float* vec_model_output =
1333+
(float*)model_output->data;
1334+
float* vec_x = (float*)x->data;
1335+
for (int j = 0; j < ggml_nelements(x); j++) {
1336+
// Substituting x = pred_noised_sample and
1337+
// pred_epsilon = model_output
1338+
vec_x[j] =
1339+
std::sqrt(alpha_prod_s) *
1340+
vec_pred_original_sample[j] +
1341+
std::sqrt(beta_prod_s) *
1342+
vec_model_output[j];
1343+
}
1344+
}
1345+
// 4. Sample and inject noise z ~ N(0, I) for
1346+
// MultiStep Inference Noise is not used on the final
1347+
// timestep of the timestep schedule. This also means
1348+
// that noise is not used for one-step sampling. Eta
1349+
// (referred to as "gamma" in the paper) was
1350+
// introduced to control the stochasticity in every
1351+
// step. When eta = 0, it represents deterministic
1352+
// sampling, whereas eta = 1 indicates full stochastic
1353+
// sampling.
1354+
if (eta > 0 && i != steps - 1) {
1355+
// In this case, x is still pred_noised_sample,
1356+
// continue in-place
1357+
ggml_tensor_set_f32_randn(noise, rng);
1358+
float* vec_x = (float*)x->data;
1359+
float* vec_noise = (float*)noise->data;
1360+
for (int j = 0; j < ggml_nelements(x); j++) {
1361+
// Corresponding to (35) in Zheng et
1362+
// al. (2024), substituting x =
1363+
// pred_noised_sample
1364+
vec_x[j] =
1365+
std::sqrt(alpha_prod_t_prev /
1366+
alpha_prod_s) *
1367+
vec_x[j] +
1368+
std::sqrt(1 - alpha_prod_t_prev /
1369+
alpha_prod_s) *
1370+
vec_noise[j];
1371+
}
1372+
}
1373+
}
1374+
} break;
10131375

10141376
default:
10151377
LOG_ERROR("Attempting to sample with nonexisting sample method %i", method);

0 commit comments

Comments
 (0)