Skip to content

Commit 19d876e

Browse files
authored
feat: implement DDIM with the "trailing" timestep spacing and TCD (leejet#568)
1 parent f27f2b2 commit 19d876e

File tree

4 files changed

+400
-3
lines changed

4 files changed

+400
-3
lines changed

denoiser.hpp

Lines changed: 370 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,8 @@ static void sample_k_diffusion(sample_method_t method,
474474
ggml_context* work_ctx,
475475
ggml_tensor* x,
476476
std::vector<float> sigmas,
477-
std::shared_ptr<RNG> rng) {
477+
std::shared_ptr<RNG> rng,
478+
float eta) {
478479
size_t steps = sigmas.size() - 1;
479480
// sample_euler_ancestral
480481
switch (method) {
@@ -1005,6 +1006,374 @@ static void sample_k_diffusion(sample_method_t method,
10051006
}
10061007
}
10071008
} break;
1009+
case DDIM_TRAILING: // Denoising Diffusion Implicit Models
1010+
// with the "trailing" timestep spacing
1011+
{
1012+
// See J. Song et al., "Denoising Diffusion Implicit
1013+
// Models", arXiv:2010.02502 [cs.LG]
1014+
//
1015+
// DDIM itself needs alphas_cumprod (DDPM, J. Ho et al.,
1016+
// arXiv:2006.11239 [cs.LG] with k-diffusion's start and
1017+
// end beta) (which unfortunately k-diffusion's data
1018+
// structure hides from the denoiser), and the sigmas are
1019+
// also needed to invert the behavior of CompVisDenoiser
1020+
// (k-diffusion's LMSDiscreteScheduler)
1021+
float beta_start = 0.00085f;
1022+
float beta_end = 0.0120f;
1023+
std::vector<double> alphas_cumprod;
1024+
std::vector<double> compvis_sigmas;
1025+
1026+
alphas_cumprod.reserve(TIMESTEPS);
1027+
compvis_sigmas.reserve(TIMESTEPS);
1028+
for (int i = 0; i < TIMESTEPS; i++) {
1029+
alphas_cumprod[i] =
1030+
(i == 0 ? 1.0f : alphas_cumprod[i - 1]) *
1031+
(1.0f -
1032+
std::pow(sqrtf(beta_start) +
1033+
(sqrtf(beta_end) - sqrtf(beta_start)) *
1034+
((float)i / (TIMESTEPS - 1)), 2));
1035+
compvis_sigmas[i] =
1036+
std::sqrt((1 - alphas_cumprod[i]) /
1037+
alphas_cumprod[i]);
1038+
}
1039+
1040+
struct ggml_tensor* pred_original_sample =
1041+
ggml_dup_tensor(work_ctx, x);
1042+
struct ggml_tensor* variance_noise =
1043+
ggml_dup_tensor(work_ctx, x);
1044+
1045+
for (int i = 0; i < steps; i++) {
1046+
// The "trailing" DDIM timestep, see S. Lin et al.,
1047+
// "Common Diffusion Noise Schedules and Sample Steps
1048+
// are Flawed", arXiv:2305.08891 [cs], p. 4, Table
1049+
// 2. Most variables below follow Diffusers naming
1050+
//
1051+
// Diffuser naming vs. Song et al. (2010), p. 5, (12)
1052+
// and p. 16, (16) (<variable name> -> <name in
1053+
// paper>):
1054+
//
1055+
// - pred_noise_t -> epsilon_theta^(t)(x_t)
1056+
// - pred_original_sample -> f_theta^(t)(x_t) or x_0
1057+
// - std_dev_t -> sigma_t (not the LMS sigma)
1058+
// - eta -> eta (set to 0 at the moment)
1059+
// - pred_sample_direction -> "direction pointing to
1060+
// x_t"
1061+
// - pred_prev_sample -> "x_t-1"
1062+
int timestep =
1063+
roundf(TIMESTEPS -
1064+
i * ((float)TIMESTEPS / steps)) - 1;
1065+
// 1. get previous step value (=t-1)
1066+
int prev_timestep = timestep - TIMESTEPS / steps;
1067+
// The sigma here is chosen to cause the
1068+
// CompVisDenoiser to produce t = timestep
1069+
float sigma = compvis_sigmas[timestep];
1070+
if (i == 0) {
1071+
// The function add_noise intializes x to
1072+
// Diffusers' latents * sigma (as in Diffusers'
1073+
// pipeline) or sample * sigma (Diffusers'
1074+
// scheduler), where this sigma = init_noise_sigma
1075+
// in Diffusers. For DDPM and DDIM however,
1076+
// init_noise_sigma = 1. But the k-diffusion
1077+
// model() also evaluates F_theta(c_in(sigma) x;
1078+
// ...) instead of the bare U-net F_theta, with
1079+
// c_in = 1 / sqrt(sigma^2 + 1), as defined in
1080+
// T. Karras et al., "Elucidating the Design Space
1081+
// of Diffusion-Based Generative Models",
1082+
// arXiv:2206.00364 [cs.CV], p. 3, Table 1. Hence
1083+
// the first call has to be prescaled as x <- x /
1084+
// (c_in * sigma) with the k-diffusion pipeline
1085+
// and CompVisDenoiser.
1086+
float* vec_x = (float*)x->data;
1087+
for (int j = 0; j < ggml_nelements(x); j++) {
1088+
vec_x[j] *= std::sqrt(sigma * sigma + 1) /
1089+
sigma;
1090+
}
1091+
}
1092+
else {
1093+
// For the subsequent steps after the first one,
1094+
// at this point x = latents or x = sample, and
1095+
// needs to be prescaled with x <- sample / c_in
1096+
// to compensate for model() applying the scale
1097+
// c_in before the U-net F_theta
1098+
float* vec_x = (float*)x->data;
1099+
for (int j = 0; j < ggml_nelements(x); j++) {
1100+
vec_x[j] *= std::sqrt(sigma * sigma + 1);
1101+
}
1102+
}
1103+
// Note (also noise_pred in Diffuser's pipeline)
1104+
// model_output = model() is the D(x, sigma) as
1105+
// defined in Karras et al. (2022), p. 3, Table 1 and
1106+
// p. 8 (7), compare also p. 38 (226) therein.
1107+
struct ggml_tensor* model_output =
1108+
model(x, sigma, i + 1);
1109+
// Here model_output is still the k-diffusion denoiser
1110+
// output, not the U-net output F_theta(c_in(sigma) x;
1111+
// ...) in Karras et al. (2022), whereas Diffusers'
1112+
// model_output is F_theta(...). Recover the actual
1113+
// model_output, which is also referred to as the
1114+
// "Karras ODE derivative" d or d_cur in several
1115+
// samplers above.
1116+
{
1117+
float* vec_x = (float*)x->data;
1118+
float* vec_model_output =
1119+
(float*)model_output->data;
1120+
for (int j = 0; j < ggml_nelements(x); j++) {
1121+
vec_model_output[j] =
1122+
(vec_x[j] - vec_model_output[j]) *
1123+
(1 / sigma);
1124+
}
1125+
}
1126+
// 2. compute alphas, betas
1127+
float alpha_prod_t = alphas_cumprod[timestep];
1128+
// Note final_alpha_cumprod = alphas_cumprod[0] due to
1129+
// trailing timestep spacing
1130+
float alpha_prod_t_prev = prev_timestep >= 0 ?
1131+
alphas_cumprod[prev_timestep] : alphas_cumprod[0];
1132+
float beta_prod_t = 1 - alpha_prod_t;
1133+
// 3. compute predicted original sample from predicted
1134+
// noise also called "predicted x_0" of formula (12)
1135+
// from https://arxiv.org/pdf/2010.02502.pdf
1136+
{
1137+
float* vec_x = (float*)x->data;
1138+
float* vec_model_output =
1139+
(float*)model_output->data;
1140+
float* vec_pred_original_sample =
1141+
(float*)pred_original_sample->data;
1142+
// Note the substitution of latents or sample = x
1143+
// * c_in = x / sqrt(sigma^2 + 1)
1144+
for (int j = 0; j < ggml_nelements(x); j++) {
1145+
vec_pred_original_sample[j] =
1146+
(vec_x[j] / std::sqrt(sigma * sigma + 1) -
1147+
std::sqrt(beta_prod_t) *
1148+
vec_model_output[j]) *
1149+
(1 / std::sqrt(alpha_prod_t));
1150+
}
1151+
}
1152+
// Assuming the "epsilon" prediction type, where below
1153+
// pred_epsilon = model_output is inserted, and is not
1154+
// defined/copied explicitly.
1155+
//
1156+
// 5. compute variance: "sigma_t(eta)" -> see formula
1157+
// (16)
1158+
//
1159+
// sigma_t = sqrt((1 - alpha_t-1)/(1 - alpha_t)) *
1160+
// sqrt(1 - alpha_t/alpha_t-1)
1161+
float beta_prod_t_prev = 1 - alpha_prod_t_prev;
1162+
float variance = (beta_prod_t_prev / beta_prod_t) *
1163+
(1 - alpha_prod_t / alpha_prod_t_prev);
1164+
float std_dev_t = eta * std::sqrt(variance);
1165+
// 6. compute "direction pointing to x_t" of formula
1166+
// (12) from https://arxiv.org/pdf/2010.02502.pdf
1167+
// 7. compute x_t without "random noise" of formula
1168+
// (12) from https://arxiv.org/pdf/2010.02502.pdf
1169+
{
1170+
float* vec_model_output = (float*)model_output->data;
1171+
float* vec_pred_original_sample =
1172+
(float*)pred_original_sample->data;
1173+
float* vec_x = (float*)x->data;
1174+
for (int j = 0; j < ggml_nelements(x); j++) {
1175+
// Two step inner loop without an explicit
1176+
// tensor
1177+
float pred_sample_direction =
1178+
std::sqrt(1 - alpha_prod_t_prev -
1179+
std::pow(std_dev_t, 2)) *
1180+
vec_model_output[j];
1181+
vec_x[j] = std::sqrt(alpha_prod_t_prev) *
1182+
vec_pred_original_sample[j] +
1183+
pred_sample_direction;
1184+
}
1185+
}
1186+
if (eta > 0) {
1187+
ggml_tensor_set_f32_randn(variance_noise, rng);
1188+
float* vec_variance_noise =
1189+
(float*)variance_noise->data;
1190+
float* vec_x = (float*)x->data;
1191+
for (int j = 0; j < ggml_nelements(x); j++) {
1192+
vec_x[j] += std_dev_t * vec_variance_noise[j];
1193+
}
1194+
}
1195+
// See the note above: x = latents or sample here, and
1196+
// is not scaled by the c_in. For the final output
1197+
// this is correct, but for subsequent iterations, x
1198+
// needs to be prescaled again, since k-diffusion's
1199+
// model() differes from the bare U-net F_theta by the
1200+
// factor c_in.
1201+
}
1202+
} break;
1203+
case TCD: // Strategic Stochastic Sampling (Algorithm 4) in
1204+
// Trajectory Consistency Distillation
1205+
{
1206+
// See J. Zheng et al., "Trajectory Consistency
1207+
// Distillation: Improved Latent Consistency Distillation
1208+
// by Semi-Linear Consistency Function with Trajectory
1209+
// Mapping", arXiv:2402.19159 [cs.CV]
1210+
float beta_start = 0.00085f;
1211+
float beta_end = 0.0120f;
1212+
std::vector<double> alphas_cumprod;
1213+
std::vector<double> compvis_sigmas;
1214+
1215+
alphas_cumprod.reserve(TIMESTEPS);
1216+
compvis_sigmas.reserve(TIMESTEPS);
1217+
for (int i = 0; i < TIMESTEPS; i++) {
1218+
alphas_cumprod[i] =
1219+
(i == 0 ? 1.0f : alphas_cumprod[i - 1]) *
1220+
(1.0f -
1221+
std::pow(sqrtf(beta_start) +
1222+
(sqrtf(beta_end) - sqrtf(beta_start)) *
1223+
((float)i / (TIMESTEPS - 1)), 2));
1224+
compvis_sigmas[i] =
1225+
std::sqrt((1 - alphas_cumprod[i]) /
1226+
alphas_cumprod[i]);
1227+
}
1228+
int original_steps = 50;
1229+
1230+
struct ggml_tensor* pred_original_sample =
1231+
ggml_dup_tensor(work_ctx, x);
1232+
struct ggml_tensor* noise =
1233+
ggml_dup_tensor(work_ctx, x);
1234+
1235+
for (int i = 0; i < steps; i++) {
1236+
// Analytic form for TCD timesteps
1237+
int timestep = TIMESTEPS - 1 -
1238+
(TIMESTEPS / original_steps) *
1239+
(int)floor(i * ((float)original_steps / steps));
1240+
// 1. get previous step value
1241+
int prev_timestep = i >= steps - 1 ? 0 :
1242+
TIMESTEPS - 1 - (TIMESTEPS / original_steps) *
1243+
(int)floor((i + 1) *
1244+
((float)original_steps / steps));
1245+
// Here timestep_s is tau_n' in Algorithm 4. The _s
1246+
// notation appears to be that from C. Lu,
1247+
// "DPM-Solver: A Fast ODE Solver for Diffusion
1248+
// Probabilistic Model Sampling in Around 10 Steps",
1249+
// arXiv:2206.00927 [cs.LG], but this notation is not
1250+
// continued in Algorithm 4, where _n' is used.
1251+
int timestep_s =
1252+
(int)floor((1 - eta) * prev_timestep);
1253+
// Begin k-diffusion specific workaround for
1254+
// evaluating F_theta(x; ...) from D(x, sigma), same
1255+
// as in DDIM (and see there for detailed comments)
1256+
float sigma = compvis_sigmas[timestep];
1257+
if (i == 0) {
1258+
float* vec_x = (float*)x->data;
1259+
for (int j = 0; j < ggml_nelements(x); j++) {
1260+
vec_x[j] *= std::sqrt(sigma * sigma + 1) /
1261+
sigma;
1262+
}
1263+
}
1264+
else {
1265+
float* vec_x = (float*)x->data;
1266+
for (int j = 0; j < ggml_nelements(x); j++) {
1267+
vec_x[j] *= std::sqrt(sigma * sigma + 1);
1268+
}
1269+
}
1270+
struct ggml_tensor* model_output =
1271+
model(x, sigma, i + 1);
1272+
{
1273+
float* vec_x = (float*)x->data;
1274+
float* vec_model_output =
1275+
(float*)model_output->data;
1276+
for (int j = 0; j < ggml_nelements(x); j++) {
1277+
vec_model_output[j] =
1278+
(vec_x[j] - vec_model_output[j]) *
1279+
(1 / sigma);
1280+
}
1281+
}
1282+
// 2. compute alphas, betas
1283+
//
1284+
// When comparing TCD with DDPM/DDIM note that Zheng
1285+
// et al. (2024) follows the DPM-Solver notation for
1286+
// alpha. One can find the following comment in the
1287+
// original DPM-Solver code
1288+
// (https://github.com/LuChengTHU/dpm-solver/):
1289+
// "**Important**: Please pay special attention for
1290+
// the args for `alphas_cumprod`: The `alphas_cumprod`
1291+
// is the \hat{alpha_n} arrays in the notations of
1292+
// DDPM. [...] Therefore, the notation \hat{alpha_n}
1293+
// is different from the notation alpha_t in
1294+
// DPM-Solver. In fact, we have alpha_{t_n} =
1295+
// \sqrt{\hat{alpha_n}}, [...]"
1296+
float alpha_prod_t = alphas_cumprod[timestep];
1297+
float beta_prod_t = 1 - alpha_prod_t;
1298+
// Note final_alpha_cumprod = alphas_cumprod[0] since
1299+
// TCD is always "trailing"
1300+
float alpha_prod_t_prev = prev_timestep >= 0 ?
1301+
alphas_cumprod[prev_timestep] : alphas_cumprod[0];
1302+
// The subscript _s are the only portion in this
1303+
// section (2) unique to TCD
1304+
float alpha_prod_s = alphas_cumprod[timestep_s];
1305+
float beta_prod_s = 1 - alpha_prod_s;
1306+
// 3. Compute the predicted noised sample x_s based on
1307+
// the model parameterization
1308+
//
1309+
// This section is also exactly the same as DDIM
1310+
{
1311+
float* vec_x = (float*)x->data;
1312+
float* vec_model_output =
1313+
(float*)model_output->data;
1314+
float* vec_pred_original_sample =
1315+
(float*)pred_original_sample->data;
1316+
for (int j = 0; j < ggml_nelements(x); j++) {
1317+
vec_pred_original_sample[j] =
1318+
(vec_x[j] / std::sqrt(sigma * sigma + 1) -
1319+
std::sqrt(beta_prod_t) *
1320+
vec_model_output[j]) *
1321+
(1 / std::sqrt(alpha_prod_t));
1322+
}
1323+
}
1324+
// This consistency function step can be difficult to
1325+
// decipher from Algorithm 4, as it is simply stated
1326+
// using a consistency function. This step is the
1327+
// modified DDIM, i.e. p. 8 (32) in Zheng et
1328+
// al. (2024), with eta set to 0 (see the paragraph
1329+
// immediately thereafter that states this somewhat
1330+
// obliquely).
1331+
{
1332+
float* vec_pred_original_sample =
1333+
(float*)pred_original_sample->data;
1334+
float* vec_model_output =
1335+
(float*)model_output->data;
1336+
float* vec_x = (float*)x->data;
1337+
for (int j = 0; j < ggml_nelements(x); j++) {
1338+
// Substituting x = pred_noised_sample and
1339+
// pred_epsilon = model_output
1340+
vec_x[j] =
1341+
std::sqrt(alpha_prod_s) *
1342+
vec_pred_original_sample[j] +
1343+
std::sqrt(beta_prod_s) *
1344+
vec_model_output[j];
1345+
}
1346+
}
1347+
// 4. Sample and inject noise z ~ N(0, I) for
1348+
// MultiStep Inference Noise is not used on the final
1349+
// timestep of the timestep schedule. This also means
1350+
// that noise is not used for one-step sampling. Eta
1351+
// (referred to as "gamma" in the paper) was
1352+
// introduced to control the stochasticity in every
1353+
// step. When eta = 0, it represents deterministic
1354+
// sampling, whereas eta = 1 indicates full stochastic
1355+
// sampling.
1356+
if (eta > 0 && i != steps - 1) {
1357+
// In this case, x is still pred_noised_sample,
1358+
// continue in-place
1359+
ggml_tensor_set_f32_randn(noise, rng);
1360+
float* vec_x = (float*)x->data;
1361+
float* vec_noise = (float*)noise->data;
1362+
for (int j = 0; j < ggml_nelements(x); j++) {
1363+
// Corresponding to (35) in Zheng et
1364+
// al. (2024), substituting x =
1365+
// pred_noised_sample
1366+
vec_x[j] =
1367+
std::sqrt(alpha_prod_t_prev /
1368+
alpha_prod_s) *
1369+
vec_x[j] +
1370+
std::sqrt(1 - alpha_prod_t_prev /
1371+
alpha_prod_s) *
1372+
vec_noise[j];
1373+
}
1374+
}
1375+
}
1376+
} break;
10081377

10091378
default:
10101379
LOG_ERROR("Attempting to sample with nonexisting sample method %i", method);

0 commit comments

Comments
 (0)