diff --git a/denoiser.hpp b/denoiser.hpp index 66799109..ee4ae517 100644 --- a/denoiser.hpp +++ b/denoiser.hpp @@ -168,24 +168,21 @@ struct AYSSchedule : SigmaSchedule { std::vector inputs; std::vector results(n + 1); - switch (version) { - case VERSION_SD2: /* fallthrough */ - LOG_WARN("AYS not designed for SD2.X models"); - case VERSION_SD1: - LOG_INFO("AYS using SD1.5 noise levels"); - inputs = noise_levels[0]; - break; - case VERSION_SDXL: - LOG_INFO("AYS using SDXL noise levels"); - inputs = noise_levels[1]; - break; - case VERSION_SVD: - LOG_INFO("AYS using SVD noise levels"); - inputs = noise_levels[2]; - break; - default: - LOG_ERROR("Version not compatable with AYS scheduler"); - return results; + if (sd_version_is_sd2((SDVersion)version)) { + LOG_WARN("AYS not designed for SD2.X models"); + } /* fallthrough */ + else if (sd_version_is_sd1((SDVersion)version)) { + LOG_INFO("AYS using SD1.5 noise levels"); + inputs = noise_levels[0]; + } else if (sd_version_is_sdxl((SDVersion)version)) { + LOG_INFO("AYS using SDXL noise levels"); + inputs = noise_levels[1]; + } else if (version == VERSION_SVD) { + LOG_INFO("AYS using SVD noise levels"); + inputs = noise_levels[2]; + } else { + LOG_ERROR("Version not compatable with AYS scheduler"); + return results; } /* Stretches those pre-calculated reference levels out to the desired @@ -346,6 +343,31 @@ struct CompVisVDenoiser : public CompVisDenoiser { } }; +struct EDMVDenoiser : public CompVisVDenoiser { + float min_sigma = 0.002; + float max_sigma = 120.0; + + EDMVDenoiser(float min_sigma = 0.002, float max_sigma = 120.0) : min_sigma(min_sigma), max_sigma(max_sigma) { + schedule = std::make_shared(); + } + + float t_to_sigma(float t) { + return std::exp(t * 4/(float)TIMESTEPS); + } + + float sigma_to_t(float s) { + return 0.25 * std::log(s); + } + + float sigma_min() { + return min_sigma; + } + + float sigma_max() { + return max_sigma; + } +}; + float time_snr_shift(float alpha, float t) { if (alpha == 1.0f) { return t; @@ -1019,7 +1041,7 @@ static void sample_k_diffusion(sample_method_t method, // also needed to invert the behavior of CompVisDenoiser // (k-diffusion's LMSDiscreteScheduler) float beta_start = 0.00085f; - float beta_end = 0.0120f; + float beta_end = 0.0120f; std::vector alphas_cumprod; std::vector compvis_sigmas; @@ -1030,8 +1052,9 @@ static void sample_k_diffusion(sample_method_t method, (i == 0 ? 1.0f : alphas_cumprod[i - 1]) * (1.0f - std::pow(sqrtf(beta_start) + - (sqrtf(beta_end) - sqrtf(beta_start)) * - ((float)i / (TIMESTEPS - 1)), 2)); + (sqrtf(beta_end) - sqrtf(beta_start)) * + ((float)i / (TIMESTEPS - 1)), + 2)); compvis_sigmas[i] = std::sqrt((1 - alphas_cumprod[i]) / alphas_cumprod[i]); @@ -1061,7 +1084,8 @@ static void sample_k_diffusion(sample_method_t method, // - pred_prev_sample -> "x_t-1" int timestep = roundf(TIMESTEPS - - i * ((float)TIMESTEPS / steps)) - 1; + i * ((float)TIMESTEPS / steps)) - + 1; // 1. get previous step value (=t-1) int prev_timestep = timestep - TIMESTEPS / steps; // The sigma here is chosen to cause the @@ -1086,10 +1110,9 @@ static void sample_k_diffusion(sample_method_t method, float* vec_x = (float*)x->data; for (int j = 0; j < ggml_nelements(x); j++) { vec_x[j] *= std::sqrt(sigma * sigma + 1) / - sigma; + sigma; } - } - else { + } else { // For the subsequent steps after the first one, // at this point x = latents or x = sample, and // needs to be prescaled with x <- sample / c_in @@ -1127,9 +1150,8 @@ static void sample_k_diffusion(sample_method_t method, float alpha_prod_t = alphas_cumprod[timestep]; // Note final_alpha_cumprod = alphas_cumprod[0] due to // trailing timestep spacing - float alpha_prod_t_prev = prev_timestep >= 0 ? - alphas_cumprod[prev_timestep] : alphas_cumprod[0]; - float beta_prod_t = 1 - alpha_prod_t; + float alpha_prod_t_prev = prev_timestep >= 0 ? alphas_cumprod[prev_timestep] : alphas_cumprod[0]; + float beta_prod_t = 1 - alpha_prod_t; // 3. compute predicted original sample from predicted // noise also called "predicted x_0" of formula (12) // from https://arxiv.org/pdf/2010.02502.pdf @@ -1145,7 +1167,7 @@ static void sample_k_diffusion(sample_method_t method, vec_pred_original_sample[j] = (vec_x[j] / std::sqrt(sigma * sigma + 1) - std::sqrt(beta_prod_t) * - vec_model_output[j]) * + vec_model_output[j]) * (1 / std::sqrt(alpha_prod_t)); } } @@ -1159,8 +1181,8 @@ static void sample_k_diffusion(sample_method_t method, // sigma_t = sqrt((1 - alpha_t-1)/(1 - alpha_t)) * // sqrt(1 - alpha_t/alpha_t-1) float beta_prod_t_prev = 1 - alpha_prod_t_prev; - float variance = (beta_prod_t_prev / beta_prod_t) * - (1 - alpha_prod_t / alpha_prod_t_prev); + float variance = (beta_prod_t_prev / beta_prod_t) * + (1 - alpha_prod_t / alpha_prod_t_prev); float std_dev_t = eta * std::sqrt(variance); // 6. compute "direction pointing to x_t" of formula // (12) from https://arxiv.org/pdf/2010.02502.pdf @@ -1179,8 +1201,8 @@ static void sample_k_diffusion(sample_method_t method, std::pow(std_dev_t, 2)) * vec_model_output[j]; vec_x[j] = std::sqrt(alpha_prod_t_prev) * - vec_pred_original_sample[j] + - pred_sample_direction; + vec_pred_original_sample[j] + + pred_sample_direction; } } if (eta > 0) { @@ -1208,7 +1230,7 @@ static void sample_k_diffusion(sample_method_t method, // by Semi-Linear Consistency Function with Trajectory // Mapping", arXiv:2402.19159 [cs.CV] float beta_start = 0.00085f; - float beta_end = 0.0120f; + float beta_end = 0.0120f; std::vector alphas_cumprod; std::vector compvis_sigmas; @@ -1219,8 +1241,9 @@ static void sample_k_diffusion(sample_method_t method, (i == 0 ? 1.0f : alphas_cumprod[i - 1]) * (1.0f - std::pow(sqrtf(beta_start) + - (sqrtf(beta_end) - sqrtf(beta_start)) * - ((float)i / (TIMESTEPS - 1)), 2)); + (sqrtf(beta_end) - sqrtf(beta_start)) * + ((float)i / (TIMESTEPS - 1)), + 2)); compvis_sigmas[i] = std::sqrt((1 - alphas_cumprod[i]) / alphas_cumprod[i]); @@ -1235,13 +1258,10 @@ static void sample_k_diffusion(sample_method_t method, for (int i = 0; i < steps; i++) { // Analytic form for TCD timesteps int timestep = TIMESTEPS - 1 - - (TIMESTEPS / original_steps) * - (int)floor(i * ((float)original_steps / steps)); + (TIMESTEPS / original_steps) * + (int)floor(i * ((float)original_steps / steps)); // 1. get previous step value - int prev_timestep = i >= steps - 1 ? 0 : - TIMESTEPS - 1 - (TIMESTEPS / original_steps) * - (int)floor((i + 1) * - ((float)original_steps / steps)); + int prev_timestep = i >= steps - 1 ? 0 : TIMESTEPS - 1 - (TIMESTEPS / original_steps) * (int)floor((i + 1) * ((float)original_steps / steps)); // Here timestep_s is tau_n' in Algorithm 4. The _s // notation appears to be that from C. Lu, // "DPM-Solver: A Fast ODE Solver for Diffusion @@ -1258,10 +1278,9 @@ static void sample_k_diffusion(sample_method_t method, float* vec_x = (float*)x->data; for (int j = 0; j < ggml_nelements(x); j++) { vec_x[j] *= std::sqrt(sigma * sigma + 1) / - sigma; + sigma; } - } - else { + } else { float* vec_x = (float*)x->data; for (int j = 0; j < ggml_nelements(x); j++) { vec_x[j] *= std::sqrt(sigma * sigma + 1); @@ -1294,15 +1313,14 @@ static void sample_k_diffusion(sample_method_t method, // DPM-Solver. In fact, we have alpha_{t_n} = // \sqrt{\hat{alpha_n}}, [...]" float alpha_prod_t = alphas_cumprod[timestep]; - float beta_prod_t = 1 - alpha_prod_t; + float beta_prod_t = 1 - alpha_prod_t; // Note final_alpha_cumprod = alphas_cumprod[0] since // TCD is always "trailing" - float alpha_prod_t_prev = prev_timestep >= 0 ? - alphas_cumprod[prev_timestep] : alphas_cumprod[0]; + float alpha_prod_t_prev = prev_timestep >= 0 ? alphas_cumprod[prev_timestep] : alphas_cumprod[0]; // The subscript _s are the only portion in this // section (2) unique to TCD float alpha_prod_s = alphas_cumprod[timestep_s]; - float beta_prod_s = 1 - alpha_prod_s; + float beta_prod_s = 1 - alpha_prod_s; // 3. Compute the predicted noised sample x_s based on // the model parameterization // @@ -1317,7 +1335,7 @@ static void sample_k_diffusion(sample_method_t method, vec_pred_original_sample[j] = (vec_x[j] / std::sqrt(sigma * sigma + 1) - std::sqrt(beta_prod_t) * - vec_model_output[j]) * + vec_model_output[j]) * (1 / std::sqrt(alpha_prod_t)); } } @@ -1339,9 +1357,9 @@ static void sample_k_diffusion(sample_method_t method, // pred_epsilon = model_output vec_x[j] = std::sqrt(alpha_prod_s) * - vec_pred_original_sample[j] + + vec_pred_original_sample[j] + std::sqrt(beta_prod_s) * - vec_model_output[j]; + vec_model_output[j]; } } // 4. Sample and inject noise z ~ N(0, I) for @@ -1357,7 +1375,7 @@ static void sample_k_diffusion(sample_method_t method, // In this case, x is still pred_noised_sample, // continue in-place ggml_tensor_set_f32_randn(noise, rng); - float* vec_x = (float*)x->data; + float* vec_x = (float*)x->data; float* vec_noise = (float*)noise->data; for (int j = 0; j < ggml_nelements(x); j++) { // Corresponding to (35) in Zheng et @@ -1366,10 +1384,10 @@ static void sample_k_diffusion(sample_method_t method, vec_x[j] = std::sqrt(alpha_prod_t_prev / alpha_prod_s) * - vec_x[j] + + vec_x[j] + std::sqrt(1 - alpha_prod_t_prev / - alpha_prod_s) * - vec_noise[j]; + alpha_prod_s) * + vec_noise[j]; } } } diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index e38a6101..a366b148 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -104,6 +104,9 @@ class StableDiffusionGGML { bool vae_tiling = false; bool stacked_id = false; + bool is_using_v_parameterization = false; + bool is_using_edm_v_parameterization = false; + std::map tensors; std::string lora_model_dir; @@ -522,12 +525,17 @@ class StableDiffusionGGML { LOG_INFO("loading model from '%s' completed, taking %.2fs", model_path.c_str(), (t1 - t0) * 1.0f / 1000); // check is_using_v_parameterization_for_sd2 - bool is_using_v_parameterization = false; + if (sd_version_is_sd2(version)) { if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) { is_using_v_parameterization = true; } } else if (sd_version_is_sdxl(version)) { + if (model_loader.tensor_storages_types.find("edm_vpred.sigma_max") != model_loader.tensor_storages_types.end()) { + // CosXL models + // TODO: get sigma_min and sigma_max values from file + is_using_edm_v_parameterization = true; + } if (model_loader.tensor_storages_types.find("v_pred") != model_loader.tensor_storages_types.end()) { is_using_v_parameterization = true; } @@ -552,6 +560,9 @@ class StableDiffusionGGML { } else if (is_using_v_parameterization) { LOG_INFO("running in v-prediction mode"); denoiser = std::make_shared(); + } else if (is_using_edm_v_parameterization) { + LOG_INFO("running in v-prediction EDM mode"); + denoiser = std::make_shared(); } else { LOG_INFO("running in eps-prediction mode"); } @@ -1363,7 +1374,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, SDCondition uncond; if (cfg_scale != 1.0) { bool force_zero_embeddings = false; - if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0) { + if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0 && !sd_ctx->sd->is_using_edm_v_parameterization) { force_zero_embeddings = true; } uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,