Skip to content

Commit 0f54865

Browse files
committed
feat: support sd 3.5 medium
Signed-off-by: thxCode <thxcode0824@gmail.com>
1 parent fd4dae5 commit 0f54865

File tree

13 files changed

+365
-131
lines changed

13 files changed

+365
-131
lines changed

conditioner.hpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,15 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
6363
: version(version), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir), wtype(wtype) {
6464
if (clip_skip <= 0) {
6565
clip_skip = 1;
66-
if (version == VERSION_SD2 || version == VERSION_SDXL_BASE || version == VERSION_SDXL_REFINER) {
66+
if (version == VERSION_SD2 || version == VERSION_SDXL || version == VERSION_SDXL_REFINER) {
6767
clip_skip = 2;
6868
}
6969
}
7070
if (version == VERSION_SD1) {
7171
text_model = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip);
7272
} else if (version == VERSION_SD2) {
7373
text_model = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPEN_CLIP_VIT_H_14, clip_skip);
74-
} else if (version == VERSION_SDXL_BASE) {
74+
} else if (version == VERSION_SDXL) {
7575
text_model = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip, false);
7676
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
7777
} else if (version == VERSION_SDXL_REFINER) {
@@ -83,7 +83,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
8383
if (version != VERSION_SDXL_REFINER) {
8484
text_model->set_clip_skip(clip_skip);
8585
}
86-
if (version == VERSION_SDXL_BASE || version == VERSION_SDXL_REFINER) {
86+
if (version == VERSION_SDXL || version == VERSION_SDXL_REFINER) {
8787
text_model2->set_clip_skip(clip_skip);
8888
}
8989
}
@@ -92,7 +92,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
9292
if (version != VERSION_SDXL_REFINER) {
9393
text_model->get_param_tensors(tensors, "cond_stage_model.transformer.text_model");
9494
}
95-
if (version == VERSION_SDXL_BASE || version == VERSION_SDXL_REFINER) {
95+
if (version == VERSION_SDXL || version == VERSION_SDXL_REFINER) {
9696
text_model2->get_param_tensors(tensors, "cond_stage_model.1.transformer.text_model");
9797
}
9898
}
@@ -101,7 +101,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
101101
if (version != VERSION_SDXL_REFINER) {
102102
text_model->alloc_params_buffer();
103103
}
104-
if (version == VERSION_SDXL_BASE || version == VERSION_SDXL_REFINER) {
104+
if (version == VERSION_SDXL || version == VERSION_SDXL_REFINER) {
105105
text_model2->alloc_params_buffer();
106106
}
107107
}
@@ -110,7 +110,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
110110
if (version != VERSION_SDXL_REFINER) {
111111
text_model->free_params_buffer();
112112
}
113-
if (version == VERSION_SDXL_BASE || version == VERSION_SDXL_REFINER) {
113+
if (version == VERSION_SDXL || version == VERSION_SDXL_REFINER) {
114114
text_model2->free_params_buffer();
115115
}
116116
}
@@ -120,7 +120,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
120120
if (version != VERSION_SDXL_REFINER) {
121121
buffer_size = text_model->get_params_buffer_size();
122122
}
123-
if (version == VERSION_SDXL_BASE || version == VERSION_SDXL_REFINER) {
123+
if (version == VERSION_SDXL || version == VERSION_SDXL_REFINER) {
124124
buffer_size += text_model2->get_params_buffer_size();
125125
}
126126
return buffer_size;
@@ -411,7 +411,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
411411
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
412412
struct ggml_tensor* input_ids2 = NULL;
413413
size_t max_token_idx = 0;
414-
if (version == VERSION_SDXL_BASE || version == VERSION_SDXL_REFINER) {
414+
if (version == VERSION_SDXL || version == VERSION_SDXL_REFINER) {
415415
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), tokenizer.EOS_TOKEN_ID);
416416
if (it != chunk_tokens.end()) {
417417
std::fill(std::next(it), chunk_tokens.end(), 0);
@@ -438,7 +438,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
438438
&chunk_hidden_states1,
439439
work_ctx);
440440
}
441-
if (version == VERSION_SDXL_BASE || version == VERSION_SDXL_REFINER) {
441+
if (version == VERSION_SDXL || version == VERSION_SDXL_REFINER) {
442442
text_model2->compute(n_threads,
443443
input_ids2,
444444
0,
@@ -497,7 +497,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
497497
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
498498

499499
ggml_tensor* vec = NULL;
500-
if (version == VERSION_SDXL_BASE || version == VERSION_SDXL_REFINER) {
500+
if (version == VERSION_SDXL || version == VERSION_SDXL_REFINER) {
501501
int out_dim = 256;
502502
vec = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, adm_in_channels);
503503
// [0:1280]

control.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class ControlNetBlock : public GGMLBlock {
3838
context_dim = 1024;
3939
num_head_channels = 64;
4040
num_heads = -1;
41-
} else if (version == VERSION_SDXL_BASE) {
41+
} else if (version == VERSION_SDXL) {
4242
context_dim = 2048;
4343
attention_resolutions = {4, 2};
4444
channel_mult = {1, 2, 4};
@@ -68,7 +68,7 @@ class ControlNetBlock : public GGMLBlock {
6868
// time_embed_1 is nn.SiLU()
6969
blocks["time_embed.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
7070

71-
if (version == VERSION_SDXL_BASE || version == VERSION_SDXL_REFINER || version == VERSION_SVD) {
71+
if (version == VERSION_SDXL || version == VERSION_SDXL_REFINER || version == VERSION_SVD) {
7272
blocks["label_emb.0.0"] = std::shared_ptr<GGMLBlock>(new Linear(adm_in_channels, time_embed_dim));
7373
// label_emb_1 is nn.SiLU()
7474
blocks["label_emb.0.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));

denoiser.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ struct AYSSchedule : SigmaSchedule {
175175
LOG_INFO("AYS using SD1.5 noise levels");
176176
inputs = noise_levels[0];
177177
break;
178-
case VERSION_SDXL_BASE:
178+
case VERSION_SDXL:
179179
case VERSION_SDXL_REFINER:
180180
LOG_INFO("AYS using SDXL noise levels");
181181
inputs = noise_levels[1];

diffusion_model.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ struct MMDiTModel : public DiffusionModel {
8080

8181
MMDiTModel(ggml_backend_t backend,
8282
ggml_type wtype,
83-
SDVersion version = VERSION_SD3_2B)
83+
SDVersion version = VERSION_SD3_MEDIUM)
8484
: mmdit(backend, wtype, version) {
8585
}
8686

examples/convert/main.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -489,9 +489,14 @@ int main(int argc, char** argv) {
489489
auto transformer_config = load_json(transformer_config_path);
490490
auto num_layers = transformer_config.at("num_layers").get<int>();
491491
if (num_layers == 38) {
492-
ver = VERSION_SD3_5_8B;
492+
ver = VERSION_SD3_5_LARGE;
493493
} else {
494-
ver = VERSION_SD3_2B;
494+
auto pos_embed_max_size = transformer_config.at("pos_embed_max_size").get<int>();
495+
if (pos_embed_max_size == 384) {
496+
ver = VERSION_SD3_5_LARGE;
497+
} else {
498+
ver = VERSION_SD3_MEDIUM;
499+
}
495500
}
496501
} else if (class_name == "FluxPipeline") {
497502
auto text_encoder_config_path = path_join(params.model_path, "text_encoder/config.json");
@@ -507,7 +512,7 @@ int main(int argc, char** argv) {
507512
ver = VERSION_FLUX_SCHNELL;
508513
}
509514
} else if (class_name == "StableDiffusionXLPipeline" || class_name == "StableDiffusionXLImg2ImgPipeline") {
510-
ver = VERSION_SDXL_BASE;
515+
ver = VERSION_SDXL;
511516
} else if (class_name == "StableDiffusionPipeline") {
512517
auto text_encoder_config_path = path_join(params.model_path, "text_encoder/config.json");
513518
if (!file_exists(text_encoder_config_path)) {
@@ -529,13 +534,14 @@ int main(int argc, char** argv) {
529534
}
530535

531536
switch (ver) {
532-
case VERSION_SD3_5_8B:
533-
case VERSION_SD3_2B:
537+
case VERSION_SD3_5_LARGE:
538+
case VERSION_SD3_5_MEDIUM:
539+
case VERSION_SD3_MEDIUM:
534540
return convert_sd3(params, ver);
535541
case VERSION_FLUX_DEV:
536542
case VERSION_FLUX_SCHNELL:
537543
return convert_flux(params, ver);
538-
case VERSION_SDXL_BASE:
544+
case VERSION_SDXL:
539545
case VERSION_SDXL_REFINER:
540546
return convert_sdxl(params, ver);
541547
case VERSION_SD2:

0 commit comments

Comments
 (0)