diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a16e692ec..4112ae9bb 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -146,7 +146,7 @@ jobs: sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-${{ steps.system-info.outputs.OS_TYPE }}-${{ steps.system-info.outputs.OS_NAME }}-${{ steps.system-info.outputs.OS_VERSION }}-${{ steps.system-info.outputs.CPU_ARCH }}.zip windows-latest-cmake: - runs-on: windows-2019 + runs-on: windows-2025 env: VULKAN_VERSION: 1.3.261.1 diff --git a/CMakeLists.txt b/CMakeLists.txt index 782a893e4..06de0d58b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,6 +28,7 @@ option(SD_CUDA "sd: cuda backend" OFF) option(SD_HIPBLAS "sd: rocm backend" OFF) option(SD_METAL "sd: metal backend" OFF) option(SD_VULKAN "sd: vulkan backend" OFF) +option(SD_OPENCL "sd: opencl backend" OFF) option(SD_SYCL "sd: sycl backend" OFF) option(SD_MUSA "sd: musa backend" OFF) option(SD_FAST_SOFTMAX "sd: x1.5 faster softmax, indeterministic (sometimes, same seed don't generate same image), cuda only" OFF) @@ -52,6 +53,12 @@ if (SD_VULKAN) add_definitions(-DSD_USE_VULKAN) endif () +if (SD_OPENCL) + message("-- Use OpenCL as backend stable-diffusion") + set(GGML_OPENCL ON) + add_definitions(-DSD_USE_OPENCL) +endif () + if (SD_HIPBLAS) message("-- Use HIPBLAS as backend stable-diffusion") set(GGML_HIP ON) diff --git a/Dockerfile.musa b/Dockerfile.musa index 0adcb7ee5..c7f5f2e83 100644 --- a/Dockerfile.musa +++ b/Dockerfile.musa @@ -2,14 +2,17 @@ ARG MUSA_VERSION=rc3.1.1 FROM mthreads/musa:${MUSA_VERSION}-devel-ubuntu22.04 as build -RUN apt-get update && apt-get install -y cmake +RUN apt-get update && apt-get install -y ccache cmake git WORKDIR /sd.cpp COPY . . RUN mkdir build && cd build && \ - cmake .. -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DSD_MUSA=ON -DCMAKE_BUILD_TYPE=Release && \ + cmake .. -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ \ + -DCMAKE_C_FLAGS="${CMAKE_C_FLAGS} -fopenmp -I/usr/lib/llvm-14/lib/clang/14.0.0/include -L/usr/lib/llvm-14/lib" \ + -DCMAKE_CXX_FLAGS="${CMAKE_CXX_FLAGS} -fopenmp -I/usr/lib/llvm-14/lib/clang/14.0.0/include -L/usr/lib/llvm-14/lib" \ + -DSD_MUSA=ON -DCMAKE_BUILD_TYPE=Release && \ cmake --build . --config Release FROM mthreads/musa:${MUSA_VERSION}-runtime-ubuntu22.04 as runtime diff --git a/README.md b/README.md index 553fb7f8f..4720dc29c 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,8 @@ Inference of Stable Diffusion and Flux in pure C/C++ - SD1.x, SD2.x, SDXL and [SD3/SD3.5](./docs/sd3.md) support - !!!The VAE in SDXL encounters NaN issues under FP16, but unfortunately, the ggml_conv_2d only operates under FP16. Hence, a parameter is needed to specify the VAE that has fixed the FP16 NaN issue. You can find it here: [SDXL VAE FP16 Fix](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors). - [Flux-dev/Flux-schnell Support](./docs/flux.md) - +- [FLUX.1-Kontext-dev](./docs/kontext.md) +- [Chroma](./docs/chroma.md) - [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo) and [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo) support - [PhotoMaker](https://github.com/TencentARC/PhotoMaker) support. - 16-bit, 32-bit float support @@ -21,7 +22,7 @@ Inference of Stable Diffusion and Flux in pure C/C++ - Accelerated memory-efficient CPU inference - Only requires ~2.3GB when using txt2img with fp16 precision to generate a 512x512 image, enabling Flash Attention just requires ~1.8GB. - AVX, AVX2 and AVX512 support for x86 architectures -- Full CUDA, Metal, Vulkan and SYCL backend for GPU acceleration. +- Full CUDA, Metal, Vulkan, OpenCL and SYCL backend for GPU acceleration. - Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs models - No need to convert to `.ggml` or `.gguf` anymore! - Flash Attention for memory usage optimization @@ -49,7 +50,7 @@ Inference of Stable Diffusion and Flux in pure C/C++ - Linux - Mac OS - Windows - - Android (via Termux) + - Android (via Termux, [Local Diffusion](https://github.com/rmatif/Local-Diffusion)) ### TODO @@ -159,6 +160,73 @@ cmake .. -DSD_VULKAN=ON cmake --build . --config Release ``` +##### Using OpenCL (for Adreno GPU) + +Currently, it supports only Adreno GPUs and is primarily optimized for Q4_0 type + +To build for Windows ARM please refers to [Windows 11 Arm64 +](https://github.com/ggml-org/llama.cpp/blob/master/docs/backend/OPENCL.md#windows-11-arm64) + +Building for Android: + + Android NDK: + Download and install the Android NDK from the [official Android developer site](https://developer.android.com/ndk/downloads). + +Setup OpenCL Dependencies for NDK: + +You need to provide OpenCL headers and the ICD loader library to your NDK sysroot. + +* OpenCL Headers: + ```bash + # In a temporary working directory + git clone https://github.com/KhronosGroup/OpenCL-Headers + cd OpenCL-Headers + # Replace with your actual NDK installation path + # e.g., cp -r CL /path/to/android-ndk-r26c/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include + sudo cp -r CL /toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include + cd .. + ``` + +* OpenCL ICD Loader: + ```bash + # In the same temporary working directory + git clone https://github.com/KhronosGroup/OpenCL-ICD-Loader + cd OpenCL-ICD-Loader + mkdir build_ndk && cd build_ndk + + # Replace in the CMAKE_TOOLCHAIN_FILE and OPENCL_ICD_LOADER_HEADERS_DIR + cmake .. -G Ninja -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_TOOLCHAIN_FILE=/build/cmake/android.toolchain.cmake \ + -DOPENCL_ICD_LOADER_HEADERS_DIR=/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include \ + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=24 \ + -DANDROID_STL=c++_shared + + ninja + # Replace + # e.g., cp libOpenCL.so /path/to/android-ndk-r26c/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/aarch64-linux-android + sudo cp libOpenCL.so /toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/aarch64-linux-android + cd ../.. + ``` + +Build `stable-diffusion.cpp` for Android with OpenCL: + +```bash +mkdir build-android && cd build-android + +# Replace with your actual NDK installation path +# e.g., -DCMAKE_TOOLCHAIN_FILE=/path/to/android-ndk-r26c/build/cmake/android.toolchain.cmake +cmake .. -G Ninja \ + -DCMAKE_TOOLCHAIN_FILE=/build/cmake/android.toolchain.cmake \ + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=android-28 \ + -DGGML_OPENMP=OFF \ + -DSD_OPENCL=ON + +ninja +``` +*(Note: Don't forget to include `LD_LIBRARY_PATH=/vendor/lib64` in your command line before running the binary)* + ##### Using SYCL Using SYCL makes the computation run on the Intel GPU. Please make sure you have installed the related driver and [IntelĀ® oneAPI Base toolkit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html) before start. More details and steps can refer to [llama.cpp SYCL backend](https://github.com/ggerganov/llama.cpp/blob/master/docs/backend/SYCL.md#linux). @@ -220,7 +288,7 @@ arguments: -m, --model [MODEL] path to full model --diffusion-model path to the standalone diffusion model --clip_l path to the clip-l text encoder - --clip_g path to the clip-l text encoder + --clip_g path to the clip-g text encoder --t5xxl path to the the t5xxl text encoder --vae [VAE] path to vae --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality) @@ -231,26 +299,32 @@ arguments: --normalize-input normalize PHOTOMAKER input id images --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now --upscale-repeats Run the ESRGAN upscaler this many times (default 1) - --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k) + --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K) If not specified, the default is the type of the weight file --lora-model-dir [DIR] lora model directory -i, --init-img [IMAGE] path to the input image, required by img2img + --mask [MASK] path to the mask image, required by img2img with mask --control-image [IMAGE] path to image condition, control net + -r, --ref_image [PATH] reference image for Flux Kontext models (can be used multiple times) -o, --output OUTPUT path to write result image to (default: ./output.png) -p, --prompt [PROMPT] the prompt to render -n, --negative-prompt PROMPT the negative prompt (default: "") --cfg-scale SCALE unconditional guidance scale: (default: 7.0) + --guidance SCALE guidance scale for img2img (default: 3.5) + --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0) + 0 means disabled, a value of 2.5 is nice for sd3.5 medium + --eta SCALE eta in DDIM, only for DDIM and TCD: (default: 0) --skip-layers LAYERS Layers to skip for SLG steps: (default: [7,8,9]) --skip-layer-start START SLG enabling point: (default: 0.01) --skip-layer-end END SLG disabling point: (default: 0.2) - SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END]) + SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END]) --strength STRENGTH strength for noising/unnoising (default: 0.75) --style-ratio STYLE-RATIO strength for keeping input identity (default: 20%) --control-strength STRENGTH strength to apply Control Net (default: 0.9) 1.0 corresponds to full destruction of information in init image -H, --height H image height, in pixel space (default: 512) -W, --width W image width, in pixel space (default: 512) - --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm} + --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd} sampling method (default: "euler_a") --steps STEPS number of sample steps (default: 20) --rng {std_default, cuda} RNG (default: cuda) @@ -267,7 +341,10 @@ arguments: This might crash if it is not supported by the backend. --control-net-cpu keep controlnet in cpu (for low vram) --canny apply canny preprocessor (edge detection) - --color Colors the logging tags according to level + --color colors the logging tags according to level + --chroma-disable-dit-mask disable dit mask for chroma + --chroma-enable-t5-mask enable t5 mask for chroma + --chroma-t5-mask-pad PAD_SIZE t5 mask pad size of chroma -v, --verbose print extra info ``` @@ -315,10 +392,12 @@ Using formats of different precisions will yield results of varying quality. These projects wrap `stable-diffusion.cpp` for easier use in other languages/frameworks. -* Golang: [seasonjs/stable-diffusion](https://github.com/seasonjs/stable-diffusion) +* Golang (non-cgo): [seasonjs/stable-diffusion](https://github.com/seasonjs/stable-diffusion) +* Golang (cgo): [Binozo/GoStableDiffusion](https://github.com/Binozo/GoStableDiffusion) * C#: [DarthAffe/StableDiffusion.NET](https://github.com/DarthAffe/StableDiffusion.NET) * Python: [william-murray1204/stable-diffusion-cpp-python](https://github.com/william-murray1204/stable-diffusion-cpp-python) * Rust: [newfla/diffusion-rs](https://github.com/newfla/diffusion-rs) +* Flutter/Dart: [rmatif/Local-Diffusion](https://github.com/rmatif/Local-Diffusion) ## UIs @@ -327,6 +406,7 @@ These projects use `stable-diffusion.cpp` as a backend for their image generatio - [Jellybox](https://jellybox.com) - [Stable Diffusion GUI](https://github.com/fszontagh/sd.cpp.gui.wx) - [Stable Diffusion CLI-GUI](https://github.com/piallai/stable-diffusion.cpp) +- [Local Diffusion](https://github.com/rmatif/Local-Diffusion) ## Contributors diff --git a/assets/flux/chroma_v40.png b/assets/flux/chroma_v40.png new file mode 100644 index 000000000..4217009dc Binary files /dev/null and b/assets/flux/chroma_v40.png differ diff --git a/assets/flux/kontext1_dev_output.png b/assets/flux/kontext1_dev_output.png new file mode 100644 index 000000000..4fa5e38dd Binary files /dev/null and b/assets/flux/kontext1_dev_output.png differ diff --git a/clip.hpp b/clip.hpp index 2307ee3c5..d359f61cd 100644 --- a/clip.hpp +++ b/clip.hpp @@ -678,8 +678,8 @@ class CLIPTextModel : public GGMLBlock { bool with_final_ln = true; CLIPTextModel(CLIPVersion version = OPENAI_CLIP_VIT_L_14, - int clip_skip_value = -1, - bool with_final_ln = true) + bool with_final_ln = true, + int clip_skip_value = -1) : version(version), with_final_ln(with_final_ln) { if (version == OPEN_CLIP_VIT_H_14) { hidden_size = 1024; @@ -701,7 +701,7 @@ class CLIPTextModel : public GGMLBlock { void set_clip_skip(int skip) { if (skip <= 0) { - return; + skip = -1; } clip_skip = skip; } @@ -871,9 +871,9 @@ struct CLIPTextModelRunner : public GGMLRunner { std::map& tensor_types, const std::string prefix, CLIPVersion version = OPENAI_CLIP_VIT_L_14, - int clip_skip_value = 1, - bool with_final_ln = true) - : GGMLRunner(backend), model(version, clip_skip_value, with_final_ln) { + bool with_final_ln = true, + int clip_skip_value = -1) + : GGMLRunner(backend), model(version, with_final_ln, clip_skip_value) { model.init(params_ctx, tensor_types, prefix); } diff --git a/common.hpp b/common.hpp index 337b4a0c4..9b5cc53be 100644 --- a/common.hpp +++ b/common.hpp @@ -56,8 +56,8 @@ class UpSampleBlock : public GGMLBlock { // x: [N, channels, h, w] auto conv = std::dynamic_pointer_cast(blocks["conv"]); - x = ggml_upscale(ctx, x, 2); // [N, channels, h*2, w*2] - x = conv->forward(ctx, x); // [N, out_channels, h*2, w*2] + x = ggml_upscale(ctx, x, 2, GGML_SCALE_MODE_NEAREST); // [N, channels, h*2, w*2] + x = conv->forward(ctx, x); // [N, out_channels, h*2, w*2] return x; } }; diff --git a/conditioner.hpp b/conditioner.hpp index 6e9acdb19..3f89d5263 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -63,23 +63,24 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { PMVersion pv = PM_VERSION_1, int clip_skip = -1) : version(version), pm_version(pv), tokenizer(sd_version_is_sd2(version) ? 0 : 49407), embd_dir(embd_dir) { - if (clip_skip <= 0) { - clip_skip = 1; - if (sd_version_is_sd2(version) || sd_version_is_sdxl(version)) { - clip_skip = 2; - } - } if (sd_version_is_sd1(version)) { - text_model = std::make_shared(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip); + text_model = std::make_shared(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14); } else if (sd_version_is_sd2(version)) { - text_model = std::make_shared(backend, tensor_types, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14, clip_skip); + text_model = std::make_shared(backend, tensor_types, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14); } else if (sd_version_is_sdxl(version)) { - text_model = std::make_shared(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, false); - text_model2 = std::make_shared(backend, tensor_types, "cond_stage_model.1.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, clip_skip, false); + text_model = std::make_shared(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, false); + text_model2 = std::make_shared(backend, tensor_types, "cond_stage_model.1.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false); } + set_clip_skip(clip_skip); } void set_clip_skip(int clip_skip) { + if (clip_skip <= 0) { + clip_skip = 1; + if (sd_version_is_sd2(version) || sd_version_is_sdxl(version)) { + clip_skip = 2; + } + } text_model->set_clip_skip(clip_skip); if (sd_version_is_sdxl(version)) { text_model2->set_clip_skip(clip_skip); @@ -458,8 +459,8 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { if (sd_version_is_sdxl(version)) { text_model2->compute(n_threads, input_ids2, - 0, - NULL, + num_custom_embeddings, + token_embed_custom.data(), max_token_idx, false, &chunk_hidden_states2, work_ctx); @@ -469,8 +470,8 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { if (chunk_idx == 0) { text_model2->compute(n_threads, input_ids2, - 0, - NULL, + num_custom_embeddings, + token_embed_custom.data(), max_token_idx, true, &pooled, @@ -665,15 +666,16 @@ struct SD3CLIPEmbedder : public Conditioner { std::map& tensor_types, int clip_skip = -1) : clip_g_tokenizer(0) { - if (clip_skip <= 0) { - clip_skip = 2; - } - clip_l = std::make_shared(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, false); - clip_g = std::make_shared(backend, tensor_types, "text_encoders.clip_g.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, clip_skip, false); + clip_l = std::make_shared(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, false); + clip_g = std::make_shared(backend, tensor_types, "text_encoders.clip_g.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false); t5 = std::make_shared(backend, tensor_types, "text_encoders.t5xxl.transformer"); + set_clip_skip(clip_skip); } void set_clip_skip(int clip_skip) { + if (clip_skip <= 0) { + clip_skip = 2; + } clip_l->set_clip_skip(clip_skip); clip_g->set_clip_skip(clip_skip); } @@ -747,7 +749,7 @@ struct SD3CLIPEmbedder : public Conditioner { clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, max_length, padding); clip_g_tokenizer.pad_tokens(clip_g_tokens, clip_g_weights, max_length, padding); - t5_tokenizer.pad_tokens(t5_tokens, t5_weights, max_length, padding); + t5_tokenizer.pad_tokens(t5_tokens, t5_weights, NULL, max_length, padding); // for (int i = 0; i < clip_l_tokens.size(); i++) { // std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", "; @@ -902,6 +904,7 @@ struct SD3CLIPEmbedder : public Conditioner { t5->compute(n_threads, input_ids, + NULL, &chunk_hidden_states_t5, work_ctx); { @@ -1004,18 +1007,20 @@ struct FluxCLIPEmbedder : public Conditioner { T5UniGramTokenizer t5_tokenizer; std::shared_ptr clip_l; std::shared_ptr t5; + size_t chunk_len = 256; FluxCLIPEmbedder(ggml_backend_t backend, std::map& tensor_types, int clip_skip = -1) { - if (clip_skip <= 0) { - clip_skip = 2; - } - clip_l = std::make_shared(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, true); + clip_l = std::make_shared(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, true); t5 = std::make_shared(backend, tensor_types, "text_encoders.t5xxl.transformer"); + set_clip_skip(clip_skip); } void set_clip_skip(int clip_skip) { + if (clip_skip <= 0) { + clip_skip = 2; + } clip_l->set_clip_skip(clip_skip); } @@ -1077,7 +1082,7 @@ struct FluxCLIPEmbedder : public Conditioner { } clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, 77, padding); - t5_tokenizer.pad_tokens(t5_tokens, t5_weights, max_length, padding); + t5_tokenizer.pad_tokens(t5_tokens, t5_weights, NULL, max_length, padding); // for (int i = 0; i < clip_l_tokens.size(); i++) { // std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", "; @@ -1109,7 +1114,6 @@ struct FluxCLIPEmbedder : public Conditioner { struct ggml_tensor* pooled = NULL; // [768,] std::vector hidden_states_vec; - size_t chunk_len = 256; size_t chunk_count = t5_tokens.size() / chunk_len; for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) { // clip_l @@ -1147,6 +1151,7 @@ struct FluxCLIPEmbedder : public Conditioner { t5->compute(n_threads, input_ids, + NULL, &chunk_hidden_states, work_ctx); { @@ -1196,7 +1201,209 @@ struct FluxCLIPEmbedder : public Conditioner { int height, int adm_in_channels = -1, bool force_zero_embeddings = false) { - auto tokens_and_weights = tokenize(text, 256, true); + auto tokens_and_weights = tokenize(text, chunk_len, true); + return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings); + } + + std::tuple> get_learned_condition_with_trigger(ggml_context* work_ctx, + int n_threads, + const std::string& text, + int clip_skip, + int width, + int height, + int num_input_imgs, + int adm_in_channels = -1, + bool force_zero_embeddings = false) { + GGML_ASSERT(0 && "Not implemented yet!"); + } + + std::string remove_trigger_from_prompt(ggml_context* work_ctx, + const std::string& prompt) { + GGML_ASSERT(0 && "Not implemented yet!"); + } +}; + +struct PixArtCLIPEmbedder : public Conditioner { + T5UniGramTokenizer t5_tokenizer; + std::shared_ptr t5; + size_t chunk_len = 512; + bool use_mask = false; + int mask_pad = 1; + + PixArtCLIPEmbedder(ggml_backend_t backend, + std::map& tensor_types, + int clip_skip = -1, + bool use_mask = false, + int mask_pad = 1) + : use_mask(use_mask), mask_pad(mask_pad) { + t5 = std::make_shared(backend, tensor_types, "text_encoders.t5xxl.transformer"); + } + + void set_clip_skip(int clip_skip) { + } + + void get_param_tensors(std::map& tensors) { + t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer"); + } + + void alloc_params_buffer() { + t5->alloc_params_buffer(); + } + + void free_params_buffer() { + t5->free_params_buffer(); + } + + size_t get_params_buffer_size() { + size_t buffer_size = 0; + + buffer_size += t5->get_params_buffer_size(); + + return buffer_size; + } + + std::tuple, std::vector, std::vector> tokenize(std::string text, + size_t max_length = 0, + bool padding = false) { + auto parsed_attention = parse_prompt_attention(text); + + { + std::stringstream ss; + ss << "["; + for (const auto& item : parsed_attention) { + ss << "['" << item.first << "', " << item.second << "], "; + } + ss << "]"; + LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str()); + } + + auto on_new_token_cb = [&](std::string& str, std::vector& bpe_tokens) -> bool { + return false; + }; + + std::vector t5_tokens; + std::vector t5_weights; + std::vector t5_mask; + for (const auto& item : parsed_attention) { + const std::string& curr_text = item.first; + float curr_weight = item.second; + + std::vector curr_tokens = t5_tokenizer.Encode(curr_text, true); + t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end()); + t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight); + } + + t5_tokenizer.pad_tokens(t5_tokens, t5_weights, &t5_mask, max_length, padding); + + return {t5_tokens, t5_weights, t5_mask}; + } + + void modify_mask_to_attend_padding(struct ggml_tensor* mask, int max_seq_length, int num_extra_padding = 8) { + float* mask_data = (float*)mask->data; + int num_pad = 0; + for (int64_t i = 0; i < max_seq_length; i++) { + if (num_pad >= num_extra_padding) { + break; + } + if (std::isinf(mask_data[i])) { + mask_data[i] = 0; + ++num_pad; + } + } + // LOG_DEBUG("PAD: %d", num_pad); + } + + SDCondition get_learned_condition_common(ggml_context* work_ctx, + int n_threads, + std::tuple, std::vector, std::vector> token_and_weights, + int clip_skip, + bool force_zero_embeddings = false) { + auto& t5_tokens = std::get<0>(token_and_weights); + auto& t5_weights = std::get<1>(token_and_weights); + auto& t5_attn_mask_vec = std::get<2>(token_and_weights); + + int64_t t0 = ggml_time_ms(); + struct ggml_tensor* hidden_states = NULL; // [N, n_token, 4096] + struct ggml_tensor* chunk_hidden_states = NULL; // [n_token, 4096] + struct ggml_tensor* pooled = NULL; // [768,] + struct ggml_tensor* t5_attn_mask = vector_to_ggml_tensor(work_ctx, t5_attn_mask_vec); // [768,] + + std::vector hidden_states_vec; + + size_t chunk_count = t5_tokens.size() / chunk_len; + + for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) { + // t5 + std::vector chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len, + t5_tokens.begin() + (chunk_idx + 1) * chunk_len); + std::vector chunk_weights(t5_weights.begin() + chunk_idx * chunk_len, + t5_weights.begin() + (chunk_idx + 1) * chunk_len); + std::vector chunk_mask(t5_attn_mask_vec.begin() + chunk_idx * chunk_len, + t5_attn_mask_vec.begin() + (chunk_idx + 1) * chunk_len); + + auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens); + auto t5_attn_mask_chunk = use_mask ? vector_to_ggml_tensor(work_ctx, chunk_mask) : NULL; + + t5->compute(n_threads, + input_ids, + t5_attn_mask_chunk, + &chunk_hidden_states, + work_ctx); + { + auto tensor = chunk_hidden_states; + float original_mean = ggml_tensor_mean(tensor); + for (int i2 = 0; i2 < tensor->ne[2]; i2++) { + for (int i1 = 0; i1 < tensor->ne[1]; i1++) { + for (int i0 = 0; i0 < tensor->ne[0]; i0++) { + float value = ggml_tensor_get_f32(tensor, i0, i1, i2); + value *= chunk_weights[i1]; + ggml_tensor_set_f32(tensor, value, i0, i1, i2); + } + } + } + float new_mean = ggml_tensor_mean(tensor); + ggml_tensor_scale(tensor, (original_mean / new_mean)); + } + + int64_t t1 = ggml_time_ms(); + LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0); + if (force_zero_embeddings) { + float* vec = (float*)chunk_hidden_states->data; + for (int i = 0; i < ggml_nelements(chunk_hidden_states); i++) { + vec[i] = 0; + } + } + + hidden_states_vec.insert(hidden_states_vec.end(), + (float*)chunk_hidden_states->data, + ((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states)); + } + + if (hidden_states_vec.size() > 0) { + hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec); + hidden_states = ggml_reshape_2d(work_ctx, + hidden_states, + chunk_hidden_states->ne[0], + ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]); + } else { + hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, 256); + ggml_set_f32(hidden_states, 0.f); + } + + modify_mask_to_attend_padding(t5_attn_mask, ggml_nelements(t5_attn_mask), mask_pad); + + return SDCondition(hidden_states, t5_attn_mask, NULL); + } + + SDCondition get_learned_condition(ggml_context* work_ctx, + int n_threads, + const std::string& text, + int clip_skip, + int width, + int height, + int adm_in_channels = -1, + bool force_zero_embeddings = false) { + auto tokens_and_weights = tokenize(text, chunk_len, true); return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings); } @@ -1218,4 +1425,4 @@ struct FluxCLIPEmbedder : public Conditioner { } }; -#endif \ No newline at end of file +#endif diff --git a/denoiser.hpp b/denoiser.hpp index 66799109d..d4bcec590 100644 --- a/denoiser.hpp +++ b/denoiser.hpp @@ -168,24 +168,21 @@ struct AYSSchedule : SigmaSchedule { std::vector inputs; std::vector results(n + 1); - switch (version) { - case VERSION_SD2: /* fallthrough */ - LOG_WARN("AYS not designed for SD2.X models"); - case VERSION_SD1: - LOG_INFO("AYS using SD1.5 noise levels"); - inputs = noise_levels[0]; - break; - case VERSION_SDXL: - LOG_INFO("AYS using SDXL noise levels"); - inputs = noise_levels[1]; - break; - case VERSION_SVD: - LOG_INFO("AYS using SVD noise levels"); - inputs = noise_levels[2]; - break; - default: - LOG_ERROR("Version not compatable with AYS scheduler"); - return results; + if (sd_version_is_sd2((SDVersion)version)) { + LOG_WARN("AYS not designed for SD2.X models"); + } /* fallthrough */ + else if (sd_version_is_sd1((SDVersion)version)) { + LOG_INFO("AYS using SD1.5 noise levels"); + inputs = noise_levels[0]; + } else if (sd_version_is_sdxl((SDVersion)version)) { + LOG_INFO("AYS using SDXL noise levels"); + inputs = noise_levels[1]; + } else if (version == VERSION_SVD) { + LOG_INFO("AYS using SVD noise levels"); + inputs = noise_levels[2]; + } else { + LOG_ERROR("Version not compatible with AYS scheduler"); + return results; } /* Stretches those pre-calculated reference levels out to the desired @@ -346,6 +343,32 @@ struct CompVisVDenoiser : public CompVisDenoiser { } }; +struct EDMVDenoiser : public CompVisVDenoiser { + float min_sigma = 0.002; + float max_sigma = 120.0; + + EDMVDenoiser(float min_sigma = 0.002, float max_sigma = 120.0) + : min_sigma(min_sigma), max_sigma(max_sigma) { + schedule = std::make_shared(); + } + + float t_to_sigma(float t) { + return std::exp(t * 4 / (float)TIMESTEPS); + } + + float sigma_to_t(float s) { + return 0.25 * std::log(s); + } + + float sigma_min() { + return min_sigma; + } + + float sigma_max() { + return max_sigma; + } +}; + float time_snr_shift(float alpha, float t) { if (alpha == 1.0f) { return t; @@ -1019,7 +1042,7 @@ static void sample_k_diffusion(sample_method_t method, // also needed to invert the behavior of CompVisDenoiser // (k-diffusion's LMSDiscreteScheduler) float beta_start = 0.00085f; - float beta_end = 0.0120f; + float beta_end = 0.0120f; std::vector alphas_cumprod; std::vector compvis_sigmas; @@ -1030,8 +1053,9 @@ static void sample_k_diffusion(sample_method_t method, (i == 0 ? 1.0f : alphas_cumprod[i - 1]) * (1.0f - std::pow(sqrtf(beta_start) + - (sqrtf(beta_end) - sqrtf(beta_start)) * - ((float)i / (TIMESTEPS - 1)), 2)); + (sqrtf(beta_end) - sqrtf(beta_start)) * + ((float)i / (TIMESTEPS - 1)), + 2)); compvis_sigmas[i] = std::sqrt((1 - alphas_cumprod[i]) / alphas_cumprod[i]); @@ -1061,7 +1085,8 @@ static void sample_k_diffusion(sample_method_t method, // - pred_prev_sample -> "x_t-1" int timestep = roundf(TIMESTEPS - - i * ((float)TIMESTEPS / steps)) - 1; + i * ((float)TIMESTEPS / steps)) - + 1; // 1. get previous step value (=t-1) int prev_timestep = timestep - TIMESTEPS / steps; // The sigma here is chosen to cause the @@ -1086,10 +1111,9 @@ static void sample_k_diffusion(sample_method_t method, float* vec_x = (float*)x->data; for (int j = 0; j < ggml_nelements(x); j++) { vec_x[j] *= std::sqrt(sigma * sigma + 1) / - sigma; + sigma; } - } - else { + } else { // For the subsequent steps after the first one, // at this point x = latents or x = sample, and // needs to be prescaled with x <- sample / c_in @@ -1127,9 +1151,8 @@ static void sample_k_diffusion(sample_method_t method, float alpha_prod_t = alphas_cumprod[timestep]; // Note final_alpha_cumprod = alphas_cumprod[0] due to // trailing timestep spacing - float alpha_prod_t_prev = prev_timestep >= 0 ? - alphas_cumprod[prev_timestep] : alphas_cumprod[0]; - float beta_prod_t = 1 - alpha_prod_t; + float alpha_prod_t_prev = prev_timestep >= 0 ? alphas_cumprod[prev_timestep] : alphas_cumprod[0]; + float beta_prod_t = 1 - alpha_prod_t; // 3. compute predicted original sample from predicted // noise also called "predicted x_0" of formula (12) // from https://arxiv.org/pdf/2010.02502.pdf @@ -1145,7 +1168,7 @@ static void sample_k_diffusion(sample_method_t method, vec_pred_original_sample[j] = (vec_x[j] / std::sqrt(sigma * sigma + 1) - std::sqrt(beta_prod_t) * - vec_model_output[j]) * + vec_model_output[j]) * (1 / std::sqrt(alpha_prod_t)); } } @@ -1159,8 +1182,8 @@ static void sample_k_diffusion(sample_method_t method, // sigma_t = sqrt((1 - alpha_t-1)/(1 - alpha_t)) * // sqrt(1 - alpha_t/alpha_t-1) float beta_prod_t_prev = 1 - alpha_prod_t_prev; - float variance = (beta_prod_t_prev / beta_prod_t) * - (1 - alpha_prod_t / alpha_prod_t_prev); + float variance = (beta_prod_t_prev / beta_prod_t) * + (1 - alpha_prod_t / alpha_prod_t_prev); float std_dev_t = eta * std::sqrt(variance); // 6. compute "direction pointing to x_t" of formula // (12) from https://arxiv.org/pdf/2010.02502.pdf @@ -1179,8 +1202,8 @@ static void sample_k_diffusion(sample_method_t method, std::pow(std_dev_t, 2)) * vec_model_output[j]; vec_x[j] = std::sqrt(alpha_prod_t_prev) * - vec_pred_original_sample[j] + - pred_sample_direction; + vec_pred_original_sample[j] + + pred_sample_direction; } } if (eta > 0) { @@ -1208,7 +1231,7 @@ static void sample_k_diffusion(sample_method_t method, // by Semi-Linear Consistency Function with Trajectory // Mapping", arXiv:2402.19159 [cs.CV] float beta_start = 0.00085f; - float beta_end = 0.0120f; + float beta_end = 0.0120f; std::vector alphas_cumprod; std::vector compvis_sigmas; @@ -1219,8 +1242,9 @@ static void sample_k_diffusion(sample_method_t method, (i == 0 ? 1.0f : alphas_cumprod[i - 1]) * (1.0f - std::pow(sqrtf(beta_start) + - (sqrtf(beta_end) - sqrtf(beta_start)) * - ((float)i / (TIMESTEPS - 1)), 2)); + (sqrtf(beta_end) - sqrtf(beta_start)) * + ((float)i / (TIMESTEPS - 1)), + 2)); compvis_sigmas[i] = std::sqrt((1 - alphas_cumprod[i]) / alphas_cumprod[i]); @@ -1235,13 +1259,10 @@ static void sample_k_diffusion(sample_method_t method, for (int i = 0; i < steps; i++) { // Analytic form for TCD timesteps int timestep = TIMESTEPS - 1 - - (TIMESTEPS / original_steps) * - (int)floor(i * ((float)original_steps / steps)); + (TIMESTEPS / original_steps) * + (int)floor(i * ((float)original_steps / steps)); // 1. get previous step value - int prev_timestep = i >= steps - 1 ? 0 : - TIMESTEPS - 1 - (TIMESTEPS / original_steps) * - (int)floor((i + 1) * - ((float)original_steps / steps)); + int prev_timestep = i >= steps - 1 ? 0 : TIMESTEPS - 1 - (TIMESTEPS / original_steps) * (int)floor((i + 1) * ((float)original_steps / steps)); // Here timestep_s is tau_n' in Algorithm 4. The _s // notation appears to be that from C. Lu, // "DPM-Solver: A Fast ODE Solver for Diffusion @@ -1258,10 +1279,9 @@ static void sample_k_diffusion(sample_method_t method, float* vec_x = (float*)x->data; for (int j = 0; j < ggml_nelements(x); j++) { vec_x[j] *= std::sqrt(sigma * sigma + 1) / - sigma; + sigma; } - } - else { + } else { float* vec_x = (float*)x->data; for (int j = 0; j < ggml_nelements(x); j++) { vec_x[j] *= std::sqrt(sigma * sigma + 1); @@ -1294,15 +1314,14 @@ static void sample_k_diffusion(sample_method_t method, // DPM-Solver. In fact, we have alpha_{t_n} = // \sqrt{\hat{alpha_n}}, [...]" float alpha_prod_t = alphas_cumprod[timestep]; - float beta_prod_t = 1 - alpha_prod_t; + float beta_prod_t = 1 - alpha_prod_t; // Note final_alpha_cumprod = alphas_cumprod[0] since // TCD is always "trailing" - float alpha_prod_t_prev = prev_timestep >= 0 ? - alphas_cumprod[prev_timestep] : alphas_cumprod[0]; + float alpha_prod_t_prev = prev_timestep >= 0 ? alphas_cumprod[prev_timestep] : alphas_cumprod[0]; // The subscript _s are the only portion in this // section (2) unique to TCD float alpha_prod_s = alphas_cumprod[timestep_s]; - float beta_prod_s = 1 - alpha_prod_s; + float beta_prod_s = 1 - alpha_prod_s; // 3. Compute the predicted noised sample x_s based on // the model parameterization // @@ -1317,7 +1336,7 @@ static void sample_k_diffusion(sample_method_t method, vec_pred_original_sample[j] = (vec_x[j] / std::sqrt(sigma * sigma + 1) - std::sqrt(beta_prod_t) * - vec_model_output[j]) * + vec_model_output[j]) * (1 / std::sqrt(alpha_prod_t)); } } @@ -1339,9 +1358,9 @@ static void sample_k_diffusion(sample_method_t method, // pred_epsilon = model_output vec_x[j] = std::sqrt(alpha_prod_s) * - vec_pred_original_sample[j] + + vec_pred_original_sample[j] + std::sqrt(beta_prod_s) * - vec_model_output[j]; + vec_model_output[j]; } } // 4. Sample and inject noise z ~ N(0, I) for @@ -1357,7 +1376,7 @@ static void sample_k_diffusion(sample_method_t method, // In this case, x is still pred_noised_sample, // continue in-place ggml_tensor_set_f32_randn(noise, rng); - float* vec_x = (float*)x->data; + float* vec_x = (float*)x->data; float* vec_noise = (float*)noise->data; for (int j = 0; j < ggml_nelements(x); j++) { // Corresponding to (35) in Zheng et @@ -1366,10 +1385,10 @@ static void sample_k_diffusion(sample_method_t method, vec_x[j] = std::sqrt(alpha_prod_t_prev / alpha_prod_s) * - vec_x[j] + + vec_x[j] + std::sqrt(1 - alpha_prod_t_prev / - alpha_prod_s) * - vec_noise[j]; + alpha_prod_s) * + vec_noise[j]; } } } diff --git a/diffusion_model.hpp b/diffusion_model.hpp index ee4d88f0c..5c349439d 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -13,6 +13,7 @@ struct DiffusionModel { struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, + std::vector ref_latents = {}, int num_video_frames = -1, std::vector controls = {}, float control_strength = 0.f, @@ -68,6 +69,7 @@ struct UNetModel : public DiffusionModel { struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, + std::vector ref_latents = {}, int num_video_frames = -1, std::vector controls = {}, float control_strength = 0.f, @@ -118,6 +120,7 @@ struct MMDiTModel : public DiffusionModel { struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, + std::vector ref_latents = {}, int num_video_frames = -1, std::vector controls = {}, float control_strength = 0.f, @@ -134,8 +137,9 @@ struct FluxModel : public DiffusionModel { FluxModel(ggml_backend_t backend, std::map& tensor_types, SDVersion version = VERSION_FLUX, - bool flash_attn = false) - : flux(backend, tensor_types, "model.diffusion_model", version, flash_attn) { + bool flash_attn = false, + bool use_mask = false) + : flux(backend, tensor_types, "model.diffusion_model", version, flash_attn, use_mask) { } void alloc_params_buffer() { @@ -169,13 +173,14 @@ struct FluxModel : public DiffusionModel { struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, + std::vector ref_latents = {}, int num_video_frames = -1, std::vector controls = {}, float control_strength = 0.f, struct ggml_tensor** output = NULL, struct ggml_context* output_ctx = NULL, std::vector skip_layers = std::vector()) { - return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, output, output_ctx, skip_layers); + return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, ref_latents, output, output_ctx, skip_layers); } }; diff --git a/docs/chroma.md b/docs/chroma.md new file mode 100644 index 000000000..d013a43c8 --- /dev/null +++ b/docs/chroma.md @@ -0,0 +1,33 @@ +# How to Use + +You can run Chroma using stable-diffusion.cpp with a GPU that has 6GB or even 4GB of VRAM, without needing to offload to RAM. + +## Download weights + +- Download Chroma + - If you don't want to do the conversion yourself, download the preconverted gguf model from [silveroxides/Chroma-GGUF](https://huggingface.co/silveroxides/Chroma-GGUF) + - Otherwise, download chroma's safetensors from [lodestones/Chroma](https://huggingface.co/lodestones/Chroma) +- Download vae from https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/ae.safetensors +- Download t5xxl from https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/t5xxl_fp16.safetensors + +## Convert Chroma weights + +You can download the preconverted gguf weights from [silveroxides/Chroma-GGUF](https://huggingface.co/silveroxides/Chroma-GGUF), this way you don't have to do the conversion yourself. + +``` +.\bin\Release\sd.exe -M convert -m ..\..\ComfyUI\models\unet\chroma-unlocked-v40.safetensors -o ..\models\chroma-unlocked-v40-q8_0.gguf -v --type q8_0 +``` + +## Run + +### Example +For example: + +``` + .\bin\Release\sd.exe -diffusion-model ..\models\chroma-unlocked-v40-q8_0.gguf --vae ..\models\ae.sft --t5xxl ..\models\t5xxl_fp16.safetensors -p "a lovely cat holding a sign says 'chroma.cpp'" --cfg-scale 4.0 --sampling-method euler -v --chroma-disable-dit-mask +``` + +![](../assets/flux/chroma_v40.png) + + + diff --git a/docs/kontext.md b/docs/kontext.md new file mode 100644 index 000000000..519752553 --- /dev/null +++ b/docs/kontext.md @@ -0,0 +1,39 @@ +# How to Use + +You can run Kontext using stable-diffusion.cpp with a GPU that has 6GB or even 4GB of VRAM, without needing to offload to RAM. + +## Download weights + +- Download Kontext + - If you don't want to do the conversion yourself, download the preconverted gguf model from [FLUX.1-Kontext-dev-GGUF](https://huggingface.co/QuantStack/FLUX.1-Kontext-dev-GGUF) + - Otherwise, download FLUX.1-Kontext-dev from https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev/blob/main/flux1-kontext-dev.safetensors +- Download vae from https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/ae.safetensors +- Download clip_l from https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/clip_l.safetensors +- Download t5xxl from https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/t5xxl_fp16.safetensors + +## Convert Kontext weights + +You can download the preconverted gguf weights from [FLUX.1-Kontext-dev-GGUF](https://huggingface.co/QuantStack/FLUX.1-Kontext-dev-GGUF), this way you don't have to do the conversion yourself. + +``` +.\bin\Release\sd.exe -M convert -m ..\..\ComfyUI\models\unet\flux1-kontext-dev.safetensors -o ..\models\flux1-kontext-dev-q8_0.gguf -v --type q8_0 +``` + +## Run + +- `--cfg-scale` is recommended to be set to 1. + +### Example +For example: + +``` + .\bin\Release\sd.exe -M edit -r .\flux1-dev-q8_0.png --diffusion-model ..\models\flux1-kontext-dev-q8_0.gguf --vae ..\models\ae.sft --clip_l ..\models\clip_l.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -p "change 'flux.cpp' to 'kontext.cpp'" --cfg-scale 1.0 --sampling-method euler -v +``` + + +| ref_image | prompt | output | +| ---- | ---- |---- | +| ![](../assets/flux/flux1-dev-q8_0.png) | change 'flux.cpp' to 'kontext.cpp' |![](../assets/flux/kontext1_dev_output.png) | + + + diff --git a/esrgan.hpp b/esrgan.hpp index 989d15fee..5cbb4ad8f 100644 --- a/esrgan.hpp +++ b/esrgan.hpp @@ -130,8 +130,8 @@ class RRDBNet : public GGMLBlock { body_feat = conv_body->forward(ctx, body_feat); feat = ggml_add(ctx, feat, body_feat); // upsample - feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx, feat, 2))); - feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx, feat, 2))); + feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST))); + feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST))); auto out = conv_last->forward(ctx, lrelu(ctx, conv_hr->forward(ctx, feat))); return out; } diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index af6b2bbdb..bb695c3bb 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -57,13 +57,16 @@ const char* modes_str[] = { "txt2img", "img2img", "img2vid", + "edit", "convert", }; +#define SD_ALL_MODES_STR "txt2img, img2img, edit, convert" enum SDMode { TXT2IMG, IMG2IMG, IMG2VID, + EDIT, CONVERT, MODE_COUNT }; @@ -84,11 +87,13 @@ struct SDParams { std::string stacked_id_embeddings_path; std::string input_id_images_path; sd_type_t wtype = SD_TYPE_COUNT; + std::string tensor_type_rules; std::string lora_model_dir; std::string output_path = "output.png"; std::string input_path; std::string mask_path; std::string control_image_path; + std::vector ref_image_paths; std::string prompt; std::string negative_prompt; @@ -129,6 +134,10 @@ struct SDParams { float slg_scale = 0.f; float skip_layer_start = 0.01f; float skip_layer_end = 0.2f; + + bool chroma_use_dit_mask = true; + bool chroma_use_t5_mask = false; + int chroma_t5_mask_pad = 1; }; void print_params(SDParams params) { @@ -154,6 +163,10 @@ void print_params(SDParams params) { printf(" init_img: %s\n", params.input_path.c_str()); printf(" mask_img: %s\n", params.mask_path.c_str()); printf(" control_image: %s\n", params.control_image_path.c_str()); + printf(" ref_images_paths:\n"); + for (auto& path : params.ref_image_paths) { + printf(" %s\n", path.c_str()); + }; printf(" clip on cpu: %s\n", params.clip_on_cpu ? "true" : "false"); printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false"); printf(" vae decoder on cpu:%s\n", params.vae_on_cpu ? "true" : "false"); @@ -178,6 +191,9 @@ void print_params(SDParams params) { printf(" batch_count: %d\n", params.batch_count); printf(" vae_tiling: %s\n", params.vae_tiling ? "true" : "false"); printf(" upscale_repeats: %d\n", params.upscale_repeats); + printf(" chroma_use_dit_mask: %s\n", params.chroma_use_dit_mask ? "true" : "false"); + printf(" chroma_use_t5_mask: %s\n", params.chroma_use_t5_mask ? "true" : "false"); + printf(" chroma_t5_mask_pad: %d\n", params.chroma_t5_mask_pad); } void print_usage(int argc, const char* argv[]) { @@ -185,14 +201,18 @@ void print_usage(int argc, const char* argv[]) { printf("\n"); printf("arguments:\n"); printf(" -h, --help show this help message and exit\n"); - printf(" -M, --mode [MODEL] run mode (txt2img or img2img or convert, default: txt2img)\n"); + printf(" -M, --mode [MODE] run mode, one of:\n"); + printf(" txt2img: generate an image from a text prompt (default)\n"); + printf(" img2img: generate an image from a text prompt and an initial image (--init-img)\n"); + printf(" edit: modify an image (--ref-image) based on text instructions\n"); + printf(" convert: convert a model file to gguf format, optionally with quantization\n"); printf(" -t, --threads N number of threads to use during computation (default: -1)\n"); printf(" If threads <= 0, then threads will be set to the number of CPU physical cores\n"); printf(" -m, --model [MODEL] path to full model\n"); printf(" --diffusion-model path to the standalone diffusion model\n"); printf(" --clip_l path to the clip-l text encoder\n"); printf(" --clip_g path to the clip-g text encoder\n"); - printf(" --t5xxl path to the the t5xxl text encoder\n"); + printf(" --t5xxl path to the t5xxl text encoder\n"); printf(" --vae [VAE] path to vae\n"); printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n"); printf(" --control-net [CONTROL_PATH] path to control net model\n"); @@ -204,10 +224,12 @@ void print_usage(int argc, const char* argv[]) { printf(" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n"); printf(" --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)\n"); printf(" If not specified, the default is the type of the weight file\n"); + printf(" --tensor-type-rules [EXPRESSION] weight type per tensor pattern (example: \"^vae\\.=f16,model\\.=q8_0\")\n"); printf(" --lora-model-dir [DIR] lora model directory\n"); printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n"); printf(" --mask [MASK] path to the mask image, required by img2img with mask\n"); printf(" --control-image [IMAGE] path to image condition, control net\n"); + printf(" -r, --ref-image [PATH] reference image for Flux Kontext models (can be used multiple times) \n"); printf(" -o, --output OUTPUT path to write result image to (default: ./output.png)\n"); printf(" -p, --prompt [PROMPT] the prompt to render\n"); printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n"); @@ -243,7 +265,10 @@ void print_usage(int argc, const char* argv[]) { printf(" This might crash if it is not supported by the backend.\n"); printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n"); printf(" --canny apply canny preprocessor (edge detection)\n"); - printf(" --color Colors the logging tags according to level\n"); + printf(" --color colors the logging tags according to level\n"); + printf(" --chroma-disable-dit-mask disable dit mask for chroma\n"); + printf(" --chroma-enable-t5-mask enable t5 mask for chroma\n"); + printf(" --chroma-t5-mask-pad PAD_SIZE t5 mask pad size of chroma\n"); printf(" -v, --verbose print extra info\n"); } @@ -273,8 +298,8 @@ void parse_args(int argc, const char** argv, SDParams& params) { } if (mode_found == -1) { fprintf(stderr, - "error: invalid mode %s, must be one of [txt2img, img2img, img2vid, convert]\n", - mode_selected); + "error: invalid mode %s, must be one of [%s]\n", + mode_selected, SD_ALL_MODES_STR); exit(1); } params.mode = (SDMode)mode_found; @@ -381,6 +406,12 @@ void parse_args(int argc, const char** argv, SDParams& params) { valid_types.c_str()); exit(1); } + } else if (arg == "--tensor-type-rules") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.tensor_type_rules = argv[i]; } else if (arg == "--lora-model-dir") { if (++i >= argc) { invalid_arg = true; @@ -629,6 +660,22 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.skip_layer_end = std::stof(argv[i]); + } else if (arg == "-r" || arg == "--ref-image") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.ref_image_paths.push_back(argv[i]); + } else if (arg == "--chroma-disable-dit-mask") { + params.chroma_use_dit_mask = false; + } else if (arg == "--chroma-enable-t5-mask") { + params.chroma_use_t5_mask = true; + } else if (arg == "--chroma-t5-mask-pad") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.chroma_t5_mask_pad = std::stoi(argv[i]); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); print_usage(argc, argv); @@ -657,7 +704,13 @@ void parse_args(int argc, const char** argv, SDParams& params) { } if ((params.mode == IMG2IMG || params.mode == IMG2VID) && params.input_path.length() == 0) { - fprintf(stderr, "error: when using the img2img mode, the following arguments are required: init-img\n"); + fprintf(stderr, "error: when using the img2img/img2vid mode, the following arguments are required: init-img\n"); + print_usage(argc, argv); + exit(1); + } + + if (params.mode == EDIT && params.ref_image_paths.size() == 0) { + fprintf(stderr, "error: when using the edit mode, the following arguments are required: ref-image\n"); print_usage(argc, argv); exit(1); } @@ -688,6 +741,10 @@ void parse_args(int argc, const char** argv, SDParams& params) { exit(1); } + if (params.mode != CONVERT && params.tensor_type_rules.size() > 0) { + fprintf(stderr, "warning: --tensor-type-rules is currently supported only for conversion\n"); + } + if (params.seed < 0) { srand((int)time(NULL)); params.seed = rand(); @@ -800,7 +857,7 @@ int main(int argc, const char* argv[]) { } if (params.mode == CONVERT) { - bool success = convert(params.model_path.c_str(), params.vae_path.c_str(), params.output_path.c_str(), params.wtype); + bool success = convert(params.model_path.c_str(), params.vae_path.c_str(), params.output_path.c_str(), params.wtype, params.tensor_type_rules.c_str()); if (!success) { fprintf(stderr, "convert '%s'/'%s' to '%s' failed\n", @@ -826,6 +883,7 @@ int main(int argc, const char* argv[]) { uint8_t* input_image_buffer = NULL; uint8_t* control_image_buffer = NULL; uint8_t* mask_image_buffer = NULL; + std::vector ref_images; if (params.mode == IMG2IMG || params.mode == IMG2VID) { vae_decode_only = false; @@ -877,6 +935,37 @@ int main(int argc, const char* argv[]) { free(input_image_buffer); input_image_buffer = resized_image_buffer; } + } else if (params.mode == EDIT) { + vae_decode_only = false; + for (auto& path : params.ref_image_paths) { + int c = 0; + int width = 0; + int height = 0; + uint8_t* image_buffer = stbi_load(path.c_str(), &width, &height, &c, 3); + if (image_buffer == NULL) { + fprintf(stderr, "load image from '%s' failed\n", path.c_str()); + return 1; + } + if (c < 3) { + fprintf(stderr, "the number of channels for the input image must be >= 3, but got %d channels\n", c); + free(image_buffer); + return 1; + } + if (width <= 0) { + fprintf(stderr, "error: the width of image must be greater than 0\n"); + free(image_buffer); + return 1; + } + if (height <= 0) { + fprintf(stderr, "error: the height of image must be greater than 0\n"); + free(image_buffer); + return 1; + } + ref_images.push_back({(uint32_t)width, + (uint32_t)height, + 3, + image_buffer}); + } } sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(), @@ -900,7 +989,10 @@ int main(int argc, const char* argv[]) { params.clip_on_cpu, params.control_net_cpu, params.vae_on_cpu, - params.diffusion_flash_attn); + params.diffusion_flash_attn, + params.chroma_use_dit_mask, + params.chroma_use_t5_mask, + params.chroma_t5_mask_pad); if (sd_ctx == NULL) { printf("new_sd_ctx_t failed\n"); @@ -968,7 +1060,7 @@ int main(int argc, const char* argv[]) { params.slg_scale, params.skip_layer_start, params.skip_layer_end); - } else { + } else if (params.mode == IMG2IMG || params.mode == IMG2VID) { sd_image_t input_image = {(uint32_t)params.width, (uint32_t)params.height, 3, @@ -1038,6 +1130,32 @@ int main(int argc, const char* argv[]) { params.skip_layer_start, params.skip_layer_end); } + } else { // EDIT + results = edit(sd_ctx, + ref_images.data(), + ref_images.size(), + params.prompt.c_str(), + params.negative_prompt.c_str(), + params.clip_skip, + params.cfg_scale, + params.guidance, + params.eta, + params.width, + params.height, + params.sample_method, + params.sample_steps, + params.strength, + params.seed, + params.batch_count, + control_image, + params.control_strength, + params.style_ratio, + params.normalize_input, + params.skip_layers.data(), + params.skip_layers.size(), + params.slg_scale, + params.skip_layer_start, + params.skip_layer_end); } if (results == NULL) { @@ -1075,11 +1193,11 @@ int main(int argc, const char* argv[]) { std::string dummy_name, ext, lc_ext; bool is_jpg; - size_t last = params.output_path.find_last_of("."); + size_t last = params.output_path.find_last_of("."); size_t last_path = std::min(params.output_path.find_last_of("/"), params.output_path.find_last_of("\\")); - if (last != std::string::npos // filename has extension - && (last_path == std::string::npos || last > last_path)) { + if (last != std::string::npos // filename has extension + && (last_path == std::string::npos || last > last_path)) { dummy_name = params.output_path.substr(0, last); ext = lc_ext = params.output_path.substr(last); std::transform(ext.begin(), ext.end(), lc_ext.begin(), ::tolower); @@ -1087,7 +1205,7 @@ int main(int argc, const char* argv[]) { } else { dummy_name = params.output_path; ext = lc_ext = ""; - is_jpg = false; + is_jpg = false; } // appending ".png" to absent or unknown extension if (!is_jpg && lc_ext != ".png") { @@ -1099,7 +1217,7 @@ int main(int argc, const char* argv[]) { continue; } std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ext : dummy_name + ext; - if(is_jpg) { + if (is_jpg) { stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel, results[i].data, 90, get_image_params(params, params.seed + i).c_str()); printf("save result JPEG image to '%s'\n", final_image_path.c_str()); diff --git a/flux.hpp b/flux.hpp index 20ff41096..11045918f 100644 --- a/flux.hpp +++ b/flux.hpp @@ -117,6 +117,7 @@ namespace Flux { struct ggml_tensor* k, struct ggml_tensor* v, struct ggml_tensor* pe, + struct ggml_tensor* mask, bool flash_attn) { // q,k,v: [N, L, n_head, d_head] // pe: [L, d_head/2, 2, 2] @@ -124,7 +125,7 @@ namespace Flux { q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head] k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head] - auto x = ggml_nn_attention_ext(ctx, q, k, v, v->ne[1], NULL, false, true, flash_attn); // [N, L, n_head*d_head] + auto x = ggml_nn_attention_ext(ctx, q, k, v, v->ne[1], mask, false, true, flash_attn); // [N, L, n_head*d_head] return x; } @@ -167,13 +168,13 @@ namespace Flux { return x; } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* pe) { + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* pe, struct ggml_tensor* mask) { // x: [N, n_token, dim] // pe: [n_token, d_head/2, 2, 2] // return [N, n_token, dim] - auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head] - x = attention(ctx, qkv[0], qkv[1], qkv[2], pe, flash_attn); // [N, n_token, dim] - x = post_attention(ctx, x); // [N, n_token, dim] + auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head] + x = attention(ctx, qkv[0], qkv[1], qkv[2], pe, mask, flash_attn); // [N, n_token, dim] + x = post_attention(ctx, x); // [N, n_token, dim] return x; } }; @@ -185,6 +186,13 @@ namespace Flux { ModulationOut(ggml_tensor* shift = NULL, ggml_tensor* scale = NULL, ggml_tensor* gate = NULL) : shift(shift), scale(scale), gate(gate) {} + + ModulationOut(struct ggml_context* ctx, ggml_tensor* vec, int64_t offset) { + int64_t stride = vec->nb[1] * vec->ne[1]; + shift = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 0)); // [N, dim] + scale = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 1)); // [N, dim] + gate = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 2)); // [N, dim] + } }; struct Modulation : public GGMLBlock { @@ -210,19 +218,12 @@ namespace Flux { auto m = ggml_reshape_3d(ctx, out, vec->ne[0], multiplier, vec->ne[1]); // [N, multiplier, dim] m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [multiplier, N, dim] - int64_t offset = m->nb[1] * m->ne[1]; - auto shift_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, dim] - auto scale_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, dim] - auto gate_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, dim] - + ModulationOut m_0 = ModulationOut(ctx, m, 0); if (is_double) { - auto shift_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, dim] - auto scale_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, dim] - auto gate_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, dim] - return {ModulationOut(shift_0, scale_0, gate_0), ModulationOut(shift_1, scale_1, gate_1)}; + return {m_0, ModulationOut(ctx, m, 3)}; } - return {ModulationOut(shift_0, scale_0, gate_0), ModulationOut()}; + return {m_0, ModulationOut()}; } }; @@ -242,25 +243,33 @@ namespace Flux { struct DoubleStreamBlock : public GGMLBlock { bool flash_attn; + bool prune_mod; + int idx = 0; public: DoubleStreamBlock(int64_t hidden_size, int64_t num_heads, float mlp_ratio, + int idx = 0, bool qkv_bias = false, - bool flash_attn = false) - : flash_attn(flash_attn) { + bool flash_attn = false, + bool prune_mod = false) + : idx(idx), flash_attn(flash_attn), prune_mod(prune_mod) { int64_t mlp_hidden_dim = hidden_size * mlp_ratio; - blocks["img_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); - blocks["img_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); - blocks["img_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn)); + if (!prune_mod) { + blocks["img_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); + } + blocks["img_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); + blocks["img_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn)); blocks["img_norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); blocks["img_mlp.0"] = std::shared_ptr(new Linear(hidden_size, mlp_hidden_dim)); // img_mlp.1 is nn.GELU(approximate="tanh") blocks["img_mlp.2"] = std::shared_ptr(new Linear(mlp_hidden_dim, hidden_size)); - blocks["txt_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); + if (!prune_mod) { + blocks["txt_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); + } blocks["txt_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); blocks["txt_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn)); @@ -270,17 +279,34 @@ namespace Flux { blocks["txt_mlp.2"] = std::shared_ptr(new Linear(mlp_hidden_dim, hidden_size)); } + std::vector get_distil_img_mod(struct ggml_context* ctx, struct ggml_tensor* vec) { + // TODO: not hardcoded? + const int single_blocks_count = 38; + const int double_blocks_count = 19; + + int64_t offset = 6 * idx + 3 * single_blocks_count; + return {ModulationOut(ctx, vec, offset), ModulationOut(ctx, vec, offset + 3)}; + } + + std::vector get_distil_txt_mod(struct ggml_context* ctx, struct ggml_tensor* vec) { + // TODO: not hardcoded? + const int single_blocks_count = 38; + const int double_blocks_count = 19; + + int64_t offset = 6 * idx + 6 * double_blocks_count + 3 * single_blocks_count; + return {ModulationOut(ctx, vec, offset), ModulationOut(ctx, vec, offset + 3)}; + } + std::pair forward(struct ggml_context* ctx, struct ggml_tensor* img, struct ggml_tensor* txt, struct ggml_tensor* vec, - struct ggml_tensor* pe) { + struct ggml_tensor* pe, + struct ggml_tensor* mask = NULL) { // img: [N, n_img_token, hidden_size] // txt: [N, n_txt_token, hidden_size] // pe: [n_img_token + n_txt_token, d_head/2, 2, 2] // return: ([N, n_img_token, hidden_size], [N, n_txt_token, hidden_size]) - - auto img_mod = std::dynamic_pointer_cast(blocks["img_mod"]); auto img_norm1 = std::dynamic_pointer_cast(blocks["img_norm1"]); auto img_attn = std::dynamic_pointer_cast(blocks["img_attn"]); @@ -288,7 +314,6 @@ namespace Flux { auto img_mlp_0 = std::dynamic_pointer_cast(blocks["img_mlp.0"]); auto img_mlp_2 = std::dynamic_pointer_cast(blocks["img_mlp.2"]); - auto txt_mod = std::dynamic_pointer_cast(blocks["txt_mod"]); auto txt_norm1 = std::dynamic_pointer_cast(blocks["txt_norm1"]); auto txt_attn = std::dynamic_pointer_cast(blocks["txt_attn"]); @@ -296,10 +321,22 @@ namespace Flux { auto txt_mlp_0 = std::dynamic_pointer_cast(blocks["txt_mlp.0"]); auto txt_mlp_2 = std::dynamic_pointer_cast(blocks["txt_mlp.2"]); - auto img_mods = img_mod->forward(ctx, vec); + std::vector img_mods; + if (prune_mod) { + img_mods = get_distil_img_mod(ctx, vec); + } else { + auto img_mod = std::dynamic_pointer_cast(blocks["img_mod"]); + img_mods = img_mod->forward(ctx, vec); + } ModulationOut img_mod1 = img_mods[0]; ModulationOut img_mod2 = img_mods[1]; - auto txt_mods = txt_mod->forward(ctx, vec); + std::vector txt_mods; + if (prune_mod) { + txt_mods = get_distil_txt_mod(ctx, vec); + } else { + auto txt_mod = std::dynamic_pointer_cast(blocks["txt_mod"]); + txt_mods = txt_mod->forward(ctx, vec); + } ModulationOut txt_mod1 = txt_mods[0]; ModulationOut txt_mod2 = txt_mods[1]; @@ -324,7 +361,7 @@ namespace Flux { auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] - auto attn = attention(ctx, q, k, v, pe, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head] + auto attn = attention(ctx, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head] attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] auto txt_attn_out = ggml_view_3d(ctx, attn, @@ -373,14 +410,18 @@ namespace Flux { int64_t hidden_size; int64_t mlp_hidden_dim; bool flash_attn; + bool prune_mod; + int idx = 0; public: SingleStreamBlock(int64_t hidden_size, int64_t num_heads, float mlp_ratio = 4.0f, + int idx = 0, float qk_scale = 0.f, - bool flash_attn = false) - : hidden_size(hidden_size), num_heads(num_heads), flash_attn(flash_attn) { + bool flash_attn = false, + bool prune_mod = false) + : hidden_size(hidden_size), num_heads(num_heads), idx(idx), flash_attn(flash_attn), prune_mod(prune_mod) { int64_t head_dim = hidden_size / num_heads; float scale = qk_scale; if (scale <= 0.f) { @@ -393,26 +434,37 @@ namespace Flux { blocks["norm"] = std::shared_ptr(new QKNorm(head_dim)); blocks["pre_norm"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); // mlp_act is nn.GELU(approximate="tanh") - blocks["modulation"] = std::shared_ptr(new Modulation(hidden_size, false)); + if (!prune_mod) { + blocks["modulation"] = std::shared_ptr(new Modulation(hidden_size, false)); + } + } + + ModulationOut get_distil_mod(struct ggml_context* ctx, struct ggml_tensor* vec) { + int64_t offset = 3 * idx; + return ModulationOut(ctx, vec, offset); } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* vec, - struct ggml_tensor* pe) { + struct ggml_tensor* pe, + struct ggml_tensor* mask = NULL) { // x: [N, n_token, hidden_size] // pe: [n_token, d_head/2, 2, 2] // return: [N, n_token, hidden_size] - auto linear1 = std::dynamic_pointer_cast(blocks["linear1"]); - auto linear2 = std::dynamic_pointer_cast(blocks["linear2"]); - auto norm = std::dynamic_pointer_cast(blocks["norm"]); - auto pre_norm = std::dynamic_pointer_cast(blocks["pre_norm"]); - auto modulation = std::dynamic_pointer_cast(blocks["modulation"]); - - auto mods = modulation->forward(ctx, vec); - ModulationOut mod = mods[0]; - + auto linear1 = std::dynamic_pointer_cast(blocks["linear1"]); + auto linear2 = std::dynamic_pointer_cast(blocks["linear2"]); + auto norm = std::dynamic_pointer_cast(blocks["norm"]); + auto pre_norm = std::dynamic_pointer_cast(blocks["pre_norm"]); + ModulationOut mod; + if (prune_mod) { + mod = get_distil_mod(ctx, vec); + } else { + auto modulation = std::dynamic_pointer_cast(blocks["modulation"]); + + mod = modulation->forward(ctx, vec)[0]; + } auto x_mod = Flux::modulate(ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale); auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim] qkv_mlp = ggml_cont(ctx, ggml_permute(ctx, qkv_mlp, 2, 0, 1, 3)); // [hidden_size * 3 + mlp_hidden_dim, N, n_token] @@ -443,7 +495,7 @@ namespace Flux { auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head] q = norm->query_norm(ctx, q); k = norm->key_norm(ctx, k); - auto attn = attention(ctx, q, k, v, pe, flash_attn); // [N, n_token, hidden_size] + auto attn = attention(ctx, q, k, v, pe, mask, flash_attn); // [N, n_token, hidden_size] auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim] auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size] @@ -454,13 +506,28 @@ namespace Flux { }; struct LastLayer : public GGMLBlock { + bool prune_mod; + public: LastLayer(int64_t hidden_size, int64_t patch_size, - int64_t out_channels) { - blocks["norm_final"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); - blocks["linear"] = std::shared_ptr(new Linear(hidden_size, patch_size * patch_size * out_channels)); - blocks["adaLN_modulation.1"] = std::shared_ptr(new Linear(hidden_size, 2 * hidden_size)); + int64_t out_channels, + bool prune_mod = false) + : prune_mod(prune_mod) { + blocks["norm_final"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); + blocks["linear"] = std::shared_ptr(new Linear(hidden_size, patch_size * patch_size * out_channels)); + if (!prune_mod) { + blocks["adaLN_modulation.1"] = std::shared_ptr(new Linear(hidden_size, 2 * hidden_size)); + } + } + + ModulationOut get_distil_mod(struct ggml_context* ctx, struct ggml_tensor* vec) { + int64_t offset = vec->ne[2] - 2; + int64_t stride = vec->nb[1] * vec->ne[1]; + auto shift = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 0)); // [N, dim] + auto scale = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 1)); // [N, dim] + // No gate + return ModulationOut(shift, scale, NULL); } struct ggml_tensor* forward(struct ggml_context* ctx, @@ -469,17 +536,24 @@ namespace Flux { // x: [N, n_token, hidden_size] // c: [N, hidden_size] // return: [N, n_token, patch_size * patch_size * out_channels] - auto norm_final = std::dynamic_pointer_cast(blocks["norm_final"]); - auto linear = std::dynamic_pointer_cast(blocks["linear"]); - auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); - - auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, 2 * hidden_size] - m = ggml_reshape_3d(ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size] - m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size] - - int64_t offset = m->nb[1] * m->ne[1]; - auto shift = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] - auto scale = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] + auto norm_final = std::dynamic_pointer_cast(blocks["norm_final"]); + auto linear = std::dynamic_pointer_cast(blocks["linear"]); + struct ggml_tensor *shift, *scale; + if (prune_mod) { + auto mod = get_distil_mod(ctx, c); + shift = mod.shift; + scale = mod.scale; + } else { + auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); + + auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, 2 * hidden_size] + m = ggml_reshape_3d(ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size] + m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size] + + int64_t offset = m->nb[1] * m->ne[1]; + shift = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] + scale = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] + } x = Flux::modulate(ctx, norm_final->forward(ctx, x), shift, scale); x = linear->forward(ctx, x); @@ -488,6 +562,34 @@ namespace Flux { } }; + struct ChromaApproximator : public GGMLBlock { + int64_t inner_size = 5120; + int64_t n_layers = 5; + ChromaApproximator(int64_t in_channels = 64, int64_t hidden_size = 3072) { + blocks["in_proj"] = std::shared_ptr(new Linear(in_channels, inner_size, true)); + for (int i = 0; i < n_layers; i++) { + blocks["norms." + std::to_string(i)] = std::shared_ptr(new RMSNorm(inner_size)); + blocks["layers." + std::to_string(i)] = std::shared_ptr(new MLPEmbedder(inner_size, inner_size)); + } + blocks["out_proj"] = std::shared_ptr(new Linear(inner_size, hidden_size, true)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + auto in_proj = std::dynamic_pointer_cast(blocks["in_proj"]); + auto out_proj = std::dynamic_pointer_cast(blocks["out_proj"]); + + x = in_proj->forward(ctx, x); + for (int i = 0; i < n_layers; i++) { + auto norm = std::dynamic_pointer_cast(blocks["norms." + std::to_string(i)]); + auto embed = std::dynamic_pointer_cast(blocks["layers." + std::to_string(i)]); + x = ggml_add_inplace(ctx, x, embed->forward(ctx, norm->forward(ctx, x))); + } + x = out_proj->forward(ctx, x); + + return x; + } + }; + struct FluxParams { int64_t in_channels = 64; int64_t out_channels = 64; @@ -504,6 +606,7 @@ namespace Flux { bool qkv_bias = true; bool guidance_embed = true; bool flash_attn = true; + bool is_chroma = false; }; struct Flux : public GGMLBlock { @@ -570,17 +673,22 @@ namespace Flux { } // Generate IDs for image patches and text - std::vector> gen_ids(int h, int w, int patch_size, int bs, int context_len) { + std::vector> gen_txt_ids(int bs, int context_len) { + return std::vector>(bs * context_len, std::vector(3, 0.0)); + } + + std::vector> gen_img_ids(int h, int w, int patch_size, int bs, int index = 0, int h_offset = 0, int w_offset = 0) { int h_len = (h + (patch_size / 2)) / patch_size; int w_len = (w + (patch_size / 2)) / patch_size; std::vector> img_ids(h_len * w_len, std::vector(3, 0.0)); - std::vector row_ids = linspace(0, h_len - 1, h_len); - std::vector col_ids = linspace(0, w_len - 1, w_len); + std::vector row_ids = linspace(h_offset, h_len - 1 + h_offset, h_len); + std::vector col_ids = linspace(w_offset, w_len - 1 + w_offset, w_len); for (int i = 0; i < h_len; ++i) { for (int j = 0; j < w_len; ++j) { + img_ids[i * w_len + j][0] = index; img_ids[i * w_len + j][1] = row_ids[i]; img_ids[i * w_len + j][2] = col_ids[j]; } @@ -592,24 +700,54 @@ namespace Flux { img_ids_repeated[i * img_ids.size() + j] = img_ids[j]; } } + return img_ids_repeated; + } - std::vector> txt_ids(bs * context_len, std::vector(3, 0.0)); - std::vector> ids(bs * (context_len + img_ids.size()), std::vector(3)); + std::vector> concat_ids(const std::vector>& a, + const std::vector>& b, + int bs) { + size_t a_len = a.size() / bs; + size_t b_len = b.size() / bs; + std::vector> ids(a.size() + b.size(), std::vector(3)); for (int i = 0; i < bs; ++i) { - for (int j = 0; j < context_len; ++j) { - ids[i * (context_len + img_ids.size()) + j] = txt_ids[j]; + for (int j = 0; j < a_len; ++j) { + ids[i * (a_len + b_len) + j] = a[i * a_len + j]; } - for (int j = 0; j < img_ids.size(); ++j) { - ids[i * (context_len + img_ids.size()) + context_len + j] = img_ids_repeated[i * img_ids.size() + j]; + for (int j = 0; j < b_len; ++j) { + ids[i * (a_len + b_len) + a_len + j] = b[i * b_len + j]; } } + return ids; + } + std::vector> gen_ids(int h, int w, int patch_size, int bs, int context_len, std::vector ref_latents) { + auto txt_ids = gen_txt_ids(bs, context_len); + auto img_ids = gen_img_ids(h, w, patch_size, bs); + + auto ids = concat_ids(txt_ids, img_ids, bs); + uint64_t curr_h_offset = 0; + uint64_t curr_w_offset = 0; + for (ggml_tensor* ref : ref_latents) { + uint64_t h_offset = 0; + uint64_t w_offset = 0; + if (ref->ne[1] + curr_h_offset > ref->ne[0] + curr_w_offset) { + w_offset = curr_w_offset; + } else { + h_offset = curr_h_offset; + } + + auto ref_ids = gen_img_ids(ref->ne[1], ref->ne[0], patch_size, bs, 1, h_offset, w_offset); + ids = concat_ids(ids, ref_ids, bs); + + curr_h_offset = std::max(curr_h_offset, ref->ne[1] + h_offset); + curr_w_offset = std::max(curr_w_offset, ref->ne[0] + w_offset); + } return ids; } // Generate positional embeddings - std::vector gen_pe(int h, int w, int patch_size, int bs, int context_len, int theta, const std::vector& axes_dim) { - std::vector> ids = gen_ids(h, w, patch_size, bs, context_len); + std::vector gen_pe(int h, int w, int patch_size, int bs, int context_len, std::vector ref_latents, int theta, const std::vector& axes_dim) { + std::vector> ids = gen_ids(h, w, patch_size, bs, context_len, ref_latents); std::vector> trans_ids = transpose(ids); size_t pos_len = ids.size(); int num_axes = axes_dim.size(); @@ -645,11 +783,15 @@ namespace Flux { : params(params) { int64_t pe_dim = params.hidden_size / params.num_heads; - blocks["img_in"] = std::shared_ptr(new Linear(params.in_channels, params.hidden_size, true)); - blocks["time_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); - blocks["vector_in"] = std::shared_ptr(new MLPEmbedder(params.vec_in_dim, params.hidden_size)); - if (params.guidance_embed) { - blocks["guidance_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); + blocks["img_in"] = std::shared_ptr(new Linear(params.in_channels, params.hidden_size, true)); + if (params.is_chroma) { + blocks["distilled_guidance_layer"] = std::shared_ptr(new ChromaApproximator(params.in_channels, params.hidden_size)); + } else { + blocks["time_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); + blocks["vector_in"] = std::shared_ptr(new MLPEmbedder(params.vec_in_dim, params.hidden_size)); + if (params.guidance_embed) { + blocks["guidance_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); + } } blocks["txt_in"] = std::shared_ptr(new Linear(params.context_in_dim, params.hidden_size, true)); @@ -657,19 +799,23 @@ namespace Flux { blocks["double_blocks." + std::to_string(i)] = std::shared_ptr(new DoubleStreamBlock(params.hidden_size, params.num_heads, params.mlp_ratio, + i, params.qkv_bias, - params.flash_attn)); + params.flash_attn, + params.is_chroma)); } for (int i = 0; i < params.depth_single_blocks; i++) { blocks["single_blocks." + std::to_string(i)] = std::shared_ptr(new SingleStreamBlock(params.hidden_size, params.num_heads, params.mlp_ratio, + i, 0.f, - params.flash_attn)); + params.flash_attn, + params.is_chroma)); } - blocks["final_layer"] = std::shared_ptr(new LastLayer(params.hidden_size, 1, params.out_channels)); + blocks["final_layer"] = std::shared_ptr(new LastLayer(params.hidden_size, 1, params.out_channels, params.is_chroma)); } struct ggml_tensor* patchify(struct ggml_context* ctx, @@ -726,25 +872,55 @@ namespace Flux { struct ggml_tensor* y, struct ggml_tensor* guidance, struct ggml_tensor* pe, - std::vector skip_layers = std::vector()) { + struct ggml_tensor* mod_index_arange = NULL, + std::vector skip_layers = {}) { auto img_in = std::dynamic_pointer_cast(blocks["img_in"]); - auto time_in = std::dynamic_pointer_cast(blocks["time_in"]); - auto vector_in = std::dynamic_pointer_cast(blocks["vector_in"]); auto txt_in = std::dynamic_pointer_cast(blocks["txt_in"]); auto final_layer = std::dynamic_pointer_cast(blocks["final_layer"]); - img = img_in->forward(ctx, img); - auto vec = time_in->forward(ctx, ggml_nn_timestep_embedding(ctx, timesteps, 256, 10000, 1000.f)); + img = img_in->forward(ctx, img); + struct ggml_tensor* vec; + struct ggml_tensor* txt_img_mask = NULL; + if (params.is_chroma) { + int64_t mod_index_length = 344; + auto approx = std::dynamic_pointer_cast(blocks["distilled_guidance_layer"]); + auto distill_timestep = ggml_nn_timestep_embedding(ctx, timesteps, 16, 10000, 1000.f); + auto distill_guidance = ggml_nn_timestep_embedding(ctx, guidance, 16, 10000, 1000.f); + + // auto mod_index_arange = ggml_arange(ctx, 0, (float)mod_index_length, 1); + // ggml_arange tot working on a lot of backends, precomputing it on CPU instead + GGML_ASSERT(arange != NULL); + auto modulation_index = ggml_nn_timestep_embedding(ctx, mod_index_arange, 32, 10000, 1000.f); // [1, 344, 32] + + // Batch broadcast (will it ever be useful) + modulation_index = ggml_repeat(ctx, modulation_index, ggml_new_tensor_3d(ctx, GGML_TYPE_F32, modulation_index->ne[0], modulation_index->ne[1], img->ne[2])); // [N, 344, 32] + + auto timestep_guidance = ggml_concat(ctx, distill_timestep, distill_guidance, 0); // [N, 1, 32] + timestep_guidance = ggml_repeat(ctx, timestep_guidance, modulation_index); // [N, 344, 32] + + vec = ggml_concat(ctx, timestep_guidance, modulation_index, 0); // [N, 344, 64] + // Permute for consistency with non-distilled modulation implementation + vec = ggml_cont(ctx, ggml_permute(ctx, vec, 0, 2, 1, 3)); // [344, N, 64] + vec = approx->forward(ctx, vec); // [344, N, hidden_size] + + if (y != NULL) { + txt_img_mask = ggml_pad(ctx, y, img->ne[1], 0, 0, 0); + } + } else { + auto time_in = std::dynamic_pointer_cast(blocks["time_in"]); + auto vector_in = std::dynamic_pointer_cast(blocks["vector_in"]); + vec = time_in->forward(ctx, ggml_nn_timestep_embedding(ctx, timesteps, 256, 10000, 1000.f)); + if (params.guidance_embed) { + GGML_ASSERT(guidance != NULL); + auto guidance_in = std::dynamic_pointer_cast(blocks["guidance_in"]); + // bf16 and fp16 result is different + auto g_in = ggml_nn_timestep_embedding(ctx, guidance, 256, 10000, 1000.f); + vec = ggml_add(ctx, vec, guidance_in->forward(ctx, g_in)); + } - if (params.guidance_embed) { - GGML_ASSERT(guidance != NULL); - auto guidance_in = std::dynamic_pointer_cast(blocks["guidance_in"]); - // bf16 and fp16 result is different - auto g_in = ggml_nn_timestep_embedding(ctx, guidance, 256, 10000, 1000.f); - vec = ggml_add(ctx, vec, guidance_in->forward(ctx, g_in)); + vec = ggml_add(ctx, vec, vector_in->forward(ctx, y)); } - vec = ggml_add(ctx, vec, vector_in->forward(ctx, y)); txt = txt_in->forward(ctx, txt); for (int i = 0; i < params.depth; i++) { @@ -754,7 +930,7 @@ namespace Flux { auto block = std::dynamic_pointer_cast(blocks["double_blocks." + std::to_string(i)]); - auto img_txt = block->forward(ctx, img, txt, vec, pe); + auto img_txt = block->forward(ctx, img, txt, vec, pe, txt_img_mask); img = img_txt.first; // [N, n_img_token, hidden_size] txt = img_txt.second; // [N, n_txt_token, hidden_size] } @@ -766,7 +942,7 @@ namespace Flux { } auto block = std::dynamic_pointer_cast(blocks["single_blocks." + std::to_string(i)]); - txt_img = block->forward(ctx, txt_img, vec, pe); + txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask); } txt_img = ggml_cont(ctx, ggml_permute(ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] @@ -781,7 +957,20 @@ namespace Flux { img = ggml_cont(ctx, ggml_permute(ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size] img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels) + return img; + } + struct ggml_tensor* process_img(struct ggml_context* ctx, + struct ggml_tensor* x) { + int64_t W = x->ne[0]; + int64_t H = x->ne[1]; + int64_t patch_size = 2; + int pad_h = (patch_size - H % patch_size) % patch_size; + int pad_w = (patch_size - W % patch_size) % patch_size; + x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w] + + // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) + auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size] return img; } @@ -793,7 +982,9 @@ namespace Flux { struct ggml_tensor* y, struct ggml_tensor* guidance, struct ggml_tensor* pe, - std::vector skip_layers = std::vector()) { + struct ggml_tensor* mod_index_arange = NULL, + std::vector ref_latents = {}, + std::vector skip_layers = {}) { // Forward pass of DiT. // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) // timestep: (N,) tensor of diffusion timesteps @@ -812,25 +1003,33 @@ namespace Flux { int64_t patch_size = 2; int pad_h = (patch_size - H % patch_size) % patch_size; int pad_w = (patch_size - W % patch_size) % patch_size; - x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w] - // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) - auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size] + auto img = process_img(ctx, x); + uint64_t img_tokens = img->ne[1]; if (c_concat != NULL) { ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0); ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); - masked = ggml_pad(ctx, masked, pad_w, pad_h, 0, 0); - mask = ggml_pad(ctx, mask, pad_w, pad_h, 0, 0); - - masked = patchify(ctx, masked, patch_size); - mask = patchify(ctx, mask, patch_size); + masked = process_img(ctx, masked); + mask = process_img(ctx, mask); img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0); } - auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size] + if (ref_latents.size() > 0) { + for (ggml_tensor* ref : ref_latents) { + ref = process_img(ctx, ref); + img = ggml_concat(ctx, img, ref, 1); + } + } + + auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size] + if (out->ne[1] > img_tokens) { + out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size] + out = ggml_view_3d(ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0); + out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size] + } // rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2) out = unpatchify(ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size, patch_size); // [N, C, H + pad_h, W + pad_w] @@ -845,14 +1044,18 @@ namespace Flux { public: FluxParams flux_params; Flux flux; - std::vector pe_vec; // for cache + std::vector pe_vec; + std::vector mod_index_arange_vec; // for cache + SDVersion version; + bool use_mask = false; FluxRunner(ggml_backend_t backend, std::map& tensor_types = empty_tensor_types, const std::string prefix = "", SDVersion version = VERSION_FLUX, - bool flash_attn = false) - : GGMLRunner(backend) { + bool flash_attn = false, + bool use_mask = false) + : GGMLRunner(backend), use_mask(use_mask) { flux_params.flash_attn = flash_attn; flux_params.guidance_embed = false; flux_params.depth = 0; @@ -868,6 +1071,10 @@ namespace Flux { // not schnell flux_params.guidance_embed = true; } + if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) { + // Chroma + flux_params.is_chroma = true; + } size_t db = tensor_name.find("double_blocks."); if (db != std::string::npos) { tensor_name = tensor_name.substr(db); // remove prefix @@ -887,7 +1094,9 @@ namespace Flux { } LOG_INFO("Flux blocks: %d double, %d single", flux_params.depth, flux_params.depth_single_blocks); - if (!flux_params.guidance_embed) { + if (flux_params.is_chroma) { + LOG_INFO("Using pruned modulation (Chroma)"); + } else if (!flux_params.guidance_embed) { LOG_INFO("Flux guidance is disabled (Schnell mode)"); } @@ -909,22 +1118,41 @@ namespace Flux { struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, - std::vector skip_layers = std::vector()) { + std::vector ref_latents = {}, + std::vector skip_layers = {}) { GGML_ASSERT(x->ne[3] == 1); struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false); + struct ggml_tensor* mod_index_arange = NULL; + x = to_backend(x); context = to_backend(context); if (c_concat != NULL) { c_concat = to_backend(c_concat); } - y = to_backend(y); + if (flux_params.is_chroma) { + guidance = ggml_set_f32(guidance, 0); + + if (!use_mask) { + y = NULL; + } + + // ggml_arange is not working on some backends, precompute it + mod_index_arange_vec = arange(0, 344); + mod_index_arange = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_F32, mod_index_arange_vec.size()); + set_backend_tensor_data(mod_index_arange, mod_index_arange_vec.data()); + } + y = to_backend(y); + timesteps = to_backend(timesteps); - if (flux_params.guidance_embed) { + if (flux_params.guidance_embed || flux_params.is_chroma) { guidance = to_backend(guidance); } + for (int i = 0; i < ref_latents.size(); i++) { + ref_latents[i] = to_backend(ref_latents[i]); + } - pe_vec = flux.gen_pe(x->ne[1], x->ne[0], 2, x->ne[3], context->ne[1], flux_params.theta, flux_params.axes_dim); + pe_vec = flux.gen_pe(x->ne[1], x->ne[0], 2, x->ne[3], context->ne[1], ref_latents, flux_params.theta, flux_params.axes_dim); int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2; // LOG_DEBUG("pos_len %d", pos_len); auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len); @@ -941,6 +1169,8 @@ namespace Flux { y, guidance, pe, + mod_index_arange, + ref_latents, skip_layers); ggml_build_forward_expand(gf, out); @@ -955,16 +1185,17 @@ namespace Flux { struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, - struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL, - std::vector skip_layers = std::vector()) { + std::vector ref_latents = {}, + struct ggml_tensor** output = NULL, + struct ggml_context* output_ctx = NULL, + std::vector skip_layers = std::vector()) { // x: [N, in_channels, h, w] // timesteps: [N, ] // context: [N, max_position, hidden_size] // y: [N, adm_in_channels] or [1, adm_in_channels] // guidance: [N, ] auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(x, timesteps, context, c_concat, y, guidance, skip_layers); + return build_graph(x, timesteps, context, c_concat, y, guidance, ref_latents, skip_layers); }; GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); @@ -1004,7 +1235,7 @@ namespace Flux { struct ggml_tensor* out = NULL; int t0 = ggml_time_ms(); - compute(8, x, timesteps, context, NULL, y, guidance, &out, work_ctx); + compute(8, x, timesteps, context, NULL, y, guidance, {}, &out, work_ctx); int t1 = ggml_time_ms(); print_ggml_tensor(out); diff --git a/ggml b/ggml index ff9052988..9e4bee1c5 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit ff9052988b76e137bcf92bb335733933ca196ac0 +Subproject commit 9e4bee1c5afc2d677a5b32ecb90cbdb483e81fff diff --git a/ggml_extend.hpp b/ggml_extend.hpp index c5913be4d..9f6a4fef6 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -39,6 +39,10 @@ #include "ggml-vulkan.h" #endif +#ifdef SD_USE_OPENCL +#include "ggml-opencl.h" +#endif + #ifdef SD_USE_SYCL #include "ggml-sycl.h" #endif @@ -113,7 +117,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_kronecker(ggml_context* ctx, struct g a->ne[0] * b->ne[0], a->ne[1] * b->ne[1], a->ne[2] * b->ne[2], - a->ne[3] * b->ne[3]), + a->ne[3] * b->ne[3], + GGML_SCALE_MODE_NEAREST), b); } @@ -597,6 +602,8 @@ typedef std::function on_tile_process; // Tiling __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) { + output = ggml_set_f32(output, 0); + int input_width = (int)input->ne[0]; int input_height = (int)input->ne[1]; int output_width = (int)output->ne[0]; @@ -864,6 +871,18 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head] v = ggml_cast(ctx, v, GGML_TYPE_F16); + if (mask != nullptr) { + mask = ggml_transpose(ctx, mask); + + if (mask->ne[1] < GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD)) { + LOG_DEBUG("mask dims %ld, %ld, %ld, %ld\n", mask->ne[0], mask->ne[1], mask->ne[2], mask->ne[3]); + LOG_DEBUG("needs padding, padding from %ld to %ld\n", mask->ne[1], GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD)); + mask = ggml_pad(ctx, mask, 0, GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) - mask->ne[1], 0, 0); + } + + mask = ggml_cast(ctx, mask, GGML_TYPE_F16); + } + kqv = ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0, 0); ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32); @@ -876,7 +895,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k] kq = ggml_scale_inplace(ctx, kq, scale); if (mask) { - kq = ggml_add(ctx, kq, mask); + kq = ggml_add_inplace(ctx, kq, mask); } if (diag_mask_inf) { kq = ggml_diag_mask_inf_inplace(ctx, kq, 0); diff --git a/lora.hpp b/lora.hpp index d38c7116f..35f5aacd1 100644 --- a/lora.hpp +++ b/lora.hpp @@ -3,7 +3,7 @@ #include "ggml_extend.hpp" -#define LORA_GRAPH_SIZE 10240 +#define LORA_GRAPH_BASE_SIZE 10240 struct LoraModel : public GGMLRunner { enum lora_t { @@ -238,7 +238,8 @@ struct LoraModel : public GGMLRunner { } struct ggml_cgraph* build_lora_graph(std::map model_tensors, SDVersion version) { - struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, LORA_GRAPH_SIZE, false); + size_t lora_graph_size = LORA_GRAPH_BASE_SIZE + lora_tensors.size() * 10; + struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, lora_graph_size, false); zero_index = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_I32, 1); set_backend_tensor_data(zero_index, zero_index_vec.data()); @@ -291,7 +292,6 @@ struct LoraModel : public GGMLRunner { std::string hada_2_down_name = ""; std::string hada_2_up_name = ""; - hada_1_down_name = fk + ".hada_w1_b"; hada_1_up_name = fk + ".hada_w1_a"; hada_1_mid_name = fk + ".hada_t1"; @@ -414,7 +414,7 @@ struct LoraModel : public GGMLRunner { } lokr_w2 = ggml_merge_lora(compute_ctx, down, up); } - + // Technically it might be unused, but I believe it's the expected behavior applied_lora_tensors.insert(alpha_name); diff --git a/model.cpp b/model.cpp index 24da39f6d..559c876c6 100644 --- a/model.cpp +++ b/model.cpp @@ -26,6 +26,10 @@ #include "ggml-vulkan.h" #endif +#ifdef SD_USE_OPENCL +#include "ggml-opencl.h" +#endif + #define ST_HEADER_SIZE_LEN 8 uint64_t read_u64(uint8_t* buffer) { @@ -96,6 +100,7 @@ const char* unused_tensors[] = { "model_ema.diffusion_model", "embedding_manager", "denoiser.sigmas", + "text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training }; bool is_unused_tensor(std::string name) { @@ -177,6 +182,64 @@ std::unordered_map pmid_v2_name_map = { std::string convert_open_clip_to_hf_clip(const std::string& name) { std::string new_name = name; std::string prefix; + if (contains(new_name, ".enc.")) { + // llama.cpp naming convention for T5 + size_t pos = new_name.find(".enc."); + if (pos != std::string::npos) { + new_name.replace(pos, 5, ".encoder."); + } + pos = new_name.find("blk."); + if (pos != std::string::npos) { + new_name.replace(pos, 4, "block."); + } + pos = new_name.find("output_norm."); + if (pos != std::string::npos) { + new_name.replace(pos, 12, "final_layer_norm."); + } + pos = new_name.find("attn_k."); + if (pos != std::string::npos) { + new_name.replace(pos, 7, "layer.0.SelfAttention.k."); + } + pos = new_name.find("attn_v."); + if (pos != std::string::npos) { + new_name.replace(pos, 7, "layer.0.SelfAttention.v."); + } + pos = new_name.find("attn_o."); + if (pos != std::string::npos) { + new_name.replace(pos, 7, "layer.0.SelfAttention.o."); + } + pos = new_name.find("attn_q."); + if (pos != std::string::npos) { + new_name.replace(pos, 7, "layer.0.SelfAttention.q."); + } + pos = new_name.find("attn_norm."); + if (pos != std::string::npos) { + new_name.replace(pos, 10, "layer.0.layer_norm."); + } + pos = new_name.find("ffn_norm."); + if (pos != std::string::npos) { + new_name.replace(pos, 9, "layer.1.layer_norm."); + } + pos = new_name.find("ffn_up."); + if (pos != std::string::npos) { + new_name.replace(pos, 7, "layer.1.DenseReluDense.wi_1."); + } + pos = new_name.find("ffn_down."); + if (pos != std::string::npos) { + new_name.replace(pos, 9, "layer.1.DenseReluDense.wo."); + } + pos = new_name.find("ffn_gate."); + if (pos != std::string::npos) { + new_name.replace(pos, 9, "layer.1.DenseReluDense.wi_0."); + } + pos = new_name.find("attn_rel_b."); + if (pos != std::string::npos) { + new_name.replace(pos, 11, "layer.0.SelfAttention.relative_attention_bias."); + } + } else if (name == "text_encoders.t5xxl.transformer.token_embd.weight") { + new_name = "text_encoders.t5xxl.transformer.shared.weight"; + } + if (starts_with(new_name, "conditioner.embedders.0.open_clip.")) { prefix = "cond_stage_model."; new_name = new_name.substr(strlen("conditioner.embedders.0.open_clip.")); @@ -275,6 +338,10 @@ std::unordered_map> su {"to_v", "v"}, {"to_out_0", "proj_out"}, {"group_norm", "norm"}, + {"key", "k"}, + {"query", "q"}, + {"value", "v"}, + {"proj_attn", "proj_out"}, }, }, { @@ -299,6 +366,10 @@ std::unordered_map> su {"to_v", "v"}, {"to_out.0", "proj_out"}, {"group_norm", "norm"}, + {"key", "k"}, + {"query", "q"}, + {"value", "v"}, + {"proj_attn", "proj_out"}, }, }, { @@ -370,6 +441,10 @@ std::string convert_diffusers_name_to_compvis(std::string key, char seq) { return format("model%cdiffusion_model%ctime_embed%c", seq, seq, seq) + std::to_string(std::stoi(m[0]) * 2 - 2) + m[1]; } + if (match(m, std::regex(format("unet%cadd_embedding%clinear_(\\d+)(.*)", seq, seq)), key)) { + return format("model%cdiffusion_model%clabel_emb%c0%c", seq, seq, seq, seq) + std::to_string(std::stoi(m[0]) * 2 - 2) + m[1]; + } + if (match(m, std::regex(format("unet%cdown_blocks%c(\\d+)%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) { std::string suffix = get_converted_suffix(m[1], m[3]); // LOG_DEBUG("%s %s %s %s", m[0].c_str(), m[1].c_str(), m[2].c_str(), m[3].c_str()); @@ -407,6 +482,19 @@ std::string convert_diffusers_name_to_compvis(std::string key, char seq) { return format("cond_stage_model%ctransformer%ctext_model", seq, seq) + m[0]; } + // clip-g + if (match(m, std::regex(format("te%c1%ctext_model%cencoder%clayers%c(\\d+)%c(.+)", seq, seq, seq, seq, seq, seq)), key)) { + return format("cond_stage_model%c1%ctransformer%ctext_model%cencoder%clayers%c", seq, seq, seq, seq, seq, seq) + m[0] + seq + m[1]; + } + + if (match(m, std::regex(format("te%c1%ctext_model(.*)", seq, seq)), key)) { + return format("cond_stage_model%c1%ctransformer%ctext_model", seq, seq, seq) + m[0]; + } + + if (match(m, std::regex(format("te%c1%ctext_projection", seq, seq)), key)) { + return format("cond_stage_model%c1%ctransformer%ctext_model%ctext_projection", seq, seq, seq, seq); + } + // vae if (match(m, std::regex(format("vae%c(.*)%cconv_norm_out(.*)", seq, seq)), key)) { return format("first_stage_model%c%s%cnorm_out%s", seq, m[0].c_str(), seq, m[1].c_str()); @@ -543,6 +631,8 @@ std::string convert_tensor_name(std::string name) { std::string new_key = convert_diffusers_name_to_compvis(name_without_network_parts, '.'); if (new_key.empty()) { new_name = name; + } else if (new_key == "cond_stage_model.1.transformer.text_model.text_projection") { + new_name = new_key; } else { new_name = new_key + "." + network_part; } @@ -966,10 +1056,14 @@ ggml_type str_to_ggml_type(const std::string& dtype) { ttype = GGML_TYPE_F32; } else if (dtype == "F32") { ttype = GGML_TYPE_F32; + } else if (dtype == "F64") { + ttype = GGML_TYPE_F64; } else if (dtype == "F8_E4M3") { ttype = GGML_TYPE_F16; } else if (dtype == "F8_E5M2") { ttype = GGML_TYPE_F16; + } else if (dtype == "I64") { + ttype = GGML_TYPE_I64; } return ttype; } @@ -982,6 +1076,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const std::ifstream file(file_path, std::ios::binary); if (!file.is_open()) { LOG_ERROR("failed to open '%s'", file_path.c_str()); + file_paths_.pop_back(); return false; } @@ -993,6 +1088,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const // read header size if (file_size_ <= ST_HEADER_SIZE_LEN) { LOG_ERROR("invalid safetensor file '%s'", file_path.c_str()); + file_paths_.pop_back(); return false; } @@ -1006,6 +1102,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const size_t header_size_ = read_u64(header_size_buf); if (header_size_ >= file_size_) { LOG_ERROR("invalid safetensor file '%s'", file_path.c_str()); + file_paths_.pop_back(); return false; } @@ -1016,6 +1113,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const file.read(header_buf.data(), header_size_); if (!file) { LOG_ERROR("read safetensors header failed: '%s'", file_path.c_str()); + file_paths_.pop_back(); return false; } @@ -1103,18 +1201,45 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const /*================================================= DiffusersModelLoader ==================================================*/ bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const std::string& prefix) { - std::string unet_path = path_join(file_path, "unet/diffusion_pytorch_model.safetensors"); - std::string vae_path = path_join(file_path, "vae/diffusion_pytorch_model.safetensors"); - std::string clip_path = path_join(file_path, "text_encoder/model.safetensors"); + std::string unet_path = path_join(file_path, "unet/diffusion_pytorch_model.safetensors"); + std::string vae_path = path_join(file_path, "vae/diffusion_pytorch_model.safetensors"); + std::string clip_path = path_join(file_path, "text_encoder/model.safetensors"); + std::string clip_g_path = path_join(file_path, "text_encoder_2/model.safetensors"); if (!init_from_safetensors_file(unet_path, "unet.")) { return false; } + for (auto ts : tensor_storages) { + if (ts.name.find("add_embedding") != std::string::npos || ts.name.find("label_emb") != std::string::npos) { + // probably SDXL + LOG_DEBUG("Fixing name for SDXL output blocks.2.2"); + for (auto& tensor_storage : tensor_storages) { + int len = 34; + auto pos = tensor_storage.name.find("unet.up_blocks.0.upsamplers.0.conv"); + if (pos == std::string::npos) { + len = 44; + pos = tensor_storage.name.find("model.diffusion_model.output_blocks.2.1.conv"); + } + if (pos != std::string::npos) { + tensor_storage.name = "model.diffusion_model.output_blocks.2.2.conv" + tensor_storage.name.substr(len); + LOG_DEBUG("NEW NAME: %s", tensor_storage.name.c_str()); + add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type); + } + } + break; + } + } + if (!init_from_safetensors_file(vae_path, "vae.")) { - return false; + LOG_WARN("Couldn't find working VAE in %s", file_path.c_str()); + // return false; } if (!init_from_safetensors_file(clip_path, "te.")) { - return false; + LOG_WARN("Couldn't find working text encoder in %s", file_path.c_str()); + // return false; + } + if (!init_from_safetensors_file(clip_g_path, "te.1.")) { + LOG_DEBUG("Couldn't find working second text encoder in %s", file_path.c_str()); } return true; } @@ -1477,6 +1602,15 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s return true; } +bool ModelLoader::model_is_unet() { + for (auto& tensor_storage : tensor_storages) { + if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos) { + return true; + } + } + return false; +} + SDVersion ModelLoader::get_sd_version() { TensorStorage token_embedding_weight, input_block_weight; bool input_block_checked = false; @@ -1499,7 +1633,7 @@ SDVersion ModelLoader::get_sd_version() { if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) { return VERSION_SD3; } - if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos) { + if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos || tensor_storage.name.find("unet.down_blocks.") != std::string::npos) { is_unet = true; if (has_multiple_encoders) { is_xl = true; @@ -1508,7 +1642,7 @@ SDVersion ModelLoader::get_sd_version() { } } } - if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || tensor_storage.name.find("cond_stage_model.1") != std::string::npos) { + if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || tensor_storage.name.find("cond_stage_model.1") != std::string::npos || tensor_storage.name.find("te.1") != std::string::npos) { has_multiple_encoders = true; if (is_unet) { is_xl = true; @@ -1530,7 +1664,7 @@ SDVersion ModelLoader::get_sd_version() { token_embedding_weight = tensor_storage; // break; } - if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight") { + if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight" || tensor_storage.name == "unet.conv_in.weight") { input_block_weight = tensor_storage; input_block_checked = true; if (found_family) { @@ -1615,7 +1749,7 @@ ggml_type ModelLoader::get_diffusion_model_wtype() { continue; } - if (tensor_storage.name.find("model.diffusion_model.") == std::string::npos) { + if (tensor_storage.name.find("model.diffusion_model.") == std::string::npos && tensor_storage.name.find("unet.") == std::string::npos) { continue; } @@ -1779,6 +1913,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend }; int tensor_count = 0; int64_t t1 = ggml_time_ms(); + bool partial = false; for (auto& tensor_storage : processed_tensor_storages) { if (tensor_storage.file_index != file_index) { ++tensor_count; @@ -1860,15 +1995,21 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend ggml_backend_tensor_set(dst_tensor, convert_buffer.data(), 0, ggml_nbytes(dst_tensor)); } } - int64_t t2 = ggml_time_ms(); - pretty_progress(++tensor_count, processed_tensor_storages.size(), (t2 - t1) / 1000.0f); - t1 = t2; + size_t tensor_max = processed_tensor_storages.size(); + int64_t t2 = ggml_time_ms(); + pretty_progress(++tensor_count, tensor_max, (t2 - t1) / 1000.0f); + t1 = t2; + partial = tensor_count != tensor_max; } if (zip != NULL) { zip_close(zip); } + if (partial) { + printf("\n"); + } + if (!success) { break; } @@ -1946,6 +2087,41 @@ bool ModelLoader::load_tensors(std::map& tenso return true; } +std::vector> parse_tensor_type_rules(const std::string& tensor_type_rules) { + std::vector> result; + for (const auto& item : splitString(tensor_type_rules, ',')) { + if (item.size() == 0) + continue; + std::string::size_type pos = item.find('='); + if (pos == std::string::npos) { + LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str()); + continue; + } + std::string tensor_pattern = item.substr(0, pos); + std::string type_name = item.substr(pos + 1); + + ggml_type tensor_type = GGML_TYPE_COUNT; + + if (type_name == "f32") { + tensor_type = GGML_TYPE_F32; + } else { + for (size_t i = 0; i < SD_TYPE_COUNT; i++) { + auto trait = ggml_get_type_traits((ggml_type)i); + if (trait->to_float && trait->type_size && type_name == trait->type_name) { + tensor_type = (ggml_type)i; + } + } + } + + if (tensor_type != GGML_TYPE_COUNT) { + result.emplace_back(tensor_pattern, tensor_type); + } else { + LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str()); + } + } + return result; +} + bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type) { const std::string& name = tensor_storage.name; if (type != GGML_TYPE_COUNT) { @@ -1977,7 +2153,7 @@ bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage return false; } -bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type) { +bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules_str) { auto backend = ggml_backend_cpu_init(); size_t mem_size = 1 * 1024 * 1024; // for padding mem_size += tensor_storages.size() * ggml_tensor_overhead(); @@ -1987,12 +2163,23 @@ bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type gguf_context* gguf_ctx = gguf_init_empty(); + auto tensor_type_rules = parse_tensor_type_rules(tensor_type_rules_str); + auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool { const std::string& name = tensor_storage.name; + ggml_type tensor_type = tensor_storage.type; + ggml_type dst_type = type; - ggml_type tensor_type = tensor_storage.type; - if (tensor_should_be_converted(tensor_storage, type)) { - tensor_type = type; + for (const auto& tensor_type_rule : tensor_type_rules) { + std::regex pattern(tensor_type_rule.first); + if (std::regex_search(name, pattern)) { + dst_type = tensor_type_rule.second; + break; + } + } + + if (tensor_should_be_converted(tensor_storage, dst_type)) { + tensor_type = dst_type; } ggml_tensor* tensor = ggml_new_tensor(ggml_ctx, tensor_type, tensor_storage.n_dims, tensor_storage.ne); @@ -2051,7 +2238,7 @@ int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type) return mem_size; } -bool convert(const char* input_path, const char* vae_path, const char* output_path, sd_type_t output_type) { +bool convert(const char* input_path, const char* vae_path, const char* output_path, sd_type_t output_type, const char* tensor_type_rules) { ModelLoader model_loader; if (!model_loader.init_from_file(input_path)) { @@ -2065,6 +2252,6 @@ bool convert(const char* input_path, const char* vae_path, const char* output_pa return false; } } - bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type); + bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type, tensor_type_rules); return success; } diff --git a/model.h b/model.h index d7f976533..95c66319d 100644 --- a/model.h +++ b/model.h @@ -12,9 +12,9 @@ #include "ggml-backend.h" #include "ggml.h" +#include "gguf.h" #include "json.hpp" #include "zip.h" -#include "gguf.h" #define SD_MAX_DIMS 5 @@ -210,6 +210,7 @@ class ModelLoader { std::map tensor_storages_types; bool init_from_file(const std::string& file_path, const std::string& prefix = ""); + bool model_is_unet(); SDVersion get_sd_version(); ggml_type get_sd_wtype(); ggml_type get_conditioner_wtype(); @@ -221,7 +222,7 @@ class ModelLoader { ggml_backend_t backend, std::set ignore_tensors = {}); - bool save_to_gguf_file(const std::string& file_path, ggml_type type); + bool save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules); bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type); int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT); ~ModelLoader() = default; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index e38a6101f..9c8265727 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -48,8 +48,7 @@ const char* sampling_methods_str[] = { "iPNDM_v", "LCM", "DDIM \"trailing\"", - "TCD" -}; + "TCD"}; /*================================================== Helper Functions ================================================*/ @@ -104,6 +103,9 @@ class StableDiffusionGGML { bool vae_tiling = false; bool stacked_id = false; + bool is_using_v_parameterization = false; + bool is_using_edm_v_parameterization = false; + std::map tensors; std::string lora_model_dir; @@ -159,7 +161,10 @@ class StableDiffusionGGML { bool clip_on_cpu, bool control_net_cpu, bool vae_on_cpu, - bool diffusion_flash_attn) { + bool diffusion_flash_attn, + bool chroma_use_dit_mask, + bool chroma_use_t5_mask, + int chroma_t5_mask_pad) { use_tiny_autoencoder = taesd_path.size() > 0; #ifdef SD_USE_CUDA LOG_DEBUG("Using CUDA backend"); @@ -179,6 +184,14 @@ class StableDiffusionGGML { LOG_WARN("Failed to initialize Vulkan backend"); } #endif +#ifdef SD_USE_OPENCL + LOG_DEBUG("Using OpenCL backend"); + // ggml_log_set(ggml_log_callback_default, nullptr); // Optional ggml logs + backend = ggml_backend_opencl_init(); + if (!backend) { + LOG_WARN("Failed to initialize OpenCL backend"); + } +#endif #ifdef SD_USE_SYCL LOG_DEBUG("Using SYCL backend"); backend = ggml_backend_sycl_init(0); @@ -200,16 +213,25 @@ class StableDiffusionGGML { } } + if (diffusion_model_path.size() > 0) { + LOG_INFO("loading diffusion model from '%s'", diffusion_model_path.c_str()); + if (!model_loader.init_from_file(diffusion_model_path, "model.diffusion_model.")) { + LOG_WARN("loading diffusion model from '%s' failed", diffusion_model_path.c_str()); + } + } + + bool is_unet = model_loader.model_is_unet(); + if (clip_l_path.size() > 0) { LOG_INFO("loading clip_l from '%s'", clip_l_path.c_str()); - if (!model_loader.init_from_file(clip_l_path, "text_encoders.clip_l.transformer.")) { + if (!model_loader.init_from_file(clip_l_path, is_unet ? "cond_stage_model.transformer." : "text_encoders.clip_l.transformer.")) { LOG_WARN("loading clip_l from '%s' failed", clip_l_path.c_str()); } } if (clip_g_path.size() > 0) { LOG_INFO("loading clip_g from '%s'", clip_g_path.c_str()); - if (!model_loader.init_from_file(clip_g_path, "text_encoders.clip_g.transformer.")) { + if (!model_loader.init_from_file(clip_g_path, is_unet ? "cond_stage_model.1.transformer." : "text_encoders.clip_g.transformer.")) { LOG_WARN("loading clip_g from '%s' failed", clip_g_path.c_str()); } } @@ -221,13 +243,6 @@ class StableDiffusionGGML { } } - if (diffusion_model_path.size() > 0) { - LOG_INFO("loading diffusion model from '%s'", diffusion_model_path.c_str()); - if (!model_loader.init_from_file(diffusion_model_path, "model.diffusion_model.")) { - LOG_WARN("loading diffusion model from '%s' failed", diffusion_model_path.c_str()); - } - } - if (vae_path.size() > 0) { LOG_INFO("loading vae from '%s'", vae_path.c_str()); if (!model_loader.init_from_file(vae_path, "vae.")) { @@ -274,10 +289,10 @@ class StableDiffusionGGML { model_loader.set_wtype_override(GGML_TYPE_F32, "vae."); } - LOG_INFO("Weight type: %s", model_wtype != SD_TYPE_COUNT ? ggml_type_name(model_wtype) : "??"); - LOG_INFO("Conditioner weight type: %s", conditioner_wtype != SD_TYPE_COUNT ? ggml_type_name(conditioner_wtype) : "??"); - LOG_INFO("Diffusion model weight type: %s", diffusion_model_wtype != SD_TYPE_COUNT ? ggml_type_name(diffusion_model_wtype) : "??"); - LOG_INFO("VAE weight type: %s", vae_wtype != SD_TYPE_COUNT ? ggml_type_name(vae_wtype) : "??"); + LOG_INFO("Weight type: %s", model_wtype != GGML_TYPE_COUNT ? ggml_type_name(model_wtype) : "??"); + LOG_INFO("Conditioner weight type: %s", conditioner_wtype != GGML_TYPE_COUNT ? ggml_type_name(conditioner_wtype) : "??"); + LOG_INFO("Diffusion model weight type: %s", diffusion_model_wtype != GGML_TYPE_COUNT ? ggml_type_name(diffusion_model_wtype) : "??"); + LOG_INFO("VAE weight type: %s", vae_wtype != GGML_TYPE_COUNT ? ggml_type_name(vae_wtype) : "??"); LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor)); @@ -334,8 +349,19 @@ class StableDiffusionGGML { cond_stage_model = std::make_shared(clip_backend, model_loader.tensor_storages_types); diffusion_model = std::make_shared(backend, model_loader.tensor_storages_types); } else if (sd_version_is_flux(version)) { - cond_stage_model = std::make_shared(clip_backend, model_loader.tensor_storages_types); - diffusion_model = std::make_shared(backend, model_loader.tensor_storages_types, version, diffusion_flash_attn); + bool is_chroma = false; + for (auto pair : model_loader.tensor_storages_types) { + if (pair.first.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) { + is_chroma = true; + break; + } + } + if (is_chroma) { + cond_stage_model = std::make_shared(clip_backend, model_loader.tensor_storages_types, -1, chroma_use_t5_mask, chroma_t5_mask_pad); + } else { + cond_stage_model = std::make_shared(clip_backend, model_loader.tensor_storages_types); + } + diffusion_model = std::make_shared(backend, model_loader.tensor_storages_types, version, diffusion_flash_attn, chroma_use_dit_mask); } else { if (id_embeddings_path.find("v2") != std::string::npos) { cond_stage_model = std::make_shared(clip_backend, model_loader.tensor_storages_types, embeddings_path, version, PM_VERSION_2); @@ -522,12 +548,17 @@ class StableDiffusionGGML { LOG_INFO("loading model from '%s' completed, taking %.2fs", model_path.c_str(), (t1 - t0) * 1.0f / 1000); // check is_using_v_parameterization_for_sd2 - bool is_using_v_parameterization = false; + if (sd_version_is_sd2(version)) { if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) { is_using_v_parameterization = true; } } else if (sd_version_is_sdxl(version)) { + if (model_loader.tensor_storages_types.find("edm_vpred.sigma_max") != model_loader.tensor_storages_types.end()) { + // CosXL models + // TODO: get sigma_min and sigma_max values from file + is_using_edm_v_parameterization = true; + } if (model_loader.tensor_storages_types.find("v_pred") != model_loader.tensor_storages_types.end()) { is_using_v_parameterization = true; } @@ -552,6 +583,9 @@ class StableDiffusionGGML { } else if (is_using_v_parameterization) { LOG_INFO("running in v-prediction mode"); denoiser = std::make_shared(); + } else if (is_using_edm_v_parameterization) { + LOG_INFO("running in v-prediction EDM mode"); + denoiser = std::make_shared(); } else { LOG_INFO("running in eps-prediction mode"); } @@ -618,7 +652,7 @@ class StableDiffusionGGML { int64_t t0 = ggml_time_ms(); struct ggml_tensor* out = ggml_dup_tensor(work_ctx, x_t); - diffusion_model->compute(n_threads, x_t, timesteps, c, concat, NULL, NULL, -1, {}, 0.f, &out); + diffusion_model->compute(n_threads, x_t, timesteps, c, concat, NULL, NULL, {}, -1, {}, 0.f, &out); diffusion_model->free_compute_buffer(); double result = 0.f; @@ -682,7 +716,7 @@ class StableDiffusionGGML { float curr_multiplier = kv.second; lora_state_diff[lora_name] -= curr_multiplier; } - + size_t rm = lora_state_diff.size() - lora_state.size(); if (rm != 0) { LOG_INFO("Attempting to apply %lu LoRAs (removing %lu applied LoRAs)", lora_state.size(), rm); @@ -800,11 +834,12 @@ class StableDiffusionGGML { const std::vector& sigmas, int start_merge_step, SDCondition id_cond, - std::vector skip_layers = {}, - float slg_scale = 0, - float skip_layer_start = 0.01, - float skip_layer_end = 0.2, - ggml_tensor* noise_mask = nullptr) { + std::vector ref_latents = {}, + std::vector skip_layers = {}, + float slg_scale = 0, + float skip_layer_start = 0.01, + float skip_layer_end = 0.2, + ggml_tensor* noise_mask = nullptr) { LOG_DEBUG("Sample"); struct ggml_init_params params; size_t data_size = ggml_row_size(init_latent->type, init_latent->ne[0]); @@ -887,6 +922,7 @@ class StableDiffusionGGML { cond.c_concat, cond.c_vector, guidance_tensor, + ref_latents, -1, controls, control_strength, @@ -899,6 +935,7 @@ class StableDiffusionGGML { cond.c_concat, id_cond.c_vector, guidance_tensor, + ref_latents, -1, controls, control_strength, @@ -919,6 +956,7 @@ class StableDiffusionGGML { uncond.c_concat, uncond.c_vector, guidance_tensor, + ref_latents, -1, controls, control_strength, @@ -939,6 +977,7 @@ class StableDiffusionGGML { cond.c_concat, cond.c_vector, guidance_tensor, + ref_latents, -1, controls, control_strength, @@ -1130,7 +1169,10 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, bool keep_clip_on_cpu, bool keep_control_net_cpu, bool keep_vae_on_cpu, - bool diffusion_flash_attn) { + bool diffusion_flash_attn, + bool chroma_use_dit_mask, + bool chroma_use_t5_mask, + int chroma_t5_mask_pad) { sd_ctx_t* sd_ctx = (sd_ctx_t*)malloc(sizeof(sd_ctx_t)); if (sd_ctx == NULL) { return NULL; @@ -1172,7 +1214,10 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, keep_clip_on_cpu, keep_control_net_cpu, keep_vae_on_cpu, - diffusion_flash_attn)) { + diffusion_flash_attn, + chroma_use_dit_mask, + chroma_use_t5_mask, + chroma_t5_mask_pad)) { delete sd_ctx->sd; sd_ctx->sd = NULL; free(sd_ctx); @@ -1209,6 +1254,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, float style_ratio, bool normalize_input, std::string input_id_images_path, + std::vector ref_latents, std::vector skip_layers = {}, float slg_scale = 0, float skip_layer_start = 0.01, @@ -1363,7 +1409,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, SDCondition uncond; if (cfg_scale != 1.0) { bool force_zero_embeddings = false; - if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0) { + if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0 && !sd_ctx->sd->is_using_edm_v_parameterization) { force_zero_embeddings = true; } uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx, @@ -1466,6 +1512,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, sigmas, start_merge_step, id_cond, + ref_latents, skip_layers, slg_scale, skip_layer_start, @@ -1521,6 +1568,29 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, return result_images; } +ggml_tensor* generate_init_latent(sd_ctx_t* sd_ctx, + ggml_context* work_ctx, + int width, + int height) { + int C = 4; + if (sd_version_is_sd3(sd_ctx->sd->version)) { + C = 16; + } else if (sd_version_is_flux(sd_ctx->sd->version)) { + C = 16; + } + int W = width / 8; + int H = height / 8; + ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1); + if (sd_version_is_sd3(sd_ctx->sd->version)) { + ggml_set_f32(init_latent, 0.0609f); + } else if (sd_version_is_flux(sd_ctx->sd->version)) { + ggml_set_f32(init_latent, 0.1159f); + } else { + ggml_set_f32(init_latent, 0.f); + } + return init_latent; +} + sd_image_t* txt2img(sd_ctx_t* sd_ctx, const char* prompt_c_str, const char* negative_prompt_c_str, @@ -1577,27 +1647,12 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps); - int C = 4; - if (sd_version_is_sd3(sd_ctx->sd->version)) { - C = 16; - } else if (sd_version_is_flux(sd_ctx->sd->version)) { - C = 16; - } - int W = width / 8; - int H = height / 8; - ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1); - if (sd_version_is_sd3(sd_ctx->sd->version)) { - ggml_set_f32(init_latent, 0.0609f); - } else if (sd_version_is_flux(sd_ctx->sd->version)) { - ggml_set_f32(init_latent, 0.1159f); - } else { - ggml_set_f32(init_latent, 0.f); - } - if (sd_version_is_inpaint(sd_ctx->sd->version)) { LOG_WARN("This is an inpainting model, this should only be used in img2img mode with a mask"); } + ggml_tensor* init_latent = generate_init_latent(sd_ctx, work_ctx, width, height); + sd_image_t* result_images = generate_image(sd_ctx, work_ctx, init_latent, @@ -1618,6 +1673,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, style_ratio, normalize_input, input_id_images_path_c_str, + {}, skip_layers_vec, slg_scale, skip_layer_start, @@ -1798,6 +1854,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, style_ratio, normalize_input, input_id_images_path_c_str, + {}, skip_layers_vec, slg_scale, skip_layer_start, @@ -1943,3 +2000,116 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx, return result_images; } + +sd_image_t* edit(sd_ctx_t* sd_ctx, + sd_image_t* ref_images, + int ref_images_count, + const char* prompt_c_str, + const char* negative_prompt_c_str, + int clip_skip, + float cfg_scale, + float guidance, + float eta, + int width, + int height, + sample_method_t sample_method, + int sample_steps, + float strength, + int64_t seed, + int batch_count, + const sd_image_t* control_cond, + float control_strength, + float style_ratio, + bool normalize_input, + int* skip_layers = NULL, + size_t skip_layers_count = 0, + float slg_scale = 0, + float skip_layer_start = 0.01, + float skip_layer_end = 0.2) { + std::vector skip_layers_vec(skip_layers, skip_layers + skip_layers_count); + LOG_DEBUG("edit %dx%d", width, height); + if (sd_ctx == NULL) { + return NULL; + } + if (ref_images_count <= 0) { + LOG_ERROR("ref images count should > 0"); + return NULL; + } + + struct ggml_init_params params; + params.mem_size = static_cast(30 * 1024 * 1024); // 10 MB + params.mem_size += width * height * 3 * sizeof(float) * 3 * ref_images_count; + params.mem_size *= batch_count; + params.mem_buffer = NULL; + params.no_alloc = false; + // LOG_DEBUG("mem_size %u ", params.mem_size); + + struct ggml_context* work_ctx = ggml_init(params); + if (!work_ctx) { + LOG_ERROR("ggml_init() failed"); + return NULL; + } + + if (seed < 0) { + srand((int)time(NULL)); + seed = rand(); + } + sd_ctx->sd->rng->manual_seed(seed); + + size_t t0 = ggml_time_ms(); + + std::vector ref_latents; + for (int i = 0; i < ref_images_count; i++) { + ggml_tensor* img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, ref_images[i].width, ref_images[i].height, 3, 1); + sd_image_to_tensor(ref_images[i].data, img); + + ggml_tensor* latent = NULL; + if (!sd_ctx->sd->use_tiny_autoencoder) { + ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, img); + latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments); + } else { + latent = sd_ctx->sd->encode_first_stage(work_ctx, img); + } + ref_latents.push_back(latent); + } + + size_t t1 = ggml_time_ms(); + LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); + + std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps); + + ggml_tensor* init_latent = generate_init_latent(sd_ctx, work_ctx, width, height); + + sd_image_t* result_images = generate_image(sd_ctx, + work_ctx, + init_latent, + prompt_c_str, + negative_prompt_c_str, + clip_skip, + cfg_scale, + guidance, + eta, + width, + height, + sample_method, + sigmas, + seed, + batch_count, + control_cond, + control_strength, + style_ratio, + normalize_input, + "", + ref_latents, + skip_layers_vec, + slg_scale, + skip_layer_start, + skip_layer_end, + NULL); + + size_t t2 = ggml_time_ms(); + + LOG_INFO("edit completed in %.2fs", (t2 - t0) * 1.0f / 1000); + + return result_images; +} \ No newline at end of file diff --git a/stable-diffusion.h b/stable-diffusion.h index 52dcc848a..212e1c918 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -61,10 +61,10 @@ enum schedule_t { // same as enum ggml_type enum sd_type_t { - SD_TYPE_F32 = 0, - SD_TYPE_F16 = 1, - SD_TYPE_Q4_0 = 2, - SD_TYPE_Q4_1 = 3, + SD_TYPE_F32 = 0, + SD_TYPE_F16 = 1, + SD_TYPE_Q4_0 = 2, + SD_TYPE_Q4_1 = 3, // SD_TYPE_Q4_2 = 4, support has been removed // SD_TYPE_Q4_3 = 5, support has been removed SD_TYPE_Q5_0 = 6, @@ -95,12 +95,12 @@ enum sd_type_t { // SD_TYPE_Q4_0_4_4 = 31, support has been removed from gguf files // SD_TYPE_Q4_0_4_8 = 32, // SD_TYPE_Q4_0_8_8 = 33, - SD_TYPE_TQ1_0 = 34, - SD_TYPE_TQ2_0 = 35, + SD_TYPE_TQ1_0 = 34, + SD_TYPE_TQ2_0 = 35, // SD_TYPE_IQ4_NL_4_4 = 36, // SD_TYPE_IQ4_NL_4_8 = 37, // SD_TYPE_IQ4_NL_8_8 = 38, - SD_TYPE_COUNT = 39, + SD_TYPE_COUNT = 39, }; SD_API const char* sd_type_name(enum sd_type_t type); @@ -150,7 +150,10 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path, bool keep_clip_on_cpu, bool keep_control_net_cpu, bool keep_vae_on_cpu, - bool diffusion_flash_attn); + bool diffusion_flash_attn, + bool chroma_use_dit_mask, + bool chroma_use_t5_mask, + int chroma_t5_mask_pad); SD_API void free_sd_ctx(sd_ctx_t* sd_ctx); @@ -220,6 +223,32 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx, float strength, int64_t seed); +SD_API sd_image_t* edit(sd_ctx_t* sd_ctx, + sd_image_t* ref_images, + int ref_images_count, + const char* prompt, + const char* negative_prompt, + int clip_skip, + float cfg_scale, + float guidance, + float eta, + int width, + int height, + enum sample_method_t sample_method, + int sample_steps, + float strength, + int64_t seed, + int batch_count, + const sd_image_t* control_cond, + float control_strength, + float style_strength, + bool normalize_input, + int* skip_layers, + size_t skip_layers_count, + float slg_scale, + float skip_layer_start, + float skip_layer_end); + typedef struct upscaler_ctx_t upscaler_ctx_t; SD_API upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path, @@ -228,7 +257,7 @@ SD_API void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx); SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_t upscale_factor); -SD_API bool convert(const char* input_path, const char* vae_path, const char* output_path, enum sd_type_t output_type); +SD_API bool convert(const char* input_path, const char* vae_path, const char* output_path, enum sd_type_t output_type, const char* tensor_type_rules); SD_API uint8_t* preprocess_canny(uint8_t* img, int width, diff --git a/t5.hpp b/t5.hpp index 2a53e2743..d511ef24b 100644 --- a/t5.hpp +++ b/t5.hpp @@ -385,6 +385,7 @@ class T5UniGramTokenizer { void pad_tokens(std::vector& tokens, std::vector& weights, + std::vector* attention_mask, size_t max_length = 0, bool padding = false) { if (max_length > 0 && padding) { @@ -397,11 +398,15 @@ class T5UniGramTokenizer { LOG_DEBUG("token length: %llu", length); std::vector new_tokens; std::vector new_weights; + std::vector new_attention_mask; int token_idx = 0; for (int i = 0; i < length; i++) { if (token_idx >= orig_token_num) { break; } + if (attention_mask != nullptr) { + new_attention_mask.push_back(0.0); + } if (i % max_length == max_length - 1) { new_tokens.push_back(eos_id_); new_weights.push_back(1.0); @@ -414,13 +419,24 @@ class T5UniGramTokenizer { new_tokens.push_back(eos_id_); new_weights.push_back(1.0); + if (attention_mask != nullptr) { + new_attention_mask.push_back(0.0); + } + tokens = new_tokens; weights = new_weights; + if (attention_mask != nullptr) { + *attention_mask = new_attention_mask; + } if (padding) { int pad_token_id = pad_id_; tokens.insert(tokens.end(), length - tokens.size(), pad_token_id); weights.insert(weights.end(), length - weights.size(), 1.0); + if (attention_mask != nullptr) { + // maybe keep some padding tokens unmasked? + attention_mask->insert(attention_mask->end(), length - attention_mask->size(), -HUGE_VALF); + } } } } @@ -579,6 +595,7 @@ class T5Attention : public GGMLBlock { } if (past_bias != NULL) { if (mask != NULL) { + mask = ggml_repeat(ctx, mask, past_bias); mask = ggml_add(ctx, mask, past_bias); } else { mask = past_bias; @@ -739,15 +756,17 @@ struct T5Runner : public GGMLRunner { struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* input_ids, - struct ggml_tensor* relative_position_bucket) { + struct ggml_tensor* relative_position_bucket, + struct ggml_tensor* attention_mask = NULL) { size_t N = input_ids->ne[1]; size_t n_token = input_ids->ne[0]; - auto hidden_states = model.forward(ctx, input_ids, NULL, NULL, relative_position_bucket); // [N, n_token, model_dim] + auto hidden_states = model.forward(ctx, input_ids, NULL, attention_mask, relative_position_bucket); // [N, n_token, model_dim] return hidden_states; } - struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids) { + struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids, + struct ggml_tensor* attention_mask = NULL) { struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); input_ids = to_backend(input_ids); @@ -767,7 +786,7 @@ struct T5Runner : public GGMLRunner { input_ids->ne[0]); set_backend_tensor_data(relative_position_bucket, relative_position_bucket_vec.data()); - struct ggml_tensor* hidden_states = forward(compute_ctx, input_ids, relative_position_bucket); + struct ggml_tensor* hidden_states = forward(compute_ctx, input_ids, relative_position_bucket, attention_mask); ggml_build_forward_expand(gf, hidden_states); @@ -776,10 +795,11 @@ struct T5Runner : public GGMLRunner { void compute(const int n_threads, struct ggml_tensor* input_ids, + struct ggml_tensor* attention_mask, ggml_tensor** output, ggml_context* output_ctx = NULL) { auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(input_ids); + return build_graph(input_ids, attention_mask); }; GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); } @@ -877,9 +897,9 @@ struct T5Embedder { model.alloc_params_buffer(); } - std::pair, std::vector> tokenize(std::string text, - size_t max_length = 0, - bool padding = false) { + std::tuple, std::vector, std::vector> tokenize(std::string text, + size_t max_length = 0, + bool padding = false) { auto parsed_attention = parse_prompt_attention(text); { @@ -906,14 +926,16 @@ struct T5Embedder { tokens.push_back(EOS_TOKEN_ID); weights.push_back(1.0); - tokenizer.pad_tokens(tokens, weights, max_length, padding); + std::vector attention_mask; + + tokenizer.pad_tokens(tokens, weights, &attention_mask, max_length, padding); // for (int i = 0; i < tokens.size(); i++) { // std::cout << tokens[i] << ":" << weights[i] << ", "; // } // std::cout << std::endl; - return {tokens, weights}; + return {tokens, weights, attention_mask}; } void test() { @@ -934,8 +956,8 @@ struct T5Embedder { // TODO: fix cuda nan std::string text("a lovely cat"); auto tokens_and_weights = tokenize(text, 77, true); - std::vector& tokens = tokens_and_weights.first; - std::vector& weights = tokens_and_weights.second; + std::vector& tokens = std::get<0>(tokens_and_weights); + std::vector& weights = std::get<1>(tokens_and_weights); for (auto token : tokens) { printf("%d ", token); } @@ -944,7 +966,7 @@ struct T5Embedder { struct ggml_tensor* out = NULL; int t0 = ggml_time_ms(); - model.compute(8, input_ids, &out, work_ctx); + model.compute(8, input_ids, NULL, &out, work_ctx); int t1 = ggml_time_ms(); print_ggml_tensor(out); diff --git a/tae.hpp b/tae.hpp index c458b87d2..678c44c57 100644 --- a/tae.hpp +++ b/tae.hpp @@ -149,7 +149,7 @@ class TinyDecoder : public UnaryBlock { if (i == 1) { h = ggml_relu_inplace(ctx, h); } else { - h = ggml_upscale(ctx, h, 2); + h = ggml_upscale(ctx, h, 2, GGML_SCALE_MODE_NEAREST); } continue; } diff --git a/upscaler.cpp b/upscaler.cpp index 0c11b666e..137213496 100644 --- a/upscaler.cpp +++ b/upscaler.cpp @@ -28,6 +28,10 @@ struct UpscalerGGML { LOG_DEBUG("Using Vulkan backend"); backend = ggml_backend_vk_init(0); #endif +#ifdef SD_USE_OPENCL + LOG_DEBUG("Using OpenCL backend"); + backend = ggml_backend_opencl_init(); +#endif #ifdef SD_USE_SYCL LOG_DEBUG("Using SYCL backend"); backend = ggml_backend_sycl_init(0); diff --git a/util.cpp b/util.cpp index da11a14d6..631c12066 100644 --- a/util.cpp +++ b/util.cpp @@ -112,7 +112,7 @@ std::vector get_files_from_dir(const std::string& dir) { sprintf(directoryPath, "%s\\%s\\*", currentDirectory, dir.c_str()); // Find the first file in the directory - hFind = FindFirstFile(directoryPath, &findFileData); + hFind = FindFirstFile(directoryPath, &findFileData); bool isAbsolutePath = false; // Check if the directory was found if (hFind == INVALID_HANDLE_VALUE) { @@ -121,7 +121,7 @@ std::vector get_files_from_dir(const std::string& dir) { char directoryPathAbsolute[MAX_PATH]; sprintf(directoryPathAbsolute, "%s*", dir.c_str()); - hFind = FindFirstFile(directoryPathAbsolute, &findFileData); + hFind = FindFirstFile(directoryPathAbsolute, &findFileData); isAbsolutePath = true; if (hFind == INVALID_HANDLE_VALUE) { printf("Absolute path was also wrong.\n");