@@ -765,17 +765,25 @@ class StableDiffusionGGML {
765
765
}
766
766
767
767
void apply_lora (const std::string& lora_name, float multiplier) {
768
- int64_t t0 = ggml_time_ms ();
769
- std::string st_file_path = path_join (lora_model_dir, lora_name + " .safetensors" );
770
- std::string ckpt_file_path = path_join (lora_model_dir, lora_name + " .ckpt" );
768
+ int64_t t0 = ggml_time_ms ();
769
+
771
770
std::string file_path;
772
- if (file_exists (st_file_path)) {
773
- file_path = st_file_path;
774
- } else if (file_exists (ckpt_file_path)) {
775
- file_path = ckpt_file_path;
771
+ if (!lora_model_dir.empty ()) {
772
+ std::string st_file_path = path_join (lora_model_dir, lora_name + " .safetensors" );
773
+ std::string ckpt_file_path = path_join (lora_model_dir, lora_name + " .ckpt" );
774
+ std::string gguf_file_path = path_join (lora_model_dir, lora_name + " .gguf" );
775
+ if (file_exists (st_file_path)) {
776
+ file_path = st_file_path;
777
+ } else if (file_exists (ckpt_file_path)) {
778
+ file_path = ckpt_file_path;
779
+ } else if (file_exists (gguf_file_path)) {
780
+ file_path = gguf_file_path;
781
+ } else {
782
+ LOG_WARN (" can not find %s, %s, %s for lora %s" , st_file_path.c_str (), ckpt_file_path.c_str (), gguf_file_path.c_str (), lora_name.c_str ());
783
+ return ;
784
+ }
776
785
} else {
777
- LOG_WARN (" can not find %s or %s for lora %s" , st_file_path.c_str (), ckpt_file_path.c_str (), lora_name.c_str ());
778
- return ;
786
+ file_path = lora_name;
779
787
}
780
788
LoraModel lora (backend, file_path);
781
789
if (!lora.load_from_file ()) {
@@ -1474,21 +1482,24 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
1474
1482
1475
1483
int sample_steps = sigmas.size () - 1 ;
1476
1484
1477
- // Apply lora
1478
- auto result_pair = extract_and_remove_lora (prompt);
1479
- std::unordered_map<std::string, float > lora_f2m = result_pair.first ; // lora_name -> multiplier
1485
+ // Apply lora if provided model directory
1486
+ int64_t t0, t1;
1487
+ if (!sd_ctx->sd ->lora_model_dir .empty ()) {
1488
+ auto result_pair = extract_and_remove_lora (prompt);
1489
+ std::unordered_map<std::string, float > lora_f2m = result_pair.first ; // lora_name -> multiplier
1480
1490
1481
- for (auto & kv : lora_f2m) {
1482
- LOG_DEBUG (" lora %s:%.2f" , kv.first .c_str (), kv.second );
1483
- }
1491
+ for (auto & kv : lora_f2m) {
1492
+ LOG_DEBUG (" lora %s:%.2f" , kv.first .c_str (), kv.second );
1493
+ }
1484
1494
1485
- prompt = result_pair.second ;
1486
- LOG_DEBUG (" prompt after extract and remove lora: \" %s\" " , prompt.c_str ());
1495
+ prompt = result_pair.second ;
1496
+ LOG_DEBUG (" prompt after extract and remove lora: \" %s\" " , prompt.c_str ());
1487
1497
1488
- int64_t t0 = ggml_time_ms ();
1489
- sd_ctx->sd ->apply_loras (lora_f2m);
1490
- int64_t t1 = ggml_time_ms ();
1491
- LOG_INFO (" apply_loras completed, taking %.2fs" , (t1 - t0) * 1 .0f / 1000 );
1498
+ t0 = ggml_time_ms ();
1499
+ sd_ctx->sd ->apply_loras (lora_f2m);
1500
+ t1 = ggml_time_ms ();
1501
+ LOG_INFO (" apply_loras completed, taking %.2fs" , (t1 - t0) * 1 .0f / 1000 );
1502
+ }
1492
1503
1493
1504
// Photo Maker
1494
1505
std::string prompt_text_only;
@@ -2209,6 +2220,27 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
2209
2220
return result_images;
2210
2221
}
2211
2222
2223
+ void sd_lora_adapters_clear (sd_ctx_t * sd_ctx) {
2224
+ if (sd_ctx == NULL ) {
2225
+ return ;
2226
+ }
2227
+ sd_ctx->sd ->curr_lora_state .clear ();
2228
+ }
2229
+
2230
+ void sd_lora_adapters_apply (sd_ctx_t * sd_ctx, std::vector<sd_lora_adapter_container_t >& lora_adapters) {
2231
+ if (sd_ctx == NULL ) {
2232
+ return ;
2233
+ }
2234
+
2235
+ sd_lora_adapters_clear (sd_ctx);
2236
+
2237
+ std::unordered_map<std::string, float > lora_state;
2238
+ for (const sd_lora_adapter_container_t & lora_adapter : lora_adapters) {
2239
+ lora_state[lora_adapter.path ] = lora_adapter.multiplier ;
2240
+ }
2241
+ sd_ctx->sd ->apply_loras (lora_state);
2242
+ }
2243
+
2212
2244
int sd_get_version (sd_ctx_t * sd_ctx) {
2213
2245
if (sd_ctx == NULL ) {
2214
2246
return VERSION_COUNT;
0 commit comments