@@ -684,18 +684,27 @@ class StableDiffusionGGML {
684
684
}
685
685
686
686
void apply_lora (const std::string& lora_name, float multiplier) {
687
- int64_t t0 = ggml_time_ms ();
688
- std::string st_file_path = path_join (lora_model_dir, lora_name + " .safetensors" );
689
- std::string ckpt_file_path = path_join (lora_model_dir, lora_name + " .ckpt" );
687
+ int64_t t0 = ggml_time_ms ();
688
+
690
689
std::string file_path;
691
- if (file_exists (st_file_path)) {
692
- file_path = st_file_path;
693
- } else if (file_exists (ckpt_file_path)) {
694
- file_path = ckpt_file_path;
690
+ if (!lora_model_dir.empty ()) {
691
+ std::string st_file_path = path_join (lora_model_dir, lora_name + " .safetensors" );
692
+ std::string ckpt_file_path = path_join (lora_model_dir, lora_name + " .ckpt" );
693
+ std::string gguf_file_path = path_join (lora_model_dir, lora_name + " .gguf" );
694
+ if (file_exists (st_file_path)) {
695
+ file_path = st_file_path;
696
+ } else if (file_exists (ckpt_file_path)) {
697
+ file_path = ckpt_file_path;
698
+ } else if (file_exists (gguf_file_path)) {
699
+ file_path = gguf_file_path;
700
+ } else {
701
+ LOG_WARN (" can not find %s, %s, %s for lora %s" , st_file_path.c_str (), ckpt_file_path.c_str (), gguf_file_path.c_str (), lora_name.c_str ());
702
+ return ;
703
+ }
695
704
} else {
696
- LOG_WARN (" can not find %s or %s for lora %s" , st_file_path.c_str (), ckpt_file_path.c_str (), lora_name.c_str ());
697
- return ;
705
+ file_path = lora_name;
698
706
}
707
+
699
708
LoraModel lora (backend, model_wtype, file_path);
700
709
if (!lora.load_from_file ()) {
701
710
LOG_WARN (" load lora tensors from %s failed" , file_path.c_str ());
@@ -1245,21 +1254,24 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
1245
1254
1246
1255
int sample_steps = sigmas.size () - 1 ;
1247
1256
1248
- // Apply lora
1249
- auto result_pair = extract_and_remove_lora (prompt);
1250
- std::unordered_map<std::string, float > lora_f2m = result_pair.first ; // lora_name -> multiplier
1257
+ // Apply lora if provided model directory
1258
+ int64_t t0, t1;
1259
+ if (!sd_ctx->sd ->lora_model_dir .empty ()) {
1260
+ auto result_pair = extract_and_remove_lora (prompt);
1261
+ std::unordered_map<std::string, float > lora_f2m = result_pair.first ; // lora_name -> multiplier
1251
1262
1252
- for (auto & kv : lora_f2m) {
1253
- LOG_DEBUG (" lora %s:%.2f" , kv.first .c_str (), kv.second );
1254
- }
1263
+ for (auto & kv : lora_f2m) {
1264
+ LOG_DEBUG (" lora %s:%.2f" , kv.first .c_str (), kv.second );
1265
+ }
1255
1266
1256
- prompt = result_pair.second ;
1257
- LOG_DEBUG (" prompt after extract and remove lora: \" %s\" " , prompt.c_str ());
1267
+ prompt = result_pair.second ;
1268
+ LOG_DEBUG (" prompt after extract and remove lora: \" %s\" " , prompt.c_str ());
1258
1269
1259
- int64_t t0 = ggml_time_ms ();
1260
- sd_ctx->sd ->apply_loras (lora_f2m);
1261
- int64_t t1 = ggml_time_ms ();
1262
- LOG_INFO (" apply_loras completed, taking %.2fs" , (t1 - t0) * 1 .0f / 1000 );
1270
+ t0 = ggml_time_ms ();
1271
+ sd_ctx->sd ->apply_loras (lora_f2m);
1272
+ t1 = ggml_time_ms ();
1273
+ LOG_INFO (" apply_loras completed, taking %.2fs" , (t1 - t0) * 1 .0f / 1000 );
1274
+ }
1263
1275
1264
1276
// Photo Maker
1265
1277
std::string prompt_text_only;
@@ -1847,6 +1859,27 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
1847
1859
return result_images;
1848
1860
}
1849
1861
1862
+ void sd_lora_adapters_clear (sd_ctx_t * sd_ctx) {
1863
+ if (sd_ctx == NULL ) {
1864
+ return ;
1865
+ }
1866
+ sd_ctx->sd ->curr_lora_state .clear ();
1867
+ }
1868
+
1869
+ void sd_lora_adapters_apply (sd_ctx_t * sd_ctx, std::vector<sd_lora_adapter_container_t >& lora_adapters) {
1870
+ if (sd_ctx == NULL ) {
1871
+ return ;
1872
+ }
1873
+
1874
+ sd_lora_adapters_clear (sd_ctx);
1875
+
1876
+ std::unordered_map<std::string, float > lora_state;
1877
+ for (const sd_lora_adapter_container_t & lora_adapter : lora_adapters) {
1878
+ lora_state[lora_adapter.path ] = lora_adapter.multiplier ;
1879
+ }
1880
+ sd_ctx->sd ->apply_loras (lora_state);
1881
+ }
1882
+
1850
1883
int sd_get_version (sd_ctx_t * sd_ctx) {
1851
1884
if (sd_ctx == NULL ) {
1852
1885
return VERSION_COUNT;
0 commit comments