Skip to content

Commit 9f7c92d

Browse files
committed
Merge branch 'master' into svd
2 parents 3aa99ed + 349439f commit 9f7c92d

20 files changed

+1469
-379
lines changed

.clang-format

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ UseTab: Never
33
IndentWidth: 4
44
TabWidth: 4
55
AllowShortIfStatementsOnASingleLine: false
6-
IndentCaseLabels: false
76
ColumnLimit: 0
87
AccessModifierOffset: -4
98
NamespaceIndentation: All

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,4 @@ test/
1010
*.gguf
1111
output*.png
1212
models*
13-
!taesd-model.gguf
1413
*.log

CMakeLists.txt

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,13 @@ option(SD_CUBLAS "sd: cuda backend" OFF)
2828
option(SD_HIPBLAS "sd: rocm backend" OFF)
2929
option(SD_METAL "sd: metal backend" OFF)
3030
option(SD_FLASH_ATTN "sd: use flash attention for x4 less memory usage" OFF)
31-
option(SD_FAST_SOFTMAX "sd: x1.5 faster softmax, indeterministic (sometimes, same seed don't generate same image), cuda only" OFF)
3231
option(BUILD_SHARED_LIBS "sd: build shared libs" OFF)
3332
#option(SD_BUILD_SERVER "sd: build server example" ON)
3433

3534
if(SD_CUBLAS)
3635
message("Use CUBLAS as backend stable-diffusion")
3736
set(GGML_CUBLAS ON)
3837
add_definitions(-DSD_USE_CUBLAS)
39-
if(SD_FAST_SOFTMAX)
40-
set(GGML_CUDA_FAST_SOFTMAX ON)
41-
endif()
4238
endif()
4339

4440
if(SD_METAL)
@@ -64,7 +60,8 @@ endif()
6460
set(SD_LIB stable-diffusion)
6561

6662
add_library(${SD_LIB} stable-diffusion.h stable-diffusion.cpp model.h model.cpp util.h util.cpp upscaler.cpp tensor.hpp
67-
ggml_extend.hpp clip.hpp common.hpp unet.hpp tae.hpp esrgan.hpp lora.hpp denoiser.hpp rng.hpp rng_philox.hpp)
63+
ggml_extend.hpp clip.hpp common.hpp unet.hpp tae.hpp esrgan.hpp lora.hpp denoiser.hpp rng.hpp rng_philox.hpp
64+
control.hpp preprocessing.hpp)
6865

6966
if(BUILD_SHARED_LIBS)
7067
message("Build shared library")

README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
3131
- Faster and memory efficient latent decoding with [TAESD](https://github.com/madebyollin/taesd)
3232
- Upscale images generated with [ESRGAN](https://github.com/xinntao/Real-ESRGAN)
3333
- VAE tiling processing for reduce memory usage
34+
- Control Net support with SD 1.5
3435
- Sampling method
3536
- `Euler A`
3637
- `Euler`
@@ -53,9 +54,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
5354
- [ ] More sampling methods
5455
- [ ] Make inference faster
5556
- The current implementation of ggml_conv_2d is slow and has high memory usage
56-
- Implement Winograd Convolution 2D for 3x3 kernel filtering
5757
- [ ] Continuing to reduce memory usage (quantizing the weights of ggml_conv_2d)
58-
- [ ] Implement Textual Inversion (embeddings)
5958
- [ ] Implement Inpainting support
6059
- [ ] k-quants support
6160

@@ -159,16 +158,20 @@ arguments:
159158
-m, --model [MODEL] path to model
160159
--vae [VAE] path to vae
161160
--taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
161+
--control-net [CONTROL_PATH] path to control net model
162+
--embd-dir [EMBEDDING_PATH] path to embeddings.
162163
--upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now.
163164
--type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)
164165
If not specified, the default is the type of the weight file.
165166
--lora-model-dir [DIR] lora model directory
166167
-i, --init-img [IMAGE] path to the input image, required by img2img
168+
--control-image [IMAGE] path to image condition, control net
167169
-o, --output OUTPUT path to write result image to (default: ./output.png)
168170
-p, --prompt [PROMPT] the prompt to render
169171
-n, --negative-prompt PROMPT the negative prompt (default: "")
170172
--cfg-scale SCALE unconditional guidance scale: (default: 7.0)
171173
--strength STRENGTH strength for noising/unnoising (default: 0.75)
174+
--control-strength STRENGTH strength to apply Control Net (default: 0.9)
172175
1.0 corresponds to full destruction of information in init image
173176
-H, --height H image height, in pixel space (default: 512)
174177
-W, --width W image width, in pixel space (default: 512)
@@ -182,6 +185,7 @@ arguments:
182185
--clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)
183186
<= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x
184187
--vae-tiling process vae in tiles to reduce memory usage
188+
--control-net-cpu keep controlnet in cpu (for low vram)
185189
-v, --verbose print extra info
186190
```
187191

assets/control.png

4.28 KB
Loading

assets/control_2.png

6.09 KB
Loading

assets/control_3.png

18.4 KB
Loading

clip.hpp

Lines changed: 137 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define __CLIP_HPP__
33

44
#include "ggml_extend.hpp"
5+
#include "model.h"
56

67
/*================================================== CLIPTokenizer ===================================================*/
78

@@ -67,6 +68,9 @@ std::vector<std::pair<int, std::u32string>> bytes_to_unicode() {
6768
}
6869

6970
// Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
71+
72+
typedef std::function<bool(std::string&, std::vector<int32_t>&)> on_new_token_cb_t;
73+
7074
class CLIPTokenizer {
7175
private:
7276
SDVersion version = VERSION_1_x;
@@ -234,8 +238,11 @@ class CLIPTokenizer {
234238
return result;
235239
}
236240

237-
std::vector<int> tokenize(std::string text, size_t max_length = 0, bool padding = false) {
238-
std::vector<int32_t> tokens = encode(text);
241+
std::vector<int> tokenize(std::string text,
242+
on_new_token_cb_t on_new_token_cb,
243+
size_t max_length = 0,
244+
bool padding = false) {
245+
std::vector<int32_t> tokens = encode(text, on_new_token_cb);
239246
tokens.insert(tokens.begin(), BOS_TOKEN_ID);
240247
if (max_length > 0) {
241248
if (tokens.size() > max_length - 1) {
@@ -255,7 +262,7 @@ class CLIPTokenizer {
255262
return tokens;
256263
}
257264

258-
std::vector<int> encode(std::string text) {
265+
std::vector<int> encode(std::string text, on_new_token_cb_t on_new_token_cb) {
259266
std::string original_text = text;
260267
std::vector<int32_t> bpe_tokens;
261268
text = whitespace_clean(text);
@@ -268,6 +275,10 @@ class CLIPTokenizer {
268275
std::string str = text;
269276
std::vector<std::string> token_strs;
270277
while (std::regex_search(str, matches, pat)) {
278+
bool skip = on_new_token_cb(str, bpe_tokens);
279+
if (skip) {
280+
continue;
281+
}
271282
for (auto& token : matches) {
272283
std::string token_str = token.str();
273284
std::u32string utf32_token;
@@ -536,7 +547,13 @@ class CLIPEmbeddings : public GGMLBlock {
536547
num_positions(num_positions) {
537548
}
538549

539-
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* input_ids) {
550+
struct ggml_tensor* get_token_embed_weight() {
551+
return params["token_embedding.weight"];
552+
}
553+
554+
struct ggml_tensor* forward(struct ggml_context* ctx,
555+
struct ggml_tensor* input_ids,
556+
struct ggml_tensor* custom_embed_weight) {
540557
// input_ids: [N, n_token]
541558
auto token_embed_weight = params["token_embedding.weight"];
542559
auto position_embed_weight = params["position_embedding.weight"];
@@ -545,7 +562,7 @@ class CLIPEmbeddings : public GGMLBlock {
545562

546563
// token_embedding + position_embedding
547564
auto x = ggml_add(ctx,
548-
ggml_get_rows(ctx, token_embed_weight, input_ids),
565+
ggml_get_rows(ctx, custom_embed_weight != NULL ? custom_embed_weight : token_embed_weight, input_ids),
549566
position_embed_weight); // [N, n_token, embed_dim]
550567
return x;
551568
}
@@ -667,14 +684,23 @@ class CLIPTextModel : public GGMLBlock {
667684
clip_skip = skip;
668685
}
669686

670-
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* input_ids, size_t max_token_idx = 0, bool return_pooled = false) {
687+
struct ggml_tensor* get_token_embed_weight() {
688+
auto embeddings = std::dynamic_pointer_cast<CLIPEmbeddings>(blocks["embeddings"]);
689+
return embeddings->get_token_embed_weight();
690+
}
691+
692+
struct ggml_tensor* forward(struct ggml_context* ctx,
693+
struct ggml_tensor* input_ids,
694+
struct ggml_tensor* tkn_embeddings,
695+
size_t max_token_idx = 0,
696+
bool return_pooled = false) {
671697
// input_ids: [N, n_token]
672698
auto embeddings = std::dynamic_pointer_cast<CLIPEmbeddings>(blocks["embeddings"]);
673699
auto encoder = std::dynamic_pointer_cast<CLIPEncoder>(blocks["encoder"]);
674700
auto final_layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["final_layer_norm"]);
675701

676-
auto x = embeddings->forward(ctx, input_ids); // [N, n_token, hidden_size]
677-
x = encoder->forward(ctx, x, return_pooled ? -1 : clip_skip, true);
702+
auto x = embeddings->forward(ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size]
703+
x = encoder->forward(ctx, x, return_pooled ? -1 : clip_skip, true);
678704
if (return_pooled || with_final_ln) {
679705
x = final_layer_norm->forward(ctx, x);
680706
}
@@ -695,6 +721,7 @@ class CLIPVisionModel : public GGMLBlock {
695721
void init_params(struct ggml_context* ctx, ggml_type wtype) {
696722
params["visual_projection"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, projection_dim, hidden_size);
697723
}
724+
698725
public:
699726
// network hparams
700727
int32_t num_channels = 3;
@@ -742,10 +769,10 @@ class CLIPVisionModel : public GGMLBlock {
742769
x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]
743770

744771
GGML_ASSERT(x->ne[2] == 1);
745-
int64_t max_token_idx = 0;
746-
ggml_tensor* pooled = ggml_view_1d(ctx, x, x->ne[0], x->nb[1] * max_token_idx); // assert N == 1
772+
int64_t max_token_idx = 0;
773+
ggml_tensor* pooled = ggml_view_1d(ctx, x, x->ne[0], x->nb[1] * max_token_idx); // assert N == 1
747774
auto visual_projection = params["visual_projection"];
748-
pooled = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_transpose(ctx, visual_projection)), pooled);
775+
pooled = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_transpose(ctx, visual_projection)), pooled);
749776
return pooled; // [N, projection_dim]
750777
}
751778
};
@@ -790,6 +817,11 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
790817
CLIPTextModel text_model;
791818
CLIPTextModel text_model2;
792819

820+
std::string embd_dir;
821+
int32_t num_custom_embeddings = 0;
822+
std::vector<uint8_t> token_embed_custom;
823+
std::vector<std::string> readed_embeddings;
824+
793825
FrozenCLIPEmbedderWithCustomWords(ggml_backend_t backend,
794826
ggml_type wtype,
795827
SDVersion version = VERSION_1_x,
@@ -849,15 +881,53 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
849881
}
850882
}
851883

884+
bool load_embedding(std::string embd_name, std::string embd_path, std::vector<int32_t>& bpe_tokens) {
885+
// the order matters
886+
ModelLoader model_loader;
887+
if (!model_loader.init_from_file(embd_path)) {
888+
LOG_ERROR("embedding '%s' failed", embd_name.c_str());
889+
return false;
890+
}
891+
struct ggml_init_params params;
892+
params.mem_size = 32 * 1024; // max for custom embeddings 32 KB
893+
params.mem_buffer = NULL;
894+
params.no_alloc = false;
895+
struct ggml_context* embd_ctx = ggml_init(params);
896+
struct ggml_tensor* embd = NULL;
897+
auto on_load = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) {
898+
if (tensor_storage.ne[0] != text_model.hidden_size) {
899+
LOG_DEBUG("embedding wrong hidden size, got %i, expected %i", tensor_storage.ne[0], text_model.hidden_size);
900+
return false;
901+
}
902+
embd = ggml_new_tensor_2d(embd_ctx, wtype, text_model.hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne[1] : 1);
903+
*dst_tensor = embd;
904+
return true;
905+
};
906+
model_loader.load_tensors(on_load, NULL);
907+
readed_embeddings.push_back(embd_name);
908+
token_embed_custom.resize(token_embed_custom.size() + ggml_nbytes(embd));
909+
memcpy((void*)(token_embed_custom.data() + num_custom_embeddings * text_model.hidden_size * ggml_type_size(wtype)),
910+
embd->data,
911+
ggml_nbytes(embd));
912+
for (int i = 0; i < embd->ne[1]; i++) {
913+
bpe_tokens.push_back(text_model.vocab_size + num_custom_embeddings);
914+
// LOG_DEBUG("new custom token: %i", text_model.vocab_size + num_custom_embeddings);
915+
num_custom_embeddings++;
916+
}
917+
LOG_DEBUG("embedding '%s' applied, custom embeddings: %i", embd_name.c_str(), num_custom_embeddings);
918+
return true;
919+
}
920+
852921
struct ggml_tensor* forward(struct ggml_context* ctx,
853922
struct ggml_tensor* input_ids,
854923
struct ggml_tensor* input_ids2,
924+
struct ggml_tensor* embeddings,
855925
size_t max_token_idx = 0,
856926
bool return_pooled = false) {
857927
if (return_pooled) {
858-
return text_model2.forward(ctx, input_ids2, max_token_idx, return_pooled);
928+
return text_model2.forward(ctx, input_ids2, NULL, max_token_idx, return_pooled);
859929
}
860-
auto hidden_states = text_model.forward(ctx, input_ids); // [N, n_token, hidden_size]
930+
auto hidden_states = text_model.forward(ctx, input_ids, embeddings); // [N, n_token, hidden_size]
861931
// LOG_DEBUG("hidden_states: %d %d %d %d", hidden_states->ne[0], hidden_states->ne[1], hidden_states->ne[2], hidden_states->ne[3]);
862932
if (version == VERSION_XL) {
863933
hidden_states = ggml_reshape_4d(ctx,
@@ -868,7 +938,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
868938
hidden_states->ne[3]);
869939
hidden_states = ggml_cont(ctx, ggml_permute(ctx, hidden_states, 2, 0, 1, 3));
870940

871-
auto hidden_states2 = text_model2.forward(ctx, input_ids2); // [N, n_token, hidden_size2]
941+
auto hidden_states2 = text_model2.forward(ctx, input_ids2, NULL); // [N, n_token, hidden_size2]
872942
// LOG_DEBUG("hidden_states: %d %d %d %d", hidden_states->ne[0], hidden_states->ne[1], hidden_states->ne[2], hidden_states->ne[3]);
873943
hidden_states2 = ggml_reshape_4d(ctx,
874944
hidden_states2,
@@ -919,7 +989,34 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
919989
}
920990
}
921991

922-
struct ggml_tensor* hidden_states = forward(compute_ctx, input_ids, input_ids2, max_token_idx, return_pooled);
992+
struct ggml_tensor* embeddings = NULL;
993+
994+
if (num_custom_embeddings > 0 && version != VERSION_XL) {
995+
embeddings = ggml_new_tensor_2d(compute_ctx,
996+
wtype,
997+
text_model.hidden_size,
998+
text_model.vocab_size + num_custom_embeddings /* custom placeholder */);
999+
ggml_allocr_alloc(allocr, embeddings);
1000+
if (!ggml_allocr_is_measure(allocr)) {
1001+
// really bad, there is memory inflexibility (this is for host<->device memory conflicts)
1002+
auto token_embed_weight = text_model.get_token_embed_weight();
1003+
void* freeze_data = malloc(ggml_nbytes(token_embed_weight));
1004+
ggml_backend_tensor_get_and_sync(backend,
1005+
token_embed_weight,
1006+
freeze_data,
1007+
0,
1008+
ggml_nbytes(token_embed_weight));
1009+
ggml_backend_tensor_set(embeddings, freeze_data, 0, ggml_nbytes(token_embed_weight));
1010+
free(freeze_data);
1011+
// concatenate custom embeddings
1012+
ggml_backend_tensor_set(embeddings,
1013+
(const void*)token_embed_custom.data(),
1014+
ggml_nbytes(token_embed_weight),
1015+
num_custom_embeddings * text_model.hidden_size * ggml_type_size(wtype));
1016+
}
1017+
}
1018+
1019+
struct ggml_tensor* hidden_states = forward(compute_ctx, input_ids, input_ids2, embeddings, max_token_idx, return_pooled);
9231020

9241021
ggml_build_forward_expand(gf, hidden_states);
9251022

@@ -957,12 +1054,36 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
9571054
LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str());
9581055
}
9591056

1057+
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
1058+
size_t word_end = str.find(",");
1059+
std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end);
1060+
embd_name = trim(embd_name);
1061+
std::string embd_path = get_full_path(embd_dir, embd_name + ".pt");
1062+
if (embd_path.size() == 0) {
1063+
embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
1064+
}
1065+
if (embd_path.size() == 0) {
1066+
embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
1067+
}
1068+
if (embd_path.size() > 0) {
1069+
if (load_embedding(embd_name, embd_path, bpe_tokens)) {
1070+
if (word_end != std::string::npos) {
1071+
str = str.substr(word_end);
1072+
} else {
1073+
str = "";
1074+
}
1075+
return true;
1076+
}
1077+
}
1078+
return false;
1079+
};
1080+
9601081
std::vector<int> tokens;
9611082
std::vector<float> weights;
9621083
for (const auto& item : parsed_attention) {
9631084
const std::string& curr_text = item.first;
9641085
float curr_weight = item.second;
965-
std::vector<int> curr_tokens = tokenizer.encode(curr_text);
1086+
std::vector<int> curr_tokens = tokenizer.encode(curr_text, on_new_token_cb);
9661087
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
9671088
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
9681089
}

0 commit comments

Comments
 (0)