diff --git a/model.cpp b/model.cpp index 24da39f6d..55fe8f8e3 100644 --- a/model.cpp +++ b/model.cpp @@ -1477,6 +1477,15 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s return true; } +bool ModelLoader::model_is_unet() { + for (auto& tensor_storage : tensor_storages) { + if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos) { + return true; + } + } + return false; +} + SDVersion ModelLoader::get_sd_version() { TensorStorage token_embedding_weight, input_block_weight; bool input_block_checked = false; diff --git a/model.h b/model.h index d7f976533..87343d77a 100644 --- a/model.h +++ b/model.h @@ -210,6 +210,7 @@ class ModelLoader { std::map tensor_storages_types; bool init_from_file(const std::string& file_path, const std::string& prefix = ""); + bool model_is_unet(); SDVersion get_sd_version(); ggml_type get_sd_wtype(); ggml_type get_conditioner_wtype(); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index e38a6101f..a14d68806 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -200,16 +200,25 @@ class StableDiffusionGGML { } } + if (diffusion_model_path.size() > 0) { + LOG_INFO("loading diffusion model from '%s'", diffusion_model_path.c_str()); + if (!model_loader.init_from_file(diffusion_model_path, "model.diffusion_model.")) { + LOG_WARN("loading diffusion model from '%s' failed", diffusion_model_path.c_str()); + } + } + + bool is_unet = model_loader.model_is_unet(); + if (clip_l_path.size() > 0) { LOG_INFO("loading clip_l from '%s'", clip_l_path.c_str()); - if (!model_loader.init_from_file(clip_l_path, "text_encoders.clip_l.transformer.")) { + if (!model_loader.init_from_file(clip_l_path, is_unet ? "cond_stage_model.transformer." : "text_encoders.clip_l.transformer.")) { LOG_WARN("loading clip_l from '%s' failed", clip_l_path.c_str()); } } if (clip_g_path.size() > 0) { LOG_INFO("loading clip_g from '%s'", clip_g_path.c_str()); - if (!model_loader.init_from_file(clip_g_path, "text_encoders.clip_g.transformer.")) { + if (!model_loader.init_from_file(clip_g_path, is_unet ? "cond_stage_model.1.transformer." : "text_encoders.clip_g.transformer.")) { LOG_WARN("loading clip_g from '%s' failed", clip_g_path.c_str()); } } @@ -221,13 +230,6 @@ class StableDiffusionGGML { } } - if (diffusion_model_path.size() > 0) { - LOG_INFO("loading diffusion model from '%s'", diffusion_model_path.c_str()); - if (!model_loader.init_from_file(diffusion_model_path, "model.diffusion_model.")) { - LOG_WARN("loading diffusion model from '%s' failed", diffusion_model_path.c_str()); - } - } - if (vae_path.size() > 0) { LOG_INFO("loading vae from '%s'", vae_path.c_str()); if (!model_loader.init_from_file(vae_path, "vae.")) {