Skip to content

Commit f0cdd20

Browse files
committed
refactor(tx): add lora util
Signed-off-by: thxCode <thxcode0824@gmail.com>
1 parent ea688d2 commit f0cdd20

File tree

2 files changed

+61
-21
lines changed

2 files changed

+61
-21
lines changed

stable-diffusion.cpp

Lines changed: 53 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -710,17 +710,25 @@ class StableDiffusionGGML {
710710
}
711711

712712
void apply_lora(const std::string& lora_name, float multiplier) {
713-
int64_t t0 = ggml_time_ms();
714-
std::string st_file_path = path_join(lora_model_dir, lora_name + ".safetensors");
715-
std::string ckpt_file_path = path_join(lora_model_dir, lora_name + ".ckpt");
713+
int64_t t0 = ggml_time_ms();
714+
716715
std::string file_path;
717-
if (file_exists(st_file_path)) {
718-
file_path = st_file_path;
719-
} else if (file_exists(ckpt_file_path)) {
720-
file_path = ckpt_file_path;
716+
if (!lora_model_dir.empty()) {
717+
std::string st_file_path = path_join(lora_model_dir, lora_name + ".safetensors");
718+
std::string ckpt_file_path = path_join(lora_model_dir, lora_name + ".ckpt");
719+
std::string gguf_file_path = path_join(lora_model_dir, lora_name + ".gguf");
720+
if (file_exists(st_file_path)) {
721+
file_path = st_file_path;
722+
} else if (file_exists(ckpt_file_path)) {
723+
file_path = ckpt_file_path;
724+
} else if (file_exists(gguf_file_path)) {
725+
file_path = gguf_file_path;
726+
} else {
727+
LOG_WARN("can not find %s, %s, %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), gguf_file_path.c_str(), lora_name.c_str());
728+
return;
729+
}
721730
} else {
722-
LOG_WARN("can not find %s or %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), lora_name.c_str());
723-
return;
731+
file_path = lora_name;
724732
}
725733
LoraModel lora(backend, file_path);
726734
if (!lora.load_from_file()) {
@@ -1271,21 +1279,24 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
12711279

12721280
int sample_steps = sigmas.size() - 1;
12731281

1274-
// Apply lora
1275-
auto result_pair = extract_and_remove_lora(prompt);
1276-
std::unordered_map<std::string, float> lora_f2m = result_pair.first; // lora_name -> multiplier
1282+
// Apply lora if provided model directory
1283+
int64_t t0, t1;
1284+
if (!sd_ctx->sd->lora_model_dir.empty()) {
1285+
auto result_pair = extract_and_remove_lora(prompt);
1286+
std::unordered_map<std::string, float> lora_f2m = result_pair.first; // lora_name -> multiplier
12771287

1278-
for (auto& kv : lora_f2m) {
1279-
LOG_DEBUG("lora %s:%.2f", kv.first.c_str(), kv.second);
1280-
}
1288+
for (auto& kv : lora_f2m) {
1289+
LOG_DEBUG("lora %s:%.2f", kv.first.c_str(), kv.second);
1290+
}
12811291

1282-
prompt = result_pair.second;
1283-
LOG_DEBUG("prompt after extract and remove lora: \"%s\"", prompt.c_str());
1292+
prompt = result_pair.second;
1293+
LOG_DEBUG("prompt after extract and remove lora: \"%s\"", prompt.c_str());
12841294

1285-
int64_t t0 = ggml_time_ms();
1286-
sd_ctx->sd->apply_loras(lora_f2m);
1287-
int64_t t1 = ggml_time_ms();
1288-
LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
1295+
t0 = ggml_time_ms();
1296+
sd_ctx->sd->apply_loras(lora_f2m);
1297+
t1 = ggml_time_ms();
1298+
LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
1299+
}
12891300

12901301
// Photo Maker
12911302
std::string prompt_text_only;
@@ -1873,6 +1884,27 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
18731884
return result_images;
18741885
}
18751886

1887+
void sd_lora_adapters_clear(sd_ctx_t* sd_ctx) {
1888+
if (sd_ctx == NULL) {
1889+
return;
1890+
}
1891+
sd_ctx->sd->curr_lora_state.clear();
1892+
}
1893+
1894+
void sd_lora_adapters_apply(sd_ctx_t* sd_ctx, std::vector<sd_lora_adapter_container_t>& lora_adapters) {
1895+
if (sd_ctx == NULL) {
1896+
return;
1897+
}
1898+
1899+
sd_lora_adapters_clear(sd_ctx);
1900+
1901+
std::unordered_map<std::string, float> lora_state;
1902+
for (const sd_lora_adapter_container_t& lora_adapter : lora_adapters) {
1903+
lora_state[lora_adapter.path] = lora_adapter.multiplier;
1904+
}
1905+
sd_ctx->sd->apply_loras(lora_state);
1906+
}
1907+
18761908
int sd_get_version(sd_ctx_t* sd_ctx) {
18771909
if (sd_ctx == NULL) {
18781910
return VERSION_COUNT;

stable-diffusion.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,14 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
223223
float strength,
224224
int64_t seed);
225225

226+
typedef struct sd_lora_adapter_container_t {
227+
const char* path;
228+
float multiplier;
229+
} sd_lora_adapter_container_t;
230+
231+
SD_API void sd_lora_adapters_clear(sd_ctx_t* sd_ctx);
232+
SD_API void sd_lora_adapters_apply(sd_ctx_t* sd_ctx, std::vector<sd_lora_adapter_container_t>& lora_adapters);
233+
226234
SD_API int sd_get_version(sd_ctx_t* sd_ctx);
227235

228236
typedef struct upscaler_ctx_t upscaler_ctx_t;

0 commit comments

Comments
 (0)