Skip to content

Commit 701573d

Browse files
committed
feat(tx): streaming ddim
Signed-off-by: thxCode <thxcode0824@gmail.com>
1 parent b9b2d91 commit 701573d

File tree

2 files changed

+160
-0
lines changed

2 files changed

+160
-0
lines changed

denoiser.hpp

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1835,6 +1835,163 @@ class LCMSampler : public Sampler {
18351835
}
18361836
};
18371837

1838+
class DDIMTrailingSampler : public Sampler {
1839+
// DDIM itself needs alphas_cumprod (DDPM, Ho et al.,
1840+
// arXiv:2006.11239 [cs.LG] with k-diffusion's start and
1841+
// end beta) (which unfortunately k-diffusion's data
1842+
// structure hides from the denoiser), and the sigmas are
1843+
// also needed to invert the behavior of CompVisDenoiser
1844+
// (k-diffusion's LMSDiscreteScheduler)
1845+
private:
1846+
std::vector<double> alphas_cumprod = {};
1847+
std::vector<double> compvis_sigmas = {};
1848+
struct ggml_tensor* noise = nullptr;
1849+
1850+
public:
1851+
DDIMTrailingSampler() {
1852+
alphas_cumprod.reserve(TIMESTEPS);
1853+
compvis_sigmas.reserve(TIMESTEPS);
1854+
for (int i = 0; i < TIMESTEPS; i++) {
1855+
alphas_cumprod[i] =
1856+
(i == 0 ? 1.0f : alphas_cumprod[i - 1]) *
1857+
(1.0f -
1858+
std::pow(sqrtf(0.00085f) +
1859+
(sqrtf(0.0120f) - sqrtf(0.00085f)) *
1860+
((float)i / (TIMESTEPS - 1)),
1861+
2));
1862+
compvis_sigmas[i] =
1863+
std::sqrt((1 - alphas_cumprod[i]) /
1864+
alphas_cumprod[i]);
1865+
}
1866+
};
1867+
~DDIMTrailingSampler() {
1868+
alphas_cumprod.clear();
1869+
compvis_sigmas.clear();
1870+
delete noise;
1871+
}
1872+
1873+
void sample(denoise_cb_t model, ggml_context* work_ctx, ggml_tensor* x, std::vector<float> sigmas, std::shared_ptr<RNG> rng, int i) override {
1874+
const int steps = sigmas.size() - 1;
1875+
// The "trailing" DDIM timestep, see S. Lin et al.,
1876+
// "Common Diffusion Noise Schedules and Sample Steps
1877+
// are Flawed", arXiv:2305.08891 [cs], p. 4, Table
1878+
// 2. Most variables below follow Diffusers naming.
1879+
int timestep = roundf(TIMESTEPS - i * ((float)TIMESTEPS / steps)) - 1;
1880+
int prev_timestep = timestep - TIMESTEPS / steps;
1881+
// The sigma here is chosen to cause the
1882+
// CompVisDenoiser to produce t = timestep
1883+
float sigma = compvis_sigmas[timestep];
1884+
if (i == 0) {
1885+
// The function add_noise intializes x to
1886+
// Diffusers' latents * sigma (as in Diffusers'
1887+
// pipeline) or sample * sigma (Diffusers'
1888+
// scheduler), where this sigma = init_noise_sigma
1889+
// in Diffusers. For DDPM and DDIM however,
1890+
// init_noise_sigma = 1. But the k-diffusion
1891+
// model() also evaluates F_theta(c_in(sigma) x;
1892+
// ...) instead of the bare U-net F_theta, with
1893+
// c_in = 1 / sqrt(sigma^2 + 1), as defined in
1894+
// T. Karras et al., "Elucidating the Design Space
1895+
// of Diffusion-Based Generative Models",
1896+
// arXiv:2206.00364 [cs.CV], p. 3, Table 1. Hence
1897+
// the first call has to be prescaled as x <- x /
1898+
// (c_in * sigma) with the k-diffusion pipeline
1899+
// and CompVisDenoiser.
1900+
float* vec_x = (float*)x->data;
1901+
for (int j = 0; j < ggml_nelements(x); j++) {
1902+
vec_x[j] *= std::sqrt(sigma * sigma + 1) / sigma;
1903+
}
1904+
} else {
1905+
// For the subsequent steps after the first one,
1906+
// at this point x = latents (pipeline) or x =
1907+
// sample (scheduler), and needs to be prescaled
1908+
// with x <- latents / c_in to compensate for
1909+
// model() applying the scale c_in before the
1910+
// U-net F_theta
1911+
float* vec_x = (float*)x->data;
1912+
for (int j = 0; j < ggml_nelements(x); j++) {
1913+
vec_x[j] *= std::sqrt(sigma * sigma + 1);
1914+
}
1915+
}
1916+
// Note model() is the D(x, sigma) as defined in
1917+
// T. Karras et al., arXiv:2206.00364, p. 3, Table 1
1918+
// and p. 8 (7)
1919+
struct ggml_tensor* noise_pred = model(x, sigma, i + 1);
1920+
// Here noise_pred is still the k-diffusion denoiser
1921+
// output, not the U-net output F_theta(c_in(sigma) x;
1922+
// ...) in Karras et al. (2022), whereas Diffusers'
1923+
// noise_pred is F_theta(...). Recover the actual
1924+
// noise_pred, which is also referred to as the
1925+
// "Karras ODE derivative" d or d_cur in several
1926+
// samplers above.
1927+
{
1928+
float* vec_x = (float*)x->data;
1929+
float* vec_noise_pred = (float*)noise_pred->data;
1930+
for (int j = 0; j < ggml_nelements(x); j++) {
1931+
vec_noise_pred[j] = (vec_x[j] - vec_noise_pred[j]) * (1 / sigma);
1932+
}
1933+
}
1934+
// 2. compute alphas, betas
1935+
float alpha_prod_t = alphas_cumprod[timestep];
1936+
// Note final_alpha_cumprod = alphas_cumprod[0]
1937+
float alpha_prod_t_prev = prev_timestep >= 0 ? alphas_cumprod[prev_timestep] : alphas_cumprod[0];
1938+
float beta_prod_t = 1 - alpha_prod_t;
1939+
// 3. compute predicted original sample from predicted
1940+
// noise also called "predicted x_0" of formula (12)
1941+
// from https://arxiv.org/pdf/2010.02502.pdf
1942+
struct ggml_tensor* pred_original_sample =
1943+
ggml_dup_tensor(work_ctx, x);
1944+
{
1945+
float* vec_x = (float*)x->data;
1946+
float* vec_noise_pred = (float*)noise_pred->data;
1947+
float* vec_pred_original_sample = (float*)pred_original_sample->data;
1948+
// Note the substitution of latents or sample = x
1949+
// * c_in = x / sqrt(sigma^2 + 1)
1950+
for (int j = 0; j < ggml_nelements(x); j++) {
1951+
vec_pred_original_sample[j] = (vec_x[j] / std::sqrt(sigma * sigma + 1) - std::sqrt(beta_prod_t) * vec_noise_pred[j]) * (1 / std::sqrt(alpha_prod_t));
1952+
}
1953+
}
1954+
// Assuming the "epsilon" prediction type, where below
1955+
// pred_epsilon = noise_pred is inserted, and is not
1956+
// defined/copied explicitly.
1957+
//
1958+
// 5. compute variance: "sigma_t(eta)" -> see formula
1959+
// (16)
1960+
//
1961+
// sigma_t = sqrt((1 - alpha_t-1)/(1 - alpha_t)) *
1962+
// sqrt(1 - alpha_t/alpha_t-1)
1963+
float beta_prod_t_prev = 1 - alpha_prod_t_prev;
1964+
float variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev);
1965+
float std_dev_t = 0 * std::sqrt(variance);
1966+
// 6. compute "direction pointing to x_t" of formula
1967+
// (12) from https://arxiv.org/pdf/2010.02502.pdf
1968+
struct ggml_tensor* pred_sample_direction = ggml_dup_tensor(work_ctx, noise_pred);
1969+
{
1970+
float* vec_noise_pred = (float*)noise_pred->data;
1971+
float* vec_pred_sample_direction = (float*)pred_sample_direction->data;
1972+
for (int j = 0; j < ggml_nelements(x); j++) {
1973+
vec_pred_sample_direction[j] = std::sqrt(1 - alpha_prod_t_prev - std::pow(std_dev_t, 2)) * vec_noise_pred[j];
1974+
}
1975+
}
1976+
// 7. compute x_t without "random noise" of formula
1977+
// (12) from https://arxiv.org/pdf/2010.02502.pdf
1978+
{
1979+
float* vec_pred_original_sample = (float*)pred_original_sample->data;
1980+
float* vec_pred_sample_direction = (float*)pred_sample_direction->data;
1981+
float* vec_x = (float*)x->data;
1982+
for (int j = 0; j < ggml_nelements(x); j++) {
1983+
vec_x[j] = std::sqrt(alpha_prod_t_prev) * vec_pred_original_sample[j] + vec_pred_sample_direction[j];
1984+
}
1985+
}
1986+
// See the note above: x = latents or sample here, and
1987+
// is not scaled by the c_in. For the final output
1988+
// this is correct, but for subsequent iterations, x
1989+
// needs to be prescaled again, since k-diffusion's
1990+
// model() differes from the bare U-net F_theta by the
1991+
// factor c_in.
1992+
}
1993+
};
1994+
18381995
std::shared_ptr<Sampler> get_sampler(sample_method_t method) {
18391996
switch (method) {
18401997
case EULER_A:
@@ -1857,6 +2014,8 @@ std::shared_ptr<Sampler> get_sampler(sample_method_t method) {
18572014
return std::make_shared<IPNDMVSampler>();
18582015
case LCM:
18592016
return std::make_shared<LCMSampler>();
2017+
case DDIM_TRAILING:
2018+
return std::make_shared<DDIMTrailingSampler>();
18602019
default:
18612020
LOG_ERROR("Attempting to sample with nonexisting sample method %i", method);
18622021
abort();

stable-diffusion.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ static const char* sample_methods_argument_str[] = {
8484
"ipndm",
8585
"ipndm_v",
8686
"lcm",
87+
"ddim_trailing",
8788
};
8889

8990
sample_method_t sd_argument_to_sample_method(const char* str) {

0 commit comments

Comments
 (0)