Skip to content

Commit 4e75394

Browse files
committed
refactor: lora
Signed-off-by: thxCode <thxcode0824@gmail.com>
1 parent 1a72908 commit 4e75394

File tree

2 files changed

+61
-24
lines changed

2 files changed

+61
-24
lines changed

stable-diffusion.cpp

Lines changed: 54 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -642,18 +642,27 @@ class StableDiffusionGGML {
642642
}
643643

644644
void apply_lora(const std::string& lora_name, float multiplier) {
645-
int64_t t0 = ggml_time_ms();
646-
std::string st_file_path = path_join(lora_model_dir, lora_name + ".safetensors");
647-
std::string ckpt_file_path = path_join(lora_model_dir, lora_name + ".ckpt");
645+
int64_t t0 = ggml_time_ms();
646+
648647
std::string file_path;
649-
if (file_exists(st_file_path)) {
650-
file_path = st_file_path;
651-
} else if (file_exists(ckpt_file_path)) {
652-
file_path = ckpt_file_path;
648+
if (!lora_model_dir.empty()) {
649+
std::string st_file_path = path_join(lora_model_dir, lora_name + ".safetensors");
650+
std::string ckpt_file_path = path_join(lora_model_dir, lora_name + ".ckpt");
651+
std::string gguf_file_path = path_join(lora_model_dir, lora_name + ".gguf");
652+
if (file_exists(st_file_path)) {
653+
file_path = st_file_path;
654+
} else if (file_exists(ckpt_file_path)) {
655+
file_path = ckpt_file_path;
656+
} else if (file_exists(gguf_file_path)) {
657+
file_path = gguf_file_path;
658+
} else {
659+
LOG_WARN("can not find %s, %s, %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), gguf_file_path.c_str(), lora_name.c_str());
660+
return;
661+
}
653662
} else {
654-
LOG_WARN("can not find %s or %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), lora_name.c_str());
655-
return;
663+
file_path = lora_name;
656664
}
665+
657666
LoraModel lora(backend, model_wtype, file_path);
658667
if (!lora.load_from_file()) {
659668
LOG_WARN("load lora tensors from %s failed", file_path.c_str());
@@ -673,6 +682,7 @@ class StableDiffusionGGML {
673682
if (!lora_state.empty() && model_wtype != GGML_TYPE_F16 && model_wtype != GGML_TYPE_F32) {
674683
LOG_WARN("In quantized models when applying LoRA, the images have poor quality.");
675684
}
685+
676686
std::unordered_map<std::string, float> lora_state_diff;
677687
for (auto& kv : lora_state) {
678688
const std::string& lora_name = kv.first;
@@ -690,11 +700,9 @@ class StableDiffusionGGML {
690700
}
691701

692702
LOG_INFO("Attempting to apply %lu LoRAs", lora_state.size());
693-
694703
for (auto& kv : lora_state_diff) {
695704
apply_lora(kv.first, kv.second);
696705
}
697-
698706
curr_lora_state = lora_state;
699707
}
700708

@@ -980,8 +988,6 @@ class StableDiffusionGGML {
980988
case VERSION_SD3_MEDIUM:
981989
case VERSION_SD3_5_MEDIUM:
982990
case VERSION_SD3_5_LARGE:
983-
C = 32;
984-
break;
985991
case VERSION_FLUX_DEV:
986992
case VERSION_FLUX_SCHNELL:
987993
C = 32;
@@ -1163,20 +1169,23 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
11631169
int sample_steps = sigmas.size() - 1;
11641170

11651171
// Apply lora
1166-
auto result_pair = extract_and_remove_lora(prompt);
1167-
std::unordered_map<std::string, float> lora_f2m = result_pair.first; // lora_name -> multiplier
1172+
int64_t t0, t1;
1173+
if (!sd_ctx->sd->lora_model_dir.empty()) {
1174+
auto result_pair = extract_and_remove_lora(prompt);
1175+
std::unordered_map<std::string, float> lora_f2m = result_pair.first; // lora_name -> multiplier
11681176

1169-
for (auto& kv : lora_f2m) {
1170-
LOG_DEBUG("lora %s:%.2f", kv.first.c_str(), kv.second);
1171-
}
1177+
for (auto& kv : lora_f2m) {
1178+
LOG_DEBUG("lora %s:%.2f", kv.first.c_str(), kv.second);
1179+
}
11721180

1173-
prompt = result_pair.second;
1174-
LOG_DEBUG("prompt after extract and remove lora: \"%s\"", prompt.c_str());
1181+
prompt = result_pair.second;
1182+
LOG_DEBUG("prompt after extract and remove lora: \"%s\"", prompt.c_str());
11751183

1176-
int64_t t0 = ggml_time_ms();
1177-
sd_ctx->sd->apply_loras(lora_f2m);
1178-
int64_t t1 = ggml_time_ms();
1179-
LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
1184+
t0 = ggml_time_ms();
1185+
sd_ctx->sd->apply_loras(lora_f2m);
1186+
t1 = ggml_time_ms();
1187+
LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
1188+
}
11801189

11811190
// Photo Maker
11821191
std::string prompt_text_only;
@@ -1630,6 +1639,27 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
16301639
return result_images;
16311640
}
16321641

1642+
void sd_lora_adapters_clear(sd_ctx_t* sd_ctx) {
1643+
if (sd_ctx == NULL) {
1644+
return;
1645+
}
1646+
sd_ctx->sd->curr_lora_state.clear();
1647+
}
1648+
1649+
void sd_lora_adapters_apply(sd_ctx_t* sd_ctx, std::vector<sd_lora_adapter_container_t>& lora_adapters) {
1650+
if (sd_ctx == NULL) {
1651+
return;
1652+
}
1653+
1654+
sd_lora_adapters_clear(sd_ctx);
1655+
1656+
std::unordered_map<std::string, float> lora_state;
1657+
for (const sd_lora_adapter_container_t& lora_adapter : lora_adapters) {
1658+
lora_state[lora_adapter.path] = lora_adapter.multiplier;
1659+
}
1660+
sd_ctx->sd->apply_loras(lora_state);
1661+
}
1662+
16331663
int sd_get_version(sd_ctx_t* sd_ctx) {
16341664
if (sd_ctx == NULL) {
16351665
return VERSION_COUNT;

stable-diffusion.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,13 @@ SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx,
183183
bool normalize_input,
184184
const char* input_id_images_path);
185185

186+
typedef struct sd_lora_adapter_container_t {
187+
const char* path;
188+
float multiplier;
189+
} sd_lora_adapter_container_t;
190+
191+
SD_API void sd_lora_adapters_clear(sd_ctx_t* sd_ctx);
192+
SD_API void sd_lora_adapters_apply(sd_ctx_t* sd_ctx, std::vector<sd_lora_adapter_container_t>& lora_adapters);
186193
SD_API int sd_get_version(sd_ctx_t* sd_ctx);
187194
SD_API sample_method_t sd_get_default_sample_method(sd_ctx_t* sd_ctx);
188195
SD_API int sd_get_default_sample_steps(sd_ctx_t* sd_ctx);

0 commit comments

Comments
 (0)