diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index fe1410891..a16e692ec 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -155,17 +155,17 @@ jobs: matrix: include: - build: "noavx" - defines: "-DGGML_AVX=OFF -DGGML_AVX2=OFF -DGGML_FMA=OFF -DSD_BUILD_SHARED_LIBS=ON" + defines: "-DGGML_NATIVE=OFF -DGGML_AVX=OFF -DGGML_AVX2=OFF -DGGML_FMA=OFF -DSD_BUILD_SHARED_LIBS=ON" - build: "avx2" - defines: "-DGGML_AVX2=ON -DSD_BUILD_SHARED_LIBS=ON" + defines: "-DGGML_NATIVE=OFF -DGGML_AVX2=ON -DSD_BUILD_SHARED_LIBS=ON" - build: "avx" - defines: "-DGGML_AVX2=OFF -DSD_BUILD_SHARED_LIBS=ON" + defines: "-DGGML_NATIVE=OFF -DGGML_AVX=ON -DGGML_AVX2=OFF -DSD_BUILD_SHARED_LIBS=ON" - build: "avx512" - defines: "-DGGML_AVX512=ON -DSD_BUILD_SHARED_LIBS=ON" + defines: "-DGGML_NATIVE=OFF -DGGML_AVX512=ON -DGGML_AVX=ON -DGGML_AVX2=ON -DSD_BUILD_SHARED_LIBS=ON" - build: "cuda12" - defines: "-DSD_CUBLAS=ON -DSD_BUILD_SHARED_LIBS=ON" - - build: "rocm5.5" - defines: '-G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DSD_HIPBLAS=ON -DCMAKE_BUILD_TYPE=Release -DAMDGPU_TARGETS="gfx1100;gfx1102;gfx1030" -DSD_BUILD_SHARED_LIBS=ON' + defines: "-DSD_CUDA=ON -DSD_BUILD_SHARED_LIBS=ON -DCMAKE_CUDA_ARCHITECTURES=90;89;80;75" + # - build: "rocm5.5" + # defines: '-G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DSD_HIPBLAS=ON -DCMAKE_BUILD_TYPE=Release -DAMDGPU_TARGETS="gfx1100;gfx1102;gfx1030" -DSD_BUILD_SHARED_LIBS=ON' - build: 'vulkan' defines: "-DSD_VULKAN=ON -DSD_BUILD_SHARED_LIBS=ON" steps: @@ -178,9 +178,9 @@ jobs: - name: Install cuda-toolkit id: cuda-toolkit if: ${{ matrix.build == 'cuda12' }} - uses: Jimver/cuda-toolkit@v0.2.11 + uses: Jimver/cuda-toolkit@v0.2.19 with: - cuda: "12.2.0" + cuda: "12.6.2" method: "network" sub-packages: '["nvcc", "cudart", "cublas", "cublas_dev", "thrust", "visual_studio_integration"]' diff --git a/CMakeLists.txt b/CMakeLists.txt index c993e7c96..782a893e4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,20 +24,20 @@ endif() # general #option(SD_BUILD_TESTS "sd: build tests" ${SD_STANDALONE}) option(SD_BUILD_EXAMPLES "sd: build examples" ${SD_STANDALONE}) -option(SD_CUBLAS "sd: cuda backend" OFF) +option(SD_CUDA "sd: cuda backend" OFF) option(SD_HIPBLAS "sd: rocm backend" OFF) option(SD_METAL "sd: metal backend" OFF) option(SD_VULKAN "sd: vulkan backend" OFF) option(SD_SYCL "sd: sycl backend" OFF) -option(SD_FLASH_ATTN "sd: use flash attention for x4 less memory usage" OFF) +option(SD_MUSA "sd: musa backend" OFF) option(SD_FAST_SOFTMAX "sd: x1.5 faster softmax, indeterministic (sometimes, same seed don't generate same image), cuda only" OFF) option(SD_BUILD_SHARED_LIBS "sd: build shared libs" OFF) #option(SD_BUILD_SERVER "sd: build server example" ON) -if(SD_CUBLAS) - message("-- Use CUBLAS as backend stable-diffusion") +if(SD_CUDA) + message("-- Use CUDA as backend stable-diffusion") set(GGML_CUDA ON) - add_definitions(-DSD_USE_CUBLAS) + add_definitions(-DSD_USE_CUDA) endif() if(SD_METAL) @@ -54,21 +54,25 @@ endif () if (SD_HIPBLAS) message("-- Use HIPBLAS as backend stable-diffusion") - set(GGML_HIPBLAS ON) - add_definitions(-DSD_USE_CUBLAS) + set(GGML_HIP ON) + add_definitions(-DSD_USE_CUDA) if(SD_FAST_SOFTMAX) set(GGML_CUDA_FAST_SOFTMAX ON) endif() endif () -if(SD_FLASH_ATTN) - message("-- Use Flash Attention for memory optimization") - add_definitions(-DSD_USE_FLASH_ATTENTION) +if(SD_MUSA) + message("-- Use MUSA as backend stable-diffusion") + set(GGML_MUSA ON) + add_definitions(-DSD_USE_CUDA) + if(SD_FAST_SOFTMAX) + set(GGML_CUDA_FAST_SOFTMAX ON) + endif() endif() set(SD_LIB stable-diffusion) -file(GLOB SD_LIB_SOURCES +file(GLOB SD_LIB_SOURCES "*.h" "*.cpp" "*.hpp" @@ -92,6 +96,7 @@ endif() if(SD_SYCL) message("-- Use SYCL as backend stable-diffusion") set(GGML_SYCL ON) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing -fsycl") add_definitions(-DSD_USE_SYCL) # disable fast-math on host, see: # https://www.intel.com/content/www/us/en/docs/cpp-compiler/developer-guide-reference/2021-10/fp-model-fp.html diff --git a/Dockerfile.musa b/Dockerfile.musa new file mode 100644 index 000000000..0adcb7ee5 --- /dev/null +++ b/Dockerfile.musa @@ -0,0 +1,19 @@ +ARG MUSA_VERSION=rc3.1.1 + +FROM mthreads/musa:${MUSA_VERSION}-devel-ubuntu22.04 as build + +RUN apt-get update && apt-get install -y cmake + +WORKDIR /sd.cpp + +COPY . . + +RUN mkdir build && cd build && \ + cmake .. -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DSD_MUSA=ON -DCMAKE_BUILD_TYPE=Release && \ + cmake --build . --config Release + +FROM mthreads/musa:${MUSA_VERSION}-runtime-ubuntu22.04 as runtime + +COPY --from=build /sd.cpp/build/bin/sd /sd + +ENTRYPOINT [ "/sd" ] \ No newline at end of file diff --git a/README.md b/README.md index 0fe607e27..553fb7f8f 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ Inference of Stable Diffusion and Flux in pure C/C++ - Plain C/C++ implementation based on [ggml](https://github.com/ggerganov/ggml), working in the same way as [llama.cpp](https://github.com/ggerganov/llama.cpp) - Super lightweight and without external dependencies -- SD1.x, SD2.x, SDXL and SD3 support +- SD1.x, SD2.x, SDXL and [SD3/SD3.5](./docs/sd3.md) support - !!!The VAE in SDXL encounters NaN issues under FP16, but unfortunately, the ggml_conv_2d only operates under FP16. Hence, a parameter is needed to specify the VAE that has fixed the FP16 NaN issue. You can find it here: [SDXL VAE FP16 Fix](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors). - [Flux-dev/Flux-schnell Support](./docs/flux.md) @@ -24,7 +24,7 @@ Inference of Stable Diffusion and Flux in pure C/C++ - Full CUDA, Metal, Vulkan and SYCL backend for GPU acceleration. - Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs models - No need to convert to `.ggml` or `.gguf` anymore! -- Flash Attention for memory usage optimization (only cpu for now) +- Flash Attention for memory usage optimization - Original `txt2img` and `img2img` mode - Negative prompt - [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) style tokenizer (not all the features, only token weighting for now) @@ -113,12 +113,12 @@ cmake .. -DGGML_OPENBLAS=ON cmake --build . --config Release ``` -##### Using CUBLAS +##### Using CUDA This provides BLAS acceleration using the CUDA cores of your Nvidia GPU. Make sure to have the CUDA toolkit installed. You can download it from your Linux distro's package manager (e.g. `apt install nvidia-cuda-toolkit`) or from here: [CUDA Toolkit](https://developer.nvidia.com/cuda-downloads). Recommended to have at least 4 GB of VRAM. ``` -cmake .. -DSD_CUBLAS=ON +cmake .. -DSD_CUDA=ON cmake --build . --config Release ``` @@ -132,6 +132,14 @@ cmake .. -G "Ninja" -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DSD_H cmake --build . --config Release ``` +##### Using MUSA + +This provides BLAS acceleration using the MUSA cores of your Moore Threads GPU. Make sure to have the MUSA toolkit installed. + +```bash +cmake .. -DCMAKE_C_COMPILER=/usr/local/musa/bin/clang -DCMAKE_CXX_COMPILER=/usr/local/musa/bin/clang++ -DSD_MUSA=ON -DCMAKE_BUILD_TYPE=Release +cmake --build . --config Release +``` ##### Using Metal @@ -182,11 +190,21 @@ Example of text2img by using SYCL backend: ##### Using Flash Attention -Enabling flash attention reduces memory usage by at least 400 MB. At the moment, it is not supported when CUBLAS is enabled because the kernel implementation is missing. +Enabling flash attention for the diffusion model reduces memory usage by varying amounts of MB. +eg.: + - flux 768x768 ~600mb + - SD2 768x768 ~1400mb +For most backends, it slows things down, but for cuda it generally speeds it up too. +At the moment, it is only supported for some models and some backends (like cpu, cuda/rocm, metal). + +Run by adding `--diffusion-fa` to the arguments and watch for: ``` -cmake .. -DSD_FLASH_ATTN=ON -cmake --build . --config Release +[INFO ] stable-diffusion.cpp:312 - Using flash attention in the diffusion model +``` +and the compute buffer shrink in the debug log: +``` +[DEBUG] ggml_extend.hpp:1004 - flux compute buffer size: 650.00 MB(VRAM) ``` ### Run @@ -197,23 +215,24 @@ usage: ./bin/sd [arguments] arguments: -h, --help show this help message and exit -M, --mode [MODEL] run mode (txt2img or img2img or convert, default: txt2img) - -t, --threads N number of threads to use during computation (default: -1). + -t, --threads N number of threads to use during computation (default: -1) If threads <= 0, then threads will be set to the number of CPU physical cores -m, --model [MODEL] path to full model --diffusion-model path to the standalone diffusion model --clip_l path to the clip-l text encoder - --t5xxl path to the the t5xxl text encoder. + --clip_g path to the clip-l text encoder + --t5xxl path to the the t5xxl text encoder --vae [VAE] path to vae --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality) --control-net [CONTROL_PATH] path to control net model - --embd-dir [EMBEDDING_PATH] path to embeddings. - --stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings. - --input-id-images-dir [DIR] path to PHOTOMAKER input id images dir. + --embd-dir [EMBEDDING_PATH] path to embeddings + --stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings + --input-id-images-dir [DIR] path to PHOTOMAKER input id images dir --normalize-input normalize PHOTOMAKER input id images - --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now. + --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now --upscale-repeats Run the ESRGAN upscaler this many times (default 1) --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k) - If not specified, the default is the type of the weight file. + If not specified, the default is the type of the weight file --lora-model-dir [DIR] lora model directory -i, --init-img [IMAGE] path to the input image, required by img2img --control-image [IMAGE] path to image condition, control net @@ -221,6 +240,10 @@ arguments: -p, --prompt [PROMPT] the prompt to render -n, --negative-prompt PROMPT the negative prompt (default: "") --cfg-scale SCALE unconditional guidance scale: (default: 7.0) + --skip-layers LAYERS Layers to skip for SLG steps: (default: [7,8,9]) + --skip-layer-start START SLG enabling point: (default: 0.01) + --skip-layer-end END SLG disabling point: (default: 0.2) + SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END]) --strength STRENGTH strength for noising/unnoising (default: 0.75) --style-ratio STYLE-RATIO strength for keeping input identity (default: 20%) --control-strength STRENGTH strength to apply Control Net (default: 0.9) @@ -232,13 +255,16 @@ arguments: --steps STEPS number of sample steps (default: 20) --rng {std_default, cuda} RNG (default: cuda) -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0) - -b, --batch-count COUNT number of images to generate. + -b, --batch-count COUNT number of images to generate --schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete) --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1) <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x --vae-tiling process vae in tiles to reduce memory usage --vae-on-cpu keep vae in cpu (for low vram) - --clip-on-cpu keep clip in cpu (for low vram). + --clip-on-cpu keep clip in cpu (for low vram) + --diffusion-fa use flash attention in the diffusion model (for low vram) + Might lower quality, since it implies converting k and v to f16. + This might crash if it is not supported by the backend. --control-net-cpu keep controlnet in cpu (for low vram) --canny apply canny preprocessor (edge detection) --color Colors the logging tags according to level @@ -253,6 +279,7 @@ arguments: # ./bin/sd -m ../models/sd_xl_base_1.0.safetensors --vae ../models/sdxl_vae-fp16-fix.safetensors -H 1024 -W 1024 -p "a lovely cat" -v # ./bin/sd -m ../models/sd3_medium_incl_clips_t5xxlfp16.safetensors -H 1024 -W 1024 -p 'a lovely cat holding a sign says \"Stable Diffusion CPP\"' --cfg-scale 4.5 --sampling-method euler -v # ./bin/sd --diffusion-model ../models/flux1-dev-q3_k.gguf --vae ../models/ae.sft --clip_l ../models/clip_l.safetensors --t5xxl ../models/t5xxl_fp16.safetensors -p "a lovely cat holding a sign says 'flux.cpp'" --cfg-scale 1.0 --sampling-method euler -v +# ./bin/sd -m ..\models\sd3.5_large.safetensors --clip_l ..\models\clip_l.safetensors --clip_g ..\models\clip_g.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -H 1024 -W 1024 -p 'a lovely cat holding a sign says \"Stable diffusion 3.5 Large\"' --cfg-scale 4.5 --sampling-method euler -v ``` Using formats of different precisions will yield results of varying quality. @@ -290,12 +317,16 @@ These projects wrap `stable-diffusion.cpp` for easier use in other languages/fra * Golang: [seasonjs/stable-diffusion](https://github.com/seasonjs/stable-diffusion) * C#: [DarthAffe/StableDiffusion.NET](https://github.com/DarthAffe/StableDiffusion.NET) +* Python: [william-murray1204/stable-diffusion-cpp-python](https://github.com/william-murray1204/stable-diffusion-cpp-python) +* Rust: [newfla/diffusion-rs](https://github.com/newfla/diffusion-rs) ## UIs These projects use `stable-diffusion.cpp` as a backend for their image generation. - [Jellybox](https://jellybox.com) +- [Stable Diffusion GUI](https://github.com/fszontagh/sd.cpp.gui.wx) +- [Stable Diffusion CLI-GUI](https://github.com/piallai/stable-diffusion.cpp) ## Contributors diff --git a/assets/sd3.5_large.png b/assets/sd3.5_large.png new file mode 100644 index 000000000..b76b13225 Binary files /dev/null and b/assets/sd3.5_large.png differ diff --git a/clip.hpp b/clip.hpp index f9ac631a8..2307ee3c5 100644 --- a/clip.hpp +++ b/clip.hpp @@ -343,6 +343,13 @@ class CLIPTokenizer { } } + std::string clean_up_tokenization(std::string& text) { + std::regex pattern(R"( ,)"); + // Replace " ," with "," + std::string result = std::regex_replace(text, pattern, ","); + return result; + } + std::string decode(const std::vector& tokens) { std::string text = ""; for (int t : tokens) { @@ -351,8 +358,12 @@ class CLIPTokenizer { std::u32string ts = decoder[t]; // printf("%d, %s \n", t, utf32_to_utf8(ts).c_str()); std::string s = utf32_to_utf8(ts); - if (s.length() >= 4 && ends_with(s, "")) { - text += " " + s.replace(s.length() - 4, s.length() - 1, ""); + if (s.length() >= 4) { + if (ends_with(s, "")) { + text += s.replace(s.length() - 4, s.length() - 1, "") + " "; + } else { + text += s; + } } else { text += " " + s; } @@ -364,6 +375,7 @@ class CLIPTokenizer { // std::string s((char *)bytes.data()); // std::string s = ""; + text = clean_up_tokenization(text); return trim(text); } @@ -533,9 +545,12 @@ class CLIPEmbeddings : public GGMLBlock { int64_t vocab_size; int64_t num_positions; - void init_params(struct ggml_context* ctx, ggml_type wtype) { - params["token_embedding.weight"] = ggml_new_tensor_2d(ctx, wtype, embed_dim, vocab_size); - params["position_embedding.weight"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, embed_dim, num_positions); + void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { + enum ggml_type token_wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "token_embedding.weight") != tensor_types.end()) ? tensor_types[prefix + "token_embedding.weight"] : GGML_TYPE_F32; + enum ggml_type position_wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "position_embedding.weight") != tensor_types.end()) ? tensor_types[prefix + "position_embedding.weight"] : GGML_TYPE_F32; + + params["token_embedding.weight"] = ggml_new_tensor_2d(ctx, token_wtype, embed_dim, vocab_size); + params["position_embedding.weight"] = ggml_new_tensor_2d(ctx, position_wtype, embed_dim, num_positions); } public: @@ -579,11 +594,14 @@ class CLIPVisionEmbeddings : public GGMLBlock { int64_t image_size; int64_t num_patches; int64_t num_positions; + void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { + enum ggml_type patch_wtype = GGML_TYPE_F16; // tensor_types.find(prefix + "patch_embedding.weight") != tensor_types.end() ? tensor_types[prefix + "patch_embedding.weight"] : GGML_TYPE_F16; + enum ggml_type class_wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "class_embedding") != tensor_types.end() ? tensor_types[prefix + "class_embedding"] : GGML_TYPE_F32; + enum ggml_type position_wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "position_embedding.weight") != tensor_types.end() ? tensor_types[prefix + "position_embedding.weight"] : GGML_TYPE_F32; - void init_params(struct ggml_context* ctx, ggml_type wtype) { - params["patch_embedding.weight"] = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, patch_size, patch_size, num_channels, embed_dim); - params["class_embedding"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, embed_dim); - params["position_embedding.weight"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, embed_dim, num_positions); + params["patch_embedding.weight"] = ggml_new_tensor_4d(ctx, patch_wtype, patch_size, patch_size, num_channels, embed_dim); + params["class_embedding"] = ggml_new_tensor_1d(ctx, class_wtype, embed_dim); + params["position_embedding.weight"] = ggml_new_tensor_2d(ctx, position_wtype, embed_dim, num_positions); } public: @@ -639,9 +657,10 @@ enum CLIPVersion { class CLIPTextModel : public GGMLBlock { protected: - void init_params(struct ggml_context* ctx, ggml_type wtype) { + void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { if (version == OPEN_CLIP_VIT_BIGG_14) { - params["text_projection"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, projection_dim, hidden_size); + enum ggml_type wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "text_projection") != tensor_types.end() ? tensor_types[prefix + "text_projection"] : GGML_TYPE_F32; + params["text_projection"] = ggml_new_tensor_2d(ctx, wtype, projection_dim, hidden_size); } } @@ -711,8 +730,12 @@ class CLIPTextModel : public GGMLBlock { if (return_pooled) { auto text_projection = params["text_projection"]; ggml_tensor* pooled = ggml_view_1d(ctx, x, hidden_size, x->nb[1] * max_token_idx); - pooled = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_transpose(ctx, text_projection)), pooled); - return pooled; + if (text_projection != NULL) { + pooled = ggml_nn_linear(ctx, pooled, text_projection, NULL); + } else { + LOG_DEBUG("Missing text_projection matrix, assuming identity..."); + } + return pooled; // [hidden_size, 1, 1] } return x; // [N, n_token, hidden_size] @@ -761,14 +784,17 @@ class CLIPVisionModel : public GGMLBlock { auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim] x = pre_layernorm->forward(ctx, x); x = encoder->forward(ctx, x, -1, false); - x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size] + // print_ggml_tensor(x, true, "ClipVisionModel x: "); + auto last_hidden_state = x; + x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size] GGML_ASSERT(x->ne[3] == 1); if (return_pooled) { ggml_tensor* pooled = ggml_cont(ctx, ggml_view_2d(ctx, x, x->ne[0], x->ne[2], x->nb[2], 0)); return pooled; // [N, hidden_size] } else { - return x; // [N, n_token, hidden_size] + // return x; // [N, n_token, hidden_size] + return last_hidden_state; // [N, n_token, hidden_size] } } }; @@ -779,9 +805,9 @@ class CLIPProjection : public UnaryBlock { int64_t out_features; bool transpose_weight; - void init_params(struct ggml_context* ctx, ggml_type wtype) { + void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { + enum ggml_type wtype = tensor_types.find(prefix + "weight") != tensor_types.end() ? tensor_types[prefix + "weight"] : GGML_TYPE_F32; if (transpose_weight) { - LOG_ERROR("transpose_weight"); params["weight"] = ggml_new_tensor_2d(ctx, wtype, out_features, in_features); } else { params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features); @@ -842,12 +868,13 @@ struct CLIPTextModelRunner : public GGMLRunner { CLIPTextModel model; CLIPTextModelRunner(ggml_backend_t backend, - ggml_type wtype, + std::map& tensor_types, + const std::string prefix, CLIPVersion version = OPENAI_CLIP_VIT_L_14, int clip_skip_value = 1, bool with_final_ln = true) - : GGMLRunner(backend, wtype), model(version, clip_skip_value, with_final_ln) { - model.init(params_ctx, wtype); + : GGMLRunner(backend), model(version, clip_skip_value, with_final_ln) { + model.init(params_ctx, tensor_types, prefix); } std::string get_desc() { @@ -889,13 +916,13 @@ struct CLIPTextModelRunner : public GGMLRunner { struct ggml_tensor* embeddings = NULL; if (num_custom_embeddings > 0 && custom_embeddings_data != NULL) { - auto custom_embeddings = ggml_new_tensor_2d(compute_ctx, - wtype, - model.hidden_size, - num_custom_embeddings); + auto token_embed_weight = model.get_token_embed_weight(); + auto custom_embeddings = ggml_new_tensor_2d(compute_ctx, + token_embed_weight->type, + model.hidden_size, + num_custom_embeddings); set_backend_tensor_data(custom_embeddings, custom_embeddings_data); - auto token_embed_weight = model.get_token_embed_weight(); // concatenate custom embeddings embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings, 1); } diff --git a/common.hpp b/common.hpp index b18ee51f5..337b4a0c4 100644 --- a/common.hpp +++ b/common.hpp @@ -182,9 +182,11 @@ class GEGLU : public GGMLBlock { int64_t dim_in; int64_t dim_out; - void init_params(struct ggml_context* ctx, ggml_type wtype) { - params["proj.weight"] = ggml_new_tensor_2d(ctx, wtype, dim_in, dim_out * 2); - params["proj.bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim_out * 2); + void init_params(struct ggml_context* ctx, std::map& tensor_types, std::string prefix = "") { + enum ggml_type wtype = (tensor_types.find(prefix + "proj.weight") != tensor_types.end()) ? tensor_types[prefix + "proj.weight"] : GGML_TYPE_F32; + enum ggml_type bias_wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "proj.bias") != tensor_types.end()) ? tensor_types[prefix + "proj.bias"] : GGML_TYPE_F32; + params["proj.weight"] = ggml_new_tensor_2d(ctx, wtype, dim_in, dim_out * 2); + params["proj.bias"] = ggml_new_tensor_1d(ctx, bias_wtype, dim_out * 2); } public: @@ -245,16 +247,19 @@ class CrossAttention : public GGMLBlock { int64_t context_dim; int64_t n_head; int64_t d_head; + bool flash_attn; public: CrossAttention(int64_t query_dim, int64_t context_dim, int64_t n_head, - int64_t d_head) + int64_t d_head, + bool flash_attn = false) : n_head(n_head), d_head(d_head), query_dim(query_dim), - context_dim(context_dim) { + context_dim(context_dim), + flash_attn(flash_attn) { int64_t inner_dim = d_head * n_head; blocks["to_q"] = std::shared_ptr(new Linear(query_dim, inner_dim, false)); @@ -283,7 +288,7 @@ class CrossAttention : public GGMLBlock { auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim] auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim] - x = ggml_nn_attention_ext(ctx, q, k, v, n_head, NULL, false); // [N, n_token, inner_dim] + x = ggml_nn_attention_ext(ctx, q, k, v, n_head, NULL, false, false, flash_attn); // [N, n_token, inner_dim] x = to_out_0->forward(ctx, x); // [N, n_token, query_dim] return x; @@ -301,15 +306,16 @@ class BasicTransformerBlock : public GGMLBlock { int64_t n_head, int64_t d_head, int64_t context_dim, - bool ff_in = false) + bool ff_in = false, + bool flash_attn = false) : n_head(n_head), d_head(d_head), ff_in(ff_in) { // disable_self_attn is always False // disable_temporal_crossattention is always False // switch_temporal_ca_to_sa is always False // inner_dim is always None or equal to dim // gated_ff is always True - blocks["attn1"] = std::shared_ptr(new CrossAttention(dim, dim, n_head, d_head)); - blocks["attn2"] = std::shared_ptr(new CrossAttention(dim, context_dim, n_head, d_head)); + blocks["attn1"] = std::shared_ptr(new CrossAttention(dim, dim, n_head, d_head, flash_attn)); + blocks["attn2"] = std::shared_ptr(new CrossAttention(dim, context_dim, n_head, d_head, flash_attn)); blocks["ff"] = std::shared_ptr(new FeedForward(dim, dim)); blocks["norm1"] = std::shared_ptr(new LayerNorm(dim)); blocks["norm2"] = std::shared_ptr(new LayerNorm(dim)); @@ -374,7 +380,8 @@ class SpatialTransformer : public GGMLBlock { int64_t n_head, int64_t d_head, int64_t depth, - int64_t context_dim) + int64_t context_dim, + bool flash_attn = false) : in_channels(in_channels), n_head(n_head), d_head(d_head), @@ -388,7 +395,7 @@ class SpatialTransformer : public GGMLBlock { for (int i = 0; i < depth; i++) { std::string name = "transformer_blocks." + std::to_string(i); - blocks[name] = std::shared_ptr(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim)); + blocks[name] = std::shared_ptr(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim, false, flash_attn)); } blocks["proj_out"] = std::shared_ptr(new Conv2d(inner_dim, in_channels, {1, 1})); @@ -433,8 +440,10 @@ class SpatialTransformer : public GGMLBlock { class AlphaBlender : public GGMLBlock { protected: - void init_params(struct ggml_context* ctx, ggml_type wtype) { - params["mix_factor"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); + void init_params(struct ggml_context* ctx, std::map& tensor_types, std::string prefix = "") { + // Get the type of the "mix_factor" tensor from the input tensors map with the specified prefix + enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.ypes.find(prefix + "mix_factor") != tensor_types.end()) ? tensor_types[prefix + "mix_factor"] : GGML_TYPE_F32; + params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1); } float get_alpha() { @@ -511,4 +520,4 @@ class VideoResBlock : public ResBlock { } }; -#endif // __COMMON_HPP__ \ No newline at end of file +#endif // __COMMON_HPP__ diff --git a/conditioner.hpp b/conditioner.hpp index 43d0a6d55..6e9acdb19 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -43,71 +43,73 @@ struct Conditioner { // ldm.modules.encoders.modules.FrozenCLIPEmbedder // Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283 struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { - SDVersion version = VERSION_SD1; + SDVersion version = VERSION_SD1; + PMVersion pm_version = PM_VERSION_1; CLIPTokenizer tokenizer; - ggml_type wtype; std::shared_ptr text_model; std::shared_ptr text_model2; std::string trigger_word = "img"; // should be user settable std::string embd_dir; - int32_t num_custom_embeddings = 0; + int32_t num_custom_embeddings = 0; + int32_t num_custom_embeddings_2 = 0; std::vector token_embed_custom; std::vector readed_embeddings; FrozenCLIPEmbedderWithCustomWords(ggml_backend_t backend, - ggml_type wtype, + std::map& tensor_types, const std::string& embd_dir, SDVersion version = VERSION_SD1, + PMVersion pv = PM_VERSION_1, int clip_skip = -1) - : version(version), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir), wtype(wtype) { + : version(version), pm_version(pv), tokenizer(sd_version_is_sd2(version) ? 0 : 49407), embd_dir(embd_dir) { if (clip_skip <= 0) { clip_skip = 1; - if (version == VERSION_SD2 || version == VERSION_SDXL) { + if (sd_version_is_sd2(version) || sd_version_is_sdxl(version)) { clip_skip = 2; } } - if (version == VERSION_SD1) { - text_model = std::make_shared(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip); - } else if (version == VERSION_SD2) { - text_model = std::make_shared(backend, wtype, OPEN_CLIP_VIT_H_14, clip_skip); - } else if (version == VERSION_SDXL) { - text_model = std::make_shared(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip, false); - text_model2 = std::make_shared(backend, wtype, OPEN_CLIP_VIT_BIGG_14, clip_skip, false); + if (sd_version_is_sd1(version)) { + text_model = std::make_shared(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip); + } else if (sd_version_is_sd2(version)) { + text_model = std::make_shared(backend, tensor_types, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14, clip_skip); + } else if (sd_version_is_sdxl(version)) { + text_model = std::make_shared(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, false); + text_model2 = std::make_shared(backend, tensor_types, "cond_stage_model.1.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, clip_skip, false); } } void set_clip_skip(int clip_skip) { text_model->set_clip_skip(clip_skip); - if (version == VERSION_SDXL) { + if (sd_version_is_sdxl(version)) { text_model2->set_clip_skip(clip_skip); } } void get_param_tensors(std::map& tensors) { text_model->get_param_tensors(tensors, "cond_stage_model.transformer.text_model"); - if (version == VERSION_SDXL) { + if (sd_version_is_sdxl(version)) { text_model2->get_param_tensors(tensors, "cond_stage_model.1.transformer.text_model"); } } void alloc_params_buffer() { text_model->alloc_params_buffer(); - if (version == VERSION_SDXL) { + if (sd_version_is_sdxl(version)) { text_model2->alloc_params_buffer(); } } void free_params_buffer() { text_model->free_params_buffer(); - if (version == VERSION_SDXL) { + if (sd_version_is_sdxl(version)) { text_model2->free_params_buffer(); } } size_t get_params_buffer_size() { size_t buffer_size = text_model->get_params_buffer_size(); - if (version == VERSION_SDXL) { + if (sd_version_is_sdxl(version)) { buffer_size += text_model2->get_params_buffer_size(); } return buffer_size; @@ -130,28 +132,55 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { params.no_alloc = false; struct ggml_context* embd_ctx = ggml_init(params); struct ggml_tensor* embd = NULL; - int64_t hidden_size = text_model->model.hidden_size; + struct ggml_tensor* embd2 = NULL; auto on_load = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) { - if (tensor_storage.ne[0] != hidden_size) { - LOG_DEBUG("embedding wrong hidden size, got %i, expected %i", tensor_storage.ne[0], hidden_size); - return false; + if (tensor_storage.ne[0] != text_model->model.hidden_size) { + if (text_model2) { + if (tensor_storage.ne[0] == text_model2->model.hidden_size) { + embd2 = ggml_new_tensor_2d(embd_ctx, tensor_storage.type, text_model2->model.hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne[1] : 1); + *dst_tensor = embd2; + } else { + LOG_DEBUG("embedding wrong hidden size, got %i, expected %i or %i", tensor_storage.ne[0], text_model->model.hidden_size, text_model2->model.hidden_size); + return false; + } + } else { + LOG_DEBUG("embedding wrong hidden size, got %i, expected %i", tensor_storage.ne[0], text_model->model.hidden_size); + return false; + } + } else { + embd = ggml_new_tensor_2d(embd_ctx, tensor_storage.type, text_model->model.hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne[1] : 1); + *dst_tensor = embd; } - embd = ggml_new_tensor_2d(embd_ctx, wtype, hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne[1] : 1); - *dst_tensor = embd; return true; }; model_loader.load_tensors(on_load, NULL); readed_embeddings.push_back(embd_name); - token_embed_custom.resize(token_embed_custom.size() + ggml_nbytes(embd)); - memcpy((void*)(token_embed_custom.data() + num_custom_embeddings * hidden_size * ggml_type_size(wtype)), - embd->data, - ggml_nbytes(embd)); - for (int i = 0; i < embd->ne[1]; i++) { - bpe_tokens.push_back(text_model->model.vocab_size + num_custom_embeddings); - // LOG_DEBUG("new custom token: %i", text_model.vocab_size + num_custom_embeddings); - num_custom_embeddings++; + if (embd) { + int64_t hidden_size = text_model->model.hidden_size; + token_embed_custom.resize(token_embed_custom.size() + ggml_nbytes(embd)); + memcpy((void*)(token_embed_custom.data() + num_custom_embeddings * hidden_size * ggml_type_size(embd->type)), + embd->data, + ggml_nbytes(embd)); + for (int i = 0; i < embd->ne[1]; i++) { + bpe_tokens.push_back(text_model->model.vocab_size + num_custom_embeddings); + // LOG_DEBUG("new custom token: %i", text_model.vocab_size + num_custom_embeddings); + num_custom_embeddings++; + } + LOG_DEBUG("embedding '%s' applied, custom embeddings: %i", embd_name.c_str(), num_custom_embeddings); + } + if (embd2) { + int64_t hidden_size = text_model2->model.hidden_size; + token_embed_custom.resize(token_embed_custom.size() + ggml_nbytes(embd2)); + memcpy((void*)(token_embed_custom.data() + num_custom_embeddings_2 * hidden_size * ggml_type_size(embd2->type)), + embd2->data, + ggml_nbytes(embd2)); + for (int i = 0; i < embd2->ne[1]; i++) { + bpe_tokens.push_back(text_model2->model.vocab_size + num_custom_embeddings_2); + // LOG_DEBUG("new custom token: %i", text_model.vocab_size + num_custom_embeddings); + num_custom_embeddings_2++; + } + LOG_DEBUG("embedding '%s' applied, custom embeddings: %i (text model 2)", embd_name.c_str(), num_custom_embeddings_2); } - LOG_DEBUG("embedding '%s' applied, custom embeddings: %i", embd_name.c_str(), num_custom_embeddings); return true; } @@ -268,7 +297,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { std::vector clean_input_ids_tmp; for (uint32_t i = 0; i < class_token_index[0]; i++) clean_input_ids_tmp.push_back(clean_input_ids[i]); - for (uint32_t i = 0; i < num_input_imgs; i++) + for (uint32_t i = 0; i < (pm_version == PM_VERSION_2 ? 2 * num_input_imgs : num_input_imgs); i++) clean_input_ids_tmp.push_back(class_token); for (uint32_t i = class_token_index[0] + 1; i < clean_input_ids.size(); i++) clean_input_ids_tmp.push_back(clean_input_ids[i]); @@ -279,13 +308,16 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { tokens.insert(tokens.end(), clean_input_ids.begin(), clean_input_ids.end()); weights.insert(weights.end(), clean_input_ids.size(), curr_weight); } - tokens.insert(tokens.begin(), tokenizer.BOS_TOKEN_ID); - weights.insert(weights.begin(), 1.0); + // BUG!! double couting, pad_tokens will add BOS at the beginning + // tokens.insert(tokens.begin(), tokenizer.BOS_TOKEN_ID); + // weights.insert(weights.begin(), 1.0); tokenizer.pad_tokens(tokens, weights, max_length, padding); - + int offset = pm_version == PM_VERSION_2 ? 2 * num_input_imgs : num_input_imgs; for (uint32_t i = 0; i < tokens.size(); i++) { - if (class_idx + 1 <= i && i < class_idx + 1 + num_input_imgs) + // if (class_idx + 1 <= i && i < class_idx + 1 + 2*num_input_imgs) // photomaker V2 has num_tokens(=2)*num_input_imgs + if (class_idx + 1 <= i && i < class_idx + 1 + offset) // photomaker V2 has num_tokens(=2)*num_input_imgs + // hardcode for now class_token_mask.push_back(true); else class_token_mask.push_back(false); @@ -398,7 +430,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens); struct ggml_tensor* input_ids2 = NULL; size_t max_token_idx = 0; - if (version == VERSION_SDXL) { + if (sd_version_is_sdxl(version)) { auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), tokenizer.EOS_TOKEN_ID); if (it != chunk_tokens.end()) { std::fill(std::next(it), chunk_tokens.end(), 0); @@ -423,7 +455,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { false, &chunk_hidden_states1, work_ctx); - if (version == VERSION_SDXL) { + if (sd_version_is_sdxl(version)) { text_model2->compute(n_threads, input_ids2, 0, @@ -482,7 +514,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]); ggml_tensor* vec = NULL; - if (version == VERSION_SDXL) { + if (sd_version_is_sdxl(version)) { int out_dim = 256; vec = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, adm_in_channels); // [0:1280] @@ -585,9 +617,9 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { struct FrozenCLIPVisionEmbedder : public GGMLRunner { CLIPVisionModelProjection vision_model; - FrozenCLIPVisionEmbedder(ggml_backend_t backend, ggml_type wtype) - : vision_model(OPEN_CLIP_VIT_H_14, true), GGMLRunner(backend, wtype) { - vision_model.init(params_ctx, wtype); + FrozenCLIPVisionEmbedder(ggml_backend_t backend, std::map& tensor_types) + : vision_model(OPEN_CLIP_VIT_H_14, true), GGMLRunner(backend) { + vision_model.init(params_ctx, tensor_types, "cond_stage_model.transformer"); } std::string get_desc() { @@ -622,7 +654,6 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner { }; struct SD3CLIPEmbedder : public Conditioner { - ggml_type wtype; CLIPTokenizer clip_l_tokenizer; CLIPTokenizer clip_g_tokenizer; T5UniGramTokenizer t5_tokenizer; @@ -631,15 +662,15 @@ struct SD3CLIPEmbedder : public Conditioner { std::shared_ptr t5; SD3CLIPEmbedder(ggml_backend_t backend, - ggml_type wtype, + std::map& tensor_types, int clip_skip = -1) - : wtype(wtype), clip_g_tokenizer(0) { + : clip_g_tokenizer(0) { if (clip_skip <= 0) { clip_skip = 2; } - clip_l = std::make_shared(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip, false); - clip_g = std::make_shared(backend, wtype, OPEN_CLIP_VIT_BIGG_14, clip_skip, false); - t5 = std::make_shared(backend, wtype); + clip_l = std::make_shared(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, false); + clip_g = std::make_shared(backend, tensor_types, "text_encoders.clip_g.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, clip_skip, false); + t5 = std::make_shared(backend, tensor_types, "text_encoders.t5xxl.transformer"); } void set_clip_skip(int clip_skip) { @@ -798,21 +829,16 @@ struct SD3CLIPEmbedder : public Conditioner { } if (chunk_idx == 0) { - // auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID); - // max_token_idx = std::min(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1); - // clip_l->compute(n_threads, - // input_ids, - // 0, - // NULL, - // max_token_idx, - // true, - // &pooled_l, - // work_ctx); - - // clip_l.transformer.text_model.text_projection no in file, ignore - // TODO: use torch.eye(embed_dim) as default clip_l.transformer.text_model.text_projection - pooled_l = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 768); - ggml_set_f32(pooled_l, 0.f); + auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID); + max_token_idx = std::min(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1); + clip_l->compute(n_threads, + input_ids, + 0, + NULL, + max_token_idx, + true, + &pooled_l, + work_ctx); } } @@ -852,21 +878,16 @@ struct SD3CLIPEmbedder : public Conditioner { } if (chunk_idx == 0) { - // auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_g_tokenizer.EOS_TOKEN_ID); - // max_token_idx = std::min(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1); - // clip_g->compute(n_threads, - // input_ids, - // 0, - // NULL, - // max_token_idx, - // true, - // &pooled_g, - // work_ctx); - // clip_l.transformer.text_model.text_projection no in file, ignore pooled_g too - - // TODO: fix pooled_g - pooled_g = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 1280); - ggml_set_f32(pooled_g, 0.f); + auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_g_tokenizer.EOS_TOKEN_ID); + max_token_idx = std::min(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1); + clip_g->compute(n_threads, + input_ids, + 0, + NULL, + max_token_idx, + true, + &pooled_g, + work_ctx); } } @@ -979,21 +1000,19 @@ struct SD3CLIPEmbedder : public Conditioner { }; struct FluxCLIPEmbedder : public Conditioner { - ggml_type wtype; CLIPTokenizer clip_l_tokenizer; T5UniGramTokenizer t5_tokenizer; std::shared_ptr clip_l; std::shared_ptr t5; FluxCLIPEmbedder(ggml_backend_t backend, - ggml_type wtype, - int clip_skip = -1) - : wtype(wtype) { + std::map& tensor_types, + int clip_skip = -1) { if (clip_skip <= 0) { clip_skip = 2; } - clip_l = std::make_shared(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip, true); - t5 = std::make_shared(backend, wtype); + clip_l = std::make_shared(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, true); + t5 = std::make_shared(backend, tensor_types, "text_encoders.t5xxl.transformer"); } void set_clip_skip(int clip_skip) { @@ -1001,8 +1020,8 @@ struct FluxCLIPEmbedder : public Conditioner { } void get_param_tensors(std::map& tensors) { - clip_l->get_param_tensors(tensors, "text_encoders.clip_l.text_model"); - t5->get_param_tensors(tensors, "text_encoders.t5xxl"); + clip_l->get_param_tensors(tensors, "text_encoders.clip_l.transformer.text_model"); + t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer"); } void alloc_params_buffer() { @@ -1104,21 +1123,17 @@ struct FluxCLIPEmbedder : public Conditioner { auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens); size_t max_token_idx = 0; - // auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID); - // max_token_idx = std::min(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1); - // clip_l->compute(n_threads, - // input_ids, - // 0, - // NULL, - // max_token_idx, - // true, - // &pooled, - // work_ctx); - - // clip_l.transformer.text_model.text_projection no in file, ignore - // TODO: use torch.eye(embed_dim) as default clip_l.transformer.text_model.text_projection - pooled = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 768); - ggml_set_f32(pooled, 0.f); + auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID); + max_token_idx = std::min(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1); + + clip_l->compute(n_threads, + input_ids, + 0, + NULL, + max_token_idx, + true, + &pooled, + work_ctx); } // t5 diff --git a/control.hpp b/control.hpp index 41f31acb7..23b75feff 100644 --- a/control.hpp +++ b/control.hpp @@ -34,11 +34,11 @@ class ControlNetBlock : public GGMLBlock { ControlNetBlock(SDVersion version = VERSION_SD1) : version(version) { - if (version == VERSION_SD2) { + if (sd_version_is_sd2(version)) { context_dim = 1024; num_head_channels = 64; num_heads = -1; - } else if (version == VERSION_SDXL) { + } else if (sd_version_is_sdxl(version)) { context_dim = 2048; attention_resolutions = {4, 2}; channel_mult = {1, 2, 4}; @@ -58,7 +58,7 @@ class ControlNetBlock : public GGMLBlock { // time_embed_1 is nn.SiLU() blocks["time_embed.2"] = std::shared_ptr(new Linear(time_embed_dim, time_embed_dim)); - if (version == VERSION_SDXL || version == VERSION_SVD) { + if (sd_version_is_sdxl(version) || version == VERSION_SVD) { blocks["label_emb.0.0"] = std::shared_ptr(new Linear(adm_in_channels, time_embed_dim)); // label_emb_1 is nn.SiLU() blocks["label_emb.0.2"] = std::shared_ptr(new Linear(time_embed_dim, time_embed_dim)); @@ -317,10 +317,10 @@ struct ControlNet : public GGMLRunner { bool guided_hint_cached = false; ControlNet(ggml_backend_t backend, - ggml_type wtype, + std::map& tensor_types, SDVersion version = VERSION_SD1) - : GGMLRunner(backend, wtype), control_net(version) { - control_net.init(params_ctx, wtype); + : GGMLRunner(backend), control_net(version) { + control_net.init(params_ctx, tensor_types, ""); } ~ControlNet() { diff --git a/denoiser.hpp b/denoiser.hpp index 287b10934..66799109d 100644 --- a/denoiser.hpp +++ b/denoiser.hpp @@ -49,7 +49,7 @@ struct ExponentialSchedule : SigmaSchedule { // Calculate step size float log_sigma_min = std::log(sigma_min); float log_sigma_max = std::log(sigma_max); - float step = (log_sigma_max - log_sigma_min) / (n - 1); + float step = (log_sigma_max - log_sigma_min) / (n - 1); // Fill sigmas with exponential values for (uint32_t i = 0; i < n; ++i) { @@ -205,7 +205,7 @@ struct AYSSchedule : SigmaSchedule { /* * GITS Scheduler: https://github.com/zju-pi/diff-sampler/tree/main/gits-main -*/ + */ struct GITSSchedule : SigmaSchedule { std::vector get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) { if (sigma_max <= 0.0f) { @@ -221,7 +221,7 @@ struct GITSSchedule : SigmaSchedule { // Calculate the index based on the coefficient int index = static_cast((coeff - 0.80f) / 0.05f); // Ensure the index is within bounds - index = std::max(0, std::min(index, static_cast(GITS_NOISE.size() - 1))); + index = std::max(0, std::min(index, static_cast(GITS_NOISE.size() - 1))); const std::vector>& selected_noise = *GITS_NOISE[index]; if (n <= 20) { @@ -474,7 +474,8 @@ static void sample_k_diffusion(sample_method_t method, ggml_context* work_ctx, ggml_tensor* x, std::vector sigmas, - std::shared_ptr rng) { + std::shared_ptr rng, + float eta) { size_t steps = sigmas.size() - 1; // sample_euler_ancestral switch (method) { @@ -823,24 +824,24 @@ static void sample_k_diffusion(sample_method_t method, } break; case IPNDM: // iPNDM sampler from https://github.com/zju-pi/diff-sampler/tree/main/diff-solvers-main { - int max_order = 4; + int max_order = 4; ggml_tensor* x_next = x; std::vector buffer_model; for (int i = 0; i < steps; i++) { - float sigma = sigmas[i]; + float sigma = sigmas[i]; float sigma_next = sigmas[i + 1]; ggml_tensor* x_cur = x_next; - float* vec_x_cur = (float*)x_cur->data; - float* vec_x_next = (float*)x_next->data; + float* vec_x_cur = (float*)x_cur->data; + float* vec_x_next = (float*)x_next->data; // Denoising step ggml_tensor* denoised = model(x_cur, sigma, i + 1); - float* vec_denoised = (float*)denoised->data; + float* vec_denoised = (float*)denoised->data; // d_cur = (x_cur - denoised) / sigma struct ggml_tensor* d_cur = ggml_dup_tensor(work_ctx, x_cur); - float* vec_d_cur = (float*)d_cur->data; + float* vec_d_cur = (float*)d_cur->data; for (int j = 0; j < ggml_nelements(d_cur); j++) { vec_d_cur[j] = (vec_x_cur[j] - vec_denoised[j]) / sigma; @@ -857,34 +858,31 @@ static void sample_k_diffusion(sample_method_t method, break; case 2: // Use one history point - { - float* vec_d_prev1 = (float*)buffer_model.back()->data; - for (int j = 0; j < ggml_nelements(x_next); j++) { - vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (3 * vec_d_cur[j] - vec_d_prev1[j]) / 2; - } + { + float* vec_d_prev1 = (float*)buffer_model.back()->data; + for (int j = 0; j < ggml_nelements(x_next); j++) { + vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (3 * vec_d_cur[j] - vec_d_prev1[j]) / 2; } - break; + } break; case 3: // Use two history points - { - float* vec_d_prev1 = (float*)buffer_model.back()->data; - float* vec_d_prev2 = (float*)buffer_model[buffer_model.size() - 2]->data; - for (int j = 0; j < ggml_nelements(x_next); j++) { - vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (23 * vec_d_cur[j] - 16 * vec_d_prev1[j] + 5 * vec_d_prev2[j]) / 12; - } + { + float* vec_d_prev1 = (float*)buffer_model.back()->data; + float* vec_d_prev2 = (float*)buffer_model[buffer_model.size() - 2]->data; + for (int j = 0; j < ggml_nelements(x_next); j++) { + vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (23 * vec_d_cur[j] - 16 * vec_d_prev1[j] + 5 * vec_d_prev2[j]) / 12; } - break; + } break; case 4: // Use three history points - { - float* vec_d_prev1 = (float*)buffer_model.back()->data; - float* vec_d_prev2 = (float*)buffer_model[buffer_model.size() - 2]->data; - float* vec_d_prev3 = (float*)buffer_model[buffer_model.size() - 3]->data; - for (int j = 0; j < ggml_nelements(x_next); j++) { - vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (55 * vec_d_cur[j] - 59 * vec_d_prev1[j] + 37 * vec_d_prev2[j] - 9 * vec_d_prev3[j]) / 24; - } + { + float* vec_d_prev1 = (float*)buffer_model.back()->data; + float* vec_d_prev2 = (float*)buffer_model[buffer_model.size() - 2]->data; + float* vec_d_prev3 = (float*)buffer_model[buffer_model.size() - 3]->data; + for (int j = 0; j < ggml_nelements(x_next); j++) { + vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (55 * vec_d_cur[j] - 59 * vec_d_prev1[j] + 37 * vec_d_prev2[j] - 9 * vec_d_prev3[j]) / 24; } - break; + } break; } // Manage buffer_model @@ -906,27 +904,27 @@ static void sample_k_diffusion(sample_method_t method, ggml_tensor* x_next = x; for (int i = 0; i < steps; i++) { - float sigma = sigmas[i]; + float sigma = sigmas[i]; float t_next = sigmas[i + 1]; // Denoising step - ggml_tensor* denoised = model(x, sigma, i + 1); - float* vec_denoised = (float*)denoised->data; + ggml_tensor* denoised = model(x, sigma, i + 1); + float* vec_denoised = (float*)denoised->data; struct ggml_tensor* d_cur = ggml_dup_tensor(work_ctx, x); - float* vec_d_cur = (float*)d_cur->data; - float* vec_x = (float*)x->data; + float* vec_d_cur = (float*)d_cur->data; + float* vec_x = (float*)x->data; // d_cur = (x - denoised) / sigma for (int j = 0; j < ggml_nelements(d_cur); j++) { vec_d_cur[j] = (vec_x[j] - vec_denoised[j]) / sigma; } - int order = std::min(max_order, i + 1); - float h_n = t_next - sigma; + int order = std::min(max_order, i + 1); + float h_n = t_next - sigma; float h_n_1 = (i > 0) ? (sigma - sigmas[i - 1]) : h_n; switch (order) { - case 1: // First Euler step + case 1: // First Euler step for (int j = 0; j < ggml_nelements(x_next); j++) { vec_x[j] += vec_d_cur[j] * h_n; } @@ -941,7 +939,7 @@ static void sample_k_diffusion(sample_method_t method, } case 3: { - float h_n_2 = (i > 1) ? (sigmas[i - 1] - sigmas[i - 2]) : h_n_1; + float h_n_2 = (i > 1) ? (sigmas[i - 1] - sigmas[i - 2]) : h_n_1; float* vec_d_prev1 = (float*)buffer_model.back()->data; float* vec_d_prev2 = (buffer_model.size() > 1) ? (float*)buffer_model[buffer_model.size() - 2]->data : vec_d_prev1; for (int j = 0; j < ggml_nelements(x_next); j++) { @@ -951,8 +949,8 @@ static void sample_k_diffusion(sample_method_t method, } case 4: { - float h_n_2 = (i > 1) ? (sigmas[i - 1] - sigmas[i - 2]) : h_n_1; - float h_n_3 = (i > 2) ? (sigmas[i - 2] - sigmas[i - 3]) : h_n_2; + float h_n_2 = (i > 1) ? (sigmas[i - 1] - sigmas[i - 2]) : h_n_1; + float h_n_3 = (i > 2) ? (sigmas[i - 2] - sigmas[i - 3]) : h_n_2; float* vec_d_prev1 = (float*)buffer_model.back()->data; float* vec_d_prev2 = (buffer_model.size() > 1) ? (float*)buffer_model[buffer_model.size() - 2]->data : vec_d_prev1; float* vec_d_prev3 = (buffer_model.size() > 2) ? (float*)buffer_model[buffer_model.size() - 3]->data : vec_d_prev2; @@ -1008,6 +1006,374 @@ static void sample_k_diffusion(sample_method_t method, } } } break; + case DDIM_TRAILING: // Denoising Diffusion Implicit Models + // with the "trailing" timestep spacing + { + // See J. Song et al., "Denoising Diffusion Implicit + // Models", arXiv:2010.02502 [cs.LG] + // + // DDIM itself needs alphas_cumprod (DDPM, J. Ho et al., + // arXiv:2006.11239 [cs.LG] with k-diffusion's start and + // end beta) (which unfortunately k-diffusion's data + // structure hides from the denoiser), and the sigmas are + // also needed to invert the behavior of CompVisDenoiser + // (k-diffusion's LMSDiscreteScheduler) + float beta_start = 0.00085f; + float beta_end = 0.0120f; + std::vector alphas_cumprod; + std::vector compvis_sigmas; + + alphas_cumprod.reserve(TIMESTEPS); + compvis_sigmas.reserve(TIMESTEPS); + for (int i = 0; i < TIMESTEPS; i++) { + alphas_cumprod[i] = + (i == 0 ? 1.0f : alphas_cumprod[i - 1]) * + (1.0f - + std::pow(sqrtf(beta_start) + + (sqrtf(beta_end) - sqrtf(beta_start)) * + ((float)i / (TIMESTEPS - 1)), 2)); + compvis_sigmas[i] = + std::sqrt((1 - alphas_cumprod[i]) / + alphas_cumprod[i]); + } + + struct ggml_tensor* pred_original_sample = + ggml_dup_tensor(work_ctx, x); + struct ggml_tensor* variance_noise = + ggml_dup_tensor(work_ctx, x); + + for (int i = 0; i < steps; i++) { + // The "trailing" DDIM timestep, see S. Lin et al., + // "Common Diffusion Noise Schedules and Sample Steps + // are Flawed", arXiv:2305.08891 [cs], p. 4, Table + // 2. Most variables below follow Diffusers naming + // + // Diffuser naming vs. Song et al. (2010), p. 5, (12) + // and p. 16, (16) ( -> ): + // + // - pred_noise_t -> epsilon_theta^(t)(x_t) + // - pred_original_sample -> f_theta^(t)(x_t) or x_0 + // - std_dev_t -> sigma_t (not the LMS sigma) + // - eta -> eta (set to 0 at the moment) + // - pred_sample_direction -> "direction pointing to + // x_t" + // - pred_prev_sample -> "x_t-1" + int timestep = + roundf(TIMESTEPS - + i * ((float)TIMESTEPS / steps)) - 1; + // 1. get previous step value (=t-1) + int prev_timestep = timestep - TIMESTEPS / steps; + // The sigma here is chosen to cause the + // CompVisDenoiser to produce t = timestep + float sigma = compvis_sigmas[timestep]; + if (i == 0) { + // The function add_noise intializes x to + // Diffusers' latents * sigma (as in Diffusers' + // pipeline) or sample * sigma (Diffusers' + // scheduler), where this sigma = init_noise_sigma + // in Diffusers. For DDPM and DDIM however, + // init_noise_sigma = 1. But the k-diffusion + // model() also evaluates F_theta(c_in(sigma) x; + // ...) instead of the bare U-net F_theta, with + // c_in = 1 / sqrt(sigma^2 + 1), as defined in + // T. Karras et al., "Elucidating the Design Space + // of Diffusion-Based Generative Models", + // arXiv:2206.00364 [cs.CV], p. 3, Table 1. Hence + // the first call has to be prescaled as x <- x / + // (c_in * sigma) with the k-diffusion pipeline + // and CompVisDenoiser. + float* vec_x = (float*)x->data; + for (int j = 0; j < ggml_nelements(x); j++) { + vec_x[j] *= std::sqrt(sigma * sigma + 1) / + sigma; + } + } + else { + // For the subsequent steps after the first one, + // at this point x = latents or x = sample, and + // needs to be prescaled with x <- sample / c_in + // to compensate for model() applying the scale + // c_in before the U-net F_theta + float* vec_x = (float*)x->data; + for (int j = 0; j < ggml_nelements(x); j++) { + vec_x[j] *= std::sqrt(sigma * sigma + 1); + } + } + // Note (also noise_pred in Diffuser's pipeline) + // model_output = model() is the D(x, sigma) as + // defined in Karras et al. (2022), p. 3, Table 1 and + // p. 8 (7), compare also p. 38 (226) therein. + struct ggml_tensor* model_output = + model(x, sigma, i + 1); + // Here model_output is still the k-diffusion denoiser + // output, not the U-net output F_theta(c_in(sigma) x; + // ...) in Karras et al. (2022), whereas Diffusers' + // model_output is F_theta(...). Recover the actual + // model_output, which is also referred to as the + // "Karras ODE derivative" d or d_cur in several + // samplers above. + { + float* vec_x = (float*)x->data; + float* vec_model_output = + (float*)model_output->data; + for (int j = 0; j < ggml_nelements(x); j++) { + vec_model_output[j] = + (vec_x[j] - vec_model_output[j]) * + (1 / sigma); + } + } + // 2. compute alphas, betas + float alpha_prod_t = alphas_cumprod[timestep]; + // Note final_alpha_cumprod = alphas_cumprod[0] due to + // trailing timestep spacing + float alpha_prod_t_prev = prev_timestep >= 0 ? + alphas_cumprod[prev_timestep] : alphas_cumprod[0]; + float beta_prod_t = 1 - alpha_prod_t; + // 3. compute predicted original sample from predicted + // noise also called "predicted x_0" of formula (12) + // from https://arxiv.org/pdf/2010.02502.pdf + { + float* vec_x = (float*)x->data; + float* vec_model_output = + (float*)model_output->data; + float* vec_pred_original_sample = + (float*)pred_original_sample->data; + // Note the substitution of latents or sample = x + // * c_in = x / sqrt(sigma^2 + 1) + for (int j = 0; j < ggml_nelements(x); j++) { + vec_pred_original_sample[j] = + (vec_x[j] / std::sqrt(sigma * sigma + 1) - + std::sqrt(beta_prod_t) * + vec_model_output[j]) * + (1 / std::sqrt(alpha_prod_t)); + } + } + // Assuming the "epsilon" prediction type, where below + // pred_epsilon = model_output is inserted, and is not + // defined/copied explicitly. + // + // 5. compute variance: "sigma_t(eta)" -> see formula + // (16) + // + // sigma_t = sqrt((1 - alpha_t-1)/(1 - alpha_t)) * + // sqrt(1 - alpha_t/alpha_t-1) + float beta_prod_t_prev = 1 - alpha_prod_t_prev; + float variance = (beta_prod_t_prev / beta_prod_t) * + (1 - alpha_prod_t / alpha_prod_t_prev); + float std_dev_t = eta * std::sqrt(variance); + // 6. compute "direction pointing to x_t" of formula + // (12) from https://arxiv.org/pdf/2010.02502.pdf + // 7. compute x_t without "random noise" of formula + // (12) from https://arxiv.org/pdf/2010.02502.pdf + { + float* vec_model_output = (float*)model_output->data; + float* vec_pred_original_sample = + (float*)pred_original_sample->data; + float* vec_x = (float*)x->data; + for (int j = 0; j < ggml_nelements(x); j++) { + // Two step inner loop without an explicit + // tensor + float pred_sample_direction = + std::sqrt(1 - alpha_prod_t_prev - + std::pow(std_dev_t, 2)) * + vec_model_output[j]; + vec_x[j] = std::sqrt(alpha_prod_t_prev) * + vec_pred_original_sample[j] + + pred_sample_direction; + } + } + if (eta > 0) { + ggml_tensor_set_f32_randn(variance_noise, rng); + float* vec_variance_noise = + (float*)variance_noise->data; + float* vec_x = (float*)x->data; + for (int j = 0; j < ggml_nelements(x); j++) { + vec_x[j] += std_dev_t * vec_variance_noise[j]; + } + } + // See the note above: x = latents or sample here, and + // is not scaled by the c_in. For the final output + // this is correct, but for subsequent iterations, x + // needs to be prescaled again, since k-diffusion's + // model() differes from the bare U-net F_theta by the + // factor c_in. + } + } break; + case TCD: // Strategic Stochastic Sampling (Algorithm 4) in + // Trajectory Consistency Distillation + { + // See J. Zheng et al., "Trajectory Consistency + // Distillation: Improved Latent Consistency Distillation + // by Semi-Linear Consistency Function with Trajectory + // Mapping", arXiv:2402.19159 [cs.CV] + float beta_start = 0.00085f; + float beta_end = 0.0120f; + std::vector alphas_cumprod; + std::vector compvis_sigmas; + + alphas_cumprod.reserve(TIMESTEPS); + compvis_sigmas.reserve(TIMESTEPS); + for (int i = 0; i < TIMESTEPS; i++) { + alphas_cumprod[i] = + (i == 0 ? 1.0f : alphas_cumprod[i - 1]) * + (1.0f - + std::pow(sqrtf(beta_start) + + (sqrtf(beta_end) - sqrtf(beta_start)) * + ((float)i / (TIMESTEPS - 1)), 2)); + compvis_sigmas[i] = + std::sqrt((1 - alphas_cumprod[i]) / + alphas_cumprod[i]); + } + int original_steps = 50; + + struct ggml_tensor* pred_original_sample = + ggml_dup_tensor(work_ctx, x); + struct ggml_tensor* noise = + ggml_dup_tensor(work_ctx, x); + + for (int i = 0; i < steps; i++) { + // Analytic form for TCD timesteps + int timestep = TIMESTEPS - 1 - + (TIMESTEPS / original_steps) * + (int)floor(i * ((float)original_steps / steps)); + // 1. get previous step value + int prev_timestep = i >= steps - 1 ? 0 : + TIMESTEPS - 1 - (TIMESTEPS / original_steps) * + (int)floor((i + 1) * + ((float)original_steps / steps)); + // Here timestep_s is tau_n' in Algorithm 4. The _s + // notation appears to be that from C. Lu, + // "DPM-Solver: A Fast ODE Solver for Diffusion + // Probabilistic Model Sampling in Around 10 Steps", + // arXiv:2206.00927 [cs.LG], but this notation is not + // continued in Algorithm 4, where _n' is used. + int timestep_s = + (int)floor((1 - eta) * prev_timestep); + // Begin k-diffusion specific workaround for + // evaluating F_theta(x; ...) from D(x, sigma), same + // as in DDIM (and see there for detailed comments) + float sigma = compvis_sigmas[timestep]; + if (i == 0) { + float* vec_x = (float*)x->data; + for (int j = 0; j < ggml_nelements(x); j++) { + vec_x[j] *= std::sqrt(sigma * sigma + 1) / + sigma; + } + } + else { + float* vec_x = (float*)x->data; + for (int j = 0; j < ggml_nelements(x); j++) { + vec_x[j] *= std::sqrt(sigma * sigma + 1); + } + } + struct ggml_tensor* model_output = + model(x, sigma, i + 1); + { + float* vec_x = (float*)x->data; + float* vec_model_output = + (float*)model_output->data; + for (int j = 0; j < ggml_nelements(x); j++) { + vec_model_output[j] = + (vec_x[j] - vec_model_output[j]) * + (1 / sigma); + } + } + // 2. compute alphas, betas + // + // When comparing TCD with DDPM/DDIM note that Zheng + // et al. (2024) follows the DPM-Solver notation for + // alpha. One can find the following comment in the + // original DPM-Solver code + // (https://github.com/LuChengTHU/dpm-solver/): + // "**Important**: Please pay special attention for + // the args for `alphas_cumprod`: The `alphas_cumprod` + // is the \hat{alpha_n} arrays in the notations of + // DDPM. [...] Therefore, the notation \hat{alpha_n} + // is different from the notation alpha_t in + // DPM-Solver. In fact, we have alpha_{t_n} = + // \sqrt{\hat{alpha_n}}, [...]" + float alpha_prod_t = alphas_cumprod[timestep]; + float beta_prod_t = 1 - alpha_prod_t; + // Note final_alpha_cumprod = alphas_cumprod[0] since + // TCD is always "trailing" + float alpha_prod_t_prev = prev_timestep >= 0 ? + alphas_cumprod[prev_timestep] : alphas_cumprod[0]; + // The subscript _s are the only portion in this + // section (2) unique to TCD + float alpha_prod_s = alphas_cumprod[timestep_s]; + float beta_prod_s = 1 - alpha_prod_s; + // 3. Compute the predicted noised sample x_s based on + // the model parameterization + // + // This section is also exactly the same as DDIM + { + float* vec_x = (float*)x->data; + float* vec_model_output = + (float*)model_output->data; + float* vec_pred_original_sample = + (float*)pred_original_sample->data; + for (int j = 0; j < ggml_nelements(x); j++) { + vec_pred_original_sample[j] = + (vec_x[j] / std::sqrt(sigma * sigma + 1) - + std::sqrt(beta_prod_t) * + vec_model_output[j]) * + (1 / std::sqrt(alpha_prod_t)); + } + } + // This consistency function step can be difficult to + // decipher from Algorithm 4, as it is simply stated + // using a consistency function. This step is the + // modified DDIM, i.e. p. 8 (32) in Zheng et + // al. (2024), with eta set to 0 (see the paragraph + // immediately thereafter that states this somewhat + // obliquely). + { + float* vec_pred_original_sample = + (float*)pred_original_sample->data; + float* vec_model_output = + (float*)model_output->data; + float* vec_x = (float*)x->data; + for (int j = 0; j < ggml_nelements(x); j++) { + // Substituting x = pred_noised_sample and + // pred_epsilon = model_output + vec_x[j] = + std::sqrt(alpha_prod_s) * + vec_pred_original_sample[j] + + std::sqrt(beta_prod_s) * + vec_model_output[j]; + } + } + // 4. Sample and inject noise z ~ N(0, I) for + // MultiStep Inference Noise is not used on the final + // timestep of the timestep schedule. This also means + // that noise is not used for one-step sampling. Eta + // (referred to as "gamma" in the paper) was + // introduced to control the stochasticity in every + // step. When eta = 0, it represents deterministic + // sampling, whereas eta = 1 indicates full stochastic + // sampling. + if (eta > 0 && i != steps - 1) { + // In this case, x is still pred_noised_sample, + // continue in-place + ggml_tensor_set_f32_randn(noise, rng); + float* vec_x = (float*)x->data; + float* vec_noise = (float*)noise->data; + for (int j = 0; j < ggml_nelements(x); j++) { + // Corresponding to (35) in Zheng et + // al. (2024), substituting x = + // pred_noised_sample + vec_x[j] = + std::sqrt(alpha_prod_t_prev / + alpha_prod_s) * + vec_x[j] + + std::sqrt(1 - alpha_prod_t_prev / + alpha_prod_s) * + vec_noise[j]; + } + } + } + } break; default: LOG_ERROR("Attempting to sample with nonexisting sample method %i", method); diff --git a/diffusion_model.hpp b/diffusion_model.hpp index 2530f7149..ee4d88f0c 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -17,7 +17,8 @@ struct DiffusionModel { std::vector controls = {}, float control_strength = 0.f, struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL) = 0; + struct ggml_context* output_ctx = NULL, + std::vector skip_layers = std::vector()) = 0; virtual void alloc_params_buffer() = 0; virtual void free_params_buffer() = 0; virtual void free_compute_buffer() = 0; @@ -30,9 +31,10 @@ struct UNetModel : public DiffusionModel { UNetModelRunner unet; UNetModel(ggml_backend_t backend, - ggml_type wtype, - SDVersion version = VERSION_SD1) - : unet(backend, wtype, version) { + std::map& tensor_types, + SDVersion version = VERSION_SD1, + bool flash_attn = false) + : unet(backend, tensor_types, "model.diffusion_model", version, flash_attn) { } void alloc_params_buffer() { @@ -70,7 +72,9 @@ struct UNetModel : public DiffusionModel { std::vector controls = {}, float control_strength = 0.f, struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL) { + struct ggml_context* output_ctx = NULL, + std::vector skip_layers = std::vector()) { + (void)skip_layers; // SLG doesn't work with UNet models return unet.compute(n_threads, x, timesteps, context, c_concat, y, num_video_frames, controls, control_strength, output, output_ctx); } }; @@ -79,9 +83,8 @@ struct MMDiTModel : public DiffusionModel { MMDiTRunner mmdit; MMDiTModel(ggml_backend_t backend, - ggml_type wtype, - SDVersion version = VERSION_SD3_2B) - : mmdit(backend, wtype, version) { + std::map& tensor_types) + : mmdit(backend, tensor_types, "model.diffusion_model") { } void alloc_params_buffer() { @@ -119,8 +122,9 @@ struct MMDiTModel : public DiffusionModel { std::vector controls = {}, float control_strength = 0.f, struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL) { - return mmdit.compute(n_threads, x, timesteps, context, y, output, output_ctx); + struct ggml_context* output_ctx = NULL, + std::vector skip_layers = std::vector()) { + return mmdit.compute(n_threads, x, timesteps, context, y, output, output_ctx, skip_layers); } }; @@ -128,9 +132,10 @@ struct FluxModel : public DiffusionModel { Flux::FluxRunner flux; FluxModel(ggml_backend_t backend, - ggml_type wtype, - SDVersion version = VERSION_FLUX_DEV) - : flux(backend, wtype, version) { + std::map& tensor_types, + SDVersion version = VERSION_FLUX, + bool flash_attn = false) + : flux(backend, tensor_types, "model.diffusion_model", version, flash_attn) { } void alloc_params_buffer() { @@ -168,9 +173,10 @@ struct FluxModel : public DiffusionModel { std::vector controls = {}, float control_strength = 0.f, struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL) { - return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx); + struct ggml_context* output_ctx = NULL, + std::vector skip_layers = std::vector()) { + return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, output, output_ctx, skip_layers); } }; -#endif \ No newline at end of file +#endif diff --git a/docs/photo_maker.md b/docs/photo_maker.md index b69ad97d9..8305a33bd 100644 --- a/docs/photo_maker.md +++ b/docs/photo_maker.md @@ -29,4 +29,26 @@ Example: ```bash bin/sd -m ../models/sdxlUnstableDiffusers_v11.safetensors --vae ../models/sdxl_vae.safetensors --stacked-id-embd-dir ../models/photomaker-v1.safetensors --input-id-images-dir ../assets/photomaker_examples/scarletthead_woman -p "a girl img, retro futurism, retro game art style but extremely beautiful, intricate details, masterpiece, best quality, space-themed, cosmic, celestial, stars, galaxies, nebulas, planets, science fiction, highly detailed" -n "realistic, photo-realistic, worst quality, greyscale, bad anatomy, bad hands, error, text" --cfg-scale 5.0 --sampling-method euler -H 1024 -W 1024 --style-ratio 10 --vae-on-cpu -o output.png -``` \ No newline at end of file +``` + +## PhotoMaker Version 2 + +[PhotoMaker Version 2 (PMV2)](https://github.com/TencentARC/PhotoMaker/blob/main/README_pmv2.md) has some key improvements. Unfortunately it has a very heavy dependency which makes running it a bit involved in ```SD.cpp```. + +Running PMV2 is now a two-step process: + +- Run a python script ```face_detect.py``` to obtain **id_embeds** for the given input images +``` +python face_detect.py input_image_dir +``` +An ```id_embeds.safetensors``` file will be generated in ```input_images_dir``` + +**Note: this step is only needed to run once; the same ```id_embeds``` can be reused** + +- Run the same command as in version 1 but replacing ```photomaker-v1.safetensors``` with ```photomaker-v2.safetensors```. + + You can download ```photomaker-v2.safetensors``` from [here](https://huggingface.co/bssrdf/PhotoMakerV2) + +- All the command line parameters from Version 1 remain the same for Version 2 + + diff --git a/docs/sd3.md b/docs/sd3.md new file mode 100644 index 000000000..777511d4b --- /dev/null +++ b/docs/sd3.md @@ -0,0 +1,20 @@ +# How to Use + +## Download weights + +- Download sd3.5_large from https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/sd3.5_large.safetensors +- Download clip_g from https://huggingface.co/Comfy-Org/stable-diffusion-3.5-fp8/blob/main/text_encoders/clip_g.safetensors +- Download clip_l from https://huggingface.co/Comfy-Org/stable-diffusion-3.5-fp8/blob/main/text_encoders/clip_l.safetensors +- Download t5xxl from https://huggingface.co/Comfy-Org/stable-diffusion-3.5-fp8/blob/main/text_encoders/t5xxl_fp16.safetensors + + +## Run + +### SD3.5 Large +For example: + +``` +.\bin\Release\sd.exe -m ..\models\sd3.5_large.safetensors --clip_l ..\models\clip_l.safetensors --clip_g ..\models\clip_g.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -H 1024 -W 1024 -p 'a lovely cat holding a sign says \"Stable diffusion 3.5 Large\"' --cfg-scale 4.5 --sampling-method euler -v +``` + +![](../assets/sd3.5_large.png) \ No newline at end of file diff --git a/esrgan.hpp b/esrgan.hpp index 33fcf09a4..989d15fee 100644 --- a/esrgan.hpp +++ b/esrgan.hpp @@ -142,10 +142,9 @@ struct ESRGAN : public GGMLRunner { int scale = 4; int tile_size = 128; // avoid cuda OOM for 4gb VRAM - ESRGAN(ggml_backend_t backend, - ggml_type wtype) - : GGMLRunner(backend, wtype) { - rrdb_net.init(params_ctx, wtype); + ESRGAN(ggml_backend_t backend, std::map& tensor_types) + : GGMLRunner(backend) { + rrdb_net.init(params_ctx, tensor_types, ""); } std::string get_desc() { diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index ceae27b83..af6b2bbdb 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -39,6 +39,8 @@ const char* sample_method_str[] = { "ipndm", "ipndm_v", "lcm", + "ddim_trailing", + "tcd", }; // Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h @@ -69,9 +71,9 @@ enum SDMode { struct SDParams { int n_threads = -1; SDMode mode = TXT2IMG; - std::string model_path; std::string clip_l_path; + std::string clip_g_path; std::string t5xxl_path; std::string diffusion_model_path; std::string vae_path; @@ -85,6 +87,7 @@ struct SDParams { std::string lora_model_dir; std::string output_path = "output.png"; std::string input_path; + std::string mask_path; std::string control_image_path; std::string prompt; @@ -92,6 +95,7 @@ struct SDParams { float min_cfg = 1.0f; float cfg_scale = 7.0f; float guidance = 3.5f; + float eta = 0.f; float style_ratio = 20.f; int clip_skip = -1; // <= 0 represents unspecified int width = 512; @@ -116,9 +120,15 @@ struct SDParams { bool normalize_input = false; bool clip_on_cpu = false; bool vae_on_cpu = false; + bool diffusion_flash_attn = false; bool canny_preprocess = false; bool color = false; int upscale_repeats = 1; + + std::vector skip_layers = {7, 8, 9}; + float slg_scale = 0.f; + float skip_layer_start = 0.01f; + float skip_layer_end = 0.2f; }; void print_params(SDParams params) { @@ -128,6 +138,7 @@ void print_params(SDParams params) { printf(" model_path: %s\n", params.model_path.c_str()); printf(" wtype: %s\n", params.wtype < SD_TYPE_COUNT ? sd_type_name(params.wtype) : "unspecified"); printf(" clip_l_path: %s\n", params.clip_l_path.c_str()); + printf(" clip_g_path: %s\n", params.clip_g_path.c_str()); printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str()); printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str()); printf(" vae_path: %s\n", params.vae_path.c_str()); @@ -141,16 +152,20 @@ void print_params(SDParams params) { printf(" normalize input image : %s\n", params.normalize_input ? "true" : "false"); printf(" output_path: %s\n", params.output_path.c_str()); printf(" init_img: %s\n", params.input_path.c_str()); + printf(" mask_img: %s\n", params.mask_path.c_str()); printf(" control_image: %s\n", params.control_image_path.c_str()); printf(" clip on cpu: %s\n", params.clip_on_cpu ? "true" : "false"); printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false"); printf(" vae decoder on cpu:%s\n", params.vae_on_cpu ? "true" : "false"); + printf(" diffusion flash attention:%s\n", params.diffusion_flash_attn ? "true" : "false"); printf(" strength(control): %.2f\n", params.control_strength); printf(" prompt: %s\n", params.prompt.c_str()); printf(" negative_prompt: %s\n", params.negative_prompt.c_str()); printf(" min_cfg: %.2f\n", params.min_cfg); printf(" cfg_scale: %.2f\n", params.cfg_scale); + printf(" slg_scale: %.2f\n", params.slg_scale); printf(" guidance: %.2f\n", params.guidance); + printf(" eta: %.2f\n", params.eta); printf(" clip_skip: %d\n", params.clip_skip); printf(" width: %d\n", params.width); printf(" height: %d\n", params.height); @@ -171,48 +186,61 @@ void print_usage(int argc, const char* argv[]) { printf("arguments:\n"); printf(" -h, --help show this help message and exit\n"); printf(" -M, --mode [MODEL] run mode (txt2img or img2img or convert, default: txt2img)\n"); - printf(" -t, --threads N number of threads to use during computation (default: -1).\n"); + printf(" -t, --threads N number of threads to use during computation (default: -1)\n"); printf(" If threads <= 0, then threads will be set to the number of CPU physical cores\n"); printf(" -m, --model [MODEL] path to full model\n"); printf(" --diffusion-model path to the standalone diffusion model\n"); printf(" --clip_l path to the clip-l text encoder\n"); - printf(" --t5xxl path to the the t5xxl text encoder.\n"); + printf(" --clip_g path to the clip-g text encoder\n"); + printf(" --t5xxl path to the the t5xxl text encoder\n"); printf(" --vae [VAE] path to vae\n"); printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n"); printf(" --control-net [CONTROL_PATH] path to control net model\n"); - printf(" --embd-dir [EMBEDDING_PATH] path to embeddings.\n"); - printf(" --stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings.\n"); - printf(" --input-id-images-dir [DIR] path to PHOTOMAKER input id images dir.\n"); + printf(" --embd-dir [EMBEDDING_PATH] path to embeddings\n"); + printf(" --stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings\n"); + printf(" --input-id-images-dir [DIR] path to PHOTOMAKER input id images dir\n"); printf(" --normalize-input normalize PHOTOMAKER input id images\n"); - printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now.\n"); + printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now\n"); printf(" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n"); - printf(" --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k)\n"); - printf(" If not specified, the default is the type of the weight file.\n"); + printf(" --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)\n"); + printf(" If not specified, the default is the type of the weight file\n"); printf(" --lora-model-dir [DIR] lora model directory\n"); printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n"); + printf(" --mask [MASK] path to the mask image, required by img2img with mask\n"); printf(" --control-image [IMAGE] path to image condition, control net\n"); printf(" -o, --output OUTPUT path to write result image to (default: ./output.png)\n"); printf(" -p, --prompt [PROMPT] the prompt to render\n"); printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n"); printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n"); + printf(" --guidance SCALE guidance scale for img2img (default: 3.5)\n"); + printf(" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n"); + printf(" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n"); + printf(" --eta SCALE eta in DDIM, only for DDIM and TCD: (default: 0)\n"); + printf(" --skip-layers LAYERS Layers to skip for SLG steps: (default: [7,8,9])\n"); + printf(" --skip-layer-start START SLG enabling point: (default: 0.01)\n"); + printf(" --skip-layer-end END SLG disabling point: (default: 0.2)\n"); + printf(" SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END])\n"); printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n"); printf(" --style-ratio STYLE-RATIO strength for keeping input identity (default: 20%%)\n"); printf(" --control-strength STRENGTH strength to apply Control Net (default: 0.9)\n"); printf(" 1.0 corresponds to full destruction of information in init image\n"); printf(" -H, --height H image height, in pixel space (default: 512)\n"); printf(" -W, --width W image width, in pixel space (default: 512)\n"); - printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm}\n"); + printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd}\n"); printf(" sampling method (default: \"euler_a\")\n"); printf(" --steps STEPS number of sample steps (default: 20)\n"); printf(" --rng {std_default, cuda} RNG (default: cuda)\n"); printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n"); - printf(" -b, --batch-count COUNT number of images to generate.\n"); + printf(" -b, --batch-count COUNT number of images to generate\n"); printf(" --schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete)\n"); printf(" --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n"); printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n"); printf(" --vae-tiling process vae in tiles to reduce memory usage\n"); printf(" --vae-on-cpu keep vae in cpu (for low vram)\n"); - printf(" --clip-on-cpu keep clip in cpu (for low vram).\n"); + printf(" --clip-on-cpu keep clip in cpu (for low vram)\n"); + printf(" --diffusion-fa use flash attention in the diffusion model (for low vram)\n"); + printf(" Might lower quality, since it implies converting k and v to f16.\n"); + printf(" This might crash if it is not supported by the backend.\n"); printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n"); printf(" --canny apply canny preprocessor (edge detection)\n"); printf(" --color Colors the logging tags according to level\n"); @@ -262,6 +290,12 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.clip_l_path = argv[i]; + } else if (arg == "--clip_g") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.clip_g_path = argv[i]; } else if (arg == "--t5xxl") { if (++i >= argc) { invalid_arg = true; @@ -321,30 +355,30 @@ void parse_args(int argc, const char** argv, SDParams& params) { invalid_arg = true; break; } - std::string type = argv[i]; - if (type == "f32") { - params.wtype = SD_TYPE_F32; - } else if (type == "f16") { - params.wtype = SD_TYPE_F16; - } else if (type == "q4_0") { - params.wtype = SD_TYPE_Q4_0; - } else if (type == "q4_1") { - params.wtype = SD_TYPE_Q4_1; - } else if (type == "q5_0") { - params.wtype = SD_TYPE_Q5_0; - } else if (type == "q5_1") { - params.wtype = SD_TYPE_Q5_1; - } else if (type == "q8_0") { - params.wtype = SD_TYPE_Q8_0; - } else if (type == "q2_k") { - params.wtype = SD_TYPE_Q2_K; - } else if (type == "q3_k") { - params.wtype = SD_TYPE_Q3_K; - } else if (type == "q4_k") { - params.wtype = SD_TYPE_Q4_K; - } else { - fprintf(stderr, "error: invalid weight format %s, must be one of [f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k]\n", - type.c_str()); + std::string type = argv[i]; + bool found = false; + std::string valid_types = ""; + for (size_t i = 0; i < SD_TYPE_COUNT; i++) { + auto trait = ggml_get_type_traits((ggml_type)i); + std::string name(trait->type_name); + if (name == "f32" || trait->to_float && trait->type_size) { + if (i) + valid_types += ", "; + valid_types += name; + if (type == name) { + if (ggml_quantize_requires_imatrix((ggml_type)i)) { + printf("\033[35;1m[WARNING]\033[0m: type %s requires imatrix to work properly. A dummy imatrix will be used, expect poor quality.\n", trait->type_name); + } + params.wtype = (enum sd_type_t)i; + found = true; + break; + } + } + } + if (!found) { + fprintf(stderr, "error: invalid weight format %s, must be one of [%s]\n", + type.c_str(), + valid_types.c_str()); exit(1); } } else if (arg == "--lora-model-dir") { @@ -359,6 +393,12 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.input_path = argv[i]; + } else if (arg == "--mask") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.mask_path = argv[i]; } else if (arg == "--control-image") { if (++i >= argc) { invalid_arg = true; @@ -405,6 +445,12 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.guidance = std::stof(argv[i]); + } else if (arg == "--eta") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.eta = std::stof(argv[i]); } else if (arg == "--strength") { if (++i >= argc) { invalid_arg = true; @@ -457,6 +503,8 @@ void parse_args(int argc, const char** argv, SDParams& params) { params.clip_on_cpu = true; // will slow down get_learned_condiotion but necessary for low MEM GPUs } else if (arg == "--vae-on-cpu") { params.vae_on_cpu = true; // will slow down latent decoding but necessary for low MEM GPUs + } else if (arg == "--diffusion-fa") { + params.diffusion_flash_attn = true; // can reduce MEM significantly } else if (arg == "--canny") { params.canny_preprocess = true; } else if (arg == "-b" || arg == "--batch-count") { @@ -526,6 +574,61 @@ void parse_args(int argc, const char** argv, SDParams& params) { params.verbose = true; } else if (arg == "--color") { params.color = true; + } else if (arg == "--slg-scale") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.slg_scale = std::stof(argv[i]); + } else if (arg == "--skip-layers") { + if (++i >= argc) { + invalid_arg = true; + break; + } + if (argv[i][0] != '[') { + invalid_arg = true; + break; + } + std::string layers_str = argv[i]; + while (layers_str.back() != ']') { + if (++i >= argc) { + invalid_arg = true; + break; + } + layers_str += " " + std::string(argv[i]); + } + layers_str = layers_str.substr(1, layers_str.size() - 2); + + std::regex regex("[, ]+"); + std::sregex_token_iterator iter(layers_str.begin(), layers_str.end(), regex, -1); + std::sregex_token_iterator end; + std::vector tokens(iter, end); + std::vector layers; + for (const auto& token : tokens) { + try { + layers.push_back(std::stoi(token)); + } catch (const std::invalid_argument& e) { + invalid_arg = true; + break; + } + } + params.skip_layers = layers; + + if (invalid_arg) { + break; + } + } else if (arg == "--skip-layer-start") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.skip_layer_start = std::stof(argv[i]); + } else if (arg == "--skip-layer-end") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.skip_layer_end = std::stof(argv[i]); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); print_usage(argc, argv); @@ -616,7 +719,18 @@ std::string get_image_params(SDParams params, int64_t seed) { } parameter_string += "Steps: " + std::to_string(params.sample_steps) + ", "; parameter_string += "CFG scale: " + std::to_string(params.cfg_scale) + ", "; + if (params.slg_scale != 0 && params.skip_layers.size() != 0) { + parameter_string += "SLG scale: " + std::to_string(params.cfg_scale) + ", "; + parameter_string += "Skip layers: ["; + for (const auto& layer : params.skip_layers) { + parameter_string += std::to_string(layer) + ", "; + } + parameter_string += "], "; + parameter_string += "Skip layer start: " + std::to_string(params.skip_layer_start) + ", "; + parameter_string += "Skip layer end: " + std::to_string(params.skip_layer_end) + ", "; + } parameter_string += "Guidance: " + std::to_string(params.guidance) + ", "; + parameter_string += "Eta: " + std::to_string(params.eta) + ", "; parameter_string += "Seed: " + std::to_string(seed) + ", "; parameter_string += "Size: " + std::to_string(params.width) + "x" + std::to_string(params.height) + ", "; parameter_string += "Model: " + sd_basename(params.model_path) + ", "; @@ -711,6 +825,8 @@ int main(int argc, const char* argv[]) { bool vae_decode_only = true; uint8_t* input_image_buffer = NULL; uint8_t* control_image_buffer = NULL; + uint8_t* mask_image_buffer = NULL; + if (params.mode == IMG2IMG || params.mode == IMG2VID) { vae_decode_only = false; @@ -765,6 +881,7 @@ int main(int argc, const char* argv[]) { sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(), params.clip_l_path.c_str(), + params.clip_g_path.c_str(), params.t5xxl_path.c_str(), params.diffusion_model_path.c_str(), params.vae_path.c_str(), @@ -782,7 +899,8 @@ int main(int argc, const char* argv[]) { params.schedule, params.clip_on_cpu, params.control_net_cpu, - params.vae_on_cpu); + params.vae_on_cpu, + params.diffusion_flash_attn); if (sd_ctx == NULL) { printf("new_sd_ctx_t failed\n"); @@ -813,6 +931,18 @@ int main(int argc, const char* argv[]) { } } + std::vector default_mask_image_vec(params.width * params.height, 255); + if (params.mask_path != "") { + int c = 0; + mask_image_buffer = stbi_load(params.mask_path.c_str(), ¶ms.width, ¶ms.height, &c, 1); + } else { + mask_image_buffer = default_mask_image_vec.data(); + } + sd_image_t mask_image = {(uint32_t)params.width, + (uint32_t)params.height, + 1, + mask_image_buffer}; + sd_image_t* results; if (params.mode == TXT2IMG) { results = txt2img(sd_ctx, @@ -821,6 +951,7 @@ int main(int argc, const char* argv[]) { params.clip_skip, params.cfg_scale, params.guidance, + params.eta, params.width, params.height, params.sample_method, @@ -831,7 +962,12 @@ int main(int argc, const char* argv[]) { params.control_strength, params.style_ratio, params.normalize_input, - params.input_id_images_path.c_str()); + params.input_id_images_path.c_str(), + params.skip_layers.data(), + params.skip_layers.size(), + params.slg_scale, + params.skip_layer_start, + params.skip_layer_end); } else { sd_image_t input_image = {(uint32_t)params.width, (uint32_t)params.height, @@ -877,11 +1013,13 @@ int main(int argc, const char* argv[]) { } else { results = img2img(sd_ctx, input_image, + mask_image, params.prompt.c_str(), params.negative_prompt.c_str(), params.clip_skip, params.cfg_scale, params.guidance, + params.eta, params.width, params.height, params.sample_method, @@ -893,7 +1031,12 @@ int main(int argc, const char* argv[]) { params.control_strength, params.style_ratio, params.normalize_input, - params.input_id_images_path.c_str()); + params.input_id_images_path.c_str(), + params.skip_layers.data(), + params.skip_layers.size(), + params.slg_scale, + params.skip_layer_start, + params.skip_layer_end); } } @@ -906,8 +1049,7 @@ int main(int argc, const char* argv[]) { int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth if (params.esrgan_path.size() > 0 && params.upscale_repeats > 0) { upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(params.esrgan_path.c_str(), - params.n_threads, - params.wtype); + params.n_threads); if (upscaler_ctx == NULL) { printf("new_upscaler_ctx failed\n"); @@ -931,16 +1073,41 @@ int main(int argc, const char* argv[]) { } } - size_t last = params.output_path.find_last_of("."); - std::string dummy_name = last != std::string::npos ? params.output_path.substr(0, last) : params.output_path; + std::string dummy_name, ext, lc_ext; + bool is_jpg; + size_t last = params.output_path.find_last_of("."); + size_t last_path = std::min(params.output_path.find_last_of("/"), + params.output_path.find_last_of("\\")); + if (last != std::string::npos // filename has extension + && (last_path == std::string::npos || last > last_path)) { + dummy_name = params.output_path.substr(0, last); + ext = lc_ext = params.output_path.substr(last); + std::transform(ext.begin(), ext.end(), lc_ext.begin(), ::tolower); + is_jpg = lc_ext == ".jpg" || lc_ext == ".jpeg" || lc_ext == ".jpe"; + } else { + dummy_name = params.output_path; + ext = lc_ext = ""; + is_jpg = false; + } + // appending ".png" to absent or unknown extension + if (!is_jpg && lc_ext != ".png") { + dummy_name += ext; + ext = ".png"; + } for (int i = 0; i < params.batch_count; i++) { if (results[i].data == NULL) { continue; } - std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ".png" : dummy_name + ".png"; - stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel, - results[i].data, 0, get_image_params(params, params.seed + i).c_str()); - printf("save result image to '%s'\n", final_image_path.c_str()); + std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ext : dummy_name + ext; + if(is_jpg) { + stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel, + results[i].data, 90, get_image_params(params, params.seed + i).c_str()); + printf("save result JPEG image to '%s'\n", final_image_path.c_str()); + } else { + stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel, + results[i].data, 0, get_image_params(params, params.seed + i).c_str()); + printf("save result PNG image to '%s'\n", final_image_path.c_str()); + } free(results[i].data); results[i].data = NULL; } diff --git a/face_detect.py b/face_detect.py new file mode 100644 index 000000000..7131af31f --- /dev/null +++ b/face_detect.py @@ -0,0 +1,88 @@ +import os +import sys + +import numpy as np +import torch +from diffusers.utils import load_image +# pip install insightface==0.7.3 +from insightface.app import FaceAnalysis +from insightface.data import get_image as ins_get_image +from safetensors.torch import save_file + +### +# https://github.com/cubiq/ComfyUI_IPAdapter_plus/issues/165#issue-2055829543 +### +class FaceAnalysis2(FaceAnalysis): + # NOTE: allows setting det_size for each detection call. + # the model allows it but the wrapping code from insightface + # doesn't show it, and people end up loading duplicate models + # for different sizes where there is absolutely no need to + def get(self, img, max_num=0, det_size=(640, 640)): + if det_size is not None: + self.det_model.input_size = det_size + + return super().get(img, max_num) + +def analyze_faces(face_analysis: FaceAnalysis, img_data: np.ndarray, det_size=(640, 640)): + # NOTE: try detect faces, if no faces detected, lower det_size until it does + detection_sizes = [None] + [(size, size) for size in range(640, 256, -64)] + [(256, 256)] + + for size in detection_sizes: + faces = face_analysis.get(img_data, det_size=size) + if len(faces) > 0: + return faces + + return [] + +if __name__ == "__main__": + #face_detector = FaceAnalysis2(providers=['CUDAExecutionProvider'], allowed_modules=['detection', 'recognition']) + face_detector = FaceAnalysis2(providers=['CPUExecutionProvider'], allowed_modules=['detection', 'recognition']) + face_detector.prepare(ctx_id=0, det_size=(640, 640)) + #input_folder_name = './scarletthead_woman' + input_folder_name = sys.argv[1] + image_basename_list = os.listdir(input_folder_name) + image_path_list = sorted([os.path.join(input_folder_name, basename) for basename in image_basename_list]) + + input_id_images = [] + for image_path in image_path_list: + input_id_images.append(load_image(image_path)) + + id_embed_list = [] + + for img in input_id_images: + img = np.array(img) + img = img[:, :, ::-1] + faces = analyze_faces(face_detector, img) + if len(faces) > 0: + id_embed_list.append(torch.from_numpy((faces[0]['embedding']))) + + if len(id_embed_list) == 0: + raise ValueError(f"No face detected in input image pool") + + id_embeds = torch.stack(id_embed_list) + + # for r in id_embeds: + # print(r) + # #torch.save(id_embeds, input_folder_name+'/id_embeds.pt'); + # weights = dict() + # weights["id_embeds"] = id_embeds + # save_file(weights, input_folder_name+'/id_embeds.safetensors') + + binary_data = id_embeds.numpy().tobytes() + two = 4 + zero = 0 + one = 1 + tensor_name = "id_embeds" +# Write binary data to a file + with open(input_folder_name+'/id_embeds.bin', "wb") as f: + f.write(two.to_bytes(4, byteorder='little')) + f.write((len(tensor_name)).to_bytes(4, byteorder='little')) + f.write(zero.to_bytes(4, byteorder='little')) + f.write((id_embeds.shape[1]).to_bytes(4, byteorder='little')) + f.write((id_embeds.shape[0]).to_bytes(4, byteorder='little')) + f.write(one.to_bytes(4, byteorder='little')) + f.write(one.to_bytes(4, byteorder='little')) + f.write(tensor_name.encode('ascii')) + f.write(binary_data) + + \ No newline at end of file diff --git a/flux.hpp b/flux.hpp index 73bc345a7..20ff41096 100644 --- a/flux.hpp +++ b/flux.hpp @@ -35,8 +35,9 @@ namespace Flux { int64_t hidden_size; float eps; - void init_params(struct ggml_context* ctx, ggml_type wtype) { - params["scale"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size); + void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { + ggml_type wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "scale") != tensor_types.end()) ? tensor_types[prefix + "scale"] : GGML_TYPE_F32; + params["scale"] = ggml_new_tensor_1d(ctx, wtype, hidden_size); } public: @@ -115,25 +116,28 @@ namespace Flux { struct ggml_tensor* q, struct ggml_tensor* k, struct ggml_tensor* v, - struct ggml_tensor* pe) { + struct ggml_tensor* pe, + bool flash_attn) { // q,k,v: [N, L, n_head, d_head] // pe: [L, d_head/2, 2, 2] // return: [N, L, n_head*d_head] q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head] k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head] - auto x = ggml_nn_attention_ext(ctx, q, k, v, v->ne[1], NULL, false, true); // [N, L, n_head*d_head] + auto x = ggml_nn_attention_ext(ctx, q, k, v, v->ne[1], NULL, false, true, flash_attn); // [N, L, n_head*d_head] return x; } struct SelfAttention : public GGMLBlock { public: int64_t num_heads; + bool flash_attn; public: SelfAttention(int64_t dim, int64_t num_heads = 8, - bool qkv_bias = false) + bool qkv_bias = false, + bool flash_attn = false) : num_heads(num_heads) { int64_t head_dim = dim / num_heads; blocks["qkv"] = std::shared_ptr(new Linear(dim, dim * 3, qkv_bias)); @@ -167,9 +171,9 @@ namespace Flux { // x: [N, n_token, dim] // pe: [n_token, d_head/2, 2, 2] // return [N, n_token, dim] - auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head] - x = attention(ctx, qkv[0], qkv[1], qkv[2], pe); // [N, n_token, dim] - x = post_attention(ctx, x); // [N, n_token, dim] + auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head] + x = attention(ctx, qkv[0], qkv[1], qkv[2], pe, flash_attn); // [N, n_token, dim] + x = post_attention(ctx, x); // [N, n_token, dim] return x; } }; @@ -237,15 +241,19 @@ namespace Flux { } struct DoubleStreamBlock : public GGMLBlock { + bool flash_attn; + public: DoubleStreamBlock(int64_t hidden_size, int64_t num_heads, float mlp_ratio, - bool qkv_bias = false) { + bool qkv_bias = false, + bool flash_attn = false) + : flash_attn(flash_attn) { int64_t mlp_hidden_dim = hidden_size * mlp_ratio; blocks["img_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); blocks["img_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); - blocks["img_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias)); + blocks["img_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn)); blocks["img_norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); blocks["img_mlp.0"] = std::shared_ptr(new Linear(hidden_size, mlp_hidden_dim)); @@ -254,7 +262,7 @@ namespace Flux { blocks["txt_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); blocks["txt_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); - blocks["txt_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias)); + blocks["txt_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn)); blocks["txt_norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); blocks["txt_mlp.0"] = std::shared_ptr(new Linear(hidden_size, mlp_hidden_dim)); @@ -316,7 +324,7 @@ namespace Flux { auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] - auto attn = attention(ctx, q, k, v, pe); // [N, n_txt_token + n_img_token, n_head*d_head] + auto attn = attention(ctx, q, k, v, pe, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head] attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] auto txt_attn_out = ggml_view_3d(ctx, attn, @@ -364,13 +372,15 @@ namespace Flux { int64_t num_heads; int64_t hidden_size; int64_t mlp_hidden_dim; + bool flash_attn; public: SingleStreamBlock(int64_t hidden_size, int64_t num_heads, float mlp_ratio = 4.0f, - float qk_scale = 0.f) - : hidden_size(hidden_size), num_heads(num_heads) { + float qk_scale = 0.f, + bool flash_attn = false) + : hidden_size(hidden_size), num_heads(num_heads), flash_attn(flash_attn) { int64_t head_dim = hidden_size / num_heads; float scale = qk_scale; if (scale <= 0.f) { @@ -433,7 +443,7 @@ namespace Flux { auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head] q = norm->query_norm(ctx, q); k = norm->key_norm(ctx, k); - auto attn = attention(ctx, q, k, v, pe); // [N, n_token, hidden_size] + auto attn = attention(ctx, q, k, v, pe, flash_attn); // [N, n_token, hidden_size] auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim] auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size] @@ -480,6 +490,7 @@ namespace Flux { struct FluxParams { int64_t in_channels = 64; + int64_t out_channels = 64; int64_t vec_in_dim = 768; int64_t context_in_dim = 4096; int64_t hidden_size = 3072; @@ -492,6 +503,7 @@ namespace Flux { int theta = 10000; bool qkv_bias = true; bool guidance_embed = true; + bool flash_attn = true; }; struct Flux : public GGMLBlock { @@ -631,8 +643,7 @@ namespace Flux { Flux() {} Flux(FluxParams params) : params(params) { - int64_t out_channels = params.in_channels; - int64_t pe_dim = params.hidden_size / params.num_heads; + int64_t pe_dim = params.hidden_size / params.num_heads; blocks["img_in"] = std::shared_ptr(new Linear(params.in_channels, params.hidden_size, true)); blocks["time_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); @@ -646,16 +657,19 @@ namespace Flux { blocks["double_blocks." + std::to_string(i)] = std::shared_ptr(new DoubleStreamBlock(params.hidden_size, params.num_heads, params.mlp_ratio, - params.qkv_bias)); + params.qkv_bias, + params.flash_attn)); } for (int i = 0; i < params.depth_single_blocks; i++) { blocks["single_blocks." + std::to_string(i)] = std::shared_ptr(new SingleStreamBlock(params.hidden_size, params.num_heads, - params.mlp_ratio)); + params.mlp_ratio, + 0.f, + params.flash_attn)); } - blocks["final_layer"] = std::shared_ptr(new LastLayer(params.hidden_size, 1, out_channels)); + blocks["final_layer"] = std::shared_ptr(new LastLayer(params.hidden_size, 1, params.out_channels)); } struct ggml_tensor* patchify(struct ggml_context* ctx, @@ -711,7 +725,8 @@ namespace Flux { struct ggml_tensor* timesteps, struct ggml_tensor* y, struct ggml_tensor* guidance, - struct ggml_tensor* pe) { + struct ggml_tensor* pe, + std::vector skip_layers = std::vector()) { auto img_in = std::dynamic_pointer_cast(blocks["img_in"]); auto time_in = std::dynamic_pointer_cast(blocks["time_in"]); auto vector_in = std::dynamic_pointer_cast(blocks["vector_in"]); @@ -733,6 +748,10 @@ namespace Flux { txt = txt_in->forward(ctx, txt); for (int i = 0; i < params.depth; i++) { + if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i) != skip_layers.end()) { + continue; + } + auto block = std::dynamic_pointer_cast(blocks["double_blocks." + std::to_string(i)]); auto img_txt = block->forward(ctx, img, txt, vec, pe); @@ -742,6 +761,9 @@ namespace Flux { auto txt_img = ggml_concat(ctx, txt, img, 1); // [N, n_txt_token + n_img_token, hidden_size] for (int i = 0; i < params.depth_single_blocks; i++) { + if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i + params.depth) != skip_layers.end()) { + continue; + } auto block = std::dynamic_pointer_cast(blocks["single_blocks." + std::to_string(i)]); txt_img = block->forward(ctx, txt_img, vec, pe); @@ -767,13 +789,16 @@ namespace Flux { struct ggml_tensor* x, struct ggml_tensor* timestep, struct ggml_tensor* context, + struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, - struct ggml_tensor* pe) { + struct ggml_tensor* pe, + std::vector skip_layers = std::vector()) { // Forward pass of DiT. // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) // timestep: (N,) tensor of diffusion timesteps // context: (N, L, D) + // c_concat: NULL, or for (N,C+M, H, W) for Fill // y: (N, adm_in_channels) tensor of class labels // guidance: (N,) // pe: (L, d_head/2, 2, 2) @@ -783,6 +808,7 @@ namespace Flux { int64_t W = x->ne[0]; int64_t H = x->ne[1]; + int64_t C = x->ne[2]; int64_t patch_size = 2; int pad_h = (patch_size - H % patch_size) % patch_size; int pad_w = (patch_size - W % patch_size) % patch_size; @@ -791,7 +817,20 @@ namespace Flux { // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size] - auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe); // [N, h*w, C * patch_size * patch_size] + if (c_concat != NULL) { + ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0); + ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); + + masked = ggml_pad(ctx, masked, pad_w, pad_h, 0, 0); + mask = ggml_pad(ctx, mask, pad_w, pad_h, 0, 0); + + masked = patchify(ctx, masked, patch_size); + mask = patchify(ctx, mask, patch_size); + + img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0); + } + + auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size] // rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2) out = unpatchify(ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size, patch_size); // [N, C, H + pad_h, W + pad_w] @@ -801,20 +840,59 @@ namespace Flux { }; struct FluxRunner : public GGMLRunner { + static std::map empty_tensor_types; + public: FluxParams flux_params; Flux flux; std::vector pe_vec; // for cache FluxRunner(ggml_backend_t backend, - ggml_type wtype, - SDVersion version = VERSION_FLUX_DEV) - : GGMLRunner(backend, wtype) { - if (version == VERSION_FLUX_SCHNELL) { - flux_params.guidance_embed = false; + std::map& tensor_types = empty_tensor_types, + const std::string prefix = "", + SDVersion version = VERSION_FLUX, + bool flash_attn = false) + : GGMLRunner(backend) { + flux_params.flash_attn = flash_attn; + flux_params.guidance_embed = false; + flux_params.depth = 0; + flux_params.depth_single_blocks = 0; + if (version == VERSION_FLUX_FILL) { + flux_params.in_channels = 384; } + for (auto pair : tensor_types) { + std::string tensor_name = pair.first; + if (tensor_name.find("model.diffusion_model.") == std::string::npos) + continue; + if (tensor_name.find("guidance_in.in_layer.weight") != std::string::npos) { + // not schnell + flux_params.guidance_embed = true; + } + size_t db = tensor_name.find("double_blocks."); + if (db != std::string::npos) { + tensor_name = tensor_name.substr(db); // remove prefix + int block_depth = atoi(tensor_name.substr(14, tensor_name.find(".", 14)).c_str()); + if (block_depth + 1 > flux_params.depth) { + flux_params.depth = block_depth + 1; + } + } + size_t sb = tensor_name.find("single_blocks."); + if (sb != std::string::npos) { + tensor_name = tensor_name.substr(sb); // remove prefix + int block_depth = atoi(tensor_name.substr(14, tensor_name.find(".", 14)).c_str()); + if (block_depth + 1 > flux_params.depth_single_blocks) { + flux_params.depth_single_blocks = block_depth + 1; + } + } + } + + LOG_INFO("Flux blocks: %d double, %d single", flux_params.depth, flux_params.depth_single_blocks); + if (!flux_params.guidance_embed) { + LOG_INFO("Flux guidance is disabled (Schnell mode)"); + } + flux = Flux(flux_params); - flux.init(params_ctx, wtype); + flux.init(params_ctx, tensor_types, prefix); } std::string get_desc() { @@ -828,13 +906,18 @@ namespace Flux { struct ggml_cgraph* build_graph(struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, + struct ggml_tensor* c_concat, struct ggml_tensor* y, - struct ggml_tensor* guidance) { + struct ggml_tensor* guidance, + std::vector skip_layers = std::vector()) { GGML_ASSERT(x->ne[3] == 1); struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false); - x = to_backend(x); - context = to_backend(context); + x = to_backend(x); + context = to_backend(context); + if (c_concat != NULL) { + c_concat = to_backend(c_concat); + } y = to_backend(y); timesteps = to_backend(timesteps); if (flux_params.guidance_embed) { @@ -854,9 +937,11 @@ namespace Flux { x, timesteps, context, + c_concat, y, guidance, - pe); + pe, + skip_layers); ggml_build_forward_expand(gf, out); @@ -867,17 +952,19 @@ namespace Flux { struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, + struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL) { + struct ggml_context* output_ctx = NULL, + std::vector skip_layers = std::vector()) { // x: [N, in_channels, h, w] // timesteps: [N, ] // context: [N, max_position, hidden_size] // y: [N, adm_in_channels] or [1, adm_in_channels] // guidance: [N, ] auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(x, timesteps, context, y, guidance); + return build_graph(x, timesteps, context, c_concat, y, guidance, skip_layers); }; GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); @@ -917,7 +1004,7 @@ namespace Flux { struct ggml_tensor* out = NULL; int t0 = ggml_time_ms(); - compute(8, x, timesteps, context, y, guidance, &out, work_ctx); + compute(8, x, timesteps, context, NULL, y, guidance, &out, work_ctx); int t1 = ggml_time_ms(); print_ggml_tensor(out); @@ -929,7 +1016,7 @@ namespace Flux { // ggml_backend_t backend = ggml_backend_cuda_init(0); ggml_backend_t backend = ggml_backend_cpu_init(); ggml_type model_data_type = GGML_TYPE_Q8_0; - std::shared_ptr flux = std::shared_ptr(new FluxRunner(backend, model_data_type)); + std::shared_ptr flux = std::shared_ptr(new FluxRunner(backend)); { LOG_INFO("loading from '%s'", file_path.c_str()); @@ -958,4 +1045,4 @@ namespace Flux { } // namespace Flux -#endif // __FLUX_HPP__ \ No newline at end of file +#endif // __FLUX_HPP__ diff --git a/ggml b/ggml index 21d3a308f..ff9052988 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit 21d3a308fcb7f31cb9beceaeebad4fb622f3c337 +Subproject commit ff9052988b76e137bcf92bb335733933ca196ac0 diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 810f2b9ef..c5913be4d 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -22,9 +22,12 @@ #include "ggml-alloc.h" #include "ggml-backend.h" +#include "ggml-cpu.h" #include "ggml.h" -#ifdef SD_USE_CUBLAS +#include "model.h" + +#ifdef SD_USE_CUDA #include "ggml-cuda.h" #endif @@ -49,6 +52,71 @@ #define __STATIC_INLINE__ static inline #endif +// n-mode trensor-matrix product +// example: 2-mode product +// A: [ne03, k, ne01, ne00] +// B: k rows, m columns => [k, m] +// result is [ne03, m, ne01, ne00] +__STATIC_INLINE__ struct ggml_tensor* ggml_mul_n_mode(struct ggml_context* ctx, struct ggml_tensor* a, struct ggml_tensor* b, int mode = 0) { + // reshape A + // swap 0th and nth axis + a = ggml_cont(ctx, ggml_permute(ctx, a, mode, mode != 1 ? 1 : 0, mode != 2 ? 2 : 0, mode != 3 ? 3 : 0)); + int ne1 = a->ne[1]; + int ne2 = a->ne[2]; + int ne3 = a->ne[3]; + // make 2D + a = ggml_cont(ctx, ggml_reshape_2d(ctx, a, a->ne[0], (ne3 * ne2 * ne1))); + + struct ggml_tensor* result = ggml_cont(ctx, ggml_transpose(ctx, ggml_mul_mat(ctx, a, b))); + + // reshape output (same shape as a after permutation except first dim) + result = ggml_reshape_4d(ctx, result, result->ne[0], ne1, ne2, ne3); + // swap back 0th and nth axis + result = ggml_permute(ctx, result, mode, mode != 1 ? 1 : 0, mode != 2 ? 2 : 0, mode != 3 ? 3 : 0); + return result; +} + +__STATIC_INLINE__ struct ggml_tensor* ggml_merge_lora(ggml_context* ctx, struct ggml_tensor* lora_down, struct ggml_tensor* lora_up, struct ggml_tensor* lora_mid = NULL) { + struct ggml_tensor* updown; + // flat lora tensors to multiply it + int64_t lora_up_rows = lora_up->ne[ggml_n_dims(lora_up) - 1]; + lora_up = ggml_reshape_2d(ctx, lora_up, ggml_nelements(lora_up) / lora_up_rows, lora_up_rows); + auto lora_down_n_dims = ggml_n_dims(lora_down); + // assume n_dims should always be a multiple of 2 (otherwise rank 1 doesn't work) + lora_down_n_dims = (lora_down_n_dims + lora_down_n_dims % 2); + int64_t lora_down_rows = lora_down->ne[lora_down_n_dims - 1]; + lora_down = ggml_reshape_2d(ctx, lora_down, ggml_nelements(lora_down) / lora_down_rows, lora_down_rows); + + // ggml_mul_mat requires tensor b transposed + lora_down = ggml_cont(ctx, ggml_transpose(ctx, lora_down)); + if (lora_mid == NULL) { + updown = ggml_mul_mat(ctx, lora_up, lora_down); + updown = ggml_cont(ctx, ggml_transpose(ctx, updown)); + } else { + // undoing tucker decomposition for conv layers. + // lora_mid has shape (3, 3, Rank, Rank) + // lora_down has shape (Rank, In, 1, 1) + // lora_up has shape (Rank, Out, 1, 1) + // conv layer shape is (3, 3, Out, In) + updown = ggml_mul_n_mode(ctx, ggml_mul_n_mode(ctx, lora_mid, lora_down, 3), lora_up, 2); + updown = ggml_cont(ctx, updown); + } + return updown; +} + +// Kronecker product +// [ne03,ne02,ne01,ne00] x [ne13,ne12,ne11,ne10] => [ne03*ne13,ne02*ne12,ne01*ne11,ne00*ne10] +__STATIC_INLINE__ struct ggml_tensor* ggml_kronecker(ggml_context* ctx, struct ggml_tensor* a, struct ggml_tensor* b) { + return ggml_mul(ctx, + ggml_upscale_ext(ctx, + a, + a->ne[0] * b->ne[0], + a->ne[1] * b->ne[1], + a->ne[2] * b->ne[2], + a->ne[3] * b->ne[3]), + b); +} + __STATIC_INLINE__ void ggml_log_callback_default(ggml_log_level level, const char* text, void* user_data) { (void)level; (void)user_data; @@ -100,17 +168,11 @@ __STATIC_INLINE__ ggml_fp16_t ggml_tensor_get_f16(const ggml_tensor* tensor, int static struct ggml_tensor* get_tensor_from_graph(struct ggml_cgraph* gf, const char* name) { struct ggml_tensor* res = NULL; - for (int i = 0; i < gf->n_nodes; i++) { - // printf("%d, %s \n", i, gf->nodes[i]->name); - if (strcmp(ggml_get_name(gf->nodes[i]), name) == 0) { - res = gf->nodes[i]; - break; - } - } - for (int i = 0; i < gf->n_leafs; i++) { - // printf("%d, %s \n", i, gf->leafs[i]->name); - if (strcmp(ggml_get_name(gf->leafs[i]), name) == 0) { - res = gf->leafs[i]; + for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { + struct ggml_tensor* node = ggml_graph_node(gf, i); + // printf("%d, %s \n", i, ggml_get_name(node)); + if (strcmp(ggml_get_name(node), name) == 0) { + res = node; break; } } @@ -293,6 +355,44 @@ __STATIC_INLINE__ void sd_image_to_tensor(const uint8_t* image_data, } } +__STATIC_INLINE__ void sd_mask_to_tensor(const uint8_t* image_data, + struct ggml_tensor* output, + bool scale = true) { + int64_t width = output->ne[0]; + int64_t height = output->ne[1]; + int64_t channels = output->ne[2]; + GGML_ASSERT(channels == 1 && output->type == GGML_TYPE_F32); + for (int iy = 0; iy < height; iy++) { + for (int ix = 0; ix < width; ix++) { + float value = *(image_data + iy * width * channels + ix); + if (scale) { + value /= 255.f; + } + ggml_tensor_set_f32(output, value, ix, iy); + } + } +} + +__STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data, + struct ggml_tensor* mask, + struct ggml_tensor* output) { + int64_t width = output->ne[0]; + int64_t height = output->ne[1]; + int64_t channels = output->ne[2]; + GGML_ASSERT(output->type == GGML_TYPE_F32); + for (int ix = 0; ix < width; ix++) { + for (int iy = 0; iy < height; iy++) { + float m = ggml_tensor_get_f32(mask, ix, iy); + m = round(m); // inpaint models need binary masks + ggml_tensor_set_f32(mask, m, ix, iy); + for (int k = 0; k < channels; k++) { + float value = (1 - m) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5; + ggml_tensor_set_f32(output, value, ix, iy, k); + } + } + } +} + __STATIC_INLINE__ void sd_mul_images_to_tensor(const uint8_t* image_data, struct ggml_tensor* output, int idx, @@ -368,8 +468,8 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input, int64_t height = input->ne[1]; int64_t channels = input->ne[2]; - int64_t img_width = output->ne[0]; - int64_t img_height = output->ne[1]; + int64_t img_width = output->ne[0]; + int64_t img_height = output->ne[1]; GGML_ASSERT(input->type == GGML_TYPE_F32 && output->type == GGML_TYPE_F32); for (int iy = 0; iy < height; iy++) { @@ -380,7 +480,7 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input, float old_value = ggml_tensor_get_f32(output, x + ix, y + iy, k); const float x_f_0 = (x > 0) ? ix / float(overlap) : 1; - const float x_f_1 = (x < (img_width - width)) ? (width - ix) / float(overlap) : 1 ; + const float x_f_1 = (x < (img_width - width)) ? (width - ix) / float(overlap) : 1; const float y_f_0 = (y > 0) ? iy / float(overlap) : 1; const float y_f_1 = (y < (img_height - height)) ? (height - iy) / float(overlap) : 1; @@ -390,8 +490,7 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input, ggml_tensor_set_f32( output, old_value + new_value * ggml_smootherstep_f32(y_f) * ggml_smootherstep_f32(x_f), - x + ix, y + iy, k - ); + x + ix, y + iy, k); } else { ggml_tensor_set_f32(output, new_value, x + ix, y + iy, k); } @@ -676,18 +775,16 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx struct ggml_tensor* k, struct ggml_tensor* v, bool mask = false) { -#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL) && !defined(SD_USE_VULKAN) && !defined(SD_USE_SYCL) +#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUDA) && !defined(SD_USE_METAL) && !defined(SD_USE_VULKAN) && !defined(SD_USE_SYCL) struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, n_token, d_head] #else - float d_head = (float)q->ne[0]; - + float d_head = (float)q->ne[0]; struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, n_token, n_k] kq = ggml_scale_inplace(ctx, kq, 1.0f / sqrt(d_head)); if (mask) { kq = ggml_diag_mask_inf_inplace(ctx, kq, 0); } - kq = ggml_soft_max_inplace(ctx, kq); - + kq = ggml_soft_max_inplace(ctx, kq); struct ggml_tensor* kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, n_token, d_head] #endif return kqv; @@ -704,7 +801,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* int64_t n_head, struct ggml_tensor* mask = NULL, bool diag_mask_inf = false, - bool skip_reshape = false) { + bool skip_reshape = false, + bool flash_attn = false) { int64_t L_q; int64_t L_k; int64_t C; @@ -735,13 +833,42 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* float scale = (1.0f / sqrt((float)d_head)); - bool use_flash_attn = false; - ggml_tensor* kqv = NULL; - if (use_flash_attn) { + // if (flash_attn) { + // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N); + // } + // is there anything oddly shaped?? ping Green-Sky if you can trip this assert + GGML_ASSERT(((L_k % 256 == 0) && L_q == L_k) || !(L_k % 256 == 0)); + + bool can_use_flash_attn = true; + can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0; + can_use_flash_attn = can_use_flash_attn && d_head % 64 == 0; // double check + + // cuda max d_head seems to be 256, cpu does seem to work with 512 + can_use_flash_attn = can_use_flash_attn && d_head <= 256; // double check + + if (mask != nullptr) { + // TODO(Green-Sky): figure out if we can bend t5 to work too + can_use_flash_attn = can_use_flash_attn && mask->ne[2] == 1; + can_use_flash_attn = can_use_flash_attn && mask->ne[3] == 1; + } + + // TODO(Green-Sky): more pad or disable for funny tensor shapes + + ggml_tensor* kqv = nullptr; + // GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn); + if (can_use_flash_attn && flash_attn) { + // LOG_DEBUG("using flash attention"); + k = ggml_cast(ctx, k, GGML_TYPE_F16); + v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head] v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head] - LOG_DEBUG("k->ne[1] == %d", k->ne[1]); + v = ggml_cast(ctx, v, GGML_TYPE_F16); + kqv = ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0, 0); + ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32); + + // kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_k, kqv->nb[1], kqv->nb[2], 0); + kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_q, kqv->nb[1], kqv->nb[2], 0); } else { v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, L_k] v = ggml_reshape_3d(ctx, v, L_k, d_head, n_head * N); // [N * n_head, d_head, L_k] @@ -757,10 +884,12 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* kq = ggml_soft_max_inplace(ctx, kq); kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, L_q, d_head] + + kqv = ggml_reshape_4d(ctx, kqv, d_head, L_q, n_head, N); // [N, n_head, L_q, d_head] + kqv = ggml_permute(ctx, kqv, 0, 2, 1, 3); // [N, L_q, n_head, d_head] } - kqv = ggml_reshape_4d(ctx, kqv, d_head, L_q, n_head, N); // [N, n_head, L_q, d_head] - kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); // [N, L_q, n_head, d_head] + kqv = ggml_cont(ctx, kqv); kqv = ggml_reshape_3d(ctx, kqv, d_head * n_head, L_q, N); // [N, L_q, C] return kqv; @@ -802,7 +931,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_group_norm(struct ggml_context* ct } __STATIC_INLINE__ void ggml_backend_tensor_get_and_sync(ggml_backend_t backend, const struct ggml_tensor* tensor, void* data, size_t offset, size_t size) { -#if defined(SD_USE_CUBLAS) || defined(SD_USE_SYCL) +#if defined(SD_USE_CUDA) || defined(SD_USE_SYCL) if (!ggml_backend_is_cpu(backend)) { ggml_backend_tensor_get_async(backend, tensor, data, offset, size); ggml_backend_synchronize(backend); @@ -925,8 +1054,8 @@ __STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) { } /* SDXL with LoRA requires more space */ -#define MAX_PARAMS_TENSOR_NUM 15360 -#define MAX_GRAPH_SIZE 15360 +#define MAX_PARAMS_TENSOR_NUM 32768 +#define MAX_GRAPH_SIZE 32768 struct GGMLRunner { protected: @@ -940,7 +1069,6 @@ struct GGMLRunner { std::map backend_tensor_data_map; - ggml_type wtype = GGML_TYPE_F32; ggml_backend_t backend = NULL; void alloc_params_ctx() { @@ -1016,8 +1144,8 @@ struct GGMLRunner { public: virtual std::string get_desc() = 0; - GGMLRunner(ggml_backend_t backend, ggml_type wtype = GGML_TYPE_F32) - : backend(backend), wtype(wtype) { + GGMLRunner(ggml_backend_t backend) + : backend(backend) { alloc_params_ctx(); } @@ -1048,6 +1176,11 @@ struct GGMLRunner { params_buffer_size / (1024.0 * 1024.0), ggml_backend_is_cpu(backend) ? "RAM" : "VRAM", num_tensors); + // printf("%s params backend buffer size = % 6.2f MB(%s) (%i tensors)\n", + // get_desc().c_str(), + // params_buffer_size / (1024.0 * 1024.0), + // ggml_backend_is_cpu(backend) ? "RAM" : "VRAM", + // num_tensors); return true; } @@ -1108,18 +1241,12 @@ struct GGMLRunner { ggml_backend_cpu_set_n_threads(backend, n_threads); } -#ifdef SD_USE_METAL - if (ggml_backend_is_metal(backend)) { - ggml_backend_metal_set_n_cb(backend, n_threads); - } -#endif ggml_backend_graph_compute(backend, gf); - #ifdef GGML_PERF ggml_graph_print(gf); #endif if (output != NULL) { - auto result = gf->nodes[gf->n_nodes - 1]; + auto result = ggml_graph_node(gf, -1); if (*output == NULL && output_ctx != NULL) { *output = ggml_dup_tensor(output_ctx, result); } @@ -1141,20 +1268,22 @@ class GGMLBlock { GGMLBlockMap blocks; ParameterMap params; - void init_blocks(struct ggml_context* ctx, ggml_type wtype) { + void init_blocks(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { for (auto& pair : blocks) { auto& block = pair.second; - - block->init(ctx, wtype); + block->init(ctx, tensor_types, prefix + pair.first); } } - virtual void init_params(struct ggml_context* ctx, ggml_type wtype) {} + virtual void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") {} public: - void init(struct ggml_context* ctx, ggml_type wtype) { - init_blocks(ctx, wtype); - init_params(ctx, wtype); + void init(struct ggml_context* ctx, std::map& tensor_types, std::string prefix = "") { + if (prefix.size() > 0) { + prefix = prefix + "."; + } + init_blocks(ctx, tensor_types, prefix); + init_params(ctx, tensor_types, prefix); } size_t get_params_num() { @@ -1210,13 +1339,15 @@ class Linear : public UnaryBlock { bool bias; bool force_f32; - void init_params(struct ggml_context* ctx, ggml_type wtype) { + void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { + enum ggml_type wtype = (tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F32; if (in_features % ggml_blck_size(wtype) != 0 || force_f32) { wtype = GGML_TYPE_F32; } params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features); if (bias) { - params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_features); + enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.ypes.find(prefix + "bias") != tensor_types.end()) ? tensor_types[prefix + "bias"] : GGML_TYPE_F32; + params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_features); } } @@ -1244,9 +1375,9 @@ class Embedding : public UnaryBlock { protected: int64_t embedding_dim; int64_t num_embeddings; - - void init_params(struct ggml_context* ctx, ggml_type wtype) { - params["weight"] = ggml_new_tensor_2d(ctx, wtype, embedding_dim, num_embeddings); + void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { + enum ggml_type wtype = (tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F32; + params["weight"] = ggml_new_tensor_2d(ctx, wtype, embedding_dim, num_embeddings); } public: @@ -1284,10 +1415,12 @@ class Conv2d : public UnaryBlock { std::pair dilation; bool bias; - void init_params(struct ggml_context* ctx, ggml_type wtype) { - params["weight"] = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kernel_size.second, kernel_size.first, in_channels, out_channels); + void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { + enum ggml_type wtype = GGML_TYPE_F16; //(tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F16; + params["weight"] = ggml_new_tensor_4d(ctx, wtype, kernel_size.second, kernel_size.first, in_channels, out_channels); if (bias) { - params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); + enum ggml_type wtype = GGML_TYPE_F32; // (tensor_types.find(prefix + "bias") != tensor_types.end()) ? tensor_types[prefix + "bias"] : GGML_TYPE_F32; + params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_channels); } } @@ -1327,10 +1460,12 @@ class Conv3dnx1x1 : public UnaryBlock { int64_t dilation; bool bias; - void init_params(struct ggml_context* ctx, ggml_type wtype) { - params["weight"] = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, kernel_size, in_channels, out_channels); // 5d => 4d + void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { + enum ggml_type wtype = GGML_TYPE_F16; //(tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F16; + params["weight"] = ggml_new_tensor_4d(ctx, wtype, 1, kernel_size, in_channels, out_channels); // 5d => 4d if (bias) { - params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); + enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "bias") != tensor_types.end()) ? tensor_types[prefix + "bias"] : GGML_TYPE_F32; + params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_channels); } } @@ -1369,11 +1504,13 @@ class LayerNorm : public UnaryBlock { bool elementwise_affine; bool bias; - void init_params(struct ggml_context* ctx, ggml_type wtype) { + void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { if (elementwise_affine) { - params["weight"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, normalized_shape); + enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.ypes.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F32; + params["weight"] = ggml_new_tensor_1d(ctx, wtype, normalized_shape); if (bias) { - params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, normalized_shape); + enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.ypes.find(prefix + "bias") != tensor_types.end()) ? tensor_types[prefix + "bias"] : GGML_TYPE_F32; + params["bias"] = ggml_new_tensor_1d(ctx, wtype, normalized_shape); } } } @@ -1409,10 +1546,12 @@ class GroupNorm : public GGMLBlock { float eps; bool affine; - void init_params(struct ggml_context* ctx, ggml_type wtype) { + void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { if (affine) { - params["weight"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_channels); - params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_channels); + enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F32; + enum ggml_type bias_wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "bias") != tensor_types.end()) ? tensor_types[prefix + "bias"] : GGML_TYPE_F32; + params["weight"] = ggml_new_tensor_1d(ctx, wtype, num_channels); + params["bias"] = ggml_new_tensor_1d(ctx, bias_wtype, num_channels); } } diff --git a/gits_noise.inl b/gits_noise.inl index fd4750267..7a10ff76f 100644 --- a/gits_noise.inl +++ b/gits_noise.inl @@ -329,21 +329,21 @@ const std::vector> GITS_NOISE_1_50 = { }; const std::vector>*> GITS_NOISE = { - { &GITS_NOISE_0_80 }, - { &GITS_NOISE_0_85 }, - { &GITS_NOISE_0_90 }, - { &GITS_NOISE_0_95 }, - { &GITS_NOISE_1_00 }, - { &GITS_NOISE_1_05 }, - { &GITS_NOISE_1_10 }, - { &GITS_NOISE_1_15 }, - { &GITS_NOISE_1_20 }, - { &GITS_NOISE_1_25 }, - { &GITS_NOISE_1_30 }, - { &GITS_NOISE_1_35 }, - { &GITS_NOISE_1_40 }, - { &GITS_NOISE_1_45 }, - { &GITS_NOISE_1_50 } + &GITS_NOISE_0_80, + &GITS_NOISE_0_85, + &GITS_NOISE_0_90, + &GITS_NOISE_0_95, + &GITS_NOISE_1_00, + &GITS_NOISE_1_05, + &GITS_NOISE_1_10, + &GITS_NOISE_1_15, + &GITS_NOISE_1_20, + &GITS_NOISE_1_25, + &GITS_NOISE_1_30, + &GITS_NOISE_1_35, + &GITS_NOISE_1_40, + &GITS_NOISE_1_45, + &GITS_NOISE_1_50 }; #endif // GITS_NOISE_INL diff --git a/lora.hpp b/lora.hpp index c44db7698..d38c7116f 100644 --- a/lora.hpp +++ b/lora.hpp @@ -6,6 +6,90 @@ #define LORA_GRAPH_SIZE 10240 struct LoraModel : public GGMLRunner { + enum lora_t { + REGULAR = 0, + DIFFUSERS = 1, + DIFFUSERS_2 = 2, + DIFFUSERS_3 = 3, + TRANSFORMERS = 4, + LORA_TYPE_COUNT + }; + + const std::string lora_ups[LORA_TYPE_COUNT] = { + ".lora_up", + "_lora.up", + ".lora_B", + ".lora.up", + ".lora_linear_layer.up", + }; + + const std::string lora_downs[LORA_TYPE_COUNT] = { + ".lora_down", + "_lora.down", + ".lora_A", + ".lora.down", + ".lora_linear_layer.down", + }; + + const std::string lora_pre[LORA_TYPE_COUNT] = { + "lora.", + "", + "", + "", + "", + }; + + const std::map alt_names = { + // mmdit + {"final_layer.adaLN_modulation.1", "norm_out.linear"}, + {"pos_embed", "pos_embed.proj"}, + {"final_layer.linear", "proj_out"}, + {"y_embedder.mlp.0", "time_text_embed.text_embedder.linear_1"}, + {"y_embedder.mlp.2", "time_text_embed.text_embedder.linear_2"}, + {"t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1"}, + {"t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2"}, + {"x_block.mlp.fc1", "ff.net.0.proj"}, + {"x_block.mlp.fc2", "ff.net.2"}, + {"context_block.mlp.fc1", "ff_context.net.0.proj"}, + {"context_block.mlp.fc2", "ff_context.net.2"}, + {"x_block.adaLN_modulation.1", "norm1.linear"}, + {"context_block.adaLN_modulation.1", "norm1_context.linear"}, + {"context_block.attn.proj", "attn.to_add_out"}, + {"x_block.attn.proj", "attn.to_out.0"}, + {"x_block.attn2.proj", "attn2.to_out.0"}, + // flux + // singlestream + {"linear2", "proj_out"}, + {"modulation.lin", "norm.linear"}, + // doublestream + {"txt_attn.proj", "attn.to_add_out"}, + {"img_attn.proj", "attn.to_out.0"}, + {"txt_mlp.0", "ff_context.net.0.proj"}, + {"txt_mlp.2", "ff_context.net.2"}, + {"img_mlp.0", "ff.net.0.proj"}, + {"img_mlp.2", "ff.net.2"}, + {"txt_mod.lin", "norm1_context.linear"}, + {"img_mod.lin", "norm1.linear"}, + }; + + const std::map qkv_prefixes = { + // mmdit + {"context_block.attn.qkv", "attn.add_"}, // suffix "_proj" + {"x_block.attn.qkv", "attn.to_"}, + {"x_block.attn2.qkv", "attn2.to_"}, + // flux + // doublestream + {"txt_attn.qkv", "attn.add_"}, // suffix "_proj" + {"img_attn.qkv", "attn.to_"}, + }; + const std::map qkvm_prefixes = { + // flux + // singlestream + {"linear1", ""}, + }; + + const std::string* type_fingerprints = lora_ups; + float multiplier = 1.0f; std::map lora_tensors; std::string file_path; @@ -14,12 +98,12 @@ struct LoraModel : public GGMLRunner { bool applied = false; std::vector zero_index_vec = {0}; ggml_tensor* zero_index = NULL; + enum lora_t type = REGULAR; LoraModel(ggml_backend_t backend, - ggml_type wtype, const std::string& file_path = "", - const std::string& prefix = "") - : file_path(file_path), GGMLRunner(backend, wtype) { + const std::string prefix = "") + : file_path(file_path), GGMLRunner(backend) { if (!model_loader.init_from_file(file_path, prefix)) { load_failed = true; } @@ -45,6 +129,13 @@ struct LoraModel : public GGMLRunner { // LOG_INFO("skipping LoRA tesnor '%s'", name.c_str()); return true; } + // LOG_INFO("%s", name.c_str()); + for (int i = 0; i < LORA_TYPE_COUNT; i++) { + if (name.find(type_fingerprints[i]) != std::string::npos) { + type = (lora_t)i; + break; + } + } if (dry_run) { struct ggml_tensor* real = ggml_new_tensor(params_ctx, @@ -62,10 +153,12 @@ struct LoraModel : public GGMLRunner { model_loader.load_tensors(on_new_tensor_cb, backend); alloc_params_buffer(); - + // exit(0); dry_run = false; model_loader.load_tensors(on_new_tensor_cb, backend); + LOG_DEBUG("lora type: \"%s\"/\"%s\"", lora_downs[type].c_str(), lora_ups[type].c_str()); + LOG_DEBUG("finished loaded lora"); return true; } @@ -77,103 +170,653 @@ struct LoraModel : public GGMLRunner { return out; } - struct ggml_cgraph* build_lora_graph(std::map model_tensors) { - struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, LORA_GRAPH_SIZE, false); + std::vector to_lora_keys(std::string blk_name, SDVersion version) { + std::vector keys; + // if (!sd_version_is_sd3(version) || blk_name != "model.diffusion_model.pos_embed") { + size_t k_pos = blk_name.find(".weight"); + if (k_pos == std::string::npos) { + return keys; + } + blk_name = blk_name.substr(0, k_pos); + // } + keys.push_back(blk_name); + keys.push_back("lora." + blk_name); + if (sd_version_is_dit(version)) { + if (blk_name.find("model.diffusion_model") != std::string::npos) { + blk_name.replace(blk_name.find("model.diffusion_model"), sizeof("model.diffusion_model") - 1, "transformer"); + } - zero_index = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_I32, 1); - set_backend_tensor_data(zero_index, zero_index_vec.data()); - ggml_build_forward_expand(gf, zero_index); + if (blk_name.find(".single_blocks") != std::string::npos) { + blk_name.replace(blk_name.find(".single_blocks"), sizeof(".single_blocks") - 1, ".single_transformer_blocks"); + } + if (blk_name.find(".double_blocks") != std::string::npos) { + blk_name.replace(blk_name.find(".double_blocks"), sizeof(".double_blocks") - 1, ".transformer_blocks"); + } - std::set applied_lora_tensors; - for (auto it : model_tensors) { - std::string k_tensor = it.first; - struct ggml_tensor* weight = model_tensors[it.first]; + if (blk_name.find(".joint_blocks") != std::string::npos) { + blk_name.replace(blk_name.find(".joint_blocks"), sizeof(".joint_blocks") - 1, ".transformer_blocks"); + } - size_t k_pos = k_tensor.find(".weight"); - if (k_pos == std::string::npos) { - continue; + if (blk_name.find("text_encoders.clip_l") != std::string::npos) { + blk_name.replace(blk_name.find("text_encoders.clip_l"), sizeof("text_encoders.clip_l") - 1, "cond_stage_model"); + } + + for (const auto& item : alt_names) { + size_t match = blk_name.find(item.first); + if (match != std::string::npos) { + blk_name = blk_name.substr(0, match) + item.second; + } + } + for (const auto& prefix : qkv_prefixes) { + size_t match = blk_name.find(prefix.first); + if (match != std::string::npos) { + std::string split_blk = "SPLIT|" + blk_name.substr(0, match) + prefix.second; + keys.push_back(split_blk); + } } - k_tensor = k_tensor.substr(0, k_pos); - replace_all_chars(k_tensor, '.', '_'); - // LOG_DEBUG("k_tensor %s", k_tensor.c_str()); - std::string lora_up_name = "lora." + k_tensor + ".lora_up.weight"; - if (lora_tensors.find(lora_up_name) == lora_tensors.end()) { - if (k_tensor == "model_diffusion_model_output_blocks_2_2_conv") { - // fix for some sdxl lora, like lcm-lora-xl - k_tensor = "model_diffusion_model_output_blocks_2_1_conv"; - lora_up_name = "lora." + k_tensor + ".lora_up.weight"; + for (const auto& prefix : qkvm_prefixes) { + size_t match = blk_name.find(prefix.first); + if (match != std::string::npos) { + std::string split_blk = "SPLIT_L|" + blk_name.substr(0, match) + prefix.second; + keys.push_back(split_blk); } } + keys.push_back(blk_name); + } - std::string lora_down_name = "lora." + k_tensor + ".lora_down.weight"; - std::string alpha_name = "lora." + k_tensor + ".alpha"; - std::string scale_name = "lora." + k_tensor + ".scale"; + std::vector ret; + for (std::string& key : keys) { + ret.push_back(key); + replace_all_chars(key, '.', '_'); + // fix for some sdxl lora, like lcm-lora-xl + if (key == "model_diffusion_model_output_blocks_2_2_conv") { + ret.push_back("model_diffusion_model_output_blocks_2_1_conv"); + } + ret.push_back(key); + } + return ret; + } - ggml_tensor* lora_up = NULL; - ggml_tensor* lora_down = NULL; + struct ggml_cgraph* build_lora_graph(std::map model_tensors, SDVersion version) { + struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, LORA_GRAPH_SIZE, false); - if (lora_tensors.find(lora_up_name) != lora_tensors.end()) { - lora_up = lora_tensors[lora_up_name]; - } + zero_index = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_I32, 1); + set_backend_tensor_data(zero_index, zero_index_vec.data()); + ggml_build_forward_expand(gf, zero_index); - if (lora_tensors.find(lora_down_name) != lora_tensors.end()) { - lora_down = lora_tensors[lora_down_name]; - } + std::set applied_lora_tensors; + for (auto it : model_tensors) { + std::string k_tensor = it.first; + struct ggml_tensor* weight = model_tensors[it.first]; - if (lora_up == NULL || lora_down == NULL) { + std::vector keys = to_lora_keys(k_tensor, version); + if (keys.size() == 0) continue; - } - applied_lora_tensors.insert(lora_up_name); - applied_lora_tensors.insert(lora_down_name); - applied_lora_tensors.insert(alpha_name); - applied_lora_tensors.insert(scale_name); - - // calc_cale - int64_t dim = lora_down->ne[ggml_n_dims(lora_down) - 1]; - float scale_value = 1.0f; - if (lora_tensors.find(scale_name) != lora_tensors.end()) { - scale_value = ggml_backend_tensor_get_f32(lora_tensors[scale_name]); - } else if (lora_tensors.find(alpha_name) != lora_tensors.end()) { - float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]); - scale_value = alpha / dim; - } - scale_value *= multiplier; - - // flat lora tensors to multiply it - int64_t lora_up_rows = lora_up->ne[ggml_n_dims(lora_up) - 1]; - lora_up = ggml_reshape_2d(compute_ctx, lora_up, ggml_nelements(lora_up) / lora_up_rows, lora_up_rows); - int64_t lora_down_rows = lora_down->ne[ggml_n_dims(lora_down) - 1]; - lora_down = ggml_reshape_2d(compute_ctx, lora_down, ggml_nelements(lora_down) / lora_down_rows, lora_down_rows); - - // ggml_mul_mat requires tensor b transposed - lora_down = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, lora_down)); - struct ggml_tensor* updown = ggml_mul_mat(compute_ctx, lora_up, lora_down); - updown = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, updown)); - updown = ggml_reshape(compute_ctx, updown, weight); - GGML_ASSERT(ggml_nelements(updown) == ggml_nelements(weight)); - updown = ggml_scale_inplace(compute_ctx, updown, scale_value); - ggml_tensor* final_weight; - if (weight->type != GGML_TYPE_F32 && weight->type != GGML_TYPE_F16) { - // final_weight = ggml_new_tensor(compute_ctx, GGML_TYPE_F32, ggml_n_dims(weight), weight->ne); - // final_weight = ggml_cpy(compute_ctx, weight, final_weight); - final_weight = to_f32(compute_ctx, weight); - final_weight = ggml_add_inplace(compute_ctx, final_weight, updown); - final_weight = ggml_cpy(compute_ctx, final_weight, weight); - } else { - final_weight = ggml_add_inplace(compute_ctx, weight, updown); + for (auto& key : keys) { + bool is_qkv_split = starts_with(key, "SPLIT|"); + if (is_qkv_split) { + key = key.substr(sizeof("SPLIT|") - 1); + } + bool is_qkvm_split = starts_with(key, "SPLIT_L|"); + if (is_qkvm_split) { + key = key.substr(sizeof("SPLIT_L|") - 1); + } + struct ggml_tensor* updown = NULL; + float scale_value = 1.0f; + std::string fk = lora_pre[type] + key; + if (lora_tensors.find(fk + ".hada_w1_a") != lora_tensors.end()) { + // LoHa mode + + // TODO: split qkv convention for LoHas (is it ever used?) + if (is_qkv_split || is_qkvm_split) { + LOG_ERROR("Split qkv isn't supported for LoHa models."); + break; + } + std::string alpha_name = ""; + + ggml_tensor* hada_1_mid = NULL; // tau for tucker decomposition + ggml_tensor* hada_1_up = NULL; + ggml_tensor* hada_1_down = NULL; + + ggml_tensor* hada_2_mid = NULL; // tau for tucker decomposition + ggml_tensor* hada_2_up = NULL; + ggml_tensor* hada_2_down = NULL; + + std::string hada_1_mid_name = ""; + std::string hada_1_down_name = ""; + std::string hada_1_up_name = ""; + + std::string hada_2_mid_name = ""; + std::string hada_2_down_name = ""; + std::string hada_2_up_name = ""; + + + hada_1_down_name = fk + ".hada_w1_b"; + hada_1_up_name = fk + ".hada_w1_a"; + hada_1_mid_name = fk + ".hada_t1"; + if (lora_tensors.find(hada_1_down_name) != lora_tensors.end()) { + hada_1_down = to_f32(compute_ctx, lora_tensors[hada_1_down_name]); + } + if (lora_tensors.find(hada_1_up_name) != lora_tensors.end()) { + hada_1_up = to_f32(compute_ctx, lora_tensors[hada_1_up_name]); + } + if (lora_tensors.find(hada_1_mid_name) != lora_tensors.end()) { + hada_1_mid = to_f32(compute_ctx, lora_tensors[hada_1_mid_name]); + applied_lora_tensors.insert(hada_1_mid_name); + hada_1_up = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, hada_1_up)); + } + + hada_2_down_name = fk + ".hada_w2_b"; + hada_2_up_name = fk + ".hada_w2_a"; + hada_2_mid_name = fk + ".hada_t2"; + if (lora_tensors.find(hada_2_down_name) != lora_tensors.end()) { + hada_2_down = to_f32(compute_ctx, lora_tensors[hada_2_down_name]); + } + if (lora_tensors.find(hada_2_up_name) != lora_tensors.end()) { + hada_2_up = to_f32(compute_ctx, lora_tensors[hada_2_up_name]); + } + if (lora_tensors.find(hada_2_mid_name) != lora_tensors.end()) { + hada_2_mid = to_f32(compute_ctx, lora_tensors[hada_2_mid_name]); + applied_lora_tensors.insert(hada_2_mid_name); + hada_2_up = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, hada_2_up)); + } + + alpha_name = fk + ".alpha"; + + applied_lora_tensors.insert(hada_1_down_name); + applied_lora_tensors.insert(hada_1_up_name); + applied_lora_tensors.insert(hada_2_down_name); + applied_lora_tensors.insert(hada_2_up_name); + + applied_lora_tensors.insert(alpha_name); + if (hada_1_up == NULL || hada_1_down == NULL || hada_2_up == NULL || hada_2_down == NULL) { + continue; + } + + struct ggml_tensor* updown_1 = ggml_merge_lora(compute_ctx, hada_1_down, hada_1_up, hada_1_mid); + struct ggml_tensor* updown_2 = ggml_merge_lora(compute_ctx, hada_2_down, hada_2_up, hada_2_mid); + updown = ggml_mul_inplace(compute_ctx, updown_1, updown_2); + + // calc_scale + // TODO: .dora_scale? + int64_t rank = hada_1_down->ne[ggml_n_dims(hada_1_down) - 1]; + if (lora_tensors.find(alpha_name) != lora_tensors.end()) { + float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]); + scale_value = alpha / rank; + } + } else if (lora_tensors.find(fk + ".lokr_w1") != lora_tensors.end() || lora_tensors.find(fk + ".lokr_w1_a") != lora_tensors.end()) { + // LoKr mode + + // TODO: split qkv convention for LoKrs (is it ever used?) + if (is_qkv_split || is_qkvm_split) { + LOG_ERROR("Split qkv isn't supported for LoKr models."); + break; + } + + std::string alpha_name = fk + ".alpha"; + + ggml_tensor* lokr_w1 = NULL; + ggml_tensor* lokr_w2 = NULL; + + std::string lokr_w1_name = ""; + std::string lokr_w2_name = ""; + + lokr_w1_name = fk + ".lokr_w1"; + lokr_w2_name = fk + ".lokr_w2"; + + if (lora_tensors.find(lokr_w1_name) != lora_tensors.end()) { + lokr_w1 = to_f32(compute_ctx, lora_tensors[lokr_w1_name]); + applied_lora_tensors.insert(lokr_w1_name); + } else { + ggml_tensor* down = NULL; + ggml_tensor* up = NULL; + std::string down_name = lokr_w1_name + "_b"; + std::string up_name = lokr_w1_name + "_a"; + if (lora_tensors.find(down_name) != lora_tensors.end()) { + // w1 should not be low rank normally, sometimes w1 and w2 are swapped + down = to_f32(compute_ctx, lora_tensors[down_name]); + applied_lora_tensors.insert(down_name); + + int64_t rank = down->ne[ggml_n_dims(down) - 1]; + if (lora_tensors.find(alpha_name) != lora_tensors.end()) { + float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]); + scale_value = alpha / rank; + } + } + if (lora_tensors.find(up_name) != lora_tensors.end()) { + up = to_f32(compute_ctx, lora_tensors[up_name]); + applied_lora_tensors.insert(up_name); + } + lokr_w1 = ggml_merge_lora(compute_ctx, down, up); + } + if (lora_tensors.find(lokr_w2_name) != lora_tensors.end()) { + lokr_w2 = to_f32(compute_ctx, lora_tensors[lokr_w2_name]); + applied_lora_tensors.insert(lokr_w2_name); + } else { + ggml_tensor* down = NULL; + ggml_tensor* up = NULL; + std::string down_name = lokr_w2_name + "_b"; + std::string up_name = lokr_w2_name + "_a"; + if (lora_tensors.find(down_name) != lora_tensors.end()) { + down = to_f32(compute_ctx, lora_tensors[down_name]); + applied_lora_tensors.insert(down_name); + + int64_t rank = down->ne[ggml_n_dims(down) - 1]; + if (lora_tensors.find(alpha_name) != lora_tensors.end()) { + float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]); + scale_value = alpha / rank; + } + } + if (lora_tensors.find(up_name) != lora_tensors.end()) { + up = to_f32(compute_ctx, lora_tensors[up_name]); + applied_lora_tensors.insert(up_name); + } + lokr_w2 = ggml_merge_lora(compute_ctx, down, up); + } + + // Technically it might be unused, but I believe it's the expected behavior + applied_lora_tensors.insert(alpha_name); + + updown = ggml_kronecker(compute_ctx, lokr_w1, lokr_w2); + + } else { + // LoRA mode + ggml_tensor* lora_mid = NULL; // tau for tucker decomposition + ggml_tensor* lora_up = NULL; + ggml_tensor* lora_down = NULL; + + std::string alpha_name = ""; + std::string scale_name = ""; + std::string split_q_scale_name = ""; + std::string lora_mid_name = ""; + std::string lora_down_name = ""; + std::string lora_up_name = ""; + + if (is_qkv_split) { + std::string suffix = ""; + auto split_q_d_name = fk + "q" + suffix + lora_downs[type] + ".weight"; + + if (lora_tensors.find(split_q_d_name) == lora_tensors.end()) { + suffix = "_proj"; + split_q_d_name = fk + "q" + suffix + lora_downs[type] + ".weight"; + } + if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) { + // print_ggml_tensor(it.second, true); //[3072, 21504, 1, 1] + // find qkv and mlp up parts in LoRA model + auto split_k_d_name = fk + "k" + suffix + lora_downs[type] + ".weight"; + auto split_v_d_name = fk + "v" + suffix + lora_downs[type] + ".weight"; + + auto split_q_u_name = fk + "q" + suffix + lora_ups[type] + ".weight"; + auto split_k_u_name = fk + "k" + suffix + lora_ups[type] + ".weight"; + auto split_v_u_name = fk + "v" + suffix + lora_ups[type] + ".weight"; + + auto split_q_scale_name = fk + "q" + suffix + ".scale"; + auto split_k_scale_name = fk + "k" + suffix + ".scale"; + auto split_v_scale_name = fk + "v" + suffix + ".scale"; + + auto split_q_alpha_name = fk + "q" + suffix + ".alpha"; + auto split_k_alpha_name = fk + "k" + suffix + ".alpha"; + auto split_v_alpha_name = fk + "v" + suffix + ".alpha"; + + ggml_tensor* lora_q_down = NULL; + ggml_tensor* lora_q_up = NULL; + ggml_tensor* lora_k_down = NULL; + ggml_tensor* lora_k_up = NULL; + ggml_tensor* lora_v_down = NULL; + ggml_tensor* lora_v_up = NULL; + + lora_q_down = to_f32(compute_ctx, lora_tensors[split_q_d_name]); + + if (lora_tensors.find(split_q_u_name) != lora_tensors.end()) { + lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]); + } + + if (lora_tensors.find(split_k_d_name) != lora_tensors.end()) { + lora_k_down = to_f32(compute_ctx, lora_tensors[split_k_d_name]); + } + + if (lora_tensors.find(split_k_u_name) != lora_tensors.end()) { + lora_k_up = to_f32(compute_ctx, lora_tensors[split_k_u_name]); + } + + if (lora_tensors.find(split_v_d_name) != lora_tensors.end()) { + lora_v_down = to_f32(compute_ctx, lora_tensors[split_v_d_name]); + } + + if (lora_tensors.find(split_v_u_name) != lora_tensors.end()) { + lora_v_up = to_f32(compute_ctx, lora_tensors[split_v_u_name]); + } + + float q_rank = lora_q_up->ne[0]; + float k_rank = lora_k_up->ne[0]; + float v_rank = lora_v_up->ne[0]; + + float lora_q_scale = 1; + float lora_k_scale = 1; + float lora_v_scale = 1; + + if (lora_tensors.find(split_q_scale_name) != lora_tensors.end()) { + lora_q_scale = ggml_backend_tensor_get_f32(lora_tensors[split_q_scale_name]); + applied_lora_tensors.insert(split_q_scale_name); + } + if (lora_tensors.find(split_k_scale_name) != lora_tensors.end()) { + lora_k_scale = ggml_backend_tensor_get_f32(lora_tensors[split_k_scale_name]); + applied_lora_tensors.insert(split_k_scale_name); + } + if (lora_tensors.find(split_v_scale_name) != lora_tensors.end()) { + lora_v_scale = ggml_backend_tensor_get_f32(lora_tensors[split_v_scale_name]); + applied_lora_tensors.insert(split_v_scale_name); + } + + if (lora_tensors.find(split_q_alpha_name) != lora_tensors.end()) { + float lora_q_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_q_alpha_name]); + applied_lora_tensors.insert(split_q_alpha_name); + lora_q_scale = lora_q_alpha / q_rank; + } + if (lora_tensors.find(split_k_alpha_name) != lora_tensors.end()) { + float lora_k_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_k_alpha_name]); + applied_lora_tensors.insert(split_k_alpha_name); + lora_k_scale = lora_k_alpha / k_rank; + } + if (lora_tensors.find(split_v_alpha_name) != lora_tensors.end()) { + float lora_v_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_v_alpha_name]); + applied_lora_tensors.insert(split_v_alpha_name); + lora_v_scale = lora_v_alpha / v_rank; + } + + ggml_scale_inplace(compute_ctx, lora_q_down, lora_q_scale); + ggml_scale_inplace(compute_ctx, lora_k_down, lora_k_scale); + ggml_scale_inplace(compute_ctx, lora_v_down, lora_v_scale); + + // print_ggml_tensor(lora_q_down, true); //[3072, R, 1, 1] + // print_ggml_tensor(lora_k_down, true); //[3072, R, 1, 1] + // print_ggml_tensor(lora_v_down, true); //[3072, R, 1, 1] + // print_ggml_tensor(lora_q_up, true); //[R, 3072, 1, 1] + // print_ggml_tensor(lora_k_up, true); //[R, 3072, 1, 1] + // print_ggml_tensor(lora_v_up, true); //[R, 3072, 1, 1] + + // these need to be stitched together this way: + // |q_up,0 ,0 | + // |0 ,k_up,0 | + // |0 ,0 ,v_up| + // (q_down,k_down,v_down) . (q ,k ,v) + + // up_concat will be [9216, R*3, 1, 1] + // down_concat will be [R*3, 3072, 1, 1] + ggml_tensor* lora_down_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_down, lora_k_down, 1), lora_v_down, 1); + + ggml_tensor* z = ggml_dup_tensor(compute_ctx, lora_q_up); + ggml_scale(compute_ctx, z, 0); + ggml_tensor* zz = ggml_concat(compute_ctx, z, z, 1); + + ggml_tensor* q_up = ggml_concat(compute_ctx, lora_q_up, zz, 1); + ggml_tensor* k_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, z, lora_k_up, 1), z, 1); + ggml_tensor* v_up = ggml_concat(compute_ctx, zz, lora_v_up, 1); + // print_ggml_tensor(q_up, true); //[R, 9216, 1, 1] + // print_ggml_tensor(k_up, true); //[R, 9216, 1, 1] + // print_ggml_tensor(v_up, true); //[R, 9216, 1, 1] + ggml_tensor* lora_up_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, q_up, k_up, 0), v_up, 0); + // print_ggml_tensor(lora_up_concat, true); //[R*3, 9216, 1, 1] + + lora_down = ggml_cont(compute_ctx, lora_down_concat); + lora_up = ggml_cont(compute_ctx, lora_up_concat); + + applied_lora_tensors.insert(split_q_u_name); + applied_lora_tensors.insert(split_k_u_name); + applied_lora_tensors.insert(split_v_u_name); + + applied_lora_tensors.insert(split_q_d_name); + applied_lora_tensors.insert(split_k_d_name); + applied_lora_tensors.insert(split_v_d_name); + } + } else if (is_qkvm_split) { + auto split_q_d_name = fk + "attn.to_q" + lora_downs[type] + ".weight"; + if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) { + // print_ggml_tensor(it.second, true); //[3072, 21504, 1, 1] + // find qkv and mlp up parts in LoRA model + auto split_k_d_name = fk + "attn.to_k" + lora_downs[type] + ".weight"; + auto split_v_d_name = fk + "attn.to_v" + lora_downs[type] + ".weight"; + + auto split_q_u_name = fk + "attn.to_q" + lora_ups[type] + ".weight"; + auto split_k_u_name = fk + "attn.to_k" + lora_ups[type] + ".weight"; + auto split_v_u_name = fk + "attn.to_v" + lora_ups[type] + ".weight"; + + auto split_m_d_name = fk + "proj_mlp" + lora_downs[type] + ".weight"; + auto split_m_u_name = fk + "proj_mlp" + lora_ups[type] + ".weight"; + + auto split_q_scale_name = fk + "attn.to_q" + ".scale"; + auto split_k_scale_name = fk + "attn.to_k" + ".scale"; + auto split_v_scale_name = fk + "attn.to_v" + ".scale"; + auto split_m_scale_name = fk + "proj_mlp" + ".scale"; + + auto split_q_alpha_name = fk + "attn.to_q" + ".alpha"; + auto split_k_alpha_name = fk + "attn.to_k" + ".alpha"; + auto split_v_alpha_name = fk + "attn.to_v" + ".alpha"; + auto split_m_alpha_name = fk + "proj_mlp" + ".alpha"; + + ggml_tensor* lora_q_down = NULL; + ggml_tensor* lora_q_up = NULL; + ggml_tensor* lora_k_down = NULL; + ggml_tensor* lora_k_up = NULL; + ggml_tensor* lora_v_down = NULL; + ggml_tensor* lora_v_up = NULL; + + ggml_tensor* lora_m_down = NULL; + ggml_tensor* lora_m_up = NULL; + + lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]); + + if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) { + lora_q_down = to_f32(compute_ctx, lora_tensors[split_q_d_name]); + } + + if (lora_tensors.find(split_q_u_name) != lora_tensors.end()) { + lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]); + } + + if (lora_tensors.find(split_k_d_name) != lora_tensors.end()) { + lora_k_down = to_f32(compute_ctx, lora_tensors[split_k_d_name]); + } + + if (lora_tensors.find(split_k_u_name) != lora_tensors.end()) { + lora_k_up = to_f32(compute_ctx, lora_tensors[split_k_u_name]); + } + + if (lora_tensors.find(split_v_d_name) != lora_tensors.end()) { + lora_v_down = to_f32(compute_ctx, lora_tensors[split_v_d_name]); + } + + if (lora_tensors.find(split_v_u_name) != lora_tensors.end()) { + lora_v_up = to_f32(compute_ctx, lora_tensors[split_v_u_name]); + } + + if (lora_tensors.find(split_m_d_name) != lora_tensors.end()) { + lora_m_down = to_f32(compute_ctx, lora_tensors[split_m_d_name]); + } + + if (lora_tensors.find(split_m_u_name) != lora_tensors.end()) { + lora_m_up = to_f32(compute_ctx, lora_tensors[split_m_u_name]); + } + + float q_rank = lora_q_up->ne[0]; + float k_rank = lora_k_up->ne[0]; + float v_rank = lora_v_up->ne[0]; + float m_rank = lora_v_up->ne[0]; + + float lora_q_scale = 1; + float lora_k_scale = 1; + float lora_v_scale = 1; + float lora_m_scale = 1; + + if (lora_tensors.find(split_q_scale_name) != lora_tensors.end()) { + lora_q_scale = ggml_backend_tensor_get_f32(lora_tensors[split_q_scale_name]); + applied_lora_tensors.insert(split_q_scale_name); + } + if (lora_tensors.find(split_k_scale_name) != lora_tensors.end()) { + lora_k_scale = ggml_backend_tensor_get_f32(lora_tensors[split_k_scale_name]); + applied_lora_tensors.insert(split_k_scale_name); + } + if (lora_tensors.find(split_v_scale_name) != lora_tensors.end()) { + lora_v_scale = ggml_backend_tensor_get_f32(lora_tensors[split_v_scale_name]); + applied_lora_tensors.insert(split_v_scale_name); + } + if (lora_tensors.find(split_m_scale_name) != lora_tensors.end()) { + lora_m_scale = ggml_backend_tensor_get_f32(lora_tensors[split_m_scale_name]); + applied_lora_tensors.insert(split_m_scale_name); + } + + if (lora_tensors.find(split_q_alpha_name) != lora_tensors.end()) { + float lora_q_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_q_alpha_name]); + applied_lora_tensors.insert(split_q_alpha_name); + lora_q_scale = lora_q_alpha / q_rank; + } + if (lora_tensors.find(split_k_alpha_name) != lora_tensors.end()) { + float lora_k_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_k_alpha_name]); + applied_lora_tensors.insert(split_k_alpha_name); + lora_k_scale = lora_k_alpha / k_rank; + } + if (lora_tensors.find(split_v_alpha_name) != lora_tensors.end()) { + float lora_v_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_v_alpha_name]); + applied_lora_tensors.insert(split_v_alpha_name); + lora_v_scale = lora_v_alpha / v_rank; + } + if (lora_tensors.find(split_m_alpha_name) != lora_tensors.end()) { + float lora_m_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_m_alpha_name]); + applied_lora_tensors.insert(split_m_alpha_name); + lora_m_scale = lora_m_alpha / m_rank; + } + + ggml_scale_inplace(compute_ctx, lora_q_down, lora_q_scale); + ggml_scale_inplace(compute_ctx, lora_k_down, lora_k_scale); + ggml_scale_inplace(compute_ctx, lora_v_down, lora_v_scale); + ggml_scale_inplace(compute_ctx, lora_m_down, lora_m_scale); + + // print_ggml_tensor(lora_q_down, true); //[3072, R, 1, 1] + // print_ggml_tensor(lora_k_down, true); //[3072, R, 1, 1] + // print_ggml_tensor(lora_v_down, true); //[3072, R, 1, 1] + // print_ggml_tensor(lora_m_down, true); //[3072, R, 1, 1] + // print_ggml_tensor(lora_q_up, true); //[R, 3072, 1, 1] + // print_ggml_tensor(lora_k_up, true); //[R, 3072, 1, 1] + // print_ggml_tensor(lora_v_up, true); //[R, 3072, 1, 1] + // print_ggml_tensor(lora_m_up, true); //[R, 12288, 1, 1] + + // these need to be stitched together this way: + // |q_up,0 ,0 ,0 | + // |0 ,k_up,0 ,0 | + // |0 ,0 ,v_up,0 | + // |0 ,0 ,0 ,m_up| + // (q_down,k_down,v_down,m_down) . (q ,k ,v ,m) + + // up_concat will be [21504, R*4, 1, 1] + // down_concat will be [R*4, 3072, 1, 1] + + ggml_tensor* lora_down_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_down, lora_k_down, 1), ggml_concat(compute_ctx, lora_v_down, lora_m_down, 1), 1); + // print_ggml_tensor(lora_down_concat, true); //[3072, R*4, 1, 1] + + // this also means that if rank is bigger than 672, it is less memory efficient to do it this way (should be fine) + // print_ggml_tensor(lora_q_up, true); //[3072, R, 1, 1] + ggml_tensor* z = ggml_dup_tensor(compute_ctx, lora_q_up); + ggml_tensor* mlp_z = ggml_dup_tensor(compute_ctx, lora_m_up); + ggml_scale(compute_ctx, z, 0); + ggml_scale(compute_ctx, mlp_z, 0); + ggml_tensor* zz = ggml_concat(compute_ctx, z, z, 1); + + ggml_tensor* q_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_up, zz, 1), mlp_z, 1); + ggml_tensor* k_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, z, lora_k_up, 1), ggml_concat(compute_ctx, z, mlp_z, 1), 1); + ggml_tensor* v_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, zz, lora_v_up, 1), mlp_z, 1); + ggml_tensor* m_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, zz, z, 1), lora_m_up, 1); + // print_ggml_tensor(q_up, true); //[R, 21504, 1, 1] + // print_ggml_tensor(k_up, true); //[R, 21504, 1, 1] + // print_ggml_tensor(v_up, true); //[R, 21504, 1, 1] + // print_ggml_tensor(m_up, true); //[R, 21504, 1, 1] + + ggml_tensor* lora_up_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, q_up, k_up, 0), ggml_concat(compute_ctx, v_up, m_up, 0), 0); + // print_ggml_tensor(lora_up_concat, true); //[R*4, 21504, 1, 1] + + lora_down = ggml_cont(compute_ctx, lora_down_concat); + lora_up = ggml_cont(compute_ctx, lora_up_concat); + + applied_lora_tensors.insert(split_q_u_name); + applied_lora_tensors.insert(split_k_u_name); + applied_lora_tensors.insert(split_v_u_name); + applied_lora_tensors.insert(split_m_u_name); + + applied_lora_tensors.insert(split_q_d_name); + applied_lora_tensors.insert(split_k_d_name); + applied_lora_tensors.insert(split_v_d_name); + applied_lora_tensors.insert(split_m_d_name); + } + } else { + lora_up_name = fk + lora_ups[type] + ".weight"; + lora_down_name = fk + lora_downs[type] + ".weight"; + lora_mid_name = fk + ".lora_mid.weight"; + + alpha_name = fk + ".alpha"; + scale_name = fk + ".scale"; + + if (lora_tensors.find(lora_up_name) != lora_tensors.end()) { + lora_up = to_f32(compute_ctx, lora_tensors[lora_up_name]); + } + + if (lora_tensors.find(lora_down_name) != lora_tensors.end()) { + lora_down = to_f32(compute_ctx, lora_tensors[lora_down_name]); + } + + if (lora_tensors.find(lora_mid_name) != lora_tensors.end()) { + lora_mid = to_f32(compute_ctx, lora_tensors[lora_mid_name]); + applied_lora_tensors.insert(lora_mid_name); + } + + applied_lora_tensors.insert(lora_up_name); + applied_lora_tensors.insert(lora_down_name); + applied_lora_tensors.insert(alpha_name); + applied_lora_tensors.insert(scale_name); + } + + if (lora_up == NULL || lora_down == NULL) { + continue; + } + // calc_scale + // TODO: .dora_scale? + int64_t rank = lora_down->ne[ggml_n_dims(lora_down) - 1]; + if (lora_tensors.find(scale_name) != lora_tensors.end()) { + scale_value = ggml_backend_tensor_get_f32(lora_tensors[scale_name]); + } else if (lora_tensors.find(alpha_name) != lora_tensors.end()) { + float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]); + scale_value = alpha / rank; + } + + updown = ggml_merge_lora(compute_ctx, lora_down, lora_up, lora_mid); + } + scale_value *= multiplier; + updown = ggml_reshape(compute_ctx, updown, weight); + GGML_ASSERT(ggml_nelements(updown) == ggml_nelements(weight)); + updown = ggml_scale_inplace(compute_ctx, updown, scale_value); + ggml_tensor* final_weight; + if (weight->type != GGML_TYPE_F32 && weight->type != GGML_TYPE_F16) { + // final_weight = ggml_new_tensor(compute_ctx, GGML_TYPE_F32, ggml_n_dims(weight), weight->ne); + // final_weight = ggml_cpy(compute_ctx, weight, final_weight); + final_weight = to_f32(compute_ctx, weight); + final_weight = ggml_add_inplace(compute_ctx, final_weight, updown); + final_weight = ggml_cpy(compute_ctx, final_weight, weight); + } else { + final_weight = ggml_add_inplace(compute_ctx, weight, updown); + } + // final_weight = ggml_add_inplace(compute_ctx, weight, updown); // apply directly + ggml_build_forward_expand(gf, final_weight); + break; } - // final_weight = ggml_add_inplace(compute_ctx, weight, updown); // apply directly - ggml_build_forward_expand(gf, final_weight); } - size_t total_lora_tensors_count = 0; size_t applied_lora_tensors_count = 0; for (auto& kv : lora_tensors) { total_lora_tensors_count++; if (applied_lora_tensors.find(kv.first) == applied_lora_tensors.end()) { - LOG_WARN("unused lora tensor %s", kv.first.c_str()); + LOG_WARN("unused lora tensor |%s|", kv.first.c_str()); + print_ggml_tensor(kv.second, true); + // exit(0); } else { applied_lora_tensors_count++; } @@ -192,9 +835,9 @@ struct LoraModel : public GGMLRunner { return gf; } - void apply(std::map model_tensors, int n_threads) { + void apply(std::map model_tensors, SDVersion version, int n_threads) { auto get_graph = [&]() -> struct ggml_cgraph* { - return build_lora_graph(model_tensors); + return build_lora_graph(model_tensors, version); }; GGMLRunner::compute(get_graph, n_threads, true); } diff --git a/mmdit.hpp b/mmdit.hpp index 6f3a8a068..dee7b1c49 100644 --- a/mmdit.hpp +++ b/mmdit.hpp @@ -142,29 +142,78 @@ struct VectorEmbedder : public GGMLBlock { } }; +class RMSNorm : public UnaryBlock { +protected: + int64_t hidden_size; + float eps; + + void init_params(struct ggml_context* ctx, std::map& tensor_types, std::string prefix = "") { + enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F32; + params["weight"] = ggml_new_tensor_1d(ctx, wtype, hidden_size); + } + +public: + RMSNorm(int64_t hidden_size, + float eps = 1e-06f) + : hidden_size(hidden_size), + eps(eps) {} + + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* w = params["weight"]; + x = ggml_rms_norm(ctx, x, eps); + x = ggml_mul(ctx, x, w); + return x; + } +}; + class SelfAttention : public GGMLBlock { public: int64_t num_heads; bool pre_only; + std::string qk_norm; public: SelfAttention(int64_t dim, - int64_t num_heads = 8, - bool qkv_bias = false, - bool pre_only = false) - : num_heads(num_heads), pre_only(pre_only) { - // qk_norm is always None - blocks["qkv"] = std::shared_ptr(new Linear(dim, dim * 3, qkv_bias)); + int64_t num_heads = 8, + std::string qk_norm = "", + bool qkv_bias = false, + bool pre_only = false) + : num_heads(num_heads), pre_only(pre_only), qk_norm(qk_norm) { + int64_t d_head = dim / num_heads; + blocks["qkv"] = std::shared_ptr(new Linear(dim, dim * 3, qkv_bias)); if (!pre_only) { blocks["proj"] = std::shared_ptr(new Linear(dim, dim)); } + if (qk_norm == "rms") { + blocks["ln_q"] = std::shared_ptr(new RMSNorm(d_head, 1.0e-6)); + blocks["ln_k"] = std::shared_ptr(new RMSNorm(d_head, 1.0e-6)); + } else if (qk_norm == "ln") { + blocks["ln_q"] = std::shared_ptr(new LayerNorm(d_head, 1.0e-6)); + blocks["ln_k"] = std::shared_ptr(new LayerNorm(d_head, 1.0e-6)); + } } std::vector pre_attention(struct ggml_context* ctx, struct ggml_tensor* x) { auto qkv_proj = std::dynamic_pointer_cast(blocks["qkv"]); - auto qkv = qkv_proj->forward(ctx, x); - return split_qkv(ctx, qkv); + auto qkv = qkv_proj->forward(ctx, x); + auto qkv_vec = split_qkv(ctx, qkv); + int64_t head_dim = qkv_vec[0]->ne[0] / num_heads; + auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head] + auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head] + auto v = qkv_vec[2]; // [N, n_token, n_head*d_head] + + if (qk_norm == "rms" || qk_norm == "ln") { + auto ln_q = std::dynamic_pointer_cast(blocks["ln_q"]); + auto ln_k = std::dynamic_pointer_cast(blocks["ln_k"]); + q = ln_q->forward(ctx, q); + k = ln_k->forward(ctx, k); + } + + q = ggml_reshape_3d(ctx, q, q->ne[0] * q->ne[1], q->ne[2], q->ne[3]); // [N, n_token, n_head*d_head] + k = ggml_reshape_3d(ctx, k, k->ne[0] * k->ne[1], k->ne[2], k->ne[3]); // [N, n_token, n_head*d_head] + + return {q, k, v}; } struct ggml_tensor* post_attention(struct ggml_context* ctx, struct ggml_tensor* x) { @@ -204,20 +253,26 @@ struct DismantledBlock : public GGMLBlock { public: int64_t num_heads; bool pre_only; + bool self_attn; public: DismantledBlock(int64_t hidden_size, int64_t num_heads, - float mlp_ratio = 4.0, - bool qkv_bias = false, - bool pre_only = false) - : num_heads(num_heads), pre_only(pre_only) { + float mlp_ratio = 4.0, + std::string qk_norm = "", + bool qkv_bias = false, + bool pre_only = false, + bool self_attn = false) + : num_heads(num_heads), pre_only(pre_only), self_attn(self_attn) { // rmsnorm is always Flase // scale_mod_only is always Flase // swiglu is always Flase - // qk_norm is always Flase blocks["norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); - blocks["attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, pre_only)); + blocks["attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only)); + + if (self_attn) { + blocks["attn2"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, false)); + } if (!pre_only) { blocks["norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); @@ -229,9 +284,52 @@ struct DismantledBlock : public GGMLBlock { if (pre_only) { n_mods = 2; } + if (self_attn) { + n_mods = 9; + } blocks["adaLN_modulation.1"] = std::shared_ptr(new Linear(hidden_size, n_mods * hidden_size)); } + std::tuple, std::vector, std::vector> pre_attention_x(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* c) { + GGML_ASSERT(self_attn); + // x: [N, n_token, hidden_size] + // c: [N, hidden_size] + auto norm1 = std::dynamic_pointer_cast(blocks["norm1"]); + auto attn = std::dynamic_pointer_cast(blocks["attn"]); + auto attn2 = std::dynamic_pointer_cast(blocks["attn2"]); + auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); + + int64_t n_mods = 9; + auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, n_mods * hidden_size] + m = ggml_reshape_3d(ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size] + m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size] + + int64_t offset = m->nb[1] * m->ne[1]; + auto shift_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] + auto scale_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] + auto gate_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size] + + auto shift_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size] + auto scale_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size] + auto gate_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size] + + auto shift_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 6); // [N, hidden_size] + auto scale_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 7); // [N, hidden_size] + auto gate_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 8); // [N, hidden_size] + + auto x_norm = norm1->forward(ctx, x); + + auto attn_in = modulate(ctx, x_norm, shift_msa, scale_msa); + auto qkv = attn->pre_attention(ctx, attn_in); + + auto attn2_in = modulate(ctx, x_norm, shift_msa2, scale_msa2); + auto qkv2 = attn2->pre_attention(ctx, attn2_in); + + return {qkv, qkv2, {x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2}}; + } + std::pair, std::vector> pre_attention(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* c) { @@ -271,6 +369,44 @@ struct DismantledBlock : public GGMLBlock { } } + struct ggml_tensor* post_attention_x(struct ggml_context* ctx, + struct ggml_tensor* attn_out, + struct ggml_tensor* attn2_out, + struct ggml_tensor* x, + struct ggml_tensor* gate_msa, + struct ggml_tensor* shift_mlp, + struct ggml_tensor* scale_mlp, + struct ggml_tensor* gate_mlp, + struct ggml_tensor* gate_msa2) { + // attn_out: [N, n_token, hidden_size] + // x: [N, n_token, hidden_size] + // gate_msa: [N, hidden_size] + // shift_mlp: [N, hidden_size] + // scale_mlp: [N, hidden_size] + // gate_mlp: [N, hidden_size] + // return: [N, n_token, hidden_size] + GGML_ASSERT(!pre_only); + + auto attn = std::dynamic_pointer_cast(blocks["attn"]); + auto attn2 = std::dynamic_pointer_cast(blocks["attn2"]); + auto norm2 = std::dynamic_pointer_cast(blocks["norm2"]); + auto mlp = std::dynamic_pointer_cast(blocks["mlp"]); + + gate_msa = ggml_reshape_3d(ctx, gate_msa, gate_msa->ne[0], 1, gate_msa->ne[1]); // [N, 1, hidden_size] + gate_mlp = ggml_reshape_3d(ctx, gate_mlp, gate_mlp->ne[0], 1, gate_mlp->ne[1]); // [N, 1, hidden_size] + gate_msa2 = ggml_reshape_3d(ctx, gate_msa2, gate_msa2->ne[0], 1, gate_msa2->ne[1]); // [N, 1, hidden_size] + + attn_out = attn->post_attention(ctx, attn_out); + attn2_out = attn2->post_attention(ctx, attn2_out); + + x = ggml_add(ctx, x, ggml_mul(ctx, attn_out, gate_msa)); + x = ggml_add(ctx, x, ggml_mul(ctx, attn2_out, gate_msa2)); + auto mlp_out = mlp->forward(ctx, modulate(ctx, norm2->forward(ctx, x), shift_mlp, scale_mlp)); + x = ggml_add(ctx, x, ggml_mul(ctx, mlp_out, gate_mlp)); + + return x; + } + struct ggml_tensor* post_attention(struct ggml_context* ctx, struct ggml_tensor* attn_out, struct ggml_tensor* x, @@ -309,29 +445,52 @@ struct DismantledBlock : public GGMLBlock { // return: [N, n_token, hidden_size] auto attn = std::dynamic_pointer_cast(blocks["attn"]); - - auto qkv_intermediates = pre_attention(ctx, x, c); - auto qkv = qkv_intermediates.first; - auto intermediates = qkv_intermediates.second; - - auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] - x = post_attention(ctx, - attn_out, - intermediates[0], - intermediates[1], - intermediates[2], - intermediates[3], - intermediates[4]); - return x; // [N, n_token, dim] + if (self_attn) { + auto qkv_intermediates = pre_attention_x(ctx, x, c); + // auto qkv = qkv_intermediates.first; + // auto intermediates = qkv_intermediates.second; + // no longer a pair, but a tuple + auto qkv = std::get<0>(qkv_intermediates); + auto qkv2 = std::get<1>(qkv_intermediates); + auto intermediates = std::get<2>(qkv_intermediates); + + auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] + auto attn2_out = ggml_nn_attention_ext(ctx, qkv2[0], qkv2[1], qkv2[2], num_heads); // [N, n_token, dim] + x = post_attention_x(ctx, + attn_out, + attn2_out, + intermediates[0], + intermediates[1], + intermediates[2], + intermediates[3], + intermediates[4], + intermediates[5]); + return x; // [N, n_token, dim] + } else { + auto qkv_intermediates = pre_attention(ctx, x, c); + auto qkv = qkv_intermediates.first; + auto intermediates = qkv_intermediates.second; + + auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] + x = post_attention(ctx, + attn_out, + intermediates[0], + intermediates[1], + intermediates[2], + intermediates[3], + intermediates[4]); + return x; // [N, n_token, dim] + } } }; -__STATIC_INLINE__ std::pair block_mixing(struct ggml_context* ctx, - struct ggml_tensor* context, - struct ggml_tensor* x, - struct ggml_tensor* c, - std::shared_ptr context_block, - std::shared_ptr x_block) { +__STATIC_INLINE__ std::pair +block_mixing(struct ggml_context* ctx, + struct ggml_tensor* context, + struct ggml_tensor* x, + struct ggml_tensor* c, + std::shared_ptr context_block, + std::shared_ptr x_block) { // context: [N, n_context, hidden_size] // x: [N, n_token, hidden_size] // c: [N, hidden_size] @@ -339,10 +498,18 @@ __STATIC_INLINE__ std::pair block_mixi auto context_qkv = context_qkv_intermediates.first; auto context_intermediates = context_qkv_intermediates.second; - auto x_qkv_intermediates = x_block->pre_attention(ctx, x, c); - auto x_qkv = x_qkv_intermediates.first; - auto x_intermediates = x_qkv_intermediates.second; + std::vector x_qkv, x_qkv2, x_intermediates; + if (x_block->self_attn) { + auto x_qkv_intermediates = x_block->pre_attention_x(ctx, x, c); + x_qkv = std::get<0>(x_qkv_intermediates); + x_qkv2 = std::get<1>(x_qkv_intermediates); + x_intermediates = std::get<2>(x_qkv_intermediates); + } else { + auto x_qkv_intermediates = x_block->pre_attention(ctx, x, c); + x_qkv = x_qkv_intermediates.first; + x_intermediates = x_qkv_intermediates.second; + } std::vector qkv; for (int i = 0; i < 3; i++) { qkv.push_back(ggml_concat(ctx, context_qkv[i], x_qkv[i], 1)); @@ -381,13 +548,27 @@ __STATIC_INLINE__ std::pair block_mixi context = NULL; } - x = x_block->post_attention(ctx, - x_attn, - x_intermediates[0], - x_intermediates[1], - x_intermediates[2], - x_intermediates[3], - x_intermediates[4]); + if (x_block->self_attn) { + auto attn2 = ggml_nn_attention_ext(ctx, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads); // [N, n_token, hidden_size] + + x = x_block->post_attention_x(ctx, + x_attn, + attn2, + x_intermediates[0], + x_intermediates[1], + x_intermediates[2], + x_intermediates[3], + x_intermediates[4], + x_intermediates[5]); + } else { + x = x_block->post_attention(ctx, + x_attn, + x_intermediates[0], + x_intermediates[1], + x_intermediates[2], + x_intermediates[3], + x_intermediates[4]); + } return {context, x}; } @@ -396,12 +577,13 @@ struct JointBlock : public GGMLBlock { public: JointBlock(int64_t hidden_size, int64_t num_heads, - float mlp_ratio = 4.0, - bool qkv_bias = false, - bool pre_only = false) { - // qk_norm is always Flase - blocks["context_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qkv_bias, pre_only)); - blocks["x_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qkv_bias, false)); + float mlp_ratio = 4.0, + std::string qk_norm = "", + bool qkv_bias = false, + bool pre_only = false, + bool self_attn_x = false) { + blocks["context_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only)); + blocks["x_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false, self_attn_x)); } std::pair forward(struct ggml_context* ctx, @@ -455,52 +637,77 @@ struct FinalLayer : public GGMLBlock { struct MMDiT : public GGMLBlock { // Diffusion model with a Transformer backbone. protected: - SDVersion version = VERSION_SD3_2B; - int64_t input_size = -1; - int64_t patch_size = 2; - int64_t in_channels = 16; - int64_t depth = 24; - float mlp_ratio = 4.0f; - int64_t adm_in_channels = 2048; - int64_t out_channels = 16; - int64_t pos_embed_max_size = 192; - int64_t num_patchs = 36864; // 192 * 192 - int64_t context_size = 4096; + int64_t input_size = -1; + int64_t patch_size = 2; + int64_t in_channels = 16; + int64_t d_self = -1; // >=0 for MMdiT-X + int64_t depth = 24; + float mlp_ratio = 4.0f; + int64_t adm_in_channels = 2048; + int64_t out_channels = 16; + int64_t pos_embed_max_size = 192; + int64_t num_patchs = 36864; // 192 * 192 + int64_t context_size = 4096; + int64_t context_embedder_out_dim = 1536; int64_t hidden_size; + std::string qk_norm; - void init_params(struct ggml_context* ctx, ggml_type wtype) { - params["pos_embed"] = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hidden_size, num_patchs, 1); + void init_params(struct ggml_context* ctx, std::map& tensor_types, std::string prefix = "") { + enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "pos_embed") != tensor_types.end()) ? tensor_types[prefix + "pos_embed"] : GGML_TYPE_F32; + params["pos_embed"] = ggml_new_tensor_3d(ctx, wtype, hidden_size, num_patchs, 1); } public: - MMDiT(SDVersion version = VERSION_SD3_2B) - : version(version) { + MMDiT(std::map& tensor_types) { // input_size is always None // learn_sigma is always False // register_length is alwalys 0 // rmsnorm is alwalys False // scale_mod_only is alwalys False // swiglu is alwalys False - // qk_norm is always None // qkv_bias is always True // context_processor_layers is always None // pos_embed_scaling_factor is not used // pos_embed_offset is not used // context_embedder_config is always {'target': 'torch.nn.Linear', 'params': {'in_features': 4096, 'out_features': 1536}} - if (version == VERSION_SD3_2B) { - input_size = -1; - patch_size = 2; - in_channels = 16; - depth = 24; - mlp_ratio = 4.0f; - adm_in_channels = 2048; - out_channels = 16; - pos_embed_max_size = 192; - num_patchs = 36864; // 192 * 192 - context_size = 4096; + + // read tensors from tensor_types + for (auto pair : tensor_types) { + std::string tensor_name = pair.first; + if (tensor_name.find("model.diffusion_model.") == std::string::npos) + continue; + size_t jb = tensor_name.find("joint_blocks."); + if (jb != std::string::npos) { + tensor_name = tensor_name.substr(jb); // remove prefix + int block_depth = atoi(tensor_name.substr(13, tensor_name.find(".", 13)).c_str()); + if (block_depth + 1 > depth) { + depth = block_depth + 1; + } + if (tensor_name.find("attn.ln") != std::string::npos) { + if (tensor_name.find(".bias") != std::string::npos) { + qk_norm = "ln"; + } else { + qk_norm = "rms"; + } + } + if (tensor_name.find("attn2") != std::string::npos) { + if (block_depth > d_self) { + d_self = block_depth; + } + } + } + } + + if (d_self >= 0) { + pos_embed_max_size *= 2; + num_patchs *= 4; } + + LOG_INFO("MMDiT layers: %d (including %d MMDiT-x layers)", depth, d_self + 1); + int64_t default_out_channels = in_channels; hidden_size = 64 * depth; + context_embedder_out_dim = 64 * depth; int64_t num_heads = depth; blocks["x_embedder"] = std::shared_ptr(new PatchEmbed(input_size, patch_size, in_channels, hidden_size, true)); @@ -510,22 +717,25 @@ struct MMDiT : public GGMLBlock { blocks["y_embedder"] = std::shared_ptr(new VectorEmbedder(adm_in_channels, hidden_size)); } - blocks["context_embedder"] = std::shared_ptr(new Linear(4096, 1536, true, true)); + blocks["context_embedder"] = std::shared_ptr(new Linear(4096, context_embedder_out_dim, true, true)); for (int i = 0; i < depth; i++) { blocks["joint_blocks." + std::to_string(i)] = std::shared_ptr(new JointBlock(hidden_size, num_heads, mlp_ratio, + qk_norm, true, - i == depth - 1)); + i == depth - 1, + i <= d_self)); } blocks["final_layer"] = std::shared_ptr(new FinalLayer(hidden_size, patch_size, out_channels)); } - struct ggml_tensor* cropped_pos_embed(struct ggml_context* ctx, - int64_t h, - int64_t w) { + struct ggml_tensor* + cropped_pos_embed(struct ggml_context* ctx, + int64_t h, + int64_t w) { auto pos_embed = params["pos_embed"]; h = (h + 1) / patch_size; @@ -587,7 +797,8 @@ struct MMDiT : public GGMLBlock { struct ggml_tensor* forward_core_with_concat(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* c_mod, - struct ggml_tensor* context) { + struct ggml_tensor* context, + std::vector skip_layers = std::vector()) { // x: [N, H*W, hidden_size] // context: [N, n_context, d_context] // c: [N, hidden_size] @@ -595,6 +806,11 @@ struct MMDiT : public GGMLBlock { auto final_layer = std::dynamic_pointer_cast(blocks["final_layer"]); for (int i = 0; i < depth; i++) { + // skip iteration if i is in skip_layers + if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i) != skip_layers.end()) { + continue; + } + auto block = std::dynamic_pointer_cast(blocks["joint_blocks." + std::to_string(i)]); auto context_x = block->forward(ctx, context, x, c_mod); @@ -610,8 +826,9 @@ struct MMDiT : public GGMLBlock { struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* t, - struct ggml_tensor* y = NULL, - struct ggml_tensor* context = NULL) { + struct ggml_tensor* y = NULL, + struct ggml_tensor* context = NULL, + std::vector skip_layers = std::vector()) { // Forward pass of DiT. // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) // t: (N,) tensor of diffusion timesteps @@ -642,22 +859,23 @@ struct MMDiT : public GGMLBlock { context = context_embedder->forward(ctx, context); // [N, L, D] aka [N, L, 1536] } - x = forward_core_with_concat(ctx, x, c, context); // (N, H*W, patch_size ** 2 * out_channels) + x = forward_core_with_concat(ctx, x, c, context, skip_layers); // (N, H*W, patch_size ** 2 * out_channels) x = unpatchify(ctx, x, h, w); // [N, C, H, W] return x; } }; - struct MMDiTRunner : public GGMLRunner { MMDiT mmdit; + static std::map empty_tensor_types; + MMDiTRunner(ggml_backend_t backend, - ggml_type wtype, - SDVersion version = VERSION_SD3_2B) - : GGMLRunner(backend, wtype), mmdit(version) { - mmdit.init(params_ctx, wtype); + std::map& tensor_types = empty_tensor_types, + const std::string prefix = "") + : GGMLRunner(backend), mmdit(tensor_types) { + mmdit.init(params_ctx, tensor_types, prefix); } std::string get_desc() { @@ -671,7 +889,8 @@ struct MMDiTRunner : public GGMLRunner { struct ggml_cgraph* build_graph(struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, - struct ggml_tensor* y) { + struct ggml_tensor* y, + std::vector skip_layers = std::vector()) { struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, MMDIT_GRAPH_SIZE, false); x = to_backend(x); @@ -683,7 +902,8 @@ struct MMDiTRunner : public GGMLRunner { x, timesteps, y, - context); + context, + skip_layers); ggml_build_forward_expand(gf, out); @@ -696,13 +916,14 @@ struct MMDiTRunner : public GGMLRunner { struct ggml_tensor* context, struct ggml_tensor* y, struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL) { + struct ggml_context* output_ctx = NULL, + std::vector skip_layers = std::vector()) { // x: [N, in_channels, h, w] // timesteps: [N, ] // context: [N, max_position, hidden_size]([N, 154, 4096]) or [1, max_position, hidden_size] // y: [N, adm_in_channels] or [1, adm_in_channels] auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(x, timesteps, context, y); + return build_graph(x, timesteps, context, y, skip_layers); }; GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); @@ -751,7 +972,7 @@ struct MMDiTRunner : public GGMLRunner { // ggml_backend_t backend = ggml_backend_cuda_init(0); ggml_backend_t backend = ggml_backend_cpu_init(); ggml_type model_data_type = GGML_TYPE_F16; - std::shared_ptr mmdit = std::shared_ptr(new MMDiTRunner(backend, model_data_type)); + std::shared_ptr mmdit = std::shared_ptr(new MMDiTRunner(backend)); { LOG_INFO("loading from '%s'", file_path.c_str()); diff --git a/model.cpp b/model.cpp index b74a735f8..24da39f6d 100644 --- a/model.cpp +++ b/model.cpp @@ -13,6 +13,7 @@ #include "ggml-alloc.h" #include "ggml-backend.h" +#include "ggml-cpu.h" #include "ggml.h" #include "stable-diffusion.h" @@ -146,6 +147,33 @@ std::unordered_map vae_decoder_name_map = { {"first_stage_model.decoder.mid.attn_1.to_v.weight", "first_stage_model.decoder.mid.attn_1.v.weight"}, }; +std::unordered_map pmid_v2_name_map = { + {"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.fc1.weight"}, + {"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.3.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.fc2.weight"}, + {"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.fc1.weight"}, + {"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.3.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.fc2.weight"}, + {"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.fc1.weight"}, + {"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.3.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.fc2.weight"}, + {"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.fc1.weight"}, + {"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.3.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.fc2.weight"}, + {"pmid.qformer_perceiver.token_proj.0.bias", + "pmid.qformer_perceiver.token_proj.fc1.bias"}, + {"pmid.qformer_perceiver.token_proj.2.bias", + "pmid.qformer_perceiver.token_proj.fc2.bias"}, + {"pmid.qformer_perceiver.token_proj.0.weight", + "pmid.qformer_perceiver.token_proj.fc1.weight"}, + {"pmid.qformer_perceiver.token_proj.2.weight", + "pmid.qformer_perceiver.token_proj.fc2.weight"}, +}; + std::string convert_open_clip_to_hf_clip(const std::string& name) { std::string new_name = name; std::string prefix; @@ -212,6 +240,13 @@ std::string convert_vae_decoder_name(const std::string& name) { return name; } +std::string convert_pmid_v2_name(const std::string& name) { + if (pmid_v2_name_map.find(name) != pmid_v2_name_map.end()) { + return pmid_v2_name_map[name]; + } + return name; +} + /* If not a SDXL LoRA the unet" prefix will have already been replaced by this * point and "te2" and "te1" don't seem to appear in non-SDXL only "te_" */ std::string convert_sdxl_lora_name(std::string tensor_name) { @@ -430,11 +465,21 @@ std::string convert_tensor_name(std::string name) { if (starts_with(name, "diffusion_model")) { name = "model." + name; } + // size_t pos = name.find("lora_A"); + // if (pos != std::string::npos) { + // name.replace(pos, strlen("lora_A"), "lora_up"); + // } + // pos = name.find("lora_B"); + // if (pos != std::string::npos) { + // name.replace(pos, strlen("lora_B"), "lora_down"); + // } std::string new_name = name; if (starts_with(name, "cond_stage_model.") || starts_with(name, "conditioner.embedders.") || starts_with(name, "text_encoders.") || ends_with(name, ".vision_model.visual_projection.weight")) { new_name = convert_open_clip_to_hf_clip(name); } else if (starts_with(name, "first_stage_model.decoder")) { new_name = convert_vae_decoder_name(name); + } else if (starts_with(name, "pmid.qformer_perceiver")) { + new_name = convert_pmid_v2_name(name); } else if (starts_with(name, "control_model.")) { // for controlnet pth models size_t pos = name.find('.'); if (pos != std::string::npos) { @@ -466,6 +511,9 @@ std::string convert_tensor_name(std::string name) { if (pos != std::string::npos) { new_name.replace(pos, strlen(".processor"), ""); } + // if (starts_with(new_name, "transformer.transformer_blocks") || starts_with(new_name, "transformer.single_transformer_blocks")) { + // new_name = "model.diffusion_model." + new_name; + // } pos = new_name.rfind("lora"); if (pos != std::string::npos) { std::string name_without_network_parts = new_name.substr(0, pos - 1); @@ -510,6 +558,26 @@ std::string convert_tensor_name(std::string name) { return new_name; } +void add_preprocess_tensor_storage_types(std::map& tensor_storages_types, std::string name, enum ggml_type type) { + std::string new_name = convert_tensor_name(name); + + if (new_name.find("cond_stage_model") != std::string::npos && ends_with(new_name, "attn.in_proj_weight")) { + size_t prefix_size = new_name.find("attn.in_proj_weight"); + std::string prefix = new_name.substr(0, prefix_size); + tensor_storages_types[prefix + "self_attn.q_proj.weight"] = type; + tensor_storages_types[prefix + "self_attn.k_proj.weight"] = type; + tensor_storages_types[prefix + "self_attn.v_proj.weight"] = type; + } else if (new_name.find("cond_stage_model") != std::string::npos && ends_with(new_name, "attn.in_proj_bias")) { + size_t prefix_size = new_name.find("attn.in_proj_bias"); + std::string prefix = new_name.substr(0, prefix_size); + tensor_storages_types[prefix + "self_attn.q_proj.bias"] = type; + tensor_storages_types[prefix + "self_attn.k_proj.bias"] = type; + tensor_storages_types[prefix + "self_attn.v_proj.bias"] = type; + } else { + tensor_storages_types[new_name] = type; + } +} + void preprocess_tensor(TensorStorage tensor_storage, std::vector& processed_tensor_storages) { std::vector result; @@ -603,6 +671,47 @@ uint16_t f8_e4m3_to_f16(uint8_t f8) { return ggml_fp32_to_fp16(*reinterpret_cast(&result)); } +uint16_t f8_e5m2_to_f16(uint8_t fp8) { + uint8_t sign = (fp8 >> 7) & 0x1; + uint8_t exponent = (fp8 >> 2) & 0x1F; + uint8_t mantissa = fp8 & 0x3; + + uint16_t fp16_sign = sign << 15; + uint16_t fp16_exponent; + uint16_t fp16_mantissa; + + if (exponent == 0 && mantissa == 0) { // zero + return fp16_sign; + } + + if (exponent == 0x1F) { // NAN and INF + fp16_exponent = 0x1F; + fp16_mantissa = mantissa ? (mantissa << 8) : 0; + return fp16_sign | (fp16_exponent << 10) | fp16_mantissa; + } + + if (exponent == 0) { // subnormal numbers + fp16_exponent = 0; + fp16_mantissa = (mantissa << 8); + return fp16_sign | fp16_mantissa; + } + + // normal numbers + int16_t true_exponent = (int16_t)exponent - 15 + 15; + if (true_exponent <= 0) { + fp16_exponent = 0; + fp16_mantissa = (mantissa << 8); + } else if (true_exponent >= 0x1F) { + fp16_exponent = 0x1F; + fp16_mantissa = 0; + } else { + fp16_exponent = (uint16_t)true_exponent; + fp16_mantissa = mantissa << 8; + } + + return fp16_sign | (fp16_exponent << 10) | fp16_mantissa; +} + void bf16_to_f32_vec(uint16_t* src, float* dst, int64_t n) { // support inplace op for (int64_t i = n - 1; i >= 0; i--) { @@ -616,6 +725,12 @@ void f8_e4m3_to_f16_vec(uint8_t* src, uint16_t* dst, int64_t n) { dst[i] = f8_e4m3_to_f16(src[i]); } } +void f8_e5m2_to_f16_vec(uint8_t* src, uint16_t* dst, int64_t n) { + // support inplace op + for (int64_t i = n - 1; i >= 0; i--) { + dst[i] = f8_e5m2_to_f16(src[i]); + } +} void convert_tensor(void* src, ggml_type src_type, @@ -639,25 +754,25 @@ void convert_tensor(void* src, if (src_type == GGML_TYPE_F16) { ggml_fp16_to_fp32_row((ggml_fp16_t*)src, (float*)dst, n); } else { - auto qtype = ggml_internal_get_type_traits(src_type); - if (qtype.to_float == NULL) { + auto qtype = ggml_get_type_traits(src_type); + if (qtype->to_float == NULL) { throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(src_type))); } - qtype.to_float(src, (float*)dst, n); + qtype->to_float(src, (float*)dst, n); } } else { // src_type == GGML_TYPE_F16 => dst_type is quantized // src_type is quantized => dst_type == GGML_TYPE_F16 or dst_type is quantized - auto qtype = ggml_internal_get_type_traits(src_type); - if (qtype.to_float == NULL) { + auto qtype = ggml_get_type_traits(src_type); + if (qtype->to_float == NULL) { throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(src_type))); } std::vector buf; buf.resize(sizeof(float) * n); char* src_data_f32 = buf.data(); - qtype.to_float(src, (float*)src_data_f32, n); + qtype->to_float(src, (float*)src_data_f32, n); if (dst_type == GGML_TYPE_F16) { ggml_fp32_to_fp16_row((float*)src_data_f32, (ggml_fp16_t*)dst, n); } else { @@ -832,6 +947,7 @@ bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::s GGML_ASSERT(ggml_nbytes(dummy) == tensor_storage.nbytes()); tensor_storages.push_back(tensor_storage); + add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type); } gguf_free(ctx_gguf_); @@ -852,6 +968,8 @@ ggml_type str_to_ggml_type(const std::string& dtype) { ttype = GGML_TYPE_F32; } else if (dtype == "F8_E4M3") { ttype = GGML_TYPE_F16; + } else if (dtype == "F8_E5M2") { + ttype = GGML_TYPE_F16; } return ttype; } @@ -965,11 +1083,16 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const tensor_storage.is_f8_e4m3 = true; // f8 -> f16 GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2); + } else if (dtype == "F8_E5M2") { + tensor_storage.is_f8_e5m2 = true; + // f8 -> f16 + GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2); } else { GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size); } tensor_storages.push_back(tensor_storage); + add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type); // LOG_DEBUG("%s %s", tensor_storage.to_string().c_str(), dtype.c_str()); } @@ -1195,7 +1318,7 @@ bool ModelLoader::parse_data_pkl(uint8_t* buffer, zip_t* zip, std::string dir, size_t file_index, - const std::string& prefix) { + const std::string prefix) { uint8_t* buffer_end = buffer + buffer_size; if (buffer[0] == 0x80) { // proto if (buffer[1] != 2) { @@ -1297,9 +1420,11 @@ bool ModelLoader::parse_data_pkl(uint8_t* buffer, reader.tensor_storage.reverse_ne(); reader.tensor_storage.file_index = file_index; // if(strcmp(prefix.c_str(), "scarlett") == 0) - // printf(" got tensor %s \n ", reader.tensor_storage.name.c_str()); + // printf(" ZIP got tensor %s \n ", reader.tensor_storage.name.c_str()); reader.tensor_storage.name = prefix + reader.tensor_storage.name; tensor_storages.push_back(reader.tensor_storage); + add_preprocess_tensor_storage_types(tensor_storages_types, reader.tensor_storage.name, reader.tensor_storage.type); + // LOG_DEBUG("%s", reader.tensor_storage.name.c_str()); // reset reader = PickleTensorReader(); @@ -1334,7 +1459,8 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s size_t pos = name.find("data.pkl"); if (pos != std::string::npos) { std::string dir = name.substr(0, pos); - void* pkl_data = NULL; + printf("ZIP %d, name = %s, dir = %s \n", i, name.c_str(), dir.c_str()); + void* pkl_data = NULL; size_t pkl_size; zip_entry_read(zip, &pkl_data, &pkl_size); @@ -1352,28 +1478,49 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s } SDVersion ModelLoader::get_sd_version() { - TensorStorage token_embedding_weight; + TensorStorage token_embedding_weight, input_block_weight; + bool input_block_checked = false; + + bool has_multiple_encoders = false; + bool is_unet = false; + + bool is_xl = false; bool is_flux = false; + +#define found_family (is_xl || is_flux) for (auto& tensor_storage : tensor_storages) { - if (tensor_storage.name.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) { - return VERSION_FLUX_DEV; - } - if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) { - is_flux = true; - } - if (tensor_storage.name.find("model.diffusion_model.joint_blocks.23.") != std::string::npos) { - return VERSION_SD3_2B; - } - if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos) { - return VERSION_SDXL; - } - if (tensor_storage.name.find("cond_stage_model.1") != std::string::npos) { - return VERSION_SDXL; - } - if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) { - return VERSION_SVD; + if (!found_family) { + if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) { + is_flux = true; + if (input_block_checked) { + break; + } + } + if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) { + return VERSION_SD3; + } + if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos) { + is_unet = true; + if (has_multiple_encoders) { + is_xl = true; + if (input_block_checked) { + break; + } + } + } + if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || tensor_storage.name.find("cond_stage_model.1") != std::string::npos) { + has_multiple_encoders = true; + if (is_unet) { + is_xl = true; + if (input_block_checked) { + break; + } + } + } + if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) { + return VERSION_SVD; + } } - if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" || tensor_storage.name == "cond_stage_model.model.token_embedding.weight" || tensor_storage.name == "text_model.embeddings.token_embedding.weight" || @@ -1383,13 +1530,39 @@ SDVersion ModelLoader::get_sd_version() { token_embedding_weight = tensor_storage; // break; } + if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight") { + input_block_weight = tensor_storage; + input_block_checked = true; + if (found_family) { + break; + } + } + } + bool is_inpaint = input_block_weight.ne[2] == 9; + if (is_xl) { + if (is_inpaint) { + return VERSION_SDXL_INPAINT; + } + return VERSION_SDXL; } + if (is_flux) { - return VERSION_FLUX_SCHNELL; + is_inpaint = input_block_weight.ne[0] == 384; + if (is_inpaint) { + return VERSION_FLUX_FILL; + } + return VERSION_FLUX; } + if (token_embedding_weight.ne[0] == 768) { + if (is_inpaint) { + return VERSION_SD1_INPAINT; + } return VERSION_SD1; } else if (token_embedding_weight.ne[0] == 1024) { + if (is_inpaint) { + return VERSION_SD2_INPAINT; + } return VERSION_SD2; } return VERSION_COUNT; @@ -1479,6 +1652,30 @@ ggml_type ModelLoader::get_vae_wtype() { return GGML_TYPE_COUNT; } +void ModelLoader::set_wtype_override(ggml_type wtype, std::string prefix) { + for (auto& pair : tensor_storages_types) { + if (prefix.size() < 1 || pair.first.substr(0, prefix.size()) == prefix) { + bool found = false; + for (auto& tensor_storage : tensor_storages) { + std::map temp; + add_preprocess_tensor_storage_types(temp, tensor_storage.name, tensor_storage.type); + for (auto& preprocessed_name : temp) { + if (preprocessed_name.first == pair.first) { + if (tensor_should_be_converted(tensor_storage, wtype)) { + pair.second = wtype; + } + found = true; + break; + } + } + if (found) { + break; + } + } + } + } +} + std::string ModelLoader::load_merges() { std::string merges_utf8_str(reinterpret_cast(merges_utf8_c_str), sizeof(merges_utf8_c_str)); return merges_utf8_str; @@ -1580,9 +1777,11 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend } return true; }; - + int tensor_count = 0; + int64_t t1 = ggml_time_ms(); for (auto& tensor_storage : processed_tensor_storages) { if (tensor_storage.file_index != file_index) { + ++tensor_count; continue; } ggml_tensor* dst_tensor = NULL; @@ -1594,6 +1793,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend } if (dst_tensor == NULL) { + ++tensor_count; continue; } @@ -1611,6 +1811,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend } else if (tensor_storage.is_f8_e4m3) { // inplace op f8_e4m3_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements()); + } else if (tensor_storage.is_f8_e5m2) { + // inplace op + f8_e5m2_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements()); } } else { read_buffer.resize(tensor_storage.nbytes()); @@ -1622,6 +1825,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend } else if (tensor_storage.is_f8_e4m3) { // inplace op f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements()); + } else if (tensor_storage.is_f8_e5m2) { + // inplace op + f8_e5m2_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements()); } convert_tensor((void*)read_buffer.data(), tensor_storage.type, dst_tensor->data, @@ -1637,6 +1843,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend } else if (tensor_storage.is_f8_e4m3) { // inplace op f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements()); + } else if (tensor_storage.is_f8_e5m2) { + // inplace op + f8_e5m2_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements()); } if (tensor_storage.type == dst_tensor->type) { @@ -1651,6 +1860,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend ggml_backend_tensor_set(dst_tensor, convert_buffer.data(), 0, ggml_nbytes(dst_tensor)); } } + int64_t t2 = ggml_time_ms(); + pretty_progress(++tensor_count, processed_tensor_storages.size(), (t2 - t1) / 1000.0f); + t1 = t2; } if (zip != NULL) { @@ -1717,9 +1929,6 @@ bool ModelLoader::load_tensors(std::map& tenso if (pair.first.find("cond_stage_model.transformer.text_model.encoder.layers.23") != std::string::npos) { continue; } - if (pair.first.find("alphas_cumprod") != std::string::npos) { - continue; - } if (pair.first.find("alphas_cumprod") != std::string::npos) { continue; diff --git a/model.h b/model.h index 33f3fbcd3..d7f976533 100644 --- a/model.h +++ b/model.h @@ -14,25 +14,84 @@ #include "ggml.h" #include "json.hpp" #include "zip.h" +#include "gguf.h" #define SD_MAX_DIMS 5 enum SDVersion { VERSION_SD1, + VERSION_SD1_INPAINT, VERSION_SD2, + VERSION_SD2_INPAINT, VERSION_SDXL, + VERSION_SDXL_INPAINT, VERSION_SVD, - VERSION_SD3_2B, - VERSION_FLUX_DEV, - VERSION_FLUX_SCHNELL, + VERSION_SD3, + VERSION_FLUX, + VERSION_FLUX_FILL, VERSION_COUNT, }; +static inline bool sd_version_is_flux(SDVersion version) { + if (version == VERSION_FLUX || version == VERSION_FLUX_FILL) { + return true; + } + return false; +} + +static inline bool sd_version_is_sd3(SDVersion version) { + if (version == VERSION_SD3) { + return true; + } + return false; +} + +static inline bool sd_version_is_sd1(SDVersion version) { + if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT) { + return true; + } + return false; +} + +static inline bool sd_version_is_sd2(SDVersion version) { + if (version == VERSION_SD2 || version == VERSION_SD2_INPAINT) { + return true; + } + return false; +} + +static inline bool sd_version_is_sdxl(SDVersion version) { + if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT) { + return true; + } + return false; +} + +static inline bool sd_version_is_inpaint(SDVersion version) { + if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL) { + return true; + } + return false; +} + +static inline bool sd_version_is_dit(SDVersion version) { + if (sd_version_is_flux(version) || sd_version_is_sd3(version)) { + return true; + } + return false; +} + +enum PMVersion { + PM_VERSION_1, + PM_VERSION_2, +}; + struct TensorStorage { std::string name; ggml_type type = GGML_TYPE_F32; bool is_bf16 = false; bool is_f8_e4m3 = false; + bool is_f8_e5m2 = false; int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1}; int n_dims = 0; @@ -62,7 +121,7 @@ struct TensorStorage { } int64_t nbytes_to_read() const { - if (is_bf16 || is_f8_e4m3) { + if (is_bf16 || is_f8_e4m3 || is_f8_e5m2) { return nbytes() / 2; } else { return nbytes(); @@ -112,6 +171,8 @@ struct TensorStorage { type_name = "bf16"; } else if (is_f8_e4m3) { type_name = "f8_e4m3"; + } else if (is_f8_e5m2) { + type_name = "f8_e5m2"; } ss << name << " | " << type_name << " | "; ss << n_dims << " ["; @@ -138,7 +199,7 @@ class ModelLoader { zip_t* zip, std::string dir, size_t file_index, - const std::string& prefix); + const std::string prefix); bool init_from_gguf_file(const std::string& file_path, const std::string& prefix = ""); bool init_from_safetensors_file(const std::string& file_path, const std::string& prefix = ""); @@ -146,16 +207,20 @@ class ModelLoader { bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = ""); public: + std::map tensor_storages_types; + bool init_from_file(const std::string& file_path, const std::string& prefix = ""); SDVersion get_sd_version(); ggml_type get_sd_wtype(); ggml_type get_conditioner_wtype(); ggml_type get_diffusion_model_wtype(); ggml_type get_vae_wtype(); + void set_wtype_override(ggml_type wtype, std::string prefix = ""); bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend_t backend); bool load_tensors(std::map& tensors, ggml_backend_t backend, std::set ignore_tensors = {}); + bool save_to_gguf_file(const std::string& file_path, ggml_type type); bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type); int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT); diff --git a/pmid.hpp b/pmid.hpp index 381050fef..ea9f02eb6 100644 --- a/pmid.hpp +++ b/pmid.hpp @@ -42,6 +42,370 @@ struct FuseBlock : public GGMLBlock { } }; +/* +class QFormerPerceiver(nn.Module): + def __init__(self, id_embeddings_dim, cross_attention_dim, num_tokens, embedding_dim=1024, use_residual=True, ratio=4): + super().__init__() + + self.num_tokens = num_tokens + self.cross_attention_dim = cross_attention_dim + self.use_residual = use_residual + print(cross_attention_dim*num_tokens) + self.token_proj = nn.Sequential( + nn.Linear(id_embeddings_dim, id_embeddings_dim*ratio), + nn.GELU(), + nn.Linear(id_embeddings_dim*ratio, cross_attention_dim*num_tokens), + ) + self.token_norm = nn.LayerNorm(cross_attention_dim) + self.perceiver_resampler = FacePerceiverResampler( + dim=cross_attention_dim, + depth=4, + dim_head=128, + heads=cross_attention_dim // 128, + embedding_dim=embedding_dim, + output_dim=cross_attention_dim, + ff_mult=4, + ) + + def forward(self, x, last_hidden_state): + x = self.token_proj(x) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + x = self.token_norm(x) # cls token + out = self.perceiver_resampler(x, last_hidden_state) # retrieve from patch tokens + if self.use_residual: # TODO: if use_residual is not true + out = x + 1.0 * out + return out +*/ + +struct PMFeedForward : public GGMLBlock { + // network hparams + int dim; + +public: + PMFeedForward(int d, int multi = 4) + : dim(d) { + int inner_dim = dim * multi; + blocks["0"] = std::shared_ptr(new LayerNorm(dim)); + blocks["1"] = std::shared_ptr(new Mlp(dim, inner_dim, dim, false)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x) { + auto norm = std::dynamic_pointer_cast(blocks["0"]); + auto ff = std::dynamic_pointer_cast(blocks["1"]); + + x = norm->forward(ctx, x); + x = ff->forward(ctx, x); + return x; + } +}; + +struct PerceiverAttention : public GGMLBlock { + // network hparams + float scale; // = dim_head**-0.5 + int dim_head; // = dim_head + int heads; // = heads +public: + PerceiverAttention(int dim, int dim_h = 64, int h = 8) + : scale(powf(dim_h, -0.5)), dim_head(dim_h), heads(h) { + int inner_dim = dim_head * heads; + blocks["norm1"] = std::shared_ptr(new LayerNorm(dim)); + blocks["norm2"] = std::shared_ptr(new LayerNorm(dim)); + blocks["to_q"] = std::shared_ptr(new Linear(dim, inner_dim, false)); + blocks["to_kv"] = std::shared_ptr(new Linear(dim, inner_dim * 2, false)); + blocks["to_out"] = std::shared_ptr(new Linear(inner_dim, dim, false)); + } + + struct ggml_tensor* reshape_tensor(struct ggml_context* ctx, + struct ggml_tensor* x, + int heads) { + int64_t ne[4]; + for (int i = 0; i < 4; ++i) + ne[i] = x->ne[i]; + // print_ggml_tensor(x, true, "PerceiverAttention reshape x 0: "); + // printf("heads = %d \n", heads); + // x = ggml_view_4d(ctx, x, x->ne[0], x->ne[1], heads, x->ne[2]/heads, + // x->nb[1], x->nb[2], x->nb[3], 0); + x = ggml_reshape_4d(ctx, x, x->ne[0] / heads, heads, x->ne[1], x->ne[2]); + // x = ggml_view_4d(ctx, x, x->ne[0]/heads, heads, x->ne[1], x->ne[2], + // x->nb[1], x->nb[2], x->nb[3], 0); + // x = ggml_cont(ctx, x); + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); + // print_ggml_tensor(x, true, "PerceiverAttention reshape x 1: "); + // x = ggml_reshape_4d(ctx, x, ne[0], heads, ne[1], ne[2]/heads); + return x; + } + + std::vector chunk_half(struct ggml_context* ctx, + struct ggml_tensor* x) { + auto tlo = ggml_view_4d(ctx, x, x->ne[0] / 2, x->ne[1], x->ne[2], x->ne[3], x->nb[1], x->nb[2], x->nb[3], 0); + auto tli = ggml_view_4d(ctx, x, x->ne[0] / 2, x->ne[1], x->ne[2], x->ne[3], x->nb[1], x->nb[2], x->nb[3], x->nb[0] * x->ne[0] / 2); + return {ggml_cont(ctx, tlo), + ggml_cont(ctx, tli)}; + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* latents) { + // x (torch.Tensor): image features + // shape (b, n1, D) + // latent (torch.Tensor): latent features + // shape (b, n2, D) + int64_t ne[4]; + for (int i = 0; i < 4; ++i) + ne[i] = latents->ne[i]; + + auto norm1 = std::dynamic_pointer_cast(blocks["norm1"]); + auto norm2 = std::dynamic_pointer_cast(blocks["norm2"]); + x = norm1->forward(ctx, x); + latents = norm2->forward(ctx, latents); + auto to_q = std::dynamic_pointer_cast(blocks["to_q"]); + auto q = to_q->forward(ctx, latents); + + auto kv_input = ggml_concat(ctx, x, latents, 1); + auto to_kv = std::dynamic_pointer_cast(blocks["to_kv"]); + auto kv = to_kv->forward(ctx, kv_input); + auto k = ggml_view_4d(ctx, kv, kv->ne[0] / 2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1] / 2, kv->nb[2] / 2, kv->nb[3] / 2, 0); + auto v = ggml_view_4d(ctx, kv, kv->ne[0] / 2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1] / 2, kv->nb[2] / 2, kv->nb[3] / 2, kv->nb[0] * (kv->ne[0] / 2)); + k = ggml_cont(ctx, k); + v = ggml_cont(ctx, v); + q = reshape_tensor(ctx, q, heads); + k = reshape_tensor(ctx, k, heads); + v = reshape_tensor(ctx, v, heads); + scale = 1.f / sqrt(sqrt((float)dim_head)); + k = ggml_scale_inplace(ctx, k, scale); + q = ggml_scale_inplace(ctx, q, scale); + // auto weight = ggml_mul_mat(ctx, q, k); + auto weight = ggml_mul_mat(ctx, k, q); // NOTE order of mul is opposite to pytorch + + // GGML's softmax() is equivalent to pytorch's softmax(x, dim=-1) + // in this case, dimension along which Softmax will be computed is the last dim + // in torch and the first dim in GGML, consistent with the convention that pytorch's + // last dimension (varying most rapidly) corresponds to GGML's first (varying most rapidly). + // weight = ggml_soft_max(ctx, weight); + weight = ggml_soft_max_inplace(ctx, weight); + v = ggml_cont(ctx, ggml_transpose(ctx, v)); + // auto out = ggml_mul_mat(ctx, weight, v); + auto out = ggml_mul_mat(ctx, v, weight); // NOTE order of mul is opposite to pytorch + out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); + out = ggml_reshape_3d(ctx, out, ne[0], ne[1], ggml_nelements(out) / (ne[0] * ne[1])); + auto to_out = std::dynamic_pointer_cast(blocks["to_out"]); + out = to_out->forward(ctx, out); + return out; + } +}; + +struct FacePerceiverResampler : public GGMLBlock { + // network hparams + int depth; + +public: + FacePerceiverResampler(int dim = 768, + int d = 4, + int dim_head = 64, + int heads = 16, + int embedding_dim = 1280, + int output_dim = 768, + int ff_mult = 4) + : depth(d) { + blocks["proj_in"] = std::shared_ptr(new Linear(embedding_dim, dim, true)); + blocks["proj_out"] = std::shared_ptr(new Linear(dim, output_dim, true)); + blocks["norm_out"] = std::shared_ptr(new LayerNorm(output_dim)); + + for (int i = 0; i < depth; i++) { + std::string name = "layers." + std::to_string(i) + ".0"; + blocks[name] = std::shared_ptr(new PerceiverAttention(dim, dim_head, heads)); + name = "layers." + std::to_string(i) + ".1"; + blocks[name] = std::shared_ptr(new PMFeedForward(dim, ff_mult)); + } + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* latents, + struct ggml_tensor* x) { + // x: [N, channels, h, w] + auto proj_in = std::dynamic_pointer_cast(blocks["proj_in"]); + auto proj_out = std::dynamic_pointer_cast(blocks["proj_out"]); + auto norm_out = std::dynamic_pointer_cast(blocks["norm_out"]); + + x = proj_in->forward(ctx, x); + for (int i = 0; i < depth; i++) { + std::string name = "layers." + std::to_string(i) + ".0"; + auto attn = std::dynamic_pointer_cast(blocks[name]); + name = "layers." + std::to_string(i) + ".1"; + auto ff = std::dynamic_pointer_cast(blocks[name]); + auto t = attn->forward(ctx, x, latents); + latents = ggml_add(ctx, t, latents); + t = ff->forward(ctx, latents); + latents = ggml_add(ctx, t, latents); + } + latents = proj_out->forward(ctx, latents); + latents = norm_out->forward(ctx, latents); + return latents; + } +}; + +struct QFormerPerceiver : public GGMLBlock { + // network hparams + int num_tokens; + int cross_attention_dim; + bool use_residul; + +public: + QFormerPerceiver(int id_embeddings_dim, int cross_attention_d, int num_t, int embedding_dim = 1024, bool use_r = true, int ratio = 4) + : cross_attention_dim(cross_attention_d), num_tokens(num_t), use_residul(use_r) { + blocks["token_proj"] = std::shared_ptr(new Mlp(id_embeddings_dim, + id_embeddings_dim * ratio, + cross_attention_dim * num_tokens, + true)); + blocks["token_norm"] = std::shared_ptr(new LayerNorm(cross_attention_d)); + blocks["perceiver_resampler"] = std::shared_ptr(new FacePerceiverResampler( + cross_attention_dim, + 4, + 128, + cross_attention_dim / 128, + embedding_dim, + cross_attention_dim, + 4)); + } + + /* + def forward(self, x, last_hidden_state): + x = self.token_proj(x) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + x = self.token_norm(x) # cls token + out = self.perceiver_resampler(x, last_hidden_state) # retrieve from patch tokens + if self.use_residual: # TODO: if use_residual is not true + out = x + 1.0 * out + return out + */ + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* last_hidden_state) { + // x: [N, channels, h, w] + auto token_proj = std::dynamic_pointer_cast(blocks["token_proj"]); + auto token_norm = std::dynamic_pointer_cast(blocks["token_norm"]); + auto perceiver_resampler = std::dynamic_pointer_cast(blocks["perceiver_resampler"]); + + x = token_proj->forward(ctx, x); + int64_t nel = ggml_nelements(x); + x = ggml_reshape_3d(ctx, x, cross_attention_dim, num_tokens, nel / (cross_attention_dim * num_tokens)); + x = token_norm->forward(ctx, x); + struct ggml_tensor* out = perceiver_resampler->forward(ctx, x, last_hidden_state); + if (use_residul) + out = ggml_add(ctx, x, out); + return out; + } +}; + +/* +class FacePerceiverResampler(torch.nn.Module): + def __init__( + self, + *, + dim=768, + depth=4, + dim_head=64, + heads=16, + embedding_dim=1280, + output_dim=768, + ff_mult=4, + ): + super().__init__() + + self.proj_in = torch.nn.Linear(embedding_dim, dim) + self.proj_out = torch.nn.Linear(dim, output_dim) + self.norm_out = torch.nn.LayerNorm(output_dim) + self.layers = torch.nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + torch.nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, latents, x): + x = self.proj_in(x) + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + latents = self.proj_out(latents) + return self.norm_out(latents) +*/ + +/* + +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + +def reshape_tensor(x, heads): + bs, length, width = x.shape + # (bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + +*/ + struct FuseModule : public GGMLBlock { // network hparams int embed_dim; @@ -61,12 +425,19 @@ struct FuseModule : public GGMLBlock { auto mlp2 = std::dynamic_pointer_cast(blocks["mlp2"]); auto layer_norm = std::dynamic_pointer_cast(blocks["layer_norm"]); - auto prompt_embeds0 = ggml_cont(ctx, ggml_permute(ctx, prompt_embeds, 2, 0, 1, 3)); - auto id_embeds0 = ggml_cont(ctx, ggml_permute(ctx, id_embeds, 2, 0, 1, 3)); - // concat is along dim 2 - auto stacked_id_embeds = ggml_concat(ctx, prompt_embeds0, id_embeds0, 2); - stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 1, 2, 0, 3)); + // print_ggml_tensor(id_embeds, true, "Fuseblock id_embeds: "); + // print_ggml_tensor(prompt_embeds, true, "Fuseblock prompt_embeds: "); + // auto prompt_embeds0 = ggml_cont(ctx, ggml_permute(ctx, prompt_embeds, 2, 0, 1, 3)); + // auto id_embeds0 = ggml_cont(ctx, ggml_permute(ctx, id_embeds, 2, 0, 1, 3)); + // print_ggml_tensor(id_embeds0, true, "Fuseblock id_embeds0: "); + // print_ggml_tensor(prompt_embeds0, true, "Fuseblock prompt_embeds0: "); + // concat is along dim 2 + // auto stacked_id_embeds = ggml_concat(ctx, prompt_embeds0, id_embeds0, 2); + auto stacked_id_embeds = ggml_concat(ctx, prompt_embeds, id_embeds, 0); + // print_ggml_tensor(stacked_id_embeds, true, "Fuseblock stacked_id_embeds 0: "); + // stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 1, 2, 0, 3)); + // print_ggml_tensor(stacked_id_embeds, true, "Fuseblock stacked_id_embeds 1: "); // stacked_id_embeds = mlp1.forward(ctx, stacked_id_embeds); // stacked_id_embeds = ggml_add(ctx, stacked_id_embeds, prompt_embeds); // stacked_id_embeds = mlp2.forward(ctx, stacked_id_embeds); @@ -77,6 +448,8 @@ struct FuseModule : public GGMLBlock { stacked_id_embeds = mlp2->forward(ctx, stacked_id_embeds); stacked_id_embeds = layer_norm->forward(ctx, stacked_id_embeds); + // print_ggml_tensor(stacked_id_embeds, true, "Fuseblock stacked_id_embeds 1: "); + return stacked_id_embeds; } @@ -98,23 +471,31 @@ struct FuseModule : public GGMLBlock { // print_ggml_tensor(class_tokens_mask_pos, true, "class_tokens_mask_pos"); struct ggml_tensor* image_token_embeds = ggml_get_rows(ctx, prompt_embeds, class_tokens_mask_pos); ggml_set_name(image_token_embeds, "image_token_embeds"); + valid_id_embeds = ggml_reshape_2d(ctx, valid_id_embeds, valid_id_embeds->ne[0], + ggml_nelements(valid_id_embeds) / valid_id_embeds->ne[0]); struct ggml_tensor* stacked_id_embeds = fuse_fn(ctx, image_token_embeds, valid_id_embeds); - stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 0, 2, 1, 3)); + // stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 0, 2, 1, 3)); + // print_ggml_tensor(stacked_id_embeds, true, "AA stacked_id_embeds"); + // print_ggml_tensor(left, true, "AA left"); + // print_ggml_tensor(right, true, "AA right"); if (left && right) { - stacked_id_embeds = ggml_concat(ctx, left, stacked_id_embeds, 2); - stacked_id_embeds = ggml_concat(ctx, stacked_id_embeds, right, 2); + stacked_id_embeds = ggml_concat(ctx, left, stacked_id_embeds, 1); + stacked_id_embeds = ggml_concat(ctx, stacked_id_embeds, right, 1); } else if (left) { - stacked_id_embeds = ggml_concat(ctx, left, stacked_id_embeds, 2); + stacked_id_embeds = ggml_concat(ctx, left, stacked_id_embeds, 1); } else if (right) { - stacked_id_embeds = ggml_concat(ctx, stacked_id_embeds, right, 2); + stacked_id_embeds = ggml_concat(ctx, stacked_id_embeds, right, 1); } - stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 0, 2, 1, 3)); + // print_ggml_tensor(stacked_id_embeds, true, "BB stacked_id_embeds"); + // stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 0, 2, 1, 3)); + // print_ggml_tensor(stacked_id_embeds, true, "CC stacked_id_embeds"); class_tokens_mask = ggml_cont(ctx, ggml_transpose(ctx, class_tokens_mask)); class_tokens_mask = ggml_repeat(ctx, class_tokens_mask, prompt_embeds); prompt_embeds = ggml_mul(ctx, prompt_embeds, class_tokens_mask); struct ggml_tensor* updated_prompt_embeds = ggml_add(ctx, prompt_embeds, stacked_id_embeds); ggml_set_name(updated_prompt_embeds, "updated_prompt_embeds"); + // print_ggml_tensor(updated_prompt_embeds, true, "updated_prompt_embeds: "); return updated_prompt_embeds; } }; @@ -159,10 +540,77 @@ struct PhotoMakerIDEncoderBlock : public CLIPVisionModelProjection { } }; +struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionModelProjection { + int cross_attention_dim; + int num_tokens; + + PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock(int id_embeddings_dim = 512) + : CLIPVisionModelProjection(OPENAI_CLIP_VIT_L_14), + cross_attention_dim(2048), + num_tokens(2) { + blocks["visual_projection_2"] = std::shared_ptr(new Linear(1024, 1280, false)); + blocks["fuse_module"] = std::shared_ptr(new FuseModule(2048)); + /* + cross_attention_dim = 2048 + # projection + self.num_tokens = 2 + self.cross_attention_dim = cross_attention_dim + self.qformer_perceiver = QFormerPerceiver( + id_embeddings_dim, + cross_attention_dim, + self.num_tokens, + )*/ + blocks["qformer_perceiver"] = std::shared_ptr(new QFormerPerceiver(id_embeddings_dim, + cross_attention_dim, + num_tokens)); + } + + /* + def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask, id_embeds): + b, num_inputs, c, h, w = id_pixel_values.shape + id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w) + + last_hidden_state = self.vision_model(id_pixel_values)[0] + id_embeds = id_embeds.view(b * num_inputs, -1) + + id_embeds = self.qformer_perceiver(id_embeds, last_hidden_state) + id_embeds = id_embeds.view(b, num_inputs, self.num_tokens, -1) + updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask) + */ + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* id_pixel_values, + struct ggml_tensor* prompt_embeds, + struct ggml_tensor* class_tokens_mask, + struct ggml_tensor* class_tokens_mask_pos, + struct ggml_tensor* id_embeds, + struct ggml_tensor* left, + struct ggml_tensor* right) { + // x: [N, channels, h, w] + auto vision_model = std::dynamic_pointer_cast(blocks["vision_model"]); + auto fuse_module = std::dynamic_pointer_cast(blocks["fuse_module"]); + auto qformer_perceiver = std::dynamic_pointer_cast(blocks["qformer_perceiver"]); + + // struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, id_pixel_values); // [N, hidden_size] + struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, id_pixel_values, false); // [N, hidden_size] + id_embeds = qformer_perceiver->forward(ctx, id_embeds, last_hidden_state); + + struct ggml_tensor* updated_prompt_embeds = fuse_module->forward(ctx, + prompt_embeds, + id_embeds, + class_tokens_mask, + class_tokens_mask_pos, + left, right); + return updated_prompt_embeds; + } +}; + struct PhotoMakerIDEncoder : public GGMLRunner { public: - SDVersion version = VERSION_SDXL; + SDVersion version = VERSION_SDXL; + PMVersion pm_version = PM_VERSION_1; PhotoMakerIDEncoderBlock id_encoder; + PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock id_encoder2; float style_strength; std::vector ctm; @@ -175,25 +623,38 @@ struct PhotoMakerIDEncoder : public GGMLRunner { std::vector zeros_right; public: - PhotoMakerIDEncoder(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_SDXL, float sty = 20.f) - : GGMLRunner(backend, wtype), + PhotoMakerIDEncoder(ggml_backend_t backend, std::map& tensor_types, const std::string prefix, SDVersion version = VERSION_SDXL, PMVersion pm_v = PM_VERSION_1, float sty = 20.f) + : GGMLRunner(backend), version(version), + pm_version(pm_v), style_strength(sty) { - id_encoder.init(params_ctx, wtype); + if (pm_version == PM_VERSION_1) { + id_encoder.init(params_ctx, tensor_types, prefix); + } else if (pm_version == PM_VERSION_2) { + id_encoder2.init(params_ctx, tensor_types, prefix); + } } std::string get_desc() { return "pmid"; } + PMVersion get_version() const { + return pm_version; + } + void get_param_tensors(std::map& tensors, const std::string prefix) { - id_encoder.get_param_tensors(tensors, prefix); + if (pm_version == PM_VERSION_1) + id_encoder.get_param_tensors(tensors, prefix); + else if (pm_version == PM_VERSION_2) + id_encoder2.get_param_tensors(tensors, prefix); } struct ggml_cgraph* build_graph( // struct ggml_allocr* allocr, struct ggml_tensor* id_pixel_values, struct ggml_tensor* prompt_embeds, - std::vector& class_tokens_mask) { + std::vector& class_tokens_mask, + struct ggml_tensor* id_embeds) { ctm.clear(); ctmf16.clear(); ctmpos.clear(); @@ -214,25 +675,32 @@ struct PhotoMakerIDEncoder : public GGMLRunner { struct ggml_tensor* id_pixel_values_d = to_backend(id_pixel_values); struct ggml_tensor* prompt_embeds_d = to_backend(prompt_embeds); + struct ggml_tensor* id_embeds_d = to_backend(id_embeds); struct ggml_tensor* left = NULL; struct ggml_tensor* right = NULL; for (int i = 0; i < class_tokens_mask.size(); i++) { if (class_tokens_mask[i]) { + // printf(" 1,"); ctm.push_back(0.f); // here use 0.f instead of 1.f to make a scale mask ctmf16.push_back(ggml_fp32_to_fp16(0.f)); // here use 0.f instead of 1.f to make a scale mask ctmpos.push_back(i); } else { + // printf(" 0,"); ctm.push_back(1.f); // here use 1.f instead of 0.f to make a scale mask ctmf16.push_back(ggml_fp32_to_fp16(1.f)); // here use 0.f instead of 1.f to make a scale mask } } + // printf("\n"); if (ctmpos[0] > 0) { - left = ggml_new_tensor_3d(ctx0, type, hidden_size, 1, ctmpos[0]); + // left = ggml_new_tensor_3d(ctx0, type, hidden_size, 1, ctmpos[0]); + left = ggml_new_tensor_3d(ctx0, type, hidden_size, ctmpos[0], 1); } if (ctmpos[ctmpos.size() - 1] < seq_length - 1) { + // right = ggml_new_tensor_3d(ctx0, type, + // hidden_size, 1, seq_length - ctmpos[ctmpos.size() - 1] - 1); right = ggml_new_tensor_3d(ctx0, type, - hidden_size, 1, seq_length - ctmpos[ctmpos.size() - 1] - 1); + hidden_size, seq_length - ctmpos[ctmpos.size() - 1] - 1, 1); } struct ggml_tensor* class_tokens_mask_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ctmpos.size()); @@ -265,12 +733,23 @@ struct PhotoMakerIDEncoder : public GGMLRunner { } } } - struct ggml_tensor* updated_prompt_embeds = id_encoder.forward(ctx0, - id_pixel_values_d, - prompt_embeds_d, - class_tokens_mask_d, - class_tokens_mask_pos, - left, right); + struct ggml_tensor* updated_prompt_embeds = NULL; + if (pm_version == PM_VERSION_1) + updated_prompt_embeds = id_encoder.forward(ctx0, + id_pixel_values_d, + prompt_embeds_d, + class_tokens_mask_d, + class_tokens_mask_pos, + left, right); + else if (pm_version == PM_VERSION_2) + updated_prompt_embeds = id_encoder2.forward(ctx0, + id_pixel_values_d, + prompt_embeds_d, + class_tokens_mask_d, + class_tokens_mask_pos, + id_embeds_d, + left, right); + ggml_build_forward_expand(gf, updated_prompt_embeds); return gf; @@ -279,12 +758,13 @@ struct PhotoMakerIDEncoder : public GGMLRunner { void compute(const int n_threads, struct ggml_tensor* id_pixel_values, struct ggml_tensor* prompt_embeds, + struct ggml_tensor* id_embeds, std::vector& class_tokens_mask, struct ggml_tensor** updated_prompt_embeds, ggml_context* output_ctx) { auto get_graph = [&]() -> struct ggml_cgraph* { // return build_graph(compute_allocr, id_pixel_values, prompt_embeds, class_tokens_mask); - return build_graph(id_pixel_values, prompt_embeds, class_tokens_mask); + return build_graph(id_pixel_values, prompt_embeds, class_tokens_mask, id_embeds); }; // GGMLRunner::compute(get_graph, n_threads, updated_prompt_embeds); @@ -292,4 +772,74 @@ struct PhotoMakerIDEncoder : public GGMLRunner { } }; +struct PhotoMakerIDEmbed : public GGMLRunner { + std::map tensors; + std::string file_path; + ModelLoader* model_loader; + bool load_failed = false; + bool applied = false; + + PhotoMakerIDEmbed(ggml_backend_t backend, + ModelLoader* ml, + const std::string& file_path = "", + const std::string& prefix = "") + : file_path(file_path), GGMLRunner(backend), model_loader(ml) { + if (!model_loader->init_from_file(file_path, prefix)) { + load_failed = true; + } + } + + std::string get_desc() { + return "id_embeds"; + } + + bool load_from_file(bool filter_tensor = false) { + LOG_INFO("loading PhotoMaker ID Embeds from '%s'", file_path.c_str()); + + if (load_failed) { + LOG_ERROR("init photomaker id embed from file failed: '%s'", file_path.c_str()); + return false; + } + + bool dry_run = true; + auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool { + const std::string& name = tensor_storage.name; + + if (filter_tensor && !contains(name, "pmid.id_embeds")) { + // LOG_INFO("skipping LoRA tesnor '%s'", name.c_str()); + return true; + } + if (dry_run) { + struct ggml_tensor* real = ggml_new_tensor(params_ctx, + tensor_storage.type, + tensor_storage.n_dims, + tensor_storage.ne); + tensors[name] = real; + } else { + auto real = tensors[name]; + *dst_tensor = real; + } + + return true; + }; + + model_loader->load_tensors(on_new_tensor_cb, backend); + alloc_params_buffer(); + + dry_run = false; + model_loader->load_tensors(on_new_tensor_cb, backend); + + LOG_DEBUG("finished loading PhotoMaker ID Embeds "); + return true; + } + + struct ggml_tensor* get() { + std::map::iterator pos; + pos = tensors.find("pmid.id_embeds"); + if (pos != tensors.end()) + return pos->second; + return NULL; + } +}; + #endif // __PMI_HPP__ diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 07b59bb8a..e38a6101f 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -26,12 +26,15 @@ const char* model_version_to_str[] = { "SD 1.x", + "SD 1.x Inpaint", "SD 2.x", + "SD 2.x Inpaint", "SDXL", + "SDXL Inpaint", "SVD", - "SD3 2B", - "Flux Dev", - "Flux Schnell"}; + "SD3.x", + "Flux", + "Flux Fill"}; const char* sampling_methods_str[] = { "Euler A", @@ -44,6 +47,8 @@ const char* sampling_methods_str[] = { "iPNDM", "iPNDM_v", "LCM", + "DDIM \"trailing\"", + "TCD" }; /*================================================== Helper Functions ================================================*/ @@ -92,6 +97,7 @@ class StableDiffusionGGML { std::shared_ptr control_net; std::shared_ptr pmid_model; std::shared_ptr pmid_lora; + std::shared_ptr pmid_id_embeds; std::string taesd_path; bool use_tiny_autoencoder = false; @@ -139,6 +145,7 @@ class StableDiffusionGGML { bool load_from_file(const std::string& model_path, const std::string& clip_l_path, + const std::string& clip_g_path, const std::string& t5xxl_path, const std::string& diffusion_model_path, const std::string& vae_path, @@ -151,15 +158,16 @@ class StableDiffusionGGML { schedule_t schedule, bool clip_on_cpu, bool control_net_cpu, - bool vae_on_cpu) { + bool vae_on_cpu, + bool diffusion_flash_attn) { use_tiny_autoencoder = taesd_path.size() > 0; -#ifdef SD_USE_CUBLAS +#ifdef SD_USE_CUDA LOG_DEBUG("Using CUDA backend"); backend = ggml_backend_cuda_init(0); #endif #ifdef SD_USE_METAL LOG_DEBUG("Using Metal backend"); - ggml_backend_metal_log_set_callback(ggml_log_callback_default, nullptr); + ggml_log_set(ggml_log_callback_default, nullptr); backend = ggml_backend_metal_init(); #endif #ifdef SD_USE_VULKAN @@ -167,7 +175,7 @@ class StableDiffusionGGML { for (int device = 0; device < ggml_backend_vk_get_device_count(); ++device) { backend = ggml_backend_vk_init(device); } - if(!backend) { + if (!backend) { LOG_WARN("Failed to initialize Vulkan backend"); } #endif @@ -180,13 +188,7 @@ class StableDiffusionGGML { LOG_DEBUG("Using CPU backend"); backend = ggml_backend_cpu_init(); } -#ifdef SD_USE_FLASH_ATTENTION -#if defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) || defined (SD_USE_SYCL) || defined(SD_USE_VULKAN) - LOG_WARN("Flash Attention not supported with GPU Backend"); -#else - LOG_INFO("Flash Attention enabled"); -#endif -#endif + ModelLoader model_loader; vae_tiling = vae_tiling_; @@ -200,14 +202,21 @@ class StableDiffusionGGML { if (clip_l_path.size() > 0) { LOG_INFO("loading clip_l from '%s'", clip_l_path.c_str()); - if (!model_loader.init_from_file(clip_l_path, "text_encoders.clip_l.")) { + if (!model_loader.init_from_file(clip_l_path, "text_encoders.clip_l.transformer.")) { LOG_WARN("loading clip_l from '%s' failed", clip_l_path.c_str()); } } + if (clip_g_path.size() > 0) { + LOG_INFO("loading clip_g from '%s'", clip_g_path.c_str()); + if (!model_loader.init_from_file(clip_g_path, "text_encoders.clip_g.transformer.")) { + LOG_WARN("loading clip_g from '%s' failed", clip_g_path.c_str()); + } + } + if (t5xxl_path.size() > 0) { LOG_INFO("loading t5xxl from '%s'", t5xxl_path.c_str()); - if (!model_loader.init_from_file(t5xxl_path, "text_encoders.t5xxl.")) { + if (!model_loader.init_from_file(t5xxl_path, "text_encoders.t5xxl.transformer.")) { LOG_WARN("loading t5xxl from '%s' failed", t5xxl_path.c_str()); } } @@ -257,20 +266,22 @@ class StableDiffusionGGML { conditioner_wtype = wtype; diffusion_model_wtype = wtype; vae_wtype = wtype; + model_loader.set_wtype_override(wtype); } - if (version == VERSION_SDXL) { + if (sd_version_is_sdxl(version)) { vae_wtype = GGML_TYPE_F32; + model_loader.set_wtype_override(GGML_TYPE_F32, "vae."); } - LOG_INFO("Weight type: %s", ggml_type_name(model_wtype)); - LOG_INFO("Conditioner weight type: %s", ggml_type_name(conditioner_wtype)); - LOG_INFO("Diffusion model weight type: %s", ggml_type_name(diffusion_model_wtype)); - LOG_INFO("VAE weight type: %s", ggml_type_name(vae_wtype)); + LOG_INFO("Weight type: %s", model_wtype != SD_TYPE_COUNT ? ggml_type_name(model_wtype) : "??"); + LOG_INFO("Conditioner weight type: %s", conditioner_wtype != SD_TYPE_COUNT ? ggml_type_name(conditioner_wtype) : "??"); + LOG_INFO("Diffusion model weight type: %s", diffusion_model_wtype != SD_TYPE_COUNT ? ggml_type_name(diffusion_model_wtype) : "??"); + LOG_INFO("VAE weight type: %s", vae_wtype != SD_TYPE_COUNT ? ggml_type_name(vae_wtype) : "??"); LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor)); - if (version == VERSION_SDXL) { + if (sd_version_is_sdxl(version)) { scale_factor = 0.13025f; if (vae_path.size() == 0 && taesd_path.size() == 0) { LOG_WARN( @@ -279,30 +290,30 @@ class StableDiffusionGGML { "try specifying SDXL VAE FP16 Fix with the --vae parameter. " "You can find it here: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors"); } - } else if (version == VERSION_SD3_2B) { + } else if (sd_version_is_sd3(version)) { scale_factor = 1.5305f; - } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + } else if (sd_version_is_flux(version)) { scale_factor = 0.3611; // TODO: shift_factor } if (version == VERSION_SVD) { - clip_vision = std::make_shared(backend, conditioner_wtype); + clip_vision = std::make_shared(backend, model_loader.tensor_storages_types); clip_vision->alloc_params_buffer(); clip_vision->get_param_tensors(tensors); - diffusion_model = std::make_shared(backend, diffusion_model_wtype, version); + diffusion_model = std::make_shared(backend, model_loader.tensor_storages_types, version); diffusion_model->alloc_params_buffer(); diffusion_model->get_param_tensors(tensors); - first_stage_model = std::make_shared(backend, vae_wtype, vae_decode_only, true, version); + first_stage_model = std::make_shared(backend, model_loader.tensor_storages_types, "first_stage_model", vae_decode_only, true, version); LOG_DEBUG("vae_decode_only %d", vae_decode_only); first_stage_model->alloc_params_buffer(); first_stage_model->get_param_tensors(tensors, "first_stage_model"); } else { clip_backend = backend; bool use_t5xxl = false; - if (version == VERSION_SD3_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + if (sd_version_is_dit(version)) { use_t5xxl = true; } if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) { @@ -313,16 +324,27 @@ class StableDiffusionGGML { LOG_INFO("CLIP: Using CPU backend"); clip_backend = ggml_backend_cpu_init(); } - if (version == VERSION_SD3_2B) { - cond_stage_model = std::make_shared(clip_backend, conditioner_wtype); - diffusion_model = std::make_shared(backend, diffusion_model_wtype, version); - } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { - cond_stage_model = std::make_shared(clip_backend, conditioner_wtype); - diffusion_model = std::make_shared(backend, diffusion_model_wtype, version); + if (diffusion_flash_attn) { + LOG_INFO("Using flash attention in the diffusion model"); + } + if (sd_version_is_sd3(version)) { + if (diffusion_flash_attn) { + LOG_WARN("flash attention in this diffusion model is currently unsupported!"); + } + cond_stage_model = std::make_shared(clip_backend, model_loader.tensor_storages_types); + diffusion_model = std::make_shared(backend, model_loader.tensor_storages_types); + } else if (sd_version_is_flux(version)) { + cond_stage_model = std::make_shared(clip_backend, model_loader.tensor_storages_types); + diffusion_model = std::make_shared(backend, model_loader.tensor_storages_types, version, diffusion_flash_attn); } else { - cond_stage_model = std::make_shared(clip_backend, conditioner_wtype, embeddings_path, version); - diffusion_model = std::make_shared(backend, diffusion_model_wtype, version); + if (id_embeddings_path.find("v2") != std::string::npos) { + cond_stage_model = std::make_shared(clip_backend, model_loader.tensor_storages_types, embeddings_path, version, PM_VERSION_2); + } else { + cond_stage_model = std::make_shared(clip_backend, model_loader.tensor_storages_types, embeddings_path, version); + } + diffusion_model = std::make_shared(backend, model_loader.tensor_storages_types, version, diffusion_flash_attn); } + cond_stage_model->alloc_params_buffer(); cond_stage_model->get_param_tensors(tensors); @@ -336,11 +358,11 @@ class StableDiffusionGGML { } else { vae_backend = backend; } - first_stage_model = std::make_shared(vae_backend, vae_wtype, vae_decode_only, false, version); + first_stage_model = std::make_shared(vae_backend, model_loader.tensor_storages_types, "first_stage_model", vae_decode_only, false, version); first_stage_model->alloc_params_buffer(); first_stage_model->get_param_tensors(tensors, "first_stage_model"); } else { - tae_first_stage = std::make_shared(backend, vae_wtype, vae_decode_only); + tae_first_stage = std::make_shared(backend, model_loader.tensor_storages_types, "decoder.layers", vae_decode_only, version); } // first_stage_model->get_param_tensors(tensors, "first_stage_model."); @@ -352,12 +374,17 @@ class StableDiffusionGGML { } else { controlnet_backend = backend; } - control_net = std::make_shared(controlnet_backend, diffusion_model_wtype, version); + control_net = std::make_shared(controlnet_backend, model_loader.tensor_storages_types, version); } - pmid_model = std::make_shared(clip_backend, model_wtype, version); + if (id_embeddings_path.find("v2") != std::string::npos) { + pmid_model = std::make_shared(backend, model_loader.tensor_storages_types, "pmid", version, PM_VERSION_2); + LOG_INFO("using PhotoMaker Version 2"); + } else { + pmid_model = std::make_shared(backend, model_loader.tensor_storages_types, "pmid", version); + } if (id_embeddings_path.size() > 0) { - pmid_lora = std::make_shared(backend, model_wtype, id_embeddings_path, ""); + pmid_lora = std::make_shared(backend, id_embeddings_path, ""); if (!pmid_lora->load_from_file(true)) { LOG_WARN("load photomaker lora tensors from %s failed", id_embeddings_path.c_str()); return false; @@ -374,14 +401,8 @@ class StableDiffusionGGML { LOG_ERROR(" pmid model params buffer allocation failed"); return false; } - // LOG_INFO("pmid param memory buffer size = %.2fMB ", - // pmid_model->params_buffer_size / 1024.0 / 1024.0); pmid_model->get_param_tensors(tensors, "pmid"); } - // if(stacked_id){ - // pmid_model.init_params(GGML_TYPE_F32); - // pmid_model.map_by_name(tensors, "pmid."); - // } } struct ggml_init_params params; @@ -502,8 +523,12 @@ class StableDiffusionGGML { // check is_using_v_parameterization_for_sd2 bool is_using_v_parameterization = false; - if (version == VERSION_SD2) { - if (is_using_v_parameterization_for_sd2(ctx)) { + if (sd_version_is_sd2(version)) { + if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) { + is_using_v_parameterization = true; + } + } else if (sd_version_is_sdxl(version)) { + if (model_loader.tensor_storages_types.find("v_pred") != model_loader.tensor_storages_types.end()) { is_using_v_parameterization = true; } } else if (version == VERSION_SVD) { @@ -511,14 +536,17 @@ class StableDiffusionGGML { is_using_v_parameterization = true; } - if (version == VERSION_SD3_2B) { + if (sd_version_is_sd3(version)) { LOG_INFO("running in FLOW mode"); denoiser = std::make_shared(); - } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + } else if (sd_version_is_flux(version)) { LOG_INFO("running in Flux FLOW mode"); - float shift = 1.15f; - if (version == VERSION_FLUX_SCHNELL) { - shift = 1.0f; // TODO: validate + float shift = 1.0f; // TODO: validate + for (auto pair : model_loader.tensor_storages_types) { + if (pair.first.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) { + shift = 1.15f; + break; + } } denoiser = std::make_shared(shift); } else if (is_using_v_parameterization) { @@ -574,7 +602,7 @@ class StableDiffusionGGML { return true; } - bool is_using_v_parameterization_for_sd2(ggml_context* work_ctx) { + bool is_using_v_parameterization_for_sd2(ggml_context* work_ctx, bool is_inpaint = false) { struct ggml_tensor* x_t = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 4, 1); ggml_set_f32(x_t, 0.5); struct ggml_tensor* c = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 1024, 2, 1, 1); @@ -582,9 +610,15 @@ class StableDiffusionGGML { struct ggml_tensor* timesteps = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 1); ggml_set_f32(timesteps, 999); + + struct ggml_tensor* concat = is_inpaint ? ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 5, 1) : NULL; + if (concat != NULL) { + ggml_set_f32(concat, 0); + } + int64_t t0 = ggml_time_ms(); struct ggml_tensor* out = ggml_dup_tensor(work_ctx, x_t); - diffusion_model->compute(n_threads, x_t, timesteps, c, NULL, NULL, NULL, -1, {}, 0.f, &out); + diffusion_model->compute(n_threads, x_t, timesteps, c, concat, NULL, NULL, -1, {}, 0.f, &out); diffusion_model->free_compute_buffer(); double result = 0.f; @@ -617,14 +651,15 @@ class StableDiffusionGGML { LOG_WARN("can not find %s or %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), lora_name.c_str()); return; } - LoraModel lora(backend, model_wtype, file_path); + LoraModel lora(backend, file_path); if (!lora.load_from_file()) { LOG_WARN("load lora tensors from %s failed", file_path.c_str()); return; } lora.multiplier = multiplier; - lora.apply(tensors, n_threads); + // TODO: send version? + lora.apply(tensors, version, n_threads); lora.free_params_buffer(); int64_t t1 = ggml_time_ms(); @@ -640,19 +675,20 @@ class StableDiffusionGGML { for (auto& kv : lora_state) { const std::string& lora_name = kv.first; float multiplier = kv.second; - - if (curr_lora_state.find(lora_name) != curr_lora_state.end()) { - float curr_multiplier = curr_lora_state[lora_name]; - float multiplier_diff = multiplier - curr_multiplier; - if (multiplier_diff != 0.f) { - lora_state_diff[lora_name] = multiplier_diff; - } - } else { - lora_state_diff[lora_name] = multiplier; - } + lora_state_diff[lora_name] += multiplier; + } + for (auto& kv : curr_lora_state) { + const std::string& lora_name = kv.first; + float curr_multiplier = kv.second; + lora_state_diff[lora_name] -= curr_multiplier; + } + + size_t rm = lora_state_diff.size() - lora_state.size(); + if (rm != 0) { + LOG_INFO("Attempting to apply %lu LoRAs (removing %lu applied LoRAs)", lora_state.size(), rm); + } else { + LOG_INFO("Attempting to apply %lu LoRAs", lora_state.size()); } - - LOG_INFO("Attempting to apply %lu LoRAs", lora_state.size()); for (auto& kv : lora_state_diff) { apply_lora(kv.first, kv.second); @@ -664,10 +700,10 @@ class StableDiffusionGGML { ggml_tensor* id_encoder(ggml_context* work_ctx, ggml_tensor* init_img, ggml_tensor* prompts_embeds, + ggml_tensor* id_embeds, std::vector& class_tokens_mask) { ggml_tensor* res = NULL; - pmid_model->compute(n_threads, init_img, prompts_embeds, class_tokens_mask, &res, work_ctx); - + pmid_model->compute(n_threads, init_img, prompts_embeds, id_embeds, class_tokens_mask, &res, work_ctx); return res; } @@ -759,10 +795,28 @@ class StableDiffusionGGML { float min_cfg, float cfg_scale, float guidance, + float eta, sample_method_t method, const std::vector& sigmas, int start_merge_step, - SDCondition id_cond) { + SDCondition id_cond, + std::vector skip_layers = {}, + float slg_scale = 0, + float skip_layer_start = 0.01, + float skip_layer_end = 0.2, + ggml_tensor* noise_mask = nullptr) { + LOG_DEBUG("Sample"); + struct ggml_init_params params; + size_t data_size = ggml_row_size(init_latent->type, init_latent->ne[0]); + for (int i = 1; i < 4; i++) { + data_size *= init_latent->ne[i]; + } + data_size += 1024; + params.mem_size = data_size * 3; + params.mem_buffer = NULL; + params.no_alloc = false; + ggml_context* tmp_ctx = ggml_init(params); + size_t steps = sigmas.size() - 1; // noise = load_tensor_from_file(work_ctx, "./rand0.bin"); // print_ggml_tensor(noise); @@ -773,13 +827,24 @@ class StableDiffusionGGML { struct ggml_tensor* noised_input = ggml_dup_tensor(work_ctx, noise); bool has_unconditioned = cfg_scale != 1.0 && uncond.c_crossattn != NULL; + bool has_skiplayer = slg_scale != 0.0 && skip_layers.size() > 0; // denoise wrapper struct ggml_tensor* out_cond = ggml_dup_tensor(work_ctx, x); struct ggml_tensor* out_uncond = NULL; + struct ggml_tensor* out_skip = NULL; + if (has_unconditioned) { out_uncond = ggml_dup_tensor(work_ctx, x); } + if (has_skiplayer) { + if (sd_version_is_dit(version)) { + out_skip = ggml_dup_tensor(work_ctx, x); + } else { + has_skiplayer = false; + LOG_WARN("SLG is incompatible with %s models", model_version_to_str[version]); + } + } struct ggml_tensor* denoised = ggml_dup_tensor(work_ctx, x); auto denoise = [&](ggml_tensor* input, float sigma, int step) -> ggml_tensor* { @@ -860,6 +925,28 @@ class StableDiffusionGGML { &out_uncond); negative_data = (float*)out_uncond->data; } + + int step_count = sigmas.size(); + bool is_skiplayer_step = has_skiplayer && step > (int)(skip_layer_start * step_count) && step < (int)(skip_layer_end * step_count); + float* skip_layer_data = NULL; + if (is_skiplayer_step) { + LOG_DEBUG("Skipping layers at step %d\n", step); + // skip layer (same as conditionned) + diffusion_model->compute(n_threads, + noised_input, + timesteps, + cond.c_crossattn, + cond.c_concat, + cond.c_vector, + guidance_tensor, + -1, + controls, + control_strength, + &out_skip, + NULL, + skip_layers); + skip_layer_data = (float*)out_skip->data; + } float* vec_denoised = (float*)denoised->data; float* vec_input = (float*)input->data; float* positive_data = (float*)out_cond->data; @@ -876,6 +963,9 @@ class StableDiffusionGGML { latent_result = negative_data[i] + cfg_scale * (positive_data[i] - negative_data[i]); } } + if (is_skiplayer_step) { + latent_result = latent_result + (positive_data[i] - skip_layer_data[i]) * slg_scale; + } // v = latent_result, eps = latent_result // denoised = (v * c_out + input * c_skip) or (input + eps * c_out) vec_denoised[i] = latent_result * c_out + vec_input[i] * c_skip; @@ -885,10 +975,23 @@ class StableDiffusionGGML { pretty_progress(step, (int)steps, (t1 - t0) / 1000000.f); // LOG_INFO("step %d sampling completed taking %.2fs", step, (t1 - t0) * 1.0f / 1000000); } + if (noise_mask != nullptr) { + for (int64_t x = 0; x < denoised->ne[0]; x++) { + for (int64_t y = 0; y < denoised->ne[1]; y++) { + float mask = ggml_tensor_get_f32(noise_mask, x, y); + for (int64_t k = 0; k < denoised->ne[2]; k++) { + float init = ggml_tensor_get_f32(init_latent, x, y, k); + float den = ggml_tensor_get_f32(denoised, x, y, k); + ggml_tensor_set_f32(denoised, init + mask * (den - init), x, y, k); + } + } + } + } + return denoised; }; - sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng); + sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng, eta); x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x); @@ -939,9 +1042,9 @@ class StableDiffusionGGML { if (use_tiny_autoencoder) { C = 4; } else { - if (version == VERSION_SD3_2B) { + if (sd_version_is_sd3(version)) { C = 32; - } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + } else if (sd_version_is_flux(version)) { C = 32; } } @@ -1008,6 +1111,7 @@ struct sd_ctx_t { sd_ctx_t* new_sd_ctx(const char* model_path_c_str, const char* clip_l_path_c_str, + const char* clip_g_path_c_str, const char* t5xxl_path_c_str, const char* diffusion_model_path_c_str, const char* vae_path_c_str, @@ -1025,13 +1129,15 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, enum schedule_t s, bool keep_clip_on_cpu, bool keep_control_net_cpu, - bool keep_vae_on_cpu) { + bool keep_vae_on_cpu, + bool diffusion_flash_attn) { sd_ctx_t* sd_ctx = (sd_ctx_t*)malloc(sizeof(sd_ctx_t)); if (sd_ctx == NULL) { return NULL; } std::string model_path(model_path_c_str); std::string clip_l_path(clip_l_path_c_str); + std::string clip_g_path(clip_g_path_c_str); std::string t5xxl_path(t5xxl_path_c_str); std::string diffusion_model_path(diffusion_model_path_c_str); std::string vae_path(vae_path_c_str); @@ -1052,6 +1158,7 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, if (!sd_ctx->sd->load_from_file(model_path, clip_l_path, + clip_g_path, t5xxl_path_c_str, diffusion_model_path, vae_path, @@ -1064,7 +1171,8 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, s, keep_clip_on_cpu, keep_control_net_cpu, - keep_vae_on_cpu)) { + keep_vae_on_cpu, + diffusion_flash_attn)) { delete sd_ctx->sd; sd_ctx->sd = NULL; free(sd_ctx); @@ -1089,6 +1197,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, int clip_skip, float cfg_scale, float guidance, + float eta, int width, int height, enum sample_method_t sample_method, @@ -1099,7 +1208,12 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, float control_strength, float style_ratio, bool normalize_input, - std::string input_id_images_path) { + std::string input_id_images_path, + std::vector skip_layers = {}, + float slg_scale = 0, + float skip_layer_start = 0.01, + float skip_layer_end = 0.2, + ggml_tensor* masked_image = NULL) { if (seed < 0) { // Generally, when using the provided command line, the seed is always >0. // However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library @@ -1139,7 +1253,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, if (sd_ctx->sd->stacked_id) { if (!sd_ctx->sd->pmid_lora->applied) { t0 = ggml_time_ms(); - sd_ctx->sd->pmid_lora->apply(sd_ctx->sd->tensors, sd_ctx->sd->n_threads); + sd_ctx->sd->pmid_lora->apply(sd_ctx->sd->tensors, sd_ctx->sd->version, sd_ctx->sd->n_threads); t1 = ggml_time_ms(); sd_ctx->sd->pmid_lora->applied = true; LOG_INFO("pmid_lora apply completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); @@ -1149,11 +1263,15 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, } // preprocess input id images std::vector input_id_images; + bool pmv2 = sd_ctx->sd->pmid_model->get_version() == PM_VERSION_2; if (sd_ctx->sd->pmid_model && input_id_images_path.size() > 0) { std::vector img_files = get_files_from_dir(input_id_images_path); for (std::string img_file : img_files) { int c = 0; int width, height; + if (ends_with(img_file, "safetensors")) { + continue; + } uint8_t* input_image_buffer = stbi_load(img_file.c_str(), &width, &height, &c, 3); if (input_image_buffer == NULL) { LOG_ERROR("PhotoMaker load image from '%s' failed", img_file.c_str()); @@ -1191,18 +1309,23 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, else sd_mul_images_to_tensor(init_image->data, init_img, i, NULL, NULL); } - t0 = ggml_time_ms(); - auto cond_tup = sd_ctx->sd->cond_stage_model->get_learned_condition_with_trigger(work_ctx, - sd_ctx->sd->n_threads, prompt, - clip_skip, - width, - height, - num_input_images, - sd_ctx->sd->diffusion_model->get_adm_in_channels()); - id_cond = std::get<0>(cond_tup); - class_tokens_mask = std::get<1>(cond_tup); // - - id_cond.c_crossattn = sd_ctx->sd->id_encoder(work_ctx, init_img, id_cond.c_crossattn, class_tokens_mask); + t0 = ggml_time_ms(); + auto cond_tup = sd_ctx->sd->cond_stage_model->get_learned_condition_with_trigger(work_ctx, + sd_ctx->sd->n_threads, prompt, + clip_skip, + width, + height, + num_input_images, + sd_ctx->sd->diffusion_model->get_adm_in_channels()); + id_cond = std::get<0>(cond_tup); + class_tokens_mask = std::get<1>(cond_tup); // + struct ggml_tensor* id_embeds = NULL; + if (pmv2) { + // id_embeds = sd_ctx->sd->pmid_id_embeds->get(); + id_embeds = load_tensor_from_file(work_ctx, path_join(input_id_images_path, "id_embeds.bin")); + // print_ggml_tensor(id_embeds, true, "id_embeds:"); + } + id_cond.c_crossattn = sd_ctx->sd->id_encoder(work_ctx, init_img, id_cond.c_crossattn, id_embeds, class_tokens_mask); t1 = ggml_time_ms(); LOG_INFO("Photomaker ID Stacking, taking %" PRId64 " ms", t1 - t0); if (sd_ctx->sd->free_params_immediately) { @@ -1240,7 +1363,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, SDCondition uncond; if (cfg_scale != 1.0) { bool force_zero_embeddings = false; - if (sd_ctx->sd->version == VERSION_SDXL && negative_prompt.size() == 0) { + if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0) { force_zero_embeddings = true; } uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx, @@ -1269,14 +1392,47 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, // Sample std::vector final_latents; // collect latents to decode int C = 4; - if (sd_ctx->sd->version == VERSION_SD3_2B) { + if (sd_version_is_sd3(sd_ctx->sd->version)) { C = 16; - } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { + } else if (sd_version_is_flux(sd_ctx->sd->version)) { C = 16; } int W = width / 8; int H = height / 8; LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]); + ggml_tensor* noise_mask = nullptr; + if (sd_version_is_inpaint(sd_ctx->sd->version)) { + if (masked_image == NULL) { + int64_t mask_channels = 1; + if (sd_ctx->sd->version == VERSION_FLUX_FILL) { + mask_channels = 8 * 8; // flatten the whole mask + } + // no mask, set the whole image as masked + masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1); + for (int64_t x = 0; x < masked_image->ne[0]; x++) { + for (int64_t y = 0; y < masked_image->ne[1]; y++) { + if (sd_ctx->sd->version == VERSION_FLUX_FILL) { + // TODO: this might be wrong + for (int64_t c = 0; c < init_latent->ne[2]; c++) { + ggml_tensor_set_f32(masked_image, 0, x, y, c); + } + for (int64_t c = init_latent->ne[2]; c < masked_image->ne[2]; c++) { + ggml_tensor_set_f32(masked_image, 1, x, y, c); + } + } else { + ggml_tensor_set_f32(masked_image, 1, x, y, 0); + for (int64_t c = 1; c < masked_image->ne[2]; c++) { + ggml_tensor_set_f32(masked_image, 0, x, y, c); + } + } + } + } + } + cond.c_concat = masked_image; + uncond.c_concat = masked_image; + } else { + noise_mask = masked_image; + } for (int b = 0; b < batch_count; b++) { int64_t sampling_start = ggml_time_ms(); int64_t cur_seed = seed + b; @@ -1305,10 +1461,17 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, cfg_scale, cfg_scale, guidance, + eta, sample_method, sigmas, start_merge_step, - id_cond); + id_cond, + skip_layers, + slg_scale, + skip_layer_start, + skip_layer_end, + noise_mask); + // struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin"); // print_ggml_tensor(x_0); int64_t sampling_end = ggml_time_ms(); @@ -1364,6 +1527,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, int clip_skip, float cfg_scale, float guidance, + float eta, int width, int height, enum sample_method_t sample_method, @@ -1374,7 +1538,13 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, float control_strength, float style_ratio, bool normalize_input, - const char* input_id_images_path_c_str) { + const char* input_id_images_path_c_str, + int* skip_layers = NULL, + size_t skip_layers_count = 0, + float slg_scale = 0, + float skip_layer_start = 0.01, + float skip_layer_end = 0.2) { + std::vector skip_layers_vec(skip_layers, skip_layers + skip_layers_count); LOG_DEBUG("txt2img %dx%d", width, height); if (sd_ctx == NULL) { return NULL; @@ -1382,10 +1552,10 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, struct ggml_init_params params; params.mem_size = static_cast(10 * 1024 * 1024); // 10 MB - if (sd_ctx->sd->version == VERSION_SD3_2B) { + if (sd_version_is_sd3(sd_ctx->sd->version)) { params.mem_size *= 3; } - if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { + if (sd_version_is_flux(sd_ctx->sd->version)) { params.mem_size *= 4; } if (sd_ctx->sd->stacked_id) { @@ -1408,22 +1578,26 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps); int C = 4; - if (sd_ctx->sd->version == VERSION_SD3_2B) { + if (sd_version_is_sd3(sd_ctx->sd->version)) { C = 16; - } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { + } else if (sd_version_is_flux(sd_ctx->sd->version)) { C = 16; } int W = width / 8; int H = height / 8; ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1); - if (sd_ctx->sd->version == VERSION_SD3_2B) { + if (sd_version_is_sd3(sd_ctx->sd->version)) { ggml_set_f32(init_latent, 0.0609f); - } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { + } else if (sd_version_is_flux(sd_ctx->sd->version)) { ggml_set_f32(init_latent, 0.1159f); } else { ggml_set_f32(init_latent, 0.f); } + if (sd_version_is_inpaint(sd_ctx->sd->version)) { + LOG_WARN("This is an inpainting model, this should only be used in img2img mode with a mask"); + } + sd_image_t* result_images = generate_image(sd_ctx, work_ctx, init_latent, @@ -1432,6 +1606,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, clip_skip, cfg_scale, guidance, + eta, width, height, sample_method, @@ -1442,7 +1617,11 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, control_strength, style_ratio, normalize_input, - input_id_images_path_c_str); + input_id_images_path_c_str, + skip_layers_vec, + slg_scale, + skip_layer_start, + skip_layer_end); size_t t1 = ggml_time_ms(); @@ -1453,11 +1632,13 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, sd_image_t* img2img(sd_ctx_t* sd_ctx, sd_image_t init_image, + sd_image_t mask, const char* prompt_c_str, const char* negative_prompt_c_str, int clip_skip, float cfg_scale, float guidance, + float eta, int width, int height, sample_method_t sample_method, @@ -1469,7 +1650,13 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, float control_strength, float style_ratio, bool normalize_input, - const char* input_id_images_path_c_str) { + const char* input_id_images_path_c_str, + int* skip_layers = NULL, + size_t skip_layers_count = 0, + float slg_scale = 0, + float skip_layer_start = 0.01, + float skip_layer_end = 0.2) { + std::vector skip_layers_vec(skip_layers, skip_layers + skip_layers_count); LOG_DEBUG("img2img %dx%d", width, height); if (sd_ctx == NULL) { return NULL; @@ -1477,16 +1664,16 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, struct ggml_init_params params; params.mem_size = static_cast(10 * 1024 * 1024); // 10 MB - if (sd_ctx->sd->version == VERSION_SD3_2B) { + if (sd_version_is_sd3(sd_ctx->sd->version)) { params.mem_size *= 2; } - if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { + if (sd_version_is_flux(sd_ctx->sd->version)) { params.mem_size *= 3; } if (sd_ctx->sd->stacked_id) { params.mem_size += static_cast(10 * 1024 * 1024); // 10 MB } - params.mem_size += width * height * 3 * sizeof(float) * 2; + params.mem_size += width * height * 3 * sizeof(float) * 3; params.mem_size *= batch_count; params.mem_buffer = NULL; params.no_alloc = false; @@ -1507,7 +1694,70 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, sd_ctx->sd->rng->manual_seed(seed); ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); + ggml_tensor* mask_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 1, 1); + + sd_mask_to_tensor(mask.data, mask_img); + sd_image_to_tensor(init_image.data, init_img); + + ggml_tensor* masked_image; + + if (sd_version_is_inpaint(sd_ctx->sd->version)) { + int64_t mask_channels = 1; + if (sd_ctx->sd->version == VERSION_FLUX_FILL) { + mask_channels = 8 * 8; // flatten the whole mask + } + ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); + sd_apply_mask(init_img, mask_img, masked_img); + ggml_tensor* masked_image_0 = NULL; + if (!sd_ctx->sd->use_tiny_autoencoder) { + ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, masked_img); + masked_image_0 = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments); + } else { + masked_image_0 = sd_ctx->sd->encode_first_stage(work_ctx, masked_img); + } + masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, masked_image_0->ne[0], masked_image_0->ne[1], mask_channels + masked_image_0->ne[2], 1); + for (int ix = 0; ix < masked_image_0->ne[0]; ix++) { + for (int iy = 0; iy < masked_image_0->ne[1]; iy++) { + int mx = ix * 8; + int my = iy * 8; + if (sd_ctx->sd->version == VERSION_FLUX_FILL) { + for (int k = 0; k < masked_image_0->ne[2]; k++) { + float v = ggml_tensor_get_f32(masked_image_0, ix, iy, k); + ggml_tensor_set_f32(masked_image, v, ix, iy, k); + } + // "Encode" 8x8 mask chunks into a flattened 1x64 vector, and concatenate to masked image + for (int x = 0; x < 8; x++) { + for (int y = 0; y < 8; y++) { + float m = ggml_tensor_get_f32(mask_img, mx + x, my + y); + // TODO: check if the way the mask is flattened is correct (is it supposed to be x*8+y or x+8*y?) + // python code was using "b (h 8) (w 8) -> b (8 8) h w" + ggml_tensor_set_f32(masked_image, m, ix, iy, masked_image_0->ne[2] + x * 8 + y); + } + } + } else { + float m = ggml_tensor_get_f32(mask_img, mx, my); + ggml_tensor_set_f32(masked_image, m, ix, iy, 0); + for (int k = 0; k < masked_image_0->ne[2]; k++) { + float v = ggml_tensor_get_f32(masked_image_0, ix, iy, k); + ggml_tensor_set_f32(masked_image, v, ix, iy, k + mask_channels); + } + } + } + } + } else { + // LOG_WARN("Inpainting with a base model is not great"); + masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / 8, height / 8, 1, 1); + for (int ix = 0; ix < masked_image->ne[0]; ix++) { + for (int iy = 0; iy < masked_image->ne[1]; iy++) { + int mx = ix * 8; + int my = iy * 8; + float m = ggml_tensor_get_f32(mask_img, mx, my); + ggml_tensor_set_f32(masked_image, m, ix, iy); + } + } + } + ggml_tensor* init_latent = NULL; if (!sd_ctx->sd->use_tiny_autoencoder) { ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img); @@ -1515,12 +1765,15 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, } else { init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); } + print_ggml_tensor(init_latent, true); size_t t1 = ggml_time_ms(); LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps); size_t t_enc = static_cast(sample_steps * strength); + if (t_enc == sample_steps) + t_enc--; LOG_INFO("target t_enc is %zu steps", t_enc); std::vector sigma_sched; sigma_sched.assign(sigmas.begin() + sample_steps - t_enc - 1, sigmas.end()); @@ -1533,6 +1786,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, clip_skip, cfg_scale, guidance, + eta, width, height, sample_method, @@ -1543,11 +1797,16 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, control_strength, style_ratio, normalize_input, - input_id_images_path_c_str); + input_id_images_path_c_str, + skip_layers_vec, + slg_scale, + skip_layer_start, + skip_layer_end, + masked_image); size_t t2 = ggml_time_ms(); - LOG_INFO("img2img completed in %.2fs", (t1 - t0) * 1.0f / 1000); + LOG_INFO("img2img completed in %.2fs", (t2 - t0) * 1.0f / 1000); return result_images; } @@ -1641,6 +1900,7 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx, min_cfg, cfg_scale, 0.f, + 0.f, sample_method, sigmas, -1, diff --git a/stable-diffusion.h b/stable-diffusion.h index 0d4cc1fda..52dcc848a 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -44,6 +44,8 @@ enum sample_method_t { IPNDM, IPNDM_V, LCM, + DDIM_TRAILING, + TCD, N_SAMPLE_METHODS }; @@ -59,41 +61,46 @@ enum schedule_t { // same as enum ggml_type enum sd_type_t { - SD_TYPE_F32 = 0, - SD_TYPE_F16 = 1, - SD_TYPE_Q4_0 = 2, - SD_TYPE_Q4_1 = 3, + SD_TYPE_F32 = 0, + SD_TYPE_F16 = 1, + SD_TYPE_Q4_0 = 2, + SD_TYPE_Q4_1 = 3, // SD_TYPE_Q4_2 = 4, support has been removed // SD_TYPE_Q4_3 = 5, support has been removed - SD_TYPE_Q5_0 = 6, - SD_TYPE_Q5_1 = 7, - SD_TYPE_Q8_0 = 8, - SD_TYPE_Q8_1 = 9, - SD_TYPE_Q2_K = 10, - SD_TYPE_Q3_K = 11, - SD_TYPE_Q4_K = 12, - SD_TYPE_Q5_K = 13, - SD_TYPE_Q6_K = 14, - SD_TYPE_Q8_K = 15, - SD_TYPE_IQ2_XXS = 16, - SD_TYPE_IQ2_XS = 17, - SD_TYPE_IQ3_XXS = 18, - SD_TYPE_IQ1_S = 19, - SD_TYPE_IQ4_NL = 20, - SD_TYPE_IQ3_S = 21, - SD_TYPE_IQ2_S = 22, - SD_TYPE_IQ4_XS = 23, - SD_TYPE_I8 = 24, - SD_TYPE_I16 = 25, - SD_TYPE_I32 = 26, - SD_TYPE_I64 = 27, - SD_TYPE_F64 = 28, - SD_TYPE_IQ1_M = 29, - SD_TYPE_BF16 = 30, - SD_TYPE_Q4_0_4_4 = 31, - SD_TYPE_Q4_0_4_8 = 32, - SD_TYPE_Q4_0_8_8 = 33, - SD_TYPE_COUNT, + SD_TYPE_Q5_0 = 6, + SD_TYPE_Q5_1 = 7, + SD_TYPE_Q8_0 = 8, + SD_TYPE_Q8_1 = 9, + SD_TYPE_Q2_K = 10, + SD_TYPE_Q3_K = 11, + SD_TYPE_Q4_K = 12, + SD_TYPE_Q5_K = 13, + SD_TYPE_Q6_K = 14, + SD_TYPE_Q8_K = 15, + SD_TYPE_IQ2_XXS = 16, + SD_TYPE_IQ2_XS = 17, + SD_TYPE_IQ3_XXS = 18, + SD_TYPE_IQ1_S = 19, + SD_TYPE_IQ4_NL = 20, + SD_TYPE_IQ3_S = 21, + SD_TYPE_IQ2_S = 22, + SD_TYPE_IQ4_XS = 23, + SD_TYPE_I8 = 24, + SD_TYPE_I16 = 25, + SD_TYPE_I32 = 26, + SD_TYPE_I64 = 27, + SD_TYPE_F64 = 28, + SD_TYPE_IQ1_M = 29, + SD_TYPE_BF16 = 30, + // SD_TYPE_Q4_0_4_4 = 31, support has been removed from gguf files + // SD_TYPE_Q4_0_4_8 = 32, + // SD_TYPE_Q4_0_8_8 = 33, + SD_TYPE_TQ1_0 = 34, + SD_TYPE_TQ2_0 = 35, + // SD_TYPE_IQ4_NL_4_4 = 36, + // SD_TYPE_IQ4_NL_4_8 = 37, + // SD_TYPE_IQ4_NL_8_8 = 38, + SD_TYPE_COUNT = 39, }; SD_API const char* sd_type_name(enum sd_type_t type); @@ -124,6 +131,7 @@ typedef struct sd_ctx_t sd_ctx_t; SD_API sd_ctx_t* new_sd_ctx(const char* model_path, const char* clip_l_path, + const char* clip_g_path, const char* t5xxl_path, const char* diffusion_model_path, const char* vae_path, @@ -141,7 +149,8 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path, enum schedule_t s, bool keep_clip_on_cpu, bool keep_control_net_cpu, - bool keep_vae_on_cpu); + bool keep_vae_on_cpu, + bool diffusion_flash_attn); SD_API void free_sd_ctx(sd_ctx_t* sd_ctx); @@ -151,6 +160,7 @@ SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx, int clip_skip, float cfg_scale, float guidance, + float eta, int width, int height, enum sample_method_t sample_method, @@ -161,15 +171,22 @@ SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx, float control_strength, float style_strength, bool normalize_input, - const char* input_id_images_path); + const char* input_id_images_path, + int* skip_layers, + size_t skip_layers_count, + float slg_scale, + float skip_layer_start, + float skip_layer_end); SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx, sd_image_t init_image, + sd_image_t mask_image, const char* prompt, const char* negative_prompt, int clip_skip, float cfg_scale, float guidance, + float eta, int width, int height, enum sample_method_t sample_method, @@ -181,7 +198,12 @@ SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx, float control_strength, float style_strength, bool normalize_input, - const char* input_id_images_path); + const char* input_id_images_path, + int* skip_layers, + size_t skip_layers_count, + float slg_scale, + float skip_layer_start, + float skip_layer_end); SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx, sd_image_t init_image, @@ -201,8 +223,7 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx, typedef struct upscaler_ctx_t upscaler_ctx_t; SD_API upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path, - int n_threads, - enum sd_type_t wtype); + int n_threads); SD_API void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx); SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_t upscale_factor); diff --git a/t5.hpp b/t5.hpp index 79109e34b..2a53e2743 100644 --- a/t5.hpp +++ b/t5.hpp @@ -441,8 +441,9 @@ class T5LayerNorm : public UnaryBlock { int64_t hidden_size; float eps; - void init_params(struct ggml_context* ctx, ggml_type wtype) { - params["weight"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size); + void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { + enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F32; + params["weight"] = ggml_new_tensor_1d(ctx, wtype, hidden_size); } public: @@ -717,14 +718,15 @@ struct T5Runner : public GGMLRunner { std::vector relative_position_bucket_vec; T5Runner(ggml_backend_t backend, - ggml_type wtype, + std::map& tensor_types, + const std::string prefix, int64_t num_layers = 24, int64_t model_dim = 4096, int64_t ff_dim = 10240, int64_t num_heads = 64, int64_t vocab_size = 32128) - : GGMLRunner(backend, wtype), model(num_layers, model_dim, ff_dim, num_heads, vocab_size) { - model.init(params_ctx, wtype); + : GGMLRunner(backend), model(num_layers, model_dim, ff_dim, num_heads, vocab_size) { + model.init(params_ctx, tensor_types, prefix); } std::string get_desc() { @@ -854,14 +856,17 @@ struct T5Embedder { T5UniGramTokenizer tokenizer; T5Runner model; + static std::map empty_tensor_types; + T5Embedder(ggml_backend_t backend, - ggml_type wtype, - int64_t num_layers = 24, - int64_t model_dim = 4096, - int64_t ff_dim = 10240, - int64_t num_heads = 64, - int64_t vocab_size = 32128) - : model(backend, wtype, num_layers, model_dim, ff_dim, num_heads, vocab_size) { + std::map& tensor_types = empty_tensor_types, + const std::string prefix = "", + int64_t num_layers = 24, + int64_t model_dim = 4096, + int64_t ff_dim = 10240, + int64_t num_heads = 64, + int64_t vocab_size = 32128) + : model(backend, tensor_types, prefix, num_layers, model_dim, ff_dim, num_heads, vocab_size) { } void get_param_tensors(std::map& tensors, const std::string prefix) { @@ -951,7 +956,7 @@ struct T5Embedder { // ggml_backend_t backend = ggml_backend_cuda_init(0); ggml_backend_t backend = ggml_backend_cpu_init(); ggml_type model_data_type = GGML_TYPE_F32; - std::shared_ptr t5 = std::shared_ptr(new T5Embedder(backend, model_data_type)); + std::shared_ptr t5 = std::shared_ptr(new T5Embedder(backend)); { LOG_INFO("loading from '%s'", file_path.c_str()); diff --git a/tae.hpp b/tae.hpp index 0e03b884e..c458b87d2 100644 --- a/tae.hpp +++ b/tae.hpp @@ -62,7 +62,8 @@ class TinyEncoder : public UnaryBlock { int num_blocks = 3; public: - TinyEncoder() { + TinyEncoder(int z_channels = 4) + : z_channels(z_channels) { int index = 0; blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(in_channels, channels, {3, 3}, {1, 1}, {1, 1})); blocks[std::to_string(index++)] = std::shared_ptr(new TAEBlock(channels, channels)); @@ -106,7 +107,10 @@ class TinyDecoder : public UnaryBlock { int num_blocks = 3; public: - TinyDecoder(int index = 0) { + TinyDecoder(int z_channels = 4) + : z_channels(z_channels) { + int index = 0; + blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(z_channels, channels, {3, 3}, {1, 1}, {1, 1})); index++; // nn.ReLU() @@ -163,12 +167,16 @@ class TAESD : public GGMLBlock { bool decode_only; public: - TAESD(bool decode_only = true) + TAESD(bool decode_only = true, SDVersion version = VERSION_SD1) : decode_only(decode_only) { - blocks["decoder.layers"] = std::shared_ptr(new TinyDecoder()); + int z_channels = 4; + if (sd_version_is_dit(version)) { + z_channels = 16; + } + blocks["decoder.layers"] = std::shared_ptr(new TinyDecoder(z_channels)); if (!decode_only) { - blocks["encoder.layers"] = std::shared_ptr(new TinyEncoder()); + blocks["encoder.layers"] = std::shared_ptr(new TinyEncoder(z_channels)); } } @@ -188,12 +196,14 @@ struct TinyAutoEncoder : public GGMLRunner { bool decode_only = false; TinyAutoEncoder(ggml_backend_t backend, - ggml_type wtype, - bool decoder_only = true) + std::map& tensor_types, + const std::string prefix, + bool decoder_only = true, + SDVersion version = VERSION_SD1) : decode_only(decoder_only), - taesd(decode_only), - GGMLRunner(backend, wtype) { - taesd.init(params_ctx, wtype); + taesd(decoder_only, version), + GGMLRunner(backend) { + taesd.init(params_ctx, tensor_types, prefix); } std::string get_desc() { diff --git a/thirdparty/stb_image_write.h b/thirdparty/stb_image_write.h index 5589a7ec2..55118853e 100644 --- a/thirdparty/stb_image_write.h +++ b/thirdparty/stb_image_write.h @@ -177,7 +177,7 @@ STBIWDEF int stbi_write_png(char const *filename, int w, int h, int comp, const STBIWDEF int stbi_write_bmp(char const *filename, int w, int h, int comp, const void *data); STBIWDEF int stbi_write_tga(char const *filename, int w, int h, int comp, const void *data); STBIWDEF int stbi_write_hdr(char const *filename, int w, int h, int comp, const float *data); -STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality); +STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality, const char* parameters = NULL); #ifdef STBIW_WINDOWS_UTF8 STBIWDEF int stbiw_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input); @@ -1412,7 +1412,7 @@ static int stbiw__jpg_processDU(stbi__write_context *s, int *bitBuf, int *bitCnt return DU[0]; } -static int stbi_write_jpg_core(stbi__write_context *s, int width, int height, int comp, const void* data, int quality) { +static int stbi_write_jpg_core(stbi__write_context *s, int width, int height, int comp, const void* data, int quality, const char* parameters) { // Constants that don't pollute global namespace static const unsigned char std_dc_luminance_nrcodes[] = {0,0,1,5,1,1,1,1,1,1,0,0,0,0,0,0,0}; static const unsigned char std_dc_luminance_values[] = {0,1,2,3,4,5,6,7,8,9,10,11}; @@ -1521,6 +1521,20 @@ static int stbi_write_jpg_core(stbi__write_context *s, int width, int height, in s->func(s->context, (void*)YTable, sizeof(YTable)); stbiw__putc(s, 1); s->func(s->context, UVTable, sizeof(UVTable)); + + // comment block with parameters of generation + if(parameters != NULL) { + stbiw__putc(s, 0xFF /* comnent */ ); + stbiw__putc(s, 0xFE /* marker */ ); + size_t param_length = std::min(2 + strlen("parameters") + 1 + strlen(parameters) + 1, (size_t) 0xFFFF); + stbiw__putc(s, param_length >> 8); // no need to mask, length < 65536 + stbiw__putc(s, param_length & 0xFF); + s->func(s->context, (void*)"parameters", strlen("parameters") + 1); // std::string is zero-terminated + s->func(s->context, (void*)parameters, std::min(param_length, (size_t) 65534) - 2 - strlen("parameters") - 1); + if(param_length > 65534) stbiw__putc(s, 0); // always zero-terminate for safety + if(param_length & 1) stbiw__putc(s, 0xFF); // pad to even length + } + s->func(s->context, (void*)head1, sizeof(head1)); s->func(s->context, (void*)(std_dc_luminance_nrcodes+1), sizeof(std_dc_luminance_nrcodes)-1); s->func(s->context, (void*)std_dc_luminance_values, sizeof(std_dc_luminance_values)); @@ -1625,16 +1639,16 @@ STBIWDEF int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, { stbi__write_context s = { 0 }; stbi__start_write_callbacks(&s, func, context); - return stbi_write_jpg_core(&s, x, y, comp, (void *) data, quality); + return stbi_write_jpg_core(&s, x, y, comp, (void *) data, quality, NULL); } #ifndef STBI_WRITE_NO_STDIO -STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality) +STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality, const char* parameters) { stbi__write_context s = { 0 }; if (stbi__start_write_file(&s,filename)) { - int r = stbi_write_jpg_core(&s, x, y, comp, data, quality); + int r = stbi_write_jpg_core(&s, x, y, comp, data, quality, parameters); stbi__end_write_file(&s); return r; } else diff --git a/unet.hpp b/unet.hpp index 94a8ba46a..31b7fe986 100644 --- a/unet.hpp +++ b/unet.hpp @@ -166,6 +166,7 @@ class SpatialVideoTransformer : public SpatialTransformer { // ldm.modules.diffusionmodules.openaimodel.UNetModel class UnetModelBlock : public GGMLBlock { protected: + static std::map empty_tensor_types; SDVersion version = VERSION_SD1; // network hparams int in_channels = 4; @@ -183,13 +184,13 @@ class UnetModelBlock : public GGMLBlock { int model_channels = 320; int adm_in_channels = 2816; // only for VERSION_SDXL/SVD - UnetModelBlock(SDVersion version = VERSION_SD1) + UnetModelBlock(SDVersion version = VERSION_SD1, std::map& tensor_types = empty_tensor_types, bool flash_attn = false) : version(version) { - if (version == VERSION_SD2) { + if (sd_version_is_sd2(version)) { context_dim = 1024; num_head_channels = 64; num_heads = -1; - } else if (version == VERSION_SDXL) { + } else if (sd_version_is_sdxl(version)) { context_dim = 2048; attention_resolutions = {4, 2}; channel_mult = {1, 2, 4}; @@ -204,6 +205,10 @@ class UnetModelBlock : public GGMLBlock { num_head_channels = 64; num_heads = -1; } + if (sd_version_is_inpaint(version)) { + in_channels = 9; + } + // dims is always 2 // use_temporal_attention is always True for SVD @@ -211,7 +216,7 @@ class UnetModelBlock : public GGMLBlock { // time_embed_1 is nn.SiLU() blocks["time_embed.2"] = std::shared_ptr(new Linear(time_embed_dim, time_embed_dim)); - if (version == VERSION_SDXL || version == VERSION_SVD) { + if (sd_version_is_sdxl(version) || version == VERSION_SVD) { blocks["label_emb.0.0"] = std::shared_ptr(new Linear(adm_in_channels, time_embed_dim)); // label_emb_1 is nn.SiLU() blocks["label_emb.0.2"] = std::shared_ptr(new Linear(time_embed_dim, time_embed_dim)); @@ -242,7 +247,7 @@ class UnetModelBlock : public GGMLBlock { if (version == VERSION_SVD) { return new SpatialVideoTransformer(in_channels, n_head, d_head, depth, context_dim); } else { - return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim); + return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim, flash_attn); } }; @@ -532,10 +537,12 @@ struct UNetModelRunner : public GGMLRunner { UnetModelBlock unet; UNetModelRunner(ggml_backend_t backend, - ggml_type wtype, - SDVersion version = VERSION_SD1) - : GGMLRunner(backend, wtype), unet(version) { - unet.init(params_ctx, wtype); + std::map& tensor_types, + const std::string prefix, + SDVersion version = VERSION_SD1, + bool flash_attn = false) + : GGMLRunner(backend), unet(version, tensor_types, flash_attn) { + unet.init(params_ctx, tensor_types, prefix); } std::string get_desc() { @@ -564,6 +571,7 @@ struct UNetModelRunner : public GGMLRunner { context = to_backend(context); y = to_backend(y); timesteps = to_backend(timesteps); + c_concat = to_backend(c_concat); for (int i = 0; i < controls.size(); i++) { controls[i] = to_backend(controls[i]); @@ -649,4 +657,4 @@ struct UNetModelRunner : public GGMLRunner { } }; -#endif // __UNET_HPP__ \ No newline at end of file +#endif // __UNET_HPP__ diff --git a/upscaler.cpp b/upscaler.cpp index 096352993..0c11b666e 100644 --- a/upscaler.cpp +++ b/upscaler.cpp @@ -15,13 +15,13 @@ struct UpscalerGGML { } bool load_from_file(const std::string& esrgan_path) { -#ifdef SD_USE_CUBLAS +#ifdef SD_USE_CUDA LOG_DEBUG("Using CUDA backend"); backend = ggml_backend_cuda_init(0); #endif #ifdef SD_USE_METAL LOG_DEBUG("Using Metal backend"); - ggml_backend_metal_log_set_callback(ggml_log_callback_default, nullptr); + ggml_log_set(ggml_log_callback_default, nullptr); backend = ggml_backend_metal_init(); #endif #ifdef SD_USE_VULKAN @@ -32,13 +32,17 @@ struct UpscalerGGML { LOG_DEBUG("Using SYCL backend"); backend = ggml_backend_sycl_init(0); #endif - + ModelLoader model_loader; + if (!model_loader.init_from_file(esrgan_path)) { + LOG_ERROR("init model loader from file failed: '%s'", esrgan_path.c_str()); + } + model_loader.set_wtype_override(model_data_type); if (!backend) { LOG_DEBUG("Using CPU backend"); backend = ggml_backend_cpu_init(); } LOG_INFO("Upscaler weight type: %s", ggml_type_name(model_data_type)); - esrgan_upscaler = std::make_shared(backend, model_data_type); + esrgan_upscaler = std::make_shared(backend, model_loader.tensor_storages_types); if (!esrgan_upscaler->load_from_file(esrgan_path)) { return false; } @@ -96,8 +100,7 @@ struct upscaler_ctx_t { }; upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path_c_str, - int n_threads, - enum sd_type_t wtype) { + int n_threads) { upscaler_ctx_t* upscaler_ctx = (upscaler_ctx_t*)malloc(sizeof(upscaler_ctx_t)); if (upscaler_ctx == NULL) { return NULL; diff --git a/util.cpp b/util.cpp index 5de5ce26e..da11a14d6 100644 --- a/util.cpp +++ b/util.cpp @@ -22,6 +22,7 @@ #include #endif +#include "ggml-cpu.h" #include "ggml.h" #include "stable-diffusion.h" @@ -112,18 +113,31 @@ std::vector get_files_from_dir(const std::string& dir) { // Find the first file in the directory hFind = FindFirstFile(directoryPath, &findFileData); - + bool isAbsolutePath = false; // Check if the directory was found if (hFind == INVALID_HANDLE_VALUE) { - printf("Unable to find directory.\n"); - return files; + printf("Unable to find directory. Try with original path \n"); + + char directoryPathAbsolute[MAX_PATH]; + sprintf(directoryPathAbsolute, "%s*", dir.c_str()); + + hFind = FindFirstFile(directoryPathAbsolute, &findFileData); + isAbsolutePath = true; + if (hFind == INVALID_HANDLE_VALUE) { + printf("Absolute path was also wrong.\n"); + return files; + } } // Loop through all files in the directory do { // Check if the found file is a regular file (not a directory) if (!(findFileData.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY)) { - files.push_back(std::string(currentDirectory) + "\\" + dir + "\\" + std::string(findFileData.cFileName)); + if (isAbsolutePath) { + files.push_back(dir + "\\" + std::string(findFileData.cFileName)); + } else { + files.push_back(std::string(currentDirectory) + "\\" + dir + "\\" + std::string(findFileData.cFileName)); + } } } while (FindNextFile(hFind, &findFileData) != 0); @@ -276,6 +290,23 @@ std::string path_join(const std::string& p1, const std::string& p2) { return p1 + "/" + p2; } +std::vector splitString(const std::string& str, char delimiter) { + std::vector result; + size_t start = 0; + size_t end = str.find(delimiter); + + while (end != std::string::npos) { + result.push_back(str.substr(start, end - start)); + start = end + 1; + end = str.find(delimiter, start); + } + + // Add the last segment after the last delimiter + result.push_back(str.substr(start)); + + return result; +} + sd_image_t* preprocess_id_image(sd_image_t* img) { int shortest_edge = 224; int size = shortest_edge; @@ -330,7 +361,7 @@ void pretty_progress(int step, int steps, float time) { } } progress += "|"; - printf(time > 1.0f ? "\r%s %i/%i - %.2fs/it" : "\r%s %i/%i - %.2fit/s", + printf(time > 1.0f ? "\r%s %i/%i - %.2fs/it" : "\r%s %i/%i - %.2fit/s\033[K", progress.c_str(), step, steps, time > 1.0f || time == 0 ? time : (1.0f / time)); fflush(stdout); // for linux @@ -393,7 +424,6 @@ const char* sd_get_system_info() { static char buffer[1024]; std::stringstream ss; ss << "System Info: \n"; - ss << " BLAS = " << ggml_cpu_has_blas() << std::endl; ss << " SSE3 = " << ggml_cpu_has_sse3() << std::endl; ss << " AVX = " << ggml_cpu_has_avx() << std::endl; ss << " AVX2 = " << ggml_cpu_has_avx2() << std::endl; diff --git a/util.h b/util.h index 9b1e6734f..14fa812e5 100644 --- a/util.h +++ b/util.h @@ -45,7 +45,7 @@ sd_image_f32_t resize_sd_image_f32_t(sd_image_f32_t image, int target_width, int sd_image_f32_t clip_preprocess(sd_image_f32_t image, int size); std::string path_join(const std::string& p1, const std::string& p2); - +std::vector splitString(const std::string& str, char delimiter); void pretty_progress(int step, int steps, float time); void log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...); diff --git a/vae.hpp b/vae.hpp index 85319fdee..4add881f6 100644 --- a/vae.hpp +++ b/vae.hpp @@ -163,8 +163,9 @@ class AE3DConv : public Conv2d { class VideoResnetBlock : public ResnetBlock { protected: - void init_params(struct ggml_context* ctx, ggml_type wtype) { - params["mix_factor"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); + void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { + enum ggml_type wtype = (tensor_types.find(prefix + "mix_factor") != tensor_types.end()) ? tensor_types[prefix + "mix_factor"] : GGML_TYPE_F32; + params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1); } float get_alpha() { @@ -457,7 +458,7 @@ class AutoencodingEngine : public GGMLBlock { bool use_video_decoder = false, SDVersion version = VERSION_SD1) : decode_only(decode_only), use_video_decoder(use_video_decoder) { - if (version == VERSION_SD3_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + if (sd_version_is_dit(version)) { dd_config.z_channels = 16; use_quant = false; } @@ -524,12 +525,13 @@ struct AutoEncoderKL : public GGMLRunner { AutoencodingEngine ae; AutoEncoderKL(ggml_backend_t backend, - ggml_type wtype, + std::map& tensor_types, + const std::string prefix, bool decode_only = false, bool use_video_decoder = false, SDVersion version = VERSION_SD1) - : decode_only(decode_only), ae(decode_only, use_video_decoder, version), GGMLRunner(backend, wtype) { - ae.init(params_ctx, wtype); + : decode_only(decode_only), ae(decode_only, use_video_decoder, version), GGMLRunner(backend) { + ae.init(params_ctx, tensor_types, prefix); } std::string get_desc() { @@ -612,4 +614,4 @@ struct AutoEncoderKL : public GGMLRunner { }; }; -#endif \ No newline at end of file +#endif