diff --git a/denoiser.hpp b/denoiser.hpp index 66799109..ee4ae517 100644 --- a/denoiser.hpp +++ b/denoiser.hpp @@ -168,24 +168,21 @@ struct AYSSchedule : SigmaSchedule { std::vector inputs; std::vector results(n + 1); - switch (version) { - case VERSION_SD2: /* fallthrough */ - LOG_WARN("AYS not designed for SD2.X models"); - case VERSION_SD1: - LOG_INFO("AYS using SD1.5 noise levels"); - inputs = noise_levels[0]; - break; - case VERSION_SDXL: - LOG_INFO("AYS using SDXL noise levels"); - inputs = noise_levels[1]; - break; - case VERSION_SVD: - LOG_INFO("AYS using SVD noise levels"); - inputs = noise_levels[2]; - break; - default: - LOG_ERROR("Version not compatable with AYS scheduler"); - return results; + if (sd_version_is_sd2((SDVersion)version)) { + LOG_WARN("AYS not designed for SD2.X models"); + } /* fallthrough */ + else if (sd_version_is_sd1((SDVersion)version)) { + LOG_INFO("AYS using SD1.5 noise levels"); + inputs = noise_levels[0]; + } else if (sd_version_is_sdxl((SDVersion)version)) { + LOG_INFO("AYS using SDXL noise levels"); + inputs = noise_levels[1]; + } else if (version == VERSION_SVD) { + LOG_INFO("AYS using SVD noise levels"); + inputs = noise_levels[2]; + } else { + LOG_ERROR("Version not compatable with AYS scheduler"); + return results; } /* Stretches those pre-calculated reference levels out to the desired @@ -346,6 +343,31 @@ struct CompVisVDenoiser : public CompVisDenoiser { } }; +struct EDMVDenoiser : public CompVisVDenoiser { + float min_sigma = 0.002; + float max_sigma = 120.0; + + EDMVDenoiser(float min_sigma = 0.002, float max_sigma = 120.0) : min_sigma(min_sigma), max_sigma(max_sigma) { + schedule = std::make_shared(); + } + + float t_to_sigma(float t) { + return std::exp(t * 4/(float)TIMESTEPS); + } + + float sigma_to_t(float s) { + return 0.25 * std::log(s); + } + + float sigma_min() { + return min_sigma; + } + + float sigma_max() { + return max_sigma; + } +}; + float time_snr_shift(float alpha, float t) { if (alpha == 1.0f) { return t; @@ -1019,7 +1041,7 @@ static void sample_k_diffusion(sample_method_t method, // also needed to invert the behavior of CompVisDenoiser // (k-diffusion's LMSDiscreteScheduler) float beta_start = 0.00085f; - float beta_end = 0.0120f; + float beta_end = 0.0120f; std::vector alphas_cumprod; std::vector compvis_sigmas; @@ -1030,8 +1052,9 @@ static void sample_k_diffusion(sample_method_t method, (i == 0 ? 1.0f : alphas_cumprod[i - 1]) * (1.0f - std::pow(sqrtf(beta_start) + - (sqrtf(beta_end) - sqrtf(beta_start)) * - ((float)i / (TIMESTEPS - 1)), 2)); + (sqrtf(beta_end) - sqrtf(beta_start)) * + ((float)i / (TIMESTEPS - 1)), + 2)); compvis_sigmas[i] = std::sqrt((1 - alphas_cumprod[i]) / alphas_cumprod[i]); @@ -1061,7 +1084,8 @@ static void sample_k_diffusion(sample_method_t method, // - pred_prev_sample -> "x_t-1" int timestep = roundf(TIMESTEPS - - i * ((float)TIMESTEPS / steps)) - 1; + i * ((float)TIMESTEPS / steps)) - + 1; // 1. get previous step value (=t-1) int prev_timestep = timestep - TIMESTEPS / steps; // The sigma here is chosen to cause the @@ -1086,10 +1110,9 @@ static void sample_k_diffusion(sample_method_t method, float* vec_x = (float*)x->data; for (int j = 0; j < ggml_nelements(x); j++) { vec_x[j] *= std::sqrt(sigma * sigma + 1) / - sigma; + sigma; } - } - else { + } else { // For the subsequent steps after the first one, // at this point x = latents or x = sample, and // needs to be prescaled with x <- sample / c_in @@ -1127,9 +1150,8 @@ static void sample_k_diffusion(sample_method_t method, float alpha_prod_t = alphas_cumprod[timestep]; // Note final_alpha_cumprod = alphas_cumprod[0] due to // trailing timestep spacing - float alpha_prod_t_prev = prev_timestep >= 0 ? - alphas_cumprod[prev_timestep] : alphas_cumprod[0]; - float beta_prod_t = 1 - alpha_prod_t; + float alpha_prod_t_prev = prev_timestep >= 0 ? alphas_cumprod[prev_timestep] : alphas_cumprod[0]; + float beta_prod_t = 1 - alpha_prod_t; // 3. compute predicted original sample from predicted // noise also called "predicted x_0" of formula (12) // from https://arxiv.org/pdf/2010.02502.pdf @@ -1145,7 +1167,7 @@ static void sample_k_diffusion(sample_method_t method, vec_pred_original_sample[j] = (vec_x[j] / std::sqrt(sigma * sigma + 1) - std::sqrt(beta_prod_t) * - vec_model_output[j]) * + vec_model_output[j]) * (1 / std::sqrt(alpha_prod_t)); } } @@ -1159,8 +1181,8 @@ static void sample_k_diffusion(sample_method_t method, // sigma_t = sqrt((1 - alpha_t-1)/(1 - alpha_t)) * // sqrt(1 - alpha_t/alpha_t-1) float beta_prod_t_prev = 1 - alpha_prod_t_prev; - float variance = (beta_prod_t_prev / beta_prod_t) * - (1 - alpha_prod_t / alpha_prod_t_prev); + float variance = (beta_prod_t_prev / beta_prod_t) * + (1 - alpha_prod_t / alpha_prod_t_prev); float std_dev_t = eta * std::sqrt(variance); // 6. compute "direction pointing to x_t" of formula // (12) from https://arxiv.org/pdf/2010.02502.pdf @@ -1179,8 +1201,8 @@ static void sample_k_diffusion(sample_method_t method, std::pow(std_dev_t, 2)) * vec_model_output[j]; vec_x[j] = std::sqrt(alpha_prod_t_prev) * - vec_pred_original_sample[j] + - pred_sample_direction; + vec_pred_original_sample[j] + + pred_sample_direction; } } if (eta > 0) { @@ -1208,7 +1230,7 @@ static void sample_k_diffusion(sample_method_t method, // by Semi-Linear Consistency Function with Trajectory // Mapping", arXiv:2402.19159 [cs.CV] float beta_start = 0.00085f; - float beta_end = 0.0120f; + float beta_end = 0.0120f; std::vector alphas_cumprod; std::vector compvis_sigmas; @@ -1219,8 +1241,9 @@ static void sample_k_diffusion(sample_method_t method, (i == 0 ? 1.0f : alphas_cumprod[i - 1]) * (1.0f - std::pow(sqrtf(beta_start) + - (sqrtf(beta_end) - sqrtf(beta_start)) * - ((float)i / (TIMESTEPS - 1)), 2)); + (sqrtf(beta_end) - sqrtf(beta_start)) * + ((float)i / (TIMESTEPS - 1)), + 2)); compvis_sigmas[i] = std::sqrt((1 - alphas_cumprod[i]) / alphas_cumprod[i]); @@ -1235,13 +1258,10 @@ static void sample_k_diffusion(sample_method_t method, for (int i = 0; i < steps; i++) { // Analytic form for TCD timesteps int timestep = TIMESTEPS - 1 - - (TIMESTEPS / original_steps) * - (int)floor(i * ((float)original_steps / steps)); + (TIMESTEPS / original_steps) * + (int)floor(i * ((float)original_steps / steps)); // 1. get previous step value - int prev_timestep = i >= steps - 1 ? 0 : - TIMESTEPS - 1 - (TIMESTEPS / original_steps) * - (int)floor((i + 1) * - ((float)original_steps / steps)); + int prev_timestep = i >= steps - 1 ? 0 : TIMESTEPS - 1 - (TIMESTEPS / original_steps) * (int)floor((i + 1) * ((float)original_steps / steps)); // Here timestep_s is tau_n' in Algorithm 4. The _s // notation appears to be that from C. Lu, // "DPM-Solver: A Fast ODE Solver for Diffusion @@ -1258,10 +1278,9 @@ static void sample_k_diffusion(sample_method_t method, float* vec_x = (float*)x->data; for (int j = 0; j < ggml_nelements(x); j++) { vec_x[j] *= std::sqrt(sigma * sigma + 1) / - sigma; + sigma; } - } - else { + } else { float* vec_x = (float*)x->data; for (int j = 0; j < ggml_nelements(x); j++) { vec_x[j] *= std::sqrt(sigma * sigma + 1); @@ -1294,15 +1313,14 @@ static void sample_k_diffusion(sample_method_t method, // DPM-Solver. In fact, we have alpha_{t_n} = // \sqrt{\hat{alpha_n}}, [...]" float alpha_prod_t = alphas_cumprod[timestep]; - float beta_prod_t = 1 - alpha_prod_t; + float beta_prod_t = 1 - alpha_prod_t; // Note final_alpha_cumprod = alphas_cumprod[0] since // TCD is always "trailing" - float alpha_prod_t_prev = prev_timestep >= 0 ? - alphas_cumprod[prev_timestep] : alphas_cumprod[0]; + float alpha_prod_t_prev = prev_timestep >= 0 ? alphas_cumprod[prev_timestep] : alphas_cumprod[0]; // The subscript _s are the only portion in this // section (2) unique to TCD float alpha_prod_s = alphas_cumprod[timestep_s]; - float beta_prod_s = 1 - alpha_prod_s; + float beta_prod_s = 1 - alpha_prod_s; // 3. Compute the predicted noised sample x_s based on // the model parameterization // @@ -1317,7 +1335,7 @@ static void sample_k_diffusion(sample_method_t method, vec_pred_original_sample[j] = (vec_x[j] / std::sqrt(sigma * sigma + 1) - std::sqrt(beta_prod_t) * - vec_model_output[j]) * + vec_model_output[j]) * (1 / std::sqrt(alpha_prod_t)); } } @@ -1339,9 +1357,9 @@ static void sample_k_diffusion(sample_method_t method, // pred_epsilon = model_output vec_x[j] = std::sqrt(alpha_prod_s) * - vec_pred_original_sample[j] + + vec_pred_original_sample[j] + std::sqrt(beta_prod_s) * - vec_model_output[j]; + vec_model_output[j]; } } // 4. Sample and inject noise z ~ N(0, I) for @@ -1357,7 +1375,7 @@ static void sample_k_diffusion(sample_method_t method, // In this case, x is still pred_noised_sample, // continue in-place ggml_tensor_set_f32_randn(noise, rng); - float* vec_x = (float*)x->data; + float* vec_x = (float*)x->data; float* vec_noise = (float*)noise->data; for (int j = 0; j < ggml_nelements(x); j++) { // Corresponding to (35) in Zheng et @@ -1366,10 +1384,10 @@ static void sample_k_diffusion(sample_method_t method, vec_x[j] = std::sqrt(alpha_prod_t_prev / alpha_prod_s) * - vec_x[j] + + vec_x[j] + std::sqrt(1 - alpha_prod_t_prev / - alpha_prod_s) * - vec_noise[j]; + alpha_prod_s) * + vec_noise[j]; } } } diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index af6b2bbd..3652177c 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -92,15 +92,16 @@ struct SDParams { std::string prompt; std::string negative_prompt; - float min_cfg = 1.0f; - float cfg_scale = 7.0f; - float guidance = 3.5f; - float eta = 0.f; - float style_ratio = 20.f; - int clip_skip = -1; // <= 0 represents unspecified - int width = 512; - int height = 512; - int batch_count = 1; + float min_cfg = 1.0f; + float cfg_scale = 7.0f; + float img_cfg_scale = INFINITY; + float guidance = 3.5f; + float eta = 0.f; + float style_ratio = 20.f; + int clip_skip = -1; // <= 0 represents unspecified + int width = 512; + int height = 512; + int batch_count = 1; int video_frames = 6; int motion_bucket_id = 127; @@ -163,6 +164,7 @@ void print_params(SDParams params) { printf(" negative_prompt: %s\n", params.negative_prompt.c_str()); printf(" min_cfg: %.2f\n", params.min_cfg); printf(" cfg_scale: %.2f\n", params.cfg_scale); + printf(" img_cfg_scale: %.2f\n", params.img_cfg_scale); printf(" slg_scale: %.2f\n", params.slg_scale); printf(" guidance: %.2f\n", params.guidance); printf(" eta: %.2f\n", params.eta); @@ -212,7 +214,8 @@ void print_usage(int argc, const char* argv[]) { printf(" -p, --prompt [PROMPT] the prompt to render\n"); printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n"); printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n"); - printf(" --guidance SCALE guidance scale for img2img (default: 3.5)\n"); + printf(" --img-cfg-scale SCALE image guidance scale for inpaint or instruct-pix2pix models: (default: same as --cfg-scale)\n"); + printf(" --guidance SCALE distilled guidance scale for models with guidance input (default: 3.5)\n"); printf(" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n"); printf(" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n"); printf(" --eta SCALE eta in DDIM, only for DDIM and TCD: (default: 0)\n"); @@ -439,6 +442,12 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.cfg_scale = std::stof(argv[i]); + } else if (arg == "--img-cfg-scale") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.img_cfg_scale = std::stof(argv[i]); } else if (arg == "--guidance") { if (++i >= argc) { invalid_arg = true; @@ -698,6 +707,10 @@ void parse_args(int argc, const char** argv, SDParams& params) { params.output_path = "output.gguf"; } } + + if (!isfinite(params.img_cfg_scale)) { + params.img_cfg_scale = params.cfg_scale; + } } static std::string sd_basename(const std::string& path) { @@ -792,6 +805,18 @@ int main(int argc, const char* argv[]) { parse_args(argc, argv, params); + sd_guidance_params_t guidance_params = {params.cfg_scale, + params.img_cfg_scale, + params.min_cfg, + params.guidance, + { + params.skip_layers.data(), + params.skip_layers.size(), + params.skip_layer_start, + params.skip_layer_end, + params.slg_scale, + }}; + sd_set_log_callback(sd_log_cb, (void*)¶ms); if (params.verbose) { @@ -949,8 +974,7 @@ int main(int argc, const char* argv[]) { params.prompt.c_str(), params.negative_prompt.c_str(), params.clip_skip, - params.cfg_scale, - params.guidance, + guidance_params, params.eta, params.width, params.height, @@ -962,12 +986,7 @@ int main(int argc, const char* argv[]) { params.control_strength, params.style_ratio, params.normalize_input, - params.input_id_images_path.c_str(), - params.skip_layers.data(), - params.skip_layers.size(), - params.slg_scale, - params.skip_layer_start, - params.skip_layer_end); + params.input_id_images_path.c_str()); } else { sd_image_t input_image = {(uint32_t)params.width, (uint32_t)params.height, @@ -983,8 +1002,7 @@ int main(int argc, const char* argv[]) { params.motion_bucket_id, params.fps, params.augmentation_level, - params.min_cfg, - params.cfg_scale, + guidance_params, params.sample_method, params.sample_steps, params.strength, @@ -1017,8 +1035,7 @@ int main(int argc, const char* argv[]) { params.prompt.c_str(), params.negative_prompt.c_str(), params.clip_skip, - params.cfg_scale, - params.guidance, + guidance_params, params.eta, params.width, params.height, @@ -1031,12 +1048,7 @@ int main(int argc, const char* argv[]) { params.control_strength, params.style_ratio, params.normalize_input, - params.input_id_images_path.c_str(), - params.skip_layers.data(), - params.skip_layers.size(), - params.slg_scale, - params.skip_layer_start, - params.skip_layer_end); + params.input_id_images_path.c_str()); } } @@ -1075,11 +1087,11 @@ int main(int argc, const char* argv[]) { std::string dummy_name, ext, lc_ext; bool is_jpg; - size_t last = params.output_path.find_last_of("."); + size_t last = params.output_path.find_last_of("."); size_t last_path = std::min(params.output_path.find_last_of("/"), params.output_path.find_last_of("\\")); - if (last != std::string::npos // filename has extension - && (last_path == std::string::npos || last > last_path)) { + if (last != std::string::npos // filename has extension + && (last_path == std::string::npos || last > last_path)) { dummy_name = params.output_path.substr(0, last); ext = lc_ext = params.output_path.substr(last); std::transform(ext.begin(), ext.end(), lc_ext.begin(), ::tolower); @@ -1087,7 +1099,7 @@ int main(int argc, const char* argv[]) { } else { dummy_name = params.output_path; ext = lc_ext = ""; - is_jpg = false; + is_jpg = false; } // appending ".png" to absent or unknown extension if (!is_jpg && lc_ext != ".png") { @@ -1099,7 +1111,7 @@ int main(int argc, const char* argv[]) { continue; } std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ext : dummy_name + ext; - if(is_jpg) { + if (is_jpg) { stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel, results[i].data, 90, get_image_params(params, params.seed + i).c_str()); printf("save result JPEG image to '%s'\n", final_image_path.c_str()); diff --git a/model.cpp b/model.cpp index 24da39f6..56a4769c 100644 --- a/model.cpp +++ b/model.cpp @@ -1539,10 +1539,14 @@ SDVersion ModelLoader::get_sd_version() { } } bool is_inpaint = input_block_weight.ne[2] == 9; + bool is_ip2p = input_block_weight.ne[2] == 8; if (is_xl) { if (is_inpaint) { return VERSION_SDXL_INPAINT; } + if (is_ip2p) { + return VERSION_SDXL_PIX2PIX; + } return VERSION_SDXL; } @@ -1558,6 +1562,9 @@ SDVersion ModelLoader::get_sd_version() { if (is_inpaint) { return VERSION_SD1_INPAINT; } + if (is_ip2p) { + return VERSION_SD1_PIX2PIX; + } return VERSION_SD1; } else if (token_embedding_weight.ne[0] == 1024) { if (is_inpaint) { diff --git a/model.h b/model.h index d7f97653..3708c585 100644 --- a/model.h +++ b/model.h @@ -12,19 +12,21 @@ #include "ggml-backend.h" #include "ggml.h" +#include "gguf.h" #include "json.hpp" #include "zip.h" -#include "gguf.h" #define SD_MAX_DIMS 5 enum SDVersion { VERSION_SD1, VERSION_SD1_INPAINT, + VERSION_SD1_PIX2PIX, VERSION_SD2, VERSION_SD2_INPAINT, VERSION_SDXL, VERSION_SDXL_INPAINT, + VERSION_SDXL_PIX2PIX, VERSION_SVD, VERSION_SD3, VERSION_FLUX, @@ -47,7 +49,7 @@ static inline bool sd_version_is_sd3(SDVersion version) { } static inline bool sd_version_is_sd1(SDVersion version) { - if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT) { + if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX) { return true; } return false; @@ -61,7 +63,7 @@ static inline bool sd_version_is_sd2(SDVersion version) { } static inline bool sd_version_is_sdxl(SDVersion version) { - if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT) { + if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX) { return true; } return false; @@ -81,6 +83,14 @@ static inline bool sd_version_is_dit(SDVersion version) { return false; } +static inline bool sd_version_is_edit(SDVersion version) { + return version == VERSION_SD1_PIX2PIX || version == VERSION_SDXL_PIX2PIX; +} + +static bool sd_version_use_concat(SDVersion version) { + return sd_version_is_edit(version) || sd_version_is_inpaint(version); +} + enum PMVersion { PM_VERSION_1, PM_VERSION_2, diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index e38a6101..47ac0b61 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -27,10 +27,12 @@ const char* model_version_to_str[] = { "SD 1.x", "SD 1.x Inpaint", + "Instruct-Pix2Pix", "SD 2.x", "SD 2.x Inpaint", "SDXL", "SDXL Inpaint", + "SDXL Instruct-Pix2Pix", "SVD", "SD3.x", "Flux", @@ -48,8 +50,7 @@ const char* sampling_methods_str[] = { "iPNDM_v", "LCM", "DDIM \"trailing\"", - "TCD" -}; + "TCD"}; /*================================================== Helper Functions ================================================*/ @@ -104,6 +105,9 @@ class StableDiffusionGGML { bool vae_tiling = false; bool stacked_id = false; + bool is_using_v_parameterization = false; + bool is_using_edm_v_parameterization = false; + std::map tensors; std::string lora_model_dir; @@ -522,12 +526,17 @@ class StableDiffusionGGML { LOG_INFO("loading model from '%s' completed, taking %.2fs", model_path.c_str(), (t1 - t0) * 1.0f / 1000); // check is_using_v_parameterization_for_sd2 - bool is_using_v_parameterization = false; + if (sd_version_is_sd2(version)) { if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) { is_using_v_parameterization = true; } } else if (sd_version_is_sdxl(version)) { + if (model_loader.tensor_storages_types.find("edm_vpred.sigma_max") != model_loader.tensor_storages_types.end()) { + // CosXL models + // TODO: get sigma_min and sigma_max values from file + is_using_edm_v_parameterization = true; + } if (model_loader.tensor_storages_types.find("v_pred") != model_loader.tensor_storages_types.end()) { is_using_v_parameterization = true; } @@ -552,6 +561,9 @@ class StableDiffusionGGML { } else if (is_using_v_parameterization) { LOG_INFO("running in v-prediction mode"); denoiser = std::make_shared(); + } else if (is_using_edm_v_parameterization) { + LOG_INFO("running in v-prediction EDM mode"); + denoiser = std::make_shared(); } else { LOG_INFO("running in eps-prediction mode"); } @@ -682,7 +694,7 @@ class StableDiffusionGGML { float curr_multiplier = kv.second; lora_state_diff[lora_name] -= curr_multiplier; } - + size_t rm = lora_state_diff.size() - lora_state.size(); if (rm != 0) { LOG_INFO("Attempting to apply %lu LoRAs (removing %lu applied LoRAs)", lora_state.size(), rm); @@ -792,19 +804,28 @@ class StableDiffusionGGML { SDCondition uncond, ggml_tensor* control_hint, float control_strength, - float min_cfg, - float cfg_scale, - float guidance, + sd_guidance_params_t guidance, float eta, sample_method_t method, const std::vector& sigmas, int start_merge_step, SDCondition id_cond, - std::vector skip_layers = {}, - float slg_scale = 0, - float skip_layer_start = 0.01, - float skip_layer_end = 0.2, - ggml_tensor* noise_mask = nullptr) { + ggml_tensor* denoise_mask = NULL) { + std::vector skip_layers(guidance.slg.layers, guidance.slg.layers + guidance.slg.layer_count); + + // TODO (Pix2Pix): separate image guidance params (right now it's reusing distilled guidance) + + float cfg_scale = guidance.txt_cfg; + float img_cfg_scale = guidance.img_cfg; + float slg_scale = guidance.slg.scale; + + float min_cfg = guidance.min_cfg; + + if (img_cfg_scale != cfg_scale && !sd_version_use_concat(version)) { + LOG_WARN("2-conditioning CFG is not supported with this model, disabling it for better performance..."); + img_cfg_scale = cfg_scale; + } + LOG_DEBUG("Sample"); struct ggml_init_params params; size_t data_size = ggml_row_size(init_latent->type, init_latent->ne[0]); @@ -826,13 +847,15 @@ class StableDiffusionGGML { struct ggml_tensor* noised_input = ggml_dup_tensor(work_ctx, noise); - bool has_unconditioned = cfg_scale != 1.0 && uncond.c_crossattn != NULL; + bool has_unconditioned = img_cfg_scale != 1.0 && uncond.c_crossattn != NULL; + bool has_img_guidance = cfg_scale != img_cfg_scale && uncond.c_crossattn != NULL; bool has_skiplayer = slg_scale != 0.0 && skip_layers.size() > 0; // denoise wrapper - struct ggml_tensor* out_cond = ggml_dup_tensor(work_ctx, x); - struct ggml_tensor* out_uncond = NULL; - struct ggml_tensor* out_skip = NULL; + struct ggml_tensor* out_cond = ggml_dup_tensor(work_ctx, x); + struct ggml_tensor* out_uncond = NULL; + struct ggml_tensor* out_skip = NULL; + struct ggml_tensor* out_img_cond = NULL; if (has_unconditioned) { out_uncond = ggml_dup_tensor(work_ctx, x); @@ -845,6 +868,9 @@ class StableDiffusionGGML { LOG_WARN("SLG is incompatible with %s models", model_version_to_str[version]); } } + if (has_img_guidance) { + out_img_cond = ggml_dup_tensor(work_ctx, x); + } struct ggml_tensor* denoised = ggml_dup_tensor(work_ctx, x); auto denoise = [&](ggml_tensor* input, float sigma, int step) -> ggml_tensor* { @@ -862,7 +888,7 @@ class StableDiffusionGGML { float t = denoiser->sigma_to_t(sigma); std::vector timesteps_vec(x->ne[3], t); // [N, ] auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec); - std::vector guidance_vec(x->ne[3], guidance); + std::vector guidance_vec(x->ne[3], guidance.distilled_guidance); auto guidance_tensor = vector_to_ggml_tensor(work_ctx, guidance_vec); copy_ggml_tensor(noised_input, input); @@ -926,8 +952,24 @@ class StableDiffusionGGML { negative_data = (float*)out_uncond->data; } + float* img_cond_data = NULL; + if (has_img_guidance) { + diffusion_model->compute(n_threads, + noised_input, + timesteps, + uncond.c_crossattn, + cond.c_concat, + uncond.c_vector, + guidance_tensor, + -1, + controls, + control_strength, + &out_img_cond); + img_cond_data = (float*)out_img_cond->data; + } + int step_count = sigmas.size(); - bool is_skiplayer_step = has_skiplayer && step > (int)(skip_layer_start * step_count) && step < (int)(skip_layer_end * step_count); + bool is_skiplayer_step = has_skiplayer && step > (int)(guidance.slg.layer_start * step_count) && step < (int)(guidance.slg.layer_end * step_count); float* skip_layer_data = NULL; if (is_skiplayer_step) { LOG_DEBUG("Skipping layers at step %d\n", step); @@ -960,8 +1002,16 @@ class StableDiffusionGGML { int64_t i3 = i / out_cond->ne[0] * out_cond->ne[1] * out_cond->ne[2]; float scale = min_cfg + (cfg_scale - min_cfg) * (i3 * 1.0f / ne3); } else { - latent_result = negative_data[i] + cfg_scale * (positive_data[i] - negative_data[i]); + if (has_img_guidance) { + latent_result = negative_data[i] + img_cfg_scale * (img_cond_data[i] - negative_data[i]) + cfg_scale * (positive_data[i] - img_cond_data[i]); + } else { + // img_cfg_scale == cfg_scale + latent_result = negative_data[i] + cfg_scale * (positive_data[i] - negative_data[i]); + } } + } else if (has_img_guidance) { + // img_cfg_scale == 1 + latent_result = img_cond_data[i] + cfg_scale * (positive_data[i] - img_cond_data[i]); } if (is_skiplayer_step) { latent_result = latent_result + (positive_data[i] - skip_layer_data[i]) * slg_scale; @@ -975,10 +1025,10 @@ class StableDiffusionGGML { pretty_progress(step, (int)steps, (t1 - t0) / 1000000.f); // LOG_INFO("step %d sampling completed taking %.2fs", step, (t1 - t0) * 1.0f / 1000000); } - if (noise_mask != nullptr) { + if (denoise_mask != nullptr) { for (int64_t x = 0; x < denoised->ne[0]; x++) { for (int64_t y = 0; y < denoised->ne[1]; y++) { - float mask = ggml_tensor_get_f32(noise_mask, x, y); + float mask = ggml_tensor_get_f32(denoise_mask, x, y); for (int64_t k = 0; k < denoised->ne[2]; k++) { float init = ggml_tensor_get_f32(init_latent, x, y, k); float den = ggml_tensor_get_f32(denoised, x, y, k); @@ -1035,6 +1085,30 @@ class StableDiffusionGGML { return latent; } + ggml_tensor* + get_first_stage_encoding_mode(ggml_context* work_ctx, ggml_tensor* moments) { + // ldm.modules.distributions.distributions.DiagonalGaussianDistribution.sample + ggml_tensor* latent = ggml_new_tensor_4d(work_ctx, moments->type, moments->ne[0], moments->ne[1], moments->ne[2] / 2, moments->ne[3]); + struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, latent); + ggml_tensor_set_f32_randn(noise, rng); + // noise = load_tensor_from_file(work_ctx, "noise.bin"); + { + float mean = 0; + for (int i = 0; i < latent->ne[3]; i++) { + for (int j = 0; j < latent->ne[2]; j++) { + for (int k = 0; k < latent->ne[1]; k++) { + for (int l = 0; l < latent->ne[0]; l++) { + // mode and mean are the same for gaussians + mean = ggml_tensor_get_f32(moments, l, k, j, i); + ggml_tensor_set_f32(latent, mean, l, k, j, i); + } + } + } + } + } + return latent; + } + ggml_tensor* compute_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode) { int64_t W = x->ne[0]; int64_t H = x->ne[1]; @@ -1195,8 +1269,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, std::string prompt, std::string negative_prompt, int clip_skip, - float cfg_scale, - float guidance, + sd_guidance_params_t guidance, float eta, int width, int height, @@ -1209,11 +1282,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, float style_ratio, bool normalize_input, std::string input_id_images_path, - std::vector skip_layers = {}, - float slg_scale = 0, - float skip_layer_start = 0.01, - float skip_layer_end = 0.2, - ggml_tensor* masked_image = NULL) { + ggml_tensor* concat_latent = NULL, + ggml_tensor* denoise_mask = NULL) { if (seed < 0) { // Generally, when using the provided command line, the seed is always >0. // However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library @@ -1361,9 +1431,9 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, sd_ctx->sd->diffusion_model->get_adm_in_channels()); SDCondition uncond; - if (cfg_scale != 1.0) { + if (guidance.txt_cfg != 1.0 || sd_version_use_concat(sd_ctx->sd->version) && guidance.txt_cfg != guidance.img_cfg) { bool force_zero_embeddings = false; - if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0) { + if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0 && !sd_ctx->sd->is_using_edm_v_parameterization) { force_zero_embeddings = true; } uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx, @@ -1400,38 +1470,46 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, int W = width / 8; int H = height / 8; LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]); - ggml_tensor* noise_mask = nullptr; if (sd_version_is_inpaint(sd_ctx->sd->version)) { - if (masked_image == NULL) { - int64_t mask_channels = 1; - if (sd_ctx->sd->version == VERSION_FLUX_FILL) { - mask_channels = 8 * 8; // flatten the whole mask - } - // no mask, set the whole image as masked - masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1); - for (int64_t x = 0; x < masked_image->ne[0]; x++) { - for (int64_t y = 0; y < masked_image->ne[1]; y++) { - if (sd_ctx->sd->version == VERSION_FLUX_FILL) { - // TODO: this might be wrong - for (int64_t c = 0; c < init_latent->ne[2]; c++) { - ggml_tensor_set_f32(masked_image, 0, x, y, c); - } - for (int64_t c = init_latent->ne[2]; c < masked_image->ne[2]; c++) { - ggml_tensor_set_f32(masked_image, 1, x, y, c); - } - } else { - ggml_tensor_set_f32(masked_image, 1, x, y, 0); - for (int64_t c = 1; c < masked_image->ne[2]; c++) { - ggml_tensor_set_f32(masked_image, 0, x, y, c); - } + int64_t mask_channels = 1; + if (sd_ctx->sd->version == VERSION_FLUX_FILL) { + mask_channels = 8 * 8; // flatten the whole mask + } + auto empty_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1); + // no mask, set the whole image as masked + for (int64_t x = 0; x < empty_latent->ne[0]; x++) { + for (int64_t y = 0; y < empty_latent->ne[1]; y++) { + if (sd_ctx->sd->version == VERSION_FLUX_FILL) { + // TODO: this might be wrong + for (int64_t c = 0; c < init_latent->ne[2]; c++) { + ggml_tensor_set_f32(empty_latent, 0, x, y, c); + } + for (int64_t c = init_latent->ne[2]; c < empty_latent->ne[2]; c++) { + ggml_tensor_set_f32(empty_latent, 1, x, y, c); + } + } else { + ggml_tensor_set_f32(empty_latent, 1, x, y, 0); + for (int64_t c = 1; c < empty_latent->ne[2]; c++) { + ggml_tensor_set_f32(empty_latent, 0, x, y, c); } } } } - cond.c_concat = masked_image; - uncond.c_concat = masked_image; - } else { - noise_mask = masked_image; + if (concat_latent == NULL) { + concat_latent = empty_latent; + } + cond.c_concat = concat_latent; + uncond.c_concat = empty_latent; + denoise_mask = NULL; + } else if (sd_version_is_edit(sd_ctx->sd->version)) { + auto empty_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], init_latent->ne[3]); + ggml_set_f32(empty_latent, 0); + uncond.c_concat = empty_latent; + if (concat_latent == NULL) { + concat_latent = empty_latent; + } + cond.c_concat = concat_latent; + } for (int b = 0; b < batch_count; b++) { int64_t sampling_start = ggml_time_ms(); @@ -1451,6 +1529,9 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, LOG_INFO("PHOTOMAKER: start_merge_step: %d", start_merge_step); } + // Disable min_cfg + guidance.min_cfg = guidance.txt_cfg; + struct ggml_tensor* x_0 = sd_ctx->sd->sample(work_ctx, x_t, noise, @@ -1458,19 +1539,13 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, uncond, image_hint, control_strength, - cfg_scale, - cfg_scale, guidance, eta, sample_method, sigmas, start_merge_step, id_cond, - skip_layers, - slg_scale, - skip_layer_start, - skip_layer_end, - noise_mask); + denoise_mask); // struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin"); // print_ggml_tensor(x_0); @@ -1525,8 +1600,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, const char* prompt_c_str, const char* negative_prompt_c_str, int clip_skip, - float cfg_scale, - float guidance, + sd_guidance_params_t guidance, float eta, int width, int height, @@ -1538,13 +1612,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, float control_strength, float style_ratio, bool normalize_input, - const char* input_id_images_path_c_str, - int* skip_layers = NULL, - size_t skip_layers_count = 0, - float slg_scale = 0, - float skip_layer_start = 0.01, - float skip_layer_end = 0.2) { - std::vector skip_layers_vec(skip_layers, skip_layers + skip_layers_count); + const char* input_id_images_path_c_str) { LOG_DEBUG("txt2img %dx%d", width, height); if (sd_ctx == NULL) { return NULL; @@ -1604,7 +1672,6 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, prompt_c_str, negative_prompt_c_str, clip_skip, - cfg_scale, guidance, eta, width, @@ -1617,11 +1684,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, control_strength, style_ratio, normalize_input, - input_id_images_path_c_str, - skip_layers_vec, - slg_scale, - skip_layer_start, - skip_layer_end); + input_id_images_path_c_str); size_t t1 = ggml_time_ms(); @@ -1636,8 +1699,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, const char* prompt_c_str, const char* negative_prompt_c_str, int clip_skip, - float cfg_scale, - float guidance, + sd_guidance_params_t guidance, float eta, int width, int height, @@ -1650,13 +1712,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, float control_strength, float style_ratio, bool normalize_input, - const char* input_id_images_path_c_str, - int* skip_layers = NULL, - size_t skip_layers_count = 0, - float slg_scale = 0, - float skip_layer_start = 0.01, - float skip_layer_end = 0.2) { - std::vector skip_layers_vec(skip_layers, skip_layers + skip_layers_count); + const char* input_id_images_path_c_str) { LOG_DEBUG("img2img %dx%d", width, height); if (sd_ctx == NULL) { return NULL; @@ -1700,7 +1756,17 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, sd_image_to_tensor(init_image.data, init_img); - ggml_tensor* masked_image; + ggml_tensor* concat_latent; + ggml_tensor* denoise_mask = NULL; + + ggml_tensor* init_latent = NULL; + ggml_tensor* init_moments = NULL; + if (!sd_ctx->sd->use_tiny_autoencoder) { + init_moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img); + init_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, init_moments); + } else { + init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); + } if (sd_version_is_inpaint(sd_ctx->sd->version)) { int64_t mask_channels = 1; @@ -1708,23 +1774,25 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, mask_channels = 8 * 8; // flatten the whole mask } ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); + // Restore init_img (encode_first_stage has side effects) TODO: remove the side effects? + sd_image_to_tensor(init_image.data, init_img); sd_apply_mask(init_img, mask_img, masked_img); - ggml_tensor* masked_image_0 = NULL; + ggml_tensor* masked_latent = NULL; if (!sd_ctx->sd->use_tiny_autoencoder) { ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, masked_img); - masked_image_0 = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments); + masked_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments); } else { - masked_image_0 = sd_ctx->sd->encode_first_stage(work_ctx, masked_img); + masked_latent = sd_ctx->sd->encode_first_stage(work_ctx, masked_img); } - masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, masked_image_0->ne[0], masked_image_0->ne[1], mask_channels + masked_image_0->ne[2], 1); - for (int ix = 0; ix < masked_image_0->ne[0]; ix++) { - for (int iy = 0; iy < masked_image_0->ne[1]; iy++) { + concat_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, masked_latent->ne[0], masked_latent->ne[1], mask_channels + masked_latent->ne[2], 1); + for (int ix = 0; ix < masked_latent->ne[0]; ix++) { + for (int iy = 0; iy < masked_latent->ne[1]; iy++) { int mx = ix * 8; int my = iy * 8; if (sd_ctx->sd->version == VERSION_FLUX_FILL) { - for (int k = 0; k < masked_image_0->ne[2]; k++) { - float v = ggml_tensor_get_f32(masked_image_0, ix, iy, k); - ggml_tensor_set_f32(masked_image, v, ix, iy, k); + for (int k = 0; k < masked_latent->ne[2]; k++) { + float v = ggml_tensor_get_f32(masked_latent, ix, iy, k); + ggml_tensor_set_f32(concat_latent, v, ix, iy, k); } // "Encode" 8x8 mask chunks into a flattened 1x64 vector, and concatenate to masked image for (int x = 0; x < 8; x++) { @@ -1732,40 +1800,46 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, float m = ggml_tensor_get_f32(mask_img, mx + x, my + y); // TODO: check if the way the mask is flattened is correct (is it supposed to be x*8+y or x+8*y?) // python code was using "b (h 8) (w 8) -> b (8 8) h w" - ggml_tensor_set_f32(masked_image, m, ix, iy, masked_image_0->ne[2] + x * 8 + y); + ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2] + x * 8 + y); } } } else { float m = ggml_tensor_get_f32(mask_img, mx, my); - ggml_tensor_set_f32(masked_image, m, ix, iy, 0); - for (int k = 0; k < masked_image_0->ne[2]; k++) { - float v = ggml_tensor_get_f32(masked_image_0, ix, iy, k); - ggml_tensor_set_f32(masked_image, v, ix, iy, k + mask_channels); + ggml_tensor_set_f32(concat_latent, m, ix, iy, 0); + for (int k = 0; k < masked_latent->ne[2]; k++) { + float v = ggml_tensor_get_f32(masked_latent, ix, iy, k); + ggml_tensor_set_f32(concat_latent, v, ix, iy, k + mask_channels); } } } } - } else { + } else if (sd_version_is_edit(sd_ctx->sd->version)) { + // Not actually masked, we're just highjacking the concat_latent variable since it will be used the same way + if (!sd_ctx->sd->use_tiny_autoencoder) { + if (sd_ctx->sd->is_using_edm_v_parameterization) { + // for CosXL edit + concat_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, init_moments); + } else { + concat_latent = sd_ctx->sd->get_first_stage_encoding_mode(work_ctx, init_moments); + } + } else { + concat_latent = init_latent; + } + } + + { // LOG_WARN("Inpainting with a base model is not great"); - masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / 8, height / 8, 1, 1); - for (int ix = 0; ix < masked_image->ne[0]; ix++) { - for (int iy = 0; iy < masked_image->ne[1]; iy++) { + denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / 8, height / 8, 1, 1); + for (int ix = 0; ix < denoise_mask->ne[0]; ix++) { + for (int iy = 0; iy < denoise_mask->ne[1]; iy++) { int mx = ix * 8; int my = iy * 8; float m = ggml_tensor_get_f32(mask_img, mx, my); - ggml_tensor_set_f32(masked_image, m, ix, iy); + ggml_tensor_set_f32(denoise_mask, m, ix, iy); } } } - ggml_tensor* init_latent = NULL; - if (!sd_ctx->sd->use_tiny_autoencoder) { - ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img); - init_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments); - } else { - init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); - } - print_ggml_tensor(init_latent, true); size_t t1 = ggml_time_ms(); LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); @@ -1784,7 +1858,6 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, prompt_c_str, negative_prompt_c_str, clip_skip, - cfg_scale, guidance, eta, width, @@ -1798,11 +1871,8 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, style_ratio, normalize_input, input_id_images_path_c_str, - skip_layers_vec, - slg_scale, - skip_layer_start, - skip_layer_end, - masked_image); + concat_latent, + denoise_mask); size_t t2 = ggml_time_ms(); @@ -1819,8 +1889,7 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx, int motion_bucket_id, int fps, float augmentation_level, - float min_cfg, - float cfg_scale, + sd_guidance_params_t guidance, enum sample_method_t sample_method, int sample_steps, float strength, @@ -1897,9 +1966,7 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx, uncond, {}, 0.f, - min_cfg, - cfg_scale, - 0.f, + guidance, 0.f, sample_method, sigmas, diff --git a/stable-diffusion.h b/stable-diffusion.h index 52dcc848..14658a64 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -61,10 +61,10 @@ enum schedule_t { // same as enum ggml_type enum sd_type_t { - SD_TYPE_F32 = 0, - SD_TYPE_F16 = 1, - SD_TYPE_Q4_0 = 2, - SD_TYPE_Q4_1 = 3, + SD_TYPE_F32 = 0, + SD_TYPE_F16 = 1, + SD_TYPE_Q4_0 = 2, + SD_TYPE_Q4_1 = 3, // SD_TYPE_Q4_2 = 4, support has been removed // SD_TYPE_Q4_3 = 5, support has been removed SD_TYPE_Q5_0 = 6, @@ -95,12 +95,12 @@ enum sd_type_t { // SD_TYPE_Q4_0_4_4 = 31, support has been removed from gguf files // SD_TYPE_Q4_0_4_8 = 32, // SD_TYPE_Q4_0_8_8 = 33, - SD_TYPE_TQ1_0 = 34, - SD_TYPE_TQ2_0 = 35, + SD_TYPE_TQ1_0 = 34, + SD_TYPE_TQ2_0 = 35, // SD_TYPE_IQ4_NL_4_4 = 36, // SD_TYPE_IQ4_NL_4_8 = 37, // SD_TYPE_IQ4_NL_8_8 = 38, - SD_TYPE_COUNT = 39, + SD_TYPE_COUNT = 39, }; SD_API const char* sd_type_name(enum sd_type_t type); @@ -129,6 +129,21 @@ typedef struct { typedef struct sd_ctx_t sd_ctx_t; +typedef struct sd_slg_params_t { + int* layers; + size_t layer_count; + float layer_start; + float layer_end; + float scale; +} sd_slg_params_t; +typedef struct sd_guidance_params_t { + float txt_cfg; + float img_cfg; + float min_cfg; + float distilled_guidance; + sd_slg_params_t slg; +} sd_guidance_params_t; + SD_API sd_ctx_t* new_sd_ctx(const char* model_path, const char* clip_l_path, const char* clip_g_path, @@ -158,8 +173,7 @@ SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx, const char* prompt, const char* negative_prompt, int clip_skip, - float cfg_scale, - float guidance, + sd_guidance_params_t guidance, float eta, int width, int height, @@ -171,12 +185,7 @@ SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx, float control_strength, float style_strength, bool normalize_input, - const char* input_id_images_path, - int* skip_layers, - size_t skip_layers_count, - float slg_scale, - float skip_layer_start, - float skip_layer_end); + const char* input_id_images_path); SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx, sd_image_t init_image, @@ -184,8 +193,7 @@ SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx, const char* prompt, const char* negative_prompt, int clip_skip, - float cfg_scale, - float guidance, + sd_guidance_params_t guidance, float eta, int width, int height, @@ -198,12 +206,7 @@ SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx, float control_strength, float style_strength, bool normalize_input, - const char* input_id_images_path, - int* skip_layers, - size_t skip_layers_count, - float slg_scale, - float skip_layer_start, - float skip_layer_end); + const char* input_id_images_path); SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx, sd_image_t init_image, @@ -213,8 +216,7 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx, int motion_bucket_id, int fps, float augmentation_level, - float min_cfg, - float cfg_scale, + sd_guidance_params_t guidance, enum sample_method_t sample_method, int sample_steps, float strength, diff --git a/unet.hpp b/unet.hpp index 31b7fe98..b3fae53a 100644 --- a/unet.hpp +++ b/unet.hpp @@ -207,6 +207,8 @@ class UnetModelBlock : public GGMLBlock { } if (sd_version_is_inpaint(version)) { in_channels = 9; + } else if (sd_version_is_edit(version)) { + in_channels = 8; } // dims is always 2