Skip to content

Commit 78ad76f

Browse files
authored
feat: add SDXL support (leejet#117)
* add SDXL support * fix the issue with generating large images
1 parent 004dfbe commit 78ad76f

File tree

5 files changed

+669
-347
lines changed

5 files changed

+669
-347
lines changed

.gitmodules

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
[submodule "ggml"]
22
path = ggml
3-
url = https://github.com/FSSRepo/ggml.git
3+
url = https://github.com/leejet/ggml.git

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
1010

1111
- Plain C/C++ implementation based on [ggml](https://github.com/ggerganov/ggml), working in the same way as [llama.cpp](https://github.com/ggerganov/llama.cpp)
1212
- Super lightweight and without external dependencies
13-
- SD1.x and SD2.x support
14-
- [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo) support
13+
- SD1.x, SD2.x and SDXL support
14+
- [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo) and [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo) support
1515
- 16-bit, 32-bit float support
1616
- 4-bit, 5-bit and 8-bit integer quantization support
1717
- Accelerated memory-efficient CPU inference
@@ -302,3 +302,4 @@ Thank you to all the people who have already contributed to stable-diffusion.cpp
302302
- [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
303303
- [k-diffusion](https://github.com/crowsonkb/k-diffusion)
304304
- [latent-consistency-model](https://github.com/luosiallen/latent-consistency-model)
305+
- [generative-models](https://github.com/Stability-AI/generative-models/)

model.cpp

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,9 @@ const char* unused_tensors[] = {
7878
"cond_stage_model.transformer.text_model.embeddings.position_ids",
7979
"cond_stage_model.model.logit_scale",
8080
"cond_stage_model.model.text_projection",
81+
"conditioner.embedders.0.transformer.text_model.embeddings.position_ids",
8182
"conditioner.embedders.0.model.logit_scale",
82-
"conditioner.embedders.0.model.text_projection",
83+
"conditioner.embedders.1.model.logit_scale",
8384
"model.diffusion_model.time_embedding.cond_proj.weight",
8485
"unet.time_embedding.cond_proj.weight",
8586
"model_ema.decay",
@@ -100,11 +101,11 @@ bool is_unused_tensor(std::string name) {
100101
}
101102

102103
std::unordered_map<std::string, std::string> open_clip_to_hf_clip_model = {
103-
{"cond_stage_model.model.ln_final.bias", "cond_stage_model.transformer.text_model.final_layer_norm.bias"},
104-
{"cond_stage_model.model.ln_final.weight", "cond_stage_model.transformer.text_model.final_layer_norm.weight"},
105-
{"cond_stage_model.model.positional_embedding", "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight"},
106-
{"cond_stage_model.model.token_embedding.weight", "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"},
107-
104+
{"model.ln_final.bias", "transformer.text_model.final_layer_norm.bias"},
105+
{"model.ln_final.weight", "transformer.text_model.final_layer_norm.weight"},
106+
{"model.positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"},
107+
{"model.token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"},
108+
{"model.text_projection", "transformer.text_model.text_projection"},
108109
};
109110

110111
std::unordered_map<std::string, std::string> open_clip_to_hk_clip_resblock = {
@@ -133,11 +134,21 @@ std::unordered_map<std::string, std::string> vae_decoder_name_map = {
133134

134135
std::string convert_open_clip_to_hf_clip(const std::string& name) {
135136
std::string new_name = name;
137+
std::string prefix;
136138
if (starts_with(new_name, "conditioner.embedders.0.")) {
137-
new_name = "cond_stage_model." + new_name.substr(strlen("conditioner.embedders.0."));
139+
prefix = "cond_stage_model.";
140+
new_name = new_name.substr(strlen("conditioner.embedders.0."));
141+
} else if (starts_with(new_name, "conditioner.embedders.1.")) {
142+
prefix = "cond_stage_model.1.";
143+
new_name = new_name.substr(strlen("conditioner.embedders.0."));
144+
} else if (starts_with(new_name, "cond_stage_model.")) {
145+
prefix = "cond_stage_model.";
146+
new_name = new_name.substr(strlen("cond_stage_model."));
147+
} else {
148+
return new_name;
138149
}
139-
std::string open_clip_resblock_prefix = "cond_stage_model.model.transformer.resblocks.";
140-
std::string hf_clip_resblock_prefix = "cond_stage_model.transformer.text_model.encoder.layers.";
150+
std::string open_clip_resblock_prefix = "model.transformer.resblocks.";
151+
std::string hf_clip_resblock_prefix = "transformer.text_model.encoder.layers.";
141152

142153
if (open_clip_to_hf_clip_model.find(new_name) != open_clip_to_hf_clip_model.end()) {
143154
new_name = open_clip_to_hf_clip_model[new_name];
@@ -156,7 +167,7 @@ std::string convert_open_clip_to_hf_clip(const std::string& name) {
156167
}
157168
}
158169

159-
return new_name;
170+
return prefix + new_name;
160171
}
161172

162173
std::string convert_vae_decoder_name(const std::string& name) {
@@ -358,7 +369,7 @@ std::string convert_diffusers_name_to_compvis(const std::string& key, char seq)
358369

359370
std::string convert_tensor_name(const std::string& name) {
360371
std::string new_name;
361-
if (starts_with(name, "cond_stage_model.model") || starts_with(name, "conditioner.embedders.0.model")) {
372+
if (starts_with(name, "cond_stage_model.") || starts_with(name, "conditioner.embedders.")) {
362373
new_name = convert_open_clip_to_hf_clip(name);
363374
} else if (starts_with(name, "first_stage_model.decoder")) {
364375
new_name = convert_vae_decoder_name(name);
@@ -419,7 +430,7 @@ void preprocess_tensor(TensorStorage tensor_storage,
419430

420431
tensor_storage.name = new_name;
421432

422-
if (starts_with(new_name, "cond_stage_model.transformer.text_model.encoder.layers.") &&
433+
if (new_name.find("transformer.text_model.encoder.layers.") != std::string::npos &&
423434
ends_with(new_name, "attn.in_proj_weight")) {
424435
size_t prefix_size = new_name.find("attn.in_proj_weight");
425436
std::string prefix = new_name.substr(0, prefix_size);
@@ -431,7 +442,7 @@ void preprocess_tensor(TensorStorage tensor_storage,
431442

432443
processed_tensor_storages.insert(processed_tensor_storages.end(), chunks.begin(), chunks.end());
433444

434-
} else if (starts_with(new_name, "cond_stage_model.transformer.text_model.encoder.layers.") &&
445+
} else if (new_name.find("transformer.text_model.encoder.layers.") != std::string::npos &&
435446
ends_with(new_name, "attn.in_proj_bias")) {
436447
size_t prefix_size = new_name.find("attn.in_proj_bias");
437448
std::string prefix = new_name.substr(0, prefix_size);
@@ -1163,15 +1174,20 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
11631174
}
11641175

11651176
SDVersion ModelLoader::get_sd_version() {
1177+
// return VERSION_1_x;
11661178
TensorStorage token_embedding_weight;
11671179
for (auto& tensor_storage : tensor_storages) {
1180+
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos) {
1181+
return VERSION_XL;
1182+
}
11681183
if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||
11691184
tensor_storage.name == "cond_stage_model.model.token_embedding.weight" ||
11701185
tensor_storage.name == "text_model.embeddings.token_embedding.weight" ||
11711186
tensor_storage.name == "te.text_model.embeddings.token_embedding.weight" ||
1172-
tensor_storage.name == "conditioner.embedders.0.model.token_embedding.weight") {
1187+
tensor_storage.name == "conditioner.embedders.0.model.token_embedding.weight" ||
1188+
tensor_storage.name == "conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight") {
11731189
token_embedding_weight = tensor_storage;
1174-
break;
1190+
// break;
11751191
}
11761192
}
11771193
if (token_embedding_weight.ne[0] == 768) {
@@ -1275,7 +1291,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
12751291
}
12761292

12771293
for (auto& tensor_storage : processed_tensor_storages) {
1278-
// LOG_DEBUG("%s", name.c_str());
1294+
// LOG_DEBUG("%s", tensor_storage.name.c_str());
12791295

12801296
ggml_tensor* dst_tensor = NULL;
12811297

0 commit comments

Comments
 (0)