Skip to content

Commit 06ea7a3

Browse files
committed
refactor(tx): add lora util
Signed-off-by: thxCode <thxcode0824@gmail.com>
1 parent e660807 commit 06ea7a3

File tree

2 files changed

+62
-21
lines changed

2 files changed

+62
-21
lines changed

stable-diffusion.cpp

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -684,18 +684,27 @@ class StableDiffusionGGML {
684684
}
685685

686686
void apply_lora(const std::string& lora_name, float multiplier) {
687-
int64_t t0 = ggml_time_ms();
688-
std::string st_file_path = path_join(lora_model_dir, lora_name + ".safetensors");
689-
std::string ckpt_file_path = path_join(lora_model_dir, lora_name + ".ckpt");
687+
int64_t t0 = ggml_time_ms();
688+
690689
std::string file_path;
691-
if (file_exists(st_file_path)) {
692-
file_path = st_file_path;
693-
} else if (file_exists(ckpt_file_path)) {
694-
file_path = ckpt_file_path;
690+
if (!lora_model_dir.empty()) {
691+
std::string st_file_path = path_join(lora_model_dir, lora_name + ".safetensors");
692+
std::string ckpt_file_path = path_join(lora_model_dir, lora_name + ".ckpt");
693+
std::string gguf_file_path = path_join(lora_model_dir, lora_name + ".gguf");
694+
if (file_exists(st_file_path)) {
695+
file_path = st_file_path;
696+
} else if (file_exists(ckpt_file_path)) {
697+
file_path = ckpt_file_path;
698+
} else if (file_exists(gguf_file_path)) {
699+
file_path = gguf_file_path;
700+
} else {
701+
LOG_WARN("can not find %s, %s, %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), gguf_file_path.c_str(), lora_name.c_str());
702+
return;
703+
}
695704
} else {
696-
LOG_WARN("can not find %s or %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), lora_name.c_str());
697-
return;
705+
file_path = lora_name;
698706
}
707+
699708
LoraModel lora(backend, model_wtype, file_path);
700709
if (!lora.load_from_file()) {
701710
LOG_WARN("load lora tensors from %s failed", file_path.c_str());
@@ -1245,21 +1254,24 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
12451254

12461255
int sample_steps = sigmas.size() - 1;
12471256

1248-
// Apply lora
1249-
auto result_pair = extract_and_remove_lora(prompt);
1250-
std::unordered_map<std::string, float> lora_f2m = result_pair.first; // lora_name -> multiplier
1257+
// Apply lora if provided model directory
1258+
int64_t t0, t1;
1259+
if (!sd_ctx->sd->lora_model_dir.empty()) {
1260+
auto result_pair = extract_and_remove_lora(prompt);
1261+
std::unordered_map<std::string, float> lora_f2m = result_pair.first; // lora_name -> multiplier
12511262

1252-
for (auto& kv : lora_f2m) {
1253-
LOG_DEBUG("lora %s:%.2f", kv.first.c_str(), kv.second);
1254-
}
1263+
for (auto& kv : lora_f2m) {
1264+
LOG_DEBUG("lora %s:%.2f", kv.first.c_str(), kv.second);
1265+
}
12551266

1256-
prompt = result_pair.second;
1257-
LOG_DEBUG("prompt after extract and remove lora: \"%s\"", prompt.c_str());
1267+
prompt = result_pair.second;
1268+
LOG_DEBUG("prompt after extract and remove lora: \"%s\"", prompt.c_str());
12581269

1259-
int64_t t0 = ggml_time_ms();
1260-
sd_ctx->sd->apply_loras(lora_f2m);
1261-
int64_t t1 = ggml_time_ms();
1262-
LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
1270+
t0 = ggml_time_ms();
1271+
sd_ctx->sd->apply_loras(lora_f2m);
1272+
t1 = ggml_time_ms();
1273+
LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
1274+
}
12631275

12641276
// Photo Maker
12651277
std::string prompt_text_only;
@@ -1847,6 +1859,27 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
18471859
return result_images;
18481860
}
18491861

1862+
void sd_lora_adapters_clear(sd_ctx_t* sd_ctx) {
1863+
if (sd_ctx == NULL) {
1864+
return;
1865+
}
1866+
sd_ctx->sd->curr_lora_state.clear();
1867+
}
1868+
1869+
void sd_lora_adapters_apply(sd_ctx_t* sd_ctx, std::vector<sd_lora_adapter_container_t>& lora_adapters) {
1870+
if (sd_ctx == NULL) {
1871+
return;
1872+
}
1873+
1874+
sd_lora_adapters_clear(sd_ctx);
1875+
1876+
std::unordered_map<std::string, float> lora_state;
1877+
for (const sd_lora_adapter_container_t& lora_adapter : lora_adapters) {
1878+
lora_state[lora_adapter.path] = lora_adapter.multiplier;
1879+
}
1880+
sd_ctx->sd->apply_loras(lora_state);
1881+
}
1882+
18501883
int sd_get_version(sd_ctx_t* sd_ctx) {
18511884
if (sd_ctx == NULL) {
18521885
return VERSION_COUNT;

stable-diffusion.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,14 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
209209
float strength,
210210
int64_t seed);
211211

212+
typedef struct sd_lora_adapter_container_t {
213+
const char* path;
214+
float multiplier;
215+
} sd_lora_adapter_container_t;
216+
217+
SD_API void sd_lora_adapters_clear(sd_ctx_t* sd_ctx);
218+
SD_API void sd_lora_adapters_apply(sd_ctx_t* sd_ctx, std::vector<sd_lora_adapter_container_t>& lora_adapters);
219+
212220
SD_API int sd_get_version(sd_ctx_t* sd_ctx);
213221

214222
typedef struct upscaler_ctx_t upscaler_ctx_t;

0 commit comments

Comments
 (0)