Skip to content

Commit 8f303b1

Browse files
committed
feat: sdxl refiner
Signed-off-by: thxCode <thxcode0824@gmail.com>
1 parent c0788c8 commit 8f303b1

File tree

7 files changed

+103
-42
lines changed

7 files changed

+103
-42
lines changed

conditioner.hpp

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
6565
: version(version), pm_version(pv), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir), wtype(wtype) {
6666
if (clip_skip <= 0) {
6767
clip_skip = 1;
68-
if (version == VERSION_SD2 || version == VERSION_SDXL) {
68+
if (version == VERSION_SD2 || version == VERSION_SDXL || version == VERSION_SDXL_REFINER) {
6969
clip_skip = 2;
7070
}
7171
}
@@ -76,40 +76,53 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
7676
} else if (version == VERSION_SDXL) {
7777
text_model = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip, false);
7878
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
79+
} else if (version == VERSION_SDXL_REFINER) {
80+
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
7981
}
8082
}
8183

8284
void set_clip_skip(int clip_skip) {
83-
text_model->set_clip_skip(clip_skip);
84-
if (version == VERSION_SDXL) {
85+
if (version != VERSION_SDXL_REFINER) {
86+
text_model->set_clip_skip(clip_skip);
87+
}
88+
if (version == VERSION_SDXL || version == VERSION_SDXL_REFINER) {
8589
text_model2->set_clip_skip(clip_skip);
8690
}
8791
}
8892

8993
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
90-
text_model->get_param_tensors(tensors, "cond_stage_model.transformer.text_model");
91-
if (version == VERSION_SDXL) {
94+
if (version != VERSION_SDXL_REFINER) {
95+
text_model->get_param_tensors(tensors, "cond_stage_model.transformer.text_model");
96+
}
97+
if (version == VERSION_SDXL || version == VERSION_SDXL_REFINER) {
9298
text_model2->get_param_tensors(tensors, "cond_stage_model.1.transformer.text_model");
9399
}
94100
}
95101

96102
void alloc_params_buffer() {
97-
text_model->alloc_params_buffer();
98-
if (version == VERSION_SDXL) {
103+
if (version != VERSION_SDXL_REFINER) {
104+
text_model->alloc_params_buffer();
105+
}
106+
if (version == VERSION_SDXL || version == VERSION_SDXL_REFINER) {
99107
text_model2->alloc_params_buffer();
100108
}
101109
}
102110

103111
void free_params_buffer() {
104-
text_model->free_params_buffer();
105-
if (version == VERSION_SDXL) {
112+
if (version != VERSION_SDXL_REFINER) {
113+
text_model->free_params_buffer();
114+
}
115+
if (version == VERSION_SDXL || version == VERSION_SDXL_REFINER) {
106116
text_model2->free_params_buffer();
107117
}
108118
}
109119

110120
size_t get_params_buffer_size() {
111-
size_t buffer_size = text_model->get_params_buffer_size();
112-
if (version == VERSION_SDXL) {
121+
size_t buffer_size = 0;
122+
if (version != VERSION_SDXL_REFINER) {
123+
buffer_size = text_model->get_params_buffer_size();
124+
}
125+
if (version == VERSION_SDXL || version == VERSION_SDXL_REFINER) {
113126
buffer_size += text_model2->get_params_buffer_size();
114127
}
115128
return buffer_size;
@@ -132,8 +145,13 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
132145
params.no_alloc = false;
133146
struct ggml_context* embd_ctx = ggml_init(params);
134147
struct ggml_tensor* embd = NULL;
135-
int64_t hidden_size = text_model->model.hidden_size;
136-
auto on_load = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) {
148+
int64_t hidden_size = 0;
149+
if (version != VERSION_SDXL_REFINER) {
150+
hidden_size = text_model->model.hidden_size;
151+
} else {
152+
hidden_size = text_model2->model.hidden_size;
153+
}
154+
auto on_load = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) {
137155
if (tensor_storage.ne[0] != hidden_size) {
138156
LOG_DEBUG("embedding wrong hidden size, got %i, expected %i", tensor_storage.ne[0], hidden_size);
139157
return false;
@@ -149,7 +167,11 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
149167
embd->data,
150168
ggml_nbytes(embd));
151169
for (int i = 0; i < embd->ne[1]; i++) {
152-
bpe_tokens.push_back(text_model->model.vocab_size + num_custom_embeddings);
170+
if (version != VERSION_SDXL_REFINER) {
171+
bpe_tokens.push_back(text_model->model.vocab_size + num_custom_embeddings);
172+
} else {
173+
bpe_tokens.push_back(text_model2->model.vocab_size + num_custom_embeddings);
174+
}
153175
// LOG_DEBUG("new custom token: %i", text_model.vocab_size + num_custom_embeddings);
154176
num_custom_embeddings++;
155177
}
@@ -163,7 +185,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
163185
int32_t image_token,
164186
bool padding = false) {
165187
return tokenize_with_trigger_token(text, num_input_imgs, image_token,
166-
text_model->model.n_token, padding);
188+
version != VERSION_SDXL_REFINER ? text_model->model.n_token : text_model2->model.n_token, padding);
167189
}
168190

169191
std::vector<int> convert_token_to_id(std::string text) {
@@ -312,7 +334,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
312334

313335
std::pair<std::vector<int>, std::vector<float>> tokenize(std::string text,
314336
bool padding = false) {
315-
return tokenize(text, text_model->model.n_token, padding);
337+
return tokenize(text, version != VERSION_SDXL_REFINER ? text_model->model.n_token : text_model2->model.n_token, padding);
316338
}
317339

318340
std::pair<std::vector<int>, std::vector<float>> tokenize(std::string text,
@@ -403,7 +425,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
403425
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
404426
struct ggml_tensor* input_ids2 = NULL;
405427
size_t max_token_idx = 0;
406-
if (version == VERSION_SDXL) {
428+
if (version == VERSION_SDXL || version == VERSION_SDXL_REFINER) {
407429
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), tokenizer.EOS_TOKEN_ID);
408430
if (it != chunk_tokens.end()) {
409431
std::fill(std::next(it), chunk_tokens.end(), 0);
@@ -428,16 +450,20 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
428450
false,
429451
&chunk_hidden_states1,
430452
work_ctx);
431-
if (version == VERSION_SDXL) {
453+
if (version == VERSION_SDXL || version == VERSION_SDXL_REFINER) {
432454
text_model2->compute(n_threads,
433455
input_ids2,
434456
0,
435457
NULL,
436458
max_token_idx,
437459
false,
438460
&chunk_hidden_states2, work_ctx);
439-
// concat
440-
chunk_hidden_states = ggml_tensor_concat(work_ctx, chunk_hidden_states1, chunk_hidden_states2, 0);
461+
if (version == VERSION_SDXL) {
462+
// concat
463+
chunk_hidden_states = ggml_tensor_concat(work_ctx, chunk_hidden_states1, chunk_hidden_states2, 0);
464+
} else {
465+
chunk_hidden_states = chunk_hidden_states2;
466+
}
441467

442468
if (chunk_idx == 0) {
443469
text_model2->compute(n_threads,
@@ -487,7 +513,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
487513
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
488514

489515
ggml_tensor* vec = NULL;
490-
if (version == VERSION_SDXL) {
516+
if (version == VERSION_SDXL || version == VERSION_SDXL_REFINER) {
491517
int out_dim = 256;
492518
vec = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, adm_in_channels);
493519
// [0:1280]

control.hpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@ class ControlNetBlock : public GGMLBlock {
2323
std::vector<int> attention_resolutions = {4, 2, 1};
2424
std::vector<int> channel_mult = {1, 2, 4, 4};
2525
std::vector<int> transformer_depth = {1, 1, 1, 1};
26-
int time_embed_dim = 1280; // model_channels*4
26+
int time_embed_dim = 1280; // model_channels*4, 1536 for VERSION_SDXL_REFINER
2727
int num_heads = 8;
2828
int num_head_channels = -1; // channels // num_heads
29-
int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL
29+
int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL, 1280 for VERSION_SDXL_REFINER
3030

3131
public:
32-
int model_channels = 320;
33-
int adm_in_channels = 2816; // only for VERSION_SDXL
32+
int model_channels = 320; // 384 for VERSION_SDXL_REFINER
33+
int adm_in_channels = 2816; // 2816 for VERSION_SDXL/SVD, 2560 for VERSION_SDXL_REFINER
3434

3535
ControlNetBlock(SDVersion version = VERSION_SD1)
3636
: version(version) {
@@ -45,6 +45,16 @@ class ControlNetBlock : public GGMLBlock {
4545
transformer_depth = {1, 2, 10};
4646
num_head_channels = 64;
4747
num_heads = -1;
48+
} else if (version == VERSION_SDXL_REFINER) {
49+
time_embed_dim = 1536;
50+
context_dim = 1280;
51+
model_channels = 384;
52+
adm_in_channels = 2560;
53+
attention_resolutions = {4, 2};
54+
channel_mult = {1, 2, 4, 4};
55+
transformer_depth = {4, 4, 4, 4};
56+
num_head_channels = 64;
57+
num_heads = -1;
4858
} else if (version == VERSION_SVD) {
4959
in_channels = 8;
5060
out_channels = 4;
@@ -58,7 +68,7 @@ class ControlNetBlock : public GGMLBlock {
5868
// time_embed_1 is nn.SiLU()
5969
blocks["time_embed.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
6070

61-
if (version == VERSION_SDXL || version == VERSION_SVD) {
71+
if (version == VERSION_SDXL || version == VERSION_SDXL_REFINER || version == VERSION_SVD) {
6272
blocks["label_emb.0.0"] = std::shared_ptr<GGMLBlock>(new Linear(adm_in_channels, time_embed_dim));
6373
// label_emb_1 is nn.SiLU()
6474
blocks["label_emb.0.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));

denoiser.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ struct AYSSchedule : SigmaSchedule {
176176
inputs = noise_levels[0];
177177
break;
178178
case VERSION_SDXL:
179+
case VERSION_SDXL_REFINER:
179180
LOG_INFO("AYS using SDXL noise levels");
180181
inputs = noise_levels[1];
181182
break;

model.cpp

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1463,10 +1463,12 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
14631463

14641464
SDVersion ModelLoader::get_sd_version() {
14651465
TensorStorage token_embedding_weight;
1466-
bool is_flux = false;
1467-
bool is_schnell = true;
1468-
bool is_lite = true;
1469-
bool is_sd3 = false;
1466+
bool is_flux = false;
1467+
bool is_schnell = true;
1468+
bool is_lite = true;
1469+
bool is_sdxl = false;
1470+
bool is_sdxl_base = false;
1471+
bool is_sd3 = false;
14701472
for (auto& tensor_storage : tensor_storages) {
14711473
if (tensor_storage.name.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) {
14721474
is_schnell = false;
@@ -1486,11 +1488,15 @@ SDVersion ModelLoader::get_sd_version() {
14861488
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.23.") != std::string::npos) {
14871489
is_sd3 = true;
14881490
}
1489-
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos) {
1490-
return VERSION_SDXL;
1491+
if (tensor_storage.name == "conditioner.embedders.0.model.token_embedding.weight" ||
1492+
tensor_storage.name == "cond_stage_model.1.transformer.text_model.embeddings.token_embedding.weight") {
1493+
if (tensor_storage.ne[0] == 1280) {
1494+
is_sdxl = true;
1495+
}
14911496
}
1492-
if (tensor_storage.name.find("cond_stage_model.1") != std::string::npos) {
1493-
return VERSION_SDXL;
1497+
if ((tensor_storage.name == "conditioner.embedders.1.model.token_embedding.weight" && tensor_storage.ne[0] == 1280) ||
1498+
(tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" && tensor_storage.ne[0] == 768)) {
1499+
is_sdxl_base = true;
14941500
}
14951501
if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) {
14961502
return VERSION_SVD;
@@ -1519,6 +1525,12 @@ SDVersion ModelLoader::get_sd_version() {
15191525
if (is_sd3) {
15201526
return VERSION_SD3_2B;
15211527
}
1528+
if (is_sdxl && !is_sdxl_base) {
1529+
return VERSION_SDXL_REFINER;
1530+
}
1531+
if (is_sdxl) {
1532+
return VERSION_SDXL;
1533+
}
15221534
if (token_embedding_weight.ne[0] == 768) {
15231535
return VERSION_SD1;
15241536
} else if (token_embedding_weight.ne[0] == 1024) {

model.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ enum SDVersion {
2121
VERSION_SD1,
2222
VERSION_SD2,
2323
VERSION_SDXL,
24+
VERSION_SDXL_REFINER,
2425
VERSION_SVD,
2526
VERSION_SD3_2B,
2627
VERSION_FLUX_DEV,

stable-diffusion.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ const char* model_version_to_str[] = {
2828
"SD 1.x",
2929
"SD 2.x",
3030
"SDXL",
31+
"SDXL Refiner",
3132
"SVD",
3233
"SD3 2B",
3334
"Flux Dev",
@@ -328,7 +329,7 @@ class StableDiffusionGGML {
328329
vae_wtype = wtype;
329330
}
330331

331-
if (version == VERSION_SDXL) {
332+
if (version == VERSION_SDXL || version == VERSION_SDXL_REFINER) {
332333
vae_wtype = GGML_TYPE_F32;
333334
}
334335

@@ -339,7 +340,7 @@ class StableDiffusionGGML {
339340

340341
LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor));
341342

342-
if (version == VERSION_SDXL) {
343+
if (version == VERSION_SDXL || version == VERSION_SDXL_REFINER) {
343344
scale_factor = 0.13025f;
344345
if (vae_path.size() == 0 && taesd_path.size() == 0) {
345346
LOG_WARN(
@@ -1378,7 +1379,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
13781379
SDCondition uncond;
13791380
if (cfg_scale != 1.0) {
13801381
bool force_zero_embeddings = false;
1381-
if (sd_ctx->sd->version == VERSION_SDXL && negative_prompt.size() == 0) {
1382+
if ((sd_ctx->sd->version == VERSION_SDXL || sd_ctx->sd->version == VERSION_SDXL_REFINER) && negative_prompt.size() == 0) {
13821383
force_zero_embeddings = true;
13831384
}
13841385
uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx,

unet.hpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,14 +174,14 @@ class UnetModelBlock : public GGMLBlock {
174174
std::vector<int> attention_resolutions = {4, 2, 1};
175175
std::vector<int> channel_mult = {1, 2, 4, 4};
176176
std::vector<int> transformer_depth = {1, 1, 1, 1};
177-
int time_embed_dim = 1280; // model_channels*4
177+
int time_embed_dim = 1280; // model_channels*4, 1536 for VERSION_SDXL_REFINER
178178
int num_heads = 8;
179179
int num_head_channels = -1; // channels // num_heads
180-
int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL
180+
int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL, 1280 for VERSION_SDXL_REFINER
181181

182182
public:
183-
int model_channels = 320;
184-
int adm_in_channels = 2816; // only for VERSION_SDXL/SVD
183+
int model_channels = 320; // 384 for VERSION_SDXL_REFINER
184+
int adm_in_channels = 2816; // 2816 for VERSION_SDXL/SVD, 2560 for VERSION_SDXL_REFINER
185185

186186
UnetModelBlock(SDVersion version = VERSION_SD1, bool flash_attn = false)
187187
: version(version) {
@@ -196,6 +196,16 @@ class UnetModelBlock : public GGMLBlock {
196196
transformer_depth = {1, 2, 10};
197197
num_head_channels = 64;
198198
num_heads = -1;
199+
} else if (version == VERSION_SDXL_REFINER) {
200+
time_embed_dim = 1536;
201+
context_dim = 1280;
202+
model_channels = 384;
203+
adm_in_channels = 2560;
204+
attention_resolutions = {4, 2};
205+
channel_mult = {1, 2, 4, 4};
206+
transformer_depth = {4, 4, 4, 4};
207+
num_head_channels = 64;
208+
num_heads = -1;
199209
} else if (version == VERSION_SVD) {
200210
in_channels = 8;
201211
out_channels = 4;
@@ -211,7 +221,7 @@ class UnetModelBlock : public GGMLBlock {
211221
// time_embed_1 is nn.SiLU()
212222
blocks["time_embed.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
213223

214-
if (version == VERSION_SDXL || version == VERSION_SVD) {
224+
if (version == VERSION_SDXL || version == VERSION_SDXL_REFINER || version == VERSION_SVD) {
215225
blocks["label_emb.0.0"] = std::shared_ptr<GGMLBlock>(new Linear(adm_in_channels, time_embed_dim));
216226
// label_emb_1 is nn.SiLU()
217227
blocks["label_emb.0.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));

0 commit comments

Comments
 (0)