Skip to content

feat: add sd3.5 medium support (+skip layer guidance) #451

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions diffusion_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ struct DiffusionModel {
std::vector<struct ggml_tensor*> controls = {},
float control_strength = 0.f,
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL) = 0;
struct ggml_context* output_ctx = NULL,
std::vector<int> skip_layers = std::vector<int>()) = 0;
virtual void alloc_params_buffer() = 0;
virtual void free_params_buffer() = 0;
virtual void free_compute_buffer() = 0;
Expand Down Expand Up @@ -70,7 +71,9 @@ struct UNetModel : public DiffusionModel {
std::vector<struct ggml_tensor*> controls = {},
float control_strength = 0.f,
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL) {
struct ggml_context* output_ctx = NULL,
std::vector<int> skip_layers = std::vector<int>()) {
(void)skip_layers; // SLG doesn't work with UNet models
return unet.compute(n_threads, x, timesteps, context, c_concat, y, num_video_frames, controls, control_strength, output, output_ctx);
}
};
Expand Down Expand Up @@ -119,8 +122,9 @@ struct MMDiTModel : public DiffusionModel {
std::vector<struct ggml_tensor*> controls = {},
float control_strength = 0.f,
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL) {
return mmdit.compute(n_threads, x, timesteps, context, y, output, output_ctx);
struct ggml_context* output_ctx = NULL,
std::vector<int> skip_layers = std::vector<int>()) {
return mmdit.compute(n_threads, x, timesteps, context, y, output, output_ctx, skip_layers);
}
};

Expand Down Expand Up @@ -168,8 +172,9 @@ struct FluxModel : public DiffusionModel {
std::vector<struct ggml_tensor*> controls = {},
float control_strength = 0.f,
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL) {
return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx);
struct ggml_context* output_ctx = NULL,
std::vector<int> skip_layers = std::vector<int>()) {
return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx, skip_layers);
}
};

Expand Down
83 changes: 82 additions & 1 deletion examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ struct SDParams {
bool canny_preprocess = false;
bool color = false;
int upscale_repeats = 1;

std::vector<int> skip_layers = {7, 8, 9};
float slg_scale = 0.;
float skip_layer_start = 0.01;
float skip_layer_end = 0.2;
};

void print_params(SDParams params) {
Expand Down Expand Up @@ -151,6 +156,7 @@ void print_params(SDParams params) {
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
printf(" min_cfg: %.2f\n", params.min_cfg);
printf(" cfg_scale: %.2f\n", params.cfg_scale);
printf(" slg_scale: %.2f\n", params.slg_scale);
printf(" guidance: %.2f\n", params.guidance);
printf(" clip_skip: %d\n", params.clip_skip);
printf(" width: %d\n", params.width);
Expand Down Expand Up @@ -197,6 +203,12 @@ void print_usage(int argc, const char* argv[]) {
printf(" -p, --prompt [PROMPT] the prompt to render\n");
printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n");
printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n");
printf(" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n");
printf(" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n");
printf(" --skip_layers LAYERS Layers to skip for SLG steps: (default: [7,8,9])\n");
printf(" --skip_layer_start START SLG enabling point: (default: 0.01)\n");
printf(" --skip_layer_end END SLG disabling point: (default: 0.2)\n");
printf(" SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END])\n");
printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n");
printf(" --style-ratio STYLE-RATIO strength for keeping input identity (default: 20%%)\n");
printf(" --control-strength STRENGTH strength to apply Control Net (default: 0.9)\n");
Expand Down Expand Up @@ -534,6 +546,61 @@ void parse_args(int argc, const char** argv, SDParams& params) {
params.verbose = true;
} else if (arg == "--color") {
params.color = true;
} else if (arg == "--slg-scale") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.slg_scale = std::stof(argv[i]);
} else if (arg == "--skip-layers") {
if (++i >= argc) {
invalid_arg = true;
break;
}
if (argv[i][0] != '[') {
invalid_arg = true;
break;
}
std::string layers_str = argv[i];
while (layers_str.back() != ']') {
if (++i >= argc) {
invalid_arg = true;
break;
}
layers_str += " " + std::string(argv[i]);
}
layers_str = layers_str.substr(1, layers_str.size() - 2);

std::regex regex("[, ]+");
std::sregex_token_iterator iter(layers_str.begin(), layers_str.end(), regex, -1);
std::sregex_token_iterator end;
std::vector<std::string> tokens(iter, end);
std::vector<int> layers;
for (const auto& token : tokens) {
try {
layers.push_back(std::stoi(token));
} catch (const std::invalid_argument& e) {
invalid_arg = true;
break;
}
}
params.skip_layers = layers;

if (invalid_arg) {
break;
}
} else if (arg == "--skip-layer-start") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.skip_layer_start = std::stof(argv[i]);
} else if (arg == "--skip-layer-end") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.skip_layer_end = std::stof(argv[i]);
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
print_usage(argc, argv);
Expand Down Expand Up @@ -624,6 +691,16 @@ std::string get_image_params(SDParams params, int64_t seed) {
}
parameter_string += "Steps: " + std::to_string(params.sample_steps) + ", ";
parameter_string += "CFG scale: " + std::to_string(params.cfg_scale) + ", ";
if (params.slg_scale != 0 && params.skip_layers.size() != 0) {
parameter_string += "SLG scale: " + std::to_string(params.cfg_scale) + ", ";
parameter_string += "Skip layers: [";
for (const auto& layer : params.skip_layers) {
parameter_string += std::to_string(layer) + ", ";
}
parameter_string += "], ";
parameter_string += "Skip layer start: " + std::to_string(params.skip_layer_start) + ", ";
parameter_string += "Skip layer end: " + std::to_string(params.skip_layer_end) + ", ";
}
parameter_string += "Guidance: " + std::to_string(params.guidance) + ", ";
parameter_string += "Seed: " + std::to_string(seed) + ", ";
parameter_string += "Size: " + std::to_string(params.width) + "x" + std::to_string(params.height) + ", ";
Expand Down Expand Up @@ -840,7 +917,11 @@ int main(int argc, const char* argv[]) {
params.control_strength,
params.style_ratio,
params.normalize_input,
params.input_id_images_path.c_str());
params.input_id_images_path.c_str(),
params.skip_layers,
params.slg_scale,
params.skip_layer_start,
params.skip_layer_end);
} else {
sd_image_t input_image = {(uint32_t)params.width,
(uint32_t)params.height,
Expand Down
26 changes: 19 additions & 7 deletions flux.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,8 @@ namespace Flux {
struct ggml_tensor* timesteps,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
struct ggml_tensor* pe) {
struct ggml_tensor* pe,
std::vector<int> skip_layers = std::vector<int>()) {
auto img_in = std::dynamic_pointer_cast<Linear>(blocks["img_in"]);
auto time_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["time_in"]);
auto vector_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["vector_in"]);
Expand All @@ -733,6 +734,10 @@ namespace Flux {
txt = txt_in->forward(ctx, txt);

for (int i = 0; i < params.depth; i++) {
if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i) != skip_layers.end()) {
continue;
}

auto block = std::dynamic_pointer_cast<DoubleStreamBlock>(blocks["double_blocks." + std::to_string(i)]);

auto img_txt = block->forward(ctx, img, txt, vec, pe);
Expand All @@ -742,6 +747,9 @@ namespace Flux {

auto txt_img = ggml_concat(ctx, txt, img, 1); // [N, n_txt_token + n_img_token, hidden_size]
for (int i = 0; i < params.depth_single_blocks; i++) {
if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i + params.depth) != skip_layers.end()) {
continue;
}
auto block = std::dynamic_pointer_cast<SingleStreamBlock>(blocks["single_blocks." + std::to_string(i)]);

txt_img = block->forward(ctx, txt_img, vec, pe);
Expand Down Expand Up @@ -769,7 +777,8 @@ namespace Flux {
struct ggml_tensor* context,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
struct ggml_tensor* pe) {
struct ggml_tensor* pe,
std::vector<int> skip_layers = std::vector<int>()) {
// Forward pass of DiT.
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
// timestep: (N,) tensor of diffusion timesteps
Expand All @@ -791,7 +800,7 @@ namespace Flux {
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]

auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe); // [N, h*w, C * patch_size * patch_size]
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size]

// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
out = unpatchify(ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size, patch_size); // [N, C, H + pad_h, W + pad_w]
Expand Down Expand Up @@ -829,7 +838,8 @@ namespace Flux {
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor* y,
struct ggml_tensor* guidance) {
struct ggml_tensor* guidance,
std::vector<int> skip_layers = std::vector<int>()) {
GGML_ASSERT(x->ne[3] == 1);
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false);

Expand All @@ -856,7 +866,8 @@ namespace Flux {
context,
y,
guidance,
pe);
pe,
skip_layers);

ggml_build_forward_expand(gf, out);

Expand All @@ -870,14 +881,15 @@ namespace Flux {
struct ggml_tensor* y,
struct ggml_tensor* guidance,
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL) {
struct ggml_context* output_ctx = NULL,
std::vector<int> skip_layers = std::vector<int>()) {
// x: [N, in_channels, h, w]
// timesteps: [N, ]
// context: [N, max_position, hidden_size]
// y: [N, adm_in_channels] or [1, adm_in_channels]
// guidance: [N, ]
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x, timesteps, context, y, guidance);
return build_graph(x, timesteps, context, y, guidance, skip_layers);
};

GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
Expand Down
Loading
Loading