Skip to content

Commit 78629d6

Browse files
committed
refactor(tx): speed up q4_0 loading
Signed-off-by: thxCode <thxcode0824@gmail.com>
1 parent 34671f9 commit 78629d6

File tree

4 files changed

+99
-50
lines changed

4 files changed

+99
-50
lines changed

conditioner.hpp

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
4848
SDVersion version = VERSION_SD1;
4949
PMVersion pm_version = PM_VERSION_1;
5050
CLIPTokenizer tokenizer;
51-
ggml_type wtype;
5251
std::shared_ptr<CLIPTextModelRunner> text_model;
5352
std::shared_ptr<CLIPTextModelRunner> text_model2;
5453

@@ -59,7 +58,8 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
5958
std::vector<std::string> readed_embeddings;
6059

6160
FrozenCLIPEmbedderWithCustomWords(ggml_backend_t backend,
62-
ggml_type wtype,
61+
ggml_type clip_l_wtype,
62+
ggml_type clip_g_wtype,
6363
const std::string& embd_dir,
6464
SDVersion version = VERSION_SD1,
6565
PMVersion pv = PM_VERSION_1,
@@ -70,7 +70,6 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
7070
pm_version(pv),
7171
tokenizer(version == VERSION_SD2 ? 0 : 49407),
7272
embd_dir(embd_dir),
73-
wtype(wtype),
7473
compvis_compatiblity_clip_l(compvis_compatiblity_clip_l),
7574
compvis_compatiblity_clip_g(compvis_compatiblity_clip_g) {
7675
if (clip_skip <= 0) {
@@ -80,14 +79,14 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
8079
}
8180
}
8281
if (version == VERSION_SD1) {
83-
text_model = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip);
82+
text_model = std::make_shared<CLIPTextModelRunner>(backend, clip_l_wtype, OPENAI_CLIP_VIT_L_14, clip_skip);
8483
} else if (version == VERSION_SD2) {
85-
text_model = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPEN_CLIP_VIT_H_14, clip_skip);
84+
text_model = std::make_shared<CLIPTextModelRunner>(backend, clip_l_wtype, OPEN_CLIP_VIT_H_14, clip_skip);
8685
} else if (version == VERSION_SDXL) {
87-
text_model = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip, false);
88-
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
86+
text_model = std::make_shared<CLIPTextModelRunner>(backend, clip_l_wtype, OPENAI_CLIP_VIT_L_14, clip_skip, false);
87+
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, clip_g_wtype, OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
8988
} else if (version == VERSION_SDXL_REFINER) {
90-
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
89+
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, clip_g_wtype, OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
9190
}
9291
}
9392

@@ -174,14 +173,14 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
174173
LOG_DEBUG("embedding wrong hidden size, got %i, expected %i", tensor_storage.ne[0], hidden_size);
175174
return false;
176175
}
177-
embd = ggml_new_tensor_2d(embd_ctx, wtype, hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne[1] : 1);
176+
embd = ggml_new_tensor_2d(embd_ctx, tensor_storage.type, hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne[1] : 1);
178177
*dst_tensor = embd;
179178
return true;
180179
};
181180
model_loader.load_tensors(on_load, NULL);
182181
readed_embeddings.push_back(embd_name);
183182
token_embed_custom.resize(token_embed_custom.size() + ggml_nbytes(embd));
184-
memcpy((void*)(token_embed_custom.data() + num_custom_embeddings * hidden_size * ggml_type_size(wtype)),
183+
memcpy((void*)(token_embed_custom.data() + num_custom_embeddings * hidden_size * ggml_type_size(embd->type)),
185184
embd->data,
186185
ggml_nbytes(embd));
187186
for (int i = 0; i < embd->ne[1]; i++) {
@@ -674,7 +673,6 @@ struct SD3CLIPEmbedder : public Conditioner {
674673
bool compvis_compatiblity_clip_l;
675674
bool compvis_compatiblity_clip_g;
676675
bool compvis_compatiblity_t5xxl;
677-
ggml_type wtype;
678676
CLIPTokenizer clip_l_tokenizer;
679677
CLIPTokenizer clip_g_tokenizer;
680678
T5UniGramTokenizer t5_tokenizer;
@@ -683,22 +681,23 @@ struct SD3CLIPEmbedder : public Conditioner {
683681
std::shared_ptr<T5Runner> t5;
684682

685683
SD3CLIPEmbedder(ggml_backend_t backend,
686-
ggml_type wtype,
684+
ggml_type clip_l_wtype,
685+
ggml_type clip_g_wtype,
686+
ggml_type t5xxl_wtype,
687687
bool compvis_compatiblity_clip_l = false,
688688
bool compvis_compatiblity_clip_g = false,
689689
bool compvis_compatiblity_t5xxl = false,
690690
int clip_skip = -1)
691-
: wtype(wtype),
692-
clip_g_tokenizer(0),
691+
: clip_g_tokenizer(0),
693692
compvis_compatiblity_clip_l(compvis_compatiblity_clip_l),
694693
compvis_compatiblity_clip_g(compvis_compatiblity_clip_g),
695694
compvis_compatiblity_t5xxl(compvis_compatiblity_t5xxl) {
696695
if (clip_skip <= 0) {
697696
clip_skip = 2;
698697
}
699-
clip_l = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip, false);
700-
clip_g = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
701-
t5 = std::make_shared<T5Runner>(backend, wtype);
698+
clip_l = std::make_shared<CLIPTextModelRunner>(backend, clip_l_wtype, OPENAI_CLIP_VIT_L_14, clip_skip, false);
699+
clip_g = std::make_shared<CLIPTextModelRunner>(backend, clip_g_wtype, OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
700+
t5 = std::make_shared<T5Runner>(backend, t5xxl_wtype);
702701
}
703702

704703
void set_clip_skip(int clip_skip) {
@@ -1042,25 +1041,24 @@ struct SD3CLIPEmbedder : public Conditioner {
10421041
struct FluxCLIPEmbedder : public Conditioner {
10431042
bool compvis_compatiblity_clip_l;
10441043
bool compvis_compatiblity_t5xxl;
1045-
ggml_type wtype;
10461044
CLIPTokenizer clip_l_tokenizer;
10471045
T5UniGramTokenizer t5_tokenizer;
10481046
std::shared_ptr<CLIPTextModelRunner> clip_l;
10491047
std::shared_ptr<T5Runner> t5;
10501048

10511049
FluxCLIPEmbedder(ggml_backend_t backend,
1052-
ggml_type wtype,
1050+
ggml_type clip_l_wtype,
1051+
ggml_type t5xxl_wtype,
10531052
bool compvis_compatiblity_clip_l = false,
10541053
bool compvis_compatiblity_t5xxl = false,
10551054
int clip_skip = -1)
1056-
: wtype(wtype),
1057-
compvis_compatiblity_clip_l(compvis_compatiblity_clip_l),
1055+
: compvis_compatiblity_clip_l(compvis_compatiblity_clip_l),
10581056
compvis_compatiblity_t5xxl(compvis_compatiblity_t5xxl) {
10591057
if (clip_skip <= 0) {
10601058
clip_skip = 2;
10611059
}
1062-
clip_l = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip, true);
1063-
t5 = std::make_shared<T5Runner>(backend, wtype);
1060+
clip_l = std::make_shared<CLIPTextModelRunner>(backend, clip_l_wtype, OPENAI_CLIP_VIT_L_14, clip_skip, true);
1061+
t5 = std::make_shared<T5Runner>(backend, t5xxl_wtype);
10641062
}
10651063

10661064
void set_clip_skip(int clip_skip) {

model.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1691,7 +1691,7 @@ ggml_type ModelLoader::get_sd_wtype() {
16911691
return get_diffusion_model_wtype();
16921692
}
16931693

1694-
ggml_type ModelLoader::get_conditioner_wtype() {
1694+
ggml_type ModelLoader::get_conditioner_wtype(std::vector<std::string> prefixes) {
16951695
for (auto& tensor_storage : tensor_storages) {
16961696
if (is_unused_tensor(tensor_storage.name)) {
16971697
continue;
@@ -1704,6 +1704,16 @@ ggml_type ModelLoader::get_conditioner_wtype() {
17041704
continue;
17051705
}
17061706

1707+
bool goahead = true;
1708+
if (!prefixes.empty()) {
1709+
goahead = std::any_of(prefixes.begin(), prefixes.end(), [&](const std::string& prefix) {
1710+
return tensor_storage.name.find(prefix) != std::string::npos;
1711+
});
1712+
}
1713+
if (!goahead) {
1714+
continue;
1715+
}
1716+
17071717
if (ggml_is_quantized(tensor_storage.type)) {
17081718
return tensor_storage.type;
17091719
}

model.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ class ModelLoader {
183183
bool init_from_safetensors_file(const std::string& dir_path, const std::string& file_prefix, ggml_type type, const std::string& prefix = "");
184184
SDVersion get_sd_version();
185185
ggml_type get_sd_wtype();
186-
ggml_type get_conditioner_wtype();
186+
ggml_type get_conditioner_wtype(std::vector<std::string> prefixes = {});
187187
ggml_type get_diffusion_model_wtype();
188188
ggml_type get_vae_wtype();
189189
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend_t backend);

stable-diffusion.cpp

Lines changed: 66 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,9 @@ class StableDiffusionGGML {
116116
ggml_backend_t control_net_backend = NULL;
117117
ggml_backend_t vae_backend = NULL;
118118
ggml_type model_wtype = GGML_TYPE_COUNT;
119-
ggml_type conditioner_wtype = GGML_TYPE_COUNT;
119+
ggml_type clip_l_wtype = GGML_TYPE_COUNT;
120+
ggml_type clip_g_wtype = GGML_TYPE_COUNT;
121+
ggml_type t5xxl_wtype = GGML_TYPE_COUNT;
120122
ggml_type diffusion_model_wtype = GGML_TYPE_COUNT;
121123
ggml_type vae_wtype = GGML_TYPE_COUNT;
122124

@@ -305,32 +307,79 @@ class StableDiffusionGGML {
305307
model_wtype = GGML_TYPE_F32;
306308
LOG_WARN("can not get mode wtype frome weight, use f32");
307309
}
308-
conditioner_wtype = model_loader.get_conditioner_wtype();
309-
if (conditioner_wtype == GGML_TYPE_COUNT) {
310-
conditioner_wtype = wtype;
310+
switch (version) {
311+
case VERSION_SVD:
312+
case VERSION_SD1:
313+
case VERSION_SD2:
314+
case VERSION_SDXL:
315+
case VERSION_SDXL_REFINER: {
316+
if (version != VERSION_SDXL_REFINER) {
317+
clip_l_wtype = model_loader.get_conditioner_wtype({"cond_stage_model.transformer.", "text_encoders.clip_l."});
318+
if (clip_l_wtype == GGML_TYPE_COUNT) {
319+
clip_l_wtype = wtype;
320+
}
321+
}
322+
if (version == VERSION_SDXL_REFINER || version == VERSION_SDXL) {
323+
clip_g_wtype = model_loader.get_conditioner_wtype({"cond_stage_model.1.", "text_encoders.clip_g."});
324+
if (clip_g_wtype == GGML_TYPE_COUNT) {
325+
clip_g_wtype = wtype;
326+
}
327+
}
328+
break;
329+
}
330+
case VERSION_SD3_2B:
331+
case VERSION_SD3_5_2B:
332+
case VERSION_SD3_5_8B: {
333+
clip_l_wtype = model_loader.get_conditioner_wtype({"cond_stage_model.transformer.", "text_encoders.clip_l."});
334+
if (clip_l_wtype == GGML_TYPE_COUNT) {
335+
clip_l_wtype = wtype;
336+
}
337+
clip_g_wtype = model_loader.get_conditioner_wtype({"cond_stage_model.1.", "text_encoders.clip_g."});
338+
if (clip_g_wtype == GGML_TYPE_COUNT) {
339+
clip_g_wtype = wtype;
340+
}
341+
t5xxl_wtype = model_loader.get_conditioner_wtype({"cond_stage_model.2.", "text_encoders.t5xxl."});
342+
if (t5xxl_wtype == GGML_TYPE_COUNT) {
343+
t5xxl_wtype = wtype;
344+
}
345+
break;
346+
}
347+
case VERSION_FLUX_LITE:
348+
case VERSION_FLUX_DEV:
349+
case VERSION_FLUX_SCHNELL: {
350+
clip_l_wtype = model_loader.get_conditioner_wtype({"cond_stage_model.transformer.", "text_encoders.clip_l."});
351+
if (clip_l_wtype == GGML_TYPE_COUNT) {
352+
clip_l_wtype = wtype;
353+
}
354+
t5xxl_wtype = model_loader.get_conditioner_wtype({"cond_stage_model.1.", "text_encoders.t5xxl."});
355+
if (t5xxl_wtype == GGML_TYPE_COUNT) {
356+
t5xxl_wtype = wtype;
357+
}
358+
break;
359+
}
311360
}
361+
312362
diffusion_model_wtype = model_loader.get_diffusion_model_wtype();
313363
if (diffusion_model_wtype == GGML_TYPE_COUNT) {
314364
diffusion_model_wtype = wtype;
315365
}
316366
vae_wtype = model_loader.get_vae_wtype();
317-
318367
if (vae_wtype == GGML_TYPE_COUNT) {
319368
vae_wtype = wtype;
320369
}
321370
} else {
322371
model_wtype = wtype;
323-
conditioner_wtype = wtype;
372+
clip_l_wtype = wtype;
373+
clip_g_wtype = wtype;
374+
t5xxl_wtype = wtype;
324375
diffusion_model_wtype = wtype;
325376
vae_wtype = wtype;
326377
}
327378

328-
if (version == VERSION_SDXL || version == VERSION_SDXL_REFINER) {
329-
vae_wtype = GGML_TYPE_F32;
330-
}
331-
332379
LOG_INFO("Weight type: %s", ggml_type_name(model_wtype));
333-
LOG_INFO("Conditioner weight type: %s", ggml_type_name(conditioner_wtype));
380+
LOG_INFO("CLIP_L weight type: %s", ggml_type_name(clip_l_wtype));
381+
LOG_INFO("CLIP_G weight type: %s", ggml_type_name(clip_g_wtype));
382+
LOG_INFO("T5XXL weight type: %s", ggml_type_name(t5xxl_wtype));
334383
LOG_INFO("Diffusion model weight type: %s", ggml_type_name(diffusion_model_wtype));
335384
LOG_INFO("VAE weight type: %s", ggml_type_name(vae_wtype));
336385

@@ -351,7 +400,7 @@ class StableDiffusionGGML {
351400
auto cc_vae = model_loader.has_prefix_tensors("first_stage_model.") && !model_loader.has_prefix_tensors("vae.");
352401

353402
if (version == VERSION_SVD) {
354-
clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend, conditioner_wtype);
403+
clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend, clip_l_wtype);
355404
clip_vision->alloc_params_buffer();
356405
clip_vision->get_param_tensors(tensors);
357406

@@ -364,15 +413,7 @@ class StableDiffusionGGML {
364413
first_stage_model->alloc_params_buffer();
365414
first_stage_model->get_param_tensors(tensors);
366415
} else {
367-
clip_backend = backend;
368-
bool use_t5xxl = false;
369-
if (sd_version_is_dit(version)) {
370-
use_t5xxl = true;
371-
}
372-
if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != diffusion_model_wtype) {
373-
clip_on_cpu = true;
374-
LOG_INFO("set clip_on_cpu to true");
375-
}
416+
clip_backend = backend;
376417
if (clip_on_cpu && !ggml_backend_is_cpu(backend)) {
377418
LOG_INFO("CLIP: Using CPU backend");
378419
clip_backend = ggml_backend_cpu_init();
@@ -384,16 +425,16 @@ class StableDiffusionGGML {
384425
if (diffusion_flash_attn) {
385426
LOG_WARN("flash attention in this diffusion model is currently unsupported!");
386427
}
387-
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend, conditioner_wtype, cc_clip_l, cc_clip_g, cc_t5xxl);
428+
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend, clip_l_wtype, clip_g_wtype, t5xxl_wtype, cc_clip_l, cc_clip_g, cc_t5xxl);
388429
diffusion_model = std::make_shared<MMDiTModel>(backend, diffusion_model_wtype, version);
389430
} else if (sd_version_is_flux(version)) {
390-
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, conditioner_wtype, cc_clip_l, cc_t5xxl);
431+
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, clip_l_wtype, t5xxl_wtype, cc_clip_l, cc_t5xxl);
391432
diffusion_model = std::make_shared<FluxModel>(backend, diffusion_model_wtype, version, diffusion_flash_attn);
392433
} else {
393434
if (id_embeddings_path.find("v2") != std::string::npos) {
394-
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, conditioner_wtype, embeddings_path, version, PM_VERSION_2, cc_clip_l, cc_clip_g);
435+
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, clip_l_wtype, clip_g_wtype, embeddings_path, version, PM_VERSION_2, cc_clip_l, cc_clip_g);
395436
} else {
396-
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, conditioner_wtype, embeddings_path, version, PM_VERSION_1, cc_clip_l, cc_clip_g);
437+
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, clip_l_wtype, clip_g_wtype, embeddings_path, version, PM_VERSION_1, cc_clip_l, cc_clip_g);
397438
}
398439
diffusion_model = std::make_shared<UNetModel>(backend, diffusion_model_wtype, version, diffusion_flash_attn);
399440
}

0 commit comments

Comments
 (0)