Skip to content

Commit f782ffe

Browse files
committed
feat: convert
Signed-off-by: thxCode <thxcode0824@gmail.com>
1 parent 461ff9c commit f782ffe

16 files changed

+934
-180
lines changed

LICENSE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
MIT License
22

3+
Copyright (c) 2024 thxCode
34
Copyright (c) 2023 leejet
45

56
Permission is hereby granted, free of charge, to any person obtaining a copy

conditioner.hpp

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -63,51 +63,64 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
6363
: version(version), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir), wtype(wtype) {
6464
if (clip_skip <= 0) {
6565
clip_skip = 1;
66-
if (version == VERSION_SD2 || version == VERSION_SDXL) {
66+
if (version == VERSION_SD2 || version == VERSION_SDXL_BASE || version == VERSION_SDXL_REFINER) {
6767
clip_skip = 2;
6868
}
6969
}
7070
if (version == VERSION_SD1) {
7171
text_model = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip);
7272
} else if (version == VERSION_SD2) {
7373
text_model = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPEN_CLIP_VIT_H_14, clip_skip);
74-
} else if (version == VERSION_SDXL) {
74+
} else if (version == VERSION_SDXL_BASE) {
7575
text_model = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip, false);
7676
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
77+
} else if (version == VERSION_SDXL_REFINER) {
78+
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
7779
}
7880
}
7981

8082
void set_clip_skip(int clip_skip) {
81-
text_model->set_clip_skip(clip_skip);
82-
if (version == VERSION_SDXL) {
83+
if (version != VERSION_SDXL_REFINER) {
84+
text_model->set_clip_skip(clip_skip);
85+
}
86+
if (version == VERSION_SDXL_BASE || version == VERSION_SDXL_REFINER) {
8387
text_model2->set_clip_skip(clip_skip);
8488
}
8589
}
8690

8791
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
88-
text_model->get_param_tensors(tensors, "cond_stage_model.transformer.text_model");
89-
if (version == VERSION_SDXL) {
92+
if (version != VERSION_SDXL_REFINER) {
93+
text_model->get_param_tensors(tensors, "cond_stage_model.transformer.text_model");
94+
}
95+
if (version == VERSION_SDXL_BASE || version == VERSION_SDXL_REFINER) {
9096
text_model2->get_param_tensors(tensors, "cond_stage_model.1.transformer.text_model");
9197
}
9298
}
9399

94100
void alloc_params_buffer() {
95-
text_model->alloc_params_buffer();
96-
if (version == VERSION_SDXL) {
101+
if (version != VERSION_SDXL_REFINER) {
102+
text_model->alloc_params_buffer();
103+
}
104+
if (version == VERSION_SDXL_BASE || version == VERSION_SDXL_REFINER) {
97105
text_model2->alloc_params_buffer();
98106
}
99107
}
100108

101109
void free_params_buffer() {
102-
text_model->free_params_buffer();
103-
if (version == VERSION_SDXL) {
110+
if (version != VERSION_SDXL_REFINER) {
111+
text_model->free_params_buffer();
112+
}
113+
if (version == VERSION_SDXL_BASE || version == VERSION_SDXL_REFINER) {
104114
text_model2->free_params_buffer();
105115
}
106116
}
107117

108118
size_t get_params_buffer_size() {
109-
size_t buffer_size = text_model->get_params_buffer_size();
110-
if (version == VERSION_SDXL) {
119+
size_t buffer_size = 0;
120+
if (version != VERSION_SDXL_REFINER) {
121+
buffer_size = text_model->get_params_buffer_size();
122+
}
123+
if (version == VERSION_SDXL_BASE || version == VERSION_SDXL_REFINER) {
111124
buffer_size += text_model2->get_params_buffer_size();
112125
}
113126
return buffer_size;
@@ -398,7 +411,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
398411
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
399412
struct ggml_tensor* input_ids2 = NULL;
400413
size_t max_token_idx = 0;
401-
if (version == VERSION_SDXL) {
414+
if (version == VERSION_SDXL_BASE || version == VERSION_SDXL_REFINER) {
402415
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), tokenizer.EOS_TOKEN_ID);
403416
if (it != chunk_tokens.end()) {
404417
std::fill(std::next(it), chunk_tokens.end(), 0);
@@ -415,15 +428,17 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
415428
}
416429

417430
{
418-
text_model->compute(n_threads,
419-
input_ids,
420-
num_custom_embeddings,
421-
token_embed_custom.data(),
422-
max_token_idx,
423-
false,
424-
&chunk_hidden_states1,
425-
work_ctx);
426-
if (version == VERSION_SDXL) {
431+
if (version != VERSION_SDXL_REFINER) {
432+
text_model->compute(n_threads,
433+
input_ids,
434+
num_custom_embeddings,
435+
token_embed_custom.data(),
436+
max_token_idx,
437+
false,
438+
&chunk_hidden_states1,
439+
work_ctx);
440+
}
441+
if (version == VERSION_SDXL_BASE || version == VERSION_SDXL_REFINER) {
427442
text_model2->compute(n_threads,
428443
input_ids2,
429444
0,
@@ -482,7 +497,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
482497
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
483498

484499
ggml_tensor* vec = NULL;
485-
if (version == VERSION_SDXL) {
500+
if (version == VERSION_SDXL_BASE || version == VERSION_SDXL_REFINER) {
486501
int out_dim = 256;
487502
vec = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, adm_in_channels);
488503
// [0:1280]
@@ -623,6 +638,7 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner {
623638

624639
struct SD3CLIPEmbedder : public Conditioner {
625640
ggml_type wtype;
641+
bool compvis_compatiblity;
626642
CLIPTokenizer clip_l_tokenizer;
627643
CLIPTokenizer clip_g_tokenizer;
628644
T5UniGramTokenizer t5_tokenizer;
@@ -632,8 +648,9 @@ struct SD3CLIPEmbedder : public Conditioner {
632648

633649
SD3CLIPEmbedder(ggml_backend_t backend,
634650
ggml_type wtype,
635-
int clip_skip = -1)
636-
: wtype(wtype), clip_g_tokenizer(0) {
651+
bool compvis_compatiblity = false,
652+
int clip_skip = -1)
653+
: wtype(wtype), compvis_compatiblity(compvis_compatiblity), clip_g_tokenizer(0) {
637654
if (clip_skip <= 0) {
638655
clip_skip = 2;
639656
}
@@ -648,6 +665,12 @@ struct SD3CLIPEmbedder : public Conditioner {
648665
}
649666

650667
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
668+
if (compvis_compatiblity) {
669+
clip_l->get_param_tensors(tensors, "cond_stage_model.transformer.text_model");
670+
clip_g->get_param_tensors(tensors, "cond_stage_model.1.transformer.text_model");
671+
t5->get_param_tensors(tensors, "cond_stage_model.2.transformer");
672+
return;
673+
}
651674
clip_l->get_param_tensors(tensors, "text_encoders.clip_l.transformer.text_model");
652675
clip_g->get_param_tensors(tensors, "text_encoders.clip_g.transformer.text_model");
653676
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");

control.hpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,28 +23,38 @@ class ControlNetBlock : public GGMLBlock {
2323
std::vector<int> attention_resolutions = {4, 2, 1};
2424
std::vector<int> channel_mult = {1, 2, 4, 4};
2525
std::vector<int> transformer_depth = {1, 1, 1, 1};
26-
int time_embed_dim = 1280; // model_channels*4
26+
int time_embed_dim = 1280; // model_channels*4, 1536 for VERSION_SDXL_REFINER
2727
int num_heads = 8;
2828
int num_head_channels = -1; // channels // num_heads
29-
int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL
29+
int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL_BASE, 1280 for VERSION_SDXL_REFINER
3030

3131
public:
32-
int model_channels = 320;
33-
int adm_in_channels = 2816; // only for VERSION_SDXL
32+
int model_channels = 320; // 384 for VERSION_SDXL_REFINER
33+
int adm_in_channels = 2816; // 2816 for VERSION_SDXL_BASE/SVD, 2560 for VERSION_SDXL_REFINER
3434

3535
ControlNetBlock(SDVersion version = VERSION_SD1)
3636
: version(version) {
3737
if (version == VERSION_SD2) {
3838
context_dim = 1024;
3939
num_head_channels = 64;
4040
num_heads = -1;
41-
} else if (version == VERSION_SDXL) {
41+
} else if (version == VERSION_SDXL_BASE) {
4242
context_dim = 2048;
4343
attention_resolutions = {4, 2};
4444
channel_mult = {1, 2, 4};
4545
transformer_depth = {1, 2, 10};
4646
num_head_channels = 64;
4747
num_heads = -1;
48+
} else if (version == VERSION_SDXL_REFINER) {
49+
time_embed_dim = 1536;
50+
context_dim = 1280;
51+
model_channels = 384;
52+
adm_in_channels = 2560;
53+
attention_resolutions = {4, 2};
54+
channel_mult = {1, 2, 4};
55+
transformer_depth = {1, 2, 10};
56+
num_head_channels = 64;
57+
num_heads = -1;
4858
} else if (version == VERSION_SVD) {
4959
in_channels = 8;
5060
out_channels = 4;
@@ -58,7 +68,7 @@ class ControlNetBlock : public GGMLBlock {
5868
// time_embed_1 is nn.SiLU()
5969
blocks["time_embed.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
6070

61-
if (version == VERSION_SDXL || version == VERSION_SVD) {
71+
if (version == VERSION_SDXL_BASE || version == VERSION_SDXL_REFINER || version == VERSION_SVD) {
6272
blocks["label_emb.0.0"] = std::shared_ptr<GGMLBlock>(new Linear(adm_in_channels, time_embed_dim));
6373
// label_emb_1 is nn.SiLU()
6474
blocks["label_emb.0.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));

denoiser.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,8 @@ struct AYSSchedule : SigmaSchedule {
175175
LOG_INFO("AYS using SD1.5 noise levels");
176176
inputs = noise_levels[0];
177177
break;
178-
case VERSION_SDXL:
178+
case VERSION_SDXL_BASE:
179+
case VERSION_SDXL_REFINER:
179180
LOG_INFO("AYS using SDXL noise levels");
180181
inputs = noise_levels[1];
181182
break;

examples/cli/main.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ struct SDParams {
5050
std::string embeddings_path;
5151
std::string stacked_id_embeddings_path;
5252
std::string input_id_images_path;
53-
sd_type_t wtype = SD_TYPE_COUNT;
53+
ggml_type wtype = GGML_TYPE_COUNT;
5454
std::string lora_model_dir;
5555
std::string output_path = "output.png";
5656
std::string input_path;
@@ -90,7 +90,7 @@ void print_params(SDParams params) {
9090
printf(" n_threads: %d\n", params.n_threads);
9191
printf(" mode: %s\n", modes_str[params.mode]);
9292
printf(" model_path: %s\n", params.model_path.c_str());
93-
printf(" wtype: %s\n", params.wtype < SD_TYPE_COUNT ? sd_type_name(params.wtype) : "unspecified");
93+
printf(" wtype: %s\n", params.wtype < GGML_TYPE_COUNT ? ggml_type_name(params.wtype) : "unspecified");
9494
printf(" clip_l_path: %s\n", params.clip_l_path.c_str());
9595
printf(" clip_g_path: %s\n", params.clip_g_path.c_str());
9696
printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str());
@@ -295,25 +295,25 @@ void parse_args(int argc, const char** argv, SDParams& params) {
295295
}
296296
std::string type = argv[i];
297297
if (type == "f32") {
298-
params.wtype = SD_TYPE_F32;
298+
params.wtype = GGML_TYPE_F32;
299299
} else if (type == "f16") {
300-
params.wtype = SD_TYPE_F16;
300+
params.wtype = GGML_TYPE_F16;
301301
} else if (type == "q4_0") {
302-
params.wtype = SD_TYPE_Q4_0;
302+
params.wtype = GGML_TYPE_Q4_0;
303303
} else if (type == "q4_1") {
304-
params.wtype = SD_TYPE_Q4_1;
304+
params.wtype = GGML_TYPE_Q4_1;
305305
} else if (type == "q5_0") {
306-
params.wtype = SD_TYPE_Q5_0;
306+
params.wtype = GGML_TYPE_Q5_0;
307307
} else if (type == "q5_1") {
308-
params.wtype = SD_TYPE_Q5_1;
308+
params.wtype = GGML_TYPE_Q5_1;
309309
} else if (type == "q8_0") {
310-
params.wtype = SD_TYPE_Q8_0;
310+
params.wtype = GGML_TYPE_Q8_0;
311311
} else if (type == "q2_k") {
312-
params.wtype = SD_TYPE_Q2_K;
312+
params.wtype = GGML_TYPE_Q2_K;
313313
} else if (type == "q3_k") {
314-
params.wtype = SD_TYPE_Q3_K;
314+
params.wtype = GGML_TYPE_Q3_K;
315315
} else if (type == "q4_k") {
316-
params.wtype = SD_TYPE_Q4_K;
316+
params.wtype = GGML_TYPE_Q4_K;
317317
} else {
318318
fprintf(stderr, "error: invalid weight format %s, must be one of [f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k]\n",
319319
type.c_str());
@@ -650,7 +650,7 @@ int main(int argc, const char* argv[]) {
650650

651651
parse_args(argc, argv, params);
652652

653-
sd_set_log_callback(sd_log_cb, (void*)&params);
653+
sd_log_set(sd_log_cb, (void*)&params);
654654

655655
if (params.verbose) {
656656
print_params(params);

examples/convert/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
set(TARGET stable-diffusion-convert)
2+
3+
add_executable(${TARGET} main.cpp)
4+
install(TARGETS ${TARGET} RUNTIME)
5+
target_link_libraries(${TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT})
6+
target_compile_features(${TARGET} PUBLIC cxx_std_11)

0 commit comments

Comments
 (0)