Skip to content

Commit 7d9176a

Browse files
committed
feat(tx): sdxl refiner
Signed-off-by: thxCode <thxcode0824@gmail.com>
1 parent afae8f5 commit 7d9176a

File tree

7 files changed

+89
-40
lines changed

7 files changed

+89
-40
lines changed

conditioner.hpp

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -73,42 +73,52 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
7373
} else if (sd_version_is_sd2(version)) {
7474
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14, clip_skip);
7575
} else if (sd_version_is_sdxl(version)) {
76-
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, false);
76+
if (version != VERSION_SDXL_REFINER) {
77+
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, false);
78+
}
7779
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "cond_stage_model.1.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
7880
}
7981
}
8082

8183
void set_clip_skip(int clip_skip) {
82-
text_model->set_clip_skip(clip_skip);
83-
if (sd_version_is_sdxl(version)) {
84+
if (text_model) {
85+
text_model->set_clip_skip(clip_skip);
86+
}
87+
if (text_model2) {
8488
text_model2->set_clip_skip(clip_skip);
8589
}
8690
}
8791

8892
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
89-
text_model->get_param_tensors(tensors, "cond_stage_model.transformer.text_model");
90-
if (sd_version_is_sdxl(version)) {
93+
if (text_model) {
94+
text_model->get_param_tensors(tensors, "cond_stage_model.transformer.text_model");
95+
}
96+
if (text_model2) {
9197
text_model2->get_param_tensors(tensors, "cond_stage_model.1.transformer.text_model");
9298
}
9399
}
94100

95101
void alloc_params_buffer() {
96-
text_model->alloc_params_buffer();
97-
if (sd_version_is_sdxl(version)) {
102+
if (text_model) {
103+
text_model->alloc_params_buffer();
104+
}
105+
if (text_model2) {
98106
text_model2->alloc_params_buffer();
99107
}
100108
}
101109

102110
void free_params_buffer() {
103-
text_model->free_params_buffer();
104-
if (sd_version_is_sdxl(version)) {
111+
if (text_model) {
112+
text_model->free_params_buffer();
113+
}
114+
if (text_model2) {
105115
text_model2->free_params_buffer();
106116
}
107117
}
108118

109119
size_t get_params_buffer_size() {
110-
size_t buffer_size = text_model->get_params_buffer_size();
111-
if (sd_version_is_sdxl(version)) {
120+
size_t buffer_size = text_model ? text_model->get_params_buffer_size() : 0;
121+
if (text_model2) {
112122
buffer_size += text_model2->get_params_buffer_size();
113123
}
114124
return buffer_size;
@@ -131,7 +141,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
131141
params.no_alloc = false;
132142
struct ggml_context* embd_ctx = ggml_init(params);
133143
struct ggml_tensor* embd = NULL;
134-
int64_t hidden_size = text_model->model.hidden_size;
144+
int64_t hidden_size = text_model ? text_model->model.hidden_size : text_model2->model.hidden_size;
135145
auto on_load = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) {
136146
if (tensor_storage.ne[0] != hidden_size) {
137147
LOG_DEBUG("embedding wrong hidden size, got %i, expected %i", tensor_storage.ne[0], hidden_size);
@@ -148,7 +158,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
148158
embd->data,
149159
ggml_nbytes(embd));
150160
for (int i = 0; i < embd->ne[1]; i++) {
151-
bpe_tokens.push_back(text_model->model.vocab_size + num_custom_embeddings);
161+
bpe_tokens.push_back((text_model ? text_model->model.vocab_size : text_model2->model.vocab_size) + num_custom_embeddings);
152162
// LOG_DEBUG("new custom token: %i", text_model.vocab_size + num_custom_embeddings);
153163
num_custom_embeddings++;
154164
}
@@ -162,7 +172,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
162172
int32_t image_token,
163173
bool padding = false) {
164174
return tokenize_with_trigger_token(text, num_input_imgs, image_token,
165-
text_model->model.n_token, padding);
175+
text_model ? text_model->model.n_token : text_model2->model.n_token, padding);
166176
}
167177

168178
std::vector<int> convert_token_to_id(std::string text) {
@@ -311,7 +321,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
311321

312322
std::pair<std::vector<int>, std::vector<float>> tokenize(std::string text,
313323
bool padding = false) {
314-
return tokenize(text, text_model->model.n_token, padding);
324+
return tokenize(text, text_model ? text_model->model.n_token : text_model2->model.n_token, padding);
315325
}
316326

317327
std::pair<std::vector<int>, std::vector<float>> tokenize(std::string text,
@@ -419,28 +429,31 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
419429
}
420430

421431
{
422-
text_model->compute(n_threads,
423-
input_ids,
424-
num_custom_embeddings,
425-
token_embed_custom.data(),
426-
max_token_idx,
427-
false,
428-
&chunk_hidden_states1,
429-
work_ctx);
430-
if (sd_version_is_sdxl(version)) {
432+
if (text_model) {
433+
text_model->compute(n_threads,
434+
input_ids,
435+
num_custom_embeddings,
436+
token_embed_custom.data(),
437+
max_token_idx,
438+
false,
439+
&chunk_hidden_states1,
440+
work_ctx);
441+
}
442+
if (text_model2) {
431443
text_model2->compute(n_threads,
432-
input_ids2,
433-
0,
434-
NULL,
444+
text_model ? input_ids2 : input_ids,
445+
text_model ? 0 : num_custom_embeddings,
446+
text_model ? NULL : token_embed_custom.data(),
435447
max_token_idx,
436448
false,
437-
&chunk_hidden_states2, work_ctx);
449+
text_model ? &chunk_hidden_states2 : &chunk_hidden_states1,
450+
work_ctx);
438451
// concat
439-
chunk_hidden_states = ggml_tensor_concat(work_ctx, chunk_hidden_states1, chunk_hidden_states2, 0);
452+
chunk_hidden_states = text_model ? ggml_tensor_concat(work_ctx, chunk_hidden_states1, chunk_hidden_states2, 0) : chunk_hidden_states1;
440453

441454
if (chunk_idx == 0) {
442455
text_model2->compute(n_threads,
443-
input_ids2,
456+
text_model ? input_ids2 : input_ids,
444457
0,
445458
NULL,
446459
max_token_idx,
@@ -486,7 +499,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
486499
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
487500

488501
ggml_tensor* vec = NULL;
489-
if (sd_version_is_sdxl(version)) {
502+
if (sd_version_is_sdxl(version) && version != VERSION_SDXL_REFINER) {
490503
int out_dim = 256;
491504
vec = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, adm_in_channels);
492505
// [0:1280]

control.hpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@ class ControlNetBlock : public GGMLBlock {
2323
std::vector<int> attention_resolutions = {4, 2, 1};
2424
std::vector<int> channel_mult = {1, 2, 4, 4};
2525
std::vector<int> transformer_depth = {1, 1, 1, 1};
26-
int time_embed_dim = 1280; // model_channels*4
26+
int time_embed_dim = 1280; // model_channels*4, 1536 for VERSION_SDXL_REFINER
2727
int num_heads = 8;
2828
int num_head_channels = -1; // channels // num_heads
29-
int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL
29+
int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL, 1280 for VERSION_SDXL_REFINER
3030

3131
public:
32-
int model_channels = 320;
33-
int adm_in_channels = 2816; // only for VERSION_SDXL
32+
int model_channels = 320; // 384 for VERSION_SDXL_REFINER
33+
int adm_in_channels = 2816; // 2816 for VERSION_SDXL/SVD, 2560 for VERSION_SDXL_REFINER
3434

3535
ControlNetBlock(SDVersion version = VERSION_SD1)
3636
: version(version) {
@@ -45,6 +45,17 @@ class ControlNetBlock : public GGMLBlock {
4545
transformer_depth = {1, 2, 10};
4646
num_head_channels = 64;
4747
num_heads = -1;
48+
if (version == VERSION_SDXL_REFINER) {
49+
time_embed_dim = 1536;
50+
context_dim = 1280;
51+
model_channels = 384;
52+
adm_in_channels = 2560;
53+
attention_resolutions = {4, 2};
54+
channel_mult = {1, 2, 4, 4};
55+
transformer_depth = {4, 4, 4, 4};
56+
num_head_channels = 64;
57+
num_heads = -1;
58+
}
4859
} else if (version == VERSION_SVD) {
4960
in_channels = 8;
5061
out_channels = 4;

denoiser.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,17 @@ struct AYSSchedule : SigmaSchedule {
170170

171171
switch (version) {
172172
case VERSION_SD2: /* fallthrough */
173+
case VERSION_SD2_INPAINT:
173174
LOG_WARN("AYS not designed for SD2.X models");
175+
return results;
174176
case VERSION_SD1:
177+
case VERSION_SD1_INPAINT:
175178
LOG_INFO("AYS using SD1.5 noise levels");
176179
inputs = noise_levels[0];
177180
break;
178181
case VERSION_SDXL:
182+
case VERSION_SDXL_REFINER:
183+
case VERSION_SDXL_INPAINT:
179184
LOG_INFO("AYS using SDXL noise levels");
180185
inputs = noise_levels[1];
181186
break;

model.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1474,6 +1474,7 @@ SDVersion ModelLoader::get_sd_version() {
14741474

14751475
bool is_xl = false;
14761476
bool is_flux = false;
1477+
bool is_refiner = false;
14771478

14781479
#define found_family (is_xl || is_flux)
14791480
for (auto& tensor_storage : tensor_storages) {
@@ -1505,6 +1506,9 @@ SDVersion ModelLoader::get_sd_version() {
15051506
}
15061507
}
15071508
}
1509+
if (tensor_storage.name.find("model.diffusion_model.output_blocks.11.0.") != std::string::npos) {
1510+
is_refiner = true;
1511+
}
15081512
if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) {
15091513
return VERSION_SVD;
15101514
}
@@ -1528,6 +1532,9 @@ SDVersion ModelLoader::get_sd_version() {
15281532
}
15291533
bool is_inpaint = input_block_weight.ne[2] == 9;
15301534
if (is_xl) {
1535+
if (is_refiner) {
1536+
return VERSION_SDXL_REFINER;
1537+
}
15311538
if (is_inpaint) {
15321539
return VERSION_SDXL_INPAINT;
15331540
}

model.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ enum SDVersion {
2323
VERSION_SD2,
2424
VERSION_SD2_INPAINT,
2525
VERSION_SDXL,
26+
VERSION_SDXL_REFINER,
2627
VERSION_SDXL_INPAINT,
2728
VERSION_SVD,
2829
VERSION_SD3,
@@ -60,7 +61,7 @@ static inline bool sd_version_is_sd2(SDVersion version) {
6061
}
6162

6263
static inline bool sd_version_is_sdxl(SDVersion version) {
63-
if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT) {
64+
if (version == VERSION_SDXL || version == VERSION_SDXL_REFINER || version == VERSION_SDXL_INPAINT) {
6465
return true;
6566
}
6667
return false;

stable-diffusion.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ const char* model_version_to_str[] = {
3232
"SD 2.x",
3333
"SD 2.x Inpaint",
3434
"SDXL",
35+
"SDXL Refiner",
3536
"SDXL Inpaint",
3637
"SVD",
3738
"SD3.x",

unet.hpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,14 +175,14 @@ class UnetModelBlock : public GGMLBlock {
175175
std::vector<int> attention_resolutions = {4, 2, 1};
176176
std::vector<int> channel_mult = {1, 2, 4, 4};
177177
std::vector<int> transformer_depth = {1, 1, 1, 1};
178-
int time_embed_dim = 1280; // model_channels*4
178+
int time_embed_dim = 1280; // model_channels*4, 1536 for VERSION_SDXL_REFINER
179179
int num_heads = 8;
180180
int num_head_channels = -1; // channels // num_heads
181-
int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL
181+
int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL, 1280 for VERSION_SDXL_REFINER
182182

183183
public:
184-
int model_channels = 320;
185-
int adm_in_channels = 2816; // only for VERSION_SDXL/SVD
184+
int model_channels = 320; // 384 for VERSION_SDXL_REFINER
185+
int adm_in_channels = 2816; // 2816 for VERSION_SDXL/SVD, 2560 for VERSION_SDXL_REFINER
186186

187187
UnetModelBlock(SDVersion version = VERSION_SD1, std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types, bool flash_attn = false)
188188
: version(version) {
@@ -197,6 +197,17 @@ class UnetModelBlock : public GGMLBlock {
197197
transformer_depth = {1, 2, 10};
198198
num_head_channels = 64;
199199
num_heads = -1;
200+
if (version == VERSION_SDXL_REFINER) {
201+
time_embed_dim = 1536;
202+
context_dim = 1280;
203+
model_channels = 384;
204+
adm_in_channels = 2560;
205+
attention_resolutions = {4, 2};
206+
channel_mult = {1, 2, 4, 4};
207+
transformer_depth = {4, 4, 4, 4};
208+
num_head_channels = 64;
209+
num_heads = -1;
210+
}
200211
} else if (version == VERSION_SVD) {
201212
in_channels = 8;
202213
out_channels = 4;

0 commit comments

Comments
 (0)