From f7ad456bd1ee0a85be5678d4bb7736a9188ec1d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Wed, 28 May 2025 18:24:58 +0200 Subject: [PATCH 01/17] Chroma: Initial commit (broken output) --- flux.hpp | 282 +++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 213 insertions(+), 69 deletions(-) diff --git a/flux.hpp b/flux.hpp index 20ff41096..ca5772fea 100644 --- a/flux.hpp +++ b/flux.hpp @@ -185,6 +185,13 @@ namespace Flux { ModulationOut(ggml_tensor* shift = NULL, ggml_tensor* scale = NULL, ggml_tensor* gate = NULL) : shift(shift), scale(scale), gate(gate) {} + + ModulationOut(struct ggml_context* ctx, ggml_tensor* vec, int64_t offset) { + int64_t stride = vec->nb[1] * vec->ne[1]; + shift = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 0)); // [N, dim] + scale = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 1)); // [N, dim] + gate = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 2)); // [N, dim] + } }; struct Modulation : public GGMLBlock { @@ -210,19 +217,12 @@ namespace Flux { auto m = ggml_reshape_3d(ctx, out, vec->ne[0], multiplier, vec->ne[1]); // [N, multiplier, dim] m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [multiplier, N, dim] - int64_t offset = m->nb[1] * m->ne[1]; - auto shift_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, dim] - auto scale_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, dim] - auto gate_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, dim] - + ModulationOut m_0 = ModulationOut(ctx, m, 0); if (is_double) { - auto shift_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, dim] - auto scale_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, dim] - auto gate_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, dim] - return {ModulationOut(shift_0, scale_0, gate_0), ModulationOut(shift_1, scale_1, gate_1)}; + return {m_0, ModulationOut(ctx, m, 3)}; } - return {ModulationOut(shift_0, scale_0, gate_0), ModulationOut()}; + return {m_0, ModulationOut()}; } }; @@ -242,25 +242,33 @@ namespace Flux { struct DoubleStreamBlock : public GGMLBlock { bool flash_attn; + bool prune_mod; + int idx = 0; public: DoubleStreamBlock(int64_t hidden_size, int64_t num_heads, float mlp_ratio, + int idx = 0, bool qkv_bias = false, - bool flash_attn = false) - : flash_attn(flash_attn) { + bool flash_attn = false, + bool prune_mod = false) + : idx(idx), flash_attn(flash_attn), prune_mod(prune_mod) { int64_t mlp_hidden_dim = hidden_size * mlp_ratio; - blocks["img_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); - blocks["img_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); - blocks["img_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn)); + if (!prune_mod) { + blocks["img_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); + } + blocks["img_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); + blocks["img_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn)); blocks["img_norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); blocks["img_mlp.0"] = std::shared_ptr(new Linear(hidden_size, mlp_hidden_dim)); // img_mlp.1 is nn.GELU(approximate="tanh") blocks["img_mlp.2"] = std::shared_ptr(new Linear(mlp_hidden_dim, hidden_size)); - blocks["txt_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); + if (!prune_mod) { + blocks["txt_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); + } blocks["txt_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); blocks["txt_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn)); @@ -270,6 +278,25 @@ namespace Flux { blocks["txt_mlp.2"] = std::shared_ptr(new Linear(mlp_hidden_dim, hidden_size)); } + std::vector get_distil_img_mod(struct ggml_context* ctx, struct ggml_tensor* vec) { + // TODO: not hardcoded? + const int single_blocks_count = 38; + const int double_blocks_count = 19; + + int64_t offset = 6 * idx + 3 * single_blocks_count; + return {ModulationOut(ctx, vec, offset), ModulationOut(ctx, vec, offset + 3)}; + } + + std::vector get_distil_txt_mod(struct ggml_context* ctx, struct ggml_tensor* vec) { + // TODO: not hardcoded? + const int single_blocks_count = 38; + const int double_blocks_count = 19; + + int64_t offset = 6 * idx + 6 * double_blocks_count + 3 * single_blocks_count; + return {ModulationOut(ctx, vec, offset), ModulationOut(ctx, vec, offset + 3)}; + } + + // TODO: chroma (prune_mod) -> get modulations from offset vec std::pair forward(struct ggml_context* ctx, struct ggml_tensor* img, struct ggml_tensor* txt, @@ -279,8 +306,6 @@ namespace Flux { // txt: [N, n_txt_token, hidden_size] // pe: [n_img_token + n_txt_token, d_head/2, 2, 2] // return: ([N, n_img_token, hidden_size], [N, n_txt_token, hidden_size]) - - auto img_mod = std::dynamic_pointer_cast(blocks["img_mod"]); auto img_norm1 = std::dynamic_pointer_cast(blocks["img_norm1"]); auto img_attn = std::dynamic_pointer_cast(blocks["img_attn"]); @@ -288,7 +313,6 @@ namespace Flux { auto img_mlp_0 = std::dynamic_pointer_cast(blocks["img_mlp.0"]); auto img_mlp_2 = std::dynamic_pointer_cast(blocks["img_mlp.2"]); - auto txt_mod = std::dynamic_pointer_cast(blocks["txt_mod"]); auto txt_norm1 = std::dynamic_pointer_cast(blocks["txt_norm1"]); auto txt_attn = std::dynamic_pointer_cast(blocks["txt_attn"]); @@ -296,10 +320,22 @@ namespace Flux { auto txt_mlp_0 = std::dynamic_pointer_cast(blocks["txt_mlp.0"]); auto txt_mlp_2 = std::dynamic_pointer_cast(blocks["txt_mlp.2"]); - auto img_mods = img_mod->forward(ctx, vec); + std::vector img_mods; + if (prune_mod) { + img_mods = get_distil_img_mod(ctx, vec); + } else { + auto img_mod = std::dynamic_pointer_cast(blocks["img_mod"]); + img_mods = img_mod->forward(ctx, vec); + } ModulationOut img_mod1 = img_mods[0]; ModulationOut img_mod2 = img_mods[1]; - auto txt_mods = txt_mod->forward(ctx, vec); + std::vector txt_mods; + if (prune_mod) { + txt_mods = get_distil_txt_mod(ctx, vec); + } else { + auto txt_mod = std::dynamic_pointer_cast(blocks["txt_mod"]); + txt_mods = txt_mod->forward(ctx, vec); + } ModulationOut txt_mod1 = txt_mods[0]; ModulationOut txt_mod2 = txt_mods[1]; @@ -373,14 +409,18 @@ namespace Flux { int64_t hidden_size; int64_t mlp_hidden_dim; bool flash_attn; + bool prune_mod; + int idx = 0; public: SingleStreamBlock(int64_t hidden_size, int64_t num_heads, float mlp_ratio = 4.0f, + int idx = 0, float qk_scale = 0.f, - bool flash_attn = false) - : hidden_size(hidden_size), num_heads(num_heads), flash_attn(flash_attn) { + bool flash_attn = false, + bool prune_mod = false) + : hidden_size(hidden_size), num_heads(num_heads), idx(idx), flash_attn(flash_attn), prune_mod(prune_mod) { int64_t head_dim = hidden_size / num_heads; float scale = qk_scale; if (scale <= 0.f) { @@ -393,7 +433,14 @@ namespace Flux { blocks["norm"] = std::shared_ptr(new QKNorm(head_dim)); blocks["pre_norm"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); // mlp_act is nn.GELU(approximate="tanh") - blocks["modulation"] = std::shared_ptr(new Modulation(hidden_size, false)); + if (!prune_mod) { + blocks["modulation"] = std::shared_ptr(new Modulation(hidden_size, false)); + } + } + + ModulationOut get_distil_mod(struct ggml_context* ctx, struct ggml_tensor* vec) { + int64_t offset = 3 * idx; + return ModulationOut(ctx, vec, offset); } struct ggml_tensor* forward(struct ggml_context* ctx, @@ -404,15 +451,18 @@ namespace Flux { // pe: [n_token, d_head/2, 2, 2] // return: [N, n_token, hidden_size] - auto linear1 = std::dynamic_pointer_cast(blocks["linear1"]); - auto linear2 = std::dynamic_pointer_cast(blocks["linear2"]); - auto norm = std::dynamic_pointer_cast(blocks["norm"]); - auto pre_norm = std::dynamic_pointer_cast(blocks["pre_norm"]); - auto modulation = std::dynamic_pointer_cast(blocks["modulation"]); - - auto mods = modulation->forward(ctx, vec); - ModulationOut mod = mods[0]; - + auto linear1 = std::dynamic_pointer_cast(blocks["linear1"]); + auto linear2 = std::dynamic_pointer_cast(blocks["linear2"]); + auto norm = std::dynamic_pointer_cast(blocks["norm"]); + auto pre_norm = std::dynamic_pointer_cast(blocks["pre_norm"]); + ModulationOut mod; + if (prune_mod) { + mod = get_distil_mod(ctx, vec); + } else { + auto modulation = std::dynamic_pointer_cast(blocks["modulation"]); + + mod = modulation->forward(ctx, vec)[0]; + } auto x_mod = Flux::modulate(ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale); auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim] qkv_mlp = ggml_cont(ctx, ggml_permute(ctx, qkv_mlp, 2, 0, 1, 3)); // [hidden_size * 3 + mlp_hidden_dim, N, n_token] @@ -454,13 +504,27 @@ namespace Flux { }; struct LastLayer : public GGMLBlock { + bool prune_mod; + public: LastLayer(int64_t hidden_size, int64_t patch_size, - int64_t out_channels) { - blocks["norm_final"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); - blocks["linear"] = std::shared_ptr(new Linear(hidden_size, patch_size * patch_size * out_channels)); - blocks["adaLN_modulation.1"] = std::shared_ptr(new Linear(hidden_size, 2 * hidden_size)); + int64_t out_channels, + bool prune_mod = false) : prune_mod(prune_mod) { + blocks["norm_final"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); + blocks["linear"] = std::shared_ptr(new Linear(hidden_size, patch_size * patch_size * out_channels)); + if (!prune_mod) { + blocks["adaLN_modulation.1"] = std::shared_ptr(new Linear(hidden_size, 2 * hidden_size)); + } + } + + ModulationOut get_distil_mod(struct ggml_context* ctx, struct ggml_tensor* vec) { + int64_t offset = vec->ne[2] - 2; + int64_t stride = vec->nb[1] * vec->ne[1]; + auto shift = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 0)); // [N, dim] + auto scale = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 1)); // [N, dim] + // No gate + return ModulationOut(shift, scale, NULL); } struct ggml_tensor* forward(struct ggml_context* ctx, @@ -469,17 +533,24 @@ namespace Flux { // x: [N, n_token, hidden_size] // c: [N, hidden_size] // return: [N, n_token, patch_size * patch_size * out_channels] - auto norm_final = std::dynamic_pointer_cast(blocks["norm_final"]); - auto linear = std::dynamic_pointer_cast(blocks["linear"]); - auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); - - auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, 2 * hidden_size] - m = ggml_reshape_3d(ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size] - m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size] - - int64_t offset = m->nb[1] * m->ne[1]; - auto shift = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] - auto scale = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] + auto norm_final = std::dynamic_pointer_cast(blocks["norm_final"]); + auto linear = std::dynamic_pointer_cast(blocks["linear"]); + struct ggml_tensor *shift, *scale; + if (prune_mod) { + auto mod = get_distil_mod(ctx, c); + shift = mod.shift; + scale = mod.scale; + } else { + auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); + + auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, 2 * hidden_size] + m = ggml_reshape_3d(ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size] + m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size] + + int64_t offset = m->nb[1] * m->ne[1]; + shift = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] + scale = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] + } x = Flux::modulate(ctx, norm_final->forward(ctx, x), shift, scale); x = linear->forward(ctx, x); @@ -488,6 +559,34 @@ namespace Flux { } }; + struct ChromaApproximator : public GGMLBlock { + int64_t inner_size = 5120; + int64_t n_layers = 5; + ChromaApproximator(int64_t in_channels = 64, int64_t hidden_size = 3072) { + blocks["in_proj"] = std::shared_ptr(new Linear(in_channels, inner_size, true)); + for (int i = 0; i < n_layers; i++) { + blocks["norms." + std::to_string(i)] = std::shared_ptr(new RMSNorm(inner_size)); + blocks["layers." + std::to_string(i)] = std::shared_ptr(new MLPEmbedder(inner_size, inner_size)); + } + blocks["out_proj"] = std::shared_ptr(new Linear(inner_size, hidden_size, true)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + auto in_proj = std::dynamic_pointer_cast(blocks["in_proj"]); + auto out_proj = std::dynamic_pointer_cast(blocks["out_proj"]); + + x = in_proj->forward(ctx, x); + for (int i = 0; i < n_layers; i++) { + auto norm = std::dynamic_pointer_cast(blocks["norms." + std::to_string(i)]); + auto embed = std::dynamic_pointer_cast(blocks["layers." + std::to_string(i)]); + x = ggml_add_inplace(ctx, x, embed->forward(ctx, norm->forward(ctx, x))); + } + x = out_proj->forward(ctx, x); + + return x; + } + }; + struct FluxParams { int64_t in_channels = 64; int64_t out_channels = 64; @@ -504,6 +603,7 @@ namespace Flux { bool qkv_bias = true; bool guidance_embed = true; bool flash_attn = true; + bool chroma_guidance = false; }; struct Flux : public GGMLBlock { @@ -645,11 +745,15 @@ namespace Flux { : params(params) { int64_t pe_dim = params.hidden_size / params.num_heads; - blocks["img_in"] = std::shared_ptr(new Linear(params.in_channels, params.hidden_size, true)); - blocks["time_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); - blocks["vector_in"] = std::shared_ptr(new MLPEmbedder(params.vec_in_dim, params.hidden_size)); - if (params.guidance_embed) { - blocks["guidance_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); + blocks["img_in"] = std::shared_ptr(new Linear(params.in_channels, params.hidden_size, true)); + if (params.chroma_guidance) { + blocks["distilled_guidance_layer"] = std::shared_ptr(new ChromaApproximator(params.in_channels, params.hidden_size)); + } else { + blocks["time_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); + blocks["vector_in"] = std::shared_ptr(new MLPEmbedder(params.vec_in_dim, params.hidden_size)); + if (params.guidance_embed) { + blocks["guidance_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); + } } blocks["txt_in"] = std::shared_ptr(new Linear(params.context_in_dim, params.hidden_size, true)); @@ -657,19 +761,24 @@ namespace Flux { blocks["double_blocks." + std::to_string(i)] = std::shared_ptr(new DoubleStreamBlock(params.hidden_size, params.num_heads, params.mlp_ratio, + i, params.qkv_bias, - params.flash_attn)); + params.flash_attn, + params.chroma_guidance)); } for (int i = 0; i < params.depth_single_blocks; i++) { blocks["single_blocks." + std::to_string(i)] = std::shared_ptr(new SingleStreamBlock(params.hidden_size, params.num_heads, params.mlp_ratio, + i, 0.f, - params.flash_attn)); + params.flash_attn, + params.chroma_guidance)); } - blocks["final_layer"] = std::shared_ptr(new LastLayer(params.hidden_size, 1, params.out_channels)); + // TODO: no modulation for chroma + blocks["final_layer"] = std::shared_ptr(new LastLayer(params.hidden_size, 1, params.out_channels, params.chroma_guidance)); } struct ggml_tensor* patchify(struct ggml_context* ctx, @@ -728,23 +837,45 @@ namespace Flux { struct ggml_tensor* pe, std::vector skip_layers = std::vector()) { auto img_in = std::dynamic_pointer_cast(blocks["img_in"]); - auto time_in = std::dynamic_pointer_cast(blocks["time_in"]); - auto vector_in = std::dynamic_pointer_cast(blocks["vector_in"]); auto txt_in = std::dynamic_pointer_cast(blocks["txt_in"]); auto final_layer = std::dynamic_pointer_cast(blocks["final_layer"]); - img = img_in->forward(ctx, img); - auto vec = time_in->forward(ctx, ggml_nn_timestep_embedding(ctx, timesteps, 256, 10000, 1000.f)); + img = img_in->forward(ctx, img); + struct ggml_tensor* vec; + if (params.chroma_guidance) { + int64_t mod_index_length = 344; + auto approx = std::dynamic_pointer_cast(blocks["distilled_guidance_layer"]); + auto distill_timestep = ggml_nn_timestep_embedding(ctx, timesteps, 16, 10000, 1000.f); + auto distill_guidance = ggml_nn_timestep_embedding(ctx, guidance, 16, 10000, 1000.f); + + // auto arrange = ggml_arange(ctx, 0, (float)mod_index_length, 1); // Not working on a lot of backends + auto arrange = y; + auto modulation_index = ggml_nn_timestep_embedding(ctx, arrange, 32, 10000, 1000.f); + + auto timestep_guidance = ggml_concat(ctx, distill_timestep, distill_guidance, 0); + timestep_guidance = ggml_repeat(ctx, distill_timestep, modulation_index); + // TODO Batch broadcast? + + vec = ggml_concat(ctx, timestep_guidance, modulation_index, 0); // [N, 344, 64] + vec = approx->forward(ctx, vec); // [N, 344, hidden_size] + + // Permute for consistency with non-distilled modulation implementation + vec = ggml_cont(ctx, ggml_permute(ctx, vec, 0, 2, 1, 3)); // [344, N, hidden_size] + } else { + auto time_in = std::dynamic_pointer_cast(blocks["time_in"]); + auto vector_in = std::dynamic_pointer_cast(blocks["vector_in"]); + vec = time_in->forward(ctx, ggml_nn_timestep_embedding(ctx, timesteps, 256, 10000, 1000.f)); + if (params.guidance_embed) { + GGML_ASSERT(guidance != NULL); + auto guidance_in = std::dynamic_pointer_cast(blocks["guidance_in"]); + // bf16 and fp16 result is different + auto g_in = ggml_nn_timestep_embedding(ctx, guidance, 256, 10000, 1000.f); + vec = ggml_add(ctx, vec, guidance_in->forward(ctx, g_in)); + } - if (params.guidance_embed) { - GGML_ASSERT(guidance != NULL); - auto guidance_in = std::dynamic_pointer_cast(blocks["guidance_in"]); - // bf16 and fp16 result is different - auto g_in = ggml_nn_timestep_embedding(ctx, guidance, 256, 10000, 1000.f); - vec = ggml_add(ctx, vec, guidance_in->forward(ctx, g_in)); + vec = ggml_add(ctx, vec, vector_in->forward(ctx, y)); } - vec = ggml_add(ctx, vec, vector_in->forward(ctx, y)); txt = txt_in->forward(ctx, txt); for (int i = 0; i < params.depth; i++) { @@ -868,6 +999,10 @@ namespace Flux { // not schnell flux_params.guidance_embed = true; } + if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) { + // Chroma + flux_params.chroma_guidance = true; + } size_t db = tensor_name.find("double_blocks."); if (db != std::string::npos) { tensor_name = tensor_name.substr(db); // remove prefix @@ -887,7 +1022,9 @@ namespace Flux { } LOG_INFO("Flux blocks: %d double, %d single", flux_params.depth, flux_params.depth_single_blocks); - if (!flux_params.guidance_embed) { + if (flux_params.chroma_guidance) { + LOG_INFO("Using pruned modulation (Chroma)"); + } else if (!flux_params.guidance_embed) { LOG_INFO("Flux guidance is disabled (Schnell mode)"); } @@ -918,7 +1055,14 @@ namespace Flux { if (c_concat != NULL) { c_concat = to_backend(c_concat); } - y = to_backend(y); + if (!flux_params.chroma_guidance) { + y = to_backend(y); + } else { + // ggml_arrange is not working on some backends, so let's reuse y to precompute it + std::vector range = arange(0, 344); + y = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_F32, range.size()); + set_backend_tensor_data(y, range.data()); + } timesteps = to_backend(timesteps); if (flux_params.guidance_embed) { guidance = to_backend(guidance); From ad39011c60bf5e4163f7595a4a29740be887dfe5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Thu, 29 May 2025 11:31:09 +0200 Subject: [PATCH 02/17] Fix small mistake (still broken) --- flux.hpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/flux.hpp b/flux.hpp index ca5772fea..6fd54327a 100644 --- a/flux.hpp +++ b/flux.hpp @@ -850,12 +850,15 @@ namespace Flux { // auto arrange = ggml_arange(ctx, 0, (float)mod_index_length, 1); // Not working on a lot of backends auto arrange = y; - auto modulation_index = ggml_nn_timestep_embedding(ctx, arrange, 32, 10000, 1000.f); + auto modulation_index = ggml_nn_timestep_embedding(ctx, arrange, 32, 10000, 1000.f);// [1, 344, 32] + + // Batch broadcast (will it ever be useful) + modulation_index = ggml_repeat(ctx, modulation_index, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, modulation_index->ne[0], modulation_index->ne[1], img->ne[2], modulation_index->ne[3]));// [N, 344, 32] - auto timestep_guidance = ggml_concat(ctx, distill_timestep, distill_guidance, 0); - timestep_guidance = ggml_repeat(ctx, distill_timestep, modulation_index); - // TODO Batch broadcast? + auto timestep_guidance = ggml_concat(ctx, distill_timestep, distill_guidance, 0); // [N, 1, 32] + timestep_guidance = ggml_repeat(ctx, timestep_guidance, modulation_index); // [N, 344, 32] + vec = ggml_concat(ctx, timestep_guidance, modulation_index, 0); // [N, 344, 64] vec = approx->forward(ctx, vec); // [N, 344, hidden_size] @@ -1064,7 +1067,7 @@ namespace Flux { set_backend_tensor_data(y, range.data()); } timesteps = to_backend(timesteps); - if (flux_params.guidance_embed) { + if (flux_params.guidance_embed || flux_params.is_chroma) { guidance = to_backend(guidance); } From 93ed7219b5e88e6c4812c8b70f53a53bbbaceaa4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Thu, 29 May 2025 11:31:24 +0200 Subject: [PATCH 03/17] is_chroma --- flux.hpp | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/flux.hpp b/flux.hpp index 6fd54327a..bd63f764a 100644 --- a/flux.hpp +++ b/flux.hpp @@ -603,7 +603,7 @@ namespace Flux { bool qkv_bias = true; bool guidance_embed = true; bool flash_attn = true; - bool chroma_guidance = false; + bool is_chroma = false; }; struct Flux : public GGMLBlock { @@ -746,7 +746,7 @@ namespace Flux { int64_t pe_dim = params.hidden_size / params.num_heads; blocks["img_in"] = std::shared_ptr(new Linear(params.in_channels, params.hidden_size, true)); - if (params.chroma_guidance) { + if (params.is_chroma) { blocks["distilled_guidance_layer"] = std::shared_ptr(new ChromaApproximator(params.in_channels, params.hidden_size)); } else { blocks["time_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); @@ -764,7 +764,7 @@ namespace Flux { i, params.qkv_bias, params.flash_attn, - params.chroma_guidance)); + params.is_chroma)); } for (int i = 0; i < params.depth_single_blocks; i++) { @@ -774,11 +774,11 @@ namespace Flux { i, 0.f, params.flash_attn, - params.chroma_guidance)); + params.is_chroma)); } // TODO: no modulation for chroma - blocks["final_layer"] = std::shared_ptr(new LastLayer(params.hidden_size, 1, params.out_channels, params.chroma_guidance)); + blocks["final_layer"] = std::shared_ptr(new LastLayer(params.hidden_size, 1, params.out_channels, params.is_chroma)); } struct ggml_tensor* patchify(struct ggml_context* ctx, @@ -842,7 +842,7 @@ namespace Flux { img = img_in->forward(ctx, img); struct ggml_tensor* vec; - if (params.chroma_guidance) { + if (params.is_chroma) { int64_t mod_index_length = 344; auto approx = std::dynamic_pointer_cast(blocks["distilled_guidance_layer"]); auto distill_timestep = ggml_nn_timestep_embedding(ctx, timesteps, 16, 10000, 1000.f); @@ -915,7 +915,6 @@ namespace Flux { img = ggml_cont(ctx, ggml_permute(ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size] img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels) - return img; } @@ -1004,7 +1003,7 @@ namespace Flux { } if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) { // Chroma - flux_params.chroma_guidance = true; + flux_params.is_chroma = true; } size_t db = tensor_name.find("double_blocks."); if (db != std::string::npos) { @@ -1025,7 +1024,7 @@ namespace Flux { } LOG_INFO("Flux blocks: %d double, %d single", flux_params.depth, flux_params.depth_single_blocks); - if (flux_params.chroma_guidance) { + if (flux_params.is_chroma) { LOG_INFO("Using pruned modulation (Chroma)"); } else if (!flux_params.guidance_embed) { LOG_INFO("Flux guidance is disabled (Schnell mode)"); @@ -1058,10 +1057,10 @@ namespace Flux { if (c_concat != NULL) { c_concat = to_backend(c_concat); } - if (!flux_params.chroma_guidance) { + if (!flux_params.is_chroma) { y = to_backend(y); } else { - // ggml_arrange is not working on some backends, so let's reuse y to precompute it + // ggml_arrange is not working on some backends, and y isn't used, so let's reuse y to precompute it std::vector range = arange(0, 344); y = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_F32, range.size()); set_backend_tensor_data(y, range.data()); From d20f77f4d4e5df08a9a09bfd2b823455abb08402 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Fri, 30 May 2025 15:10:12 +0200 Subject: [PATCH 04/17] Reshape before approx --- flux.hpp | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/flux.hpp b/flux.hpp index bd63f764a..c03d0e02f 100644 --- a/flux.hpp +++ b/flux.hpp @@ -603,7 +603,7 @@ namespace Flux { bool qkv_bias = true; bool guidance_embed = true; bool flash_attn = true; - bool is_chroma = false; + bool is_chroma = false; }; struct Flux : public GGMLBlock { @@ -850,20 +850,19 @@ namespace Flux { // auto arrange = ggml_arange(ctx, 0, (float)mod_index_length, 1); // Not working on a lot of backends auto arrange = y; - auto modulation_index = ggml_nn_timestep_embedding(ctx, arrange, 32, 10000, 1000.f);// [1, 344, 32] - + auto modulation_index = ggml_nn_timestep_embedding(ctx, arrange, 32, 10000, 1000.f); // [1, 344, 32] + // Batch broadcast (will it ever be useful) - modulation_index = ggml_repeat(ctx, modulation_index, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, modulation_index->ne[0], modulation_index->ne[1], img->ne[2], modulation_index->ne[3]));// [N, 344, 32] + modulation_index = ggml_repeat(ctx, modulation_index, ggml_new_tensor_3d(ctx, GGML_TYPE_F32, modulation_index->ne[0], modulation_index->ne[1], img->ne[2])); // [N, 344, 32] + auto timestep_guidance = ggml_concat(ctx, distill_timestep, distill_guidance, 0); // [N, 1, 32] + timestep_guidance = ggml_repeat(ctx, timestep_guidance, modulation_index); // [N, 344, 32] - auto timestep_guidance = ggml_concat(ctx, distill_timestep, distill_guidance, 0); // [N, 1, 32] - timestep_guidance = ggml_repeat(ctx, timestep_guidance, modulation_index); // [N, 344, 32] - vec = ggml_concat(ctx, timestep_guidance, modulation_index, 0); // [N, 344, 64] - vec = approx->forward(ctx, vec); // [N, 344, hidden_size] - // Permute for consistency with non-distilled modulation implementation - vec = ggml_cont(ctx, ggml_permute(ctx, vec, 0, 2, 1, 3)); // [344, N, hidden_size] + vec = ggml_cont(ctx, ggml_permute(ctx, vec, 0, 2, 1, 3)); // [344, N, 64] + vec = approx->forward(ctx, vec); // [344, N, hidden_size] + } else { auto time_in = std::dynamic_pointer_cast(blocks["time_in"]); auto vector_in = std::dynamic_pointer_cast(blocks["vector_in"]); From bb1fe2cfe5a98a105cdb5cf286af9a7882716e9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Fri, 11 Apr 2025 01:26:01 +0200 Subject: [PATCH 05/17] format --- flux.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flux.hpp b/flux.hpp index c03d0e02f..5b75be4db 100644 --- a/flux.hpp +++ b/flux.hpp @@ -861,7 +861,7 @@ namespace Flux { vec = ggml_concat(ctx, timestep_guidance, modulation_index, 0); // [N, 344, 64] // Permute for consistency with non-distilled modulation implementation vec = ggml_cont(ctx, ggml_permute(ctx, vec, 0, 2, 1, 3)); // [344, N, 64] - vec = approx->forward(ctx, vec); // [344, N, hidden_size] + vec = approx->forward(ctx, vec); // [344, N, hidden_size] } else { auto time_in = std::dynamic_pointer_cast(blocks["time_in"]); From f506a6323b89170b3edada6678bbd963ae65a1f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Fri, 30 May 2025 19:40:09 +0200 Subject: [PATCH 06/17] Fix use_after_free (hopefully) --- flux.hpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/flux.hpp b/flux.hpp index 5b75be4db..9d5b1e4e8 100644 --- a/flux.hpp +++ b/flux.hpp @@ -977,7 +977,8 @@ namespace Flux { public: FluxParams flux_params; Flux flux; - std::vector pe_vec; // for cache + std::vector pe_vec, range; // for cache + SDVersion version; FluxRunner(ggml_backend_t backend, std::map& tensor_types = empty_tensor_types, @@ -1060,8 +1061,8 @@ namespace Flux { y = to_backend(y); } else { // ggml_arrange is not working on some backends, and y isn't used, so let's reuse y to precompute it - std::vector range = arange(0, 344); - y = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_F32, range.size()); + range = arange(0, 344); + y = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_F32, range.size()); set_backend_tensor_data(y, range.data()); } timesteps = to_backend(timesteps); From 836fd723264392e5b9cfc16490533e262113f5a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Sun, 1 Jun 2025 22:42:03 +0200 Subject: [PATCH 07/17] Chroma: Attention masking (no pad) --- conditioner.hpp | 182 ++++++++++++++++++++++++++++++++++++++++++- flux.hpp | 60 +++++++++----- stable-diffusion.cpp | 15 +++- t5.hpp | 47 +++++++---- 4 files changed, 266 insertions(+), 38 deletions(-) diff --git a/conditioner.hpp b/conditioner.hpp index 6e9acdb19..1838d1134 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -747,7 +747,7 @@ struct SD3CLIPEmbedder : public Conditioner { clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, max_length, padding); clip_g_tokenizer.pad_tokens(clip_g_tokens, clip_g_weights, max_length, padding); - t5_tokenizer.pad_tokens(t5_tokens, t5_weights, max_length, padding); + t5_tokenizer.pad_tokens(t5_tokens, t5_weights, NULL, max_length, padding); // for (int i = 0; i < clip_l_tokens.size(); i++) { // std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", "; @@ -1077,7 +1077,7 @@ struct FluxCLIPEmbedder : public Conditioner { } clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, 77, padding); - t5_tokenizer.pad_tokens(t5_tokens, t5_weights, max_length, padding); + t5_tokenizer.pad_tokens(t5_tokens, t5_weights, NULL, max_length, padding); // for (int i = 0; i < clip_l_tokens.size(); i++) { // std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", "; @@ -1218,4 +1218,182 @@ struct FluxCLIPEmbedder : public Conditioner { } }; +struct PixArtCLIPEmbedder : public Conditioner { + T5UniGramTokenizer t5_tokenizer; + std::shared_ptr t5; + + PixArtCLIPEmbedder(ggml_backend_t backend, + std::map& tensor_types, + int clip_skip = -1) { + t5 = std::make_shared(backend, tensor_types, "text_encoders.t5xxl.transformer"); + } + + void set_clip_skip(int clip_skip) { + } + + void get_param_tensors(std::map& tensors) { + t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer"); + } + + void alloc_params_buffer() { + t5->alloc_params_buffer(); + } + + void free_params_buffer() { + t5->free_params_buffer(); + } + + size_t get_params_buffer_size() { + size_t buffer_size = 0; + + buffer_size += t5->get_params_buffer_size(); + + return buffer_size; + } + + std::tuple, std::vector, std::vector> tokenize(std::string text, + size_t max_length = 0, + bool padding = false) { + auto parsed_attention = parse_prompt_attention(text); + + { + std::stringstream ss; + ss << "["; + for (const auto& item : parsed_attention) { + ss << "['" << item.first << "', " << item.second << "], "; + } + ss << "]"; + LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str()); + } + + auto on_new_token_cb = [&](std::string& str, std::vector& bpe_tokens) -> bool { + return false; + }; + + std::vector t5_tokens; + std::vector t5_weights; + std::vector t5_mask; + for (const auto& item : parsed_attention) { + const std::string& curr_text = item.first; + float curr_weight = item.second; + + std::vector curr_tokens = t5_tokenizer.Encode(curr_text, true); + t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end()); + t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight); + } + + t5_tokenizer.pad_tokens(t5_tokens, t5_weights, &t5_mask, max_length, padding); + + return {t5_tokens, t5_weights, t5_mask}; + } + + SDCondition get_learned_condition_common(ggml_context* work_ctx, + int n_threads, + std::tuple, std::vector, std::vector> token_and_weights, + int clip_skip, + bool force_zero_embeddings = false) { + auto& t5_tokens = std::get<0>(token_and_weights); + auto& t5_weights = std::get<1>(token_and_weights); + auto& t5_attn_mask_vec = std::get<2>(token_and_weights); + + int64_t t0 = ggml_time_ms(); + struct ggml_tensor* hidden_states = NULL; // [N, n_token, 4096] + struct ggml_tensor* chunk_hidden_states = NULL; // [n_token, 4096] + struct ggml_tensor* pooled = NULL; // [768,] + struct ggml_tensor* t5_attn_mask = vector_to_ggml_tensor(work_ctx, t5_attn_mask_vec); // [768,] + + std::vector hidden_states_vec; + + size_t chunk_len = 256; + size_t chunk_count = t5_tokens.size() / chunk_len; + for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) { + // t5 + std::vector chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len, + t5_tokens.begin() + (chunk_idx + 1) * chunk_len); + std::vector chunk_weights(t5_weights.begin() + chunk_idx * chunk_len, + t5_weights.begin() + (chunk_idx + 1) * chunk_len); + std::vector chunk_mask(t5_attn_mask_vec.begin() + chunk_idx * chunk_len, + t5_attn_mask_vec.begin() + (chunk_idx + 1) * chunk_len); + + auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens); + auto t5_attn_mask_chunk = vector_to_ggml_tensor(work_ctx, chunk_mask); + + t5->compute(n_threads, + input_ids, + &chunk_hidden_states, + work_ctx, + t5_attn_mask_chunk); + { + auto tensor = chunk_hidden_states; + float original_mean = ggml_tensor_mean(tensor); + for (int i2 = 0; i2 < tensor->ne[2]; i2++) { + for (int i1 = 0; i1 < tensor->ne[1]; i1++) { + for (int i0 = 0; i0 < tensor->ne[0]; i0++) { + float value = ggml_tensor_get_f32(tensor, i0, i1, i2); + value *= chunk_weights[i1]; + ggml_tensor_set_f32(tensor, value, i0, i1, i2); + } + } + } + float new_mean = ggml_tensor_mean(tensor); + ggml_tensor_scale(tensor, (original_mean / new_mean)); + } + + int64_t t1 = ggml_time_ms(); + LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0); + if (force_zero_embeddings) { + float* vec = (float*)chunk_hidden_states->data; + for (int i = 0; i < ggml_nelements(chunk_hidden_states); i++) { + vec[i] = 0; + } + } + + hidden_states_vec.insert(hidden_states_vec.end(), + (float*)chunk_hidden_states->data, + ((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states)); + } + + if (hidden_states_vec.size() > 0) { + hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec); + hidden_states = ggml_reshape_2d(work_ctx, + hidden_states, + chunk_hidden_states->ne[0], + ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]); + } else { + hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, 256); + ggml_set_f32(hidden_states, 0.f); + } + return SDCondition(hidden_states, t5_attn_mask, NULL); + } + + SDCondition get_learned_condition(ggml_context* work_ctx, + int n_threads, + const std::string& text, + int clip_skip, + int width, + int height, + int adm_in_channels = -1, + bool force_zero_embeddings = false) { + auto tokens_and_weights = tokenize(text, 512, true); + return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings); + } + + std::tuple> get_learned_condition_with_trigger(ggml_context* work_ctx, + int n_threads, + const std::string& text, + int clip_skip, + int width, + int height, + int num_input_imgs, + int adm_in_channels = -1, + bool force_zero_embeddings = false) { + GGML_ASSERT(0 && "Not implemented yet!"); + } + + std::string remove_trigger_from_prompt(ggml_context* work_ctx, + const std::string& prompt) { + GGML_ASSERT(0 && "Not implemented yet!"); + } +}; + #endif \ No newline at end of file diff --git a/flux.hpp b/flux.hpp index 9d5b1e4e8..198f8ebf3 100644 --- a/flux.hpp +++ b/flux.hpp @@ -117,6 +117,7 @@ namespace Flux { struct ggml_tensor* k, struct ggml_tensor* v, struct ggml_tensor* pe, + struct ggml_tensor* mask, bool flash_attn) { // q,k,v: [N, L, n_head, d_head] // pe: [L, d_head/2, 2, 2] @@ -124,7 +125,7 @@ namespace Flux { q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head] k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head] - auto x = ggml_nn_attention_ext(ctx, q, k, v, v->ne[1], NULL, false, true, flash_attn); // [N, L, n_head*d_head] + auto x = ggml_nn_attention_ext(ctx, q, k, v, v->ne[1], mask, false, true, flash_attn); // [N, L, n_head*d_head] return x; } @@ -167,13 +168,13 @@ namespace Flux { return x; } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* pe) { + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* pe, struct ggml_tensor* mask) { // x: [N, n_token, dim] // pe: [n_token, d_head/2, 2, 2] // return [N, n_token, dim] - auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head] - x = attention(ctx, qkv[0], qkv[1], qkv[2], pe, flash_attn); // [N, n_token, dim] - x = post_attention(ctx, x); // [N, n_token, dim] + auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head] + x = attention(ctx, qkv[0], qkv[1], qkv[2], pe, mask, flash_attn); // [N, n_token, dim] + x = post_attention(ctx, x); // [N, n_token, dim] return x; } }; @@ -301,7 +302,8 @@ namespace Flux { struct ggml_tensor* img, struct ggml_tensor* txt, struct ggml_tensor* vec, - struct ggml_tensor* pe) { + struct ggml_tensor* pe, + struct ggml_tensor* mask = NULL) { // img: [N, n_img_token, hidden_size] // txt: [N, n_txt_token, hidden_size] // pe: [n_img_token + n_txt_token, d_head/2, 2, 2] @@ -360,7 +362,7 @@ namespace Flux { auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] - auto attn = attention(ctx, q, k, v, pe, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head] + auto attn = attention(ctx, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head] attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] auto txt_attn_out = ggml_view_3d(ctx, attn, @@ -446,7 +448,8 @@ namespace Flux { struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* vec, - struct ggml_tensor* pe) { + struct ggml_tensor* pe, + struct ggml_tensor* mask = NULL) { // x: [N, n_token, hidden_size] // pe: [n_token, d_head/2, 2, 2] // return: [N, n_token, hidden_size] @@ -493,7 +496,7 @@ namespace Flux { auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head] q = norm->query_norm(ctx, q); k = norm->key_norm(ctx, k); - auto attn = attention(ctx, q, k, v, pe, flash_attn); // [N, n_token, hidden_size] + auto attn = attention(ctx, q, k, v, pe, mask, flash_attn); // [N, n_token, hidden_size] auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim] auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size] @@ -707,6 +710,10 @@ namespace Flux { return ids; } + void chroma_modify_mask_to_attend_padding(struct ggml_tensor* mask, int max_seq_length, int num_extra_padding = 8) { + // TODO: implement + } + // Generate positional embeddings std::vector gen_pe(int h, int w, int patch_size, int bs, int context_len, int theta, const std::vector& axes_dim) { std::vector> ids = gen_ids(h, w, patch_size, bs, context_len); @@ -835,6 +842,7 @@ namespace Flux { struct ggml_tensor* y, struct ggml_tensor* guidance, struct ggml_tensor* pe, + struct ggml_tensor* arange = NULL, std::vector skip_layers = std::vector()) { auto img_in = std::dynamic_pointer_cast(blocks["img_in"]); auto txt_in = std::dynamic_pointer_cast(blocks["txt_in"]); @@ -842,15 +850,16 @@ namespace Flux { img = img_in->forward(ctx, img); struct ggml_tensor* vec; + struct ggml_tensor* txt_img_mask = NULL; if (params.is_chroma) { int64_t mod_index_length = 344; auto approx = std::dynamic_pointer_cast(blocks["distilled_guidance_layer"]); auto distill_timestep = ggml_nn_timestep_embedding(ctx, timesteps, 16, 10000, 1000.f); auto distill_guidance = ggml_nn_timestep_embedding(ctx, guidance, 16, 10000, 1000.f); - // auto arrange = ggml_arange(ctx, 0, (float)mod_index_length, 1); // Not working on a lot of backends - auto arrange = y; - auto modulation_index = ggml_nn_timestep_embedding(ctx, arrange, 32, 10000, 1000.f); // [1, 344, 32] + // auto arange = ggml_arange(ctx, 0, (float)mod_index_length, 1); // Not working on a lot of backends, precomputing it on CPU instead + GGML_ASSERT(arange != NULL); + auto modulation_index = ggml_nn_timestep_embedding(ctx, arange, 32, 10000, 1000.f); // [1, 344, 32] // Batch broadcast (will it ever be useful) modulation_index = ggml_repeat(ctx, modulation_index, ggml_new_tensor_3d(ctx, GGML_TYPE_F32, modulation_index->ne[0], modulation_index->ne[1], img->ne[2])); // [N, 344, 32] @@ -863,6 +872,9 @@ namespace Flux { vec = ggml_cont(ctx, ggml_permute(ctx, vec, 0, 2, 1, 3)); // [344, N, 64] vec = approx->forward(ctx, vec); // [344, N, hidden_size] + if (y != NULL) { + txt_img_mask = ggml_concat(ctx, y, ggml_scale_inplace(ctx, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, img->ne[1]), 0), 0); + } } else { auto time_in = std::dynamic_pointer_cast(blocks["time_in"]); auto vector_in = std::dynamic_pointer_cast(blocks["vector_in"]); @@ -887,7 +899,7 @@ namespace Flux { auto block = std::dynamic_pointer_cast(blocks["double_blocks." + std::to_string(i)]); - auto img_txt = block->forward(ctx, img, txt, vec, pe); + auto img_txt = block->forward(ctx, img, txt, vec, pe, txt_img_mask); img = img_txt.first; // [N, n_img_token, hidden_size] txt = img_txt.second; // [N, n_txt_token, hidden_size] } @@ -899,7 +911,7 @@ namespace Flux { } auto block = std::dynamic_pointer_cast(blocks["single_blocks." + std::to_string(i)]); - txt_img = block->forward(ctx, txt_img, vec, pe); + txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask); } txt_img = ggml_cont(ctx, ggml_permute(ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] @@ -925,6 +937,7 @@ namespace Flux { struct ggml_tensor* y, struct ggml_tensor* guidance, struct ggml_tensor* pe, + struct ggml_tensor* arange = NULL, std::vector skip_layers = std::vector()) { // Forward pass of DiT. // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) @@ -962,7 +975,7 @@ namespace Flux { img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0); } - auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size] + auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, arange, skip_layers); // [N, h*w, C * patch_size * patch_size] // rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2) out = unpatchify(ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size, patch_size); // [N, C, H + pad_h, W + pad_w] @@ -1052,19 +1065,23 @@ namespace Flux { GGML_ASSERT(x->ne[3] == 1); struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false); + struct ggml_tensor* precompute_arange = NULL; + x = to_backend(x); context = to_backend(context); if (c_concat != NULL) { c_concat = to_backend(c_concat); } - if (!flux_params.is_chroma) { - y = to_backend(y); - } else { + if (flux_params.is_chroma) { + flux.chroma_modify_mask_to_attend_padding(y, context->ne[1], 1); // ggml_arrange is not working on some backends, and y isn't used, so let's reuse y to precompute it - range = arange(0, 344); - y = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_F32, range.size()); - set_backend_tensor_data(y, range.data()); + range = arange(0, 344); + precompute_arange = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_F32, range.size()); + set_backend_tensor_data(precompute_arange, range.data()); + // y = NULL; } + y = to_backend(y); + timesteps = to_backend(timesteps); if (flux_params.guidance_embed || flux_params.is_chroma) { guidance = to_backend(guidance); @@ -1087,6 +1104,7 @@ namespace Flux { y, guidance, pe, + precompute_arange, skip_layers); ggml_build_forward_expand(gf, out); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index e38a6101f..a593284af 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -334,8 +334,19 @@ class StableDiffusionGGML { cond_stage_model = std::make_shared(clip_backend, model_loader.tensor_storages_types); diffusion_model = std::make_shared(backend, model_loader.tensor_storages_types); } else if (sd_version_is_flux(version)) { - cond_stage_model = std::make_shared(clip_backend, model_loader.tensor_storages_types); - diffusion_model = std::make_shared(backend, model_loader.tensor_storages_types, version, diffusion_flash_attn); + bool is_chroma = false; + for (auto pair : model_loader.tensor_storages_types) { + if (pair.first.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) { + is_chroma = true; + break; + } + } + if (is_chroma) { + cond_stage_model = std::make_shared(clip_backend, model_loader.tensor_storages_types); + } else { + cond_stage_model = std::make_shared(clip_backend, model_loader.tensor_storages_types); + } + diffusion_model = std::make_shared(backend, model_loader.tensor_storages_types, version, diffusion_flash_attn); } else { if (id_embeddings_path.find("v2") != std::string::npos) { cond_stage_model = std::make_shared(clip_backend, model_loader.tensor_storages_types, embeddings_path, version, PM_VERSION_2); diff --git a/t5.hpp b/t5.hpp index 2a53e2743..ae1577d1f 100644 --- a/t5.hpp +++ b/t5.hpp @@ -385,6 +385,7 @@ class T5UniGramTokenizer { void pad_tokens(std::vector& tokens, std::vector& weights, + std::vector* attention_mask, size_t max_length = 0, bool padding = false) { if (max_length > 0 && padding) { @@ -397,11 +398,15 @@ class T5UniGramTokenizer { LOG_DEBUG("token length: %llu", length); std::vector new_tokens; std::vector new_weights; + std::vector new_attention_mask; int token_idx = 0; for (int i = 0; i < length; i++) { if (token_idx >= orig_token_num) { break; } + if (attention_mask != nullptr) { + new_attention_mask.push_back(0.0); + } if (i % max_length == max_length - 1) { new_tokens.push_back(eos_id_); new_weights.push_back(1.0); @@ -414,13 +419,23 @@ class T5UniGramTokenizer { new_tokens.push_back(eos_id_); new_weights.push_back(1.0); + if (attention_mask != nullptr) { + new_attention_mask.push_back(0.0); + } + tokens = new_tokens; weights = new_weights; + if (attention_mask != nullptr) { + *attention_mask = new_attention_mask; + } if (padding) { int pad_token_id = pad_id_; tokens.insert(tokens.end(), length - tokens.size(), pad_token_id); weights.insert(weights.end(), length - weights.size(), 1.0); + if (attention_mask != nullptr) { + attention_mask->insert(attention_mask->end(), length - attention_mask->size(), -HUGE_VALF); + } } } } @@ -579,6 +594,7 @@ class T5Attention : public GGMLBlock { } if (past_bias != NULL) { if (mask != NULL) { + mask = ggml_repeat(ctx,mask,past_bias); mask = ggml_add(ctx, mask, past_bias); } else { mask = past_bias; @@ -739,15 +755,17 @@ struct T5Runner : public GGMLRunner { struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* input_ids, - struct ggml_tensor* relative_position_bucket) { + struct ggml_tensor* relative_position_bucket, + struct ggml_tensor* attention_mask = NULL) { size_t N = input_ids->ne[1]; size_t n_token = input_ids->ne[0]; - auto hidden_states = model.forward(ctx, input_ids, NULL, NULL, relative_position_bucket); // [N, n_token, model_dim] + auto hidden_states = model.forward(ctx, input_ids, NULL, attention_mask, relative_position_bucket); // [N, n_token, model_dim] return hidden_states; } - struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids) { + struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids, + struct ggml_tensor* attention_mask = NULL) { struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); input_ids = to_backend(input_ids); @@ -767,7 +785,7 @@ struct T5Runner : public GGMLRunner { input_ids->ne[0]); set_backend_tensor_data(relative_position_bucket, relative_position_bucket_vec.data()); - struct ggml_tensor* hidden_states = forward(compute_ctx, input_ids, relative_position_bucket); + struct ggml_tensor* hidden_states = forward(compute_ctx, input_ids, relative_position_bucket, attention_mask); ggml_build_forward_expand(gf, hidden_states); @@ -777,9 +795,10 @@ struct T5Runner : public GGMLRunner { void compute(const int n_threads, struct ggml_tensor* input_ids, ggml_tensor** output, - ggml_context* output_ctx = NULL) { + ggml_context* output_ctx = NULL, + struct ggml_tensor* attention_mask = NULL) { auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(input_ids); + return build_graph(input_ids, attention_mask); }; GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); } @@ -877,9 +896,9 @@ struct T5Embedder { model.alloc_params_buffer(); } - std::pair, std::vector> tokenize(std::string text, - size_t max_length = 0, - bool padding = false) { + std::tuple, std::vector, std::vector> tokenize(std::string text, + size_t max_length = 0, + bool padding = false) { auto parsed_attention = parse_prompt_attention(text); { @@ -906,14 +925,16 @@ struct T5Embedder { tokens.push_back(EOS_TOKEN_ID); weights.push_back(1.0); - tokenizer.pad_tokens(tokens, weights, max_length, padding); + std::vector attention_mask; + + tokenizer.pad_tokens(tokens, weights, &attention_mask, max_length, padding); // for (int i = 0; i < tokens.size(); i++) { // std::cout << tokens[i] << ":" << weights[i] << ", "; // } // std::cout << std::endl; - return {tokens, weights}; + return {tokens, weights, attention_mask}; } void test() { @@ -934,8 +955,8 @@ struct T5Embedder { // TODO: fix cuda nan std::string text("a lovely cat"); auto tokens_and_weights = tokenize(text, 77, true); - std::vector& tokens = tokens_and_weights.first; - std::vector& weights = tokens_and_weights.second; + std::vector& tokens = std::get<0>(tokens_and_weights); + std::vector& weights = std::get<1>(tokens_and_weights); for (auto token : tokens) { printf("%d ", token); } From 67cc9966161a213f59909725fcc894e76a4ca9ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Mon, 2 Jun 2025 11:40:27 +0200 Subject: [PATCH 08/17] implement chroma mask padding --- flux.hpp | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/flux.hpp b/flux.hpp index 198f8ebf3..2fe53651b 100644 --- a/flux.hpp +++ b/flux.hpp @@ -711,7 +711,18 @@ namespace Flux { } void chroma_modify_mask_to_attend_padding(struct ggml_tensor* mask, int max_seq_length, int num_extra_padding = 8) { - // TODO: implement + float* mask_data = (float*)mask->data; + int num_pad = 0; + for (int64_t i = 0; i < max_seq_length; i++) { + if (num_pad >= num_extra_padding) { + break; + } + if (isinf(mask_data[i])) { + mask_data[i] = 0; + ++num_pad; + } + } + // LOG_DEBUG("PAD: %d", num_pad); } // Generate positional embeddings @@ -1073,7 +1084,7 @@ namespace Flux { c_concat = to_backend(c_concat); } if (flux_params.is_chroma) { - flux.chroma_modify_mask_to_attend_padding(y, context->ne[1], 1); + flux.chroma_modify_mask_to_attend_padding(y, ggml_nelements(y), 1); // ggml_arrange is not working on some backends, and y isn't used, so let's reuse y to precompute it range = arange(0, 344); precompute_arange = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_F32, range.size()); From 55a268613bef006228b966d76b4c19f389599a13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Mon, 2 Jun 2025 12:47:31 +0200 Subject: [PATCH 09/17] Use env variable to control chroma padding settings --- conditioner.hpp | 10 ++++++++++ flux.hpp | 25 ++++++++++++++++++++++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/conditioner.hpp b/conditioner.hpp index 1838d1134..04d30d42b 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -1318,6 +1318,16 @@ struct PixArtCLIPEmbedder : public Conditioner { auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens); auto t5_attn_mask_chunk = vector_to_ggml_tensor(work_ctx, chunk_mask); + const char* SD_CHROMA_USE_T5_MASK = getenv("SD_CHROMA_USE_T5_MASK"); + if (SD_CHROMA_USE_T5_MASK != nullptr) { + std::string sd_chroma_use_t5_mask_str = SD_CHROMA_USE_T5_MASK; + if (sd_chroma_use_t5_mask_str == "OFF" || sd_chroma_use_t5_mask_str == "FALSE") { + t5_attn_mask_chunk = NULL; + } else if (sd_chroma_use_t5_mask_str != "ON" && sd_chroma_use_t5_mask_str != "TRUE") { + LOG_WARN("SD_CHROMA_USE_T5_MASK environment variable has unexpected value. Assuming default (\"ON\"). (Expected \"ON\"/\"TRUE\" or\"OFF\"/\"FALSE\", got \"%s\")", SD_CHROMA_USE_T5_MASK); + } + } + t5->compute(n_threads, input_ids, &chunk_hidden_states, diff --git a/flux.hpp b/flux.hpp index 2fe53651b..c774d2113 100644 --- a/flux.hpp +++ b/flux.hpp @@ -1084,7 +1084,30 @@ namespace Flux { c_concat = to_backend(c_concat); } if (flux_params.is_chroma) { - flux.chroma_modify_mask_to_attend_padding(y, ggml_nelements(y), 1); + int mask_pad = 1; + const char* SD_CHROMA_MASK_PAD_OVERRIDE = getenv("SD_CHROMA_MASK_PAD_OVERRIDE"); + if (SD_CHROMA_MASK_PAD_OVERRIDE != nullptr) { + std::string mask_pad_str = SD_CHROMA_MASK_PAD_OVERRIDE; + try { + mask_pad = std::stoi(mask_pad_str); + } catch (const std::invalid_argument&) { + LOG_WARN("SD_CHROMA_MASK_PAD_OVERRIDE environment variable is not a valid integer (%s). Falling back to default (%d)", SD_CHROMA_MASK_PAD_OVERRIDE, mask_pad); + } catch (const std::out_of_range&) { + LOG_WARN("SD_CHROMA_MASK_PAD_OVERRIDE environment variable value is out of range for `int` type (%s). Falling back to default (%d)", SD_CHROMA_MASK_PAD_OVERRIDE, mask_pad); + } + } + flux.chroma_modify_mask_to_attend_padding(y, ggml_nelements(y), mask_pad); + + const char* SD_CHROMA_USE_DIT_MASK = getenv("SD_CHROMA_USE_DIT_MASK"); + if (SD_CHROMA_USE_DIT_MASK != nullptr) { + std::string sd_chroma_use_DiT_mask_str = SD_CHROMA_USE_DIT_MASK; + if (sd_chroma_use_DiT_mask_str == "OFF" || sd_chroma_use_DiT_mask_str == "FALSE") { + y = NULL; + } else if (sd_chroma_use_DiT_mask_str != "ON" && sd_chroma_use_DiT_mask_str != "TRUE") { + LOG_WARN("SD_CHROMA_USE_DIT_MASK environment variable has unexpected value. Assuming default (\"ON\"). (Expected \"ON\"/\"TRUE\" or\"OFF\"/\"FALSE\", got \"%s\")", SD_CHROMA_USE_DIT_MASK); + } + } + // ggml_arrange is not working on some backends, and y isn't used, so let's reuse y to precompute it range = arange(0, 344); precompute_arange = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_F32, range.size()); From 5a96b09ec6655eac0791246648930fa5a39a2bea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Mon, 2 Jun 2025 15:21:02 +0200 Subject: [PATCH 10/17] std::isinf --- flux.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flux.hpp b/flux.hpp index c774d2113..876636673 100644 --- a/flux.hpp +++ b/flux.hpp @@ -717,7 +717,7 @@ namespace Flux { if (num_pad >= num_extra_padding) { break; } - if (isinf(mask_data[i])) { + if (std::isinf(mask_data[i])) { mask_data[i] = 0; ++num_pad; } From 2bdc8470f7b0f0e8138269c0ce33d323b4b0b5c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Mon, 2 Jun 2025 17:04:02 +0200 Subject: [PATCH 11/17] Use ggml_pad instead of concat with empty tensor --- flux.hpp | 2 +- t5.hpp | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/flux.hpp b/flux.hpp index 876636673..3e32d913c 100644 --- a/flux.hpp +++ b/flux.hpp @@ -884,7 +884,7 @@ namespace Flux { vec = approx->forward(ctx, vec); // [344, N, hidden_size] if (y != NULL) { - txt_img_mask = ggml_concat(ctx, y, ggml_scale_inplace(ctx, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, img->ne[1]), 0), 0); + txt_img_mask = ggml_pad(ctx, y, img->ne[1], 0, 0, 0); } } else { auto time_in = std::dynamic_pointer_cast(blocks["time_in"]); diff --git a/t5.hpp b/t5.hpp index ae1577d1f..7edaf8041 100644 --- a/t5.hpp +++ b/t5.hpp @@ -434,6 +434,7 @@ class T5UniGramTokenizer { tokens.insert(tokens.end(), length - tokens.size(), pad_token_id); weights.insert(weights.end(), length - weights.size(), 1.0); if (attention_mask != nullptr) { + // maybe keep some padding tokens unmasked? attention_mask->insert(attention_mask->end(), length - attention_mask->size(), -HUGE_VALF); } } @@ -594,7 +595,7 @@ class T5Attention : public GGMLBlock { } if (past_bias != NULL) { if (mask != NULL) { - mask = ggml_repeat(ctx,mask,past_bias); + mask = ggml_repeat(ctx, mask, past_bias); mask = ggml_add(ctx, mask, past_bias); } else { mask = past_bias; From 322fad2df2f630215b6d338cadcc5618110ebade Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Mon, 2 Jun 2025 17:29:47 +0200 Subject: [PATCH 12/17] Optimise masked attention --- ggml_extend.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml_extend.hpp b/ggml_extend.hpp index c5913be4d..94c6a128e 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -876,7 +876,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k] kq = ggml_scale_inplace(ctx, kq, scale); if (mask) { - kq = ggml_add(ctx, kq, mask); + kq = ggml_add_inplace(ctx, kq, mask); } if (diag_mask_inf) { kq = ggml_diag_mask_inf_inplace(ctx, kq, 0); From 42e217d62f3e0276828d9450c1bcb57ec3385ef8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Mon, 2 Jun 2025 20:26:33 +0200 Subject: [PATCH 13/17] Chroma: disable guidance by default --- flux.hpp | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/flux.hpp b/flux.hpp index 3e32d913c..e00fe58d9 100644 --- a/flux.hpp +++ b/flux.hpp @@ -1084,6 +1084,22 @@ namespace Flux { c_concat = to_backend(c_concat); } if (flux_params.is_chroma) { + const char* SD_CHROMA_ENABLE_GUIDANCE = getenv("SD_CHROMA_ENABLE_GUIDANCE"); + bool disable_guidance = true; + if (SD_CHROMA_ENABLE_GUIDANCE != NULL) { + std::string enable_guidance_str = SD_CHROMA_ENABLE_GUIDANCE; + if (enable_guidance_str == "ON" || enable_guidance_str == "TRUE") { + LOG_WARN("Chroma guidance has been enabled. Image might be broken. (SD_CHROMA_ENABLE_GUIDANCE env variable to \"OFF\" to disable)", SD_CHROMA_ENABLE_GUIDANCE); + disable_guidance = false; + } else if (enable_guidance_str != "OFF" && enable_guidance_str != "FALSE") { + LOG_WARN("SD_CHROMA_ENABLE_GUIDANCE environment variable has unexpected value. Assuming default (\"OFF\"). (Expected \"ON\"/\"TRUE\" or\"OFF\"/\"FALSE\", got \"%s\")", SD_CHROMA_ENABLE_GUIDANCE); + } + } + if (disable_guidance) { + LOG_DEBUG("Forcing guidance to 0 for chroma model (SD_CHROMA_ENABLE_GUIDANCE env variable to \"ON\" to enable)"); + guidance = ggml_set_f32(guidance, 0); + } + int mask_pad = 1; const char* SD_CHROMA_MASK_PAD_OVERRIDE = getenv("SD_CHROMA_MASK_PAD_OVERRIDE"); if (SD_CHROMA_MASK_PAD_OVERRIDE != nullptr) { From 4fdedd5ac59bf0d5e4d197077eed53929fbe35e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Mon, 2 Jun 2025 23:45:49 +0200 Subject: [PATCH 14/17] Chroma: Fix t5 chunk length --- conditioner.hpp | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/conditioner.hpp b/conditioner.hpp index 04d30d42b..8f67e56b1 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -1004,6 +1004,7 @@ struct FluxCLIPEmbedder : public Conditioner { T5UniGramTokenizer t5_tokenizer; std::shared_ptr clip_l; std::shared_ptr t5; + size_t chunk_len = 256; FluxCLIPEmbedder(ggml_backend_t backend, std::map& tensor_types, @@ -1109,7 +1110,6 @@ struct FluxCLIPEmbedder : public Conditioner { struct ggml_tensor* pooled = NULL; // [768,] std::vector hidden_states_vec; - size_t chunk_len = 256; size_t chunk_count = t5_tokens.size() / chunk_len; for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) { // clip_l @@ -1196,7 +1196,7 @@ struct FluxCLIPEmbedder : public Conditioner { int height, int adm_in_channels = -1, bool force_zero_embeddings = false) { - auto tokens_and_weights = tokenize(text, 256, true); + auto tokens_and_weights = tokenize(text, chunk_len, true); return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings); } @@ -1221,6 +1221,7 @@ struct FluxCLIPEmbedder : public Conditioner { struct PixArtCLIPEmbedder : public Conditioner { T5UniGramTokenizer t5_tokenizer; std::shared_ptr t5; + size_t chunk_len = 512; PixArtCLIPEmbedder(ggml_backend_t backend, std::map& tensor_types, @@ -1304,8 +1305,18 @@ struct PixArtCLIPEmbedder : public Conditioner { std::vector hidden_states_vec; - size_t chunk_len = 256; size_t chunk_count = t5_tokens.size() / chunk_len; + + bool use_mask = true; + const char* SD_CHROMA_USE_T5_MASK = getenv("SD_CHROMA_USE_T5_MASK"); + if (SD_CHROMA_USE_T5_MASK != nullptr) { + std::string sd_chroma_use_t5_mask_str = SD_CHROMA_USE_T5_MASK; + if (sd_chroma_use_t5_mask_str == "OFF" || sd_chroma_use_t5_mask_str == "FALSE") { + use_mask = false; + } else if (sd_chroma_use_t5_mask_str != "ON" && sd_chroma_use_t5_mask_str != "TRUE") { + LOG_WARN("SD_CHROMA_USE_T5_MASK environment variable has unexpected value. Assuming default (\"ON\"). (Expected \"ON\"/\"TRUE\" or\"OFF\"/\"FALSE\", got \"%s\")", SD_CHROMA_USE_T5_MASK); + } + } for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) { // t5 std::vector chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len, @@ -1316,17 +1327,7 @@ struct PixArtCLIPEmbedder : public Conditioner { t5_attn_mask_vec.begin() + (chunk_idx + 1) * chunk_len); auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens); - auto t5_attn_mask_chunk = vector_to_ggml_tensor(work_ctx, chunk_mask); - - const char* SD_CHROMA_USE_T5_MASK = getenv("SD_CHROMA_USE_T5_MASK"); - if (SD_CHROMA_USE_T5_MASK != nullptr) { - std::string sd_chroma_use_t5_mask_str = SD_CHROMA_USE_T5_MASK; - if (sd_chroma_use_t5_mask_str == "OFF" || sd_chroma_use_t5_mask_str == "FALSE") { - t5_attn_mask_chunk = NULL; - } else if (sd_chroma_use_t5_mask_str != "ON" && sd_chroma_use_t5_mask_str != "TRUE") { - LOG_WARN("SD_CHROMA_USE_T5_MASK environment variable has unexpected value. Assuming default (\"ON\"). (Expected \"ON\"/\"TRUE\" or\"OFF\"/\"FALSE\", got \"%s\")", SD_CHROMA_USE_T5_MASK); - } - } + auto t5_attn_mask_chunk = use_mask ? vector_to_ggml_tensor(work_ctx, chunk_mask) : NULL; t5->compute(n_threads, input_ids, @@ -1384,7 +1385,7 @@ struct PixArtCLIPEmbedder : public Conditioner { int height, int adm_in_channels = -1, bool force_zero_embeddings = false) { - auto tokens_and_weights = tokenize(text, 512, true); + auto tokens_and_weights = tokenize(text, chunk_len, true); return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings); } From 3238fe3fab5502a4799cb06d6934f6a9b6aaf88b Mon Sep 17 00:00:00 2001 From: Green Sky Date: Tue, 3 Jun 2025 15:41:56 +0200 Subject: [PATCH 15/17] fix mask with flash attn --- ggml_extend.hpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 94c6a128e..101a5d1f6 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -864,6 +864,18 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head] v = ggml_cast(ctx, v, GGML_TYPE_F16); + if (mask != nullptr) { + mask = ggml_transpose(ctx, mask); + + if (mask->ne[1] < GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD)) { + LOG_DEBUG("mask dims %ld, %ld, %ld, %ld\n", mask->ne[0], mask->ne[1], mask->ne[2], mask->ne[3]); + LOG_DEBUG("needs padding, padding from %ld to %ld\n", mask->ne[1], GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD)); + mask = ggml_pad(ctx, mask, 0, GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) - mask->ne[1], 0, 0); + } + + mask = ggml_cast(ctx, mask, GGML_TYPE_F16); + } + kqv = ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0, 0); ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32); From efaa137bc0fb6ac098ced5600dfe2b17e754eb4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Tue, 3 Jun 2025 23:09:49 +0200 Subject: [PATCH 16/17] Remove deprecated Todos --- flux.hpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/flux.hpp b/flux.hpp index e00fe58d9..db27083d7 100644 --- a/flux.hpp +++ b/flux.hpp @@ -297,7 +297,6 @@ namespace Flux { return {ModulationOut(ctx, vec, offset), ModulationOut(ctx, vec, offset + 3)}; } - // TODO: chroma (prune_mod) -> get modulations from offset vec std::pair forward(struct ggml_context* ctx, struct ggml_tensor* img, struct ggml_tensor* txt, @@ -795,7 +794,6 @@ namespace Flux { params.is_chroma)); } - // TODO: no modulation for chroma blocks["final_layer"] = std::shared_ptr(new LastLayer(params.hidden_size, 1, params.out_channels, params.is_chroma)); } From e22c57cc1ee6b5aef01a936ade953e9b202aac20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Wed, 4 Jun 2025 02:31:54 +0200 Subject: [PATCH 17/17] Only include pad into mask once --- conditioner.hpp | 30 ++++++++++++++++++++++++++++++ flux.hpp | 27 --------------------------- 2 files changed, 30 insertions(+), 27 deletions(-) diff --git a/conditioner.hpp b/conditioner.hpp index 8f67e56b1..f48f4f493 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -1288,6 +1288,21 @@ struct PixArtCLIPEmbedder : public Conditioner { return {t5_tokens, t5_weights, t5_mask}; } + void modify_mask_to_attend_padding(struct ggml_tensor* mask, int max_seq_length, int num_extra_padding = 8) { + float* mask_data = (float*)mask->data; + int num_pad = 0; + for (int64_t i = 0; i < max_seq_length; i++) { + if (num_pad >= num_extra_padding) { + break; + } + if (std::isinf(mask_data[i])) { + mask_data[i] = 0; + ++num_pad; + } + } + // LOG_DEBUG("PAD: %d", num_pad); + } + SDCondition get_learned_condition_common(ggml_context* work_ctx, int n_threads, std::tuple, std::vector, std::vector> token_and_weights, @@ -1374,6 +1389,21 @@ struct PixArtCLIPEmbedder : public Conditioner { hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, 256); ggml_set_f32(hidden_states, 0.f); } + + int mask_pad = 1; + const char* SD_CHROMA_MASK_PAD_OVERRIDE = getenv("SD_CHROMA_MASK_PAD_OVERRIDE"); + if (SD_CHROMA_MASK_PAD_OVERRIDE != nullptr) { + std::string mask_pad_str = SD_CHROMA_MASK_PAD_OVERRIDE; + try { + mask_pad = std::stoi(mask_pad_str); + } catch (const std::invalid_argument&) { + LOG_WARN("SD_CHROMA_MASK_PAD_OVERRIDE environment variable is not a valid integer (%s). Falling back to default (%d)", SD_CHROMA_MASK_PAD_OVERRIDE, mask_pad); + } catch (const std::out_of_range&) { + LOG_WARN("SD_CHROMA_MASK_PAD_OVERRIDE environment variable value is out of range for `int` type (%s). Falling back to default (%d)", SD_CHROMA_MASK_PAD_OVERRIDE, mask_pad); + } + } + modify_mask_to_attend_padding(t5_attn_mask, ggml_nelements(t5_attn_mask), mask_pad); + return SDCondition(hidden_states, t5_attn_mask, NULL); } diff --git a/flux.hpp b/flux.hpp index db27083d7..a33d9eb02 100644 --- a/flux.hpp +++ b/flux.hpp @@ -709,20 +709,6 @@ namespace Flux { return ids; } - void chroma_modify_mask_to_attend_padding(struct ggml_tensor* mask, int max_seq_length, int num_extra_padding = 8) { - float* mask_data = (float*)mask->data; - int num_pad = 0; - for (int64_t i = 0; i < max_seq_length; i++) { - if (num_pad >= num_extra_padding) { - break; - } - if (std::isinf(mask_data[i])) { - mask_data[i] = 0; - ++num_pad; - } - } - // LOG_DEBUG("PAD: %d", num_pad); - } // Generate positional embeddings std::vector gen_pe(int h, int w, int patch_size, int bs, int context_len, int theta, const std::vector& axes_dim) { @@ -1098,19 +1084,6 @@ namespace Flux { guidance = ggml_set_f32(guidance, 0); } - int mask_pad = 1; - const char* SD_CHROMA_MASK_PAD_OVERRIDE = getenv("SD_CHROMA_MASK_PAD_OVERRIDE"); - if (SD_CHROMA_MASK_PAD_OVERRIDE != nullptr) { - std::string mask_pad_str = SD_CHROMA_MASK_PAD_OVERRIDE; - try { - mask_pad = std::stoi(mask_pad_str); - } catch (const std::invalid_argument&) { - LOG_WARN("SD_CHROMA_MASK_PAD_OVERRIDE environment variable is not a valid integer (%s). Falling back to default (%d)", SD_CHROMA_MASK_PAD_OVERRIDE, mask_pad); - } catch (const std::out_of_range&) { - LOG_WARN("SD_CHROMA_MASK_PAD_OVERRIDE environment variable value is out of range for `int` type (%s). Falling back to default (%d)", SD_CHROMA_MASK_PAD_OVERRIDE, mask_pad); - } - } - flux.chroma_modify_mask_to_attend_padding(y, ggml_nelements(y), mask_pad); const char* SD_CHROMA_USE_DIT_MASK = getenv("SD_CHROMA_USE_DIT_MASK"); if (SD_CHROMA_USE_DIT_MASK != nullptr) {