@@ -710,17 +710,25 @@ class StableDiffusionGGML {
710
710
}
711
711
712
712
void apply_lora (const std::string& lora_name, float multiplier) {
713
- int64_t t0 = ggml_time_ms ();
714
- std::string st_file_path = path_join (lora_model_dir, lora_name + " .safetensors" );
715
- std::string ckpt_file_path = path_join (lora_model_dir, lora_name + " .ckpt" );
713
+ int64_t t0 = ggml_time_ms ();
714
+
716
715
std::string file_path;
717
- if (file_exists (st_file_path)) {
718
- file_path = st_file_path;
719
- } else if (file_exists (ckpt_file_path)) {
720
- file_path = ckpt_file_path;
716
+ if (!lora_model_dir.empty ()) {
717
+ std::string st_file_path = path_join (lora_model_dir, lora_name + " .safetensors" );
718
+ std::string ckpt_file_path = path_join (lora_model_dir, lora_name + " .ckpt" );
719
+ std::string gguf_file_path = path_join (lora_model_dir, lora_name + " .gguf" );
720
+ if (file_exists (st_file_path)) {
721
+ file_path = st_file_path;
722
+ } else if (file_exists (ckpt_file_path)) {
723
+ file_path = ckpt_file_path;
724
+ } else if (file_exists (gguf_file_path)) {
725
+ file_path = gguf_file_path;
726
+ } else {
727
+ LOG_WARN (" can not find %s, %s, %s for lora %s" , st_file_path.c_str (), ckpt_file_path.c_str (), gguf_file_path.c_str (), lora_name.c_str ());
728
+ return ;
729
+ }
721
730
} else {
722
- LOG_WARN (" can not find %s or %s for lora %s" , st_file_path.c_str (), ckpt_file_path.c_str (), lora_name.c_str ());
723
- return ;
731
+ file_path = lora_name;
724
732
}
725
733
LoraModel lora (backend, file_path);
726
734
if (!lora.load_from_file ()) {
@@ -1271,21 +1279,24 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
1271
1279
1272
1280
int sample_steps = sigmas.size () - 1 ;
1273
1281
1274
- // Apply lora
1275
- auto result_pair = extract_and_remove_lora (prompt);
1276
- std::unordered_map<std::string, float > lora_f2m = result_pair.first ; // lora_name -> multiplier
1282
+ // Apply lora if provided model directory
1283
+ int64_t t0, t1;
1284
+ if (!sd_ctx->sd ->lora_model_dir .empty ()) {
1285
+ auto result_pair = extract_and_remove_lora (prompt);
1286
+ std::unordered_map<std::string, float > lora_f2m = result_pair.first ; // lora_name -> multiplier
1277
1287
1278
- for (auto & kv : lora_f2m) {
1279
- LOG_DEBUG (" lora %s:%.2f" , kv.first .c_str (), kv.second );
1280
- }
1288
+ for (auto & kv : lora_f2m) {
1289
+ LOG_DEBUG (" lora %s:%.2f" , kv.first .c_str (), kv.second );
1290
+ }
1281
1291
1282
- prompt = result_pair.second ;
1283
- LOG_DEBUG (" prompt after extract and remove lora: \" %s\" " , prompt.c_str ());
1292
+ prompt = result_pair.second ;
1293
+ LOG_DEBUG (" prompt after extract and remove lora: \" %s\" " , prompt.c_str ());
1284
1294
1285
- int64_t t0 = ggml_time_ms ();
1286
- sd_ctx->sd ->apply_loras (lora_f2m);
1287
- int64_t t1 = ggml_time_ms ();
1288
- LOG_INFO (" apply_loras completed, taking %.2fs" , (t1 - t0) * 1 .0f / 1000 );
1295
+ t0 = ggml_time_ms ();
1296
+ sd_ctx->sd ->apply_loras (lora_f2m);
1297
+ t1 = ggml_time_ms ();
1298
+ LOG_INFO (" apply_loras completed, taking %.2fs" , (t1 - t0) * 1 .0f / 1000 );
1299
+ }
1289
1300
1290
1301
// Photo Maker
1291
1302
std::string prompt_text_only;
@@ -1873,6 +1884,27 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
1873
1884
return result_images;
1874
1885
}
1875
1886
1887
+ void sd_lora_adapters_clear (sd_ctx_t * sd_ctx) {
1888
+ if (sd_ctx == NULL ) {
1889
+ return ;
1890
+ }
1891
+ sd_ctx->sd ->curr_lora_state .clear ();
1892
+ }
1893
+
1894
+ void sd_lora_adapters_apply (sd_ctx_t * sd_ctx, std::vector<sd_lora_adapter_container_t >& lora_adapters) {
1895
+ if (sd_ctx == NULL ) {
1896
+ return ;
1897
+ }
1898
+
1899
+ sd_lora_adapters_clear (sd_ctx);
1900
+
1901
+ std::unordered_map<std::string, float > lora_state;
1902
+ for (const sd_lora_adapter_container_t & lora_adapter : lora_adapters) {
1903
+ lora_state[lora_adapter.path ] = lora_adapter.multiplier ;
1904
+ }
1905
+ sd_ctx->sd ->apply_loras (lora_state);
1906
+ }
1907
+
1876
1908
int sd_get_version (sd_ctx_t * sd_ctx) {
1877
1909
if (sd_ctx == NULL ) {
1878
1910
return VERSION_COUNT;
0 commit comments