@@ -63,51 +63,64 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
63
63
: version(version), tokenizer(version == VERSION_SD2 ? 0 : 49407 ), embd_dir(embd_dir), wtype(wtype) {
64
64
if (clip_skip <= 0 ) {
65
65
clip_skip = 1 ;
66
- if (version == VERSION_SD2 || version == VERSION_SDXL ) {
66
+ if (version == VERSION_SD2 || version == VERSION_SDXL_BASE || version == VERSION_SDXL_REFINER ) {
67
67
clip_skip = 2 ;
68
68
}
69
69
}
70
70
if (version == VERSION_SD1) {
71
71
text_model = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip);
72
72
} else if (version == VERSION_SD2) {
73
73
text_model = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPEN_CLIP_VIT_H_14, clip_skip);
74
- } else if (version == VERSION_SDXL ) {
74
+ } else if (version == VERSION_SDXL_BASE ) {
75
75
text_model = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip, false );
76
76
text_model2 = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPEN_CLIP_VIT_BIGG_14, clip_skip, false );
77
+ } else if (version == VERSION_SDXL_REFINER) {
78
+ text_model2 = std::make_shared<CLIPTextModelRunner>(backend, wtype, OPEN_CLIP_VIT_BIGG_14, clip_skip, false );
77
79
}
78
80
}
79
81
80
82
void set_clip_skip (int clip_skip) {
81
- text_model->set_clip_skip (clip_skip);
82
- if (version == VERSION_SDXL) {
83
+ if (version != VERSION_SDXL_REFINER) {
84
+ text_model->set_clip_skip (clip_skip);
85
+ }
86
+ if (version == VERSION_SDXL_BASE || version == VERSION_SDXL_REFINER) {
83
87
text_model2->set_clip_skip (clip_skip);
84
88
}
85
89
}
86
90
87
91
void get_param_tensors (std::map<std::string, struct ggml_tensor *>& tensors) {
88
- text_model->get_param_tensors (tensors, " cond_stage_model.transformer.text_model" );
89
- if (version == VERSION_SDXL) {
92
+ if (version != VERSION_SDXL_REFINER) {
93
+ text_model->get_param_tensors (tensors, " cond_stage_model.transformer.text_model" );
94
+ }
95
+ if (version == VERSION_SDXL_BASE || version == VERSION_SDXL_REFINER) {
90
96
text_model2->get_param_tensors (tensors, " cond_stage_model.1.transformer.text_model" );
91
97
}
92
98
}
93
99
94
100
void alloc_params_buffer () {
95
- text_model->alloc_params_buffer ();
96
- if (version == VERSION_SDXL) {
101
+ if (version != VERSION_SDXL_REFINER) {
102
+ text_model->alloc_params_buffer ();
103
+ }
104
+ if (version == VERSION_SDXL_BASE || version == VERSION_SDXL_REFINER) {
97
105
text_model2->alloc_params_buffer ();
98
106
}
99
107
}
100
108
101
109
void free_params_buffer () {
102
- text_model->free_params_buffer ();
103
- if (version == VERSION_SDXL) {
110
+ if (version != VERSION_SDXL_REFINER) {
111
+ text_model->free_params_buffer ();
112
+ }
113
+ if (version == VERSION_SDXL_BASE || version == VERSION_SDXL_REFINER) {
104
114
text_model2->free_params_buffer ();
105
115
}
106
116
}
107
117
108
118
size_t get_params_buffer_size () {
109
- size_t buffer_size = text_model->get_params_buffer_size ();
110
- if (version == VERSION_SDXL) {
119
+ size_t buffer_size = 0 ;
120
+ if (version != VERSION_SDXL_REFINER) {
121
+ buffer_size = text_model->get_params_buffer_size ();
122
+ }
123
+ if (version == VERSION_SDXL_BASE || version == VERSION_SDXL_REFINER) {
111
124
buffer_size += text_model2->get_params_buffer_size ();
112
125
}
113
126
return buffer_size;
@@ -398,7 +411,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
398
411
auto input_ids = vector_to_ggml_tensor_i32 (work_ctx, chunk_tokens);
399
412
struct ggml_tensor * input_ids2 = NULL ;
400
413
size_t max_token_idx = 0 ;
401
- if (version == VERSION_SDXL ) {
414
+ if (version == VERSION_SDXL_BASE || version == VERSION_SDXL_REFINER ) {
402
415
auto it = std::find (chunk_tokens.begin (), chunk_tokens.end (), tokenizer.EOS_TOKEN_ID );
403
416
if (it != chunk_tokens.end ()) {
404
417
std::fill (std::next (it), chunk_tokens.end (), 0 );
@@ -415,15 +428,17 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
415
428
}
416
429
417
430
{
418
- text_model->compute (n_threads,
419
- input_ids,
420
- num_custom_embeddings,
421
- token_embed_custom.data (),
422
- max_token_idx,
423
- false ,
424
- &chunk_hidden_states1,
425
- work_ctx);
426
- if (version == VERSION_SDXL) {
431
+ if (version != VERSION_SDXL_REFINER) {
432
+ text_model->compute (n_threads,
433
+ input_ids,
434
+ num_custom_embeddings,
435
+ token_embed_custom.data (),
436
+ max_token_idx,
437
+ false ,
438
+ &chunk_hidden_states1,
439
+ work_ctx);
440
+ }
441
+ if (version == VERSION_SDXL_BASE || version == VERSION_SDXL_REFINER) {
427
442
text_model2->compute (n_threads,
428
443
input_ids2,
429
444
0 ,
@@ -482,7 +497,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
482
497
ggml_nelements (hidden_states) / chunk_hidden_states->ne [0 ]);
483
498
484
499
ggml_tensor* vec = NULL ;
485
- if (version == VERSION_SDXL ) {
500
+ if (version == VERSION_SDXL_BASE || version == VERSION_SDXL_REFINER ) {
486
501
int out_dim = 256 ;
487
502
vec = ggml_new_tensor_1d (work_ctx, GGML_TYPE_F32, adm_in_channels);
488
503
// [0:1280]
@@ -623,6 +638,7 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner {
623
638
624
639
struct SD3CLIPEmbedder : public Conditioner {
625
640
ggml_type wtype;
641
+ bool compvis_compatiblity;
626
642
CLIPTokenizer clip_l_tokenizer;
627
643
CLIPTokenizer clip_g_tokenizer;
628
644
T5UniGramTokenizer t5_tokenizer;
@@ -632,8 +648,9 @@ struct SD3CLIPEmbedder : public Conditioner {
632
648
633
649
SD3CLIPEmbedder (ggml_backend_t backend,
634
650
ggml_type wtype,
635
- int clip_skip = -1 )
636
- : wtype(wtype), clip_g_tokenizer(0 ) {
651
+ bool compvis_compatiblity = false ,
652
+ int clip_skip = -1 )
653
+ : wtype(wtype), compvis_compatiblity(compvis_compatiblity), clip_g_tokenizer(0 ) {
637
654
if (clip_skip <= 0 ) {
638
655
clip_skip = 2 ;
639
656
}
@@ -648,6 +665,12 @@ struct SD3CLIPEmbedder : public Conditioner {
648
665
}
649
666
650
667
void get_param_tensors (std::map<std::string, struct ggml_tensor *>& tensors) {
668
+ if (compvis_compatiblity) {
669
+ clip_l->get_param_tensors (tensors, " cond_stage_model.transformer.text_model" );
670
+ clip_g->get_param_tensors (tensors, " cond_stage_model.1.transformer.text_model" );
671
+ t5->get_param_tensors (tensors, " cond_stage_model.2.transformer" );
672
+ return ;
673
+ }
651
674
clip_l->get_param_tensors (tensors, " text_encoders.clip_l.transformer.text_model" );
652
675
clip_g->get_param_tensors (tensors, " text_encoders.clip_g.transformer.text_model" );
653
676
t5->get_param_tensors (tensors, " text_encoders.t5xxl.transformer" );
0 commit comments