diff --git a/README.md b/README.md index a0acedc94..6bb2b9eaf 100644 --- a/README.md +++ b/README.md @@ -12,11 +12,12 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in - Super lightweight and without external dependencies - SD1.x, SD2.x, SDXL and SD3 support - !!!The VAE in SDXL encounters NaN issues under FP16, but unfortunately, the ggml_conv_2d only operates under FP16. Hence, a parameter is needed to specify the VAE that has fixed the FP16 NaN issue. You can find it here: [SDXL VAE FP16 Fix](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors). +- [Flux-dev/Flux-schnell Support](./docs/flux.md) - [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo) and [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo) support - [PhotoMaker](https://github.com/TencentARC/PhotoMaker) support. - 16-bit, 32-bit float support -- 4-bit, 5-bit and 8-bit integer quantization support +- 2-bit, 3-bit, 4-bit, 5-bit and 8-bit integer quantization support - Accelerated memory-efficient CPU inference - Only requires ~2.3GB when using txt2img with fp16 precision to generate a 512x512 image, enabling Flash Attention just requires ~1.8GB. - AVX, AVX2 and AVX512 support for x86 architectures @@ -57,7 +58,6 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in - The current implementation of ggml_conv_2d is slow and has high memory usage - [ ] Continuing to reduce memory usage (quantizing the weights of ggml_conv_2d) - [ ] Implement Inpainting support -- [ ] k-quants support ## Usage @@ -171,7 +171,7 @@ arguments: --normalize-input normalize PHOTOMAKER input id images --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now. --upscale-repeats Run the ESRGAN upscaler this many times (default 1) - --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0) + --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k) If not specified, the default is the type of the weight file. --lora-model-dir [DIR] lora model directory -i, --init-img [IMAGE] path to the input image, required by img2img @@ -198,7 +198,7 @@ arguments: --vae-tiling process vae in tiles to reduce memory usage --control-net-cpu keep controlnet in cpu (for low vram) --canny apply canny preprocessor (edge detection) - --color colors the logging tags according to level + --color Colors the logging tags according to level -v, --verbose print extra info ``` @@ -209,6 +209,7 @@ arguments: # ./bin/sd -m ../models/v1-5-pruned-emaonly.safetensors -p "a lovely cat" # ./bin/sd -m ../models/sd_xl_base_1.0.safetensors --vae ../models/sdxl_vae-fp16-fix.safetensors -H 1024 -W 1024 -p "a lovely cat" -v # ./bin/sd -m ../models/sd3_medium_incl_clips_t5xxlfp16.safetensors -H 1024 -W 1024 -p 'a lovely cat holding a sign says \"Stable Diffusion CPP\"' --cfg-scale 4.5 --sampling-method euler -v +# ./bin/sd --diffusion-model ../models/flux1-dev-q3_k.gguf --vae ../models/ae.sft --clip_l ../models/clip_l.safetensors --t5xxl ../models/t5xxl_fp16.safetensors -p "a lovely cat holding a sign says 'flux.cpp'" --cfg-scale 1.0 --sampling-method euler -v ``` Using formats of different precisions will yield results of varying quality. diff --git a/assets/flux/flux1-dev-q2_k.png b/assets/flux/flux1-dev-q2_k.png new file mode 100644 index 000000000..1aef6f8c6 Binary files /dev/null and b/assets/flux/flux1-dev-q2_k.png differ diff --git a/assets/flux/flux1-dev-q3_k.png b/assets/flux/flux1-dev-q3_k.png new file mode 100644 index 000000000..352bfc70c Binary files /dev/null and b/assets/flux/flux1-dev-q3_k.png differ diff --git a/assets/flux/flux1-dev-q4_0.png b/assets/flux/flux1-dev-q4_0.png new file mode 100644 index 000000000..1a5ee2b56 Binary files /dev/null and b/assets/flux/flux1-dev-q4_0.png differ diff --git a/assets/flux/flux1-dev-q8_0 with lora.png b/assets/flux/flux1-dev-q8_0 with lora.png new file mode 100644 index 000000000..fb05892aa Binary files /dev/null and b/assets/flux/flux1-dev-q8_0 with lora.png differ diff --git a/assets/flux/flux1-dev-q8_0.png b/assets/flux/flux1-dev-q8_0.png new file mode 100644 index 000000000..3f469d2da Binary files /dev/null and b/assets/flux/flux1-dev-q8_0.png differ diff --git a/assets/flux/flux1-schnell-q8_0.png b/assets/flux/flux1-schnell-q8_0.png new file mode 100644 index 000000000..4ba7dc401 Binary files /dev/null and b/assets/flux/flux1-schnell-q8_0.png differ diff --git a/common.hpp b/common.hpp index bfdcc004c..b18ee51f5 100644 --- a/common.hpp +++ b/common.hpp @@ -367,7 +367,7 @@ class SpatialTransformer : public GGMLBlock { int64_t n_head; int64_t d_head; int64_t depth = 1; // 1 - int64_t context_dim = 768; // hidden_size, 1024 for VERSION_2_x + int64_t context_dim = 768; // hidden_size, 1024 for VERSION_SD2 public: SpatialTransformer(int64_t in_channels, diff --git a/conditioner.hpp b/conditioner.hpp index e01be2b21..0e8f5a3ad 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -43,7 +43,7 @@ struct Conditioner { // ldm.modules.encoders.modules.FrozenCLIPEmbedder // Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283 struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { - SDVersion version = VERSION_1_x; + SDVersion version = VERSION_SD1; CLIPTokenizer tokenizer; ggml_type wtype; std::shared_ptr text_model; @@ -58,20 +58,20 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { FrozenCLIPEmbedderWithCustomWords(ggml_backend_t backend, ggml_type wtype, const std::string& embd_dir, - SDVersion version = VERSION_1_x, + SDVersion version = VERSION_SD1, int clip_skip = -1) - : version(version), tokenizer(version == VERSION_2_x ? 0 : 49407), embd_dir(embd_dir), wtype(wtype) { + : version(version), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir), wtype(wtype) { if (clip_skip <= 0) { clip_skip = 1; - if (version == VERSION_2_x || version == VERSION_XL) { + if (version == VERSION_SD2 || version == VERSION_SDXL) { clip_skip = 2; } } - if (version == VERSION_1_x) { + if (version == VERSION_SD1) { text_model = std::make_shared(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip); - } else if (version == VERSION_2_x) { + } else if (version == VERSION_SD2) { text_model = std::make_shared(backend, wtype, OPEN_CLIP_VIT_H_14, clip_skip); - } else if (version == VERSION_XL) { + } else if (version == VERSION_SDXL) { text_model = std::make_shared(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip, false); text_model2 = std::make_shared(backend, wtype, OPEN_CLIP_VIT_BIGG_14, clip_skip, false); } @@ -79,35 +79,35 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { void set_clip_skip(int clip_skip) { text_model->set_clip_skip(clip_skip); - if (version == VERSION_XL) { + if (version == VERSION_SDXL) { text_model2->set_clip_skip(clip_skip); } } void get_param_tensors(std::map& tensors) { text_model->get_param_tensors(tensors, "cond_stage_model.transformer.text_model"); - if (version == VERSION_XL) { + if (version == VERSION_SDXL) { text_model2->get_param_tensors(tensors, "cond_stage_model.1.transformer.text_model"); } } void alloc_params_buffer() { text_model->alloc_params_buffer(); - if (version == VERSION_XL) { + if (version == VERSION_SDXL) { text_model2->alloc_params_buffer(); } } void free_params_buffer() { text_model->free_params_buffer(); - if (version == VERSION_XL) { + if (version == VERSION_SDXL) { text_model2->free_params_buffer(); } } size_t get_params_buffer_size() { size_t buffer_size = text_model->get_params_buffer_size(); - if (version == VERSION_XL) { + if (version == VERSION_SDXL) { buffer_size += text_model2->get_params_buffer_size(); } return buffer_size; @@ -398,7 +398,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens); struct ggml_tensor* input_ids2 = NULL; size_t max_token_idx = 0; - if (version == VERSION_XL) { + if (version == VERSION_SDXL) { auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), tokenizer.EOS_TOKEN_ID); if (it != chunk_tokens.end()) { std::fill(std::next(it), chunk_tokens.end(), 0); @@ -423,7 +423,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { false, &chunk_hidden_states1, work_ctx); - if (version == VERSION_XL) { + if (version == VERSION_SDXL) { text_model2->compute(n_threads, input_ids2, 0, @@ -482,7 +482,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]); ggml_tensor* vec = NULL; - if (version == VERSION_XL) { + if (version == VERSION_SDXL) { int out_dim = 256; vec = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, adm_in_channels); // [0:1280] @@ -978,4 +978,230 @@ struct SD3CLIPEmbedder : public Conditioner { } }; + +struct FluxCLIPEmbedder : public Conditioner { + ggml_type wtype; + CLIPTokenizer clip_l_tokenizer; + T5UniGramTokenizer t5_tokenizer; + std::shared_ptr clip_l; + std::shared_ptr t5; + + FluxCLIPEmbedder(ggml_backend_t backend, + ggml_type wtype, + int clip_skip = -1) + : wtype(wtype) { + if (clip_skip <= 0) { + clip_skip = 2; + } + clip_l = std::make_shared(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip, true); + t5 = std::make_shared(backend, wtype); + } + + void set_clip_skip(int clip_skip) { + clip_l->set_clip_skip(clip_skip); + } + + void get_param_tensors(std::map& tensors) { + clip_l->get_param_tensors(tensors, "text_encoders.clip_l.text_model"); + t5->get_param_tensors(tensors, "text_encoders.t5xxl"); + } + + void alloc_params_buffer() { + clip_l->alloc_params_buffer(); + t5->alloc_params_buffer(); + } + + void free_params_buffer() { + clip_l->free_params_buffer(); + t5->free_params_buffer(); + } + + size_t get_params_buffer_size() { + size_t buffer_size = clip_l->get_params_buffer_size(); + buffer_size += t5->get_params_buffer_size(); + return buffer_size; + } + + std::vector, std::vector>> tokenize(std::string text, + size_t max_length = 0, + bool padding = false) { + auto parsed_attention = parse_prompt_attention(text); + + { + std::stringstream ss; + ss << "["; + for (const auto& item : parsed_attention) { + ss << "['" << item.first << "', " << item.second << "], "; + } + ss << "]"; + LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str()); + } + + auto on_new_token_cb = [&](std::string& str, std::vector& bpe_tokens) -> bool { + return false; + }; + + std::vector clip_l_tokens; + std::vector clip_l_weights; + std::vector t5_tokens; + std::vector t5_weights; + for (const auto& item : parsed_attention) { + const std::string& curr_text = item.first; + float curr_weight = item.second; + + std::vector curr_tokens = clip_l_tokenizer.encode(curr_text, on_new_token_cb); + clip_l_tokens.insert(clip_l_tokens.end(), curr_tokens.begin(), curr_tokens.end()); + clip_l_weights.insert(clip_l_weights.end(), curr_tokens.size(), curr_weight); + + curr_tokens = t5_tokenizer.Encode(curr_text, true); + t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end()); + t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight); + } + + clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, 77, padding); + t5_tokenizer.pad_tokens(t5_tokens, t5_weights, max_length, padding); + + // for (int i = 0; i < clip_l_tokens.size(); i++) { + // std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", "; + // } + // std::cout << std::endl; + + // for (int i = 0; i < t5_tokens.size(); i++) { + // std::cout << t5_tokens[i] << ":" << t5_weights[i] << ", "; + // } + // std::cout << std::endl; + + return {{clip_l_tokens, clip_l_weights}, {t5_tokens, t5_weights}}; + } + + SDCondition get_learned_condition_common(ggml_context* work_ctx, + int n_threads, + std::vector, std::vector>> token_and_weights, + int clip_skip, + bool force_zero_embeddings = false) { + set_clip_skip(clip_skip); + auto& clip_l_tokens = token_and_weights[0].first; + auto& clip_l_weights = token_and_weights[0].second; + auto& t5_tokens = token_and_weights[1].first; + auto& t5_weights = token_and_weights[1].second; + + int64_t t0 = ggml_time_ms(); + struct ggml_tensor* hidden_states = NULL; // [N, n_token, 4096] + struct ggml_tensor* chunk_hidden_states = NULL; // [n_token, 4096] + struct ggml_tensor* pooled = NULL; // [768,] + std::vector hidden_states_vec; + + size_t chunk_len = 256; + size_t chunk_count = t5_tokens.size() / chunk_len; + for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) { + // clip_l + if (chunk_idx == 0) { + size_t chunk_len_l = 77; + std::vector chunk_tokens(clip_l_tokens.begin(), + clip_l_tokens.begin() + chunk_len_l); + std::vector chunk_weights(clip_l_weights.begin(), + clip_l_weights.begin() + chunk_len_l); + + auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens); + size_t max_token_idx = 0; + + // auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID); + // max_token_idx = std::min(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1); + // clip_l->compute(n_threads, + // input_ids, + // 0, + // NULL, + // max_token_idx, + // true, + // &pooled, + // work_ctx); + + // clip_l.transformer.text_model.text_projection no in file, ignore + // TODO: use torch.eye(embed_dim) as default clip_l.transformer.text_model.text_projection + pooled = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 768); + ggml_set_f32(pooled, 0.f); + } + + // t5 + { + std::vector chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len, + t5_tokens.begin() + (chunk_idx + 1) * chunk_len); + std::vector chunk_weights(t5_weights.begin() + chunk_idx * chunk_len, + t5_weights.begin() + (chunk_idx + 1) * chunk_len); + + auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens); + + t5->compute(n_threads, + input_ids, + &chunk_hidden_states, + work_ctx); + { + auto tensor = chunk_hidden_states; + float original_mean = ggml_tensor_mean(tensor); + for (int i2 = 0; i2 < tensor->ne[2]; i2++) { + for (int i1 = 0; i1 < tensor->ne[1]; i1++) { + for (int i0 = 0; i0 < tensor->ne[0]; i0++) { + float value = ggml_tensor_get_f32(tensor, i0, i1, i2); + value *= chunk_weights[i1]; + ggml_tensor_set_f32(tensor, value, i0, i1, i2); + } + } + } + float new_mean = ggml_tensor_mean(tensor); + ggml_tensor_scale(tensor, (original_mean / new_mean)); + } + } + + int64_t t1 = ggml_time_ms(); + LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0); + if (force_zero_embeddings) { + float* vec = (float*)chunk_hidden_states->data; + for (int i = 0; i < ggml_nelements(chunk_hidden_states); i++) { + vec[i] = 0; + } + } + + hidden_states_vec.insert(hidden_states_vec.end(), + (float*)chunk_hidden_states->data, + ((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states)); + } + + hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec); + hidden_states = ggml_reshape_2d(work_ctx, + hidden_states, + chunk_hidden_states->ne[0], + ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]); + return SDCondition(hidden_states, pooled, NULL); + } + + SDCondition get_learned_condition(ggml_context* work_ctx, + int n_threads, + const std::string& text, + int clip_skip, + int width, + int height, + int adm_in_channels = -1, + bool force_zero_embeddings = false) { + auto tokens_and_weights = tokenize(text, 256, true); + return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings); + } + + std::tuple> get_learned_condition_with_trigger(ggml_context* work_ctx, + int n_threads, + const std::string& text, + int clip_skip, + int width, + int height, + int num_input_imgs, + int adm_in_channels = -1, + bool force_zero_embeddings = false) { + GGML_ASSERT(0 && "Not implemented yet!"); + } + + std::string remove_trigger_from_prompt(ggml_context* work_ctx, + const std::string& prompt) { + GGML_ASSERT(0 && "Not implemented yet!"); + } +}; + #endif \ No newline at end of file diff --git a/control.hpp b/control.hpp index 3375e7306..41f31acb7 100644 --- a/control.hpp +++ b/control.hpp @@ -14,7 +14,7 @@ */ class ControlNetBlock : public GGMLBlock { protected: - SDVersion version = VERSION_1_x; + SDVersion version = VERSION_SD1; // network hparams int in_channels = 4; int out_channels = 4; @@ -26,19 +26,19 @@ class ControlNetBlock : public GGMLBlock { int time_embed_dim = 1280; // model_channels*4 int num_heads = 8; int num_head_channels = -1; // channels // num_heads - int context_dim = 768; // 1024 for VERSION_2_x, 2048 for VERSION_XL + int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL public: int model_channels = 320; - int adm_in_channels = 2816; // only for VERSION_XL + int adm_in_channels = 2816; // only for VERSION_SDXL - ControlNetBlock(SDVersion version = VERSION_1_x) + ControlNetBlock(SDVersion version = VERSION_SD1) : version(version) { - if (version == VERSION_2_x) { + if (version == VERSION_SD2) { context_dim = 1024; num_head_channels = 64; num_heads = -1; - } else if (version == VERSION_XL) { + } else if (version == VERSION_SDXL) { context_dim = 2048; attention_resolutions = {4, 2}; channel_mult = {1, 2, 4}; @@ -58,7 +58,7 @@ class ControlNetBlock : public GGMLBlock { // time_embed_1 is nn.SiLU() blocks["time_embed.2"] = std::shared_ptr(new Linear(time_embed_dim, time_embed_dim)); - if (version == VERSION_XL || version == VERSION_SVD) { + if (version == VERSION_SDXL || version == VERSION_SVD) { blocks["label_emb.0.0"] = std::shared_ptr(new Linear(adm_in_channels, time_embed_dim)); // label_emb_1 is nn.SiLU() blocks["label_emb.0.2"] = std::shared_ptr(new Linear(time_embed_dim, time_embed_dim)); @@ -307,7 +307,7 @@ class ControlNetBlock : public GGMLBlock { }; struct ControlNet : public GGMLRunner { - SDVersion version = VERSION_1_x; + SDVersion version = VERSION_SD1; ControlNetBlock control_net; ggml_backend_buffer_t control_buffer = NULL; // keep control output tensors in backend memory @@ -318,7 +318,7 @@ struct ControlNet : public GGMLRunner { ControlNet(ggml_backend_t backend, ggml_type wtype, - SDVersion version = VERSION_1_x) + SDVersion version = VERSION_SD1) : GGMLRunner(backend, wtype), control_net(version) { control_net.init(params_ctx, wtype); } diff --git a/denoiser.hpp b/denoiser.hpp index 26f4c853d..85e4a0bb7 100644 --- a/denoiser.hpp +++ b/denoiser.hpp @@ -8,6 +8,7 @@ // Ref: https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/external.py #define TIMESTEPS 1000 +#define FLUX_TIMESTEPS 1000 struct SigmaSchedule { int version = 0; @@ -144,13 +145,13 @@ struct AYSSchedule : SigmaSchedule { std::vector results(n + 1); switch (version) { - case VERSION_2_x: /* fallthrough */ + case VERSION_SD2: /* fallthrough */ LOG_WARN("AYS not designed for SD2.X models"); - case VERSION_1_x: + case VERSION_SD1: LOG_INFO("AYS using SD1.5 noise levels"); inputs = noise_levels[0]; break; - case VERSION_XL: + case VERSION_SDXL: LOG_INFO("AYS using SDXL noise levels"); inputs = noise_levels[1]; break; @@ -350,6 +351,66 @@ struct DiscreteFlowDenoiser : public Denoiser { } }; + +float flux_time_shift(float mu, float sigma, float t) { + return std::exp(mu) / (std::exp(mu) + std::pow((1.0 / t - 1.0), sigma)); +} + +struct FluxFlowDenoiser : public Denoiser { + float sigmas[TIMESTEPS]; + float shift = 1.15f; + + float sigma_data = 1.0f; + + FluxFlowDenoiser(float shift = 1.15f) { + set_parameters(shift); + } + + void set_parameters(float shift = 1.15f) { + this->shift = shift; + for (int i = 1; i < TIMESTEPS + 1; i++) { + sigmas[i - 1] = t_to_sigma(i/TIMESTEPS * TIMESTEPS); + } + } + + float sigma_min() { + return sigmas[0]; + } + + float sigma_max() { + return sigmas[TIMESTEPS - 1]; + } + + float sigma_to_t(float sigma) { + return sigma; + } + + float t_to_sigma(float t) { + t = t + 1; + return flux_time_shift(shift, 1.0f, t / TIMESTEPS); + } + + std::vector get_scalings(float sigma) { + float c_skip = 1.0f; + float c_out = -sigma; + float c_in = 1.0f; + return {c_skip, c_out, c_in}; + } + + // this function will modify noise/latent + ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) { + ggml_tensor_scale(noise, sigma); + ggml_tensor_scale(latent, 1.0f - sigma); + ggml_tensor_add(latent, noise); + return latent; + } + + ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) { + ggml_tensor_scale(latent, 1.0f / (1.0f - sigma)); + return latent; + } +}; + typedef std::function denoise_cb_t; // k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t diff --git a/diffusion_model.hpp b/diffusion_model.hpp index fb2849457..5c214e1d6 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -3,6 +3,7 @@ #include "mmdit.hpp" #include "unet.hpp" +#include "flux.hpp" struct DiffusionModel { virtual void compute(int n_threads, @@ -11,6 +12,7 @@ struct DiffusionModel { struct ggml_tensor* context, struct ggml_tensor* c_concat, struct ggml_tensor* y, + struct ggml_tensor* guidance, int num_video_frames = -1, std::vector controls = {}, float control_strength = 0.f, @@ -29,7 +31,7 @@ struct UNetModel : public DiffusionModel { UNetModel(ggml_backend_t backend, ggml_type wtype, - SDVersion version = VERSION_1_x) + SDVersion version = VERSION_SD1) : unet(backend, wtype, version) { } @@ -63,6 +65,7 @@ struct UNetModel : public DiffusionModel { struct ggml_tensor* context, struct ggml_tensor* c_concat, struct ggml_tensor* y, + struct ggml_tensor* guidance, int num_video_frames = -1, std::vector controls = {}, float control_strength = 0.f, @@ -77,7 +80,7 @@ struct MMDiTModel : public DiffusionModel { MMDiTModel(ggml_backend_t backend, ggml_type wtype, - SDVersion version = VERSION_3_2B) + SDVersion version = VERSION_SD3_2B) : mmdit(backend, wtype, version) { } @@ -111,6 +114,7 @@ struct MMDiTModel : public DiffusionModel { struct ggml_tensor* context, struct ggml_tensor* c_concat, struct ggml_tensor* y, + struct ggml_tensor* guidance, int num_video_frames = -1, std::vector controls = {}, float control_strength = 0.f, @@ -120,4 +124,54 @@ struct MMDiTModel : public DiffusionModel { } }; + +struct FluxModel : public DiffusionModel { + Flux::FluxRunner flux; + + FluxModel(ggml_backend_t backend, + ggml_type wtype, + SDVersion version = VERSION_FLUX_DEV) + : flux(backend, wtype, version) { + } + + void alloc_params_buffer() { + flux.alloc_params_buffer(); + } + + void free_params_buffer() { + flux.free_params_buffer(); + } + + void free_compute_buffer() { + flux.free_compute_buffer(); + } + + void get_param_tensors(std::map& tensors) { + flux.get_param_tensors(tensors, "model.diffusion_model"); + } + + size_t get_params_buffer_size() { + return flux.get_params_buffer_size(); + } + + int64_t get_adm_in_channels() { + return 768; + } + + void compute(int n_threads, + struct ggml_tensor* x, + struct ggml_tensor* timesteps, + struct ggml_tensor* context, + struct ggml_tensor* c_concat, + struct ggml_tensor* y, + struct ggml_tensor* guidance, + int num_video_frames = -1, + std::vector controls = {}, + float control_strength = 0.f, + struct ggml_tensor** output = NULL, + struct ggml_context* output_ctx = NULL) { + return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx); + } +}; + #endif \ No newline at end of file diff --git a/docs/flux.md b/docs/flux.md new file mode 100644 index 000000000..f324ad17a --- /dev/null +++ b/docs/flux.md @@ -0,0 +1,63 @@ +# How to Use + +You can run Flux using stable-diffusion.cpp with a GPU that has 6GB or even 4GB of VRAM, without needing to offload to RAM. + +## Download weights + +- Download flux-dev from https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors +- Download flux-schnell from https://huggingface.co/black-forest-labs/FLUX.1-schnell/blob/main/flux1-schnell.safetensors +- Download vae from https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/ae.safetensors +- Download clip_l from https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/clip_l.safetensors +- Download t5xxl from https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/t5xxl_fp16.safetensors + +## Convert flux weights + +Using fp16 will lead to overflow, but ggml's support for bf16 is not yet fully developed. Therefore, we need to convert flux to gguf format here, which also saves VRAM. For example: +``` +.\bin\Release\sd.exe -M convert -m ..\..\ComfyUI\models\unet\flux1-dev.sft -o ..\models\flux1-dev-q8_0.gguf -v --type q8_0 +``` + +## Run + +- `--cfg-scale` is recommended to be set to 1. + +### Flux-dev +For example: + +``` + .\bin\Release\sd.exe --diffusion-model ..\models\flux1-dev-q8_0.gguf --vae ..\models\ae.sft --clip_l ..\models\clip_l.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -p "a lovely cat holding a sign says 'flux.cpp'" --cfg-scale 1.0 --sampling-method euler -v +``` + +Using formats of different precisions will yield results of varying quality. + +| Type | q8_0 | q4_0 | q3_k | q2_k | +|---- | ---- |---- |---- |---- | +| **Memory** | 12068.09 MB | 6394.53 MB | 4888.16 MB | 3735.73 MB | +| **Result** | ![](../assets/flux/flux1-dev-q8_0.png) |![](../assets/flux/flux1-dev-q4_0.png) |![](../assets/flux/flux1-dev-q3_k.png) |![](../assets/flux/flux1-dev-q2_k.png)| + + + +### Flux-schnell + + +``` + .\bin\Release\sd.exe --diffusion-model ..\models\flux1-schnell-q8_0.gguf --vae ..\models\ae.sft --clip_l ..\models\clip_l.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -p "a lovely cat holding a sign says 'flux.cpp'" --cfg-scale 1.0 --sampling-method euler -v --steps 4 +``` + +| q8_0 | +| ---- | +|![](../assets/flux/flux1-schnell-q8_0.png) | + +## Run with LoRA + +Since many flux LoRA training libraries have used various LoRA naming formats, it is possible that not all flux LoRA naming formats are supported. It is recommended to use LoRA with naming formats compatible with ComfyUI. + +### Flux-dev q8_0 with LoRA + +- LoRA model from https://huggingface.co/XLabs-AI/flux-lora-collection/tree/main (using comfy converted version!!!) + +``` +.\bin\Release\sd.exe --diffusion-model ..\models\flux1-dev-q8_0.gguf --vae ...\models\ae.sft --clip_l ..\models\clip_l.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -p "a lovely cat holding a sign says 'flux.cpp'" --cfg-scale 1.0 --sampling-method euler -v --lora-model-dir ../models +``` + +![output](../assets/flux/flux1-dev-q8_0%20with%20lora.png) diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 6675095b5..1756a976b 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -7,9 +7,8 @@ #include // #include "preprocessing.hpp" -#include "mmdit.hpp" +#include "flux.hpp" #include "stable-diffusion.h" -#include "t5.hpp" #define STB_IMAGE_IMPLEMENTATION #define STB_IMAGE_STATIC @@ -68,6 +67,9 @@ struct SDParams { SDMode mode = TXT2IMG; std::string model_path; + std::string clip_l_path; + std::string t5xxl_path; + std::string diffusion_model_path; std::string vae_path; std::string taesd_path; std::string esrgan_path; @@ -85,6 +87,7 @@ struct SDParams { std::string negative_prompt; float min_cfg = 1.0f; float cfg_scale = 7.0f; + float guidance = 3.5f; float style_ratio = 20.f; int clip_skip = -1; // <= 0 represents unspecified int width = 512; @@ -120,6 +123,9 @@ void print_params(SDParams params) { printf(" mode: %s\n", modes_str[params.mode]); printf(" model_path: %s\n", params.model_path.c_str()); printf(" wtype: %s\n", params.wtype < SD_TYPE_COUNT ? sd_type_name(params.wtype) : "unspecified"); + printf(" clip_l_path: %s\n", params.clip_l_path.c_str()); + printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str()); + printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str()); printf(" vae_path: %s\n", params.vae_path.c_str()); printf(" taesd_path: %s\n", params.taesd_path.c_str()); printf(" esrgan_path: %s\n", params.esrgan_path.c_str()); @@ -140,6 +146,7 @@ void print_params(SDParams params) { printf(" negative_prompt: %s\n", params.negative_prompt.c_str()); printf(" min_cfg: %.2f\n", params.min_cfg); printf(" cfg_scale: %.2f\n", params.cfg_scale); + printf(" guidance: %.2f\n", params.guidance); printf(" clip_skip: %d\n", params.clip_skip); printf(" width: %d\n", params.width); printf(" height: %d\n", params.height); @@ -172,7 +179,7 @@ void print_usage(int argc, const char* argv[]) { printf(" --normalize-input normalize PHOTOMAKER input id images\n"); printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now.\n"); printf(" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n"); - printf(" --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)\n"); + printf(" --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k)\n"); printf(" If not specified, the default is the type of the weight file.\n"); printf(" --lora-model-dir [DIR] lora model directory\n"); printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n"); @@ -240,6 +247,24 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.model_path = argv[i]; + } else if (arg == "--clip_l") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.clip_l_path = argv[i]; + } else if (arg == "--t5xxl") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.t5xxl_path = argv[i]; + } else if (arg == "--diffusion-model") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.diffusion_model_path = argv[i]; } else if (arg == "--vae") { if (++i >= argc) { invalid_arg = true; @@ -302,8 +327,14 @@ void parse_args(int argc, const char** argv, SDParams& params) { params.wtype = SD_TYPE_Q5_1; } else if (type == "q8_0") { params.wtype = SD_TYPE_Q8_0; + } else if (type == "q2_k") { + params.wtype = SD_TYPE_Q2_K; + } else if (type == "q3_k") { + params.wtype = SD_TYPE_Q3_K; + } else if (type == "q4_k") { + params.wtype = SD_TYPE_Q4_K; } else { - fprintf(stderr, "error: invalid weight format %s, must be one of [f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0]\n", + fprintf(stderr, "error: invalid weight format %s, must be one of [f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k]\n", type.c_str()); exit(1); } @@ -359,6 +390,12 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.cfg_scale = std::stof(argv[i]); + } else if (arg == "--guidance") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.guidance = std::stof(argv[i]); } else if (arg == "--strength") { if (++i >= argc) { invalid_arg = true; @@ -501,8 +538,8 @@ void parse_args(int argc, const char** argv, SDParams& params) { exit(1); } - if (params.model_path.length() == 0) { - fprintf(stderr, "error: the following arguments are required: model_path\n"); + if (params.model_path.length() == 0 && params.diffusion_model_path.length() == 0) { + fprintf(stderr, "error: the following arguments are required: model_path/diffusion_model\n"); print_usage(argc, argv); exit(1); } @@ -570,6 +607,7 @@ std::string get_image_params(SDParams params, int64_t seed) { } parameter_string += "Steps: " + std::to_string(params.sample_steps) + ", "; parameter_string += "CFG scale: " + std::to_string(params.cfg_scale) + ", "; + parameter_string += "Guidance: " + std::to_string(params.guidance) + ", "; parameter_string += "Seed: " + std::to_string(seed) + ", "; parameter_string += "Size: " + std::to_string(params.width) + "x" + std::to_string(params.height) + ", "; parameter_string += "Model: " + sd_basename(params.model_path) + ", "; @@ -717,6 +755,9 @@ int main(int argc, const char* argv[]) { } sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(), + params.clip_l_path.c_str(), + params.t5xxl_path.c_str(), + params.diffusion_model_path.c_str(), params.vae_path.c_str(), params.taesd_path.c_str(), params.controlnet_path.c_str(), @@ -770,6 +811,7 @@ int main(int argc, const char* argv[]) { params.negative_prompt.c_str(), params.clip_skip, params.cfg_scale, + params.guidance, params.width, params.height, params.sample_method, @@ -830,6 +872,7 @@ int main(int argc, const char* argv[]) { params.negative_prompt.c_str(), params.clip_skip, params.cfg_scale, + params.guidance, params.width, params.height, params.sample_method, diff --git a/flux.hpp b/flux.hpp new file mode 100644 index 000000000..3b398b424 --- /dev/null +++ b/flux.hpp @@ -0,0 +1,965 @@ +#ifndef __FLUX_HPP__ +#define __FLUX_HPP__ + +#include + +#include "ggml_extend.hpp" +#include "model.h" + +#define FLUX_GRAPH_SIZE 10240 + +namespace Flux { + +struct MLPEmbedder : public UnaryBlock { +public: + MLPEmbedder(int64_t in_dim, int64_t hidden_dim) { + blocks["in_layer"] = std::shared_ptr(new Linear(in_dim, hidden_dim, true)); + blocks["out_layer"] = std::shared_ptr(new Linear(hidden_dim, hidden_dim, true)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + // x: [..., in_dim] + // return: [..., hidden_dim] + auto in_layer = std::dynamic_pointer_cast(blocks["in_layer"]); + auto out_layer = std::dynamic_pointer_cast(blocks["out_layer"]); + + x = in_layer->forward(ctx, x); + x = ggml_silu_inplace(ctx, x); + x = out_layer->forward(ctx, x); + return x; + } +}; + +class RMSNorm : public UnaryBlock { +protected: + int64_t hidden_size; + float eps; + + void init_params(struct ggml_context* ctx, ggml_type wtype) { + params["scale"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size); + } + +public: + RMSNorm(int64_t hidden_size, + float eps = 1e-06f) + : hidden_size(hidden_size), + eps(eps) {} + + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* w = params["scale"]; + x = ggml_rms_norm(ctx, x, eps); + x = ggml_mul(ctx, x, w); + return x; + } +}; + + +struct QKNorm : public GGMLBlock { +public: + QKNorm(int64_t dim) { + blocks["query_norm"] = std::shared_ptr(new RMSNorm(dim)); + blocks["key_norm"] = std::shared_ptr(new RMSNorm(dim)); + } + + struct ggml_tensor* query_norm(struct ggml_context* ctx, struct ggml_tensor* x) { + // x: [..., dim] + // return: [..., dim] + auto norm = std::dynamic_pointer_cast(blocks["query_norm"]); + + x = norm->forward(ctx, x); + return x; + } + + struct ggml_tensor* key_norm(struct ggml_context* ctx, struct ggml_tensor* x) { + // x: [..., dim] + // return: [..., dim] + auto norm = std::dynamic_pointer_cast(blocks["key_norm"]); + + x = norm->forward(ctx, x); + return x; + } +}; + +__STATIC_INLINE__ struct ggml_tensor* apply_rope(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* pe) { + // x: [N, L, n_head, d_head] + // pe: [L, d_head/2, 2, 2] + int64_t d_head = x->ne[0]; + int64_t n_head = x->ne[1]; + int64_t L = x->ne[2]; + int64_t N = x->ne[3]; + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, n_head, L, d_head] + x = ggml_reshape_4d(ctx, x, 2, d_head/2, L, n_head * N); // [N * n_head, L, d_head/2, 2] + x = ggml_cont(ctx, ggml_permute(ctx, x, 3, 0, 1, 2)); // [2, N * n_head, L, d_head/2] + + int64_t offset = x->nb[2] * x->ne[2]; + auto x_0 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 0); // [N * n_head, L, d_head/2] + auto x_1 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 1); // [N * n_head, L, d_head/2] + x_0 = ggml_reshape_4d(ctx, x_0, 1, x_0->ne[0], x_0->ne[1], x_0->ne[2]); // [N * n_head, L, d_head/2, 1] + x_1 = ggml_reshape_4d(ctx, x_1, 1, x_1->ne[0], x_1->ne[1], x_1->ne[2]); // [N * n_head, L, d_head/2, 1] + auto temp_x = ggml_new_tensor_4d(ctx, x_0->type, 2, x_0->ne[1], x_0->ne[2], x_0->ne[3]); + x_0 = ggml_repeat(ctx, x_0, temp_x); // [N * n_head, L, d_head/2, 2] + x_1 = ggml_repeat(ctx, x_1, temp_x); // [N * n_head, L, d_head/2, 2] + + pe = ggml_cont(ctx, ggml_permute(ctx, pe, 3, 0, 1, 2)); // [2, L, d_head/2, 2] + offset = pe->nb[2] * pe->ne[2]; + auto pe_0 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 0); // [L, d_head/2, 2] + auto pe_1 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 1); // [L, d_head/2, 2] + + auto x_out = ggml_add_inplace(ctx, ggml_mul(ctx, x_0, pe_0), ggml_mul(ctx, x_1, pe_1)); // [N * n_head, L, d_head/2, 2] + x_out = ggml_reshape_3d(ctx, x_out, d_head, L, n_head*N); // [N*n_head, L, d_head] + return x_out; +} + +__STATIC_INLINE__ struct ggml_tensor* attention(struct ggml_context* ctx, + struct ggml_tensor* q, + struct ggml_tensor* k, + struct ggml_tensor* v, + struct ggml_tensor* pe) { + // q,k,v: [N, L, n_head, d_head] + // pe: [L, d_head/2, 2, 2] + // return: [N, L, n_head*d_head] + q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head] + k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head] + + auto x = ggml_nn_attention_ext(ctx, q, k, v, v->ne[1], NULL, false, true); // [N, L, n_head*d_head] + return x; +} + +struct SelfAttention : public GGMLBlock { +public: + int64_t num_heads; + +public: + SelfAttention(int64_t dim, + int64_t num_heads = 8, + bool qkv_bias = false) + : num_heads(num_heads) { + int64_t head_dim = dim / num_heads; + blocks["qkv"] = std::shared_ptr(new Linear(dim, dim * 3, qkv_bias)); + blocks["norm"] = std::shared_ptr(new QKNorm(head_dim)); + blocks["proj"] = std::shared_ptr(new Linear(dim, dim)); + } + + std::vector pre_attention(struct ggml_context* ctx, struct ggml_tensor* x) { + auto qkv_proj = std::dynamic_pointer_cast(blocks["qkv"]); + auto norm = std::dynamic_pointer_cast(blocks["norm"]); + + + auto qkv = qkv_proj->forward(ctx, x); + auto qkv_vec = split_qkv(ctx, qkv); + int64_t head_dim = qkv_vec[0]->ne[0] / num_heads; + auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); + auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); + auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); + q = norm->query_norm(ctx, q); + k = norm->key_norm(ctx, k); + return {q, k, v}; + } + + struct ggml_tensor* post_attention(struct ggml_context* ctx, struct ggml_tensor* x) { + auto proj = std::dynamic_pointer_cast(blocks["proj"]); + + x = proj->forward(ctx, x); // [N, n_token, dim] + return x; + } + + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* pe) { + // x: [N, n_token, dim] + // pe: [n_token, d_head/2, 2, 2] + // return [N, n_token, dim] + auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head] + x = attention(ctx, qkv[0], qkv[1], qkv[2], pe); // [N, n_token, dim] + x = post_attention(ctx, x); // [N, n_token, dim] + return x; + } +}; + + +struct ModulationOut { + ggml_tensor* shift = NULL; + ggml_tensor* scale = NULL; + ggml_tensor* gate = NULL; + + ModulationOut(ggml_tensor* shift = NULL, ggml_tensor* scale = NULL, ggml_tensor* gate = NULL) + : shift(shift), scale(scale), gate(gate) {} +}; + +struct Modulation : public GGMLBlock { +public: + bool is_double; + int multiplier; +public: + Modulation(int64_t dim, bool is_double): is_double(is_double) { + multiplier = is_double? 6 : 3; + blocks["lin"] = std::shared_ptr(new Linear(dim, dim * multiplier)); + } + + std::vector forward(struct ggml_context* ctx, struct ggml_tensor* vec) { + // x: [N, dim] + // return: [ModulationOut, ModulationOut] + auto lin = std::dynamic_pointer_cast(blocks["lin"]); + + auto out = ggml_silu(ctx, vec); + out = lin->forward(ctx, out); // [N, multiplier*dim] + + auto m = ggml_reshape_3d(ctx, out, vec->ne[0], multiplier, vec->ne[1]); // [N, multiplier, dim] + m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [multiplier, N, dim] + + int64_t offset = m->nb[1] * m->ne[1]; + auto shift_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, dim] + auto scale_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, dim] + auto gate_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, dim] + + if (is_double) { + auto shift_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, dim] + auto scale_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, dim] + auto gate_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, dim] + return {ModulationOut(shift_0, scale_0, gate_0), ModulationOut(shift_1, scale_1, gate_1)}; + } + + return {ModulationOut(shift_0, scale_0, gate_0), ModulationOut()}; + } +}; + +__STATIC_INLINE__ struct ggml_tensor* modulate(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* shift, + struct ggml_tensor* scale) { + // x: [N, L, C] + // scale: [N, C] + // shift: [N, C] + scale = ggml_reshape_3d(ctx, scale, scale->ne[0], 1, scale->ne[1]); // [N, 1, C] + shift = ggml_reshape_3d(ctx, shift, shift->ne[0], 1, shift->ne[1]); // [N, 1, C] + x = ggml_add(ctx, x, ggml_mul(ctx, x, scale)); + x = ggml_add(ctx, x, shift); + return x; +} + +struct DoubleStreamBlock : public GGMLBlock { +public: + DoubleStreamBlock(int64_t hidden_size, + int64_t num_heads, + float mlp_ratio, + bool qkv_bias = false) { + int64_t mlp_hidden_dim = hidden_size * mlp_ratio; + blocks["img_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); + blocks["img_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); + blocks["img_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias)); + + blocks["img_norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); + blocks["img_mlp.0"] = std::shared_ptr(new Linear(hidden_size, mlp_hidden_dim)); + // img_mlp.1 is nn.GELU(approximate="tanh") + blocks["img_mlp.2"] = std::shared_ptr(new Linear(mlp_hidden_dim, hidden_size)); + + blocks["txt_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); + blocks["txt_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); + blocks["txt_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias)); + + blocks["txt_norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); + blocks["txt_mlp.0"] = std::shared_ptr(new Linear(hidden_size, mlp_hidden_dim)); + // img_mlp.1 is nn.GELU(approximate="tanh") + blocks["txt_mlp.2"] = std::shared_ptr(new Linear(mlp_hidden_dim, hidden_size)); + } + + std::pair forward(struct ggml_context* ctx, + struct ggml_tensor* img, + struct ggml_tensor* txt, + struct ggml_tensor* vec, + struct ggml_tensor* pe) { + // img: [N, n_img_token, hidden_size] + // txt: [N, n_txt_token, hidden_size] + // pe: [n_img_token + n_txt_token, d_head/2, 2, 2] + // return: ([N, n_img_token, hidden_size], [N, n_txt_token, hidden_size]) + + auto img_mod = std::dynamic_pointer_cast(blocks["img_mod"]); + auto img_norm1 = std::dynamic_pointer_cast(blocks["img_norm1"]); + auto img_attn = std::dynamic_pointer_cast(blocks["img_attn"]); + + auto img_norm2 = std::dynamic_pointer_cast(blocks["img_norm2"]); + auto img_mlp_0 = std::dynamic_pointer_cast(blocks["img_mlp.0"]); + auto img_mlp_2 = std::dynamic_pointer_cast(blocks["img_mlp.2"]); + + auto txt_mod = std::dynamic_pointer_cast(blocks["txt_mod"]); + auto txt_norm1 = std::dynamic_pointer_cast(blocks["txt_norm1"]); + auto txt_attn = std::dynamic_pointer_cast(blocks["txt_attn"]); + + auto txt_norm2 = std::dynamic_pointer_cast(blocks["txt_norm2"]); + auto txt_mlp_0 = std::dynamic_pointer_cast(blocks["txt_mlp.0"]); + auto txt_mlp_2 = std::dynamic_pointer_cast(blocks["txt_mlp.2"]); + + + auto img_mods = img_mod->forward(ctx, vec); + ModulationOut img_mod1 = img_mods[0]; + ModulationOut img_mod2 = img_mods[1]; + auto txt_mods = txt_mod->forward(ctx, vec); + ModulationOut txt_mod1 = txt_mods[0]; + ModulationOut txt_mod2 = txt_mods[1]; + + // prepare image for attention + auto img_modulated = img_norm1->forward(ctx, img); + img_modulated = Flux::modulate(ctx, img_modulated, img_mod1.shift, img_mod1.scale); + auto img_qkv = img_attn->pre_attention(ctx, img_modulated); // q,k,v: [N, n_img_token, n_head, d_head] + auto img_q = img_qkv[0]; + auto img_k = img_qkv[1]; + auto img_v = img_qkv[2]; + + // prepare txt for attention + auto txt_modulated = txt_norm1->forward(ctx, txt); + txt_modulated = Flux::modulate(ctx, txt_modulated, txt_mod1.shift, txt_mod1.scale); + auto txt_qkv = txt_attn->pre_attention(ctx, txt_modulated); // q,k,v: [N, n_txt_token, n_head, d_head] + auto txt_q = txt_qkv[0]; + auto txt_k = txt_qkv[1]; + auto txt_v = txt_qkv[2]; + + // run actual attention + auto q = ggml_concat(ctx, txt_q, img_q, 2); // [N, n_txt_token + n_img_token, n_head, d_head] + auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] + auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] + + auto attn = attention(ctx, q, k, v, pe); // [N, n_txt_token + n_img_token, n_head*d_head] + attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] + auto txt_attn_out = ggml_view_3d(ctx, + attn, + attn->ne[0], + attn->ne[1], + txt->ne[1], + attn->nb[1], + attn->nb[2], + 0); // [n_txt_token, N, hidden_size] + txt_attn_out = ggml_cont(ctx, ggml_permute(ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size] + auto img_attn_out = ggml_view_3d(ctx, + attn, + attn->ne[0], + attn->ne[1], + img->ne[1], + attn->nb[1], + attn->nb[2], + attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size] + img_attn_out = ggml_cont(ctx, ggml_permute(ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size] + + // calculate the img bloks + img = ggml_add(ctx, img, ggml_mul(ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate)); + + auto img_mlp_out = img_mlp_0->forward(ctx, Flux::modulate(ctx, img_norm2->forward(ctx, img), img_mod2.shift, img_mod2.scale)); + img_mlp_out = ggml_gelu_inplace(ctx, img_mlp_out); + img_mlp_out = img_mlp_2->forward(ctx, img_mlp_out); + + img = ggml_add(ctx, img, ggml_mul(ctx, img_mlp_out, img_mod2.gate)); + + // calculate the txt bloks + txt = ggml_add(ctx, txt, ggml_mul(ctx, txt_attn->post_attention(ctx, txt_attn_out), txt_mod1.gate)); + + auto txt_mlp_out = txt_mlp_0->forward(ctx, Flux::modulate(ctx, txt_norm2->forward(ctx, txt), txt_mod2.shift, txt_mod2.scale)); + txt_mlp_out = ggml_gelu_inplace(ctx, txt_mlp_out); + txt_mlp_out = txt_mlp_2->forward(ctx, txt_mlp_out); + + txt = ggml_add(ctx, txt, ggml_mul(ctx, txt_mlp_out, txt_mod2.gate)); + + return {img, txt}; + } +}; + + +struct SingleStreamBlock : public GGMLBlock { +public: + int64_t num_heads; + int64_t hidden_size; + int64_t mlp_hidden_dim; +public: + SingleStreamBlock(int64_t hidden_size, + int64_t num_heads, + float mlp_ratio = 4.0f, + float qk_scale = 0.f) : + hidden_size(hidden_size), num_heads(num_heads) { + int64_t head_dim = hidden_size / num_heads; + float scale = qk_scale; + if (scale <= 0.f) { + scale = 1 / sqrt((float)head_dim); + } + mlp_hidden_dim = hidden_size * mlp_ratio; + + blocks["linear1"] = std::shared_ptr(new Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim)); + blocks["linear2"] = std::shared_ptr(new Linear(hidden_size + mlp_hidden_dim, hidden_size)); + blocks["norm"] = std::shared_ptr(new QKNorm(head_dim)); + blocks["pre_norm"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); + // mlp_act is nn.GELU(approximate="tanh") + blocks["modulation"] = std::shared_ptr(new Modulation(hidden_size, false)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* vec, + struct ggml_tensor* pe) { + // x: [N, n_token, hidden_size] + // pe: [n_token, d_head/2, 2, 2] + // return: [N, n_token, hidden_size] + + auto linear1 = std::dynamic_pointer_cast(blocks["linear1"]); + auto linear2 = std::dynamic_pointer_cast(blocks["linear2"]); + auto norm = std::dynamic_pointer_cast(blocks["norm"]); + auto pre_norm = std::dynamic_pointer_cast(blocks["pre_norm"]); + auto modulation = std::dynamic_pointer_cast(blocks["modulation"]); + + auto mods = modulation->forward(ctx, vec); + ModulationOut mod = mods[0]; + + auto x_mod = Flux::modulate(ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale); + auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim] + qkv_mlp = ggml_cont(ctx, ggml_permute(ctx, qkv_mlp, 2, 0, 1, 3)); // [hidden_size * 3 + mlp_hidden_dim, N, n_token] + + auto qkv = ggml_view_3d(ctx, + qkv_mlp, + qkv_mlp->ne[0], + qkv_mlp->ne[1], + hidden_size * 3, + qkv_mlp->nb[1], + qkv_mlp->nb[2], + 0); // [hidden_size * 3 , N, n_token] + qkv = ggml_cont(ctx, ggml_permute(ctx, qkv, 1, 2, 0, 3)); // [N, n_token, hidden_size * 3] + auto mlp = ggml_view_3d(ctx, + qkv_mlp, + qkv_mlp->ne[0], + qkv_mlp->ne[1], + mlp_hidden_dim, + qkv_mlp->nb[1], + qkv_mlp->nb[2], + qkv_mlp->nb[2] * hidden_size * 3); // [mlp_hidden_dim , N, n_token] + mlp = ggml_cont(ctx, ggml_permute(ctx, mlp, 1, 2, 0, 3)); // [N, n_token, mlp_hidden_dim] + + auto qkv_vec = split_qkv(ctx, qkv); // q,k,v: [N, n_token, hidden_size] + int64_t head_dim = hidden_size / num_heads; + auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head] + auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head] + auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head] + q = norm->query_norm(ctx, q); + k = norm->key_norm(ctx, k); + auto attn = attention(ctx, q, k, v, pe); // [N, n_token, hidden_size] + + auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim] + auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size] + + output = ggml_add(ctx, x, ggml_mul(ctx, output, mod.gate)); + return output; + } +}; + + +struct LastLayer : public GGMLBlock { +public: + LastLayer(int64_t hidden_size, + int64_t patch_size, + int64_t out_channels) { + blocks["norm_final"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); + blocks["linear"] = std::shared_ptr(new Linear(hidden_size, patch_size * patch_size * out_channels)); + blocks["adaLN_modulation.1"] = std::shared_ptr(new Linear(hidden_size, 2 * hidden_size)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* c) { + // x: [N, n_token, hidden_size] + // c: [N, hidden_size] + // return: [N, n_token, patch_size * patch_size * out_channels] + auto norm_final = std::dynamic_pointer_cast(blocks["norm_final"]); + auto linear = std::dynamic_pointer_cast(blocks["linear"]); + auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); + + auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, 2 * hidden_size] + m = ggml_reshape_3d(ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size] + m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size] + + int64_t offset = m->nb[1] * m->ne[1]; + auto shift = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] + auto scale = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] + + x = Flux::modulate(ctx, norm_final->forward(ctx, x), shift, scale); + x = linear->forward(ctx, x); + + return x; + } +}; + +struct FluxParams { + int64_t in_channels = 64; + int64_t vec_in_dim=768; + int64_t context_in_dim = 4096; + int64_t hidden_size = 3072; + float mlp_ratio = 4.0f; + int64_t num_heads = 24; + int64_t depth = 19; + int64_t depth_single_blocks = 38; + std::vector axes_dim = {16, 56, 56}; + int64_t axes_dim_sum = 128; + int theta = 10000; + bool qkv_bias = true; + bool guidance_embed = true; +}; + + +struct Flux : public GGMLBlock { +public: + std::vector linspace(float start, float end, int num) { + std::vector result(num); + float step = (end - start) / (num - 1); + for (int i = 0; i < num; ++i) { + result[i] = start + i * step; + } + return result; + } + + std::vector> transpose(const std::vector>& mat) { + int rows = mat.size(); + int cols = mat[0].size(); + std::vector> transposed(cols, std::vector(rows)); + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < cols; ++j) { + transposed[j][i] = mat[i][j]; + } + } + return transposed; + } + + std::vector flatten(const std::vector>& vec) { + std::vector flat_vec; + for (const auto& sub_vec : vec) { + flat_vec.insert(flat_vec.end(), sub_vec.begin(), sub_vec.end()); + } + return flat_vec; + } + + std::vector> rope(const std::vector& pos, int dim, int theta) { + assert(dim % 2 == 0); + int half_dim = dim / 2; + + std::vector scale = linspace(0, (dim * 1.0f - 2) / dim, half_dim); + + std::vector omega(half_dim); + for (int i = 0; i < half_dim; ++i) { + omega[i] = 1.0 / std::pow(theta, scale[i]); + } + + int pos_size = pos.size(); + std::vector> out(pos_size, std::vector(half_dim)); + for (int i = 0; i < pos_size; ++i) { + for (int j = 0; j < half_dim; ++j) { + out[i][j] = pos[i] * omega[j]; + } + } + + std::vector> result(pos_size, std::vector(half_dim * 4)); + for (int i = 0; i < pos_size; ++i) { + for (int j = 0; j < half_dim; ++j) { + result[i][4 * j] = std::cos(out[i][j]); + result[i][4 * j + 1] = -std::sin(out[i][j]); + result[i][4 * j + 2] = std::sin(out[i][j]); + result[i][4 * j + 3] = std::cos(out[i][j]); + } + } + + return result; + } + + // Generate IDs for image patches and text + std::vector> gen_ids(int h, int w, int patch_size, int bs, int context_len) { + int h_len = (h + (patch_size / 2)) / patch_size; + int w_len = (w + (patch_size / 2)) / patch_size; + + std::vector> img_ids(h_len * w_len, std::vector(3, 0.0)); + + std::vector row_ids = linspace(0, h_len - 1, h_len); + std::vector col_ids = linspace(0, w_len - 1, w_len); + + for (int i = 0; i < h_len; ++i) { + for (int j = 0; j < w_len; ++j) { + img_ids[i * w_len + j][1] = row_ids[i]; + img_ids[i * w_len + j][2] = col_ids[j]; + } + } + + std::vector> img_ids_repeated(bs * img_ids.size(), std::vector(3)); + for (int i = 0; i < bs; ++i) { + for (int j = 0; j < img_ids.size(); ++j) { + img_ids_repeated[i * img_ids.size() + j] = img_ids[j]; + } + } + + std::vector> txt_ids(bs * context_len, std::vector(3, 0.0)); + std::vector> ids(bs * (context_len + img_ids.size()), std::vector(3)); + for (int i = 0; i < bs; ++i) { + for (int j = 0; j < context_len; ++j) { + ids[i * (context_len + img_ids.size()) + j] = txt_ids[j]; + } + for (int j = 0; j < img_ids.size(); ++j) { + ids[i * (context_len + img_ids.size()) + context_len + j] = img_ids_repeated[i * img_ids.size() + j]; + } + } + + return ids; + } + + // Generate positional embeddings + std::vector gen_pe(int h, int w, int patch_size, int bs, int context_len, int theta, const std::vector& axes_dim) { + std::vector> ids = gen_ids(h, w, patch_size, bs, context_len); + std::vector> trans_ids = transpose(ids); + size_t pos_len = ids.size(); + int num_axes = axes_dim.size(); + for (int i = 0; i < pos_len; i++) { + // std::cout << trans_ids[0][i] << " " << trans_ids[1][i] << " " << trans_ids[2][i] << std::endl; + } + + + int emb_dim = 0; + for (int d : axes_dim) emb_dim += d / 2; + + std::vector> emb(bs * pos_len, std::vector(emb_dim * 2 * 2, 0.0)); + int offset = 0; + for (int i = 0; i < num_axes; ++i) { + std::vector> rope_emb = rope(trans_ids[i], axes_dim[i], theta); // [bs*pos_len, axes_dim[i]/2 * 2 * 2] + for (int b = 0; b < bs; ++b) { + for (int j = 0; j < pos_len; ++j) { + for (int k = 0; k < rope_emb[0].size(); ++k) { + emb[b * pos_len + j][offset + k] = rope_emb[j][k]; + } + } + } + offset += rope_emb[0].size(); + } + + return flatten(emb); + } +public: + FluxParams params; + Flux() {} + Flux(FluxParams params) : params(params) { + int64_t out_channels = params.in_channels; + int64_t pe_dim = params.hidden_size / params.num_heads; + + blocks["img_in"] = std::shared_ptr(new Linear(params.in_channels, params.hidden_size)); + blocks["time_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); + blocks["vector_in"] = std::shared_ptr(new MLPEmbedder(params.vec_in_dim, params.hidden_size)); + if (params.guidance_embed) { + blocks["guidance_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); + } + blocks["txt_in"] = std::shared_ptr(new Linear(params.context_in_dim, params.hidden_size)); + + for (int i = 0; i < params.depth; i++) { + blocks["double_blocks." + std::to_string(i)] = std::shared_ptr(new DoubleStreamBlock(params.hidden_size, + params.num_heads, + params.mlp_ratio, + params.qkv_bias)); + } + + for (int i = 0; i < params.depth_single_blocks; i++) { + blocks["single_blocks." + std::to_string(i)] = std::shared_ptr(new SingleStreamBlock(params.hidden_size, + params.num_heads, + params.mlp_ratio)); + } + + blocks["final_layer"] = std::shared_ptr(new LastLayer(params.hidden_size, 1, out_channels)); + } + + struct ggml_tensor* patchify(struct ggml_context* ctx, + struct ggml_tensor* x, + int64_t patch_size) { + // x: [N, C, H, W] + // return: [N, h*w, C * patch_size * patch_size] + int64_t N = x->ne[3]; + int64_t C = x->ne[2]; + int64_t H = x->ne[1]; + int64_t W = x->ne[0]; + int64_t p = patch_size; + int64_t h = H / patch_size; + int64_t w = W / patch_size; + + GGML_ASSERT(h * p == H && w * p == W); + + x = ggml_reshape_4d(ctx, x, p, w, p, h*C*N); // [N*C*h, p, w, p] + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, w, p, p] + x = ggml_reshape_4d(ctx, x, p * p, w * h, C, N); // [N, C, h*w, p*p] + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, h*w, C, p*p] + x = ggml_reshape_3d(ctx, x, p*p*C, w*h, N); // [N, h*w, C*p*p] + return x; + } + + struct ggml_tensor* unpatchify(struct ggml_context* ctx, + struct ggml_tensor* x, + int64_t h, + int64_t w, + int64_t patch_size) { + // x: [N, h*w, C*patch_size*patch_size] + // return: [N, C, H, W] + int64_t N = x->ne[2]; + int64_t C = x->ne[0] / patch_size / patch_size; + int64_t H = h * patch_size; + int64_t W = w * patch_size; + int64_t p = patch_size; + + GGML_ASSERT(C * p * p == x->ne[0]); + + x = ggml_reshape_4d(ctx, x, p * p, C, w * h, N); // [N, h*w, C, p*p] + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N, C, h*w, p*p] + x = ggml_reshape_4d(ctx, x, p, p, w, h * C * N); // [N*C*h, w, p, p] + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // [N*C*h, p, w, p] + x = ggml_reshape_4d(ctx, x, W, H, C, N); // [N, C, h*p, w*p] + + return x; + } + + struct ggml_tensor* forward_orig(struct ggml_context* ctx, + struct ggml_tensor* img, + struct ggml_tensor* txt, + struct ggml_tensor* timesteps, + struct ggml_tensor* y, + struct ggml_tensor* guidance, + struct ggml_tensor* pe) { + auto img_in = std::dynamic_pointer_cast(blocks["img_in"]); + auto time_in = std::dynamic_pointer_cast(blocks["time_in"]); + auto vector_in = std::dynamic_pointer_cast(blocks["vector_in"]); + auto txt_in = std::dynamic_pointer_cast(blocks["txt_in"]); + auto final_layer = std::dynamic_pointer_cast(blocks["final_layer"]); + + img = img_in->forward(ctx, img); + auto vec = time_in->forward(ctx, ggml_nn_timestep_embedding(ctx, timesteps, 256, 10000, 1000.f)); + + if (params.guidance_embed) { + GGML_ASSERT(guidance != NULL); + auto guidance_in = std::dynamic_pointer_cast(blocks["guidance_in"]); + // bf16 and fp16 result is different + auto g_in = ggml_nn_timestep_embedding(ctx, guidance, 256, 10000, 1000.f); + vec = ggml_add(ctx, vec, guidance_in->forward(ctx, g_in)); + } + + vec = ggml_add(ctx, vec, vector_in->forward(ctx, y)); + txt = txt_in->forward(ctx, txt); + + for (int i = 0; i < params.depth; i++) { + auto block = std::dynamic_pointer_cast(blocks["double_blocks." + std::to_string(i)]); + + auto img_txt = block->forward(ctx, img, txt, vec, pe); + img = img_txt.first; // [N, n_img_token, hidden_size] + txt = img_txt.second; // [N, n_txt_token, hidden_size] + } + + auto txt_img = ggml_concat(ctx, txt, img, 1); // [N, n_txt_token + n_img_token, hidden_size] + for (int i = 0; i < params.depth_single_blocks; i++) { + auto block = std::dynamic_pointer_cast(blocks["single_blocks." + std::to_string(i)]); + + txt_img = block->forward(ctx, txt_img, vec, pe); + } + + txt_img = ggml_cont(ctx, ggml_permute(ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] + img = ggml_view_3d(ctx, + txt_img, + txt_img->ne[0], + txt_img->ne[1], + img->ne[1], + txt_img->nb[1], + txt_img->nb[2], + txt_img->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size] + img = ggml_cont(ctx, ggml_permute(ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size] + + img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels) + + return img; + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* timestep, + struct ggml_tensor* context, + struct ggml_tensor* y, + struct ggml_tensor* guidance, + struct ggml_tensor* pe) { + // Forward pass of DiT. + // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + // timestep: (N,) tensor of diffusion timesteps + // context: (N, L, D) + // y: (N, adm_in_channels) tensor of class labels + // guidance: (N,) + // pe: (L, d_head/2, 2, 2) + // return: (N, C, H, W) + + GGML_ASSERT(x->ne[3] == 1); + + int64_t W = x->ne[0]; + int64_t H = x->ne[1]; + int64_t patch_size = 2; + int pad_h = (patch_size - H % patch_size) % patch_size; + int pad_w = (patch_size - W % patch_size) % patch_size; + x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w] + + // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) + auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size] + + auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe); // [N, h*w, C * patch_size * patch_size] + + // rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2) + out = unpatchify(ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size, patch_size); // [N, C, H + pad_h, W + pad_w] + + return out; + } +}; + + +struct FluxRunner : public GGMLRunner { +public: + FluxParams flux_params; + Flux flux; + std::vector pe_vec; // for cache + + FluxRunner(ggml_backend_t backend, + ggml_type wtype, + SDVersion version = VERSION_FLUX_DEV) + : GGMLRunner(backend, wtype) { + if (version == VERSION_FLUX_SCHNELL) { + flux_params.guidance_embed = false; + } + flux = Flux(flux_params); + flux.init(params_ctx, wtype); + } + + std::string get_desc() { + return "flux"; + } + + void get_param_tensors(std::map& tensors, const std::string prefix) { + flux.get_param_tensors(tensors, prefix); + } + + struct ggml_cgraph* build_graph(struct ggml_tensor* x, + struct ggml_tensor* timesteps, + struct ggml_tensor* context, + struct ggml_tensor* y, + struct ggml_tensor* guidance) { + GGML_ASSERT(x->ne[3] == 1); + struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false); + + x = to_backend(x); + context = to_backend(context); + y = to_backend(y); + timesteps = to_backend(timesteps); + if (flux_params.guidance_embed) { + guidance = to_backend(guidance); + } + + pe_vec = flux.gen_pe(x->ne[1], x->ne[0], 2, x->ne[3], context->ne[1], flux_params.theta, flux_params.axes_dim); + int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2; + // LOG_DEBUG("pos_len %d", pos_len); + auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum/2, pos_len); + // pe->data = pe_vec.data(); + // print_ggml_tensor(pe); + // pe->data = NULL; + set_backend_tensor_data(pe, pe_vec.data()); + + + struct ggml_tensor* out = flux.forward(compute_ctx, + x, + timesteps, + context, + y, + guidance, + pe); + + ggml_build_forward_expand(gf, out); + + return gf; + } + + void compute(int n_threads, + struct ggml_tensor* x, + struct ggml_tensor* timesteps, + struct ggml_tensor* context, + struct ggml_tensor* y, + struct ggml_tensor* guidance, + struct ggml_tensor** output = NULL, + struct ggml_context* output_ctx = NULL) { + // x: [N, in_channels, h, w] + // timesteps: [N, ] + // context: [N, max_position, hidden_size] + // y: [N, adm_in_channels] or [1, adm_in_channels] + // guidance: [N, ] + auto get_graph = [&]() -> struct ggml_cgraph* { + return build_graph(x, timesteps, context, y, guidance); + }; + + GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); + } + + void test() { + struct ggml_init_params params; + params.mem_size = static_cast(20 * 1024 * 1024); // 20 MB + params.mem_buffer = NULL; + params.no_alloc = false; + + struct ggml_context* work_ctx = ggml_init(params); + GGML_ASSERT(work_ctx != NULL); + + { + // cpu f16: + // cuda f16: nan + // cuda q8_0: pass + auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 16, 16, 16, 1); + ggml_set_f32(x, 0.01f); + // print_ggml_tensor(x); + + std::vector timesteps_vec(1, 999.f); + auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec); + + std::vector guidance_vec(1, 3.5f); + auto guidance = vector_to_ggml_tensor(work_ctx, guidance_vec); + + auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 4096, 256, 1); + ggml_set_f32(context, 0.01f); + // print_ggml_tensor(context); + + auto y = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 768, 1); + ggml_set_f32(y, 0.01f); + // print_ggml_tensor(y); + + struct ggml_tensor* out = NULL; + + int t0 = ggml_time_ms(); + compute(8, x, timesteps, context, y, guidance, &out, work_ctx); + int t1 = ggml_time_ms(); + + print_ggml_tensor(out); + LOG_DEBUG("flux test done in %dms", t1 - t0); + } + } + + static void load_from_file_and_test(const std::string& file_path) { + // ggml_backend_t backend = ggml_backend_cuda_init(0); + ggml_backend_t backend = ggml_backend_cpu_init(); + ggml_type model_data_type = GGML_TYPE_Q8_0; + std::shared_ptr flux = std::shared_ptr(new FluxRunner(backend, model_data_type)); + { + LOG_INFO("loading from '%s'", file_path.c_str()); + + flux->alloc_params_buffer(); + std::map tensors; + flux->get_param_tensors(tensors, "model.diffusion_model"); + + ModelLoader model_loader; + if (!model_loader.init_from_file(file_path, "model.diffusion_model.")) { + LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str()); + return; + } + + bool success = model_loader.load_tensors(tensors, backend); + + if (!success) { + LOG_ERROR("load tensors from model loader failed"); + return; + } + + LOG_INFO("flux model loaded"); + } + flux->test(); + } +}; + +} // namespace Flux + +#endif // __FLUX_HPP__ \ No newline at end of file diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 14ad37c0e..dcef98ad6 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -627,6 +627,20 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_3d_nx1x1(struct ggml_context* return x; // [N, OC, T, OH * OW] } +// qkv: [N, L, 3*C] +// return: ([N, L, C], [N, L, C], [N, L, C]) +__STATIC_INLINE__ std::vector split_qkv(struct ggml_context* ctx, + struct ggml_tensor* qkv) { + qkv = ggml_reshape_4d(ctx, qkv, qkv->ne[0] / 3, 3, qkv->ne[1], qkv->ne[2]); // [N, L, 3, C] + qkv = ggml_cont(ctx, ggml_permute(ctx, qkv, 0, 3, 1, 2)); // [3, N, L, C] + + int64_t offset = qkv->nb[2] * qkv->ne[2]; + auto q = ggml_view_3d(ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], qkv->nb[1], qkv->nb[2], offset * 0); // [N, L, C] + auto k = ggml_view_3d(ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], qkv->nb[1], qkv->nb[2], offset * 1); // [N, L, C] + auto v = ggml_view_3d(ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], qkv->nb[1], qkv->nb[2], offset * 2); // [N, L, C] + return {q, k, v}; +} + // q: [N * n_head, n_token, d_head] // k: [N * n_head, n_k, d_head] // v: [N * n_head, d_head, n_k] @@ -653,9 +667,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx return kqv; } -// q: [N, L_q, C] -// k: [N, L_k, C] -// v: [N, L_k, C] +// q: [N, L_q, C] or [N*n_head, L_q, d_head] +// k: [N, L_k, C] or [N*n_head, L_k, d_head] +// v: [N, L_k, C] or [N, L_k, n_head, d_head] // return: [N, L_q, C] __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* ctx, struct ggml_tensor* q, @@ -663,38 +677,61 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* struct ggml_tensor* v, int64_t n_head, struct ggml_tensor* mask = NULL, - bool diag_mask_inf = false) { - int64_t L_q = q->ne[1]; - int64_t L_k = k->ne[1]; - int64_t C = q->ne[0]; - int64_t N = q->ne[2]; + bool diag_mask_inf = false, + bool skip_reshape = false) { + int64_t L_q; + int64_t L_k; + int64_t C ; + int64_t N ; + int64_t d_head; + if (!skip_reshape) { + L_q = q->ne[1]; + L_k = k->ne[1]; + C = q->ne[0]; + N = q->ne[2]; + d_head = C / n_head; + q = ggml_reshape_4d(ctx, q, d_head, n_head, L_q, N); // [N, L_q, n_head, d_head] + q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, L_q, d_head] + q = ggml_reshape_3d(ctx, q, d_head, L_q, n_head * N); // [N * n_head, L_q, d_head] + + k = ggml_reshape_4d(ctx, k, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head] + k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, L_k, d_head] + k = ggml_reshape_3d(ctx, k, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head] + + v = ggml_reshape_4d(ctx, v, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head] + } else { + L_q = q->ne[1]; + L_k = k->ne[1]; + d_head = v->ne[0]; + N = v->ne[3]; + C = d_head * n_head; + } - int64_t d_head = C / n_head; float scale = (1.0f / sqrt((float)d_head)); - q = ggml_reshape_4d(ctx, q, d_head, n_head, L_q, N); // [N, L_q, n_head, d_head] - q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, L_q, d_head] - q = ggml_reshape_3d(ctx, q, d_head, L_q, n_head * N); // [N * n_head, L_q, d_head] - - k = ggml_reshape_4d(ctx, k, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head] - k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, L_k, d_head] - k = ggml_reshape_3d(ctx, k, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head] + bool use_flash_attn = false; + ggml_tensor* kqv = NULL; + if (use_flash_attn) { + v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head] + v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head] + LOG_DEBUG("k->ne[1] == %d", k->ne[1]); + kqv = ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0); + } else { + v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, L_k] + v = ggml_reshape_3d(ctx, v, L_k, d_head, n_head * N); // [N * n_head, d_head, L_k] - v = ggml_reshape_4d(ctx, v, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head] - v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, L_k] - v = ggml_reshape_3d(ctx, v, L_k, d_head, n_head * N); // [N * n_head, d_head, L_k] + auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k] + kq = ggml_scale_inplace(ctx, kq, scale); + if (mask) { + kq = ggml_add(ctx, kq, mask); + } + if (diag_mask_inf) { + kq = ggml_diag_mask_inf_inplace(ctx, kq, 0); + } + kq = ggml_soft_max_inplace(ctx, kq); - auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k] - kq = ggml_scale_inplace(ctx, kq, scale); - if (mask) { - kq = ggml_add(ctx, kq, mask); - } - if (diag_mask_inf) { - kq = ggml_diag_mask_inf_inplace(ctx, kq, 0); + kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, L_q, d_head] } - kq = ggml_soft_max_inplace(ctx, kq); - - auto kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, L_q, d_head] kqv = ggml_reshape_4d(ctx, kqv, d_head, L_q, n_head, N); // [N, n_head, L_q, d_head] kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); // [N, L_q, n_head, d_head] @@ -846,7 +883,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_timestep_embedding( struct ggml_context* ctx, struct ggml_tensor* timesteps, int dim, - int max_period = 10000) { + int max_period = 10000, + float time_factor = 1.0f) { + timesteps = ggml_scale(ctx, timesteps, time_factor); return ggml_timestep_embedding(ctx, timesteps, dim, max_period); } @@ -1144,6 +1183,9 @@ class Linear : public UnaryBlock { bool bias; void init_params(struct ggml_context* ctx, ggml_type wtype) { + if (in_features % ggml_blck_size(wtype) != 0) { + wtype = GGML_TYPE_F32; + } params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features); if (bias) { params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_features); diff --git a/lora.hpp b/lora.hpp index edee74cae..309378f38 100644 --- a/lora.hpp +++ b/lora.hpp @@ -12,6 +12,8 @@ struct LoraModel : public GGMLRunner { ModelLoader model_loader; bool load_failed = false; bool applied = false; + std::vector zero_index_vec = {0}; + ggml_tensor* zero_index = NULL; LoraModel(ggml_backend_t backend, ggml_type wtype, @@ -68,9 +70,19 @@ struct LoraModel : public GGMLRunner { return true; } + ggml_tensor* to_f32(ggml_context* ctx, ggml_tensor* a) { + auto out = ggml_reshape_1d(ctx, a, ggml_nelements(a)); + out = ggml_get_rows(ctx, out, zero_index); + out = ggml_reshape(ctx, out, a); + return out; + } + struct ggml_cgraph* build_lora_graph(std::map model_tensors) { struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, LORA_GRAPH_SIZE, false); + zero_index = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_I32, 1); + set_backend_tensor_data(zero_index, zero_index_vec.data()); + std::set applied_lora_tensors; for (auto it : model_tensors) { std::string k_tensor = it.first; @@ -141,15 +153,16 @@ struct LoraModel : public GGMLRunner { GGML_ASSERT(ggml_nelements(updown) == ggml_nelements(weight)); updown = ggml_scale_inplace(compute_ctx, updown, scale_value); ggml_tensor* final_weight; - // if (weight->type != GGML_TYPE_F32 && weight->type != GGML_TYPE_F16) { - // final_weight = ggml_new_tensor(compute_ctx, GGML_TYPE_F32, weight->n_dims, weight->ne); - // final_weight = ggml_cpy_inplace(compute_ctx, weight, final_weight); - // final_weight = ggml_add_inplace(compute_ctx, final_weight, updown); - // final_weight = ggml_cpy_inplace(compute_ctx, final_weight, weight); - // } else { - // final_weight = ggml_add_inplace(compute_ctx, weight, updown); - // } - final_weight = ggml_add_inplace(compute_ctx, weight, updown); // apply directly + if (weight->type != GGML_TYPE_F32 && weight->type != GGML_TYPE_F16) { + // final_weight = ggml_new_tensor(compute_ctx, GGML_TYPE_F32, ggml_n_dims(weight), weight->ne); + // final_weight = ggml_cpy(compute_ctx, weight, final_weight); + final_weight = to_f32(compute_ctx, weight); + final_weight = ggml_add_inplace(compute_ctx, final_weight, updown); + final_weight = ggml_cpy(compute_ctx, final_weight, weight); + } else { + final_weight = ggml_add_inplace(compute_ctx, weight, updown); + } + // final_weight = ggml_add_inplace(compute_ctx, weight, updown); // apply directly ggml_build_forward_expand(gf, final_weight); } diff --git a/mmdit.hpp b/mmdit.hpp index 7d7b22d9a..0a4d83107 100644 --- a/mmdit.hpp +++ b/mmdit.hpp @@ -142,20 +142,6 @@ struct VectorEmbedder : public GGMLBlock { } }; -__STATIC_INLINE__ std::vector split_qkv(struct ggml_context* ctx, - struct ggml_tensor* qkv) { - // qkv: [N, L, 3*C] - // return: ([N, L, C], [N, L, C], [N, L, C]) - qkv = ggml_reshape_4d(ctx, qkv, qkv->ne[0] / 3, 3, qkv->ne[1], qkv->ne[2]); // [N, L, 3, C] - qkv = ggml_cont(ctx, ggml_permute(ctx, qkv, 0, 3, 1, 2)); // [3, N, L, C] - - int64_t offset = qkv->nb[2] * qkv->ne[2]; - auto q = ggml_view_3d(ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], qkv->nb[1], qkv->nb[2], offset * 0); // [N, L, C] - auto k = ggml_view_3d(ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], qkv->nb[1], qkv->nb[2], offset * 1); // [N, L, C] - auto v = ggml_view_3d(ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], qkv->nb[1], qkv->nb[2], offset * 2); // [N, L, C] - return {q, k, v}; -} - class SelfAttention : public GGMLBlock { public: int64_t num_heads; @@ -469,7 +455,7 @@ struct FinalLayer : public GGMLBlock { struct MMDiT : public GGMLBlock { // Diffusion model with a Transformer backbone. protected: - SDVersion version = VERSION_3_2B; + SDVersion version = VERSION_SD3_2B; int64_t input_size = -1; int64_t patch_size = 2; int64_t in_channels = 16; @@ -487,7 +473,7 @@ struct MMDiT : public GGMLBlock { } public: - MMDiT(SDVersion version = VERSION_3_2B) + MMDiT(SDVersion version = VERSION_SD3_2B) : version(version) { // input_size is always None // learn_sigma is always False @@ -501,7 +487,7 @@ struct MMDiT : public GGMLBlock { // pos_embed_scaling_factor is not used // pos_embed_offset is not used // context_embedder_config is always {'target': 'torch.nn.Linear', 'params': {'in_features': 4096, 'out_features': 1536}} - if (version == VERSION_3_2B) { + if (version == VERSION_SD3_2B) { input_size = -1; patch_size = 2; in_channels = 16; @@ -669,7 +655,7 @@ struct MMDiTRunner : public GGMLRunner { MMDiTRunner(ggml_backend_t backend, ggml_type wtype, - SDVersion version = VERSION_3_2B) + SDVersion version = VERSION_SD3_2B) : GGMLRunner(backend, wtype), mmdit(version) { mmdit.init(params_ctx, wtype); } diff --git a/model.cpp b/model.cpp index 7ab2287bc..f5c070123 100644 --- a/model.cpp +++ b/model.cpp @@ -422,7 +422,10 @@ std::string convert_diffusers_name_to_compvis(std::string key, char seq) { return key; } -std::string convert_tensor_name(const std::string& name) { +std::string convert_tensor_name(std::string name) { + if (starts_with(name, "diffusion_model")) { + name = "model." + name; + } std::string new_name = name; if (starts_with(name, "cond_stage_model.") || starts_with(name, "conditioner.embedders.") || starts_with(name, "text_encoders.") || ends_with(name, ".vision_model.visual_projection.weight")) { new_name = convert_open_clip_to_hf_clip(name); @@ -554,6 +557,48 @@ float bf16_to_f32(uint16_t bfloat16) { return *reinterpret_cast(&val_bits); } +uint16_t f8_e4m3_to_f16(uint8_t f8) { + // do we need to support uz? + + const uint32_t exponent_bias = 7; + if (f8 == 0xff) { + return ggml_fp32_to_fp16(-NAN); + } else if (f8 == 0x7f) { + return ggml_fp32_to_fp16(NAN); + } + + uint32_t sign = f8 & 0x80; + uint32_t exponent = (f8 & 0x78) >> 3; + uint32_t mantissa = f8 & 0x07; + uint32_t result = sign << 24; + if (exponent == 0) { + if (mantissa > 0) { + exponent = 0x7f - exponent_bias; + + // yes, 2 times + if ((mantissa & 0x04) == 0) { + mantissa &= 0x03; + mantissa <<= 1; + exponent -= 1; + } + if ((mantissa & 0x04) == 0) { + mantissa &= 0x03; + mantissa <<= 1; + exponent -= 1; + } + + result |= (mantissa & 0x03) << 21; + result |= exponent << 23; + } + } else { + result |= mantissa << 20; + exponent += 0x7f - exponent_bias; + result |= exponent << 23; + } + + return ggml_fp32_to_fp16(*reinterpret_cast(&result)); +} + void bf16_to_f32_vec(uint16_t* src, float* dst, int64_t n) { // support inplace op for (int64_t i = n - 1; i >= 0; i--) { @@ -561,6 +606,13 @@ void bf16_to_f32_vec(uint16_t* src, float* dst, int64_t n) { } } +void f8_e4m3_to_f16_vec(uint8_t* src, uint16_t* dst, int64_t n) { + // support inplace op + for (int64_t i = n - 1; i >= 0; i--) { + dst[i] = f8_e4m3_to_f16(src[i]); + } +} + void convert_tensor(void* src, ggml_type src_type, void* dst, @@ -794,6 +846,8 @@ ggml_type str_to_ggml_type(const std::string& dtype) { ttype = GGML_TYPE_F32; } else if (dtype == "F32") { ttype = GGML_TYPE_F32; + } else if (dtype == "F8_E4M3") { + ttype = GGML_TYPE_F16; } return ttype; } @@ -866,7 +920,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const ggml_type type = str_to_ggml_type(dtype); if (type == GGML_TYPE_COUNT) { - LOG_ERROR("unsupported dtype '%s'", dtype.c_str()); + LOG_ERROR("unsupported dtype '%s' (tensor '%s')", dtype.c_str(), name.c_str()); return false; } @@ -903,6 +957,10 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const if (dtype == "BF16") { tensor_storage.is_bf16 = true; GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2); + } else if (dtype == "F8_E4M3") { + tensor_storage.is_f8_e4m3 = true; + // f8 -> f16 + GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2); } else { GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size); } @@ -1291,15 +1349,22 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s SDVersion ModelLoader::get_sd_version() { TensorStorage token_embedding_weight; + bool is_flux = false; for (auto& tensor_storage : tensor_storages) { + if (tensor_storage.name.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) { + return VERSION_FLUX_DEV; + } + if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) { + is_flux = true; + } if (tensor_storage.name.find("model.diffusion_model.joint_blocks.23.") != std::string::npos) { - return VERSION_3_2B; + return VERSION_SD3_2B; } if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos) { - return VERSION_XL; + return VERSION_SDXL; } if (tensor_storage.name.find("cond_stage_model.1") != std::string::npos) { - return VERSION_XL; + return VERSION_SDXL; } if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) { return VERSION_SVD; @@ -1315,10 +1380,13 @@ SDVersion ModelLoader::get_sd_version() { // break; } } + if (is_flux) { + return VERSION_FLUX_SCHNELL; + } if (token_embedding_weight.ne[0] == 768) { - return VERSION_1_x; + return VERSION_SD1; } else if (token_embedding_weight.ne[0] == 1024) { - return VERSION_2_x; + return VERSION_SD2; } return VERSION_COUNT; } @@ -1330,8 +1398,68 @@ ggml_type ModelLoader::get_sd_wtype() { } if (tensor_storage.name.find(".weight") != std::string::npos && - (tensor_storage.name.find("time_embed") != std::string::npos) || - tensor_storage.name.find("context_embedder") != std::string::npos) { + (tensor_storage.name.find("time_embed") != std::string::npos || + tensor_storage.name.find("context_embedder") != std::string::npos || + tensor_storage.name.find("time_in") != std::string::npos)) { + return tensor_storage.type; + } + } + return GGML_TYPE_COUNT; +} + +ggml_type ModelLoader::get_conditioner_wtype() { + for (auto& tensor_storage : tensor_storages) { + if (is_unused_tensor(tensor_storage.name)) { + continue; + } + + if ((tensor_storage.name.find("text_encoders") == std::string::npos && + tensor_storage.name.find("cond_stage_model") == std::string::npos && + tensor_storage.name.find("te.text_model.") == std::string::npos && + tensor_storage.name.find("conditioner") == std::string::npos)) { + continue; + } + + if (tensor_storage.name.find(".weight") != std::string::npos) { + return tensor_storage.type; + } + } + return GGML_TYPE_COUNT; +} + + +ggml_type ModelLoader::get_diffusion_model_wtype() { + for (auto& tensor_storage : tensor_storages) { + if (is_unused_tensor(tensor_storage.name)) { + continue; + } + + if (tensor_storage.name.find("model.diffusion_model.") == std::string::npos) { + continue; + } + + if (tensor_storage.name.find(".weight") != std::string::npos && + (tensor_storage.name.find("time_embed") != std::string::npos || + tensor_storage.name.find("context_embedder") != std::string::npos || + tensor_storage.name.find("time_in") != std::string::npos)) { + return tensor_storage.type; + } + } + return GGML_TYPE_COUNT; +} + +ggml_type ModelLoader::get_vae_wtype() { + for (auto& tensor_storage : tensor_storages) { + if (is_unused_tensor(tensor_storage.name)) { + continue; + } + + if (tensor_storage.name.find("vae.") == std::string::npos && + tensor_storage.name.find("first_stage_model") == std::string::npos) { + continue; + } + + if (tensor_storage.name.find(".weight")) { return tensor_storage.type; } } @@ -1467,6 +1595,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend if (tensor_storage.is_bf16) { // inplace op bf16_to_f32_vec((uint16_t*)dst_tensor->data, (float*)dst_tensor->data, tensor_storage.nelements()); + } else if (tensor_storage.is_f8_e4m3) { + // inplace op + f8_e4m3_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements()); } } else { read_buffer.resize(tensor_storage.nbytes()); @@ -1475,6 +1606,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend if (tensor_storage.is_bf16) { // inplace op bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements()); + } else if (tensor_storage.is_f8_e4m3) { + // inplace op + f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements()); } convert_tensor((void*)read_buffer.data(), tensor_storage.type, dst_tensor->data, @@ -1487,6 +1621,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend if (tensor_storage.is_bf16) { // inplace op bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements()); + } else if (tensor_storage.is_f8_e4m3) { + // inplace op + f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements()); } if (tensor_storage.type == dst_tensor->type) { @@ -1602,7 +1739,7 @@ bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type ggml_type tensor_type = tensor_storage.type; if (type != GGML_TYPE_COUNT) { - if (ggml_is_quantized(type) && tensor_storage.ne[0] % 32 != 0) { + if (ggml_is_quantized(type) && tensor_storage.ne[0] % ggml_blck_size(type) != 0) { tensor_type = GGML_TYPE_F16; } else { tensor_type = type; diff --git a/model.h b/model.h index 5bfce308f..f96c067e2 100644 --- a/model.h +++ b/model.h @@ -18,11 +18,13 @@ #define SD_MAX_DIMS 5 enum SDVersion { - VERSION_1_x, - VERSION_2_x, - VERSION_XL, + VERSION_SD1, + VERSION_SD2, + VERSION_SDXL, VERSION_SVD, - VERSION_3_2B, + VERSION_SD3_2B, + VERSION_FLUX_DEV, + VERSION_FLUX_SCHNELL, VERSION_COUNT, }; @@ -30,6 +32,7 @@ struct TensorStorage { std::string name; ggml_type type = GGML_TYPE_F32; bool is_bf16 = false; + bool is_f8_e4m3 = false; int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1}; int n_dims = 0; @@ -59,7 +62,7 @@ struct TensorStorage { } int64_t nbytes_to_read() const { - if (is_bf16) { + if (is_bf16 || is_f8_e4m3) { return nbytes() / 2; } else { return nbytes(); @@ -107,6 +110,8 @@ struct TensorStorage { const char* type_name = ggml_type_name(type); if (is_bf16) { type_name = "bf16"; + } else if (is_f8_e4m3) { + type_name = "f8_e4m3"; } ss << name << " | " << type_name << " | "; ss << n_dims << " ["; @@ -144,6 +149,9 @@ class ModelLoader { bool init_from_file(const std::string& file_path, const std::string& prefix = ""); SDVersion get_sd_version(); ggml_type get_sd_wtype(); + ggml_type get_conditioner_wtype(); + ggml_type get_diffusion_model_wtype(); + ggml_type get_vae_wtype(); bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend_t backend); bool load_tensors(std::map& tensors, ggml_backend_t backend, @@ -155,4 +163,6 @@ class ModelLoader { static std::string load_merges(); static std::string load_t5_tokenizer_json(); }; -#endif // __MODEL_H__ \ No newline at end of file + +#endif // __MODEL_H__ + diff --git a/pmid.hpp b/pmid.hpp index d1d8c3192..381050fef 100644 --- a/pmid.hpp +++ b/pmid.hpp @@ -161,7 +161,7 @@ struct PhotoMakerIDEncoderBlock : public CLIPVisionModelProjection { struct PhotoMakerIDEncoder : public GGMLRunner { public: - SDVersion version = VERSION_XL; + SDVersion version = VERSION_SDXL; PhotoMakerIDEncoderBlock id_encoder; float style_strength; @@ -175,7 +175,7 @@ struct PhotoMakerIDEncoder : public GGMLRunner { std::vector zeros_right; public: - PhotoMakerIDEncoder(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_XL, float sty = 20.f) + PhotoMakerIDEncoder(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_SDXL, float sty = 20.f) : GGMLRunner(backend, wtype), version(version), style_strength(sty) { diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 34bf8f527..619da299a 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -25,11 +25,13 @@ // #include "stb_image_write.h" const char* model_version_to_str[] = { - "1.x", - "2.x", - "XL", + "SD 1.x", + "SD 2.x", + "SDXL", "SVD", - "3 2B"}; + "SD3 2B", + "Flux Dev", + "Flux Schnell"}; const char* sampling_methods_str[] = { "Euler A", @@ -67,7 +69,11 @@ class StableDiffusionGGML { ggml_backend_t clip_backend = NULL; ggml_backend_t control_net_backend = NULL; ggml_backend_t vae_backend = NULL; - ggml_type model_data_type = GGML_TYPE_COUNT; + ggml_type model_wtype = GGML_TYPE_COUNT; + ggml_type conditioner_wtype = GGML_TYPE_COUNT; + ggml_type diffusion_model_wtype = GGML_TYPE_COUNT; + ggml_type vae_wtype = GGML_TYPE_COUNT; + SDVersion version; bool vae_decode_only = false; @@ -131,6 +137,9 @@ class StableDiffusionGGML { } bool load_from_file(const std::string& model_path, + const std::string& clip_l_path, + const std::string& t5xxl_path, + const std::string& diffusion_model_path, const std::string& vae_path, const std::string control_net_path, const std::string embeddings_path, @@ -164,14 +173,36 @@ class StableDiffusionGGML { LOG_INFO("Flash Attention enabled"); #endif #endif - LOG_INFO("loading model from '%s'", model_path.c_str()); ModelLoader model_loader; vae_tiling = vae_tiling_; - if (!model_loader.init_from_file(model_path)) { - LOG_ERROR("init model loader from file failed: '%s'", model_path.c_str()); - return false; + if (model_path.size() > 0) { + LOG_INFO("loading model from '%s'", model_path.c_str()); + if (!model_loader.init_from_file(model_path)) { + LOG_ERROR("init model loader from file failed: '%s'", model_path.c_str()); + } + } + + if (clip_l_path.size() > 0) { + LOG_INFO("loading clip_l from '%s'", clip_l_path.c_str()); + if (!model_loader.init_from_file(clip_l_path, "text_encoders.clip_l.")) { + LOG_WARN("loading clip_l from '%s' failed", clip_l_path.c_str()); + } + } + + if (t5xxl_path.size() > 0) { + LOG_INFO("loading t5xxl from '%s'", t5xxl_path.c_str()); + if (!model_loader.init_from_file(t5xxl_path, "text_encoders.t5xxl.")) { + LOG_WARN("loading t5xxl from '%s' failed", t5xxl_path.c_str()); + } + } + + if (diffusion_model_path.size() > 0) { + LOG_INFO("loading diffusion model from '%s'", diffusion_model_path.c_str()); + if (!model_loader.init_from_file(diffusion_model_path, "model.diffusion_model.")) { + LOG_WARN("loading diffusion model from '%s' failed", diffusion_model_path.c_str()); + } } if (vae_path.size() > 0) { @@ -187,16 +218,45 @@ class StableDiffusionGGML { return false; } - LOG_INFO("Stable Diffusion %s ", model_version_to_str[version]); + LOG_INFO("Version: %s ", model_version_to_str[version]); if (wtype == GGML_TYPE_COUNT) { - model_data_type = model_loader.get_sd_wtype(); + model_wtype = model_loader.get_sd_wtype(); + if (model_wtype == GGML_TYPE_COUNT) { + model_wtype = GGML_TYPE_F32; + LOG_WARN("can not get mode wtype frome weight, use f32"); + } + conditioner_wtype = model_loader.get_conditioner_wtype(); + if (conditioner_wtype == GGML_TYPE_COUNT) { + conditioner_wtype = wtype; + } + diffusion_model_wtype = model_loader.get_diffusion_model_wtype(); + if (diffusion_model_wtype == GGML_TYPE_COUNT) { + diffusion_model_wtype = wtype; + } + vae_wtype = model_loader.get_vae_wtype(); + + if (vae_wtype == GGML_TYPE_COUNT) { + vae_wtype = wtype; + } } else { - model_data_type = wtype; + model_wtype = wtype; + conditioner_wtype = wtype; + diffusion_model_wtype = wtype; + vae_wtype = wtype; + } + + if (version == VERSION_SDXL) { + vae_wtype = GGML_TYPE_F32; } - LOG_INFO("Stable Diffusion weight type: %s", ggml_type_name(model_data_type)); + + LOG_INFO("Weight type: %s", ggml_type_name(model_wtype)); + LOG_INFO("Conditioner weight type: %s", ggml_type_name(conditioner_wtype)); + LOG_INFO("Diffsuion model weight type: %s", ggml_type_name(diffusion_model_wtype)); + LOG_INFO("VAE weight type: %s", ggml_type_name(vae_wtype)); + LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor)); - if (version == VERSION_XL) { + if (version == VERSION_SDXL) { scale_factor = 0.13025f; if (vae_path.size() == 0 && taesd_path.size() == 0) { LOG_WARN( @@ -205,26 +265,33 @@ class StableDiffusionGGML { "try specifying SDXL VAE FP16 Fix with the --vae parameter. " "You can find it here: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors"); } - } else if (version == VERSION_3_2B) { + } else if (version == VERSION_SD3_2B) { scale_factor = 1.5305f; + } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + scale_factor = 0.3611; + // TODO: shift_factor } if (version == VERSION_SVD) { - clip_vision = std::make_shared(backend, model_data_type); + clip_vision = std::make_shared(backend, conditioner_wtype); clip_vision->alloc_params_buffer(); clip_vision->get_param_tensors(tensors); - diffusion_model = std::make_shared(backend, model_data_type, version); + diffusion_model = std::make_shared(backend, diffusion_model_wtype, version); diffusion_model->alloc_params_buffer(); diffusion_model->get_param_tensors(tensors); - first_stage_model = std::make_shared(backend, model_data_type, vae_decode_only, true, version); + first_stage_model = std::make_shared(backend, vae_wtype, vae_decode_only, true, version); LOG_DEBUG("vae_decode_only %d", vae_decode_only); first_stage_model->alloc_params_buffer(); first_stage_model->get_param_tensors(tensors, "first_stage_model"); } else { clip_backend = backend; - if (!ggml_backend_is_cpu(backend) && version == VERSION_3_2B && model_data_type != GGML_TYPE_F32) { + bool use_t5xxl = false; + if (version == VERSION_SD3_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + use_t5xxl = true; + } + if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) { clip_on_cpu = true; LOG_INFO("set clip_on_cpu to true"); } @@ -232,12 +299,15 @@ class StableDiffusionGGML { LOG_INFO("CLIP: Using CPU backend"); clip_backend = ggml_backend_cpu_init(); } - if (version == VERSION_3_2B) { - cond_stage_model = std::make_shared(clip_backend, model_data_type); - diffusion_model = std::make_shared(backend, model_data_type, version); + if (version == VERSION_SD3_2B) { + cond_stage_model = std::make_shared(clip_backend, conditioner_wtype); + diffusion_model = std::make_shared(backend, diffusion_model_wtype, version); + } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + cond_stage_model = std::make_shared(clip_backend, conditioner_wtype); + diffusion_model = std::make_shared(backend, diffusion_model_wtype, version); } else { - cond_stage_model = std::make_shared(clip_backend, model_data_type, embeddings_path, version); - diffusion_model = std::make_shared(backend, model_data_type, version); + cond_stage_model = std::make_shared(clip_backend, conditioner_wtype, embeddings_path, version); + diffusion_model = std::make_shared(backend, diffusion_model_wtype, version); } cond_stage_model->alloc_params_buffer(); cond_stage_model->get_param_tensors(tensors); @@ -245,11 +315,6 @@ class StableDiffusionGGML { diffusion_model->alloc_params_buffer(); diffusion_model->get_param_tensors(tensors); - ggml_type vae_type = model_data_type; - if (version == VERSION_XL) { - vae_type = GGML_TYPE_F32; // avoid nan, not work... - } - if (!use_tiny_autoencoder) { if (vae_on_cpu && !ggml_backend_is_cpu(backend)) { LOG_INFO("VAE Autoencoder: Using CPU backend"); @@ -257,11 +322,11 @@ class StableDiffusionGGML { } else { vae_backend = backend; } - first_stage_model = std::make_shared(vae_backend, vae_type, vae_decode_only, false, version); + first_stage_model = std::make_shared(vae_backend, vae_wtype, vae_decode_only, false, version); first_stage_model->alloc_params_buffer(); first_stage_model->get_param_tensors(tensors, "first_stage_model"); } else { - tae_first_stage = std::make_shared(backend, model_data_type, vae_decode_only); + tae_first_stage = std::make_shared(backend, vae_wtype, vae_decode_only); } // first_stage_model->get_param_tensors(tensors, "first_stage_model."); @@ -273,12 +338,12 @@ class StableDiffusionGGML { } else { controlnet_backend = backend; } - control_net = std::make_shared(controlnet_backend, model_data_type, version); + control_net = std::make_shared(controlnet_backend, diffusion_model_wtype, version); } - pmid_model = std::make_shared(clip_backend, model_data_type, version); + pmid_model = std::make_shared(clip_backend, model_wtype, version); if (id_embeddings_path.size() > 0) { - pmid_lora = std::make_shared(backend, model_data_type, id_embeddings_path, ""); + pmid_lora = std::make_shared(backend, model_wtype, id_embeddings_path, ""); if (!pmid_lora->load_from_file(true)) { LOG_WARN("load photomaker lora tensors from %s failed", id_embeddings_path.c_str()); return false; @@ -423,7 +488,7 @@ class StableDiffusionGGML { // check is_using_v_parameterization_for_sd2 bool is_using_v_parameterization = false; - if (version == VERSION_2_x) { + if (version == VERSION_SD2) { if (is_using_v_parameterization_for_sd2(ctx)) { is_using_v_parameterization = true; } @@ -432,9 +497,16 @@ class StableDiffusionGGML { is_using_v_parameterization = true; } - if (version == VERSION_3_2B) { + if (version == VERSION_SD3_2B) { LOG_INFO("running in FLOW mode"); denoiser = std::make_shared(); + } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + LOG_INFO("running in Flux FLOW mode"); + float shift = 1.15f; + if (version == VERSION_FLUX_SCHNELL) { + shift = 1.0f; // TODO: validate + } + denoiser = std::make_shared(shift); } else if (is_using_v_parameterization) { LOG_INFO("running in v-prediction mode"); denoiser = std::make_shared(); @@ -489,7 +561,7 @@ class StableDiffusionGGML { ggml_set_f32(timesteps, 999); int64_t t0 = ggml_time_ms(); struct ggml_tensor* out = ggml_dup_tensor(work_ctx, x_t); - diffusion_model->compute(n_threads, x_t, timesteps, c, NULL, NULL, -1, {}, 0.f, &out); + diffusion_model->compute(n_threads, x_t, timesteps, c, NULL, NULL, NULL, -1, {}, 0.f, &out); diffusion_model->free_compute_buffer(); double result = 0.f; @@ -522,7 +594,7 @@ class StableDiffusionGGML { LOG_WARN("can not find %s or %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), lora_name.c_str()); return; } - LoraModel lora(backend, model_data_type, file_path); + LoraModel lora(backend, model_wtype, file_path); if (!lora.load_from_file()) { LOG_WARN("load lora tensors from %s failed", file_path.c_str()); return; @@ -538,7 +610,7 @@ class StableDiffusionGGML { } void apply_loras(const std::unordered_map& lora_state) { - if (lora_state.size() > 0 && model_data_type != GGML_TYPE_F16 && model_data_type != GGML_TYPE_F32) { + if (lora_state.size() > 0 && model_wtype != GGML_TYPE_F16 && model_wtype != GGML_TYPE_F32) { LOG_WARN("In quantized models when applying LoRA, the images have poor quality."); } std::unordered_map lora_state_diff; @@ -663,6 +735,7 @@ class StableDiffusionGGML { float control_strength, float min_cfg, float cfg_scale, + float guidance, sample_method_t method, const std::vector& sigmas, int start_merge_step, @@ -701,6 +774,8 @@ class StableDiffusionGGML { float t = denoiser->sigma_to_t(sigma); std::vector timesteps_vec(x->ne[3], t); // [N, ] auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec); + std::vector guidance_vec(x->ne[3], guidance); + auto guidance_tensor = vector_to_ggml_tensor(work_ctx, guidance_vec); copy_ggml_tensor(noised_input, input); // noised_input = noised_input * c_in @@ -723,6 +798,7 @@ class StableDiffusionGGML { cond.c_crossattn, cond.c_concat, cond.c_vector, + guidance_tensor, -1, controls, control_strength, @@ -734,6 +810,7 @@ class StableDiffusionGGML { id_cond.c_crossattn, cond.c_concat, id_cond.c_vector, + guidance_tensor, -1, controls, control_strength, @@ -753,6 +830,7 @@ class StableDiffusionGGML { uncond.c_crossattn, uncond.c_concat, uncond.c_vector, + guidance_tensor, -1, controls, control_strength, @@ -838,7 +916,9 @@ class StableDiffusionGGML { if (use_tiny_autoencoder) { C = 4; } else { - if (version == VERSION_3_2B) { + if (version == VERSION_SD3_2B) { + C = 32; + } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { C = 32; } } @@ -904,6 +984,9 @@ struct sd_ctx_t { }; sd_ctx_t* new_sd_ctx(const char* model_path_c_str, + const char* clip_l_path_c_str, + const char* t5xxl_path_c_str, + const char* diffusion_model_path_c_str, const char* vae_path_c_str, const char* taesd_path_c_str, const char* control_net_path_c_str, @@ -925,6 +1008,9 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, return NULL; } std::string model_path(model_path_c_str); + std::string clip_l_path(clip_l_path_c_str); + std::string t5xxl_path(t5xxl_path_c_str); + std::string diffusion_model_path(diffusion_model_path_c_str); std::string vae_path(vae_path_c_str); std::string taesd_path(taesd_path_c_str); std::string control_net_path(control_net_path_c_str); @@ -942,6 +1028,9 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, } if (!sd_ctx->sd->load_from_file(model_path, + clip_l_path, + t5xxl_path_c_str, + diffusion_model_path, vae_path, control_net_path, embd_path, @@ -976,6 +1065,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, std::string negative_prompt, int clip_skip, float cfg_scale, + float guidance, int width, int height, enum sample_method_t sample_method, @@ -1127,7 +1217,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, SDCondition uncond; if (cfg_scale != 1.0) { bool force_zero_embeddings = false; - if (sd_ctx->sd->version == VERSION_XL && negative_prompt.size() == 0) { + if (sd_ctx->sd->version == VERSION_SDXL && negative_prompt.size() == 0) { force_zero_embeddings = true; } uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx, @@ -1156,7 +1246,9 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, // Sample std::vector final_latents; // collect latents to decode int C = 4; - if (sd_ctx->sd->version == VERSION_3_2B) { + if (sd_ctx->sd->version == VERSION_SD3_2B) { + C = 16; + } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { C = 16; } int W = width / 8; @@ -1189,6 +1281,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, control_strength, cfg_scale, cfg_scale, + guidance, sample_method, sigmas, start_merge_step, @@ -1247,6 +1340,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, const char* negative_prompt_c_str, int clip_skip, float cfg_scale, + float guidance, int width, int height, enum sample_method_t sample_method, @@ -1265,9 +1359,12 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, struct ggml_init_params params; params.mem_size = static_cast(10 * 1024 * 1024); // 10 MB - if (sd_ctx->sd->version == VERSION_3_2B) { + if (sd_ctx->sd->version == VERSION_SD3_2B) { params.mem_size *= 3; } + if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { + params.mem_size *= 4; + } if (sd_ctx->sd->stacked_id) { params.mem_size += static_cast(10 * 1024 * 1024); // 10 MB } @@ -1288,14 +1385,18 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps); int C = 4; - if (sd_ctx->sd->version == VERSION_3_2B) { + if (sd_ctx->sd->version == VERSION_SD3_2B) { + C = 16; + } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { C = 16; } int W = width / 8; int H = height / 8; ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1); - if (sd_ctx->sd->version == VERSION_3_2B) { + if (sd_ctx->sd->version == VERSION_SD3_2B) { ggml_set_f32(init_latent, 0.0609f); + } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { + ggml_set_f32(init_latent, 0.1159f); } else { ggml_set_f32(init_latent, 0.f); } @@ -1307,6 +1408,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, negative_prompt_c_str, clip_skip, cfg_scale, + guidance, width, height, sample_method, @@ -1332,6 +1434,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, const char* negative_prompt_c_str, int clip_skip, float cfg_scale, + float guidance, int width, int height, sample_method_t sample_method, @@ -1351,9 +1454,12 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, struct ggml_init_params params; params.mem_size = static_cast(10 * 1024 * 1024); // 10 MB - if (sd_ctx->sd->version == VERSION_3_2B) { + if (sd_ctx->sd->version == VERSION_SD3_2B) { params.mem_size *= 2; } + if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { + params.mem_size *= 3; + } if (sd_ctx->sd->stacked_id) { params.mem_size += static_cast(10 * 1024 * 1024); // 10 MB } @@ -1403,6 +1509,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, negative_prompt_c_str, clip_skip, cfg_scale, + guidance, width, height, sample_method, @@ -1510,6 +1617,7 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx, 0.f, min_cfg, cfg_scale, + 0.f, sample_method, sigmas, -1, diff --git a/stable-diffusion.h b/stable-diffusion.h index f78748faf..0225b34c1 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -119,6 +119,9 @@ typedef struct { typedef struct sd_ctx_t sd_ctx_t; SD_API sd_ctx_t* new_sd_ctx(const char* model_path, + const char* clip_l_path, + const char* t5xxl_path, + const char* diffusion_model_path, const char* vae_path, const char* taesd_path, const char* control_net_path_c_str, @@ -143,6 +146,7 @@ SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx, const char* negative_prompt, int clip_skip, float cfg_scale, + float guidance, int width, int height, enum sample_method_t sample_method, @@ -161,6 +165,7 @@ SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx, const char* negative_prompt, int clip_skip, float cfg_scale, + float guidance, int width, int height, enum sample_method_t sample_method, diff --git a/unet.hpp b/unet.hpp index 737a2bbec..94a8ba46a 100644 --- a/unet.hpp +++ b/unet.hpp @@ -166,7 +166,7 @@ class SpatialVideoTransformer : public SpatialTransformer { // ldm.modules.diffusionmodules.openaimodel.UNetModel class UnetModelBlock : public GGMLBlock { protected: - SDVersion version = VERSION_1_x; + SDVersion version = VERSION_SD1; // network hparams int in_channels = 4; int out_channels = 4; @@ -177,19 +177,19 @@ class UnetModelBlock : public GGMLBlock { int time_embed_dim = 1280; // model_channels*4 int num_heads = 8; int num_head_channels = -1; // channels // num_heads - int context_dim = 768; // 1024 for VERSION_2_x, 2048 for VERSION_XL + int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL public: int model_channels = 320; - int adm_in_channels = 2816; // only for VERSION_XL/SVD + int adm_in_channels = 2816; // only for VERSION_SDXL/SVD - UnetModelBlock(SDVersion version = VERSION_1_x) + UnetModelBlock(SDVersion version = VERSION_SD1) : version(version) { - if (version == VERSION_2_x) { + if (version == VERSION_SD2) { context_dim = 1024; num_head_channels = 64; num_heads = -1; - } else if (version == VERSION_XL) { + } else if (version == VERSION_SDXL) { context_dim = 2048; attention_resolutions = {4, 2}; channel_mult = {1, 2, 4}; @@ -211,7 +211,7 @@ class UnetModelBlock : public GGMLBlock { // time_embed_1 is nn.SiLU() blocks["time_embed.2"] = std::shared_ptr(new Linear(time_embed_dim, time_embed_dim)); - if (version == VERSION_XL || version == VERSION_SVD) { + if (version == VERSION_SDXL || version == VERSION_SVD) { blocks["label_emb.0.0"] = std::shared_ptr(new Linear(adm_in_channels, time_embed_dim)); // label_emb_1 is nn.SiLU() blocks["label_emb.0.2"] = std::shared_ptr(new Linear(time_embed_dim, time_embed_dim)); @@ -533,7 +533,7 @@ struct UNetModelRunner : public GGMLRunner { UNetModelRunner(ggml_backend_t backend, ggml_type wtype, - SDVersion version = VERSION_1_x) + SDVersion version = VERSION_SD1) : GGMLRunner(backend, wtype), unet(version) { unet.init(params_ctx, wtype); } diff --git a/vae.hpp b/vae.hpp index cb8112d4a..85319fdee 100644 --- a/vae.hpp +++ b/vae.hpp @@ -455,9 +455,9 @@ class AutoencodingEngine : public GGMLBlock { public: AutoencodingEngine(bool decode_only = true, bool use_video_decoder = false, - SDVersion version = VERSION_1_x) + SDVersion version = VERSION_SD1) : decode_only(decode_only), use_video_decoder(use_video_decoder) { - if (version == VERSION_3_2B) { + if (version == VERSION_SD3_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { dd_config.z_channels = 16; use_quant = false; } @@ -527,7 +527,7 @@ struct AutoEncoderKL : public GGMLRunner { ggml_type wtype, bool decode_only = false, bool use_video_decoder = false, - SDVersion version = VERSION_1_x) + SDVersion version = VERSION_SD1) : decode_only(decode_only), ae(decode_only, use_video_decoder, version), GGMLRunner(backend, wtype) { ae.init(params_ctx, wtype); }