@@ -179,6 +179,7 @@ class StableDiffusionGGML {
179
179
std::string lora_model_dir;
180
180
// lora_name => multiplier
181
181
std::unordered_map<std::string, float > curr_lora_state;
182
+ std::unordered_map<std::string, std::shared_ptr<LoraModel>> curr_loras;
182
183
183
184
std::shared_ptr<Denoiser> denoiser = std::make_shared<CompVisDenoiser>();
184
185
@@ -773,8 +774,6 @@ class StableDiffusionGGML {
773
774
}
774
775
775
776
void apply_lora (const std::string& lora_name, float multiplier) {
776
- int64_t t0 = ggml_time_ms ();
777
-
778
777
std::string file_path;
779
778
if (!lora_model_dir.empty ()) {
780
779
std::string st_file_path = path_join (lora_model_dir, lora_name + " .safetensors" );
@@ -793,44 +792,45 @@ class StableDiffusionGGML {
793
792
} else {
794
793
file_path = lora_name;
795
794
}
796
- LoraModel lora (backend, file_path);
797
- if (!lora.load_from_file ()) {
798
- LOG_WARN (" load lora tensors from %s failed" , file_path.c_str ());
795
+
796
+ if (curr_loras.find (lora_name) == curr_loras.end ()) {
797
+ LOG_INFO (" loading lora from '%s'" , file_path.c_str ());
798
+ int64_t t0 = ggml_time_ms ();
799
+ std::shared_ptr<LoraModel> lora = std::make_shared<LoraModel>(backend, file_path);
800
+ if (!lora->load_from_file ()) {
801
+ LOG_WARN (" load lora tensors from %s failed" , file_path.c_str ());
802
+ return ;
803
+ }
804
+ int64_t t1 = ggml_time_ms ();
805
+ LOG_INFO (" lora '%s' loaded, taking %.2fs" , lora_name.c_str (), (t1 - t0) * 1 .0f / 1000 );
806
+ curr_loras[lora_name] = lora;
807
+ } else if (curr_loras[lora_name]->multiplier == multiplier) {
799
808
return ;
800
809
}
801
810
802
- lora.multiplier = multiplier;
803
- // TODO: send version?
804
- lora.apply (tensors, version, n_threads);
805
- lora.free_params_buffer ();
806
-
811
+ int64_t t0 = ggml_time_ms ();
812
+ std::shared_ptr<LoraModel> lora = curr_loras[lora_name];
813
+ lora->multiplier = multiplier;
814
+ lora->apply (tensors, version, n_threads);
807
815
int64_t t1 = ggml_time_ms ();
808
-
809
816
LOG_INFO (" lora '%s' applied, taking %.2fs" , lora_name.c_str (), (t1 - t0) * 1 .0f / 1000 );
810
817
}
811
818
812
819
void apply_loras (const std::unordered_map<std::string, float >& lora_state) {
813
- if (lora_state.size () > 0 && model_wtype != GGML_TYPE_F16 && model_wtype != GGML_TYPE_F32) {
814
- LOG_WARN (" In quantized models when applying LoRA, the images have poor quality." );
815
- }
816
820
std::unordered_map<std::string, float > lora_state_diff;
817
821
for (auto & kv : lora_state) {
818
822
const std::string& lora_name = kv.first ;
819
823
float multiplier = kv.second ;
820
824
821
825
if (curr_lora_state.find (lora_name) != curr_lora_state.end ()) {
822
826
float curr_multiplier = curr_lora_state[lora_name];
823
- float multiplier_diff = multiplier - curr_multiplier;
824
- if (multiplier_diff != 0 .f ) {
825
- lora_state_diff[lora_name] = multiplier_diff;
826
- }
827
- } else {
827
+ multiplier -= curr_multiplier;
828
+ }
829
+ if (multiplier != 0 .f ) {
828
830
lora_state_diff[lora_name] = multiplier;
829
831
}
830
832
}
831
833
832
- LOG_INFO (" Attempting to apply %lu LoRAs" , lora_state.size ());
833
-
834
834
for (auto & kv : lora_state_diff) {
835
835
apply_lora (kv.first , kv.second );
836
836
}
0 commit comments