diff --git a/conditioner.hpp b/conditioner.hpp index 6e9acdb19..f48f4f493 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -747,7 +747,7 @@ struct SD3CLIPEmbedder : public Conditioner { clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, max_length, padding); clip_g_tokenizer.pad_tokens(clip_g_tokens, clip_g_weights, max_length, padding); - t5_tokenizer.pad_tokens(t5_tokens, t5_weights, max_length, padding); + t5_tokenizer.pad_tokens(t5_tokens, t5_weights, NULL, max_length, padding); // for (int i = 0; i < clip_l_tokens.size(); i++) { // std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", "; @@ -1004,6 +1004,7 @@ struct FluxCLIPEmbedder : public Conditioner { T5UniGramTokenizer t5_tokenizer; std::shared_ptr clip_l; std::shared_ptr t5; + size_t chunk_len = 256; FluxCLIPEmbedder(ggml_backend_t backend, std::map& tensor_types, @@ -1077,7 +1078,7 @@ struct FluxCLIPEmbedder : public Conditioner { } clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, 77, padding); - t5_tokenizer.pad_tokens(t5_tokens, t5_weights, max_length, padding); + t5_tokenizer.pad_tokens(t5_tokens, t5_weights, NULL, max_length, padding); // for (int i = 0; i < clip_l_tokens.size(); i++) { // std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", "; @@ -1109,7 +1110,6 @@ struct FluxCLIPEmbedder : public Conditioner { struct ggml_tensor* pooled = NULL; // [768,] std::vector hidden_states_vec; - size_t chunk_len = 256; size_t chunk_count = t5_tokens.size() / chunk_len; for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) { // clip_l @@ -1196,7 +1196,226 @@ struct FluxCLIPEmbedder : public Conditioner { int height, int adm_in_channels = -1, bool force_zero_embeddings = false) { - auto tokens_and_weights = tokenize(text, 256, true); + auto tokens_and_weights = tokenize(text, chunk_len, true); + return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings); + } + + std::tuple> get_learned_condition_with_trigger(ggml_context* work_ctx, + int n_threads, + const std::string& text, + int clip_skip, + int width, + int height, + int num_input_imgs, + int adm_in_channels = -1, + bool force_zero_embeddings = false) { + GGML_ASSERT(0 && "Not implemented yet!"); + } + + std::string remove_trigger_from_prompt(ggml_context* work_ctx, + const std::string& prompt) { + GGML_ASSERT(0 && "Not implemented yet!"); + } +}; + +struct PixArtCLIPEmbedder : public Conditioner { + T5UniGramTokenizer t5_tokenizer; + std::shared_ptr t5; + size_t chunk_len = 512; + + PixArtCLIPEmbedder(ggml_backend_t backend, + std::map& tensor_types, + int clip_skip = -1) { + t5 = std::make_shared(backend, tensor_types, "text_encoders.t5xxl.transformer"); + } + + void set_clip_skip(int clip_skip) { + } + + void get_param_tensors(std::map& tensors) { + t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer"); + } + + void alloc_params_buffer() { + t5->alloc_params_buffer(); + } + + void free_params_buffer() { + t5->free_params_buffer(); + } + + size_t get_params_buffer_size() { + size_t buffer_size = 0; + + buffer_size += t5->get_params_buffer_size(); + + return buffer_size; + } + + std::tuple, std::vector, std::vector> tokenize(std::string text, + size_t max_length = 0, + bool padding = false) { + auto parsed_attention = parse_prompt_attention(text); + + { + std::stringstream ss; + ss << "["; + for (const auto& item : parsed_attention) { + ss << "['" << item.first << "', " << item.second << "], "; + } + ss << "]"; + LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str()); + } + + auto on_new_token_cb = [&](std::string& str, std::vector& bpe_tokens) -> bool { + return false; + }; + + std::vector t5_tokens; + std::vector t5_weights; + std::vector t5_mask; + for (const auto& item : parsed_attention) { + const std::string& curr_text = item.first; + float curr_weight = item.second; + + std::vector curr_tokens = t5_tokenizer.Encode(curr_text, true); + t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end()); + t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight); + } + + t5_tokenizer.pad_tokens(t5_tokens, t5_weights, &t5_mask, max_length, padding); + + return {t5_tokens, t5_weights, t5_mask}; + } + + void modify_mask_to_attend_padding(struct ggml_tensor* mask, int max_seq_length, int num_extra_padding = 8) { + float* mask_data = (float*)mask->data; + int num_pad = 0; + for (int64_t i = 0; i < max_seq_length; i++) { + if (num_pad >= num_extra_padding) { + break; + } + if (std::isinf(mask_data[i])) { + mask_data[i] = 0; + ++num_pad; + } + } + // LOG_DEBUG("PAD: %d", num_pad); + } + + SDCondition get_learned_condition_common(ggml_context* work_ctx, + int n_threads, + std::tuple, std::vector, std::vector> token_and_weights, + int clip_skip, + bool force_zero_embeddings = false) { + auto& t5_tokens = std::get<0>(token_and_weights); + auto& t5_weights = std::get<1>(token_and_weights); + auto& t5_attn_mask_vec = std::get<2>(token_and_weights); + + int64_t t0 = ggml_time_ms(); + struct ggml_tensor* hidden_states = NULL; // [N, n_token, 4096] + struct ggml_tensor* chunk_hidden_states = NULL; // [n_token, 4096] + struct ggml_tensor* pooled = NULL; // [768,] + struct ggml_tensor* t5_attn_mask = vector_to_ggml_tensor(work_ctx, t5_attn_mask_vec); // [768,] + + std::vector hidden_states_vec; + + size_t chunk_count = t5_tokens.size() / chunk_len; + + bool use_mask = true; + const char* SD_CHROMA_USE_T5_MASK = getenv("SD_CHROMA_USE_T5_MASK"); + if (SD_CHROMA_USE_T5_MASK != nullptr) { + std::string sd_chroma_use_t5_mask_str = SD_CHROMA_USE_T5_MASK; + if (sd_chroma_use_t5_mask_str == "OFF" || sd_chroma_use_t5_mask_str == "FALSE") { + use_mask = false; + } else if (sd_chroma_use_t5_mask_str != "ON" && sd_chroma_use_t5_mask_str != "TRUE") { + LOG_WARN("SD_CHROMA_USE_T5_MASK environment variable has unexpected value. Assuming default (\"ON\"). (Expected \"ON\"/\"TRUE\" or\"OFF\"/\"FALSE\", got \"%s\")", SD_CHROMA_USE_T5_MASK); + } + } + for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) { + // t5 + std::vector chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len, + t5_tokens.begin() + (chunk_idx + 1) * chunk_len); + std::vector chunk_weights(t5_weights.begin() + chunk_idx * chunk_len, + t5_weights.begin() + (chunk_idx + 1) * chunk_len); + std::vector chunk_mask(t5_attn_mask_vec.begin() + chunk_idx * chunk_len, + t5_attn_mask_vec.begin() + (chunk_idx + 1) * chunk_len); + + auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens); + auto t5_attn_mask_chunk = use_mask ? vector_to_ggml_tensor(work_ctx, chunk_mask) : NULL; + + t5->compute(n_threads, + input_ids, + &chunk_hidden_states, + work_ctx, + t5_attn_mask_chunk); + { + auto tensor = chunk_hidden_states; + float original_mean = ggml_tensor_mean(tensor); + for (int i2 = 0; i2 < tensor->ne[2]; i2++) { + for (int i1 = 0; i1 < tensor->ne[1]; i1++) { + for (int i0 = 0; i0 < tensor->ne[0]; i0++) { + float value = ggml_tensor_get_f32(tensor, i0, i1, i2); + value *= chunk_weights[i1]; + ggml_tensor_set_f32(tensor, value, i0, i1, i2); + } + } + } + float new_mean = ggml_tensor_mean(tensor); + ggml_tensor_scale(tensor, (original_mean / new_mean)); + } + + int64_t t1 = ggml_time_ms(); + LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0); + if (force_zero_embeddings) { + float* vec = (float*)chunk_hidden_states->data; + for (int i = 0; i < ggml_nelements(chunk_hidden_states); i++) { + vec[i] = 0; + } + } + + hidden_states_vec.insert(hidden_states_vec.end(), + (float*)chunk_hidden_states->data, + ((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states)); + } + + if (hidden_states_vec.size() > 0) { + hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec); + hidden_states = ggml_reshape_2d(work_ctx, + hidden_states, + chunk_hidden_states->ne[0], + ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]); + } else { + hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, 256); + ggml_set_f32(hidden_states, 0.f); + } + + int mask_pad = 1; + const char* SD_CHROMA_MASK_PAD_OVERRIDE = getenv("SD_CHROMA_MASK_PAD_OVERRIDE"); + if (SD_CHROMA_MASK_PAD_OVERRIDE != nullptr) { + std::string mask_pad_str = SD_CHROMA_MASK_PAD_OVERRIDE; + try { + mask_pad = std::stoi(mask_pad_str); + } catch (const std::invalid_argument&) { + LOG_WARN("SD_CHROMA_MASK_PAD_OVERRIDE environment variable is not a valid integer (%s). Falling back to default (%d)", SD_CHROMA_MASK_PAD_OVERRIDE, mask_pad); + } catch (const std::out_of_range&) { + LOG_WARN("SD_CHROMA_MASK_PAD_OVERRIDE environment variable value is out of range for `int` type (%s). Falling back to default (%d)", SD_CHROMA_MASK_PAD_OVERRIDE, mask_pad); + } + } + modify_mask_to_attend_padding(t5_attn_mask, ggml_nelements(t5_attn_mask), mask_pad); + + return SDCondition(hidden_states, t5_attn_mask, NULL); + } + + SDCondition get_learned_condition(ggml_context* work_ctx, + int n_threads, + const std::string& text, + int clip_skip, + int width, + int height, + int adm_in_channels = -1, + bool force_zero_embeddings = false) { + auto tokens_and_weights = tokenize(text, chunk_len, true); return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings); } diff --git a/flux.hpp b/flux.hpp index 20ff41096..a33d9eb02 100644 --- a/flux.hpp +++ b/flux.hpp @@ -117,6 +117,7 @@ namespace Flux { struct ggml_tensor* k, struct ggml_tensor* v, struct ggml_tensor* pe, + struct ggml_tensor* mask, bool flash_attn) { // q,k,v: [N, L, n_head, d_head] // pe: [L, d_head/2, 2, 2] @@ -124,7 +125,7 @@ namespace Flux { q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head] k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head] - auto x = ggml_nn_attention_ext(ctx, q, k, v, v->ne[1], NULL, false, true, flash_attn); // [N, L, n_head*d_head] + auto x = ggml_nn_attention_ext(ctx, q, k, v, v->ne[1], mask, false, true, flash_attn); // [N, L, n_head*d_head] return x; } @@ -167,13 +168,13 @@ namespace Flux { return x; } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* pe) { + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* pe, struct ggml_tensor* mask) { // x: [N, n_token, dim] // pe: [n_token, d_head/2, 2, 2] // return [N, n_token, dim] - auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head] - x = attention(ctx, qkv[0], qkv[1], qkv[2], pe, flash_attn); // [N, n_token, dim] - x = post_attention(ctx, x); // [N, n_token, dim] + auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head] + x = attention(ctx, qkv[0], qkv[1], qkv[2], pe, mask, flash_attn); // [N, n_token, dim] + x = post_attention(ctx, x); // [N, n_token, dim] return x; } }; @@ -185,6 +186,13 @@ namespace Flux { ModulationOut(ggml_tensor* shift = NULL, ggml_tensor* scale = NULL, ggml_tensor* gate = NULL) : shift(shift), scale(scale), gate(gate) {} + + ModulationOut(struct ggml_context* ctx, ggml_tensor* vec, int64_t offset) { + int64_t stride = vec->nb[1] * vec->ne[1]; + shift = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 0)); // [N, dim] + scale = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 1)); // [N, dim] + gate = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 2)); // [N, dim] + } }; struct Modulation : public GGMLBlock { @@ -210,19 +218,12 @@ namespace Flux { auto m = ggml_reshape_3d(ctx, out, vec->ne[0], multiplier, vec->ne[1]); // [N, multiplier, dim] m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [multiplier, N, dim] - int64_t offset = m->nb[1] * m->ne[1]; - auto shift_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, dim] - auto scale_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, dim] - auto gate_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, dim] - + ModulationOut m_0 = ModulationOut(ctx, m, 0); if (is_double) { - auto shift_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, dim] - auto scale_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, dim] - auto gate_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, dim] - return {ModulationOut(shift_0, scale_0, gate_0), ModulationOut(shift_1, scale_1, gate_1)}; + return {m_0, ModulationOut(ctx, m, 3)}; } - return {ModulationOut(shift_0, scale_0, gate_0), ModulationOut()}; + return {m_0, ModulationOut()}; } }; @@ -242,25 +243,33 @@ namespace Flux { struct DoubleStreamBlock : public GGMLBlock { bool flash_attn; + bool prune_mod; + int idx = 0; public: DoubleStreamBlock(int64_t hidden_size, int64_t num_heads, float mlp_ratio, + int idx = 0, bool qkv_bias = false, - bool flash_attn = false) - : flash_attn(flash_attn) { + bool flash_attn = false, + bool prune_mod = false) + : idx(idx), flash_attn(flash_attn), prune_mod(prune_mod) { int64_t mlp_hidden_dim = hidden_size * mlp_ratio; - blocks["img_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); - blocks["img_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); - blocks["img_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn)); + if (!prune_mod) { + blocks["img_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); + } + blocks["img_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); + blocks["img_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn)); blocks["img_norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); blocks["img_mlp.0"] = std::shared_ptr(new Linear(hidden_size, mlp_hidden_dim)); // img_mlp.1 is nn.GELU(approximate="tanh") blocks["img_mlp.2"] = std::shared_ptr(new Linear(mlp_hidden_dim, hidden_size)); - blocks["txt_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); + if (!prune_mod) { + blocks["txt_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); + } blocks["txt_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); blocks["txt_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn)); @@ -270,17 +279,34 @@ namespace Flux { blocks["txt_mlp.2"] = std::shared_ptr(new Linear(mlp_hidden_dim, hidden_size)); } + std::vector get_distil_img_mod(struct ggml_context* ctx, struct ggml_tensor* vec) { + // TODO: not hardcoded? + const int single_blocks_count = 38; + const int double_blocks_count = 19; + + int64_t offset = 6 * idx + 3 * single_blocks_count; + return {ModulationOut(ctx, vec, offset), ModulationOut(ctx, vec, offset + 3)}; + } + + std::vector get_distil_txt_mod(struct ggml_context* ctx, struct ggml_tensor* vec) { + // TODO: not hardcoded? + const int single_blocks_count = 38; + const int double_blocks_count = 19; + + int64_t offset = 6 * idx + 6 * double_blocks_count + 3 * single_blocks_count; + return {ModulationOut(ctx, vec, offset), ModulationOut(ctx, vec, offset + 3)}; + } + std::pair forward(struct ggml_context* ctx, struct ggml_tensor* img, struct ggml_tensor* txt, struct ggml_tensor* vec, - struct ggml_tensor* pe) { + struct ggml_tensor* pe, + struct ggml_tensor* mask = NULL) { // img: [N, n_img_token, hidden_size] // txt: [N, n_txt_token, hidden_size] // pe: [n_img_token + n_txt_token, d_head/2, 2, 2] // return: ([N, n_img_token, hidden_size], [N, n_txt_token, hidden_size]) - - auto img_mod = std::dynamic_pointer_cast(blocks["img_mod"]); auto img_norm1 = std::dynamic_pointer_cast(blocks["img_norm1"]); auto img_attn = std::dynamic_pointer_cast(blocks["img_attn"]); @@ -288,7 +314,6 @@ namespace Flux { auto img_mlp_0 = std::dynamic_pointer_cast(blocks["img_mlp.0"]); auto img_mlp_2 = std::dynamic_pointer_cast(blocks["img_mlp.2"]); - auto txt_mod = std::dynamic_pointer_cast(blocks["txt_mod"]); auto txt_norm1 = std::dynamic_pointer_cast(blocks["txt_norm1"]); auto txt_attn = std::dynamic_pointer_cast(blocks["txt_attn"]); @@ -296,10 +321,22 @@ namespace Flux { auto txt_mlp_0 = std::dynamic_pointer_cast(blocks["txt_mlp.0"]); auto txt_mlp_2 = std::dynamic_pointer_cast(blocks["txt_mlp.2"]); - auto img_mods = img_mod->forward(ctx, vec); + std::vector img_mods; + if (prune_mod) { + img_mods = get_distil_img_mod(ctx, vec); + } else { + auto img_mod = std::dynamic_pointer_cast(blocks["img_mod"]); + img_mods = img_mod->forward(ctx, vec); + } ModulationOut img_mod1 = img_mods[0]; ModulationOut img_mod2 = img_mods[1]; - auto txt_mods = txt_mod->forward(ctx, vec); + std::vector txt_mods; + if (prune_mod) { + txt_mods = get_distil_txt_mod(ctx, vec); + } else { + auto txt_mod = std::dynamic_pointer_cast(blocks["txt_mod"]); + txt_mods = txt_mod->forward(ctx, vec); + } ModulationOut txt_mod1 = txt_mods[0]; ModulationOut txt_mod2 = txt_mods[1]; @@ -324,7 +361,7 @@ namespace Flux { auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] - auto attn = attention(ctx, q, k, v, pe, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head] + auto attn = attention(ctx, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head] attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] auto txt_attn_out = ggml_view_3d(ctx, attn, @@ -373,14 +410,18 @@ namespace Flux { int64_t hidden_size; int64_t mlp_hidden_dim; bool flash_attn; + bool prune_mod; + int idx = 0; public: SingleStreamBlock(int64_t hidden_size, int64_t num_heads, float mlp_ratio = 4.0f, + int idx = 0, float qk_scale = 0.f, - bool flash_attn = false) - : hidden_size(hidden_size), num_heads(num_heads), flash_attn(flash_attn) { + bool flash_attn = false, + bool prune_mod = false) + : hidden_size(hidden_size), num_heads(num_heads), idx(idx), flash_attn(flash_attn), prune_mod(prune_mod) { int64_t head_dim = hidden_size / num_heads; float scale = qk_scale; if (scale <= 0.f) { @@ -393,26 +434,37 @@ namespace Flux { blocks["norm"] = std::shared_ptr(new QKNorm(head_dim)); blocks["pre_norm"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); // mlp_act is nn.GELU(approximate="tanh") - blocks["modulation"] = std::shared_ptr(new Modulation(hidden_size, false)); + if (!prune_mod) { + blocks["modulation"] = std::shared_ptr(new Modulation(hidden_size, false)); + } + } + + ModulationOut get_distil_mod(struct ggml_context* ctx, struct ggml_tensor* vec) { + int64_t offset = 3 * idx; + return ModulationOut(ctx, vec, offset); } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* vec, - struct ggml_tensor* pe) { + struct ggml_tensor* pe, + struct ggml_tensor* mask = NULL) { // x: [N, n_token, hidden_size] // pe: [n_token, d_head/2, 2, 2] // return: [N, n_token, hidden_size] - auto linear1 = std::dynamic_pointer_cast(blocks["linear1"]); - auto linear2 = std::dynamic_pointer_cast(blocks["linear2"]); - auto norm = std::dynamic_pointer_cast(blocks["norm"]); - auto pre_norm = std::dynamic_pointer_cast(blocks["pre_norm"]); - auto modulation = std::dynamic_pointer_cast(blocks["modulation"]); - - auto mods = modulation->forward(ctx, vec); - ModulationOut mod = mods[0]; - + auto linear1 = std::dynamic_pointer_cast(blocks["linear1"]); + auto linear2 = std::dynamic_pointer_cast(blocks["linear2"]); + auto norm = std::dynamic_pointer_cast(blocks["norm"]); + auto pre_norm = std::dynamic_pointer_cast(blocks["pre_norm"]); + ModulationOut mod; + if (prune_mod) { + mod = get_distil_mod(ctx, vec); + } else { + auto modulation = std::dynamic_pointer_cast(blocks["modulation"]); + + mod = modulation->forward(ctx, vec)[0]; + } auto x_mod = Flux::modulate(ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale); auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim] qkv_mlp = ggml_cont(ctx, ggml_permute(ctx, qkv_mlp, 2, 0, 1, 3)); // [hidden_size * 3 + mlp_hidden_dim, N, n_token] @@ -443,7 +495,7 @@ namespace Flux { auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head] q = norm->query_norm(ctx, q); k = norm->key_norm(ctx, k); - auto attn = attention(ctx, q, k, v, pe, flash_attn); // [N, n_token, hidden_size] + auto attn = attention(ctx, q, k, v, pe, mask, flash_attn); // [N, n_token, hidden_size] auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim] auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size] @@ -454,13 +506,27 @@ namespace Flux { }; struct LastLayer : public GGMLBlock { + bool prune_mod; + public: LastLayer(int64_t hidden_size, int64_t patch_size, - int64_t out_channels) { - blocks["norm_final"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); - blocks["linear"] = std::shared_ptr(new Linear(hidden_size, patch_size * patch_size * out_channels)); - blocks["adaLN_modulation.1"] = std::shared_ptr(new Linear(hidden_size, 2 * hidden_size)); + int64_t out_channels, + bool prune_mod = false) : prune_mod(prune_mod) { + blocks["norm_final"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); + blocks["linear"] = std::shared_ptr(new Linear(hidden_size, patch_size * patch_size * out_channels)); + if (!prune_mod) { + blocks["adaLN_modulation.1"] = std::shared_ptr(new Linear(hidden_size, 2 * hidden_size)); + } + } + + ModulationOut get_distil_mod(struct ggml_context* ctx, struct ggml_tensor* vec) { + int64_t offset = vec->ne[2] - 2; + int64_t stride = vec->nb[1] * vec->ne[1]; + auto shift = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 0)); // [N, dim] + auto scale = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 1)); // [N, dim] + // No gate + return ModulationOut(shift, scale, NULL); } struct ggml_tensor* forward(struct ggml_context* ctx, @@ -469,17 +535,24 @@ namespace Flux { // x: [N, n_token, hidden_size] // c: [N, hidden_size] // return: [N, n_token, patch_size * patch_size * out_channels] - auto norm_final = std::dynamic_pointer_cast(blocks["norm_final"]); - auto linear = std::dynamic_pointer_cast(blocks["linear"]); - auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); - - auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, 2 * hidden_size] - m = ggml_reshape_3d(ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size] - m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size] - - int64_t offset = m->nb[1] * m->ne[1]; - auto shift = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] - auto scale = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] + auto norm_final = std::dynamic_pointer_cast(blocks["norm_final"]); + auto linear = std::dynamic_pointer_cast(blocks["linear"]); + struct ggml_tensor *shift, *scale; + if (prune_mod) { + auto mod = get_distil_mod(ctx, c); + shift = mod.shift; + scale = mod.scale; + } else { + auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); + + auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, 2 * hidden_size] + m = ggml_reshape_3d(ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size] + m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size] + + int64_t offset = m->nb[1] * m->ne[1]; + shift = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] + scale = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] + } x = Flux::modulate(ctx, norm_final->forward(ctx, x), shift, scale); x = linear->forward(ctx, x); @@ -488,6 +561,34 @@ namespace Flux { } }; + struct ChromaApproximator : public GGMLBlock { + int64_t inner_size = 5120; + int64_t n_layers = 5; + ChromaApproximator(int64_t in_channels = 64, int64_t hidden_size = 3072) { + blocks["in_proj"] = std::shared_ptr(new Linear(in_channels, inner_size, true)); + for (int i = 0; i < n_layers; i++) { + blocks["norms." + std::to_string(i)] = std::shared_ptr(new RMSNorm(inner_size)); + blocks["layers." + std::to_string(i)] = std::shared_ptr(new MLPEmbedder(inner_size, inner_size)); + } + blocks["out_proj"] = std::shared_ptr(new Linear(inner_size, hidden_size, true)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + auto in_proj = std::dynamic_pointer_cast(blocks["in_proj"]); + auto out_proj = std::dynamic_pointer_cast(blocks["out_proj"]); + + x = in_proj->forward(ctx, x); + for (int i = 0; i < n_layers; i++) { + auto norm = std::dynamic_pointer_cast(blocks["norms." + std::to_string(i)]); + auto embed = std::dynamic_pointer_cast(blocks["layers." + std::to_string(i)]); + x = ggml_add_inplace(ctx, x, embed->forward(ctx, norm->forward(ctx, x))); + } + x = out_proj->forward(ctx, x); + + return x; + } + }; + struct FluxParams { int64_t in_channels = 64; int64_t out_channels = 64; @@ -504,6 +605,7 @@ namespace Flux { bool qkv_bias = true; bool guidance_embed = true; bool flash_attn = true; + bool is_chroma = false; }; struct Flux : public GGMLBlock { @@ -607,6 +709,7 @@ namespace Flux { return ids; } + // Generate positional embeddings std::vector gen_pe(int h, int w, int patch_size, int bs, int context_len, int theta, const std::vector& axes_dim) { std::vector> ids = gen_ids(h, w, patch_size, bs, context_len); @@ -645,11 +748,15 @@ namespace Flux { : params(params) { int64_t pe_dim = params.hidden_size / params.num_heads; - blocks["img_in"] = std::shared_ptr(new Linear(params.in_channels, params.hidden_size, true)); - blocks["time_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); - blocks["vector_in"] = std::shared_ptr(new MLPEmbedder(params.vec_in_dim, params.hidden_size)); - if (params.guidance_embed) { - blocks["guidance_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); + blocks["img_in"] = std::shared_ptr(new Linear(params.in_channels, params.hidden_size, true)); + if (params.is_chroma) { + blocks["distilled_guidance_layer"] = std::shared_ptr(new ChromaApproximator(params.in_channels, params.hidden_size)); + } else { + blocks["time_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); + blocks["vector_in"] = std::shared_ptr(new MLPEmbedder(params.vec_in_dim, params.hidden_size)); + if (params.guidance_embed) { + blocks["guidance_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); + } } blocks["txt_in"] = std::shared_ptr(new Linear(params.context_in_dim, params.hidden_size, true)); @@ -657,19 +764,23 @@ namespace Flux { blocks["double_blocks." + std::to_string(i)] = std::shared_ptr(new DoubleStreamBlock(params.hidden_size, params.num_heads, params.mlp_ratio, + i, params.qkv_bias, - params.flash_attn)); + params.flash_attn, + params.is_chroma)); } for (int i = 0; i < params.depth_single_blocks; i++) { blocks["single_blocks." + std::to_string(i)] = std::shared_ptr(new SingleStreamBlock(params.hidden_size, params.num_heads, params.mlp_ratio, + i, 0.f, - params.flash_attn)); + params.flash_attn, + params.is_chroma)); } - blocks["final_layer"] = std::shared_ptr(new LastLayer(params.hidden_size, 1, params.out_channels)); + blocks["final_layer"] = std::shared_ptr(new LastLayer(params.hidden_size, 1, params.out_channels, params.is_chroma)); } struct ggml_tensor* patchify(struct ggml_context* ctx, @@ -726,25 +837,54 @@ namespace Flux { struct ggml_tensor* y, struct ggml_tensor* guidance, struct ggml_tensor* pe, + struct ggml_tensor* arange = NULL, std::vector skip_layers = std::vector()) { auto img_in = std::dynamic_pointer_cast(blocks["img_in"]); - auto time_in = std::dynamic_pointer_cast(blocks["time_in"]); - auto vector_in = std::dynamic_pointer_cast(blocks["vector_in"]); auto txt_in = std::dynamic_pointer_cast(blocks["txt_in"]); auto final_layer = std::dynamic_pointer_cast(blocks["final_layer"]); - img = img_in->forward(ctx, img); - auto vec = time_in->forward(ctx, ggml_nn_timestep_embedding(ctx, timesteps, 256, 10000, 1000.f)); + img = img_in->forward(ctx, img); + struct ggml_tensor* vec; + struct ggml_tensor* txt_img_mask = NULL; + if (params.is_chroma) { + int64_t mod_index_length = 344; + auto approx = std::dynamic_pointer_cast(blocks["distilled_guidance_layer"]); + auto distill_timestep = ggml_nn_timestep_embedding(ctx, timesteps, 16, 10000, 1000.f); + auto distill_guidance = ggml_nn_timestep_embedding(ctx, guidance, 16, 10000, 1000.f); + + // auto arange = ggml_arange(ctx, 0, (float)mod_index_length, 1); // Not working on a lot of backends, precomputing it on CPU instead + GGML_ASSERT(arange != NULL); + auto modulation_index = ggml_nn_timestep_embedding(ctx, arange, 32, 10000, 1000.f); // [1, 344, 32] + + // Batch broadcast (will it ever be useful) + modulation_index = ggml_repeat(ctx, modulation_index, ggml_new_tensor_3d(ctx, GGML_TYPE_F32, modulation_index->ne[0], modulation_index->ne[1], img->ne[2])); // [N, 344, 32] + + auto timestep_guidance = ggml_concat(ctx, distill_timestep, distill_guidance, 0); // [N, 1, 32] + timestep_guidance = ggml_repeat(ctx, timestep_guidance, modulation_index); // [N, 344, 32] + + vec = ggml_concat(ctx, timestep_guidance, modulation_index, 0); // [N, 344, 64] + // Permute for consistency with non-distilled modulation implementation + vec = ggml_cont(ctx, ggml_permute(ctx, vec, 0, 2, 1, 3)); // [344, N, 64] + vec = approx->forward(ctx, vec); // [344, N, hidden_size] + + if (y != NULL) { + txt_img_mask = ggml_pad(ctx, y, img->ne[1], 0, 0, 0); + } + } else { + auto time_in = std::dynamic_pointer_cast(blocks["time_in"]); + auto vector_in = std::dynamic_pointer_cast(blocks["vector_in"]); + vec = time_in->forward(ctx, ggml_nn_timestep_embedding(ctx, timesteps, 256, 10000, 1000.f)); + if (params.guidance_embed) { + GGML_ASSERT(guidance != NULL); + auto guidance_in = std::dynamic_pointer_cast(blocks["guidance_in"]); + // bf16 and fp16 result is different + auto g_in = ggml_nn_timestep_embedding(ctx, guidance, 256, 10000, 1000.f); + vec = ggml_add(ctx, vec, guidance_in->forward(ctx, g_in)); + } - if (params.guidance_embed) { - GGML_ASSERT(guidance != NULL); - auto guidance_in = std::dynamic_pointer_cast(blocks["guidance_in"]); - // bf16 and fp16 result is different - auto g_in = ggml_nn_timestep_embedding(ctx, guidance, 256, 10000, 1000.f); - vec = ggml_add(ctx, vec, guidance_in->forward(ctx, g_in)); + vec = ggml_add(ctx, vec, vector_in->forward(ctx, y)); } - vec = ggml_add(ctx, vec, vector_in->forward(ctx, y)); txt = txt_in->forward(ctx, txt); for (int i = 0; i < params.depth; i++) { @@ -754,7 +894,7 @@ namespace Flux { auto block = std::dynamic_pointer_cast(blocks["double_blocks." + std::to_string(i)]); - auto img_txt = block->forward(ctx, img, txt, vec, pe); + auto img_txt = block->forward(ctx, img, txt, vec, pe, txt_img_mask); img = img_txt.first; // [N, n_img_token, hidden_size] txt = img_txt.second; // [N, n_txt_token, hidden_size] } @@ -766,7 +906,7 @@ namespace Flux { } auto block = std::dynamic_pointer_cast(blocks["single_blocks." + std::to_string(i)]); - txt_img = block->forward(ctx, txt_img, vec, pe); + txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask); } txt_img = ggml_cont(ctx, ggml_permute(ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] @@ -781,7 +921,6 @@ namespace Flux { img = ggml_cont(ctx, ggml_permute(ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size] img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels) - return img; } @@ -793,6 +932,7 @@ namespace Flux { struct ggml_tensor* y, struct ggml_tensor* guidance, struct ggml_tensor* pe, + struct ggml_tensor* arange = NULL, std::vector skip_layers = std::vector()) { // Forward pass of DiT. // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) @@ -830,7 +970,7 @@ namespace Flux { img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0); } - auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size] + auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, arange, skip_layers); // [N, h*w, C * patch_size * patch_size] // rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2) out = unpatchify(ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size, patch_size); // [N, C, H + pad_h, W + pad_w] @@ -845,7 +985,8 @@ namespace Flux { public: FluxParams flux_params; Flux flux; - std::vector pe_vec; // for cache + std::vector pe_vec, range; // for cache + SDVersion version; FluxRunner(ggml_backend_t backend, std::map& tensor_types = empty_tensor_types, @@ -868,6 +1009,10 @@ namespace Flux { // not schnell flux_params.guidance_embed = true; } + if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) { + // Chroma + flux_params.is_chroma = true; + } size_t db = tensor_name.find("double_blocks."); if (db != std::string::npos) { tensor_name = tensor_name.substr(db); // remove prefix @@ -887,7 +1032,9 @@ namespace Flux { } LOG_INFO("Flux blocks: %d double, %d single", flux_params.depth, flux_params.depth_single_blocks); - if (!flux_params.guidance_embed) { + if (flux_params.is_chroma) { + LOG_INFO("Using pruned modulation (Chroma)"); + } else if (!flux_params.guidance_embed) { LOG_INFO("Flux guidance is disabled (Schnell mode)"); } @@ -913,14 +1060,51 @@ namespace Flux { GGML_ASSERT(x->ne[3] == 1); struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false); + struct ggml_tensor* precompute_arange = NULL; + x = to_backend(x); context = to_backend(context); if (c_concat != NULL) { c_concat = to_backend(c_concat); } - y = to_backend(y); + if (flux_params.is_chroma) { + const char* SD_CHROMA_ENABLE_GUIDANCE = getenv("SD_CHROMA_ENABLE_GUIDANCE"); + bool disable_guidance = true; + if (SD_CHROMA_ENABLE_GUIDANCE != NULL) { + std::string enable_guidance_str = SD_CHROMA_ENABLE_GUIDANCE; + if (enable_guidance_str == "ON" || enable_guidance_str == "TRUE") { + LOG_WARN("Chroma guidance has been enabled. Image might be broken. (SD_CHROMA_ENABLE_GUIDANCE env variable to \"OFF\" to disable)", SD_CHROMA_ENABLE_GUIDANCE); + disable_guidance = false; + } else if (enable_guidance_str != "OFF" && enable_guidance_str != "FALSE") { + LOG_WARN("SD_CHROMA_ENABLE_GUIDANCE environment variable has unexpected value. Assuming default (\"OFF\"). (Expected \"ON\"/\"TRUE\" or\"OFF\"/\"FALSE\", got \"%s\")", SD_CHROMA_ENABLE_GUIDANCE); + } + } + if (disable_guidance) { + LOG_DEBUG("Forcing guidance to 0 for chroma model (SD_CHROMA_ENABLE_GUIDANCE env variable to \"ON\" to enable)"); + guidance = ggml_set_f32(guidance, 0); + } + + + const char* SD_CHROMA_USE_DIT_MASK = getenv("SD_CHROMA_USE_DIT_MASK"); + if (SD_CHROMA_USE_DIT_MASK != nullptr) { + std::string sd_chroma_use_DiT_mask_str = SD_CHROMA_USE_DIT_MASK; + if (sd_chroma_use_DiT_mask_str == "OFF" || sd_chroma_use_DiT_mask_str == "FALSE") { + y = NULL; + } else if (sd_chroma_use_DiT_mask_str != "ON" && sd_chroma_use_DiT_mask_str != "TRUE") { + LOG_WARN("SD_CHROMA_USE_DIT_MASK environment variable has unexpected value. Assuming default (\"ON\"). (Expected \"ON\"/\"TRUE\" or\"OFF\"/\"FALSE\", got \"%s\")", SD_CHROMA_USE_DIT_MASK); + } + } + + // ggml_arrange is not working on some backends, and y isn't used, so let's reuse y to precompute it + range = arange(0, 344); + precompute_arange = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_F32, range.size()); + set_backend_tensor_data(precompute_arange, range.data()); + // y = NULL; + } + y = to_backend(y); + timesteps = to_backend(timesteps); - if (flux_params.guidance_embed) { + if (flux_params.guidance_embed || flux_params.is_chroma) { guidance = to_backend(guidance); } @@ -941,6 +1125,7 @@ namespace Flux { y, guidance, pe, + precompute_arange, skip_layers); ggml_build_forward_expand(gf, out); diff --git a/ggml_extend.hpp b/ggml_extend.hpp index c5913be4d..101a5d1f6 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -864,6 +864,18 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head] v = ggml_cast(ctx, v, GGML_TYPE_F16); + if (mask != nullptr) { + mask = ggml_transpose(ctx, mask); + + if (mask->ne[1] < GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD)) { + LOG_DEBUG("mask dims %ld, %ld, %ld, %ld\n", mask->ne[0], mask->ne[1], mask->ne[2], mask->ne[3]); + LOG_DEBUG("needs padding, padding from %ld to %ld\n", mask->ne[1], GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD)); + mask = ggml_pad(ctx, mask, 0, GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) - mask->ne[1], 0, 0); + } + + mask = ggml_cast(ctx, mask, GGML_TYPE_F16); + } + kqv = ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0, 0); ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32); @@ -876,7 +888,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k] kq = ggml_scale_inplace(ctx, kq, scale); if (mask) { - kq = ggml_add(ctx, kq, mask); + kq = ggml_add_inplace(ctx, kq, mask); } if (diag_mask_inf) { kq = ggml_diag_mask_inf_inplace(ctx, kq, 0); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index e38a6101f..a593284af 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -334,8 +334,19 @@ class StableDiffusionGGML { cond_stage_model = std::make_shared(clip_backend, model_loader.tensor_storages_types); diffusion_model = std::make_shared(backend, model_loader.tensor_storages_types); } else if (sd_version_is_flux(version)) { - cond_stage_model = std::make_shared(clip_backend, model_loader.tensor_storages_types); - diffusion_model = std::make_shared(backend, model_loader.tensor_storages_types, version, diffusion_flash_attn); + bool is_chroma = false; + for (auto pair : model_loader.tensor_storages_types) { + if (pair.first.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) { + is_chroma = true; + break; + } + } + if (is_chroma) { + cond_stage_model = std::make_shared(clip_backend, model_loader.tensor_storages_types); + } else { + cond_stage_model = std::make_shared(clip_backend, model_loader.tensor_storages_types); + } + diffusion_model = std::make_shared(backend, model_loader.tensor_storages_types, version, diffusion_flash_attn); } else { if (id_embeddings_path.find("v2") != std::string::npos) { cond_stage_model = std::make_shared(clip_backend, model_loader.tensor_storages_types, embeddings_path, version, PM_VERSION_2); diff --git a/t5.hpp b/t5.hpp index 2a53e2743..7edaf8041 100644 --- a/t5.hpp +++ b/t5.hpp @@ -385,6 +385,7 @@ class T5UniGramTokenizer { void pad_tokens(std::vector& tokens, std::vector& weights, + std::vector* attention_mask, size_t max_length = 0, bool padding = false) { if (max_length > 0 && padding) { @@ -397,11 +398,15 @@ class T5UniGramTokenizer { LOG_DEBUG("token length: %llu", length); std::vector new_tokens; std::vector new_weights; + std::vector new_attention_mask; int token_idx = 0; for (int i = 0; i < length; i++) { if (token_idx >= orig_token_num) { break; } + if (attention_mask != nullptr) { + new_attention_mask.push_back(0.0); + } if (i % max_length == max_length - 1) { new_tokens.push_back(eos_id_); new_weights.push_back(1.0); @@ -414,13 +419,24 @@ class T5UniGramTokenizer { new_tokens.push_back(eos_id_); new_weights.push_back(1.0); + if (attention_mask != nullptr) { + new_attention_mask.push_back(0.0); + } + tokens = new_tokens; weights = new_weights; + if (attention_mask != nullptr) { + *attention_mask = new_attention_mask; + } if (padding) { int pad_token_id = pad_id_; tokens.insert(tokens.end(), length - tokens.size(), pad_token_id); weights.insert(weights.end(), length - weights.size(), 1.0); + if (attention_mask != nullptr) { + // maybe keep some padding tokens unmasked? + attention_mask->insert(attention_mask->end(), length - attention_mask->size(), -HUGE_VALF); + } } } } @@ -579,6 +595,7 @@ class T5Attention : public GGMLBlock { } if (past_bias != NULL) { if (mask != NULL) { + mask = ggml_repeat(ctx, mask, past_bias); mask = ggml_add(ctx, mask, past_bias); } else { mask = past_bias; @@ -739,15 +756,17 @@ struct T5Runner : public GGMLRunner { struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* input_ids, - struct ggml_tensor* relative_position_bucket) { + struct ggml_tensor* relative_position_bucket, + struct ggml_tensor* attention_mask = NULL) { size_t N = input_ids->ne[1]; size_t n_token = input_ids->ne[0]; - auto hidden_states = model.forward(ctx, input_ids, NULL, NULL, relative_position_bucket); // [N, n_token, model_dim] + auto hidden_states = model.forward(ctx, input_ids, NULL, attention_mask, relative_position_bucket); // [N, n_token, model_dim] return hidden_states; } - struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids) { + struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids, + struct ggml_tensor* attention_mask = NULL) { struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); input_ids = to_backend(input_ids); @@ -767,7 +786,7 @@ struct T5Runner : public GGMLRunner { input_ids->ne[0]); set_backend_tensor_data(relative_position_bucket, relative_position_bucket_vec.data()); - struct ggml_tensor* hidden_states = forward(compute_ctx, input_ids, relative_position_bucket); + struct ggml_tensor* hidden_states = forward(compute_ctx, input_ids, relative_position_bucket, attention_mask); ggml_build_forward_expand(gf, hidden_states); @@ -777,9 +796,10 @@ struct T5Runner : public GGMLRunner { void compute(const int n_threads, struct ggml_tensor* input_ids, ggml_tensor** output, - ggml_context* output_ctx = NULL) { + ggml_context* output_ctx = NULL, + struct ggml_tensor* attention_mask = NULL) { auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(input_ids); + return build_graph(input_ids, attention_mask); }; GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); } @@ -877,9 +897,9 @@ struct T5Embedder { model.alloc_params_buffer(); } - std::pair, std::vector> tokenize(std::string text, - size_t max_length = 0, - bool padding = false) { + std::tuple, std::vector, std::vector> tokenize(std::string text, + size_t max_length = 0, + bool padding = false) { auto parsed_attention = parse_prompt_attention(text); { @@ -906,14 +926,16 @@ struct T5Embedder { tokens.push_back(EOS_TOKEN_ID); weights.push_back(1.0); - tokenizer.pad_tokens(tokens, weights, max_length, padding); + std::vector attention_mask; + + tokenizer.pad_tokens(tokens, weights, &attention_mask, max_length, padding); // for (int i = 0; i < tokens.size(); i++) { // std::cout << tokens[i] << ":" << weights[i] << ", "; // } // std::cout << std::endl; - return {tokens, weights}; + return {tokens, weights, attention_mask}; } void test() { @@ -934,8 +956,8 @@ struct T5Embedder { // TODO: fix cuda nan std::string text("a lovely cat"); auto tokens_and_weights = tokenize(text, 77, true); - std::vector& tokens = tokens_and_weights.first; - std::vector& weights = tokens_and_weights.second; + std::vector& tokens = std::get<0>(tokens_and_weights); + std::vector& weights = std::get<1>(tokens_and_weights); for (auto token : tokens) { printf("%d ", token); }