@@ -73,42 +73,52 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
73
73
} else if (sd_version_is_sd2 (version)) {
74
74
text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, " cond_stage_model.transformer.text_model" , OPEN_CLIP_VIT_H_14, clip_skip);
75
75
} else if (sd_version_is_sdxl (version)) {
76
- text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, " cond_stage_model.transformer.text_model" , OPENAI_CLIP_VIT_L_14, clip_skip, false );
76
+ if (version != VERSION_SDXL_REFINER) {
77
+ text_model = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, " cond_stage_model.transformer.text_model" , OPENAI_CLIP_VIT_L_14, clip_skip, false );
78
+ }
77
79
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, " cond_stage_model.1.transformer.text_model" , OPEN_CLIP_VIT_BIGG_14, clip_skip, false );
78
80
}
79
81
}
80
82
81
83
void set_clip_skip (int clip_skip) {
82
- text_model->set_clip_skip (clip_skip);
83
- if (sd_version_is_sdxl (version)) {
84
+ if (text_model) {
85
+ text_model->set_clip_skip (clip_skip);
86
+ }
87
+ if (text_model2) {
84
88
text_model2->set_clip_skip (clip_skip);
85
89
}
86
90
}
87
91
88
92
void get_param_tensors (std::map<std::string, struct ggml_tensor *>& tensors) {
89
- text_model->get_param_tensors (tensors, " cond_stage_model.transformer.text_model" );
90
- if (sd_version_is_sdxl (version)) {
93
+ if (text_model) {
94
+ text_model->get_param_tensors (tensors, " cond_stage_model.transformer.text_model" );
95
+ }
96
+ if (text_model2) {
91
97
text_model2->get_param_tensors (tensors, " cond_stage_model.1.transformer.text_model" );
92
98
}
93
99
}
94
100
95
101
void alloc_params_buffer () {
96
- text_model->alloc_params_buffer ();
97
- if (sd_version_is_sdxl (version)) {
102
+ if (text_model) {
103
+ text_model->alloc_params_buffer ();
104
+ }
105
+ if (text_model2) {
98
106
text_model2->alloc_params_buffer ();
99
107
}
100
108
}
101
109
102
110
void free_params_buffer () {
103
- text_model->free_params_buffer ();
104
- if (sd_version_is_sdxl (version)) {
111
+ if (text_model) {
112
+ text_model->free_params_buffer ();
113
+ }
114
+ if (text_model2) {
105
115
text_model2->free_params_buffer ();
106
116
}
107
117
}
108
118
109
119
size_t get_params_buffer_size () {
110
- size_t buffer_size = text_model->get_params_buffer_size ();
111
- if (sd_version_is_sdxl (version) ) {
120
+ size_t buffer_size = text_model ? text_model ->get_params_buffer_size () : 0 ;
121
+ if (text_model2 ) {
112
122
buffer_size += text_model2->get_params_buffer_size ();
113
123
}
114
124
return buffer_size;
@@ -131,7 +141,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
131
141
params.no_alloc = false ;
132
142
struct ggml_context * embd_ctx = ggml_init (params);
133
143
struct ggml_tensor * embd = NULL ;
134
- int64_t hidden_size = text_model->model .hidden_size ;
144
+ int64_t hidden_size = text_model ? text_model-> model . hidden_size : text_model2 ->model .hidden_size ;
135
145
auto on_load = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) {
136
146
if (tensor_storage.ne [0 ] != hidden_size) {
137
147
LOG_DEBUG (" embedding wrong hidden size, got %i, expected %i" , tensor_storage.ne [0 ], hidden_size);
@@ -148,7 +158,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
148
158
embd->data ,
149
159
ggml_nbytes (embd));
150
160
for (int i = 0 ; i < embd->ne [1 ]; i++) {
151
- bpe_tokens.push_back (text_model->model .vocab_size + num_custom_embeddings);
161
+ bpe_tokens.push_back (( text_model ? text_model ->model .vocab_size : text_model2-> model . vocab_size ) + num_custom_embeddings);
152
162
// LOG_DEBUG("new custom token: %i", text_model.vocab_size + num_custom_embeddings);
153
163
num_custom_embeddings++;
154
164
}
@@ -162,7 +172,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
162
172
int32_t image_token,
163
173
bool padding = false ) {
164
174
return tokenize_with_trigger_token (text, num_input_imgs, image_token,
165
- text_model->model .n_token , padding);
175
+ text_model ? text_model-> model . n_token : text_model2 ->model .n_token , padding);
166
176
}
167
177
168
178
std::vector<int > convert_token_to_id (std::string text) {
@@ -311,7 +321,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
311
321
312
322
std::pair<std::vector<int >, std::vector<float >> tokenize (std::string text,
313
323
bool padding = false ) {
314
- return tokenize (text, text_model->model .n_token , padding);
324
+ return tokenize (text, text_model ? text_model-> model . n_token : text_model2 ->model .n_token , padding);
315
325
}
316
326
317
327
std::pair<std::vector<int >, std::vector<float >> tokenize (std::string text,
@@ -419,28 +429,31 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
419
429
}
420
430
421
431
{
422
- text_model->compute (n_threads,
423
- input_ids,
424
- num_custom_embeddings,
425
- token_embed_custom.data (),
426
- max_token_idx,
427
- false ,
428
- &chunk_hidden_states1,
429
- work_ctx);
430
- if (sd_version_is_sdxl (version)) {
432
+ if (text_model) {
433
+ text_model->compute (n_threads,
434
+ input_ids,
435
+ num_custom_embeddings,
436
+ token_embed_custom.data (),
437
+ max_token_idx,
438
+ false ,
439
+ &chunk_hidden_states1,
440
+ work_ctx);
441
+ }
442
+ if (text_model2) {
431
443
text_model2->compute (n_threads,
432
- input_ids2,
433
- 0 ,
434
- NULL ,
444
+ text_model ? input_ids2 : input_ids ,
445
+ text_model ? 0 : num_custom_embeddings ,
446
+ text_model ? NULL : token_embed_custom. data () ,
435
447
max_token_idx,
436
448
false ,
437
- &chunk_hidden_states2, work_ctx);
449
+ text_model ? &chunk_hidden_states2 : &chunk_hidden_states1,
450
+ work_ctx);
438
451
// concat
439
- chunk_hidden_states = ggml_tensor_concat (work_ctx, chunk_hidden_states1, chunk_hidden_states2, 0 );
452
+ chunk_hidden_states = text_model ? ggml_tensor_concat (work_ctx, chunk_hidden_states1, chunk_hidden_states2, 0 ) : chunk_hidden_states1 ;
440
453
441
454
if (chunk_idx == 0 ) {
442
455
text_model2->compute (n_threads,
443
- input_ids2,
456
+ text_model ? input_ids2 : input_ids ,
444
457
0 ,
445
458
NULL ,
446
459
max_token_idx,
@@ -486,7 +499,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
486
499
ggml_nelements (hidden_states) / chunk_hidden_states->ne [0 ]);
487
500
488
501
ggml_tensor* vec = NULL ;
489
- if (sd_version_is_sdxl (version)) {
502
+ if (sd_version_is_sdxl (version) && version != VERSION_SDXL_REFINER ) {
490
503
int out_dim = 256 ;
491
504
vec = ggml_new_tensor_1d (work_ctx, GGML_TYPE_F32, adm_in_channels);
492
505
// [0:1280]
0 commit comments