diff --git a/diffusion_model.hpp b/diffusion_model.hpp index 2530f714..26c619b0 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -17,7 +17,8 @@ struct DiffusionModel { std::vector controls = {}, float control_strength = 0.f, struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL) = 0; + struct ggml_context* output_ctx = NULL, + std::vector skip_layers = std::vector()) = 0; virtual void alloc_params_buffer() = 0; virtual void free_params_buffer() = 0; virtual void free_compute_buffer() = 0; @@ -70,7 +71,9 @@ struct UNetModel : public DiffusionModel { std::vector controls = {}, float control_strength = 0.f, struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL) { + struct ggml_context* output_ctx = NULL, + std::vector skip_layers = std::vector()) { + (void)skip_layers; // SLG doesn't work with UNet models return unet.compute(n_threads, x, timesteps, context, c_concat, y, num_video_frames, controls, control_strength, output, output_ctx); } }; @@ -119,8 +122,9 @@ struct MMDiTModel : public DiffusionModel { std::vector controls = {}, float control_strength = 0.f, struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL) { - return mmdit.compute(n_threads, x, timesteps, context, y, output, output_ctx); + struct ggml_context* output_ctx = NULL, + std::vector skip_layers = std::vector()) { + return mmdit.compute(n_threads, x, timesteps, context, y, output, output_ctx, skip_layers); } }; @@ -168,8 +172,9 @@ struct FluxModel : public DiffusionModel { std::vector controls = {}, float control_strength = 0.f, struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL) { - return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx); + struct ggml_context* output_ctx = NULL, + std::vector skip_layers = std::vector()) { + return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx, skip_layers); } }; diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index f1bdc698..1f33547e 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -119,6 +119,11 @@ struct SDParams { bool canny_preprocess = false; bool color = false; int upscale_repeats = 1; + + std::vector skip_layers = {7, 8, 9}; + float slg_scale = 0.; + float skip_layer_start = 0.01; + float skip_layer_end = 0.2; }; void print_params(SDParams params) { @@ -151,6 +156,7 @@ void print_params(SDParams params) { printf(" negative_prompt: %s\n", params.negative_prompt.c_str()); printf(" min_cfg: %.2f\n", params.min_cfg); printf(" cfg_scale: %.2f\n", params.cfg_scale); + printf(" slg_scale: %.2f\n", params.slg_scale); printf(" guidance: %.2f\n", params.guidance); printf(" clip_skip: %d\n", params.clip_skip); printf(" width: %d\n", params.width); @@ -197,6 +203,12 @@ void print_usage(int argc, const char* argv[]) { printf(" -p, --prompt [PROMPT] the prompt to render\n"); printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n"); printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n"); + printf(" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n"); + printf(" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n"); + printf(" --skip_layers LAYERS Layers to skip for SLG steps: (default: [7,8,9])\n"); + printf(" --skip_layer_start START SLG enabling point: (default: 0.01)\n"); + printf(" --skip_layer_end END SLG disabling point: (default: 0.2)\n"); + printf(" SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END])\n"); printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n"); printf(" --style-ratio STYLE-RATIO strength for keeping input identity (default: 20%%)\n"); printf(" --control-strength STRENGTH strength to apply Control Net (default: 0.9)\n"); @@ -534,6 +546,61 @@ void parse_args(int argc, const char** argv, SDParams& params) { params.verbose = true; } else if (arg == "--color") { params.color = true; + } else if (arg == "--slg-scale") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.slg_scale = std::stof(argv[i]); + } else if (arg == "--skip-layers") { + if (++i >= argc) { + invalid_arg = true; + break; + } + if (argv[i][0] != '[') { + invalid_arg = true; + break; + } + std::string layers_str = argv[i]; + while (layers_str.back() != ']') { + if (++i >= argc) { + invalid_arg = true; + break; + } + layers_str += " " + std::string(argv[i]); + } + layers_str = layers_str.substr(1, layers_str.size() - 2); + + std::regex regex("[, ]+"); + std::sregex_token_iterator iter(layers_str.begin(), layers_str.end(), regex, -1); + std::sregex_token_iterator end; + std::vector tokens(iter, end); + std::vector layers; + for (const auto& token : tokens) { + try { + layers.push_back(std::stoi(token)); + } catch (const std::invalid_argument& e) { + invalid_arg = true; + break; + } + } + params.skip_layers = layers; + + if (invalid_arg) { + break; + } + } else if (arg == "--skip-layer-start") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.skip_layer_start = std::stof(argv[i]); + } else if (arg == "--skip-layer-end") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.skip_layer_end = std::stof(argv[i]); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); print_usage(argc, argv); @@ -624,6 +691,16 @@ std::string get_image_params(SDParams params, int64_t seed) { } parameter_string += "Steps: " + std::to_string(params.sample_steps) + ", "; parameter_string += "CFG scale: " + std::to_string(params.cfg_scale) + ", "; + if (params.slg_scale != 0 && params.skip_layers.size() != 0) { + parameter_string += "SLG scale: " + std::to_string(params.cfg_scale) + ", "; + parameter_string += "Skip layers: ["; + for (const auto& layer : params.skip_layers) { + parameter_string += std::to_string(layer) + ", "; + } + parameter_string += "], "; + parameter_string += "Skip layer start: " + std::to_string(params.skip_layer_start) + ", "; + parameter_string += "Skip layer end: " + std::to_string(params.skip_layer_end) + ", "; + } parameter_string += "Guidance: " + std::to_string(params.guidance) + ", "; parameter_string += "Seed: " + std::to_string(seed) + ", "; parameter_string += "Size: " + std::to_string(params.width) + "x" + std::to_string(params.height) + ", "; @@ -840,7 +917,11 @@ int main(int argc, const char* argv[]) { params.control_strength, params.style_ratio, params.normalize_input, - params.input_id_images_path.c_str()); + params.input_id_images_path.c_str(), + params.skip_layers, + params.slg_scale, + params.skip_layer_start, + params.skip_layer_end); } else { sd_image_t input_image = {(uint32_t)params.width, (uint32_t)params.height, diff --git a/flux.hpp b/flux.hpp index 73bc345a..89bf7843 100644 --- a/flux.hpp +++ b/flux.hpp @@ -711,7 +711,8 @@ namespace Flux { struct ggml_tensor* timesteps, struct ggml_tensor* y, struct ggml_tensor* guidance, - struct ggml_tensor* pe) { + struct ggml_tensor* pe, + std::vector skip_layers = std::vector()) { auto img_in = std::dynamic_pointer_cast(blocks["img_in"]); auto time_in = std::dynamic_pointer_cast(blocks["time_in"]); auto vector_in = std::dynamic_pointer_cast(blocks["vector_in"]); @@ -733,6 +734,10 @@ namespace Flux { txt = txt_in->forward(ctx, txt); for (int i = 0; i < params.depth; i++) { + if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i) != skip_layers.end()) { + continue; + } + auto block = std::dynamic_pointer_cast(blocks["double_blocks." + std::to_string(i)]); auto img_txt = block->forward(ctx, img, txt, vec, pe); @@ -742,6 +747,9 @@ namespace Flux { auto txt_img = ggml_concat(ctx, txt, img, 1); // [N, n_txt_token + n_img_token, hidden_size] for (int i = 0; i < params.depth_single_blocks; i++) { + if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i + params.depth) != skip_layers.end()) { + continue; + } auto block = std::dynamic_pointer_cast(blocks["single_blocks." + std::to_string(i)]); txt_img = block->forward(ctx, txt_img, vec, pe); @@ -769,7 +777,8 @@ namespace Flux { struct ggml_tensor* context, struct ggml_tensor* y, struct ggml_tensor* guidance, - struct ggml_tensor* pe) { + struct ggml_tensor* pe, + std::vector skip_layers = std::vector()) { // Forward pass of DiT. // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) // timestep: (N,) tensor of diffusion timesteps @@ -791,7 +800,7 @@ namespace Flux { // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size] - auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe); // [N, h*w, C * patch_size * patch_size] + auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size] // rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2) out = unpatchify(ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size, patch_size); // [N, C, H + pad_h, W + pad_w] @@ -829,7 +838,8 @@ namespace Flux { struct ggml_tensor* timesteps, struct ggml_tensor* context, struct ggml_tensor* y, - struct ggml_tensor* guidance) { + struct ggml_tensor* guidance, + std::vector skip_layers = std::vector()) { GGML_ASSERT(x->ne[3] == 1); struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false); @@ -856,7 +866,8 @@ namespace Flux { context, y, guidance, - pe); + pe, + skip_layers); ggml_build_forward_expand(gf, out); @@ -870,14 +881,15 @@ namespace Flux { struct ggml_tensor* y, struct ggml_tensor* guidance, struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL) { + struct ggml_context* output_ctx = NULL, + std::vector skip_layers = std::vector()) { // x: [N, in_channels, h, w] // timesteps: [N, ] // context: [N, max_position, hidden_size] // y: [N, adm_in_channels] or [1, adm_in_channels] // guidance: [N, ] auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(x, timesteps, context, y, guidance); + return build_graph(x, timesteps, context, y, guidance, skip_layers); }; GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); diff --git a/mmdit.hpp b/mmdit.hpp index 3a278dac..35810bad 100644 --- a/mmdit.hpp +++ b/mmdit.hpp @@ -252,6 +252,7 @@ struct DismantledBlock : public GGMLBlock { public: int64_t num_heads; bool pre_only; + bool self_attn; public: DismantledBlock(int64_t hidden_size, @@ -259,14 +260,19 @@ struct DismantledBlock : public GGMLBlock { float mlp_ratio = 4.0, std::string qk_norm = "", bool qkv_bias = false, - bool pre_only = false) - : num_heads(num_heads), pre_only(pre_only) { + bool pre_only = false, + bool self_attn = false) + : num_heads(num_heads), pre_only(pre_only), self_attn(self_attn) { // rmsnorm is always Flase // scale_mod_only is always Flase // swiglu is always Flase blocks["norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); blocks["attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only)); + if (self_attn) { + blocks["attn2"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, false)); + } + if (!pre_only) { blocks["norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); int64_t mlp_hidden_dim = (int64_t)(hidden_size * mlp_ratio); @@ -277,9 +283,52 @@ struct DismantledBlock : public GGMLBlock { if (pre_only) { n_mods = 2; } + if (self_attn) { + n_mods = 9; + } blocks["adaLN_modulation.1"] = std::shared_ptr(new Linear(hidden_size, n_mods * hidden_size)); } + std::tuple, std::vector, std::vector> pre_attention_x(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* c) { + GGML_ASSERT(self_attn); + // x: [N, n_token, hidden_size] + // c: [N, hidden_size] + auto norm1 = std::dynamic_pointer_cast(blocks["norm1"]); + auto attn = std::dynamic_pointer_cast(blocks["attn"]); + auto attn2 = std::dynamic_pointer_cast(blocks["attn2"]); + auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); + + int64_t n_mods = 9; + auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, n_mods * hidden_size] + m = ggml_reshape_3d(ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size] + m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size] + + int64_t offset = m->nb[1] * m->ne[1]; + auto shift_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] + auto scale_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] + auto gate_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size] + + auto shift_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size] + auto scale_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size] + auto gate_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size] + + auto shift_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 6); // [N, hidden_size] + auto scale_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 7); // [N, hidden_size] + auto gate_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 8); // [N, hidden_size] + + auto x_norm = norm1->forward(ctx, x); + + auto attn_in = modulate(ctx, x_norm, shift_msa, scale_msa); + auto qkv = attn->pre_attention(ctx, attn_in); + + auto attn2_in = modulate(ctx, x_norm, shift_msa2, scale_msa2); + auto qkv2 = attn2->pre_attention(ctx, attn2_in); + + return {qkv, qkv2, {x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2}}; + } + std::pair, std::vector> pre_attention(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* c) { @@ -319,6 +368,44 @@ struct DismantledBlock : public GGMLBlock { } } + struct ggml_tensor* post_attention_x(struct ggml_context* ctx, + struct ggml_tensor* attn_out, + struct ggml_tensor* attn2_out, + struct ggml_tensor* x, + struct ggml_tensor* gate_msa, + struct ggml_tensor* shift_mlp, + struct ggml_tensor* scale_mlp, + struct ggml_tensor* gate_mlp, + struct ggml_tensor* gate_msa2) { + // attn_out: [N, n_token, hidden_size] + // x: [N, n_token, hidden_size] + // gate_msa: [N, hidden_size] + // shift_mlp: [N, hidden_size] + // scale_mlp: [N, hidden_size] + // gate_mlp: [N, hidden_size] + // return: [N, n_token, hidden_size] + GGML_ASSERT(!pre_only); + + auto attn = std::dynamic_pointer_cast(blocks["attn"]); + auto attn2 = std::dynamic_pointer_cast(blocks["attn2"]); + auto norm2 = std::dynamic_pointer_cast(blocks["norm2"]); + auto mlp = std::dynamic_pointer_cast(blocks["mlp"]); + + gate_msa = ggml_reshape_3d(ctx, gate_msa, gate_msa->ne[0], 1, gate_msa->ne[1]); // [N, 1, hidden_size] + gate_mlp = ggml_reshape_3d(ctx, gate_mlp, gate_mlp->ne[0], 1, gate_mlp->ne[1]); // [N, 1, hidden_size] + gate_msa2 = ggml_reshape_3d(ctx, gate_msa2, gate_msa2->ne[0], 1, gate_msa2->ne[1]); // [N, 1, hidden_size] + + attn_out = attn->post_attention(ctx, attn_out); + attn2_out = attn2->post_attention(ctx, attn2_out); + + x = ggml_add(ctx, x, ggml_mul(ctx, attn_out, gate_msa)); + x = ggml_add(ctx, x, ggml_mul(ctx, attn2_out, gate_msa2)); + auto mlp_out = mlp->forward(ctx, modulate(ctx, norm2->forward(ctx, x), shift_mlp, scale_mlp)); + x = ggml_add(ctx, x, ggml_mul(ctx, mlp_out, gate_mlp)); + + return x; + } + struct ggml_tensor* post_attention(struct ggml_context* ctx, struct ggml_tensor* attn_out, struct ggml_tensor* x, @@ -357,29 +444,52 @@ struct DismantledBlock : public GGMLBlock { // return: [N, n_token, hidden_size] auto attn = std::dynamic_pointer_cast(blocks["attn"]); - - auto qkv_intermediates = pre_attention(ctx, x, c); - auto qkv = qkv_intermediates.first; - auto intermediates = qkv_intermediates.second; - - auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] - x = post_attention(ctx, - attn_out, - intermediates[0], - intermediates[1], - intermediates[2], - intermediates[3], - intermediates[4]); - return x; // [N, n_token, dim] + if (self_attn) { + auto qkv_intermediates = pre_attention_x(ctx, x, c); + // auto qkv = qkv_intermediates.first; + // auto intermediates = qkv_intermediates.second; + // no longer a pair, but a tuple + auto qkv = std::get<0>(qkv_intermediates); + auto qkv2 = std::get<1>(qkv_intermediates); + auto intermediates = std::get<2>(qkv_intermediates); + + auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] + auto attn2_out = ggml_nn_attention_ext(ctx, qkv2[0], qkv2[1], qkv2[2], num_heads); // [N, n_token, dim] + x = post_attention_x(ctx, + attn_out, + attn2_out, + intermediates[0], + intermediates[1], + intermediates[2], + intermediates[3], + intermediates[4], + intermediates[5]); + return x; // [N, n_token, dim] + } else { + auto qkv_intermediates = pre_attention(ctx, x, c); + auto qkv = qkv_intermediates.first; + auto intermediates = qkv_intermediates.second; + + auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] + x = post_attention(ctx, + attn_out, + intermediates[0], + intermediates[1], + intermediates[2], + intermediates[3], + intermediates[4]); + return x; // [N, n_token, dim] + } } }; -__STATIC_INLINE__ std::pair block_mixing(struct ggml_context* ctx, - struct ggml_tensor* context, - struct ggml_tensor* x, - struct ggml_tensor* c, - std::shared_ptr context_block, - std::shared_ptr x_block) { +__STATIC_INLINE__ std::pair +block_mixing(struct ggml_context* ctx, + struct ggml_tensor* context, + struct ggml_tensor* x, + struct ggml_tensor* c, + std::shared_ptr context_block, + std::shared_ptr x_block) { // context: [N, n_context, hidden_size] // x: [N, n_token, hidden_size] // c: [N, hidden_size] @@ -387,10 +497,18 @@ __STATIC_INLINE__ std::pair block_mixi auto context_qkv = context_qkv_intermediates.first; auto context_intermediates = context_qkv_intermediates.second; - auto x_qkv_intermediates = x_block->pre_attention(ctx, x, c); - auto x_qkv = x_qkv_intermediates.first; - auto x_intermediates = x_qkv_intermediates.second; + std::vector x_qkv, x_qkv2, x_intermediates; + if (x_block->self_attn) { + auto x_qkv_intermediates = x_block->pre_attention_x(ctx, x, c); + x_qkv = std::get<0>(x_qkv_intermediates); + x_qkv2 = std::get<1>(x_qkv_intermediates); + x_intermediates = std::get<2>(x_qkv_intermediates); + } else { + auto x_qkv_intermediates = x_block->pre_attention(ctx, x, c); + x_qkv = x_qkv_intermediates.first; + x_intermediates = x_qkv_intermediates.second; + } std::vector qkv; for (int i = 0; i < 3; i++) { qkv.push_back(ggml_concat(ctx, context_qkv[i], x_qkv[i], 1)); @@ -429,13 +547,27 @@ __STATIC_INLINE__ std::pair block_mixi context = NULL; } - x = x_block->post_attention(ctx, - x_attn, - x_intermediates[0], - x_intermediates[1], - x_intermediates[2], - x_intermediates[3], - x_intermediates[4]); + if (x_block->self_attn) { + auto attn2 = ggml_nn_attention_ext(ctx, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads); // [N, n_token, hidden_size] + + x = x_block->post_attention_x(ctx, + x_attn, + attn2, + x_intermediates[0], + x_intermediates[1], + x_intermediates[2], + x_intermediates[3], + x_intermediates[4], + x_intermediates[5]); + } else { + x = x_block->post_attention(ctx, + x_attn, + x_intermediates[0], + x_intermediates[1], + x_intermediates[2], + x_intermediates[3], + x_intermediates[4]); + } return {context, x}; } @@ -447,9 +579,10 @@ struct JointBlock : public GGMLBlock { float mlp_ratio = 4.0, std::string qk_norm = "", bool qkv_bias = false, - bool pre_only = false) { + bool pre_only = false, + bool self_attn_x = false) { blocks["context_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only)); - blocks["x_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false)); + blocks["x_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false, self_attn_x)); } std::pair forward(struct ggml_context* ctx, @@ -507,6 +640,7 @@ struct MMDiT : public GGMLBlock { int64_t input_size = -1; int64_t patch_size = 2; int64_t in_channels = 16; + int64_t d_self = -1; // >=0 for MMdiT-X int64_t depth = 24; float mlp_ratio = 4.0f; int64_t adm_in_channels = 2048; @@ -561,6 +695,20 @@ struct MMDiT : public GGMLBlock { context_size = 4096; context_embedder_out_dim = 2432; qk_norm = "rms"; + } else if (version == VERSION_SD3_5_2B) { + input_size = -1; + patch_size = 2; + in_channels = 16; + depth = 24; + d_self = 12; + mlp_ratio = 4.0f; + adm_in_channels = 2048; + out_channels = 16; + pos_embed_max_size = 384; + num_patchs = 147456; + context_size = 4096; + context_embedder_out_dim = 1536; + qk_norm = "rms"; } int64_t default_out_channels = in_channels; hidden_size = 64 * depth; @@ -581,15 +729,17 @@ struct MMDiT : public GGMLBlock { mlp_ratio, qk_norm, true, - i == depth - 1)); + i == depth - 1, + i <= d_self)); } blocks["final_layer"] = std::shared_ptr(new FinalLayer(hidden_size, patch_size, out_channels)); } - struct ggml_tensor* cropped_pos_embed(struct ggml_context* ctx, - int64_t h, - int64_t w) { + struct ggml_tensor* + cropped_pos_embed(struct ggml_context* ctx, + int64_t h, + int64_t w) { auto pos_embed = params["pos_embed"]; h = (h + 1) / patch_size; @@ -651,7 +801,8 @@ struct MMDiT : public GGMLBlock { struct ggml_tensor* forward_core_with_concat(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* c_mod, - struct ggml_tensor* context) { + struct ggml_tensor* context, + std::vector skip_layers = std::vector()) { // x: [N, H*W, hidden_size] // context: [N, n_context, d_context] // c: [N, hidden_size] @@ -659,6 +810,11 @@ struct MMDiT : public GGMLBlock { auto final_layer = std::dynamic_pointer_cast(blocks["final_layer"]); for (int i = 0; i < depth; i++) { + // skip iteration if i is in skip_layers + if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i) != skip_layers.end()) { + continue; + } + auto block = std::dynamic_pointer_cast(blocks["joint_blocks." + std::to_string(i)]); auto context_x = block->forward(ctx, context, x, c_mod); @@ -674,8 +830,9 @@ struct MMDiT : public GGMLBlock { struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* t, - struct ggml_tensor* y = NULL, - struct ggml_tensor* context = NULL) { + struct ggml_tensor* y = NULL, + struct ggml_tensor* context = NULL, + std::vector skip_layers = std::vector()) { // Forward pass of DiT. // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) // t: (N,) tensor of diffusion timesteps @@ -706,7 +863,7 @@ struct MMDiT : public GGMLBlock { context = context_embedder->forward(ctx, context); // [N, L, D] aka [N, L, 1536] } - x = forward_core_with_concat(ctx, x, c, context); // (N, H*W, patch_size ** 2 * out_channels) + x = forward_core_with_concat(ctx, x, c, context, skip_layers); // (N, H*W, patch_size ** 2 * out_channels) x = unpatchify(ctx, x, h, w); // [N, C, H, W] @@ -735,7 +892,8 @@ struct MMDiTRunner : public GGMLRunner { struct ggml_cgraph* build_graph(struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, - struct ggml_tensor* y) { + struct ggml_tensor* y, + std::vector skip_layers = std::vector()) { struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, MMDIT_GRAPH_SIZE, false); x = to_backend(x); @@ -747,7 +905,8 @@ struct MMDiTRunner : public GGMLRunner { x, timesteps, y, - context); + context, + skip_layers); ggml_build_forward_expand(gf, out); @@ -760,13 +919,14 @@ struct MMDiTRunner : public GGMLRunner { struct ggml_tensor* context, struct ggml_tensor* y, struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL) { + struct ggml_context* output_ctx = NULL, + std::vector skip_layers = std::vector()) { // x: [N, in_channels, h, w] // timesteps: [N, ] // context: [N, max_position, hidden_size]([N, 154, 4096]) or [1, max_position, hidden_size] // y: [N, adm_in_channels] or [1, adm_in_channels] auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(x, timesteps, context, y); + return build_graph(x, timesteps, context, y, skip_layers); }; GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); diff --git a/model.cpp b/model.cpp index 26451cdc..3da1b3a4 100644 --- a/model.cpp +++ b/model.cpp @@ -1373,6 +1373,9 @@ SDVersion ModelLoader::get_sd_version() { if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) { is_flux = true; } + if (tensor_storage.name.find("joint_blocks.0.x_block.attn2.ln_q.weight") != std::string::npos) { + return VERSION_SD3_5_2B; + } if (tensor_storage.name.find("joint_blocks.37.x_block.attn.ln_q.weight") != std::string::npos) { return VERSION_SD3_5_8B; } diff --git a/model.h b/model.h index 4efbdf81..041245e3 100644 --- a/model.h +++ b/model.h @@ -26,6 +26,7 @@ enum SDVersion { VERSION_FLUX_DEV, VERSION_FLUX_SCHNELL, VERSION_SD3_5_8B, + VERSION_SD3_5_2B, VERSION_COUNT, }; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 4d28a147..079daa04 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -32,7 +32,8 @@ const char* model_version_to_str[] = { "SD3 2B", "Flux Dev", "Flux Schnell", - "SD3.5 8B"}; + "SD3.5 8B", + "SD3.5 2B"}; const char* sampling_methods_str[] = { "Euler A", @@ -288,7 +289,7 @@ class StableDiffusionGGML { "try specifying SDXL VAE FP16 Fix with the --vae parameter. " "You can find it here: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors"); } - } else if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) { + } else if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) { scale_factor = 1.5305f; } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { scale_factor = 0.3611; @@ -311,7 +312,7 @@ class StableDiffusionGGML { } else { clip_backend = backend; bool use_t5xxl = false; - if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { use_t5xxl = true; } if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) { @@ -322,7 +323,7 @@ class StableDiffusionGGML { LOG_INFO("CLIP: Using CPU backend"); clip_backend = ggml_backend_cpu_init(); } - if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) { + if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) { cond_stage_model = std::make_shared(clip_backend, conditioner_wtype); diffusion_model = std::make_shared(backend, diffusion_model_wtype, version); } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { @@ -520,7 +521,7 @@ class StableDiffusionGGML { is_using_v_parameterization = true; } - if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) { + if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) { LOG_INFO("running in FLOW mode"); denoiser = std::make_shared(); } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { @@ -771,7 +772,11 @@ class StableDiffusionGGML { sample_method_t method, const std::vector& sigmas, int start_merge_step, - SDCondition id_cond) { + SDCondition id_cond, + std::vector skip_layers = {}, + float slg_scale = 2.5, + float skip_layer_start = 0.01, + float skip_layer_end = 0.2) { size_t steps = sigmas.size() - 1; // noise = load_tensor_from_file(work_ctx, "./rand0.bin"); // print_ggml_tensor(noise); @@ -782,13 +787,24 @@ class StableDiffusionGGML { struct ggml_tensor* noised_input = ggml_dup_tensor(work_ctx, noise); bool has_unconditioned = cfg_scale != 1.0 && uncond.c_crossattn != NULL; + bool has_skiplayer = slg_scale != 0.0 && skip_layers.size() > 0; // denoise wrapper struct ggml_tensor* out_cond = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* out_uncond = NULL; + struct ggml_tensor* out_skip = NULL; + if (has_unconditioned) { out_uncond = ggml_dup_tensor(work_ctx, x); } + if (has_skiplayer) { + if (version == VERSION_SD3_2B || version == VERSION_SD3_5_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + out_skip = ggml_dup_tensor(work_ctx, x); + } else { + has_skiplayer = false; + LOG_WARN("SLG is incompatible with %s models", model_version_to_str[version]); + } + } struct ggml_tensor* denoised = ggml_dup_tensor(work_ctx, x); auto denoise = [&](ggml_tensor* input, float sigma, int step) -> ggml_tensor* { @@ -869,6 +885,28 @@ class StableDiffusionGGML { &out_uncond); negative_data = (float*)out_uncond->data; } + + int step_count = sigmas.size(); + bool is_skiplayer_step = has_skiplayer && step > (int)(skip_layer_start * step_count) && step < (int)(skip_layer_end * step_count); + float* skip_layer_data = NULL; + if (is_skiplayer_step) { + LOG_DEBUG("Skipping layers at step %d\n", step); + // skip layer (same as conditionned) + diffusion_model->compute(n_threads, + noised_input, + timesteps, + cond.c_crossattn, + cond.c_concat, + cond.c_vector, + guidance_tensor, + -1, + controls, + control_strength, + &out_skip, + NULL, + skip_layers); + skip_layer_data = (float*)out_skip->data; + } float* vec_denoised = (float*)denoised->data; float* vec_input = (float*)input->data; float* positive_data = (float*)out_cond->data; @@ -885,6 +923,9 @@ class StableDiffusionGGML { latent_result = negative_data[i] + cfg_scale * (positive_data[i] - negative_data[i]); } } + if (is_skiplayer_step) { + latent_result = latent_result + (positive_data[i] - skip_layer_data[i]) * slg_scale; + } // v = latent_result, eps = latent_result // denoised = (v * c_out + input * c_skip) or (input + eps * c_out) vec_denoised[i] = latent_result * c_out + vec_input[i] * c_skip; @@ -948,7 +989,7 @@ class StableDiffusionGGML { if (use_tiny_autoencoder) { C = 4; } else { - if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) { + if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) { C = 32; } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { C = 32; @@ -1111,7 +1152,11 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, float control_strength, float style_ratio, bool normalize_input, - std::string input_id_images_path) { + std::string input_id_images_path, + std::vector skip_layers = {}, + float slg_scale = 2.5, + float skip_layer_start = 0.01, + float skip_layer_end = 0.2) { if (seed < 0) { // Generally, when using the provided command line, the seed is always >0. // However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library @@ -1281,7 +1326,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, // Sample std::vector final_latents; // collect latents to decode int C = 4; - if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) { + if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) { C = 16; } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { C = 16; @@ -1320,7 +1365,11 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, sample_method, sigmas, start_merge_step, - id_cond); + id_cond, + skip_layers, + slg_scale, + skip_layer_start, + skip_layer_end); // struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin"); // print_ggml_tensor(x_0); int64_t sampling_end = ggml_time_ms(); @@ -1386,7 +1435,11 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, float control_strength, float style_ratio, bool normalize_input, - const char* input_id_images_path_c_str) { + const char* input_id_images_path_c_str, + std::vector skip_layers, + float slg_scale, + float skip_layer_start, + float skip_layer_end) { LOG_DEBUG("txt2img %dx%d", width, height); if (sd_ctx == NULL) { return NULL; @@ -1394,7 +1447,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, struct ggml_init_params params; params.mem_size = static_cast(10 * 1024 * 1024); // 10 MB - if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) { + if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) { params.mem_size *= 3; } if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { @@ -1420,7 +1473,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps); int C = 4; - if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) { + if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) { C = 16; } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { C = 16; @@ -1428,7 +1481,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, int W = width / 8; int H = height / 8; ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1); - if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) { + if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) { ggml_set_f32(init_latent, 0.0609f); } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { ggml_set_f32(init_latent, 0.1159f); @@ -1454,7 +1507,11 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, control_strength, style_ratio, normalize_input, - input_id_images_path_c_str); + input_id_images_path_c_str, + skip_layers, + slg_scale, + skip_layer_start, + skip_layer_end); size_t t1 = ggml_time_ms(); @@ -1481,7 +1538,11 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, float control_strength, float style_ratio, bool normalize_input, - const char* input_id_images_path_c_str) { + const char* input_id_images_path_c_str, + std::vector skip_layers, + float slg_scale, + float skip_layer_start, + float skip_layer_end) { LOG_DEBUG("img2img %dx%d", width, height); if (sd_ctx == NULL) { return NULL; @@ -1489,7 +1550,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, struct ggml_init_params params; params.mem_size = static_cast(10 * 1024 * 1024); // 10 MB - if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) { + if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) { params.mem_size *= 2; } if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { @@ -1555,7 +1616,11 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, control_strength, style_ratio, normalize_input, - input_id_images_path_c_str); + input_id_images_path_c_str, + skip_layers, + slg_scale, + skip_layer_start, + skip_layer_end); size_t t2 = ggml_time_ms(); diff --git a/stable-diffusion.h b/stable-diffusion.h index 812e8fc9..b310ee59 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -162,7 +162,11 @@ SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx, float control_strength, float style_strength, bool normalize_input, - const char* input_id_images_path); + const char* input_id_images_path, + std::vector skip_layers = {}, + float slg_scale = 2.5, + float skip_layer_start = 0.01, + float skip_layer_end = 0.2); SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx, sd_image_t init_image, @@ -182,7 +186,11 @@ SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx, float control_strength, float style_strength, bool normalize_input, - const char* input_id_images_path); + const char* input_id_images_path, + std::vector skip_layers = {}, + float slg_scale = 2.5, + float skip_layer_start = 0.01, + float skip_layer_end = 0.2); SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx, sd_image_t init_image, diff --git a/vae.hpp b/vae.hpp index 42b694cd..50ddf752 100644 --- a/vae.hpp +++ b/vae.hpp @@ -457,7 +457,7 @@ class AutoencodingEngine : public GGMLBlock { bool use_video_decoder = false, SDVersion version = VERSION_SD1) : decode_only(decode_only), use_video_decoder(use_video_decoder) { - if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { dd_config.z_channels = 16; use_quant = false; }