Skip to content

Commit 7e8732f

Browse files
committed
fix: sdxl refiner
Signed-off-by: thxCode <thxcode0824@gmail.com>
1 parent 0428156 commit 7e8732f

File tree

2 files changed

+31
-15
lines changed

2 files changed

+31
-15
lines changed

conditioner.hpp

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,13 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
143143
params.no_alloc = false;
144144
struct ggml_context* embd_ctx = ggml_init(params);
145145
struct ggml_tensor* embd = NULL;
146-
int64_t hidden_size = text_model->model.hidden_size;
147-
auto on_load = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) {
146+
int64_t hidden_size = 0;
147+
if (version != VERSION_SDXL_REFINER) {
148+
hidden_size = text_model->model.hidden_size;
149+
} else {
150+
hidden_size = text_model2->model.hidden_size;
151+
}
152+
auto on_load = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) {
148153
if (tensor_storage.ne[0] != hidden_size) {
149154
LOG_DEBUG("embedding wrong hidden size, got %i, expected %i", tensor_storage.ne[0], hidden_size);
150155
return false;
@@ -160,21 +165,24 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
160165
embd->data,
161166
ggml_nbytes(embd));
162167
for (int i = 0; i < embd->ne[1]; i++) {
163-
bpe_tokens.push_back(text_model->model.vocab_size + num_custom_embeddings);
168+
if (version != VERSION_SDXL_REFINER) {
169+
bpe_tokens.push_back(text_model->model.vocab_size + num_custom_embeddings);
170+
} else {
171+
bpe_tokens.push_back(text_model2->model.vocab_size + num_custom_embeddings);
172+
}
164173
// LOG_DEBUG("new custom token: %i", text_model.vocab_size + num_custom_embeddings);
165174
num_custom_embeddings++;
166175
}
167176
LOG_DEBUG("embedding '%s' applied, custom embeddings: %i", embd_name.c_str(), num_custom_embeddings);
168177
return true;
169178
}
170179

171-
std::tuple<std::vector<int>, std::vector<float>, std::vector<bool>>
172-
tokenize_with_trigger_token(std::string text,
173-
int num_input_imgs,
174-
int32_t image_token,
175-
bool padding = false) {
180+
std::tuple<std::vector<int>, std::vector<float>, std::vector<bool>> tokenize_with_trigger_token(std::string text,
181+
int num_input_imgs,
182+
int32_t image_token,
183+
bool padding = false) {
176184
return tokenize_with_trigger_token(text, num_input_imgs, image_token,
177-
text_model->model.n_token, padding);
185+
version != VERSION_SDXL_REFINER ? text_model->model.n_token : text_model2->model.n_token, padding);
178186
}
179187

180188
std::vector<int> convert_token_to_id(std::string text) {
@@ -320,7 +328,10 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
320328

321329
std::pair<std::vector<int>, std::vector<float>> tokenize(std::string text,
322330
bool padding = false) {
323-
return tokenize(text, text_model->model.n_token, padding);
331+
if (version != VERSION_SDXL_REFINER) {
332+
return tokenize(text, text_model->model.n_token, padding);
333+
}
334+
return tokenize(text, text_model2->model.n_token, padding);
324335
}
325336

326337
std::pair<std::vector<int>, std::vector<float>> tokenize(std::string text,
@@ -446,8 +457,13 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
446457
max_token_idx,
447458
false,
448459
&chunk_hidden_states2, work_ctx);
449-
// concat
450-
chunk_hidden_states = ggml_tensor_concat(work_ctx, chunk_hidden_states1, chunk_hidden_states2, 0);
460+
461+
if (version == VERSION_SDXL) {
462+
// concat
463+
chunk_hidden_states = ggml_tensor_concat(work_ctx, chunk_hidden_states1, chunk_hidden_states2, 0);
464+
} else {
465+
chunk_hidden_states = chunk_hidden_states2;
466+
}
451467

452468
if (chunk_idx == 0) {
453469
text_model2->compute(n_threads,
@@ -497,7 +513,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
497513
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
498514

499515
ggml_tensor* vec = NULL;
500-
if (version == VERSION_SDXL || version == VERSION_SDXL_REFINER) {
516+
if (version == VERSION_SDXL) {
501517
int out_dim = 256;
502518
vec = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, adm_in_channels);
503519
// [0:1280]

unet.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,8 @@ class UnetModelBlock : public GGMLBlock {
202202
model_channels = 384;
203203
adm_in_channels = 2560;
204204
attention_resolutions = {4, 2};
205-
channel_mult = {1, 2, 4};
206-
transformer_depth = {1, 2, 10};
205+
channel_mult = {1, 2, 4, 4};
206+
transformer_depth = {4, 4, 4, 4};
207207
num_head_channels = 64;
208208
num_heads = -1;
209209
} else if (version == VERSION_SVD) {

0 commit comments

Comments
 (0)