Skip to content

Commit 63f027e

Browse files
committed
refactor(tx): adjust wtype log
Signed-off-by: thxCode <thxcode0824@gmail.com>
1 parent 75c6149 commit 63f027e

File tree

3 files changed

+75
-32
lines changed

3 files changed

+75
-32
lines changed

model.cpp

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1564,23 +1564,10 @@ SDVersion ModelLoader::get_sd_version() {
15641564
}
15651565

15661566
ggml_type ModelLoader::get_sd_wtype() {
1567-
for (auto& tensor_storage : tensor_storages) {
1568-
if (is_unused_tensor(tensor_storage.name)) {
1569-
continue;
1570-
}
1571-
1572-
if (ggml_is_quantized(tensor_storage.type)) {
1573-
return tensor_storage.type;
1574-
}
1575-
1576-
if (tensor_should_be_converted(tensor_storage, GGML_TYPE_Q4_K)) {
1577-
return tensor_storage.type;
1578-
}
1579-
}
1580-
return GGML_TYPE_COUNT;
1567+
return get_diffusion_model_wtype();
15811568
}
15821569

1583-
ggml_type ModelLoader::get_conditioner_wtype() {
1570+
ggml_type ModelLoader::get_conditioner_wtype(std::vector<std::string> prefixes) {
15841571
for (auto& tensor_storage : tensor_storages) {
15851572
if (is_unused_tensor(tensor_storage.name)) {
15861573
continue;
@@ -1593,6 +1580,16 @@ ggml_type ModelLoader::get_conditioner_wtype() {
15931580
continue;
15941581
}
15951582

1583+
bool goahead = true;
1584+
if (!prefixes.empty()) {
1585+
goahead = std::any_of(prefixes.begin(), prefixes.end(), [&](const std::string& prefix) {
1586+
return tensor_storage.name.find(prefix) != std::string::npos;
1587+
});
1588+
}
1589+
if (!goahead) {
1590+
continue;
1591+
}
1592+
15961593
if (ggml_is_quantized(tensor_storage.type)) {
15971594
return tensor_storage.type;
15981595
}

model.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ class ModelLoader {
212212
bool init_from_file(const std::string& file_path, const std::string& prefix = "");
213213
SDVersion get_sd_version();
214214
ggml_type get_sd_wtype();
215-
ggml_type get_conditioner_wtype();
215+
ggml_type get_conditioner_wtype(std::vector<std::string> prefixes = {});
216216
ggml_type get_diffusion_model_wtype();
217217
ggml_type get_vae_wtype();
218218
void set_wtype_override(ggml_type wtype, std::string prefix = "");

stable-diffusion.cpp

Lines changed: 62 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,9 @@ class StableDiffusionGGML {
144144
ggml_backend_t control_net_backend = NULL;
145145
ggml_backend_t vae_backend = NULL;
146146
ggml_type model_wtype = GGML_TYPE_COUNT;
147-
ggml_type conditioner_wtype = GGML_TYPE_COUNT;
147+
ggml_type clip_l_wtype = GGML_TYPE_COUNT;
148+
ggml_type clip_g_wtype = GGML_TYPE_COUNT;
149+
ggml_type t5xxl_wtype = GGML_TYPE_COUNT;
148150
ggml_type diffusion_model_wtype = GGML_TYPE_COUNT;
149151
ggml_type vae_wtype = GGML_TYPE_COUNT;
150152

@@ -337,9 +339,56 @@ class StableDiffusionGGML {
337339
model_wtype = GGML_TYPE_F32;
338340
LOG_WARN("can not get mode wtype frome weight, use f32");
339341
}
340-
conditioner_wtype = model_loader.get_conditioner_wtype();
341-
if (conditioner_wtype == GGML_TYPE_COUNT) {
342-
conditioner_wtype = wtype;
342+
switch (version) {
343+
case VERSION_SVD:
344+
case VERSION_SD1:
345+
case VERSION_SD1_INPAINT:
346+
case VERSION_SD2:
347+
case VERSION_SD2_INPAINT:
348+
case VERSION_SDXL:
349+
case VERSION_SDXL_REFINER:
350+
case VERSION_SDXL_INPAINT: {
351+
if (version != VERSION_SDXL_REFINER) {
352+
clip_l_wtype = model_loader.get_conditioner_wtype({"cond_stage_model.transformer.", "text_encoders.clip_l."});
353+
if (clip_l_wtype == GGML_TYPE_COUNT) {
354+
clip_l_wtype = wtype;
355+
}
356+
}
357+
if (sd_version_is_sdxl(version)) {
358+
clip_g_wtype = model_loader.get_conditioner_wtype({"cond_stage_model.1.", "text_encoders.clip_g."});
359+
if (clip_g_wtype == GGML_TYPE_COUNT) {
360+
clip_g_wtype = wtype;
361+
}
362+
}
363+
break;
364+
}
365+
case VERSION_SD3: {
366+
clip_l_wtype = model_loader.get_conditioner_wtype({"cond_stage_model.transformer.", "text_encoders.clip_l."});
367+
if (clip_l_wtype == GGML_TYPE_COUNT) {
368+
clip_l_wtype = wtype;
369+
}
370+
clip_g_wtype = model_loader.get_conditioner_wtype({"cond_stage_model.1.", "text_encoders.clip_g."});
371+
if (clip_g_wtype == GGML_TYPE_COUNT) {
372+
clip_g_wtype = wtype;
373+
}
374+
t5xxl_wtype = model_loader.get_conditioner_wtype({"cond_stage_model.2.", "text_encoders.t5xxl."});
375+
if (t5xxl_wtype == GGML_TYPE_COUNT) {
376+
t5xxl_wtype = wtype;
377+
}
378+
break;
379+
}
380+
case VERSION_FLUX:
381+
case VERSION_FLUX_FILL: {
382+
clip_l_wtype = model_loader.get_conditioner_wtype({"cond_stage_model.transformer.", "text_encoders.clip_l."});
383+
if (clip_l_wtype == GGML_TYPE_COUNT) {
384+
clip_l_wtype = wtype;
385+
}
386+
t5xxl_wtype = model_loader.get_conditioner_wtype({"cond_stage_model.1.", "text_encoders.t5xxl."});
387+
if (t5xxl_wtype == GGML_TYPE_COUNT) {
388+
t5xxl_wtype = wtype;
389+
}
390+
break;
391+
}
343392
}
344393
diffusion_model_wtype = model_loader.get_diffusion_model_wtype();
345394
if (diffusion_model_wtype == GGML_TYPE_COUNT) {
@@ -352,7 +401,9 @@ class StableDiffusionGGML {
352401
}
353402
} else {
354403
model_wtype = wtype;
355-
conditioner_wtype = wtype;
404+
clip_l_wtype = wtype;
405+
clip_g_wtype = wtype;
406+
t5xxl_wtype = wtype;
356407
diffusion_model_wtype = wtype;
357408
vae_wtype = wtype;
358409
model_loader.set_wtype_override(wtype);
@@ -363,22 +414,17 @@ class StableDiffusionGGML {
363414
model_loader.set_wtype_override(GGML_TYPE_F32, "vae.");
364415
}
365416

366-
LOG_INFO("Weight type: %s", model_wtype != SD_TYPE_COUNT ? ggml_type_name(model_wtype) : "??");
367-
LOG_INFO("Conditioner weight type: %s", conditioner_wtype != SD_TYPE_COUNT ? ggml_type_name(conditioner_wtype) : "??");
368-
LOG_INFO("Diffusion model weight type: %s", diffusion_model_wtype != SD_TYPE_COUNT ? ggml_type_name(diffusion_model_wtype) : "??");
369-
LOG_INFO("VAE weight type: %s", vae_wtype != SD_TYPE_COUNT ? ggml_type_name(vae_wtype) : "??");
417+
LOG_INFO("Weight type: %s", model_wtype != GGML_TYPE_COUNT ? ggml_type_name(model_wtype) : "??");
418+
LOG_INFO("CLIP_L weight type: %s", clip_l_wtype != GGML_TYPE_COUNT ? ggml_type_name(clip_l_wtype) : "??");
419+
LOG_INFO("CLIP_G weight type: %s", clip_g_wtype != GGML_TYPE_COUNT ? ggml_type_name(clip_g_wtype) : "??");
420+
LOG_INFO("T5XXL weight type: %s", t5xxl_wtype != GGML_TYPE_COUNT ? ggml_type_name(t5xxl_wtype) : "??");
421+
LOG_INFO("Diffusion model weight type: %s", diffusion_model_wtype != GGML_TYPE_COUNT ? ggml_type_name(diffusion_model_wtype) : "??");
422+
LOG_INFO("VAE weight type: %s", vae_wtype != GGML_TYPE_COUNT ? ggml_type_name(vae_wtype) : "??");
370423

371424
LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor));
372425

373426
if (sd_version_is_sdxl(version)) {
374427
scale_factor = 0.13025f;
375-
if (vae_path.size() == 0 && taesd_path.size() == 0) {
376-
LOG_WARN(
377-
"!!!It looks like you are using SDXL model. "
378-
"If you find that the generated images are completely black, "
379-
"try specifying SDXL VAE FP16 Fix with the --vae parameter. "
380-
"You can find it here: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors");
381-
}
382428
} else if (sd_version_is_sd3(version)) {
383429
scale_factor = 1.5305f;
384430
} else if (sd_version_is_flux(version)) {

0 commit comments

Comments
 (0)