Skip to content

Commit 01fec2a

Browse files
committed
refactor: lora apply
Signed-off-by: thxCode <thxcode0824@gmail.com>
1 parent a5fcfc7 commit 01fec2a

File tree

1 file changed

+21
-21
lines changed

1 file changed

+21
-21
lines changed

stable-diffusion.cpp

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ class StableDiffusionGGML {
179179
std::string lora_model_dir;
180180
// lora_name => multiplier
181181
std::unordered_map<std::string, float> curr_lora_state;
182+
std::unordered_map<std::string, std::shared_ptr<LoraModel>> curr_loras;
182183

183184
std::shared_ptr<Denoiser> denoiser = std::make_shared<CompVisDenoiser>();
184185

@@ -773,8 +774,6 @@ class StableDiffusionGGML {
773774
}
774775

775776
void apply_lora(const std::string& lora_name, float multiplier) {
776-
int64_t t0 = ggml_time_ms();
777-
778777
std::string file_path;
779778
if (!lora_model_dir.empty()) {
780779
std::string st_file_path = path_join(lora_model_dir, lora_name + ".safetensors");
@@ -793,44 +792,45 @@ class StableDiffusionGGML {
793792
} else {
794793
file_path = lora_name;
795794
}
796-
LoraModel lora(backend, file_path);
797-
if (!lora.load_from_file()) {
798-
LOG_WARN("load lora tensors from %s failed", file_path.c_str());
795+
796+
if (curr_loras.find(lora_name) == curr_loras.end()) {
797+
LOG_INFO("loading lora from '%s'", file_path.c_str());
798+
int64_t t0 = ggml_time_ms();
799+
std::shared_ptr<LoraModel> lora = std::make_shared<LoraModel>(backend, file_path);
800+
if (!lora->load_from_file()) {
801+
LOG_WARN("load lora tensors from %s failed", file_path.c_str());
802+
return;
803+
}
804+
int64_t t1 = ggml_time_ms();
805+
LOG_INFO("lora '%s' loaded, taking %.2fs", lora_name.c_str(), (t1 - t0) * 1.0f / 1000);
806+
curr_loras[lora_name] = lora;
807+
} else if (curr_loras[lora_name]->multiplier == multiplier) {
799808
return;
800809
}
801810

802-
lora.multiplier = multiplier;
803-
// TODO: send version?
804-
lora.apply(tensors, version, n_threads);
805-
lora.free_params_buffer();
806-
811+
int64_t t0 = ggml_time_ms();
812+
std::shared_ptr<LoraModel> lora = curr_loras[lora_name];
813+
lora->multiplier = multiplier;
814+
lora->apply(tensors, version, n_threads);
807815
int64_t t1 = ggml_time_ms();
808-
809816
LOG_INFO("lora '%s' applied, taking %.2fs", lora_name.c_str(), (t1 - t0) * 1.0f / 1000);
810817
}
811818

812819
void apply_loras(const std::unordered_map<std::string, float>& lora_state) {
813-
if (lora_state.size() > 0 && model_wtype != GGML_TYPE_F16 && model_wtype != GGML_TYPE_F32) {
814-
LOG_WARN("In quantized models when applying LoRA, the images have poor quality.");
815-
}
816820
std::unordered_map<std::string, float> lora_state_diff;
817821
for (auto& kv : lora_state) {
818822
const std::string& lora_name = kv.first;
819823
float multiplier = kv.second;
820824

821825
if (curr_lora_state.find(lora_name) != curr_lora_state.end()) {
822826
float curr_multiplier = curr_lora_state[lora_name];
823-
float multiplier_diff = multiplier - curr_multiplier;
824-
if (multiplier_diff != 0.f) {
825-
lora_state_diff[lora_name] = multiplier_diff;
826-
}
827-
} else {
827+
multiplier -= curr_multiplier;
828+
}
829+
if (multiplier != 0.f) {
828830
lora_state_diff[lora_name] = multiplier;
829831
}
830832
}
831833

832-
LOG_INFO("Attempting to apply %lu LoRAs", lora_state.size());
833-
834834
for (auto& kv : lora_state_diff) {
835835
apply_lora(kv.first, kv.second);
836836
}

0 commit comments

Comments
 (0)