Skip to content

Commit c458bfd

Browse files
committed
refactor(tx): add lora util
Signed-off-by: thxCode <thxcode0824@gmail.com>
1 parent 33518e2 commit c458bfd

File tree

2 files changed

+61
-21
lines changed

2 files changed

+61
-21
lines changed

stable-diffusion.cpp

Lines changed: 53 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -765,17 +765,25 @@ class StableDiffusionGGML {
765765
}
766766

767767
void apply_lora(const std::string& lora_name, float multiplier) {
768-
int64_t t0 = ggml_time_ms();
769-
std::string st_file_path = path_join(lora_model_dir, lora_name + ".safetensors");
770-
std::string ckpt_file_path = path_join(lora_model_dir, lora_name + ".ckpt");
768+
int64_t t0 = ggml_time_ms();
769+
771770
std::string file_path;
772-
if (file_exists(st_file_path)) {
773-
file_path = st_file_path;
774-
} else if (file_exists(ckpt_file_path)) {
775-
file_path = ckpt_file_path;
771+
if (!lora_model_dir.empty()) {
772+
std::string st_file_path = path_join(lora_model_dir, lora_name + ".safetensors");
773+
std::string ckpt_file_path = path_join(lora_model_dir, lora_name + ".ckpt");
774+
std::string gguf_file_path = path_join(lora_model_dir, lora_name + ".gguf");
775+
if (file_exists(st_file_path)) {
776+
file_path = st_file_path;
777+
} else if (file_exists(ckpt_file_path)) {
778+
file_path = ckpt_file_path;
779+
} else if (file_exists(gguf_file_path)) {
780+
file_path = gguf_file_path;
781+
} else {
782+
LOG_WARN("can not find %s, %s, %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), gguf_file_path.c_str(), lora_name.c_str());
783+
return;
784+
}
776785
} else {
777-
LOG_WARN("can not find %s or %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), lora_name.c_str());
778-
return;
786+
file_path = lora_name;
779787
}
780788
LoraModel lora(backend, file_path);
781789
if (!lora.load_from_file()) {
@@ -1474,21 +1482,24 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
14741482

14751483
int sample_steps = sigmas.size() - 1;
14761484

1477-
// Apply lora
1478-
auto result_pair = extract_and_remove_lora(prompt);
1479-
std::unordered_map<std::string, float> lora_f2m = result_pair.first; // lora_name -> multiplier
1485+
// Apply lora if provided model directory
1486+
int64_t t0, t1;
1487+
if (!sd_ctx->sd->lora_model_dir.empty()) {
1488+
auto result_pair = extract_and_remove_lora(prompt);
1489+
std::unordered_map<std::string, float> lora_f2m = result_pair.first; // lora_name -> multiplier
14801490

1481-
for (auto& kv : lora_f2m) {
1482-
LOG_DEBUG("lora %s:%.2f", kv.first.c_str(), kv.second);
1483-
}
1491+
for (auto& kv : lora_f2m) {
1492+
LOG_DEBUG("lora %s:%.2f", kv.first.c_str(), kv.second);
1493+
}
14841494

1485-
prompt = result_pair.second;
1486-
LOG_DEBUG("prompt after extract and remove lora: \"%s\"", prompt.c_str());
1495+
prompt = result_pair.second;
1496+
LOG_DEBUG("prompt after extract and remove lora: \"%s\"", prompt.c_str());
14871497

1488-
int64_t t0 = ggml_time_ms();
1489-
sd_ctx->sd->apply_loras(lora_f2m);
1490-
int64_t t1 = ggml_time_ms();
1491-
LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
1498+
t0 = ggml_time_ms();
1499+
sd_ctx->sd->apply_loras(lora_f2m);
1500+
t1 = ggml_time_ms();
1501+
LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
1502+
}
14921503

14931504
// Photo Maker
14941505
std::string prompt_text_only;
@@ -2209,6 +2220,27 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
22092220
return result_images;
22102221
}
22112222

2223+
void sd_lora_adapters_clear(sd_ctx_t* sd_ctx) {
2224+
if (sd_ctx == NULL) {
2225+
return;
2226+
}
2227+
sd_ctx->sd->curr_lora_state.clear();
2228+
}
2229+
2230+
void sd_lora_adapters_apply(sd_ctx_t* sd_ctx, std::vector<sd_lora_adapter_container_t>& lora_adapters) {
2231+
if (sd_ctx == NULL) {
2232+
return;
2233+
}
2234+
2235+
sd_lora_adapters_clear(sd_ctx);
2236+
2237+
std::unordered_map<std::string, float> lora_state;
2238+
for (const sd_lora_adapter_container_t& lora_adapter : lora_adapters) {
2239+
lora_state[lora_adapter.path] = lora_adapter.multiplier;
2240+
}
2241+
sd_ctx->sd->apply_loras(lora_state);
2242+
}
2243+
22122244
int sd_get_version(sd_ctx_t* sd_ctx) {
22132245
if (sd_ctx == NULL) {
22142246
return VERSION_COUNT;

stable-diffusion.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,14 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
241241
float strength,
242242
int64_t seed);
243243

244+
typedef struct sd_lora_adapter_container_t {
245+
const char* path;
246+
float multiplier;
247+
} sd_lora_adapter_container_t;
248+
249+
SD_API void sd_lora_adapters_clear(sd_ctx_t* sd_ctx);
250+
SD_API void sd_lora_adapters_apply(sd_ctx_t* sd_ctx, std::vector<sd_lora_adapter_container_t>& lora_adapters);
251+
244252
SD_API int sd_get_version(sd_ctx_t* sd_ctx);
245253

246254
typedef struct upscaler_ctx_t upscaler_ctx_t;

0 commit comments

Comments
 (0)