@@ -642,18 +642,27 @@ class StableDiffusionGGML {
642
642
}
643
643
644
644
void apply_lora (const std::string& lora_name, float multiplier) {
645
- int64_t t0 = ggml_time_ms ();
646
- std::string st_file_path = path_join (lora_model_dir, lora_name + " .safetensors" );
647
- std::string ckpt_file_path = path_join (lora_model_dir, lora_name + " .ckpt" );
645
+ int64_t t0 = ggml_time_ms ();
646
+
648
647
std::string file_path;
649
- if (file_exists (st_file_path)) {
650
- file_path = st_file_path;
651
- } else if (file_exists (ckpt_file_path)) {
652
- file_path = ckpt_file_path;
648
+ if (!lora_model_dir.empty ()) {
649
+ std::string st_file_path = path_join (lora_model_dir, lora_name + " .safetensors" );
650
+ std::string ckpt_file_path = path_join (lora_model_dir, lora_name + " .ckpt" );
651
+ std::string gguf_file_path = path_join (lora_model_dir, lora_name + " .gguf" );
652
+ if (file_exists (st_file_path)) {
653
+ file_path = st_file_path;
654
+ } else if (file_exists (ckpt_file_path)) {
655
+ file_path = ckpt_file_path;
656
+ } else if (file_exists (gguf_file_path)) {
657
+ file_path = gguf_file_path;
658
+ } else {
659
+ LOG_WARN (" can not find %s, %s, %s for lora %s" , st_file_path.c_str (), ckpt_file_path.c_str (), gguf_file_path.c_str (), lora_name.c_str ());
660
+ return ;
661
+ }
653
662
} else {
654
- LOG_WARN (" can not find %s or %s for lora %s" , st_file_path.c_str (), ckpt_file_path.c_str (), lora_name.c_str ());
655
- return ;
663
+ file_path = lora_name;
656
664
}
665
+
657
666
LoraModel lora (backend, model_wtype, file_path);
658
667
if (!lora.load_from_file ()) {
659
668
LOG_WARN (" load lora tensors from %s failed" , file_path.c_str ());
@@ -673,6 +682,7 @@ class StableDiffusionGGML {
673
682
if (!lora_state.empty () && model_wtype != GGML_TYPE_F16 && model_wtype != GGML_TYPE_F32) {
674
683
LOG_WARN (" In quantized models when applying LoRA, the images have poor quality." );
675
684
}
685
+
676
686
std::unordered_map<std::string, float > lora_state_diff;
677
687
for (auto & kv : lora_state) {
678
688
const std::string& lora_name = kv.first ;
@@ -690,11 +700,9 @@ class StableDiffusionGGML {
690
700
}
691
701
692
702
LOG_INFO (" Attempting to apply %lu LoRAs" , lora_state.size ());
693
-
694
703
for (auto & kv : lora_state_diff) {
695
704
apply_lora (kv.first , kv.second );
696
705
}
697
-
698
706
curr_lora_state = lora_state;
699
707
}
700
708
@@ -980,8 +988,6 @@ class StableDiffusionGGML {
980
988
case VERSION_SD3_MEDIUM:
981
989
case VERSION_SD3_5_MEDIUM:
982
990
case VERSION_SD3_5_LARGE:
983
- C = 32 ;
984
- break ;
985
991
case VERSION_FLUX_DEV:
986
992
case VERSION_FLUX_SCHNELL:
987
993
C = 32 ;
@@ -1163,20 +1169,23 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
1163
1169
int sample_steps = sigmas.size () - 1 ;
1164
1170
1165
1171
// Apply lora
1166
- auto result_pair = extract_and_remove_lora (prompt);
1167
- std::unordered_map<std::string, float > lora_f2m = result_pair.first ; // lora_name -> multiplier
1172
+ int64_t t0, t1;
1173
+ if (!sd_ctx->sd ->lora_model_dir .empty ()) {
1174
+ auto result_pair = extract_and_remove_lora (prompt);
1175
+ std::unordered_map<std::string, float > lora_f2m = result_pair.first ; // lora_name -> multiplier
1168
1176
1169
- for (auto & kv : lora_f2m) {
1170
- LOG_DEBUG (" lora %s:%.2f" , kv.first .c_str (), kv.second );
1171
- }
1177
+ for (auto & kv : lora_f2m) {
1178
+ LOG_DEBUG (" lora %s:%.2f" , kv.first .c_str (), kv.second );
1179
+ }
1172
1180
1173
- prompt = result_pair.second ;
1174
- LOG_DEBUG (" prompt after extract and remove lora: \" %s\" " , prompt.c_str ());
1181
+ prompt = result_pair.second ;
1182
+ LOG_DEBUG (" prompt after extract and remove lora: \" %s\" " , prompt.c_str ());
1175
1183
1176
- int64_t t0 = ggml_time_ms ();
1177
- sd_ctx->sd ->apply_loras (lora_f2m);
1178
- int64_t t1 = ggml_time_ms ();
1179
- LOG_INFO (" apply_loras completed, taking %.2fs" , (t1 - t0) * 1 .0f / 1000 );
1184
+ t0 = ggml_time_ms ();
1185
+ sd_ctx->sd ->apply_loras (lora_f2m);
1186
+ t1 = ggml_time_ms ();
1187
+ LOG_INFO (" apply_loras completed, taking %.2fs" , (t1 - t0) * 1 .0f / 1000 );
1188
+ }
1180
1189
1181
1190
// Photo Maker
1182
1191
std::string prompt_text_only;
@@ -1630,6 +1639,27 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
1630
1639
return result_images;
1631
1640
}
1632
1641
1642
+ void sd_lora_adapters_clear (sd_ctx_t * sd_ctx) {
1643
+ if (sd_ctx == NULL ) {
1644
+ return ;
1645
+ }
1646
+ sd_ctx->sd ->curr_lora_state .clear ();
1647
+ }
1648
+
1649
+ void sd_lora_adapters_apply (sd_ctx_t * sd_ctx, std::vector<sd_lora_adapter_container_t >& lora_adapters) {
1650
+ if (sd_ctx == NULL ) {
1651
+ return ;
1652
+ }
1653
+
1654
+ sd_lora_adapters_clear (sd_ctx);
1655
+
1656
+ std::unordered_map<std::string, float > lora_state;
1657
+ for (const sd_lora_adapter_container_t & lora_adapter : lora_adapters) {
1658
+ lora_state[lora_adapter.path ] = lora_adapter.multiplier ;
1659
+ }
1660
+ sd_ctx->sd ->apply_loras (lora_state);
1661
+ }
1662
+
1633
1663
int sd_get_version (sd_ctx_t * sd_ctx) {
1634
1664
if (sd_ctx == NULL ) {
1635
1665
return VERSION_COUNT;
0 commit comments