From 56b110304f0da9a76b7b2d69163bd3369e4518fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Wed, 20 Nov 2024 23:00:25 +0100 Subject: [PATCH 1/2] Flux Lite (Freepik) support --- flux.hpp | 3 +++ model.cpp | 20 ++++++++++++++++---- model.h | 1 + stable-diffusion.cpp | 23 ++++++++++++----------- vae.hpp | 2 +- 5 files changed, 33 insertions(+), 16 deletions(-) diff --git a/flux.hpp b/flux.hpp index 73bc345a7..637128f30 100644 --- a/flux.hpp +++ b/flux.hpp @@ -813,6 +813,9 @@ namespace Flux { if (version == VERSION_FLUX_SCHNELL) { flux_params.guidance_embed = false; } + if (version == VERSION_FLUX_LITE){ + flux_params.depth = 8; + } flux = Flux(flux_params); flux.init(params_ctx, wtype); } diff --git a/model.cpp b/model.cpp index 26451cdc5..94eaa801c 100644 --- a/model.cpp +++ b/model.cpp @@ -1364,15 +1364,20 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s SDVersion ModelLoader::get_sd_version() { TensorStorage token_embedding_weight; - bool is_flux = false; - bool is_sd3 = false; + bool is_flux = false; + bool is_schnell = true; + bool is_lite = true; + bool is_sd3 = false; for (auto& tensor_storage : tensor_storages) { if (tensor_storage.name.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) { - return VERSION_FLUX_DEV; + is_schnell = false; } if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) { is_flux = true; } + if (tensor_storage.name.find("model.diffusion_model.double_blocks.8") != std::string::npos) { + is_lite = false; + } if (tensor_storage.name.find("joint_blocks.37.x_block.attn.ln_q.weight") != std::string::npos) { return VERSION_SD3_5_8B; } @@ -1400,7 +1405,14 @@ SDVersion ModelLoader::get_sd_version() { } } if (is_flux) { - return VERSION_FLUX_SCHNELL; + if (is_schnell) { + GGML_ASSERT(!is_lite); + return VERSION_FLUX_SCHNELL; + } else if (is_lite) { + return VERSION_FLUX_LITE; + } else { + return VERSION_FLUX_DEV; + } } if (is_sd3) { return VERSION_SD3_2B; diff --git a/model.h b/model.h index 4efbdf813..b1c78bfba 100644 --- a/model.h +++ b/model.h @@ -26,6 +26,7 @@ enum SDVersion { VERSION_FLUX_DEV, VERSION_FLUX_SCHNELL, VERSION_SD3_5_8B, + VERSION_FLUX_LITE, VERSION_COUNT, }; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 4d28a147b..69a5f1f2f 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -32,7 +32,8 @@ const char* model_version_to_str[] = { "SD3 2B", "Flux Dev", "Flux Schnell", - "SD3.5 8B"}; + "SD3.5 8B", + "Flux Lite 8B"}; const char* sampling_methods_str[] = { "Euler A", @@ -290,7 +291,7 @@ class StableDiffusionGGML { } } else if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) { scale_factor = 1.5305f; - } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) { scale_factor = 0.3611; // TODO: shift_factor } @@ -311,7 +312,7 @@ class StableDiffusionGGML { } else { clip_backend = backend; bool use_t5xxl = false; - if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) { use_t5xxl = true; } if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) { @@ -325,7 +326,7 @@ class StableDiffusionGGML { if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) { cond_stage_model = std::make_shared(clip_backend, conditioner_wtype); diffusion_model = std::make_shared(backend, diffusion_model_wtype, version); - } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) { cond_stage_model = std::make_shared(clip_backend, conditioner_wtype); diffusion_model = std::make_shared(backend, diffusion_model_wtype, version); } else { @@ -523,7 +524,7 @@ class StableDiffusionGGML { if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) { LOG_INFO("running in FLOW mode"); denoiser = std::make_shared(); - } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) { LOG_INFO("running in Flux FLOW mode"); float shift = 1.15f; if (version == VERSION_FLUX_SCHNELL) { @@ -950,7 +951,7 @@ class StableDiffusionGGML { } else { if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) { C = 32; - } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) { C = 32; } } @@ -1283,7 +1284,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, int C = 4; if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) { C = 16; - } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { + } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) { C = 16; } int W = width / 8; @@ -1397,7 +1398,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) { params.mem_size *= 3; } - if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { + if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) { params.mem_size *= 4; } if (sd_ctx->sd->stacked_id) { @@ -1422,7 +1423,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, int C = 4; if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) { C = 16; - } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { + } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) { C = 16; } int W = width / 8; @@ -1430,7 +1431,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1); if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) { ggml_set_f32(init_latent, 0.0609f); - } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { + } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) { ggml_set_f32(init_latent, 0.1159f); } else { ggml_set_f32(init_latent, 0.f); @@ -1492,7 +1493,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) { params.mem_size *= 2; } - if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { + if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) { params.mem_size *= 3; } if (sd_ctx->sd->stacked_id) { diff --git a/vae.hpp b/vae.hpp index 42b694cd5..3005598f7 100644 --- a/vae.hpp +++ b/vae.hpp @@ -457,7 +457,7 @@ class AutoencodingEngine : public GGMLBlock { bool use_video_decoder = false, SDVersion version = VERSION_SD1) : decode_only(decode_only), use_video_decoder(use_video_decoder) { - if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) { dd_config.z_channels = 16; use_quant = false; } From 0da6255ce5665b06a3cac1b1002af5968da84067 Mon Sep 17 00:00:00 2001 From: leejet Date: Sat, 23 Nov 2024 11:39:13 +0800 Subject: [PATCH 2/2] format code --- clip.hpp | 2 +- conditioner.hpp | 11 ++++------- flux.hpp | 2 +- stable-diffusion.cpp | 3 +-- vae.hpp | 3 +-- 5 files changed, 8 insertions(+), 13 deletions(-) diff --git a/clip.hpp b/clip.hpp index bf2a8c149..e0d846aa8 100644 --- a/clip.hpp +++ b/clip.hpp @@ -712,7 +712,7 @@ class CLIPTextModel : public GGMLBlock { auto text_projection = params["text_projection"]; ggml_tensor* pooled = ggml_view_1d(ctx, x, hidden_size, x->nb[1] * max_token_idx); if (text_projection != NULL) { - pooled = ggml_nn_linear(ctx, pooled, text_projection, NULL); + pooled = ggml_nn_linear(ctx, pooled, text_projection, NULL); } else { LOG_DEBUG("Missing text_projection matrix, assuming identity..."); } diff --git a/conditioner.hpp b/conditioner.hpp index 9f9d5ae1f..ea02d377f 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -798,7 +798,7 @@ struct SD3CLIPEmbedder : public Conditioner { } if (chunk_idx == 0) { - auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID); + auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID); max_token_idx = std::min(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1); clip_l->compute(n_threads, input_ids, @@ -808,7 +808,6 @@ struct SD3CLIPEmbedder : public Conditioner { true, &pooled_l, work_ctx); - } } @@ -848,7 +847,7 @@ struct SD3CLIPEmbedder : public Conditioner { } if (chunk_idx == 0) { - auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_g_tokenizer.EOS_TOKEN_ID); + auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_g_tokenizer.EOS_TOKEN_ID); max_token_idx = std::min(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1); clip_g->compute(n_threads, input_ids, @@ -858,7 +857,6 @@ struct SD3CLIPEmbedder : public Conditioner { true, &pooled_g, work_ctx); - } } @@ -1096,9 +1094,9 @@ struct FluxCLIPEmbedder : public Conditioner { auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens); size_t max_token_idx = 0; - auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID); + auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID); max_token_idx = std::min(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1); - + clip_l->compute(n_threads, input_ids, 0, @@ -1107,7 +1105,6 @@ struct FluxCLIPEmbedder : public Conditioner { true, &pooled, work_ctx); - } // t5 diff --git a/flux.hpp b/flux.hpp index 6a65b2604..faea59a4d 100644 --- a/flux.hpp +++ b/flux.hpp @@ -822,7 +822,7 @@ namespace Flux { if (version == VERSION_FLUX_SCHNELL) { flux_params.guidance_embed = false; } - if (version == VERSION_FLUX_LITE){ + if (version == VERSION_FLUX_LITE) { flux_params.depth = 8; } flux = Flux(flux_params); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 097aa3e7b..2297cd377 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -313,8 +313,7 @@ class StableDiffusionGGML { } else { clip_backend = backend; bool use_t5xxl = false; - if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B \ - || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) { + if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) { use_t5xxl = true; } if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) { diff --git a/vae.hpp b/vae.hpp index 23df80f6b..8642375f8 100644 --- a/vae.hpp +++ b/vae.hpp @@ -457,8 +457,7 @@ class AutoencodingEngine : public GGMLBlock { bool use_video_decoder = false, SDVersion version = VERSION_SD1) : decode_only(decode_only), use_video_decoder(use_video_decoder) { - if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B \ - || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) { + if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) { dd_config.z_channels = 16; use_quant = false; }