diff --git a/.gitmodules b/.gitmodules
index d9d943713..d5788ea42 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -1,3 +1,3 @@
[submodule "ggml"]
path = ggml
- url = https://github.com/ggerganov/ggml.git
+ url = https://github.com/leejet/ggml.git
diff --git a/README.md b/README.md
index 8f4a5f3b1..0fb37c075 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,5 @@
-
+
# stable-diffusion.cpp
@@ -10,7 +10,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
- Plain C/C++ implementation based on [ggml](https://github.com/ggerganov/ggml), working in the same way as [llama.cpp](https://github.com/ggerganov/llama.cpp)
- Super lightweight and without external dependencies
-- SD1.x, SD2.x and SDXL support
+- SD1.x, SD2.x, SDXL and SD3 support
- !!!The VAE in SDXL encounters NaN issues under FP16, but unfortunately, the ggml_conv_2d only operates under FP16. Hence, a parameter is needed to specify the VAE that has fixed the FP16 NaN issue. You can find it here: [SDXL VAE FP16 Fix](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors).
- [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo) and [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo) support
@@ -86,11 +86,13 @@ git submodule update
- Stable Diffusion v1.4 from https://huggingface.co/CompVis/stable-diffusion-v-1-4-original
- Stable Diffusion v1.5 from https://huggingface.co/runwayml/stable-diffusion-v1-5
- Stable Diffuison v2.1 from https://huggingface.co/stabilityai/stable-diffusion-2-1
+ - Stable Diffusion 3 2B from https://huggingface.co/stabilityai/stable-diffusion-3-medium
```shell
curl -L -O https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt
# curl -L -O https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors
# curl -L -O https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-nonema-pruned.safetensors
+ # curl -L -O https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/sd3_medium_incl_clips_t5xxlfp16.safetensors
```
### Build
@@ -226,6 +228,7 @@ For example:
./bin/sd -m ../models/sd-v1-4.ckpt -p "a lovely cat"
# ./bin/sd -m ../models/v1-5-pruned-emaonly.safetensors -p "a lovely cat"
# ./bin/sd -m ../models/sd_xl_base_1.0.safetensors --vae ../models/sdxl_vae-fp16-fix.safetensors -H 1024 -W 1024 -p "a lovely cat" -v
+# ./bin/sd -m ../models/sd3_medium_incl_clips_t5xxlfp16.safetensors -H 1024 -W 1024 -p 'a lovely cat holding a sign says \"Stable Diffusion CPP\"' --cfg-scale 4.5 --sampling-method euler -v
```
Using formats of different precisions will yield results of varying quality.
@@ -384,6 +387,7 @@ Thank you to all the people who have already contributed to stable-diffusion.cpp
- [ggml](https://github.com/ggerganov/ggml)
- [stable-diffusion](https://github.com/CompVis/stable-diffusion)
+- [sd3-ref](https://github.com/Stability-AI/sd3-ref)
- [stable-diffusion-stability-ai](https://github.com/Stability-AI/stablediffusion)
- [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
- [ComfyUI](https://github.com/comfyanonymous/ComfyUI)
diff --git a/assets/cat_with_sd_cpp_20184.png b/assets/cat_with_sd_cpp_20184.png
new file mode 100644
index 000000000..04a82bef8
Binary files /dev/null and b/assets/cat_with_sd_cpp_20184.png differ
diff --git a/assets/cat_with_sd_cpp_42.png b/assets/cat_with_sd_cpp_42.png
new file mode 100644
index 000000000..6368d5427
Binary files /dev/null and b/assets/cat_with_sd_cpp_42.png differ
diff --git a/clip.hpp b/clip.hpp
index cf82fdfec..664da58c6 100644
--- a/clip.hpp
+++ b/clip.hpp
@@ -31,16 +31,6 @@ std::pair, std::string> extract_and_remov
return std::make_pair(filename2multiplier, text);
}
-const std::string UNK_TOKEN = "<|endoftext|>";
-const std::string BOS_TOKEN = "<|startoftext|>";
-const std::string EOS_TOKEN = "<|endoftext|>";
-const std::string PAD_TOEKN = "<|endoftext|>";
-
-const int UNK_TOKEN_ID = 49407;
-const int BOS_TOKEN_ID = 49406;
-const int EOS_TOKEN_ID = 49407;
-const int PAD_TOKEN_ID = 49407;
-
std::vector> bytes_to_unicode() {
std::vector> byte_unicode_pairs;
std::set byte_set;
@@ -73,7 +63,6 @@ typedef std::function&)> on_new_token_cb
class CLIPTokenizer {
private:
- SDVersion version = VERSION_1_x;
std::map byte_encoder;
std::map byte_decoder;
std::map encoder;
@@ -83,6 +72,18 @@ class CLIPTokenizer {
int encoder_len;
int bpe_len;
+public:
+ const std::string UNK_TOKEN = "<|endoftext|>";
+ const std::string BOS_TOKEN = "<|startoftext|>";
+ const std::string EOS_TOKEN = "<|endoftext|>";
+ const std::string PAD_TOEKN = "<|endoftext|>";
+
+ const int UNK_TOKEN_ID = 49407;
+ const int BOS_TOKEN_ID = 49406;
+ const int EOS_TOKEN_ID = 49407;
+ const int PAD_TOKEN_ID = 49407;
+
+private:
static std::string strip(const std::string& str) {
std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f");
std::string::size_type end = str.find_last_not_of(" \t\n\r\v\f");
@@ -117,8 +118,14 @@ class CLIPTokenizer {
}
public:
- CLIPTokenizer(SDVersion version = VERSION_1_x)
- : version(version) {}
+ CLIPTokenizer(int pad_token_id = 49407, const std::string& merges_utf8_str = "")
+ : PAD_TOKEN_ID(pad_token_id) {
+ if (merges_utf8_str.size() > 0) {
+ load_from_merges(merges_utf8_str);
+ } else {
+ load_from_merges(ModelLoader::load_merges());
+ }
+ }
void load_from_merges(const std::string& merges_utf8_str) {
auto byte_unicode_pairs = bytes_to_unicode();
@@ -283,11 +290,7 @@ class CLIPTokenizer {
} else {
tokens.push_back(EOS_TOKEN_ID);
if (padding) {
- int pad_token_id = PAD_TOKEN_ID;
- if (version == VERSION_2_x) {
- pad_token_id = 0;
- }
- tokens.insert(tokens.end(), max_length - tokens.size(), pad_token_id);
+ tokens.insert(tokens.end(), max_length - tokens.size(), PAD_TOKEN_ID);
}
}
}
@@ -295,6 +298,51 @@ class CLIPTokenizer {
return tokens;
}
+ void pad_tokens(std::vector& tokens,
+ std::vector& weights,
+ size_t max_length = 0,
+ bool padding = false) {
+ if (max_length > 0 && padding) {
+ size_t n = std::ceil(tokens.size() * 1.0 / (max_length - 2));
+ if (n == 0) {
+ n = 1;
+ }
+ size_t length = max_length * n;
+ LOG_DEBUG("token length: %llu", length);
+ std::vector new_tokens;
+ std::vector new_weights;
+ new_tokens.push_back(BOS_TOKEN_ID);
+ new_weights.push_back(1.0);
+ int token_idx = 0;
+ for (int i = 1; i < length; i++) {
+ if (token_idx >= tokens.size()) {
+ break;
+ }
+ if (i % max_length == 0) {
+ new_tokens.push_back(BOS_TOKEN_ID);
+ new_weights.push_back(1.0);
+ } else if (i % max_length == max_length - 1) {
+ new_tokens.push_back(EOS_TOKEN_ID);
+ new_weights.push_back(1.0);
+ } else {
+ new_tokens.push_back(tokens[token_idx]);
+ new_weights.push_back(weights[token_idx]);
+ token_idx++;
+ }
+ }
+
+ new_tokens.push_back(EOS_TOKEN_ID);
+ new_weights.push_back(1.0);
+ tokens = new_tokens;
+ weights = new_weights;
+
+ if (padding) {
+ tokens.insert(tokens.end(), length - tokens.size(), PAD_TOKEN_ID);
+ weights.insert(weights.end(), length - weights.size(), 1.0);
+ }
+ }
+ }
+
std::string decode(const std::vector& tokens) {
std::string text = "";
for (int t : tokens) {
@@ -371,113 +419,6 @@ class CLIPTokenizer {
}
};
-// Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/prompt_parser.py#L345
-//
-// Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
-// Accepted tokens are:
-// (abc) - increases attention to abc by a multiplier of 1.1
-// (abc:3.12) - increases attention to abc by a multiplier of 3.12
-// [abc] - decreases attention to abc by a multiplier of 1.1
-// \( - literal character '('
-// \[ - literal character '['
-// \) - literal character ')'
-// \] - literal character ']'
-// \\ - literal character '\'
-// anything else - just text
-//
-// >>> parse_prompt_attention('normal text')
-// [['normal text', 1.0]]
-// >>> parse_prompt_attention('an (important) word')
-// [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
-// >>> parse_prompt_attention('(unbalanced')
-// [['unbalanced', 1.1]]
-// >>> parse_prompt_attention('\(literal\]')
-// [['(literal]', 1.0]]
-// >>> parse_prompt_attention('(unnecessary)(parens)')
-// [['unnecessaryparens', 1.1]]
-// >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
-// [['a ', 1.0],
-// ['house', 1.5730000000000004],
-// [' ', 1.1],
-// ['on', 1.0],
-// [' a ', 1.1],
-// ['hill', 0.55],
-// [', sun, ', 1.1],
-// ['sky', 1.4641000000000006],
-// ['.', 1.1]]
-std::vector> parse_prompt_attention(const std::string& text) {
- std::vector> res;
- std::vector round_brackets;
- std::vector square_brackets;
-
- float round_bracket_multiplier = 1.1f;
- float square_bracket_multiplier = 1 / 1.1f;
-
- std::regex re_attention(R"(\\\(|\\\)|\\\[|\\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|\)|\]|[^\\()\[\]:]+|:)");
- std::regex re_break(R"(\s*\bBREAK\b\s*)");
-
- auto multiply_range = [&](int start_position, float multiplier) {
- for (int p = start_position; p < res.size(); ++p) {
- res[p].second *= multiplier;
- }
- };
-
- std::smatch m;
- std::string remaining_text = text;
-
- while (std::regex_search(remaining_text, m, re_attention)) {
- std::string text = m[0];
- std::string weight = m[1];
-
- if (text == "(") {
- round_brackets.push_back((int)res.size());
- } else if (text == "[") {
- square_brackets.push_back((int)res.size());
- } else if (!weight.empty()) {
- if (!round_brackets.empty()) {
- multiply_range(round_brackets.back(), std::stof(weight));
- round_brackets.pop_back();
- }
- } else if (text == ")" && !round_brackets.empty()) {
- multiply_range(round_brackets.back(), round_bracket_multiplier);
- round_brackets.pop_back();
- } else if (text == "]" && !square_brackets.empty()) {
- multiply_range(square_brackets.back(), square_bracket_multiplier);
- square_brackets.pop_back();
- } else if (text == "\\(") {
- res.push_back({text.substr(1), 1.0f});
- } else {
- res.push_back({text, 1.0f});
- }
-
- remaining_text = m.suffix();
- }
-
- for (int pos : round_brackets) {
- multiply_range(pos, round_bracket_multiplier);
- }
-
- for (int pos : square_brackets) {
- multiply_range(pos, square_bracket_multiplier);
- }
-
- if (res.empty()) {
- res.push_back({"", 1.0f});
- }
-
- int i = 0;
- while (i + 1 < res.size()) {
- if (res[i].second == res[i + 1].second) {
- res[i].first += res[i + 1].first;
- res.erase(res.begin() + i + 1);
- } else {
- ++i;
- }
- }
-
- return res;
-}
-
/*================================================ FrozenCLIPEmbedder ================================================*/
// Ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py
@@ -527,7 +468,7 @@ struct CLIPLayer : public GGMLBlock {
: d_model(d_model),
n_head(n_head),
intermediate_size(intermediate_size) {
- blocks["self_attn"] = std::shared_ptr(new MultiheadAttention(d_model, n_head, true));
+ blocks["self_attn"] = std::shared_ptr(new MultiheadAttention(d_model, n_head, true, true));
blocks["layer_norm1"] = std::shared_ptr(new LayerNorm(d_model));
blocks["layer_norm2"] = std::shared_ptr(new LayerNorm(d_model));
@@ -897,42 +838,16 @@ class CLIPVisionModelProjection : public GGMLBlock {
}
};
-// ldm.modules.encoders.modules.FrozenCLIPEmbedder
-// Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283
-struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
- SDVersion version = VERSION_1_x;
- CLIPTokenizer tokenizer;
- CLIPTextModel text_model;
- CLIPTextModel text_model2;
-
- std::string embd_dir;
- int32_t num_custom_embeddings = 0;
- std::vector token_embed_custom;
- std::vector readed_embeddings;
-
- FrozenCLIPEmbedderWithCustomWords(ggml_backend_t backend,
- ggml_type wtype,
- SDVersion version = VERSION_1_x,
- int clip_skip = -1)
- : GGMLModule(backend, wtype), version(version), tokenizer(version) {
- if (clip_skip <= 0) {
- clip_skip = 1;
- if (version == VERSION_2_x || version == VERSION_XL) {
- clip_skip = 2;
- }
- }
- if (version == VERSION_1_x) {
- text_model = CLIPTextModel(OPENAI_CLIP_VIT_L_14, clip_skip);
- text_model.init(params_ctx, wtype);
- } else if (version == VERSION_2_x) {
- text_model = CLIPTextModel(OPEN_CLIP_VIT_H_14, clip_skip);
- text_model.init(params_ctx, wtype);
- } else if (version == VERSION_XL) {
- text_model = CLIPTextModel(OPENAI_CLIP_VIT_L_14, clip_skip, false);
- text_model2 = CLIPTextModel(OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
- text_model.init(params_ctx, wtype);
- text_model2.init(params_ctx, wtype);
- }
+struct CLIPTextModelRunner : public GGMLRunner {
+ CLIPTextModel model;
+
+ CLIPTextModelRunner(ggml_backend_t backend,
+ ggml_type wtype,
+ CLIPVersion version = OPENAI_CLIP_VIT_L_14,
+ int clip_skip_value = 1,
+ bool with_final_ln = true)
+ : GGMLRunner(backend, wtype), model(version, clip_skip_value, with_final_ln) {
+ model.init(params_ctx, wtype);
}
std::string get_desc() {
@@ -940,140 +855,52 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
}
void set_clip_skip(int clip_skip) {
- text_model.set_clip_skip(clip_skip);
- if (version == VERSION_XL) {
- text_model2.set_clip_skip(clip_skip);
- }
+ model.set_clip_skip(clip_skip);
}
void get_param_tensors(std::map& tensors, const std::string prefix) {
- text_model.get_param_tensors(tensors, prefix + "transformer.text_model");
- if (version == VERSION_XL) {
- text_model2.get_param_tensors(tensors, prefix + "1.transformer.text_model");
- }
- }
-
- bool load_embedding(std::string embd_name, std::string embd_path, std::vector& bpe_tokens) {
- // the order matters
- ModelLoader model_loader;
- if (!model_loader.init_from_file(embd_path)) {
- LOG_ERROR("embedding '%s' failed", embd_name.c_str());
- return false;
- }
- if (std::find(readed_embeddings.begin(), readed_embeddings.end(), embd_name) != readed_embeddings.end()) {
- LOG_DEBUG("embedding already read in: %s", embd_name.c_str());
- return true;
- }
- struct ggml_init_params params;
- params.mem_size = 10 * 1024 * 1024; // max for custom embeddings 10 MB
- params.mem_buffer = NULL;
- params.no_alloc = false;
- struct ggml_context* embd_ctx = ggml_init(params);
- struct ggml_tensor* embd = NULL;
- auto on_load = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) {
- if (tensor_storage.ne[0] != text_model.hidden_size) {
- LOG_DEBUG("embedding wrong hidden size, got %i, expected %i", tensor_storage.ne[0], text_model.hidden_size);
- return false;
- }
- embd = ggml_new_tensor_2d(embd_ctx, wtype, text_model.hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne[1] : 1);
- *dst_tensor = embd;
- return true;
- };
- model_loader.load_tensors(on_load, NULL);
- readed_embeddings.push_back(embd_name);
- token_embed_custom.resize(token_embed_custom.size() + ggml_nbytes(embd));
- memcpy((void*)(token_embed_custom.data() + num_custom_embeddings * text_model.hidden_size * ggml_type_size(wtype)),
- embd->data,
- ggml_nbytes(embd));
- for (int i = 0; i < embd->ne[1]; i++) {
- bpe_tokens.push_back(text_model.vocab_size + num_custom_embeddings);
- // LOG_DEBUG("new custom token: %i", text_model.vocab_size + num_custom_embeddings);
- num_custom_embeddings++;
- }
- LOG_DEBUG("embedding '%s' applied, custom embeddings: %i", embd_name.c_str(), num_custom_embeddings);
- return true;
+ model.get_param_tensors(tensors, prefix);
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* input_ids,
- struct ggml_tensor* input_ids2,
struct ggml_tensor* embeddings,
size_t max_token_idx = 0,
bool return_pooled = false) {
size_t N = input_ids->ne[1];
size_t n_token = input_ids->ne[0];
- if (input_ids != NULL && input_ids->ne[0] > text_model.n_token) {
- GGML_ASSERT(input_ids->ne[0] % text_model.n_token == 0);
- input_ids = ggml_reshape_2d(ctx, input_ids, text_model.n_token, input_ids->ne[0] / text_model.n_token);
- }
- if (input_ids2 != NULL && input_ids2->ne[0] > text_model2.n_token) {
- GGML_ASSERT(input_ids2->ne[0] % text_model2.n_token == 0);
- input_ids2 = ggml_reshape_2d(ctx, input_ids2, text_model2.n_token, input_ids2->ne[0] / text_model2.n_token);
- }
-
- if (return_pooled) {
- return text_model2.forward(ctx, input_ids2, NULL, max_token_idx, return_pooled);
+ if (input_ids->ne[0] > model.n_token) {
+ GGML_ASSERT(input_ids->ne[0] % model.n_token == 0);
+ input_ids = ggml_reshape_2d(ctx, input_ids, model.n_token, input_ids->ne[0] / model.n_token);
}
- auto hidden_states = text_model.forward(ctx, input_ids, embeddings); // [N, n_token, hidden_size]
- // LOG_DEBUG("hidden_states: %d %d %d %d", hidden_states->ne[0], hidden_states->ne[1], hidden_states->ne[2], hidden_states->ne[3]);
- if (version == VERSION_XL) {
- hidden_states = ggml_reshape_4d(ctx,
- hidden_states,
- hidden_states->ne[0],
- hidden_states->ne[1],
- hidden_states->ne[2],
- hidden_states->ne[3]);
- hidden_states = ggml_cont(ctx, ggml_permute(ctx, hidden_states, 2, 0, 1, 3));
-
- auto hidden_states2 = text_model2.forward(ctx, input_ids2, NULL); // [N, n_token, hidden_size2]
- // LOG_DEBUG("hidden_states: %d %d %d %d", hidden_states->ne[0], hidden_states->ne[1], hidden_states->ne[2], hidden_states->ne[3]);
- hidden_states2 = ggml_reshape_4d(ctx,
- hidden_states2,
- hidden_states2->ne[0],
- hidden_states2->ne[1],
- hidden_states2->ne[2],
- hidden_states2->ne[3]);
- hidden_states2 = ggml_cont(ctx, ggml_permute(ctx, hidden_states2, 2, 0, 1, 3));
-
- hidden_states = ggml_concat(ctx, hidden_states, hidden_states2, 2); // [N, n_token, hidden_size + hidden_size2]
-
- hidden_states = ggml_cont(ctx, ggml_permute(ctx, hidden_states, 1, 2, 0, 3));
- }
- hidden_states = ggml_reshape_3d(ctx, hidden_states, hidden_states->ne[0], n_token, N);
- // LOG_DEBUG("hidden_states: %d %d %d %d", hidden_states->ne[0], hidden_states->ne[1], hidden_states->ne[2], hidden_states->ne[3]);
- return hidden_states;
+ return model.forward(ctx, input_ids, embeddings, max_token_idx, return_pooled);
}
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids,
- struct ggml_tensor* input_ids2 = NULL,
- size_t max_token_idx = 0,
- bool return_pooled = false) {
+ int num_custom_embeddings = 0,
+ void* custom_embeddings_data = NULL,
+ size_t max_token_idx = 0,
+ bool return_pooled = false) {
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
- input_ids2 = to_backend(input_ids2);
- if (!return_pooled) {
- input_ids = to_backend(input_ids);
- }
+ input_ids = to_backend(input_ids);
struct ggml_tensor* embeddings = NULL;
- if (num_custom_embeddings > 0 && version != VERSION_XL) {
- auto custom_embeddings = ggml_new_tensor_3d(compute_ctx,
+ if (num_custom_embeddings > 0 && custom_embeddings_data != NULL) {
+ auto custom_embeddings = ggml_new_tensor_2d(compute_ctx,
wtype,
- text_model.hidden_size,
- 1,
+ model.hidden_size,
num_custom_embeddings);
- set_backend_tensor_data(custom_embeddings, token_embed_custom.data());
+ set_backend_tensor_data(custom_embeddings, custom_embeddings_data);
- auto token_embed_weight = text_model.get_token_embed_weight();
- token_embed_weight = ggml_reshape_3d(compute_ctx, token_embed_weight, token_embed_weight->ne[0], 1, token_embed_weight->ne[1]);
+ auto token_embed_weight = model.get_token_embed_weight();
// concatenate custom embeddings
- embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings, 2);
- embeddings = ggml_reshape_2d(compute_ctx, embeddings, embeddings->ne[0], embeddings->ne[2]);
+ embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings, 1);
}
- struct ggml_tensor* hidden_states = forward(compute_ctx, input_ids, input_ids2, embeddings, max_token_idx, return_pooled);
+ struct ggml_tensor* hidden_states = forward(compute_ctx, input_ids, embeddings, max_token_idx, return_pooled);
ggml_build_forward_expand(gf, hidden_states);
@@ -1082,317 +909,16 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
void compute(const int n_threads,
struct ggml_tensor* input_ids,
- struct ggml_tensor* input_ids2,
+ int num_custom_embeddings,
+ void* custom_embeddings_data,
size_t max_token_idx,
bool return_pooled,
ggml_tensor** output,
ggml_context* output_ctx = NULL) {
auto get_graph = [&]() -> struct ggml_cgraph* {
- return build_graph(input_ids, input_ids2, max_token_idx, return_pooled);
- };
- GGMLModule::compute(get_graph, n_threads, true, output, output_ctx);
- }
-
- std::pair, std::vector> tokenize(std::string text,
- bool padding = false) {
- return tokenize(text, text_model.n_token, padding);
- }
-
- std::tuple, std::vector, std::vector>
- tokenize_with_trigger_token(std::string text,
- int num_input_imgs,
- int32_t image_token,
- bool padding = false) {
- return tokenize_with_trigger_token(text, num_input_imgs, image_token,
- text_model.n_token, padding);
- }
-
- std::vector convert_token_to_id(std::string text) {
- auto on_new_token_cb = [&](std::string& str, std::vector& bpe_tokens) -> bool {
- size_t word_end = str.find(",");
- std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end);
- embd_name = trim(embd_name);
- std::string embd_path = get_full_path(embd_dir, embd_name + ".pt");
- if (embd_path.size() == 0) {
- embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
- }
- if (embd_path.size() == 0) {
- embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
- }
- if (embd_path.size() > 0) {
- if (load_embedding(embd_name, embd_path, bpe_tokens)) {
- if (word_end != std::string::npos) {
- str = str.substr(word_end);
- } else {
- str = "";
- }
- return true;
- }
- }
- return false;
- };
- std::vector curr_tokens = tokenizer.encode(text, on_new_token_cb);
- return curr_tokens;
- }
-
- std::string decode(const std::vector& tokens) {
- return tokenizer.decode(tokens);
- }
-
- void pad_tokens(std::vector& tokens,
- std::vector& weights,
- size_t max_length = 0,
- bool padding = false) {
- if (max_length > 0 && padding) {
- size_t n = std::ceil(tokens.size() * 1.0 / (max_length - 2));
- if (n == 0) {
- n = 1;
- }
- size_t length = max_length * n;
- LOG_DEBUG("token length: %llu", length);
- std::vector new_tokens;
- std::vector new_weights;
- new_tokens.push_back(BOS_TOKEN_ID);
- new_weights.push_back(1.0);
- int token_idx = 0;
- for (int i = 1; i < length; i++) {
- if (token_idx >= tokens.size()) {
- break;
- }
- if (i % max_length == 0) {
- new_tokens.push_back(BOS_TOKEN_ID);
- new_weights.push_back(1.0);
- } else if (i % max_length == max_length - 1) {
- new_tokens.push_back(EOS_TOKEN_ID);
- new_weights.push_back(1.0);
- } else {
- new_tokens.push_back(tokens[token_idx]);
- new_weights.push_back(weights[token_idx]);
- token_idx++;
- }
- }
-
- new_tokens.push_back(EOS_TOKEN_ID);
- new_weights.push_back(1.0);
- tokens = new_tokens;
- weights = new_weights;
-
- if (padding) {
- int pad_token_id = PAD_TOKEN_ID;
- if (version == VERSION_2_x) {
- pad_token_id = 0;
- }
- tokens.insert(tokens.end(), length - tokens.size(), pad_token_id);
- weights.insert(weights.end(), length - weights.size(), 1.0);
- }
- }
- }
-
- std::tuple, std::vector, std::vector>
- tokenize_with_trigger_token(std::string text,
- int num_input_imgs,
- int32_t image_token,
- size_t max_length = 0,
- bool padding = false) {
- auto parsed_attention = parse_prompt_attention(text);
-
- {
- std::stringstream ss;
- ss << "[";
- for (const auto& item : parsed_attention) {
- ss << "['" << item.first << "', " << item.second << "], ";
- }
- ss << "]";
- LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str());
- }
-
- auto on_new_token_cb = [&](std::string& str, std::vector& bpe_tokens) -> bool {
- size_t word_end = str.find(",");
- std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end);
- embd_name = trim(embd_name);
- std::string embd_path = get_full_path(embd_dir, embd_name + ".pt");
- if (embd_path.size() == 0) {
- embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
- }
- if (embd_path.size() == 0) {
- embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
- }
- if (embd_path.size() > 0) {
- if (load_embedding(embd_name, embd_path, bpe_tokens)) {
- if (word_end != std::string::npos) {
- str = str.substr(word_end);
- } else {
- str = "";
- }
- return true;
- }
- }
- return false;
- };
-
- std::vector tokens;
- std::vector weights;
- std::vector class_token_mask;
- int32_t class_idx = -1, tokens_acc = 0;
- for (const auto& item : parsed_attention) {
- std::vector class_token_index;
- std::vector clean_input_ids;
- const std::string& curr_text = item.first;
- float curr_weight = item.second;
- // printf(" %s: %f \n", curr_text.c_str(), curr_weight);
- std::vector curr_tokens = tokenizer.encode(curr_text, on_new_token_cb);
- int32_t clean_index = 0;
- for (uint32_t i = 0; i < curr_tokens.size(); i++) {
- int token_id = curr_tokens[i];
- if (token_id == image_token)
- class_token_index.push_back(clean_index - 1);
- else {
- clean_input_ids.push_back(token_id);
- clean_index++;
- }
- }
- // GGML_ASSERT(class_token_index.size() == 1); // PhotoMaker currently does not support multiple
- // trigger words in a single prompt.
- if (class_token_index.size() == 1) {
- // Expand the class word token and corresponding mask
- int class_token = clean_input_ids[class_token_index[0]];
- class_idx = tokens_acc + class_token_index[0];
- std::vector clean_input_ids_tmp;
- for (uint32_t i = 0; i < class_token_index[0]; i++)
- clean_input_ids_tmp.push_back(clean_input_ids[i]);
- for (uint32_t i = 0; i < num_input_imgs; i++)
- clean_input_ids_tmp.push_back(class_token);
- for (uint32_t i = class_token_index[0] + 1; i < clean_input_ids.size(); i++)
- clean_input_ids_tmp.push_back(clean_input_ids[i]);
- clean_input_ids.clear();
- clean_input_ids = clean_input_ids_tmp;
- }
- tokens_acc += clean_index;
- tokens.insert(tokens.end(), clean_input_ids.begin(), clean_input_ids.end());
- weights.insert(weights.end(), clean_input_ids.size(), curr_weight);
- }
- tokens.insert(tokens.begin(), BOS_TOKEN_ID);
- weights.insert(weights.begin(), 1.0);
-
- pad_tokens(tokens, weights, max_length, padding);
-
- for (uint32_t i = 0; i < tokens.size(); i++) {
- if (class_idx + 1 <= i && i < class_idx + 1 + num_input_imgs)
- class_token_mask.push_back(true);
- else
- class_token_mask.push_back(false);
- }
-
- // printf("[");
- // for (int i = 0; i < tokens.size(); i++) {
- // printf("%d, ", class_token_mask[i] ? 1 : 0);
- // }
- // printf("]\n");
-
- // for (int i = 0; i < tokens.size(); i++) {
- // std::cout << tokens[i] << ":" << weights[i] << ", ";
- // }
- // std::cout << std::endl;
-
- return std::make_tuple(tokens, weights, class_token_mask);
- }
-
- std::pair, std::vector> tokenize(std::string text,
- size_t max_length = 0,
- bool padding = false) {
- auto parsed_attention = parse_prompt_attention(text);
-
- {
- std::stringstream ss;
- ss << "[";
- for (const auto& item : parsed_attention) {
- ss << "['" << item.first << "', " << item.second << "], ";
- }
- ss << "]";
- LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str());
- }
-
- auto on_new_token_cb = [&](std::string& str, std::vector& bpe_tokens) -> bool {
- size_t word_end = str.find(",");
- std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end);
- embd_name = trim(embd_name);
- std::string embd_path = get_full_path(embd_dir, embd_name + ".pt");
- if (embd_path.size() == 0) {
- embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
- }
- if (embd_path.size() == 0) {
- embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
- }
- if (embd_path.size() > 0) {
- if (load_embedding(embd_name, embd_path, bpe_tokens)) {
- if (word_end != std::string::npos) {
- str = str.substr(word_end);
- } else {
- str = "";
- }
- return true;
- }
- }
- return false;
- };
-
- std::vector tokens;
- std::vector weights;
- for (const auto& item : parsed_attention) {
- const std::string& curr_text = item.first;
- float curr_weight = item.second;
- std::vector curr_tokens = tokenizer.encode(curr_text, on_new_token_cb);
- tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
- weights.insert(weights.end(), curr_tokens.size(), curr_weight);
- }
-
- pad_tokens(tokens, weights, max_length, padding);
-
- // for (int i = 0; i < tokens.size(); i++) {
- // std::cout << tokens[i] << ":" << weights[i] << ", ";
- // }
- // std::cout << std::endl;
-
- return {tokens, weights};
- }
-};
-
-struct FrozenCLIPVisionEmbedder : public GGMLModule {
- CLIPVisionModelProjection vision_model;
-
- FrozenCLIPVisionEmbedder(ggml_backend_t backend, ggml_type wtype)
- : vision_model(OPEN_CLIP_VIT_H_14, true), GGMLModule(backend, wtype) {
- vision_model.init(params_ctx, wtype);
- }
-
- std::string get_desc() {
- return "clip_vision";
- }
-
- void get_param_tensors(std::map& tensors, const std::string prefix) {
- vision_model.get_param_tensors(tensors, prefix + "transformer");
- }
-
- struct ggml_cgraph* build_graph(struct ggml_tensor* pixel_values) {
- struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
-
- pixel_values = to_backend(pixel_values);
-
- struct ggml_tensor* hidden_states = vision_model.forward(compute_ctx, pixel_values);
-
- ggml_build_forward_expand(gf, hidden_states);
-
- return gf;
- }
-
- void compute(const int n_threads,
- ggml_tensor* pixel_values,
- ggml_tensor** output,
- ggml_context* output_ctx) {
- auto get_graph = [&]() -> struct ggml_cgraph* {
- return build_graph(pixel_values);
+ return build_graph(input_ids, num_custom_embeddings, custom_embeddings_data, max_token_idx, return_pooled);
};
- GGMLModule::compute(get_graph, n_threads, true, output, output_ctx);
+ GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
}
};
diff --git a/common.hpp b/common.hpp
index 30b213ee5..bfdcc004c 100644
--- a/common.hpp
+++ b/common.hpp
@@ -279,26 +279,11 @@ class CrossAttention : public GGMLBlock {
int64_t n_context = context->ne[1];
int64_t inner_dim = d_head * n_head;
- auto q = to_q->forward(ctx, x); // [N, n_token, inner_dim]
- q = ggml_reshape_4d(ctx, q, d_head, n_head, n_token, n); // [N, n_token, n_head, d_head]
- q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, n_token, d_head]
- q = ggml_reshape_3d(ctx, q, d_head, n_token, n_head * n); // [N * n_head, n_token, d_head]
-
- auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
- k = ggml_reshape_4d(ctx, k, d_head, n_head, n_context, n); // [N, n_context, n_head, d_head]
- k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, n_context, d_head]
- k = ggml_reshape_3d(ctx, k, d_head, n_context, n_head * n); // [N * n_head, n_context, d_head]
-
- auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]
- v = ggml_reshape_4d(ctx, v, d_head, n_head, n_context, n); // [N, n_context, n_head, d_head]
- v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, n_context]
- v = ggml_reshape_3d(ctx, v, n_context, d_head, n_head * n); // [N * n_head, d_head, n_context]
-
- auto kqv = ggml_nn_attention(ctx, q, k, v, false); // [N * n_head, n_token, d_head]
- kqv = ggml_reshape_4d(ctx, kqv, d_head, n_token, n_head, n);
- kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); // [N, n_token, n_head, d_head]
-
- x = ggml_reshape_3d(ctx, kqv, d_head * n_head, n_token, n); // [N, n_token, inner_dim]
+ auto q = to_q->forward(ctx, x); // [N, n_token, inner_dim]
+ auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
+ auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]
+
+ x = ggml_nn_attention_ext(ctx, q, k, v, n_head, NULL, false); // [N, n_token, inner_dim]
x = to_out_0->forward(ctx, x); // [N, n_token, query_dim]
return x;
diff --git a/conditioner.hpp b/conditioner.hpp
new file mode 100644
index 000000000..e01be2b21
--- /dev/null
+++ b/conditioner.hpp
@@ -0,0 +1,981 @@
+#ifndef __CONDITIONER_HPP__
+#define __CONDITIONER_HPP__
+
+#include "clip.hpp"
+#include "t5.hpp"
+
+struct SDCondition {
+ struct ggml_tensor* c_crossattn = NULL; // aka context
+ struct ggml_tensor* c_vector = NULL; // aka y
+ struct ggml_tensor* c_concat = NULL;
+
+ SDCondition() = default;
+ SDCondition(struct ggml_tensor* c_crossattn, struct ggml_tensor* c_vector, struct ggml_tensor* c_concat) :
+ c_crossattn(c_crossattn), c_vector(c_vector), c_concat(c_concat) {}
+};
+
+struct Conditioner {
+ virtual SDCondition get_learned_condition(ggml_context* work_ctx,
+ int n_threads,
+ const std::string& text,
+ int clip_skip,
+ int width,
+ int height,
+ int adm_in_channels = -1,
+ bool force_zero_embeddings = false) = 0;
+ virtual void alloc_params_buffer() = 0;
+ virtual void free_params_buffer() = 0;
+ virtual void get_param_tensors(std::map& tensors) = 0;
+ virtual size_t get_params_buffer_size() = 0;
+ virtual std::tuple> get_learned_condition_with_trigger(ggml_context* work_ctx,
+ int n_threads,
+ const std::string& text,
+ int clip_skip,
+ int width,
+ int height,
+ int num_input_imgs,
+ int adm_in_channels = -1,
+ bool force_zero_embeddings = false) = 0;
+ virtual std::string remove_trigger_from_prompt(ggml_context* work_ctx,
+ const std::string& prompt) = 0;
+};
+
+// ldm.modules.encoders.modules.FrozenCLIPEmbedder
+// Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283
+struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
+ SDVersion version = VERSION_1_x;
+ CLIPTokenizer tokenizer;
+ ggml_type wtype;
+ std::shared_ptr text_model;
+ std::shared_ptr text_model2;
+
+ std::string trigger_word = "img"; // should be user settable
+ std::string embd_dir;
+ int32_t num_custom_embeddings = 0;
+ std::vector token_embed_custom;
+ std::vector readed_embeddings;
+
+ FrozenCLIPEmbedderWithCustomWords(ggml_backend_t backend,
+ ggml_type wtype,
+ const std::string& embd_dir,
+ SDVersion version = VERSION_1_x,
+ int clip_skip = -1)
+ : version(version), tokenizer(version == VERSION_2_x ? 0 : 49407), embd_dir(embd_dir), wtype(wtype) {
+ if (clip_skip <= 0) {
+ clip_skip = 1;
+ if (version == VERSION_2_x || version == VERSION_XL) {
+ clip_skip = 2;
+ }
+ }
+ if (version == VERSION_1_x) {
+ text_model = std::make_shared(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip);
+ } else if (version == VERSION_2_x) {
+ text_model = std::make_shared(backend, wtype, OPEN_CLIP_VIT_H_14, clip_skip);
+ } else if (version == VERSION_XL) {
+ text_model = std::make_shared(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip, false);
+ text_model2 = std::make_shared(backend, wtype, OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
+ }
+ }
+
+ void set_clip_skip(int clip_skip) {
+ text_model->set_clip_skip(clip_skip);
+ if (version == VERSION_XL) {
+ text_model2->set_clip_skip(clip_skip);
+ }
+ }
+
+ void get_param_tensors(std::map& tensors) {
+ text_model->get_param_tensors(tensors, "cond_stage_model.transformer.text_model");
+ if (version == VERSION_XL) {
+ text_model2->get_param_tensors(tensors, "cond_stage_model.1.transformer.text_model");
+ }
+ }
+
+ void alloc_params_buffer() {
+ text_model->alloc_params_buffer();
+ if (version == VERSION_XL) {
+ text_model2->alloc_params_buffer();
+ }
+ }
+
+ void free_params_buffer() {
+ text_model->free_params_buffer();
+ if (version == VERSION_XL) {
+ text_model2->free_params_buffer();
+ }
+ }
+
+ size_t get_params_buffer_size() {
+ size_t buffer_size = text_model->get_params_buffer_size();
+ if (version == VERSION_XL) {
+ buffer_size += text_model2->get_params_buffer_size();
+ }
+ return buffer_size;
+ }
+
+ bool load_embedding(std::string embd_name, std::string embd_path, std::vector& bpe_tokens) {
+ // the order matters
+ ModelLoader model_loader;
+ if (!model_loader.init_from_file(embd_path)) {
+ LOG_ERROR("embedding '%s' failed", embd_name.c_str());
+ return false;
+ }
+ if (std::find(readed_embeddings.begin(), readed_embeddings.end(), embd_name) != readed_embeddings.end()) {
+ LOG_DEBUG("embedding already read in: %s", embd_name.c_str());
+ return true;
+ }
+ struct ggml_init_params params;
+ params.mem_size = 10 * 1024 * 1024; // max for custom embeddings 10 MB
+ params.mem_buffer = NULL;
+ params.no_alloc = false;
+ struct ggml_context* embd_ctx = ggml_init(params);
+ struct ggml_tensor* embd = NULL;
+ int64_t hidden_size = text_model->model.hidden_size;
+ auto on_load = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) {
+ if (tensor_storage.ne[0] != hidden_size) {
+ LOG_DEBUG("embedding wrong hidden size, got %i, expected %i", tensor_storage.ne[0], hidden_size);
+ return false;
+ }
+ embd = ggml_new_tensor_2d(embd_ctx, wtype, hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne[1] : 1);
+ *dst_tensor = embd;
+ return true;
+ };
+ model_loader.load_tensors(on_load, NULL);
+ readed_embeddings.push_back(embd_name);
+ token_embed_custom.resize(token_embed_custom.size() + ggml_nbytes(embd));
+ memcpy((void*)(token_embed_custom.data() + num_custom_embeddings * hidden_size * ggml_type_size(wtype)),
+ embd->data,
+ ggml_nbytes(embd));
+ for (int i = 0; i < embd->ne[1]; i++) {
+ bpe_tokens.push_back(text_model->model.vocab_size + num_custom_embeddings);
+ // LOG_DEBUG("new custom token: %i", text_model.vocab_size + num_custom_embeddings);
+ num_custom_embeddings++;
+ }
+ LOG_DEBUG("embedding '%s' applied, custom embeddings: %i", embd_name.c_str(), num_custom_embeddings);
+ return true;
+ }
+
+ std::tuple, std::vector, std::vector>
+ tokenize_with_trigger_token(std::string text,
+ int num_input_imgs,
+ int32_t image_token,
+ bool padding = false) {
+ return tokenize_with_trigger_token(text, num_input_imgs, image_token,
+ text_model->model.n_token, padding);
+ }
+
+ std::vector convert_token_to_id(std::string text) {
+ auto on_new_token_cb = [&](std::string& str, std::vector& bpe_tokens) -> bool {
+ size_t word_end = str.find(",");
+ std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end);
+ embd_name = trim(embd_name);
+ std::string embd_path = get_full_path(embd_dir, embd_name + ".pt");
+ if (embd_path.size() == 0) {
+ embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
+ }
+ if (embd_path.size() == 0) {
+ embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
+ }
+ if (embd_path.size() > 0) {
+ if (load_embedding(embd_name, embd_path, bpe_tokens)) {
+ if (word_end != std::string::npos) {
+ str = str.substr(word_end);
+ } else {
+ str = "";
+ }
+ return true;
+ }
+ }
+ return false;
+ };
+ std::vector curr_tokens = tokenizer.encode(text, on_new_token_cb);
+ return curr_tokens;
+ }
+
+ std::string decode(const std::vector& tokens) {
+ return tokenizer.decode(tokens);
+ }
+
+ std::tuple, std::vector, std::vector>
+ tokenize_with_trigger_token(std::string text,
+ int num_input_imgs,
+ int32_t image_token,
+ size_t max_length = 0,
+ bool padding = false) {
+ auto parsed_attention = parse_prompt_attention(text);
+
+ {
+ std::stringstream ss;
+ ss << "[";
+ for (const auto& item : parsed_attention) {
+ ss << "['" << item.first << "', " << item.second << "], ";
+ }
+ ss << "]";
+ LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str());
+ }
+
+ auto on_new_token_cb = [&](std::string& str, std::vector& bpe_tokens) -> bool {
+ size_t word_end = str.find(",");
+ std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end);
+ embd_name = trim(embd_name);
+ std::string embd_path = get_full_path(embd_dir, embd_name + ".pt");
+ if (embd_path.size() == 0) {
+ embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
+ }
+ if (embd_path.size() == 0) {
+ embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
+ }
+ if (embd_path.size() > 0) {
+ if (load_embedding(embd_name, embd_path, bpe_tokens)) {
+ if (word_end != std::string::npos) {
+ str = str.substr(word_end);
+ } else {
+ str = "";
+ }
+ return true;
+ }
+ }
+ return false;
+ };
+
+ std::vector tokens;
+ std::vector weights;
+ std::vector class_token_mask;
+ int32_t class_idx = -1, tokens_acc = 0;
+ for (const auto& item : parsed_attention) {
+ std::vector class_token_index;
+ std::vector clean_input_ids;
+ const std::string& curr_text = item.first;
+ float curr_weight = item.second;
+ // printf(" %s: %f \n", curr_text.c_str(), curr_weight);
+ std::vector curr_tokens = tokenizer.encode(curr_text, on_new_token_cb);
+ int32_t clean_index = 0;
+ for (uint32_t i = 0; i < curr_tokens.size(); i++) {
+ int token_id = curr_tokens[i];
+ if (token_id == image_token)
+ class_token_index.push_back(clean_index - 1);
+ else {
+ clean_input_ids.push_back(token_id);
+ clean_index++;
+ }
+ }
+ // GGML_ASSERT(class_token_index.size() == 1); // PhotoMaker currently does not support multiple
+ // trigger words in a single prompt.
+ if (class_token_index.size() == 1) {
+ // Expand the class word token and corresponding mask
+ int class_token = clean_input_ids[class_token_index[0]];
+ class_idx = tokens_acc + class_token_index[0];
+ std::vector clean_input_ids_tmp;
+ for (uint32_t i = 0; i < class_token_index[0]; i++)
+ clean_input_ids_tmp.push_back(clean_input_ids[i]);
+ for (uint32_t i = 0; i < num_input_imgs; i++)
+ clean_input_ids_tmp.push_back(class_token);
+ for (uint32_t i = class_token_index[0] + 1; i < clean_input_ids.size(); i++)
+ clean_input_ids_tmp.push_back(clean_input_ids[i]);
+ clean_input_ids.clear();
+ clean_input_ids = clean_input_ids_tmp;
+ }
+ tokens_acc += clean_index;
+ tokens.insert(tokens.end(), clean_input_ids.begin(), clean_input_ids.end());
+ weights.insert(weights.end(), clean_input_ids.size(), curr_weight);
+ }
+ tokens.insert(tokens.begin(), tokenizer.BOS_TOKEN_ID);
+ weights.insert(weights.begin(), 1.0);
+
+ tokenizer.pad_tokens(tokens, weights, max_length, padding);
+
+ for (uint32_t i = 0; i < tokens.size(); i++) {
+ if (class_idx + 1 <= i && i < class_idx + 1 + num_input_imgs)
+ class_token_mask.push_back(true);
+ else
+ class_token_mask.push_back(false);
+ }
+
+ // printf("[");
+ // for (int i = 0; i < tokens.size(); i++) {
+ // printf("%d, ", class_token_mask[i] ? 1 : 0);
+ // }
+ // printf("]\n");
+
+ // for (int i = 0; i < tokens.size(); i++) {
+ // std::cout << tokens[i] << ":" << weights[i] << ", ";
+ // }
+ // std::cout << std::endl;
+
+ return std::make_tuple(tokens, weights, class_token_mask);
+ }
+
+ std::pair, std::vector> tokenize(std::string text,
+ bool padding = false) {
+ return tokenize(text, text_model->model.n_token, padding);
+ }
+
+ std::pair, std::vector> tokenize(std::string text,
+ size_t max_length = 0,
+ bool padding = false) {
+ auto parsed_attention = parse_prompt_attention(text);
+
+ {
+ std::stringstream ss;
+ ss << "[";
+ for (const auto& item : parsed_attention) {
+ ss << "['" << item.first << "', " << item.second << "], ";
+ }
+ ss << "]";
+ LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str());
+ }
+
+ auto on_new_token_cb = [&](std::string& str, std::vector& bpe_tokens) -> bool {
+ size_t word_end = str.find(",");
+ std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end);
+ embd_name = trim(embd_name);
+ std::string embd_path = get_full_path(embd_dir, embd_name + ".pt");
+ if (embd_path.size() == 0) {
+ embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
+ }
+ if (embd_path.size() == 0) {
+ embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
+ }
+ if (embd_path.size() > 0) {
+ if (load_embedding(embd_name, embd_path, bpe_tokens)) {
+ if (word_end != std::string::npos) {
+ str = str.substr(word_end);
+ } else {
+ str = "";
+ }
+ return true;
+ }
+ }
+ return false;
+ };
+
+ std::vector tokens;
+ std::vector weights;
+ for (const auto& item : parsed_attention) {
+ const std::string& curr_text = item.first;
+ float curr_weight = item.second;
+ std::vector curr_tokens = tokenizer.encode(curr_text, on_new_token_cb);
+ tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
+ weights.insert(weights.end(), curr_tokens.size(), curr_weight);
+ }
+
+ tokenizer.pad_tokens(tokens, weights, max_length, padding);
+
+ // for (int i = 0; i < tokens.size(); i++) {
+ // std::cout << tokens[i] << ":" << weights[i] << ", ";
+ // }
+ // std::cout << std::endl;
+
+ return {tokens, weights};
+ }
+
+ SDCondition get_learned_condition_common(ggml_context* work_ctx,
+ int n_threads,
+ std::vector& tokens,
+ std::vector& weights,
+ int clip_skip,
+ int width,
+ int height,
+ int adm_in_channels = -1,
+ bool force_zero_embeddings = false) {
+ set_clip_skip(clip_skip);
+ int64_t t0 = ggml_time_ms();
+ struct ggml_tensor* hidden_states = NULL; // [N, n_token, hidden_size]
+ struct ggml_tensor* chunk_hidden_states = NULL; // [n_token, hidden_size] or [n_token, hidden_size + hidden_size2]
+ struct ggml_tensor* chunk_hidden_states1 = NULL; // [n_token, hidden_size]
+ struct ggml_tensor* chunk_hidden_states2 = NULL; // [n_token, hidden_size2]
+ struct ggml_tensor* pooled = NULL;
+ std::vector hidden_states_vec;
+
+ size_t chunk_len = 77;
+ size_t chunk_count = tokens.size() / chunk_len;
+ for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) {
+ std::vector chunk_tokens(tokens.begin() + chunk_idx * chunk_len,
+ tokens.begin() + (chunk_idx + 1) * chunk_len);
+ std::vector chunk_weights(weights.begin() + chunk_idx * chunk_len,
+ weights.begin() + (chunk_idx + 1) * chunk_len);
+
+ auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
+ struct ggml_tensor* input_ids2 = NULL;
+ size_t max_token_idx = 0;
+ if (version == VERSION_XL) {
+ auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), tokenizer.EOS_TOKEN_ID);
+ if (it != chunk_tokens.end()) {
+ std::fill(std::next(it), chunk_tokens.end(), 0);
+ }
+
+ max_token_idx = std::min(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
+
+ input_ids2 = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
+
+ // for (int i = 0; i < chunk_tokens.size(); i++) {
+ // printf("%d ", chunk_tokens[i]);
+ // }
+ // printf("\n");
+ }
+
+ {
+ text_model->compute(n_threads,
+ input_ids,
+ num_custom_embeddings,
+ token_embed_custom.data(),
+ max_token_idx,
+ false,
+ &chunk_hidden_states1,
+ work_ctx);
+ if (version == VERSION_XL) {
+ text_model2->compute(n_threads,
+ input_ids2,
+ 0,
+ NULL,
+ max_token_idx,
+ false,
+ &chunk_hidden_states2, work_ctx);
+ // concat
+ chunk_hidden_states = ggml_tensor_concat(work_ctx, chunk_hidden_states1, chunk_hidden_states2, 0);
+
+ if (chunk_idx == 0) {
+ text_model2->compute(n_threads,
+ input_ids2,
+ 0,
+ NULL,
+ max_token_idx,
+ true,
+ &pooled,
+ work_ctx);
+ }
+ } else {
+ chunk_hidden_states = chunk_hidden_states1;
+ }
+ }
+
+ int64_t t1 = ggml_time_ms();
+ LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
+ ggml_tensor* result = ggml_dup_tensor(work_ctx, chunk_hidden_states);
+ {
+ float original_mean = ggml_tensor_mean(chunk_hidden_states);
+ for (int i2 = 0; i2 < chunk_hidden_states->ne[2]; i2++) {
+ for (int i1 = 0; i1 < chunk_hidden_states->ne[1]; i1++) {
+ for (int i0 = 0; i0 < chunk_hidden_states->ne[0]; i0++) {
+ float value = ggml_tensor_get_f32(chunk_hidden_states, i0, i1, i2);
+ value *= chunk_weights[i1];
+ ggml_tensor_set_f32(result, value, i0, i1, i2);
+ }
+ }
+ }
+ float new_mean = ggml_tensor_mean(result);
+ ggml_tensor_scale(result, (original_mean / new_mean));
+ }
+ if (force_zero_embeddings) {
+ float* vec = (float*)result->data;
+ for (int i = 0; i < ggml_nelements(result); i++) {
+ vec[i] = 0;
+ }
+ }
+ hidden_states_vec.insert(hidden_states_vec.end(), (float*)result->data, ((float*)result->data) + ggml_nelements(result));
+ }
+
+ hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
+ hidden_states = ggml_reshape_2d(work_ctx,
+ hidden_states,
+ chunk_hidden_states->ne[0],
+ ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
+
+ ggml_tensor* vec = NULL;
+ if (version == VERSION_XL) {
+ int out_dim = 256;
+ vec = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, adm_in_channels);
+ // [0:1280]
+ size_t offset = 0;
+ memcpy(vec->data, pooled->data, ggml_nbytes(pooled));
+ offset += ggml_nbytes(pooled);
+
+ // original_size_as_tuple
+ float orig_width = (float)width;
+ float orig_height = (float)height;
+ std::vector timesteps = {orig_height, orig_width};
+
+ ggml_tensor* embed_view = ggml_view_2d(work_ctx, vec, out_dim, 2, ggml_type_size(GGML_TYPE_F32) * out_dim, offset);
+ offset += ggml_nbytes(embed_view);
+ set_timestep_embedding(timesteps, embed_view, out_dim);
+ // print_ggml_tensor(ggml_reshape_1d(work_ctx, embed_view, out_dim * 2));
+ // crop_coords_top_left
+ float crop_coord_top = 0.f;
+ float crop_coord_left = 0.f;
+ timesteps = {crop_coord_top, crop_coord_left};
+ embed_view = ggml_view_2d(work_ctx, vec, out_dim, 2, ggml_type_size(GGML_TYPE_F32) * out_dim, offset);
+ offset += ggml_nbytes(embed_view);
+ set_timestep_embedding(timesteps, embed_view, out_dim);
+ // print_ggml_tensor(ggml_reshape_1d(work_ctx, embed_view, out_dim * 2));
+ // target_size_as_tuple
+ float target_width = (float)width;
+ float target_height = (float)height;
+ timesteps = {target_height, target_width};
+ embed_view = ggml_view_2d(work_ctx, vec, out_dim, 2, ggml_type_size(GGML_TYPE_F32) * out_dim, offset);
+ offset += ggml_nbytes(embed_view);
+ set_timestep_embedding(timesteps, embed_view, out_dim);
+ // print_ggml_tensor(ggml_reshape_1d(work_ctx, embed_view, out_dim * 2));
+ GGML_ASSERT(offset == ggml_nbytes(vec));
+ }
+ // print_ggml_tensor(result);
+ return SDCondition(hidden_states, vec, NULL);
+ }
+
+ std::tuple>
+ get_learned_condition_with_trigger(ggml_context* work_ctx,
+ int n_threads,
+ const std::string& text,
+ int clip_skip,
+ int width,
+ int height,
+ int num_input_imgs,
+ int adm_in_channels = -1,
+ bool force_zero_embeddings = false) {
+ auto image_tokens = convert_token_to_id(trigger_word);
+ // if(image_tokens.size() == 1){
+ // printf(" image token id is: %d \n", image_tokens[0]);
+ // }
+ GGML_ASSERT(image_tokens.size() == 1);
+ auto tokens_and_weights = tokenize_with_trigger_token(text,
+ num_input_imgs,
+ image_tokens[0],
+ true);
+ std::vector& tokens = std::get<0>(tokens_and_weights);
+ std::vector& weights = std::get<1>(tokens_and_weights);
+ std::vector& clsm = std::get<2>(tokens_and_weights);
+ // printf("tokens: \n");
+ // for(int i = 0; i < tokens.size(); ++i)
+ // printf("%d ", tokens[i]);
+ // printf("\n");
+ // printf("clsm: \n");
+ // for(int i = 0; i < clsm.size(); ++i)
+ // printf("%d ", clsm[i]?1:0);
+ // printf("\n");
+ auto cond = get_learned_condition_common(work_ctx, n_threads, tokens, weights, clip_skip, width, height, adm_in_channels, force_zero_embeddings);
+ return std::make_tuple(cond, clsm);
+ }
+
+ std::string remove_trigger_from_prompt(ggml_context* work_ctx,
+ const std::string& prompt) {
+ auto image_tokens = convert_token_to_id(trigger_word);
+ GGML_ASSERT(image_tokens.size() == 1);
+ auto tokens_and_weights = tokenize(prompt, false);
+ std::vector& tokens = tokens_and_weights.first;
+ auto it = std::find(tokens.begin(), tokens.end(), image_tokens[0]);
+ GGML_ASSERT(it != tokens.end()); // prompt must have trigger word
+ tokens.erase(it);
+ return decode(tokens);
+ }
+
+ SDCondition get_learned_condition(ggml_context* work_ctx,
+ int n_threads,
+ const std::string& text,
+ int clip_skip,
+ int width,
+ int height,
+ int adm_in_channels = -1,
+ bool force_zero_embeddings = false) {
+ auto tokens_and_weights = tokenize(text, true);
+ std::vector& tokens = tokens_and_weights.first;
+ std::vector& weights = tokens_and_weights.second;
+ return get_learned_condition_common(work_ctx, n_threads, tokens, weights, clip_skip, width, height, adm_in_channels, force_zero_embeddings);
+ }
+};
+
+struct FrozenCLIPVisionEmbedder : public GGMLRunner {
+ CLIPVisionModelProjection vision_model;
+
+ FrozenCLIPVisionEmbedder(ggml_backend_t backend, ggml_type wtype)
+ : vision_model(OPEN_CLIP_VIT_H_14, true), GGMLRunner(backend, wtype) {
+ vision_model.init(params_ctx, wtype);
+ }
+
+ std::string get_desc() {
+ return "clip_vision";
+ }
+
+ void get_param_tensors(std::map& tensors) {
+ vision_model.get_param_tensors(tensors, "cond_stage_model.transformer");
+ }
+
+ struct ggml_cgraph* build_graph(struct ggml_tensor* pixel_values) {
+ struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
+
+ pixel_values = to_backend(pixel_values);
+
+ struct ggml_tensor* hidden_states = vision_model.forward(compute_ctx, pixel_values);
+
+ ggml_build_forward_expand(gf, hidden_states);
+
+ return gf;
+ }
+
+ void compute(const int n_threads,
+ ggml_tensor* pixel_values,
+ ggml_tensor** output,
+ ggml_context* output_ctx) {
+ auto get_graph = [&]() -> struct ggml_cgraph* {
+ return build_graph(pixel_values);
+ };
+ GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
+ }
+};
+
+struct SD3CLIPEmbedder : public Conditioner {
+ ggml_type wtype;
+ CLIPTokenizer clip_l_tokenizer;
+ CLIPTokenizer clip_g_tokenizer;
+ T5UniGramTokenizer t5_tokenizer;
+ std::shared_ptr clip_l;
+ std::shared_ptr clip_g;
+ std::shared_ptr t5;
+
+ SD3CLIPEmbedder(ggml_backend_t backend,
+ ggml_type wtype,
+ int clip_skip = -1)
+ : wtype(wtype), clip_g_tokenizer(0) {
+ if (clip_skip <= 0) {
+ clip_skip = 2;
+ }
+ clip_l = std::make_shared(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip, false);
+ clip_g = std::make_shared(backend, wtype, OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
+ t5 = std::make_shared(backend, wtype);
+ }
+
+ void set_clip_skip(int clip_skip) {
+ clip_l->set_clip_skip(clip_skip);
+ clip_g->set_clip_skip(clip_skip);
+ }
+
+ void get_param_tensors(std::map& tensors) {
+ clip_l->get_param_tensors(tensors, "text_encoders.clip_l.transformer.text_model");
+ clip_g->get_param_tensors(tensors, "text_encoders.clip_g.transformer.text_model");
+ t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
+ }
+
+ void alloc_params_buffer() {
+ clip_l->alloc_params_buffer();
+ clip_g->alloc_params_buffer();
+ t5->alloc_params_buffer();
+ }
+
+ void free_params_buffer() {
+ clip_l->free_params_buffer();
+ clip_g->free_params_buffer();
+ t5->free_params_buffer();
+ }
+
+ size_t get_params_buffer_size() {
+ size_t buffer_size = clip_l->get_params_buffer_size();
+ buffer_size += clip_g->get_params_buffer_size();
+ buffer_size += t5->get_params_buffer_size();
+ return buffer_size;
+ }
+
+ std::vector, std::vector>> tokenize(std::string text,
+ size_t max_length = 0,
+ bool padding = false) {
+ auto parsed_attention = parse_prompt_attention(text);
+
+ {
+ std::stringstream ss;
+ ss << "[";
+ for (const auto& item : parsed_attention) {
+ ss << "['" << item.first << "', " << item.second << "], ";
+ }
+ ss << "]";
+ LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str());
+ }
+
+ auto on_new_token_cb = [&](std::string& str, std::vector& bpe_tokens) -> bool {
+ return false;
+ };
+
+ std::vector clip_l_tokens;
+ std::vector clip_l_weights;
+ std::vector clip_g_tokens;
+ std::vector clip_g_weights;
+ std::vector t5_tokens;
+ std::vector t5_weights;
+ for (const auto& item : parsed_attention) {
+ const std::string& curr_text = item.first;
+ float curr_weight = item.second;
+
+ std::vector curr_tokens = clip_l_tokenizer.encode(curr_text, on_new_token_cb);
+ clip_l_tokens.insert(clip_l_tokens.end(), curr_tokens.begin(), curr_tokens.end());
+ clip_l_weights.insert(clip_l_weights.end(), curr_tokens.size(), curr_weight);
+
+ curr_tokens = clip_g_tokenizer.encode(curr_text, on_new_token_cb);
+ clip_g_tokens.insert(clip_g_tokens.end(), curr_tokens.begin(), curr_tokens.end());
+ clip_g_weights.insert(clip_g_weights.end(), curr_tokens.size(), curr_weight);
+
+ curr_tokens = t5_tokenizer.Encode(curr_text, true);
+ t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
+ t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
+ }
+
+ clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, max_length, padding);
+ clip_g_tokenizer.pad_tokens(clip_g_tokens, clip_g_weights, max_length, padding);
+ t5_tokenizer.pad_tokens(t5_tokens, t5_weights, max_length, padding);
+
+ // for (int i = 0; i < clip_l_tokens.size(); i++) {
+ // std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", ";
+ // }
+ // std::cout << std::endl;
+
+ // for (int i = 0; i < clip_g_tokens.size(); i++) {
+ // std::cout << clip_g_tokens[i] << ":" << clip_g_weights[i] << ", ";
+ // }
+ // std::cout << std::endl;
+
+ // for (int i = 0; i < t5_tokens.size(); i++) {
+ // std::cout << t5_tokens[i] << ":" << t5_weights[i] << ", ";
+ // }
+ // std::cout << std::endl;
+
+ return {{clip_l_tokens, clip_l_weights}, {clip_g_tokens, clip_g_weights}, {t5_tokens, t5_weights}};
+ }
+
+ SDCondition get_learned_condition_common(ggml_context* work_ctx,
+ int n_threads,
+ std::vector, std::vector>> token_and_weights,
+ int clip_skip,
+ bool force_zero_embeddings = false) {
+ set_clip_skip(clip_skip);
+ auto& clip_l_tokens = token_and_weights[0].first;
+ auto& clip_l_weights = token_and_weights[0].second;
+ auto& clip_g_tokens = token_and_weights[1].first;
+ auto& clip_g_weights = token_and_weights[1].second;
+ auto& t5_tokens = token_and_weights[2].first;
+ auto& t5_weights = token_and_weights[2].second;
+
+ int64_t t0 = ggml_time_ms();
+ struct ggml_tensor* hidden_states = NULL; // [N, n_token*2, 4096]
+ struct ggml_tensor* chunk_hidden_states = NULL; // [n_token*2, 4096]
+ struct ggml_tensor* chunk_hidden_states_l = NULL; // [n_token, hidden_size_l]
+ struct ggml_tensor* chunk_hidden_states_g = NULL; // [n_token, hidden_size_g]
+ struct ggml_tensor* chunk_hidden_states_t5 = NULL; // [n_token, hidden_size_t5]
+ struct ggml_tensor* pooled = NULL;
+ struct ggml_tensor* pooled_l = NULL; // [768,]
+ struct ggml_tensor* pooled_g = NULL; // [1280,]
+ std::vector hidden_states_vec;
+
+ size_t chunk_len = 77;
+ size_t chunk_count = clip_l_tokens.size() / chunk_len;
+ for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) {
+ // clip_l
+ {
+ std::vector chunk_tokens(clip_l_tokens.begin() + chunk_idx * chunk_len,
+ clip_l_tokens.begin() + (chunk_idx + 1) * chunk_len);
+ std::vector chunk_weights(clip_l_weights.begin() + chunk_idx * chunk_len,
+ clip_l_weights.begin() + (chunk_idx + 1) * chunk_len);
+
+ auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
+ size_t max_token_idx = 0;
+
+ clip_l->compute(n_threads,
+ input_ids,
+ 0,
+ NULL,
+ max_token_idx,
+ false,
+ &chunk_hidden_states_l,
+ work_ctx);
+ {
+ auto tensor = chunk_hidden_states_l;
+ float original_mean = ggml_tensor_mean(tensor);
+ for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
+ for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
+ for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
+ float value = ggml_tensor_get_f32(tensor, i0, i1, i2);
+ value *= chunk_weights[i1];
+ ggml_tensor_set_f32(tensor, value, i0, i1, i2);
+ }
+ }
+ }
+ float new_mean = ggml_tensor_mean(tensor);
+ ggml_tensor_scale(tensor, (original_mean / new_mean));
+ }
+
+ if (chunk_idx == 0) {
+ // auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
+ // max_token_idx = std::min(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
+ // clip_l->compute(n_threads,
+ // input_ids,
+ // 0,
+ // NULL,
+ // max_token_idx,
+ // true,
+ // &pooled_l,
+ // work_ctx);
+
+ // clip_l.transformer.text_model.text_projection no in file, ignore
+ // TODO: use torch.eye(embed_dim) as default clip_l.transformer.text_model.text_projection
+ pooled_l = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 768);
+ ggml_set_f32(pooled_l, 0.f);
+ }
+ }
+
+ // clip_g
+ {
+ std::vector chunk_tokens(clip_g_tokens.begin() + chunk_idx * chunk_len,
+ clip_g_tokens.begin() + (chunk_idx + 1) * chunk_len);
+ std::vector chunk_weights(clip_g_weights.begin() + chunk_idx * chunk_len,
+ clip_g_weights.begin() + (chunk_idx + 1) * chunk_len);
+
+ auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
+ size_t max_token_idx = 0;
+
+ clip_g->compute(n_threads,
+ input_ids,
+ 0,
+ NULL,
+ max_token_idx,
+ false,
+ &chunk_hidden_states_g,
+ work_ctx);
+
+ {
+ auto tensor = chunk_hidden_states_g;
+ float original_mean = ggml_tensor_mean(tensor);
+ for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
+ for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
+ for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
+ float value = ggml_tensor_get_f32(tensor, i0, i1, i2);
+ value *= chunk_weights[i1];
+ ggml_tensor_set_f32(tensor, value, i0, i1, i2);
+ }
+ }
+ }
+ float new_mean = ggml_tensor_mean(tensor);
+ ggml_tensor_scale(tensor, (original_mean / new_mean));
+ }
+
+ if (chunk_idx == 0) {
+ // auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_g_tokenizer.EOS_TOKEN_ID);
+ // max_token_idx = std::min(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
+ // clip_g->compute(n_threads,
+ // input_ids,
+ // 0,
+ // NULL,
+ // max_token_idx,
+ // true,
+ // &pooled_g,
+ // work_ctx);
+ // clip_l.transformer.text_model.text_projection no in file, ignore pooled_g too
+
+ // TODO: fix pooled_g
+ pooled_g = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 1280);
+ ggml_set_f32(pooled_g, 0.f);
+ }
+ }
+
+ // t5
+ {
+ std::vector chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len,
+ t5_tokens.begin() + (chunk_idx + 1) * chunk_len);
+ std::vector chunk_weights(t5_weights.begin() + chunk_idx * chunk_len,
+ t5_weights.begin() + (chunk_idx + 1) * chunk_len);
+
+ auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
+
+ t5->compute(n_threads,
+ input_ids,
+ &chunk_hidden_states_t5,
+ work_ctx);
+ {
+ auto tensor = chunk_hidden_states_t5;
+ float original_mean = ggml_tensor_mean(tensor);
+ for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
+ for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
+ for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
+ float value = ggml_tensor_get_f32(tensor, i0, i1, i2);
+ value *= chunk_weights[i1];
+ ggml_tensor_set_f32(tensor, value, i0, i1, i2);
+ }
+ }
+ }
+ float new_mean = ggml_tensor_mean(tensor);
+ ggml_tensor_scale(tensor, (original_mean / new_mean));
+ }
+ }
+
+ auto chunk_hidden_states_lg_pad = ggml_new_tensor_3d(work_ctx,
+ chunk_hidden_states_l->type,
+ 4096,
+ chunk_hidden_states_l->ne[1],
+ chunk_hidden_states_l->ne[2]); // [n_token, 4096]
+
+ for (int i2 = 0; i2 < chunk_hidden_states_lg_pad->ne[2]; i2++) {
+ for (int i1 = 0; i1 < chunk_hidden_states_lg_pad->ne[1]; i1++) {
+ for (int i0 = 0; i0 < chunk_hidden_states_lg_pad->ne[0]; i0++) {
+ float value = 0.f;
+ if (i0 < chunk_hidden_states_l->ne[0]) {
+ value = ggml_tensor_get_f32(chunk_hidden_states_l, i0, i1, i2);
+ } else if (i0 < chunk_hidden_states_l->ne[0] + chunk_hidden_states_g->ne[0]) {
+ value = ggml_tensor_get_f32(chunk_hidden_states_g, i0 - chunk_hidden_states_l->ne[0], i1, i2);
+ }
+ ggml_tensor_set_f32(chunk_hidden_states_lg_pad, value, i0, i1, i2);
+ }
+ }
+ }
+
+ chunk_hidden_states = ggml_tensor_concat(work_ctx, chunk_hidden_states_lg_pad, chunk_hidden_states_t5, 1); // [n_token*2, 4096]
+
+ if (chunk_idx == 0) {
+ pooled = ggml_tensor_concat(work_ctx, pooled_l, pooled_g, 0); // [768 + 1280]
+ }
+
+ int64_t t1 = ggml_time_ms();
+ LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
+ if (force_zero_embeddings) {
+ float* vec = (float*)chunk_hidden_states->data;
+ for (int i = 0; i < ggml_nelements(chunk_hidden_states); i++) {
+ vec[i] = 0;
+ }
+ }
+
+ hidden_states_vec.insert(hidden_states_vec.end(),
+ (float*)chunk_hidden_states->data,
+ ((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states));
+ }
+
+ hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
+ hidden_states = ggml_reshape_2d(work_ctx,
+ hidden_states,
+ chunk_hidden_states->ne[0],
+ ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
+ return SDCondition(hidden_states, pooled, NULL);
+ }
+
+ SDCondition get_learned_condition(ggml_context* work_ctx,
+ int n_threads,
+ const std::string& text,
+ int clip_skip,
+ int width,
+ int height,
+ int adm_in_channels = -1,
+ bool force_zero_embeddings = false) {
+ auto tokens_and_weights = tokenize(text, 77, true);
+ return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings);
+ }
+
+ std::tuple> get_learned_condition_with_trigger(ggml_context* work_ctx,
+ int n_threads,
+ const std::string& text,
+ int clip_skip,
+ int width,
+ int height,
+ int num_input_imgs,
+ int adm_in_channels = -1,
+ bool force_zero_embeddings = false) {
+ GGML_ASSERT(0 && "Not implemented yet!");
+ }
+
+ std::string remove_trigger_from_prompt(ggml_context* work_ctx,
+ const std::string& prompt) {
+ GGML_ASSERT(0 && "Not implemented yet!");
+ }
+};
+
+#endif
\ No newline at end of file
diff --git a/control.hpp b/control.hpp
index c2523d801..3375e7306 100644
--- a/control.hpp
+++ b/control.hpp
@@ -306,7 +306,7 @@ class ControlNetBlock : public GGMLBlock {
}
};
-struct ControlNet : public GGMLModule {
+struct ControlNet : public GGMLRunner {
SDVersion version = VERSION_1_x;
ControlNetBlock control_net;
@@ -319,7 +319,7 @@ struct ControlNet : public GGMLModule {
ControlNet(ggml_backend_t backend,
ggml_type wtype,
SDVersion version = VERSION_1_x)
- : GGMLModule(backend, wtype), control_net(version) {
+ : GGMLRunner(backend, wtype), control_net(version) {
control_net.init(params_ctx, wtype);
}
@@ -426,7 +426,7 @@ struct ControlNet : public GGMLModule {
return build_graph(x, hint, timesteps, context, y);
};
- GGMLModule::compute(get_graph, n_threads, false, output, output_ctx);
+ GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
guided_hint_cached = true;
}
diff --git a/denoiser.hpp b/denoiser.hpp
index a55c43038..26f4c853d 100644
--- a/denoiser.hpp
+++ b/denoiser.hpp
@@ -10,50 +10,14 @@
#define TIMESTEPS 1000
struct SigmaSchedule {
- float alphas_cumprod[TIMESTEPS];
- float sigmas[TIMESTEPS];
- float log_sigmas[TIMESTEPS];
int version = 0;
+ typedef std::function t_to_sigma_t;
- virtual std::vector get_sigmas(uint32_t n) = 0;
-
- float sigma_to_t(float sigma) {
- float log_sigma = std::log(sigma);
- std::vector dists;
- dists.reserve(TIMESTEPS);
- for (float log_sigma_val : log_sigmas) {
- dists.push_back(log_sigma - log_sigma_val);
- }
-
- int low_idx = 0;
- for (size_t i = 0; i < TIMESTEPS; i++) {
- if (dists[i] >= 0) {
- low_idx++;
- }
- }
- low_idx = std::min(std::max(low_idx - 1, 0), TIMESTEPS - 2);
- int high_idx = low_idx + 1;
-
- float low = log_sigmas[low_idx];
- float high = log_sigmas[high_idx];
- float w = (low - log_sigma) / (low - high);
- w = std::max(0.f, std::min(1.f, w));
- float t = (1.0f - w) * low_idx + w * high_idx;
-
- return t;
- }
-
- float t_to_sigma(float t) {
- int low_idx = static_cast(std::floor(t));
- int high_idx = static_cast(std::ceil(t));
- float w = t - static_cast(low_idx);
- float log_sigma = (1.0f - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx];
- return std::exp(log_sigma);
- }
+ virtual std::vector get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) = 0;
};
struct DiscreteSchedule : SigmaSchedule {
- std::vector get_sigmas(uint32_t n) {
+ std::vector get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) {
std::vector result;
int t_max = TIMESTEPS - 1;
@@ -161,7 +125,7 @@ struct AYSSchedule : SigmaSchedule {
return results;
}
- std::vector get_sigmas(uint32_t len) {
+ std::vector get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) {
const std::vector noise_levels[] = {
/* SD1.5 */
{14.6146412293f, 6.4745760956f, 3.8636745985f, 2.6946151520f,
@@ -177,7 +141,7 @@ struct AYSSchedule : SigmaSchedule {
};
std::vector inputs;
- std::vector results(len + 1);
+ std::vector results(n + 1);
switch (version) {
case VERSION_2_x: /* fallthrough */
@@ -201,26 +165,24 @@ struct AYSSchedule : SigmaSchedule {
/* Stretches those pre-calculated reference levels out to the desired
* size using log-linear interpolation */
- if ((len + 1) != inputs.size()) {
- results = log_linear_interpolation(inputs, len + 1);
+ if ((n + 1) != inputs.size()) {
+ results = log_linear_interpolation(inputs, n + 1);
} else {
results = inputs;
}
/* Not sure if this is strictly neccessary */
- results[len] = 0.0f;
+ results[n] = 0.0f;
return results;
}
};
struct KarrasSchedule : SigmaSchedule {
- std::vector get_sigmas(uint32_t n) {
+ std::vector get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) {
// These *COULD* be function arguments here,
// but does anybody ever bother to touch them?
- float sigma_min = 0.1f;
- float sigma_max = 10.f;
- float rho = 7.f;
+ float rho = 7.f;
std::vector result(n + 1);
@@ -236,23 +198,89 @@ struct KarrasSchedule : SigmaSchedule {
};
struct Denoiser {
- std::shared_ptr schedule = std::make_shared();
- virtual std::vector get_scalings(float sigma) = 0;
+ std::shared_ptr schedule = std::make_shared();
+ virtual float sigma_min() = 0;
+ virtual float sigma_max() = 0;
+ virtual float sigma_to_t(float sigma) = 0;
+ virtual float t_to_sigma(float t) = 0;
+ virtual std::vector get_scalings(float sigma) = 0;
+ virtual ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) = 0;
+ virtual ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) = 0;
+
+ virtual std::vector get_sigmas(uint32_t n) {
+ auto bound_t_to_sigma = std::bind(&Denoiser::t_to_sigma, this, std::placeholders::_1);
+ return schedule->get_sigmas(n, sigma_min(), sigma_max(), bound_t_to_sigma);
+ }
};
struct CompVisDenoiser : public Denoiser {
+ float sigmas[TIMESTEPS];
+ float log_sigmas[TIMESTEPS];
+
float sigma_data = 1.0f;
+ float sigma_min() {
+ return sigmas[0];
+ }
+
+ float sigma_max() {
+ return sigmas[TIMESTEPS - 1];
+ }
+
+ float sigma_to_t(float sigma) {
+ float log_sigma = std::log(sigma);
+ std::vector dists;
+ dists.reserve(TIMESTEPS);
+ for (float log_sigma_val : log_sigmas) {
+ dists.push_back(log_sigma - log_sigma_val);
+ }
+
+ int low_idx = 0;
+ for (size_t i = 0; i < TIMESTEPS; i++) {
+ if (dists[i] >= 0) {
+ low_idx++;
+ }
+ }
+ low_idx = std::min(std::max(low_idx - 1, 0), TIMESTEPS - 2);
+ int high_idx = low_idx + 1;
+
+ float low = log_sigmas[low_idx];
+ float high = log_sigmas[high_idx];
+ float w = (low - log_sigma) / (low - high);
+ w = std::max(0.f, std::min(1.f, w));
+ float t = (1.0f - w) * low_idx + w * high_idx;
+
+ return t;
+ }
+
+ float t_to_sigma(float t) {
+ int low_idx = static_cast(std::floor(t));
+ int high_idx = static_cast(std::ceil(t));
+ float w = t - static_cast(low_idx);
+ float log_sigma = (1.0f - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx];
+ return std::exp(log_sigma);
+ }
+
std::vector get_scalings(float sigma) {
- float c_out = -sigma;
- float c_in = 1.0f / std::sqrt(sigma * sigma + sigma_data * sigma_data);
- return {c_out, c_in};
+ float c_skip = 1.0f;
+ float c_out = -sigma;
+ float c_in = 1.0f / std::sqrt(sigma * sigma + sigma_data * sigma_data);
+ return {c_skip, c_out, c_in};
}
-};
-struct CompVisVDenoiser : public Denoiser {
- float sigma_data = 1.0f;
+ // this function will modify noise/latent
+ ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) {
+ ggml_tensor_scale(noise, sigma);
+ ggml_tensor_add(latent, noise);
+ return latent;
+ }
+
+ ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) {
+ return latent;
+ }
+};
+struct CompVisVDenoiser : public CompVisDenoiser {
std::vector get_scalings(float sigma) {
float c_skip = sigma_data * sigma_data / (sigma * sigma + sigma_data * sigma_data);
float c_out = -sigma * sigma_data / std::sqrt(sigma * sigma + sigma_data * sigma_data);
@@ -261,6 +289,67 @@ struct CompVisVDenoiser : public Denoiser {
}
};
+float time_snr_shift(float alpha, float t) {
+ if (alpha == 1.0f) {
+ return t;
+ }
+ return alpha * t / (1 + (alpha - 1) * t);
+}
+
+struct DiscreteFlowDenoiser : public Denoiser {
+ float sigmas[TIMESTEPS];
+ float shift = 3.0f;
+
+ float sigma_data = 1.0f;
+
+ DiscreteFlowDenoiser() {
+ set_parameters();
+ }
+
+ void set_parameters() {
+ for (int i = 1; i < TIMESTEPS + 1; i++) {
+ sigmas[i - 1] = t_to_sigma(i);
+ }
+ }
+
+ float sigma_min() {
+ return sigmas[0];
+ }
+
+ float sigma_max() {
+ return sigmas[TIMESTEPS - 1];
+ }
+
+ float sigma_to_t(float sigma) {
+ return sigma * 1000.f;
+ }
+
+ float t_to_sigma(float t) {
+ t = t + 1;
+ return time_snr_shift(shift, t / 1000.f);
+ }
+
+ std::vector get_scalings(float sigma) {
+ float c_skip = 1.0f;
+ float c_out = -sigma;
+ float c_in = 1.0f;
+ return {c_skip, c_out, c_in};
+ }
+
+ // this function will modify noise/latent
+ ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) {
+ ggml_tensor_scale(noise, sigma);
+ ggml_tensor_scale(latent, 1.0f - sigma);
+ ggml_tensor_add(latent, noise);
+ return latent;
+ }
+
+ ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) {
+ ggml_tensor_scale(latent, 1.0f / (1.0f - sigma));
+ return latent;
+ }
+};
+
typedef std::function denoise_cb_t;
// k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t
diff --git a/diffusion_model.hpp b/diffusion_model.hpp
new file mode 100644
index 000000000..fb2849457
--- /dev/null
+++ b/diffusion_model.hpp
@@ -0,0 +1,123 @@
+#ifndef __DIFFUSION_MODEL_H__
+#define __DIFFUSION_MODEL_H__
+
+#include "mmdit.hpp"
+#include "unet.hpp"
+
+struct DiffusionModel {
+ virtual void compute(int n_threads,
+ struct ggml_tensor* x,
+ struct ggml_tensor* timesteps,
+ struct ggml_tensor* context,
+ struct ggml_tensor* c_concat,
+ struct ggml_tensor* y,
+ int num_video_frames = -1,
+ std::vector controls = {},
+ float control_strength = 0.f,
+ struct ggml_tensor** output = NULL,
+ struct ggml_context* output_ctx = NULL) = 0;
+ virtual void alloc_params_buffer() = 0;
+ virtual void free_params_buffer() = 0;
+ virtual void free_compute_buffer() = 0;
+ virtual void get_param_tensors(std::map& tensors) = 0;
+ virtual size_t get_params_buffer_size() = 0;
+ virtual int64_t get_adm_in_channels() = 0;
+};
+
+struct UNetModel : public DiffusionModel {
+ UNetModelRunner unet;
+
+ UNetModel(ggml_backend_t backend,
+ ggml_type wtype,
+ SDVersion version = VERSION_1_x)
+ : unet(backend, wtype, version) {
+ }
+
+ void alloc_params_buffer() {
+ unet.alloc_params_buffer();
+ }
+
+ void free_params_buffer() {
+ unet.free_params_buffer();
+ }
+
+ void free_compute_buffer() {
+ unet.free_compute_buffer();
+ }
+
+ void get_param_tensors(std::map& tensors) {
+ unet.get_param_tensors(tensors, "model.diffusion_model");
+ }
+
+ size_t get_params_buffer_size() {
+ return unet.get_params_buffer_size();
+ }
+
+ int64_t get_adm_in_channels() {
+ return unet.unet.adm_in_channels;
+ }
+
+ void compute(int n_threads,
+ struct ggml_tensor* x,
+ struct ggml_tensor* timesteps,
+ struct ggml_tensor* context,
+ struct ggml_tensor* c_concat,
+ struct ggml_tensor* y,
+ int num_video_frames = -1,
+ std::vector controls = {},
+ float control_strength = 0.f,
+ struct ggml_tensor** output = NULL,
+ struct ggml_context* output_ctx = NULL) {
+ return unet.compute(n_threads, x, timesteps, context, c_concat, y, num_video_frames, controls, control_strength, output, output_ctx);
+ }
+};
+
+struct MMDiTModel : public DiffusionModel {
+ MMDiTRunner mmdit;
+
+ MMDiTModel(ggml_backend_t backend,
+ ggml_type wtype,
+ SDVersion version = VERSION_3_2B)
+ : mmdit(backend, wtype, version) {
+ }
+
+ void alloc_params_buffer() {
+ mmdit.alloc_params_buffer();
+ }
+
+ void free_params_buffer() {
+ mmdit.free_params_buffer();
+ }
+
+ void free_compute_buffer() {
+ mmdit.free_compute_buffer();
+ }
+
+ void get_param_tensors(std::map& tensors) {
+ mmdit.get_param_tensors(tensors, "model.diffusion_model");
+ }
+
+ size_t get_params_buffer_size() {
+ return mmdit.get_params_buffer_size();
+ }
+
+ int64_t get_adm_in_channels() {
+ return 768 + 1280;
+ }
+
+ void compute(int n_threads,
+ struct ggml_tensor* x,
+ struct ggml_tensor* timesteps,
+ struct ggml_tensor* context,
+ struct ggml_tensor* c_concat,
+ struct ggml_tensor* y,
+ int num_video_frames = -1,
+ std::vector controls = {},
+ float control_strength = 0.f,
+ struct ggml_tensor** output = NULL,
+ struct ggml_context* output_ctx = NULL) {
+ return mmdit.compute(n_threads, x, timesteps, context, y, output, output_ctx);
+ }
+};
+
+#endif
\ No newline at end of file
diff --git a/esrgan.hpp b/esrgan.hpp
index 234de9ec6..33fcf09a4 100644
--- a/esrgan.hpp
+++ b/esrgan.hpp
@@ -137,14 +137,14 @@ class RRDBNet : public GGMLBlock {
}
};
-struct ESRGAN : public GGMLModule {
+struct ESRGAN : public GGMLRunner {
RRDBNet rrdb_net;
int scale = 4;
int tile_size = 128; // avoid cuda OOM for 4gb VRAM
ESRGAN(ggml_backend_t backend,
ggml_type wtype)
- : GGMLModule(backend, wtype) {
+ : GGMLRunner(backend, wtype) {
rrdb_net.init(params_ctx, wtype);
}
@@ -191,7 +191,7 @@ struct ESRGAN : public GGMLModule {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x);
};
- GGMLModule::compute(get_graph, n_threads, false, output, output_ctx);
+ GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
}
};
diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp
index 565af74a8..6675095b5 100644
--- a/examples/cli/main.cpp
+++ b/examples/cli/main.cpp
@@ -7,7 +7,9 @@
#include
// #include "preprocessing.hpp"
+#include "mmdit.hpp"
#include "stable-diffusion.h"
+#include "t5.hpp"
#define STB_IMAGE_IMPLEMENTATION
#define STB_IMAGE_STATIC
@@ -626,6 +628,7 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
int main(int argc, const char* argv[]) {
SDParams params;
+
parse_args(argc, argv, params);
sd_set_log_callback(sd_log_cb, (void*)¶ms);
diff --git a/ggml b/ggml
index 9d562d712..34a63747c 160000
--- a/ggml
+++ b/ggml
@@ -1 +1 @@
-Subproject commit 9d562d712513c77a4de44ad0428be62bc3f2a9cf
+Subproject commit 34a63747c4f0edf952267c3d0c1c1ef3dd9fe827
diff --git a/ggml_extend.hpp b/ggml_extend.hpp
index dbe93031d..1c82d10eb 100644
--- a/ggml_extend.hpp
+++ b/ggml_extend.hpp
@@ -75,6 +75,16 @@ __STATIC_INLINE__ float ggml_tensor_get_f32(const ggml_tensor* tensor, int l, in
return *(float*)((char*)(tensor->data) + i * tensor->nb[3] + j * tensor->nb[2] + k * tensor->nb[1] + l * tensor->nb[0]);
}
+__STATIC_INLINE__ int ggml_tensor_get_i32(const ggml_tensor* tensor, int l, int k = 0, int j = 0, int i = 0) {
+ if (tensor->buffer != NULL) {
+ float value;
+ ggml_backend_tensor_get(tensor, &value, i * tensor->nb[3] + j * tensor->nb[2] + k * tensor->nb[1] + l * tensor->nb[0], sizeof(int));
+ return value;
+ }
+ GGML_ASSERT(tensor->nb[0] == sizeof(int));
+ return *(int*)((char*)(tensor->data) + i * tensor->nb[3] + j * tensor->nb[2] + k * tensor->nb[1] + l * tensor->nb[0]);
+}
+
__STATIC_INLINE__ ggml_fp16_t ggml_tensor_get_f16(const ggml_tensor* tensor, int l, int k = 0, int j = 0, int i = 0) {
GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
return *(ggml_fp16_t*)((char*)(tensor->data) + i * tensor->nb[3] + j * tensor->nb[2] + k * tensor->nb[1] + l * tensor->nb[0]);
@@ -126,6 +136,8 @@ __STATIC_INLINE__ void print_ggml_tensor(struct ggml_tensor* tensor, bool shape_
printf(" [%d, %d, %d, %d] = %f\n", i, j, k, l, ggml_tensor_get_f32(tensor, l, k, j, i));
} else if (tensor->type == GGML_TYPE_F16) {
printf(" [%d, %d, %d, %d] = %i\n", i, j, k, l, ggml_tensor_get_f16(tensor, l, k, j, i));
+ } else if (tensor->type == GGML_TYPE_I32) {
+ printf(" [%d, %d, %d, %d] = %i\n", i, j, k, l, ggml_tensor_get_i32(tensor, l, k, j, i));
}
fflush(stdout);
}
@@ -401,6 +413,42 @@ __STATIC_INLINE__ void ggml_tensor_clamp(struct ggml_tensor* src, float min, flo
}
}
+__STATIC_INLINE__ struct ggml_tensor* ggml_tensor_concat(struct ggml_context* ctx,
+ struct ggml_tensor* a,
+ struct ggml_tensor* b,
+ int dim) {
+ int64_t ne[GGML_MAX_DIMS];
+ for (int d = 0; d < GGML_MAX_DIMS; ++d) {
+ if (d == dim) {
+ ne[d] = a->ne[d] + b->ne[d];
+ continue;
+ }
+ GGML_ASSERT(a->ne[d] == b->ne[d]);
+ ne[d] = a->ne[d];
+ }
+ struct ggml_tensor* result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);
+ int64_t o[4] = {0, 0, 0, 0};
+ o[dim] = a->ne[dim];
+
+ float v;
+ for (int i3 = 0; i3 < result->ne[3]; i3++) {
+ for (int i2 = 0; i2 < result->ne[2]; i2++) {
+ for (int i1 = 0; i1 < result->ne[1]; i1++) {
+ for (int i0 = 0; i0 < result->ne[0]; i0++) {
+ if (i0 < a->ne[0] && i1 < a->ne[1] && i2 < a->ne[2] && i3 < a->ne[3]) {
+ v = ggml_tensor_get_f32(a, i0, i1, i2, i3);
+ } else {
+ v = ggml_tensor_get_f32(b, i0 - o[0], i1 - o[1], i2 - o[2], i3 - o[3]);
+ }
+
+ ggml_tensor_set_f32(result, v, i0, i1, i2, i3);
+ }
+ }
+ }
+ }
+ return result;
+}
+
// convert values from [0, 1] to [-1, 1]
__STATIC_INLINE__ void ggml_tensor_scale_input(struct ggml_tensor* src) {
int64_t nelements = ggml_nelements(src);
@@ -605,6 +653,56 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx
return kqv;
}
+// q: [N, L_q, C]
+// k: [N, L_k, C]
+// v: [N, L_k, C]
+// return: [N, L_q, C]
+__STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* ctx,
+ struct ggml_tensor* q,
+ struct ggml_tensor* k,
+ struct ggml_tensor* v,
+ int64_t n_head,
+ struct ggml_tensor* mask = NULL,
+ bool diag_mask_inf = false) {
+ int64_t L_q = q->ne[1];
+ int64_t L_k = k->ne[1];
+ int64_t C = q->ne[0];
+ int64_t N = q->ne[2];
+
+ int64_t d_head = C / n_head;
+ float scale = (1.0f / sqrt((float)d_head));
+
+ q = ggml_reshape_4d(ctx, q, d_head, n_head, L_q, N); // [N, L_q, n_head, d_head]
+ q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, L_q, d_head]
+ q = ggml_reshape_3d(ctx, q, d_head, L_q, n_head * N); // [N * n_head, L_q, d_head]
+
+ k = ggml_reshape_4d(ctx, k, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head]
+ k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
+ k = ggml_reshape_3d(ctx, k, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
+
+ v = ggml_reshape_4d(ctx, v, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head]
+ v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, L_k]
+ v = ggml_reshape_3d(ctx, v, L_k, d_head, n_head * N); // [N * n_head, d_head, L_k]
+
+ auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k]
+ kq = ggml_scale_inplace(ctx, kq, scale);
+ if (mask) {
+ kq = ggml_add(ctx, kq, mask);
+ }
+ if (diag_mask_inf) {
+ kq = ggml_diag_mask_inf_inplace(ctx, kq, 0);
+ }
+ kq = ggml_soft_max_inplace(ctx, kq);
+
+ auto kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, L_q, d_head]
+
+ kqv = ggml_reshape_4d(ctx, kqv, d_head, L_q, n_head, N); // [N, n_head, L_q, d_head]
+ kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); // [N, L_q, n_head, d_head]
+ kqv = ggml_reshape_3d(ctx, kqv, d_head * n_head, L_q, N); // [N, L_q, C]
+
+ return kqv;
+}
+
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_layer_norm(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* w,
@@ -764,7 +862,7 @@ __STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) {
#define MAX_PARAMS_TENSOR_NUM 15360
#define MAX_GRAPH_SIZE 15360
-struct GGMLModule {
+struct GGMLRunner {
protected:
typedef std::function get_graph_cb_t;
@@ -852,12 +950,12 @@ struct GGMLModule {
public:
virtual std::string get_desc() = 0;
- GGMLModule(ggml_backend_t backend, ggml_type wtype = GGML_TYPE_F32)
+ GGMLRunner(ggml_backend_t backend, ggml_type wtype = GGML_TYPE_F32)
: backend(backend), wtype(wtype) {
alloc_params_ctx();
}
- virtual ~GGMLModule() {
+ virtual ~GGMLRunner() {
free_params_buffer();
free_compute_buffer();
free_params_ctx();
@@ -873,7 +971,9 @@ struct GGMLModule {
size_t num_tensors = ggml_tensor_num(params_ctx);
params_buffer = ggml_backend_alloc_ctx_tensors(params_ctx, backend);
if (params_buffer == NULL) {
- LOG_ERROR("%s alloc params backend buffer failed", get_desc().c_str());
+ LOG_ERROR("%s alloc params backend buffer failed, num_tensors = %i",
+ get_desc().c_str(),
+ num_tensors);
return false;
}
size_t params_buffer_size = ggml_backend_buffer_get_size(params_buffer);
@@ -1068,6 +1168,40 @@ class Linear : public UnaryBlock {
}
};
+class Embedding : public UnaryBlock {
+protected:
+ int64_t embedding_dim;
+ int64_t num_embeddings;
+
+ void init_params(struct ggml_context* ctx, ggml_type wtype) {
+ params["weight"] = ggml_new_tensor_2d(ctx, wtype, embedding_dim, num_embeddings);
+ }
+
+public:
+ Embedding(int64_t num_embeddings, int64_t embedding_dim)
+ : embedding_dim(embedding_dim),
+ num_embeddings(num_embeddings) {
+ }
+
+ struct ggml_tensor* forward(struct ggml_context* ctx,
+ struct ggml_tensor* input_ids) {
+ // input_ids: [N, n_token]
+ auto weight = params["weight"];
+
+ // There are issues with ggml batch inference, so we are expanding it here first.
+ // TODO: fix ggml batch inference
+ int64_t n = input_ids->ne[1];
+ input_ids = ggml_reshape_1d(ctx, input_ids, input_ids->ne[0] * input_ids->ne[1]);
+
+ input_ids = ggml_reshape_3d(ctx, input_ids, input_ids->ne[0], 1, input_ids->ne[1]);
+ auto embedding = ggml_get_rows(ctx, weight, input_ids);
+ embedding = ggml_reshape_3d(ctx, embedding, embedding->ne[0], embedding->ne[1] / n, n);
+
+ // [N, n_token, embedding_dim]
+ return embedding;
+ }
+};
+
class Conv2d : public UnaryBlock {
protected:
int64_t in_channels;
@@ -1241,53 +1375,44 @@ class MultiheadAttention : public GGMLBlock {
protected:
int64_t embed_dim;
int64_t n_head;
- bool bias;
- bool mask;
+ std::string q_proj_name;
+ std::string k_proj_name;
+ std::string v_proj_name;
+ std::string out_proj_name;
public:
MultiheadAttention(int64_t embed_dim,
int64_t n_head,
- bool bias = true)
+ bool qkv_proj_bias = true,
+ bool out_proj_bias = true,
+ std::string q_proj_name = "q_proj",
+ std::string k_proj_name = "k_proj",
+ std::string v_proj_name = "v_proj",
+ std::string out_proj_name = "out_proj")
: embed_dim(embed_dim),
n_head(n_head),
- bias(bias) {
- blocks["q_proj"] = std::shared_ptr(new Linear(embed_dim, embed_dim, bias));
- blocks["k_proj"] = std::shared_ptr(new Linear(embed_dim, embed_dim, bias));
- blocks["v_proj"] = std::shared_ptr(new Linear(embed_dim, embed_dim, bias));
- blocks["out_proj"] = std::shared_ptr(new Linear(embed_dim, embed_dim, bias));
+ q_proj_name(q_proj_name),
+ k_proj_name(k_proj_name),
+ v_proj_name(v_proj_name),
+ out_proj_name(out_proj_name) {
+ blocks[q_proj_name] = std::shared_ptr(new Linear(embed_dim, embed_dim, qkv_proj_bias));
+ blocks[k_proj_name] = std::shared_ptr(new Linear(embed_dim, embed_dim, qkv_proj_bias));
+ blocks[v_proj_name] = std::shared_ptr(new Linear(embed_dim, embed_dim, qkv_proj_bias));
+ blocks[out_proj_name] = std::shared_ptr(new Linear(embed_dim, embed_dim, out_proj_bias));
}
// x: [N, n_token, embed_dim]
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, bool mask = false) {
- auto q_proj = std::dynamic_pointer_cast(blocks["q_proj"]);
- auto k_proj = std::dynamic_pointer_cast(blocks["k_proj"]);
- auto v_proj = std::dynamic_pointer_cast(blocks["v_proj"]);
- auto out_proj = std::dynamic_pointer_cast(blocks["out_proj"]);
-
- int64_t N = x->ne[2];
- int64_t n_token = x->ne[1];
- int64_t d_head = embed_dim / n_head;
+ auto q_proj = std::dynamic_pointer_cast(blocks[q_proj_name]);
+ auto k_proj = std::dynamic_pointer_cast(blocks[k_proj_name]);
+ auto v_proj = std::dynamic_pointer_cast(blocks[v_proj_name]);
+ auto out_proj = std::dynamic_pointer_cast(blocks[out_proj_name]);
struct ggml_tensor* q = q_proj->forward(ctx, x);
- q = ggml_reshape_4d(ctx, q, d_head, n_head, n_token, N); // [N, n_token, n_head, d_head]
- q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, n_token, d_head]
- q = ggml_reshape_3d(ctx, q, d_head, n_token, n_head * N); // [N * n_head, n_token, d_head]
-
struct ggml_tensor* k = k_proj->forward(ctx, x);
- k = ggml_reshape_4d(ctx, k, d_head, n_head, n_token, N); // [N, n_token, n_head, d_head]
- k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, n_token, d_head]
- k = ggml_reshape_3d(ctx, k, d_head, n_token, n_head * N); // [N * n_head, n_token, d_head]
-
struct ggml_tensor* v = v_proj->forward(ctx, x);
- v = ggml_reshape_4d(ctx, v, d_head, n_head, n_token, N); // [N, n_token, n_head, d_head]
- v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, n_token]
- v = ggml_reshape_3d(ctx, v, n_token, d_head, n_head * N); // [N * n_head, d_head, n_token]
-
- struct ggml_tensor* kqv = ggml_nn_attention(ctx, q, k, v, mask); // [N * n_head, n_token, d_head]
- kqv = ggml_reshape_4d(ctx, kqv, d_head, n_token, n_head, N);
- kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); // [N, n_token, n_head, d_head]
- x = ggml_reshape_3d(ctx, kqv, d_head * n_head, n_token, N); // [N, n_token, d_head * n_head]
+ x = ggml_nn_attention_ext(ctx, q, k, v, n_head, NULL, mask); // [N, n_token, embed_dim]
x = out_proj->forward(ctx, x); // [N, n_token, embed_dim]
return x;
diff --git a/lora.hpp b/lora.hpp
index 9b5cfe22b..edee74cae 100644
--- a/lora.hpp
+++ b/lora.hpp
@@ -5,7 +5,7 @@
#define LORA_GRAPH_SIZE 10240
-struct LoraModel : public GGMLModule {
+struct LoraModel : public GGMLRunner {
float multiplier = 1.0f;
std::map lora_tensors;
std::string file_path;
@@ -17,7 +17,7 @@ struct LoraModel : public GGMLModule {
ggml_type wtype,
const std::string& file_path = "",
const std::string& prefix = "")
- : file_path(file_path), GGMLModule(backend, wtype) {
+ : file_path(file_path), GGMLRunner(backend, wtype) {
if (!model_loader.init_from_file(file_path, prefix)) {
load_failed = true;
}
@@ -182,7 +182,7 @@ struct LoraModel : public GGMLModule {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_lora_graph(model_tensors);
};
- GGMLModule::compute(get_graph, n_threads, true);
+ GGMLRunner::compute(get_graph, n_threads, true);
}
};
diff --git a/mmdit.hpp b/mmdit.hpp
new file mode 100644
index 000000000..7d7b22d9a
--- /dev/null
+++ b/mmdit.hpp
@@ -0,0 +1,795 @@
+#ifndef __MMDIT_HPP__
+#define __MMDIT_HPP__
+
+#include "ggml_extend.hpp"
+#include "model.h"
+
+#define MMDIT_GRAPH_SIZE 10240
+
+struct Mlp : public GGMLBlock {
+public:
+ Mlp(int64_t in_features,
+ int64_t hidden_features = -1,
+ int64_t out_features = -1,
+ bool bias = true) {
+ // act_layer is always lambda: nn.GELU(approximate="tanh")
+ // norm_layer is always None
+ // use_conv is always False
+ if (hidden_features == -1) {
+ hidden_features = in_features;
+ }
+ if (out_features == -1) {
+ out_features = in_features;
+ }
+ blocks["fc1"] = std::shared_ptr(new Linear(in_features, hidden_features, bias));
+ blocks["fc2"] = std::shared_ptr(new Linear(hidden_features, out_features, bias));
+ }
+
+ struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
+ // x: [N, n_token, in_features]
+ auto fc1 = std::dynamic_pointer_cast(blocks["fc1"]);
+ auto fc2 = std::dynamic_pointer_cast(blocks["fc2"]);
+
+ x = fc1->forward(ctx, x);
+ x = ggml_gelu_inplace(ctx, x);
+ x = fc2->forward(ctx, x);
+ return x;
+ }
+};
+
+struct PatchEmbed : public GGMLBlock {
+ // 2D Image to Patch Embedding
+protected:
+ bool flatten;
+ bool dynamic_img_pad;
+ int patch_size;
+
+public:
+ PatchEmbed(int64_t img_size = 224,
+ int patch_size = 16,
+ int64_t in_chans = 3,
+ int64_t embed_dim = 1536,
+ bool bias = true,
+ bool flatten = true,
+ bool dynamic_img_pad = true)
+ : patch_size(patch_size),
+ flatten(flatten),
+ dynamic_img_pad(dynamic_img_pad) {
+ // img_size is always None
+ // patch_size is always 2
+ // in_chans is always 16
+ // norm_layer is always False
+ // strict_img_size is always true, but not used
+
+ blocks["proj"] = std::shared_ptr(new Conv2d(in_chans,
+ embed_dim,
+ {patch_size, patch_size},
+ {patch_size, patch_size},
+ {0, 0},
+ {1, 1},
+ bias));
+ }
+
+ struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
+ // x: [N, C, H, W]
+ // return: [N, H*W, embed_dim]
+ auto proj = std::dynamic_pointer_cast(blocks["proj"]);
+
+ if (dynamic_img_pad) {
+ int64_t W = x->ne[0];
+ int64_t H = x->ne[1];
+ int pad_h = (patch_size - H % patch_size) % patch_size;
+ int pad_w = (patch_size - W % patch_size) % patch_size;
+ x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // TODO: reflect pad mode
+ }
+ x = proj->forward(ctx, x);
+
+ if (flatten) {
+ x = ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1], x->ne[2], x->ne[3]);
+ x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3));
+ }
+ return x;
+ }
+};
+
+struct TimestepEmbedder : public GGMLBlock {
+ // Embeds scalar timesteps into vector representations.
+protected:
+ int64_t frequency_embedding_size;
+
+public:
+ TimestepEmbedder(int64_t hidden_size,
+ int64_t frequency_embedding_size = 256)
+ : frequency_embedding_size(frequency_embedding_size) {
+ blocks["mlp.0"] = std::shared_ptr(new Linear(frequency_embedding_size, hidden_size));
+ blocks["mlp.2"] = std::shared_ptr(new Linear(hidden_size, hidden_size));
+ }
+
+ struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* t) {
+ // t: [N, ]
+ // return: [N, hidden_size]
+ auto mlp_0 = std::dynamic_pointer_cast(blocks["mlp.0"]);
+ auto mlp_2 = std::dynamic_pointer_cast(blocks["mlp.2"]);
+
+ auto t_freq = ggml_nn_timestep_embedding(ctx, t, frequency_embedding_size); // [N, frequency_embedding_size]
+
+ auto t_emb = mlp_0->forward(ctx, t_freq);
+ t_emb = ggml_silu_inplace(ctx, t_emb);
+ t_emb = mlp_2->forward(ctx, t_emb);
+ return t_emb;
+ }
+};
+
+struct VectorEmbedder : public GGMLBlock {
+ // Embeds a flat vector of dimension input_dim
+public:
+ VectorEmbedder(int64_t input_dim,
+ int64_t hidden_size) {
+ blocks["mlp.0"] = std::shared_ptr(new Linear(input_dim, hidden_size));
+ blocks["mlp.2"] = std::shared_ptr(new Linear(hidden_size, hidden_size));
+ }
+
+ struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
+ // x: [N, input_dim]
+ // return: [N, hidden_size]
+ auto mlp_0 = std::dynamic_pointer_cast(blocks["mlp.0"]);
+ auto mlp_2 = std::dynamic_pointer_cast(blocks["mlp.2"]);
+
+ x = mlp_0->forward(ctx, x);
+ x = ggml_silu_inplace(ctx, x);
+ x = mlp_2->forward(ctx, x);
+ return x;
+ }
+};
+
+__STATIC_INLINE__ std::vector split_qkv(struct ggml_context* ctx,
+ struct ggml_tensor* qkv) {
+ // qkv: [N, L, 3*C]
+ // return: ([N, L, C], [N, L, C], [N, L, C])
+ qkv = ggml_reshape_4d(ctx, qkv, qkv->ne[0] / 3, 3, qkv->ne[1], qkv->ne[2]); // [N, L, 3, C]
+ qkv = ggml_cont(ctx, ggml_permute(ctx, qkv, 0, 3, 1, 2)); // [3, N, L, C]
+
+ int64_t offset = qkv->nb[2] * qkv->ne[2];
+ auto q = ggml_view_3d(ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], qkv->nb[1], qkv->nb[2], offset * 0); // [N, L, C]
+ auto k = ggml_view_3d(ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], qkv->nb[1], qkv->nb[2], offset * 1); // [N, L, C]
+ auto v = ggml_view_3d(ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], qkv->nb[1], qkv->nb[2], offset * 2); // [N, L, C]
+ return {q, k, v};
+}
+
+class SelfAttention : public GGMLBlock {
+public:
+ int64_t num_heads;
+ bool pre_only;
+
+public:
+ SelfAttention(int64_t dim,
+ int64_t num_heads = 8,
+ bool qkv_bias = false,
+ bool pre_only = false)
+ : num_heads(num_heads), pre_only(pre_only) {
+ // qk_norm is always None
+ blocks["qkv"] = std::shared_ptr(new Linear(dim, dim * 3, qkv_bias));
+ if (!pre_only) {
+ blocks["proj"] = std::shared_ptr(new Linear(dim, dim));
+ }
+ }
+
+ std::vector pre_attention(struct ggml_context* ctx, struct ggml_tensor* x) {
+ auto qkv_proj = std::dynamic_pointer_cast(blocks["qkv"]);
+
+ auto qkv = qkv_proj->forward(ctx, x);
+ return split_qkv(ctx, qkv);
+ }
+
+ struct ggml_tensor* post_attention(struct ggml_context* ctx, struct ggml_tensor* x) {
+ GGML_ASSERT(!pre_only);
+
+ auto proj = std::dynamic_pointer_cast(blocks["proj"]);
+
+ x = proj->forward(ctx, x); // [N, n_token, dim]
+ return x;
+ }
+
+ // x: [N, n_token, dim]
+ struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
+ auto qkv = pre_attention(ctx, x);
+ x = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim]
+ x = post_attention(ctx, x); // [N, n_token, dim]
+ return x;
+ }
+};
+
+__STATIC_INLINE__ struct ggml_tensor* modulate(struct ggml_context* ctx,
+ struct ggml_tensor* x,
+ struct ggml_tensor* shift,
+ struct ggml_tensor* scale) {
+ // x: [N, L, C]
+ // scale: [N, C]
+ // shift: [N, C]
+ scale = ggml_reshape_3d(ctx, scale, scale->ne[0], 1, scale->ne[1]); // [N, 1, C]
+ shift = ggml_reshape_3d(ctx, shift, shift->ne[0], 1, shift->ne[1]); // [N, 1, C]
+ x = ggml_add(ctx, x, ggml_mul(ctx, x, scale));
+ x = ggml_add(ctx, x, shift);
+ return x;
+}
+
+struct DismantledBlock : public GGMLBlock {
+ // A DiT block with gated adaptive layer norm (adaLN) conditioning.
+public:
+ int64_t num_heads;
+ bool pre_only;
+
+public:
+ DismantledBlock(int64_t hidden_size,
+ int64_t num_heads,
+ float mlp_ratio = 4.0,
+ bool qkv_bias = false,
+ bool pre_only = false)
+ : num_heads(num_heads), pre_only(pre_only) {
+ // rmsnorm is always Flase
+ // scale_mod_only is always Flase
+ // swiglu is always Flase
+ // qk_norm is always Flase
+ blocks["norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false));
+ blocks["attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, pre_only));
+
+ if (!pre_only) {
+ blocks["norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false));
+ int64_t mlp_hidden_dim = (int64_t)(hidden_size * mlp_ratio);
+ blocks["mlp"] = std::shared_ptr(new Mlp(hidden_size, mlp_hidden_dim));
+ }
+
+ int64_t n_mods = 6;
+ if (pre_only) {
+ n_mods = 2;
+ }
+ blocks["adaLN_modulation.1"] = std::shared_ptr(new Linear(hidden_size, n_mods * hidden_size));
+ }
+
+ std::pair, std::vector> pre_attention(struct ggml_context* ctx,
+ struct ggml_tensor* x,
+ struct ggml_tensor* c) {
+ // x: [N, n_token, hidden_size]
+ // c: [N, hidden_size]
+ auto norm1 = std::dynamic_pointer_cast(blocks["norm1"]);
+ auto attn = std::dynamic_pointer_cast(blocks["attn"]);
+ auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]);
+
+ int64_t n_mods = 6;
+ if (pre_only) {
+ n_mods = 2;
+ }
+ auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, n_mods * hidden_size]
+ m = ggml_reshape_3d(ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size]
+ m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size]
+
+ int64_t offset = m->nb[1] * m->ne[1];
+ auto shift_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
+ auto scale_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
+ if (!pre_only) {
+ auto gate_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size]
+ auto shift_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size]
+ auto scale_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size]
+ auto gate_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size]
+
+ auto attn_in = modulate(ctx, norm1->forward(ctx, x), shift_msa, scale_msa);
+
+ auto qkv = attn->pre_attention(ctx, attn_in);
+
+ return {qkv, {x, gate_msa, shift_mlp, scale_mlp, gate_mlp}};
+ } else {
+ auto attn_in = modulate(ctx, norm1->forward(ctx, x), shift_msa, scale_msa);
+ auto qkv = attn->pre_attention(ctx, attn_in);
+
+ return {qkv, {NULL, NULL, NULL, NULL, NULL}};
+ }
+ }
+
+ struct ggml_tensor* post_attention(struct ggml_context* ctx,
+ struct ggml_tensor* attn_out,
+ struct ggml_tensor* x,
+ struct ggml_tensor* gate_msa,
+ struct ggml_tensor* shift_mlp,
+ struct ggml_tensor* scale_mlp,
+ struct ggml_tensor* gate_mlp) {
+ // attn_out: [N, n_token, hidden_size]
+ // x: [N, n_token, hidden_size]
+ // gate_msa: [N, hidden_size]
+ // shift_mlp: [N, hidden_size]
+ // scale_mlp: [N, hidden_size]
+ // gate_mlp: [N, hidden_size]
+ // return: [N, n_token, hidden_size]
+ GGML_ASSERT(!pre_only);
+
+ auto attn = std::dynamic_pointer_cast(blocks["attn"]);
+ auto norm2 = std::dynamic_pointer_cast(blocks["norm2"]);
+ auto mlp = std::dynamic_pointer_cast(blocks["mlp"]);
+
+ gate_msa = ggml_reshape_3d(ctx, gate_msa, gate_msa->ne[0], 1, gate_msa->ne[1]); // [N, 1, hidden_size]
+ gate_mlp = ggml_reshape_3d(ctx, gate_mlp, gate_mlp->ne[0], 1, gate_mlp->ne[1]); // [N, 1, hidden_size]
+
+ attn_out = attn->post_attention(ctx, attn_out);
+
+ x = ggml_add(ctx, x, ggml_mul(ctx, attn_out, gate_msa));
+ auto mlp_out = mlp->forward(ctx, modulate(ctx, norm2->forward(ctx, x), shift_mlp, scale_mlp));
+ x = ggml_add(ctx, x, ggml_mul(ctx, mlp_out, gate_mlp));
+
+ return x;
+ }
+
+ struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* c) {
+ // x: [N, n_token, hidden_size]
+ // c: [N, hidden_size]
+ // return: [N, n_token, hidden_size]
+
+ auto attn = std::dynamic_pointer_cast(blocks["attn"]);
+
+ auto qkv_intermediates = pre_attention(ctx, x, c);
+ auto qkv = qkv_intermediates.first;
+ auto intermediates = qkv_intermediates.second;
+
+ auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim]
+ x = post_attention(ctx,
+ attn_out,
+ intermediates[0],
+ intermediates[1],
+ intermediates[2],
+ intermediates[3],
+ intermediates[4]);
+ return x; // [N, n_token, dim]
+ }
+};
+
+__STATIC_INLINE__ std::pair block_mixing(struct ggml_context* ctx,
+ struct ggml_tensor* context,
+ struct ggml_tensor* x,
+ struct ggml_tensor* c,
+ std::shared_ptr context_block,
+ std::shared_ptr x_block) {
+ // context: [N, n_context, hidden_size]
+ // x: [N, n_token, hidden_size]
+ // c: [N, hidden_size]
+ auto context_qkv_intermediates = context_block->pre_attention(ctx, context, c);
+ auto context_qkv = context_qkv_intermediates.first;
+ auto context_intermediates = context_qkv_intermediates.second;
+
+ auto x_qkv_intermediates = x_block->pre_attention(ctx, x, c);
+ auto x_qkv = x_qkv_intermediates.first;
+ auto x_intermediates = x_qkv_intermediates.second;
+
+ std::vector qkv;
+ for (int i = 0; i < 3; i++) {
+ qkv.push_back(ggml_concat(ctx, context_qkv[i], x_qkv[i], 1));
+ }
+
+ auto attn = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], x_block->num_heads); // [N, n_context + n_token, hidden_size]
+ attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size]
+ auto context_attn = ggml_view_3d(ctx,
+ attn,
+ attn->ne[0],
+ attn->ne[1],
+ context->ne[1],
+ attn->nb[1],
+ attn->nb[2],
+ 0); // [n_context, N, hidden_size]
+ context_attn = ggml_cont(ctx, ggml_permute(ctx, context_attn, 0, 2, 1, 3)); // [N, n_context, hidden_size]
+ auto x_attn = ggml_view_3d(ctx,
+ attn,
+ attn->ne[0],
+ attn->ne[1],
+ x->ne[1],
+ attn->nb[1],
+ attn->nb[2],
+ attn->nb[2] * context->ne[1]); // [n_token, N, hidden_size]
+ x_attn = ggml_cont(ctx, ggml_permute(ctx, x_attn, 0, 2, 1, 3)); // [N, n_token, hidden_size]
+
+ if (!context_block->pre_only) {
+ context = context_block->post_attention(ctx,
+ context_attn,
+ context_intermediates[0],
+ context_intermediates[1],
+ context_intermediates[2],
+ context_intermediates[3],
+ context_intermediates[4]);
+ } else {
+ context = NULL;
+ }
+
+ x = x_block->post_attention(ctx,
+ x_attn,
+ x_intermediates[0],
+ x_intermediates[1],
+ x_intermediates[2],
+ x_intermediates[3],
+ x_intermediates[4]);
+
+ return {context, x};
+}
+
+struct JointBlock : public GGMLBlock {
+public:
+ JointBlock(int64_t hidden_size,
+ int64_t num_heads,
+ float mlp_ratio = 4.0,
+ bool qkv_bias = false,
+ bool pre_only = false) {
+ // qk_norm is always Flase
+ blocks["context_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qkv_bias, pre_only));
+ blocks["x_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qkv_bias, false));
+ }
+
+ std::pair forward(struct ggml_context* ctx,
+ struct ggml_tensor* context,
+ struct ggml_tensor* x,
+ struct ggml_tensor* c) {
+ auto context_block = std::dynamic_pointer_cast(blocks["context_block"]);
+ auto x_block = std::dynamic_pointer_cast(blocks["x_block"]);
+
+ return block_mixing(ctx, context, x, c, context_block, x_block);
+ }
+};
+
+struct FinalLayer : public GGMLBlock {
+ // The final layer of DiT.
+public:
+ FinalLayer(int64_t hidden_size,
+ int64_t patch_size,
+ int64_t out_channels) {
+ // total_out_channels is always None
+ blocks["norm_final"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false));
+ blocks["linear"] = std::shared_ptr(new Linear(hidden_size, patch_size * patch_size * out_channels));
+ blocks["adaLN_modulation.1"] = std::shared_ptr(new Linear(hidden_size, 2 * hidden_size));
+ }
+
+ struct ggml_tensor* forward(struct ggml_context* ctx,
+ struct ggml_tensor* x,
+ struct ggml_tensor* c) {
+ // x: [N, n_token, hidden_size]
+ // c: [N, hidden_size]
+ // return: [N, n_token, patch_size * patch_size * out_channels]
+ auto norm_final = std::dynamic_pointer_cast(blocks["norm_final"]);
+ auto linear = std::dynamic_pointer_cast(blocks["linear"]);
+ auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]);
+
+ auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, 2 * hidden_size]
+ m = ggml_reshape_3d(ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size]
+ m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size]
+
+ int64_t offset = m->nb[1] * m->ne[1];
+ auto shift = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
+ auto scale = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
+
+ x = modulate(ctx, norm_final->forward(ctx, x), shift, scale);
+ x = linear->forward(ctx, x);
+
+ return x;
+ }
+};
+
+struct MMDiT : public GGMLBlock {
+ // Diffusion model with a Transformer backbone.
+protected:
+ SDVersion version = VERSION_3_2B;
+ int64_t input_size = -1;
+ int64_t patch_size = 2;
+ int64_t in_channels = 16;
+ int64_t depth = 24;
+ float mlp_ratio = 4.0f;
+ int64_t adm_in_channels = 2048;
+ int64_t out_channels = 16;
+ int64_t pos_embed_max_size = 192;
+ int64_t num_patchs = 36864; // 192 * 192
+ int64_t context_size = 4096;
+ int64_t hidden_size;
+
+ void init_params(struct ggml_context* ctx, ggml_type wtype) {
+ params["pos_embed"] = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hidden_size, num_patchs, 1);
+ }
+
+public:
+ MMDiT(SDVersion version = VERSION_3_2B)
+ : version(version) {
+ // input_size is always None
+ // learn_sigma is always False
+ // register_length is alwalys 0
+ // rmsnorm is alwalys False
+ // scale_mod_only is alwalys False
+ // swiglu is alwalys False
+ // qk_norm is always None
+ // qkv_bias is always True
+ // context_processor_layers is always None
+ // pos_embed_scaling_factor is not used
+ // pos_embed_offset is not used
+ // context_embedder_config is always {'target': 'torch.nn.Linear', 'params': {'in_features': 4096, 'out_features': 1536}}
+ if (version == VERSION_3_2B) {
+ input_size = -1;
+ patch_size = 2;
+ in_channels = 16;
+ depth = 24;
+ mlp_ratio = 4.0f;
+ adm_in_channels = 2048;
+ out_channels = 16;
+ pos_embed_max_size = 192;
+ num_patchs = 36864; // 192 * 192
+ context_size = 4096;
+ }
+ int64_t default_out_channels = in_channels;
+ hidden_size = 64 * depth;
+ int64_t num_heads = depth;
+
+ blocks["x_embedder"] = std::shared_ptr(new PatchEmbed(input_size, patch_size, in_channels, hidden_size, true));
+ blocks["t_embedder"] = std::shared_ptr(new TimestepEmbedder(hidden_size));
+
+ if (adm_in_channels != -1) {
+ blocks["y_embedder"] = std::shared_ptr(new VectorEmbedder(adm_in_channels, hidden_size));
+ }
+
+ blocks["context_embedder"] = std::shared_ptr(new Linear(4096, 1536));
+
+ for (int i = 0; i < depth; i++) {
+ blocks["joint_blocks." + std::to_string(i)] = std::shared_ptr(new JointBlock(hidden_size,
+ num_heads,
+ mlp_ratio,
+ true,
+ i == depth - 1));
+ }
+
+ blocks["final_layer"] = std::shared_ptr(new FinalLayer(hidden_size, patch_size, out_channels));
+ }
+
+ struct ggml_tensor* cropped_pos_embed(struct ggml_context* ctx,
+ int64_t h,
+ int64_t w) {
+ auto pos_embed = params["pos_embed"];
+
+ h = (h + 1) / patch_size;
+ w = (w + 1) / patch_size;
+
+ GGML_ASSERT(h <= pos_embed_max_size && h > 0);
+ GGML_ASSERT(w <= pos_embed_max_size && w > 0);
+
+ int64_t top = (pos_embed_max_size - h) / 2;
+ int64_t left = (pos_embed_max_size - w) / 2;
+
+ auto spatial_pos_embed = ggml_reshape_3d(ctx, pos_embed, hidden_size, pos_embed_max_size, pos_embed_max_size);
+
+ // spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
+ spatial_pos_embed = ggml_view_3d(ctx,
+ spatial_pos_embed,
+ hidden_size,
+ pos_embed_max_size,
+ h,
+ spatial_pos_embed->nb[1],
+ spatial_pos_embed->nb[2],
+ spatial_pos_embed->nb[2] * top); // [h, pos_embed_max_size, hidden_size]
+ spatial_pos_embed = ggml_cont(ctx, ggml_permute(ctx, spatial_pos_embed, 0, 2, 1, 3)); // [pos_embed_max_size, h, hidden_size]
+ spatial_pos_embed = ggml_view_3d(ctx,
+ spatial_pos_embed,
+ hidden_size,
+ h,
+ w,
+ spatial_pos_embed->nb[1],
+ spatial_pos_embed->nb[2],
+ spatial_pos_embed->nb[2] * left); // [w, h, hidden_size]
+ spatial_pos_embed = ggml_cont(ctx, ggml_permute(ctx, spatial_pos_embed, 0, 2, 1, 3)); // [h, w, hidden_size]
+ spatial_pos_embed = ggml_reshape_3d(ctx, spatial_pos_embed, hidden_size, h * w, 1); // [1, h*w, hidden_size]
+ return spatial_pos_embed;
+ }
+
+ struct ggml_tensor* unpatchify(struct ggml_context* ctx,
+ struct ggml_tensor* x,
+ int64_t h,
+ int64_t w) {
+ // x: [N, H*W, patch_size * patch_size * C]
+ // return: [N, C, H, W]
+ int64_t n = x->ne[2];
+ int64_t c = out_channels;
+ int64_t p = patch_size;
+ h = (h + 1) / p;
+ w = (w + 1) / p;
+
+ GGML_ASSERT(h * w == x->ne[1]);
+
+ x = ggml_reshape_4d(ctx, x, c, p * p, w * h, n); // [N, H*W, P*P, C]
+ x = ggml_cont(ctx, ggml_permute(ctx, x, 2, 0, 1, 3)); // [N, C, H*W, P*P]
+ x = ggml_reshape_4d(ctx, x, p, p, w, h * c * n); // [N*C*H, W, P, P]
+ x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*H, P, W, P]
+ x = ggml_reshape_4d(ctx, x, p * w, p * h, c, n); // [N, C, H*P, W*P]
+ return x;
+ }
+
+ struct ggml_tensor* forward_core_with_concat(struct ggml_context* ctx,
+ struct ggml_tensor* x,
+ struct ggml_tensor* c_mod,
+ struct ggml_tensor* context) {
+ // x: [N, H*W, hidden_size]
+ // context: [N, n_context, d_context]
+ // c: [N, hidden_size]
+ // return: [N, N*W, patch_size * patch_size * out_channels]
+ auto final_layer = std::dynamic_pointer_cast(blocks["final_layer"]);
+
+ for (int i = 0; i < depth; i++) {
+ auto block = std::dynamic_pointer_cast