diff --git a/.github/workflows/build-linux-cross.yml b/.github/workflows/build-linux-cross.yml index 7cfc82ba4e277..04ad187d35c09 100644 --- a/.github/workflows/build-linux-cross.yml +++ b/.github/workflows/build-linux-cross.yml @@ -48,98 +48,98 @@ jobs: cmake --build build --config Release -j $(nproc) - ubuntu-24-riscv64-vulkan-cross: - runs-on: ubuntu-24.04 - - steps: - - uses: actions/checkout@v4 - - name: Setup Riscv - run: | - sudo dpkg --add-architecture riscv64 - - # Add arch-specific repositories for non-amd64 architectures - cat << EOF | sudo tee /etc/apt/sources.list.d/riscv64-ports.list - deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble main universe - deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe - deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe - deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe - EOF - - sudo apt-get update || true ;# Prevent failure due to missing URLs. - - sudo apt-get install -y --no-install-recommends \ - build-essential \ - glslc \ - gcc-14-riscv64-linux-gnu \ - g++-14-riscv64-linux-gnu \ - libvulkan-dev:riscv64 - - - name: Build - run: | - cmake -B build -DLLAMA_CURL=OFF \ - -DCMAKE_BUILD_TYPE=Release \ - -DGGML_VULKAN=ON \ - -DGGML_OPENMP=OFF \ - -DLLAMA_BUILD_EXAMPLES=ON \ - -DLLAMA_BUILD_TOOLS=ON \ - -DLLAMA_BUILD_TESTS=OFF \ - -DCMAKE_SYSTEM_NAME=Linux \ - -DCMAKE_SYSTEM_PROCESSOR=riscv64 \ - -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ - -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \ - -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ - -DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \ - -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ - -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ - -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH - - cmake --build build --config Release -j $(nproc) - - ubuntu-24-arm64-vulkan-cross: - runs-on: ubuntu-24.04 - - steps: - - uses: actions/checkout@v4 - - name: Setup Arm64 - run: | - sudo dpkg --add-architecture arm64 - - # Add arch-specific repositories for non-amd64 architectures - cat << EOF | sudo tee /etc/apt/sources.list.d/arm64-ports.list - deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble main universe - deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe - deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe - deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe - EOF - - sudo apt-get update || true ;# Prevent failure due to missing URLs. - - sudo apt-get install -y --no-install-recommends \ - build-essential \ - glslc \ - crossbuild-essential-arm64 \ - libvulkan-dev:arm64 - - - name: Build - run: | - cmake -B build -DLLAMA_CURL=OFF \ - -DCMAKE_BUILD_TYPE=Release \ - -DGGML_VULKAN=ON \ - -DGGML_OPENMP=OFF \ - -DLLAMA_BUILD_EXAMPLES=ON \ - -DLLAMA_BUILD_TOOLS=ON \ - -DLLAMA_BUILD_TESTS=OFF \ - -DCMAKE_SYSTEM_NAME=Linux \ - -DCMAKE_SYSTEM_PROCESSOR=aarch64 \ - -DCMAKE_C_COMPILER=aarch64-linux-gnu-gcc \ - -DCMAKE_CXX_COMPILER=aarch64-linux-gnu-g++ \ - -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ - -DCMAKE_FIND_ROOT_PATH=/usr/lib/aarch64-linux-gnu \ - -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ - -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ - -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH - - cmake --build build --config Release -j $(nproc) + # ubuntu-24-riscv64-vulkan-cross: + # runs-on: ubuntu-24.04 + + # steps: + # - uses: actions/checkout@v4 + # - name: Setup Riscv + # run: | + # sudo dpkg --add-architecture riscv64 + + # # Add arch-specific repositories for non-amd64 architectures + # cat << EOF | sudo tee /etc/apt/sources.list.d/riscv64-ports.list + # deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble main universe + # deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe + # deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe + # deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe + # EOF + + # sudo apt-get update || true ;# Prevent failure due to missing URLs. + + # sudo apt-get install -y --no-install-recommends \ + # build-essential \ + # glslc \ + # gcc-14-riscv64-linux-gnu \ + # g++-14-riscv64-linux-gnu \ + # libvulkan-dev:riscv64 + + # - name: Build + # run: | + # cmake -B build -DLLAMA_CURL=OFF \ + # -DCMAKE_BUILD_TYPE=Release \ + # -DGGML_VULKAN=ON \ + # -DGGML_OPENMP=OFF \ + # -DLLAMA_BUILD_EXAMPLES=ON \ + # -DLLAMA_BUILD_TOOLS=ON \ + # -DLLAMA_BUILD_TESTS=OFF \ + # -DCMAKE_SYSTEM_NAME=Linux \ + # -DCMAKE_SYSTEM_PROCESSOR=riscv64 \ + # -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ + # -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \ + # -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + # -DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \ + # -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ + # -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ + # -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH + + # cmake --build build --config Release -j $(nproc) + + # ubuntu-24-arm64-vulkan-cross: + # runs-on: ubuntu-24.04 + + # steps: + # - uses: actions/checkout@v4 + # - name: Setup Arm64 + # run: | + # sudo dpkg --add-architecture arm64 + + # # Add arch-specific repositories for non-amd64 architectures + # cat << EOF | sudo tee /etc/apt/sources.list.d/arm64-ports.list + # deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble main universe + # deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe + # deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe + # deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe + # EOF + + # sudo apt-get update || true ;# Prevent failure due to missing URLs. + + # sudo apt-get install -y --no-install-recommends \ + # build-essential \ + # glslc \ + # crossbuild-essential-arm64 \ + # libvulkan-dev:arm64 + + # - name: Build + # run: | + # cmake -B build -DLLAMA_CURL=OFF \ + # -DCMAKE_BUILD_TYPE=Release \ + # -DGGML_VULKAN=ON \ + # -DGGML_OPENMP=OFF \ + # -DLLAMA_BUILD_EXAMPLES=ON \ + # -DLLAMA_BUILD_TOOLS=ON \ + # -DLLAMA_BUILD_TESTS=OFF \ + # -DCMAKE_SYSTEM_NAME=Linux \ + # -DCMAKE_SYSTEM_PROCESSOR=aarch64 \ + # -DCMAKE_C_COMPILER=aarch64-linux-gnu-gcc \ + # -DCMAKE_CXX_COMPILER=aarch64-linux-gnu-g++ \ + # -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + # -DCMAKE_FIND_ROOT_PATH=/usr/lib/aarch64-linux-gnu \ + # -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ + # -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ + # -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH + + # cmake --build build --config Release -j $(nproc) ubuntu-24-ppc64el-cpu-cross: runs-on: ubuntu-24.04 @@ -185,52 +185,52 @@ jobs: cmake --build build --config Release -j $(nproc) - ubuntu-24-ppc64el-vulkan-cross: - runs-on: ubuntu-24.04 - - steps: - - uses: actions/checkout@v4 - - name: Setup PowerPC64le - run: | - sudo dpkg --add-architecture ppc64el - - # Add arch-specific repositories for non-amd64 architectures - cat << EOF | sudo tee /etc/apt/sources.list.d/ppc64el-ports.list - deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble main universe - deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe - deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe - deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe - EOF - - sudo apt-get update || true ;# Prevent failure due to missing URLs. - - sudo apt-get install -y --no-install-recommends \ - build-essential \ - glslc \ - gcc-14-powerpc64le-linux-gnu \ - g++-14-powerpc64le-linux-gnu \ - libvulkan-dev:ppc64el - - - name: Build - run: | - cmake -B build -DLLAMA_CURL=OFF \ - -DCMAKE_BUILD_TYPE=Release \ - -DGGML_VULKAN=ON \ - -DGGML_OPENMP=OFF \ - -DLLAMA_BUILD_EXAMPLES=ON \ - -DLLAMA_BUILD_TOOLS=ON \ - -DLLAMA_BUILD_TESTS=OFF \ - -DCMAKE_SYSTEM_NAME=Linux \ - -DCMAKE_SYSTEM_PROCESSOR=ppc64 \ - -DCMAKE_C_COMPILER=powerpc64le-linux-gnu-gcc-14 \ - -DCMAKE_CXX_COMPILER=powerpc64le-linux-gnu-g++-14 \ - -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ - -DCMAKE_FIND_ROOT_PATH=/usr/lib/powerpc64le-linux-gnu \ - -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ - -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ - -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH - - cmake --build build --config Release -j $(nproc) + # ubuntu-24-ppc64el-vulkan-cross: + # runs-on: ubuntu-24.04 + + # steps: + # - uses: actions/checkout@v4 + # - name: Setup PowerPC64le + # run: | + # sudo dpkg --add-architecture ppc64el + + # # Add arch-specific repositories for non-amd64 architectures + # cat << EOF | sudo tee /etc/apt/sources.list.d/ppc64el-ports.list + # deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble main universe + # deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe + # deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe + # deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe + # EOF + + # sudo apt-get update || true ;# Prevent failure due to missing URLs. + + # sudo apt-get install -y --no-install-recommends \ + # build-essential \ + # glslc \ + # gcc-14-powerpc64le-linux-gnu \ + # g++-14-powerpc64le-linux-gnu \ + # libvulkan-dev:ppc64el + + # - name: Build + # run: | + # cmake -B build -DLLAMA_CURL=OFF \ + # -DCMAKE_BUILD_TYPE=Release \ + # -DGGML_VULKAN=ON \ + # -DGGML_OPENMP=OFF \ + # -DLLAMA_BUILD_EXAMPLES=ON \ + # -DLLAMA_BUILD_TOOLS=ON \ + # -DLLAMA_BUILD_TESTS=OFF \ + # -DCMAKE_SYSTEM_NAME=Linux \ + # -DCMAKE_SYSTEM_PROCESSOR=ppc64 \ + # -DCMAKE_C_COMPILER=powerpc64le-linux-gnu-gcc-14 \ + # -DCMAKE_CXX_COMPILER=powerpc64le-linux-gnu-g++-14 \ + # -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + # -DCMAKE_FIND_ROOT_PATH=/usr/lib/powerpc64le-linux-gnu \ + # -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ + # -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ + # -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH + + # cmake --build build --config Release -j $(nproc) debian-13-loongarch64-cpu-cross: runs-on: ubuntu-24.04 diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 788d7a1d10bd0..5bd988b7f7ce3 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -135,6 +135,69 @@ jobs: cd build ctest -L main --verbose --timeout 900 + macOS-latest-cmake-arm64-webgpu: + runs-on: macos-14 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: macOS-latest-cmake-arm64-webgpu + evict-old-files: 1d + + - name: Dependencies + id: depends + continue-on-error: true + run: | + brew update + brew install curl + + - name: Dawn Dependency + id: dawn-depends + run: | + ARTIFACTS_JSON=$(curl -s -L \ + -H "Accept: application/vnd.github+json" \ + -H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \ + -H "X-GitHub-Api-Version: 2022-11-28" \ + "https://api.github.com/repos/google/dawn/actions/artifacts") + echo "Finding latest macos-latest-Release artifact..." + DOWNLOAD_URL=$(echo "$ARTIFACTS_JSON" | jq -r '.artifacts + | sort_by(.created_at) + | reverse + | map(select(.name | test("macos-latest-Release$"))) + | .[0].archive_download_url') + if [ "$DOWNLOAD_URL" = "null" ] || [ -z "$DOWNLOAD_URL" ]; then + echo "No suitable Dawn artifact found!" + exit 1 + fi + echo "Downloading from: $DOWNLOAD_URL" + curl -L \ + -H "Accept: application/vnd.github+json" \ + -H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \ + -o artifact.zip "$DOWNLOAD_URL" + unzip artifact.zip + mkdir dawn + tar_file=$(find . -name '*.tar.gz' | head -n 1) + echo "Extracting: $tar_file" + tar -xvf "$tar_file" -C dawn --strip-components=1 + + - name: Build + id: cmake_build + run: | + export CMAKE_PREFIX_PATH=dawn + cmake -B build -DGGML_WEBGPU=ON -DGGML_METAL=OFF -DGGML_BLAS=OFF + cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) + + - name: Test + id: cmake_test + run: | + cd build + ctest -L main --verbose --timeout 900 + ubuntu-cpu-cmake: strategy: matrix: @@ -344,6 +407,72 @@ jobs: # This is using llvmpipe and runs slower than other backends ctest -L main --verbose --timeout 4200 + ubuntu-22-cmake-webgpu: + runs-on: ubuntu-22.04 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: ubuntu-22-cmake-webgpu + evict-old-files: 1d + + - name: Vulkan SDK Dependencies + id: vulkan-depends + run: | + wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | sudo apt-key add - + sudo wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list + sudo apt-get update -y + sudo apt-get install -y build-essential mesa-vulkan-drivers vulkan-sdk libcurl4-openssl-dev + + - name: Dawn Dependency + id: dawn-depends + run: | + sudo apt-get install -y libxrandr-dev libxinerama-dev libxcursor-dev mesa-common-dev libx11-xcb-dev libxi-dev + ARTIFACTS_JSON=$(curl -s -L \ + -H "Accept: application/vnd.github+json" \ + -H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \ + -H "X-GitHub-Api-Version: 2022-11-28" \ + "https://api.github.com/repos/google/dawn/actions/artifacts") + echo "Finding latest ubuntu-latest-Release artifact..." + DOWNLOAD_URL=$(echo "$ARTIFACTS_JSON" | jq -r '.artifacts + | sort_by(.created_at) + | reverse + | map(select(.name | test("ubuntu-latest-Release$"))) + | .[0].archive_download_url') + if [ "$DOWNLOAD_URL" = "null" ] || [ -z "$DOWNLOAD_URL" ]; then + echo "No suitable Dawn artifact found!" + exit 1 + fi + echo "Downloading from: $DOWNLOAD_URL" + curl -L \ + -H "Accept: application/vnd.github+json" \ + -H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \ + -o artifact.zip "$DOWNLOAD_URL" + unzip artifact.zip + mkdir dawn + tar_file=$(find . -name '*.tar.gz' | head -n 1) + echo "Extracting: $tar_file" + tar -xvf "$tar_file" -C dawn --strip-components=1 + + - name: Build + id: cmake_build + run: | + export Dawn_DIR=dawn/lib64/cmake/Dawn + cmake -B build -DGGML_WEBGPU=ON + cmake --build build --config Release -j $(nproc) + + - name: Test + id: cmake_test + run: | + cd build + # This is using llvmpipe and runs slower than other backends + ctest -L main --verbose --timeout 3600 + ubuntu-22-cmake-hip: runs-on: ubuntu-22.04 container: rocm/dev-ubuntu-22.04:6.0.2 diff --git a/CMakePresets.json b/CMakePresets.json index e9844701304fc..b5afeb3c0f2f9 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -55,6 +55,17 @@ "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/arm64-apple-clang.cmake" } }, + { + "name": "x64-linux-gcc", "hidden": true, + "cacheVariables": { + "CMAKE_C_COMPILER": "gcc", + "CMAKE_CXX_COMPILER": "g++" + } + }, + { "name": "x64-linux-gcc-debug", "inherits": [ "base", "x64-linux-gcc", "debug" ] }, + { "name": "x64-linux-gcc-release", "inherits": [ "base", "x64-linux-gcc", "release" ] }, + { "name": "x64-linux-gcc-reldbg", "inherits": [ "base", "x64-linux-gcc", "reldbg" ] }, + { "name": "x64-linux-gcc+static-release", "inherits": [ "base", "x64-linux-gcc", "release", "static" ] }, { "name": "arm64-windows-llvm-debug", "inherits": [ "base", "arm64-windows-llvm", "debug" ] }, { "name": "arm64-windows-llvm-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg" ] }, diff --git a/README.md b/README.md index 3bac4288ffd71..edde61238cb5f 100644 --- a/README.md +++ b/README.md @@ -269,6 +269,8 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo | [Vulkan](docs/build.md#vulkan) | GPU | | [CANN](docs/build.md#cann) | Ascend NPU | | [OpenCL](docs/backend/OPENCL.md) | Adreno GPU | +| [WebGPU [In Progress]](docs/build.md#webgpu) | All | + | [RPC](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) | All | ## Obtaining and quantizing models diff --git a/ci/run.sh b/ci/run.sh index 1146f86b64e27..4d3abf9232212 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -16,6 +16,9 @@ # # with VULKAN support # GG_BUILD_VULKAN=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt # +# # with WebGPU support +# GG_BUILD_WEBGPU=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt +# # # with MUSA support # GG_BUILD_MUSA=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt # @@ -81,6 +84,10 @@ if [ ! -z ${GG_BUILD_VULKAN} ]; then CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_VULKAN=1" fi +if [ ! -z ${GG_BUILD_WEBGPU} ]; then + CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_WEBGPU=1" +fi + if [ ! -z ${GG_BUILD_MUSA} ]; then # Use qy1 by default (MTT S80) MUSA_ARCH=${MUSA_ARCH:-21} diff --git a/common/arg.cpp b/common/arg.cpp index 56827a65908be..c1151f51da17b 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1464,6 +1464,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.swa_full = true; } ).set_env("LLAMA_ARG_SWA_FULL")); + add_opt(common_arg( + {"--kv-unified", "-kvu"}, + string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n" + "[(more info)](https://github.com/ggml-org/llama.cpp/pull/14363)", params.kv_unified ? "true" : "false"), + [](common_params & params) { + params.kv_unified = true; + } + ).set_env("LLAMA_ARG_KV_SPLIT")); add_opt(common_arg( {"--no-context-shift"}, string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"), @@ -3423,5 +3431,34 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_examples({LLAMA_EXAMPLE_SERVER})); + // diffusion parameters + add_opt(common_arg( + { "--diffusion-steps" }, "N", + string_format("number of diffusion steps (default: %d)", params.diffusion.steps), + [](common_params & params, int value) { params.diffusion.steps = value; } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + { "--diffusion-eps" }, "F", + string_format("epsilon for timesteps (default: %.6f)", (double) params.diffusion.eps), + [](common_params & params, const std::string & value) { params.diffusion.eps = std::stof(value); } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + { "--diffusion-algorithm" }, "N", + string_format("diffusion algorithm: 0=ORIGIN, 1=MASKGIT_PLUS, 2=TOPK_MARGIN, 3=ENTROPY (default: %d)", + params.diffusion.algorithm), + [](common_params & params, int value) { params.diffusion.algorithm = value; } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + { "--diffusion-alg-temp" }, "F", + string_format("algorithm temperature (default: %.3f)", (double) params.diffusion.alg_temp), + [](common_params & params, const std::string & value) { params.diffusion.alg_temp = std::stof(value); } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + add_opt(common_arg( + { "--diffusion-visual" }, + string_format("enable visual diffusion mode (show progressive generation) (default: %s)", + params.diffusion.visual_mode ? "true" : "false"), + [](common_params & params) { params.diffusion.visual_mode = true; } + ).set_examples({ LLAMA_EXAMPLE_DIFFUSION })); + return ctx_arg; } diff --git a/common/common.cpp b/common/common.cpp index e4e71ad13fb59..466271be61c63 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1005,15 +1005,21 @@ struct common_init_result common_init_from_params(common_params & params) { params.sampling.ignore_eos = false; } - if (params.sampling.ignore_eos) { - for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) { - if (llama_vocab_is_eog(vocab, i)) { - LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY); - params.sampling.logit_bias.push_back({i, -INFINITY}); - } + // initialize once + for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) { + if (llama_vocab_is_eog(vocab, i)) { + LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY); + params.sampling.logit_bias_eog.push_back({i, -INFINITY}); } } + if (params.sampling.ignore_eos) { + // add EOG biases to the active set of logit biases + params.sampling.logit_bias.insert( + params.sampling.logit_bias.end(), + params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end()); + } + if (params.sampling.penalty_last_n == -1) { LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx)); params.sampling.penalty_last_n = llama_n_ctx(lctx); @@ -1157,6 +1163,7 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.no_perf = params.no_perf; cparams.op_offload = !params.no_op_offload; cparams.swa_full = params.swa_full; + cparams.kv_unified = params.kv_unified; cparams.type_k = params.cache_type_k; cparams.type_v = params.cache_type_v; diff --git a/common/common.h b/common/common.h index a5abe32859fdd..27adf552465e7 100644 --- a/common/common.h +++ b/common/common.h @@ -81,6 +81,7 @@ enum llama_example { LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_PARALLEL, LLAMA_EXAMPLE_TTS, + LLAMA_EXAMPLE_DIFFUSION, LLAMA_EXAMPLE_COUNT, }; @@ -177,7 +178,8 @@ struct common_params_sampling { std::vector grammar_triggers; // optional triggers (for lazy grammars) std::set preserved_tokens; - std::vector logit_bias; // logit biases to apply + std::vector logit_bias; // logit biases to apply + std::vector logit_bias_eog; // pre-calculated logit biases for EOG tokens // print the parameters into a string std::string print() const; @@ -217,6 +219,14 @@ struct common_params_vocoder { bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT }; +struct common_params_diffusion { + int32_t steps = 64; // number of diffusion steps + float eps = 1e-3f; // epsilon for timesteps + int32_t algorithm = 0; // diffusion algorithm (0=ORIGIN, 1=MASKGIT_PLUS, 2=TOPK_MARGIN, 3=ENTROPY) + float alg_temp = 0.0f; // algorithm temperature + bool visual_mode = false; // show progressive diffusion on screen +}; + enum common_reasoning_format { COMMON_REASONING_FORMAT_NONE, COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in tags in stream mode @@ -268,6 +278,7 @@ struct common_params { struct common_params_sampling sampling; struct common_params_speculative speculative; struct common_params_vocoder vocoder; + struct common_params_diffusion diffusion; struct common_params_model model; @@ -330,6 +341,7 @@ struct common_params { bool no_perf = false; // disable performance metrics bool ctx_shift = true; // context shift on inifinite text generation bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) + bool kv_unified = false; // enable unified KV cache bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool use_mmap = true; // use mmap for faster loads diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 8afb425b156f2..d802524bba4a0 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -669,6 +669,36 @@ def get_vocab_base_pre(self, tokenizer) -> str: # NOTE: if you get an error here, you need to update the convert_hf_to_gguf_update.py script # or pull the latest version of the model from Huggingface # don't edit the hashes manually! + if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b": + # ref: https://huggingface.co/THUDM/glm-4-9b-chat + res = "chatglm-bpe" + if chkhsh == "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516": + # ref: https://huggingface.co/THUDM/glm-4-9b-chat + res = "chatglm-bpe" + if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2": + # ref: https://huggingface.co/THUDM/glm-4-9b-hf + res = "glm4" + if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35": + # ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0 + res = "minerva-7b" + if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664": + # ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct + res = "hunyuan" + if chkhsh == "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6": + # ref: https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base + res = "falcon-h1" + if chkhsh == "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86": + # ref: https://huggingface.co/tiiuae/Falcon-H1-1B-Base + res = "falcon-h1" + if chkhsh == "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896": + # ref: https://huggingface.co/tiiuae/Falcon-H1-7B-Base + res = "falcon-h1" + if chkhsh == "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b": + # ref: https://huggingface.co/tiiuae/Falcon-H1-34B-Base + res = "falcon-h1" + if chkhsh == "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890": + # ref: https://huggingface.co/moonshotai/Kimi-K2-Base + res = "kimi-k2" if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5": # ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B res = "llama-bpe" @@ -804,36 +834,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "d5f1dd6f980fec569fb218a81a7658ac45fc56b38c5a0adeb1c232fbe04ef5ec": # ref: https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base res = "seed-coder" - if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b": - # ref: https://huggingface.co/THUDM/glm-4-9b-chat - res = "chatglm-bpe" - if chkhsh == "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516": - # ref: https://huggingface.co/THUDM/glm-4-9b-chat - res = "chatglm-bpe" - if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2": - # ref: https://huggingface.co/THUDM/glm-4-9b-hf - res = "glm4" - if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35": - # ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0 - res = "minerva-7b" - if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664": - # ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct - res = "hunyuan" if chkhsh == "b0a6b1c0bd5998ebd9df08611efde34a4ff03faed45ae09c43e6b31ebd4b94cf": # ref: https://huggingface.co/skt/A.X-4.0 res = "a.x-4.0" - if chkhsh == "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6": - # ref: https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base - res = "falcon-h1" - if chkhsh == "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86": - # ref: https://huggingface.co/tiiuae/Falcon-H1-1B-Base - res = "falcon-h1" - if chkhsh == "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896": - # ref: https://huggingface.co/tiiuae/Falcon-H1-7B-Base - res = "falcon-h1" - if chkhsh == "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b": - # ref: https://huggingface.co/tiiuae/Falcon-H1-34B-Base - res = "falcon-h1" if chkhsh == "f6791d196f87ce6b56a7d234be618e0d58f8cda3549416635b2bebcd22cd95c4": # ref: https://huggingface.co/K-intelligence/Midm-2.0-Base-Instruct res = "midm-2.0" @@ -1082,7 +1085,14 @@ def _set_vocab_rwkv_world(self): self.gguf_writer.add_token_list(tokens) self.gguf_writer.add_token_types(toktypes) special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False) - special_vocab.chat_template = "rwkv-world" + if special_vocab.chat_template is None: + template_path = Path(__file__).parent / "models" / "templates" / "llama-cpp-rwkv-world.jinja" + if template_path.is_file(): + with open(template_path, "r", encoding="utf-8") as f: + template = f.read() + else: + template = "rwkv-world" + special_vocab.chat_template = template # hack: Add '\n\n' as the EOT token to make it chat normally special_vocab._set_special_token("eot", 261) # hack: Override these as they have already been set (incorrectly) @@ -2768,6 +2778,76 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter yield from super().modify_tensors(data_torch, name, bid) +@ModelBase.register("DreamModel") +class DreamModel(TextModel): + model_arch = gguf.MODEL_ARCH.DREAM + + def get_vocab_base(self) -> tuple[list[str], list[int], str]: + tokens: list[str] = [] + toktypes: list[int] = [] + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) + + vocab_dict = tokenizer.get_vocab() + vocab_size = self.hparams.get("vocab_size", len(vocab_dict)) + assert max(vocab_dict.values()) < vocab_size + + tokpre = self.get_vocab_base_pre(tokenizer) + + reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab_dict.items()} + added_vocab = tokenizer.get_added_vocab() + + for i in range(vocab_size): + if i not in reverse_vocab: + tokens.append(f"[PAD{i}]") + toktypes.append(gguf.TokenType.UNUSED) + elif reverse_vocab[i] in added_vocab: + tokens.append(reverse_vocab[i]) + # Check if it's a special token - treat special tokens as CONTROL tokens + if hasattr(tokenizer, 'added_tokens_decoder') and i in tokenizer.added_tokens_decoder: + if tokenizer.added_tokens_decoder[i].special: + toktypes.append(gguf.TokenType.CONTROL) + else: + toktypes.append(gguf.TokenType.USER_DEFINED) + else: + # Fallback: treat all added vocab as control tokens for special tokens like <|im_start|> + toktypes.append(gguf.TokenType.CONTROL) + else: + tokens.append(reverse_vocab[i]) + toktypes.append(gguf.TokenType.NORMAL) + + return tokens, toktypes, tokpre + + def set_vocab(self): + try: + self._set_vocab_sentencepiece() + except FileNotFoundError: + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self._try_set_pooling_type() + + # Dream models use non-causal attention for diffusion + self.gguf_writer.add_causal_attention(False) + # Handle RoPE scaling similar to Qwen2 + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) + + # Add Dream-specific parameters + mask_token_id = self.hparams.get("mask_token_id") + if mask_token_id is not None: + self.gguf_writer.add_mask_token_id(mask_token_id) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # Dream model tensors should be mapped directly since it's the base model + yield from super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("Ernie4_5_ForCausalLM") class Ernie4_5Model(TextModel): model_arch = gguf.MODEL_ARCH.ERNIE4_5 @@ -3501,6 +3581,175 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(new_name, data_torch)] +@ModelBase.register("Plamo2ForCausalLM", "PLaMo2ForCausalLM") +class Plamo2Model(TextModel): + model_arch = gguf.MODEL_ARCH.PLAMO2 + + def set_vocab(self): + # PLaMo 2 uses a custom tokenizer with a .jsonl file + # We need to handle this specially + tokenizer_jsonl_path = self.dir_model / "tokenizer.jsonl" + tokenizer_config_path = self.dir_model / "tokenizer_config.json" + + if not tokenizer_jsonl_path.is_file(): + raise FileNotFoundError(f"PLaMo 2 tokenizer file not found: {tokenizer_jsonl_path}") + + # Load tokenizer config + with open(tokenizer_config_path, 'r', encoding='utf-8') as f: + tokenizer_config = json.load(f) + + # Load tokens from JSONL file (actually a list format) + tokens = [] + scores = [] + toktypes = [] + + with open(tokenizer_jsonl_path, 'r', encoding='utf-8') as f: + for line_num, line in enumerate(f): + if line.strip(): + token_data = json.loads(line) + # Format: [token, score, type, ?, ?, ?, ?] + token = token_data[0].encode("utf-8") + score = float(token_data[1]) + token_type_str = token_data[2] if len(token_data) > 2 else "NORMAL" + + tokens.append(token) + scores.append(score) + + # Map token type strings to GGUF token types + if token_type_str == "UNKNOWN": + toktypes.append(gguf.TokenType.UNKNOWN) + elif token_type_str == "CONTROL": + toktypes.append(gguf.TokenType.CONTROL) + elif token_type_str == "BYTE": + toktypes.append(gguf.TokenType.BYTE) + else: + # Check for PLaMo-2 special tokens + token_str = token_data[0] + if token_str.startswith("<|plamo:") and token_str.endswith("|>"): + toktypes.append(gguf.TokenType.CONTROL) + else: + toktypes.append(gguf.TokenType.NORMAL) + + vocab_size = self.hparams["vocab_size"] + if vocab_size > len(tokens): + pad_count = vocab_size - len(tokens) + logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]") + for i in range(1, pad_count + 1): + tokens.append(bytes(f"[PAD{i}]", encoding="utf-8")) + scores.append(-1000.0) + toktypes.append(gguf.TokenType.UNUSED) + + # Use "plamo2" tokenizer type for PLaMo-2's custom Aho-Corasick tokenizer + self.gguf_writer.add_tokenizer_model("plamo2") + self.gguf_writer.add_tokenizer_pre("default") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + + # Add special tokens from config + if "bos_token" in tokenizer_config and tokenizer_config["bos_token"] is not None: + token_id = tokens.index(tokenizer_config["bos_token"].encode("utf-8")) + self.gguf_writer.add_bos_token_id(token_id) + if "eos_token" in tokenizer_config and tokenizer_config["eos_token"] is not None: + token_id = tokens.index(tokenizer_config["eos_token"].encode("utf-8")) + self.gguf_writer.add_eos_token_id(token_id) + if "pad_token" in tokenizer_config and tokenizer_config["pad_token"] is not None: + token_id = tokens.index(tokenizer_config["pad_token"].encode("utf-8")) + self.gguf_writer.add_pad_token_id(token_id) + if "sep_token" in tokenizer_config and tokenizer_config["sep_token"] is not None: + token_id = tokens.index(tokenizer_config["sep_token"].encode("utf-8")) + self.gguf_writer.add_sep_token_id(token_id) + if "unk_token" in tokenizer_config and tokenizer_config["unk_token"] is not None: + token_id = tokens.index(tokenizer_config["unk_token"].encode("utf-8")) + self.gguf_writer.add_unk_token_id(token_id) + + # Add <|plamo:op|> as EOT to ensure appropriate end of generation + self.gguf_writer.add_eot_token_id(4) + + self.gguf_writer.add_add_space_prefix(False) + + def set_gguf_parameters(self): + hparams = self.hparams + block_count = hparams["num_hidden_layers"] + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + + # Which layers are Mamba layers + # PLaMo 2 uses mamba_step to indicate the pattern (e.g., 2 means every other layer) + # This logic matches modeling_plamo.py's is_mamba function + mamba_step = hparams.get("mamba_step", 2) + mamba_enabled = hparams.get("mamba_enabled", True) + mamba_layers = [] + + if mamba_enabled: + for i in range(block_count): + if block_count <= (mamba_step // 2): + # use attention in last layer + is_mamba = (i != block_count - 1) + else: + is_mamba = (i % mamba_step) != (mamba_step // 2) + if is_mamba: + mamba_layers.append(0) + else: + mamba_layers.append(hparams.get("num_key_value_heads", 4)) + + if mamba_layers: + self.gguf_writer.add_head_count_kv(mamba_layers) + + self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 2048)) + self.gguf_writer.add_embedding_length(hparams.get("hidden_size", 4096)) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 32)) + self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06)) + self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1000000.0)) + + # Mamba parameters + self.gguf_writer.add_ssm_state_size(hparams.get("mamba_d_state", 64)) + self.gguf_writer.add_ssm_conv_kernel(hparams.get("mamba_d_conv", 4)) + self.gguf_writer.add_ssm_time_step_rank(hparams.get("mamba_num_heads", 64)) + intermediate_size = hparams.get("mamba_num_heads", 64) * hparams.get("hidden_size_per_head", 128) + self.gguf_writer.add_ssm_inner_size(intermediate_size) + self.gguf_writer.add_ssm_group_count(0) + + # MLP feed forward parameters (for attention layers) + self.gguf_writer.add_feed_forward_length(hparams.get("intermediate_size", 16384)) + self.gguf_writer.add_file_type(self.ftype) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + if name.endswith(".A_log"): + data_torch = -torch.exp(data_torch) + elif name.endswith(".dt_bias"): + name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias" + elif name.endswith(".dt_norm_weight"): + name = name.rpartition(".dt_norm_weight")[0] + ".dt_norm.weight" + elif name.endswith(".B_norm_weight"): + name = name.rpartition(".B_norm_weight")[0] + ".B_norm.weight" + elif name.endswith(".C_norm_weight"): + name = name.rpartition(".C_norm_weight")[0] + ".C_norm.weight" + elif name.endswith(".k_weight"): + name = name.rpartition(".k_weight")[0] + ".k.weight" + elif name.endswith(".q_weight"): + name = name.rpartition(".q_weight")[0] + ".q.weight" + elif name.endswith(".conv1d.weight"): + data_torch = torch.squeeze(data_torch) # remove (, 1, ) + assert data_torch.ndim == 2 + elif name.endswith(".pre_mixer_norm.weight"): + data_torch += 1.0 + elif name.endswith(".post_mixer_norm.weight"): + data_torch += 1.0 / 5 + elif name.endswith(".pre_mlp_norm.weight"): + data_torch += 1.0 + elif name.endswith(".post_mlp_norm.weight"): + data_torch += 1.0 / (5**1.5) + elif name.endswith(".norm.weight"): + data_torch += 1.0 + + new_name = self.map_tensor_name(name) + + return [(new_name, data_torch)] + + @ModelBase.register("CodeShellForCausalLM") class CodeShellModel(TextModel): model_arch = gguf.MODEL_ARCH.CODESHELL @@ -5563,7 +5812,58 @@ class DeepseekV2Model(TextModel): model_arch = gguf.MODEL_ARCH.DEEPSEEK2 def set_vocab(self): - self._set_vocab_gpt2() + try: + self._set_vocab_gpt2() + return + except Exception: + pass + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) + tokpre = self.get_vocab_base_pre(tokenizer) + + if tokpre == "kimi-k2": + # Build merges list using the approach similar to HunYuanMoE + merges = [] + vocab = {} + mergeable_ranks = tokenizer.model._mergeable_ranks + for token, rank in mergeable_ranks.items(): + vocab[QwenModel.token_bytes_to_string(token)] = rank + if len(token) == 1: + continue + merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank) + if len(merged) == 2: + merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged))) + + # Build token list + vocab_size = self.hparams["vocab_size"] + special_tokens = tokenizer.special_tokens + reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()} + tokens: list[str] = [] + toktypes: list[int] = [] + + for i in range(vocab_size): + if i not in reverse_vocab: + tokens.append(f"[PAD{i}]") + toktypes.append(gguf.TokenType.UNUSED) + else: + token = reverse_vocab[i] + tokens.append(token) + if i in special_tokens.values(): + toktypes.append(gguf.TokenType.CONTROL) + else: + toktypes.append(gguf.TokenType.NORMAL) + + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + self.gguf_writer.add_token_merges(merges) + + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False) + special_vocab.add_to_gguf(self.gguf_writer) + else: + raise NotImplementedError(f"Deepseek pre-tokenizer {tokpre!r} is not supported yet!") def set_gguf_parameters(self): diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 16f4acfe7834f..f7b6d97b19c8b 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -7,7 +7,6 @@ import re import requests -import sys import json import shutil import argparse @@ -69,8 +68,7 @@ class TOKENIZER_TYPE(IntEnum): hf_token = args.hf_token if args.hf_token is not None else hf_token if hf_token is None: - logger.error("HF token is required. Please provide it as an argument or set it in ~/.cache/huggingface/token") - sys.exit(1) + logger.warning("HF token not found. You can provide it as an argument or set it in ~/.cache/huggingface/token") # TODO: this string has to exercise as much pre-tokenizer functionality as possible # will be updated with time - contributions welcome @@ -146,11 +144,12 @@ class TOKENIZER_TYPE(IntEnum): {"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-1B-Base", "chkhsh": "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86"}, {"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-7B-Base", "chkhsh": "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896"}, {"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-34B-Base", "chkhsh": "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b"}, + {"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890"}, ] def download_file_with_auth(url, token, save_path): - headers = {"Authorization": f"Bearer {token}"} + headers = {"Authorization": f"Bearer {token}"} if token else None response = sess.get(url, headers=headers) response.raise_for_status() os.makedirs(os.path.dirname(save_path), exist_ok=True) @@ -231,7 +230,7 @@ def get_existing_models(convert_py): # generate the source code for the convert_hf_to_gguf.py:get_vocab_base_pre() function: src_ifs = "" -for model in [*all_models, *pre_computed_hashes]: +for model in [*pre_computed_hashes, *all_models]: name = model["name"] tokt = model["tokt"] chkhsh = model.get("chkhsh") @@ -239,11 +238,6 @@ def get_existing_models(convert_py): if tokt == TOKENIZER_TYPE.SPM or tokt == TOKENIZER_TYPE.UGM: continue - # Skip if the tokenizer folder does not exist or there are other download issues previously - if not os.path.exists(f"models/tokenizers/{name}"): - logger.warning(f"Directory for tokenizer {name} not found. Skipping...") - continue - # create the tokenizer if chkhsh is not None: # if the model has a pre-computed hash, use it @@ -253,15 +247,19 @@ def get_existing_models(convert_py): chkhsh = existing_models[name] else: # otherwise, compute the hash of the tokenizer + + # Fail if the tokenizer folder with config does not exist or there are other download issues previously + if not os.path.isfile(f"models/tokenizers/{name}/tokenizer_config.json"): + raise OSError(f"Config for tokenizer {name} not found. The model may not exist or is not accessible with the provided token.") + try: logger.info(f"Loading tokenizer from {f'models/tokenizers/{name}'}...") if name == "t5": tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False) else: tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}") - except OSError as e: - logger.error(f"Error loading tokenizer for model {name}. The model may not exist or is not accessible with the provided token. Error: {e}") - continue # Skip to the next model if the tokenizer can't be loaded + except Exception as e: + raise OSError(f"Error loading tokenizer for model {name}.") from e chktok = tokenizer.encode(CHK_TXT) chkhsh = sha256(str(chktok).encode()).hexdigest() diff --git a/docs/build.md b/docs/build.md index 2e0b5d970c91a..70767ad91c056 100644 --- a/docs/build.md +++ b/docs/build.md @@ -557,6 +557,23 @@ ninja To read documentation for how to build on Android, [click here](./android.md) +## WebGPU [In Progress] + +The WebGPU backend relies on [Dawn](https://dawn.googlesource.com/dawn). Follow the instructions [here](https://dawn.googlesource.com/dawn/+/refs/heads/main/docs/quickstart-cmake.md) to install Dawn locally so that llama.cpp can find it using CMake. The currrent implementation is up-to-date with Dawn commit `bed1a61`. + +In the llama.cpp directory, build with CMake: + +``` +cmake -B build -DGGML_WEBGPU=ON +cmake --build build --config Release +``` + +### Browser Support + +WebGPU allows cross-platform access to the GPU from supported browsers. We utilize [Emscripten](https://emscripten.org/) to compile ggml's WebGPU backend to WebAssembly. Emscripten does not officially support WebGPU bindings yet, but Dawn currently maintains its own WebGPU bindings called emdawnwebgpu. + +Follow the instructions [here](https://dawn.googlesource.com/dawn/+/refs/heads/main/src/emdawnwebgpu/) to download or build the emdawnwebgpu package (Note that it might be safer to build the emdawbwebgpu package locally, so that it stays in sync with the version of Dawn you have installed above). When building using CMake, the path to the emdawnwebgpu port file needs to be set with the flag `EMDAWNWEBGPU_DIR`. + ## IBM Z & LinuxONE To read documentation for how to build on IBM Z & LinuxONE, [click here](./build-s390x.md) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 49e4d2cf8c198..11ff38762b848 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -33,6 +33,7 @@ else() add_subdirectory(speculative-simple) add_subdirectory(gen-docs) add_subdirectory(training) + add_subdirectory(diffusion) if (NOT GGML_BACKEND_DL) add_subdirectory(convert-llama2c-to-ggml) # these examples use the backends directly and cannot be built with dynamic loading diff --git a/examples/diffusion/CMakeLists.txt b/examples/diffusion/CMakeLists.txt new file mode 100644 index 0000000000000..396549c8029d9 --- /dev/null +++ b/examples/diffusion/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-diffusion-cli) +add_executable(${TARGET} diffusion-cli.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/diffusion/diffusion-cli.cpp b/examples/diffusion/diffusion-cli.cpp new file mode 100644 index 0000000000000..3e11ce1160b05 --- /dev/null +++ b/examples/diffusion/diffusion-cli.cpp @@ -0,0 +1,507 @@ +#include "arg.h" +#include "chat.h" +#include "common.h" +#include "llama.h" +#include "log.h" + +#include +#include +#include +#include +#include +#include +#include + +typedef bool (*diffusion_step_callback_t)(int32_t step, + int32_t total_steps, + const llama_token * tokens, + int32_t n_tokens, + void * user_data); + +enum diffusion_alg { + DIFFUSION_ALG_ORIGIN = 0, + DIFFUSION_ALG_MASKGIT_PLUS = 1, + DIFFUSION_ALG_TOPK_MARGIN = 2, + DIFFUSION_ALG_ENTROPY = 3, +}; + +struct diffusion_params { + int32_t steps; + float eps; + float temperature; + float top_p; + int32_t top_k; + llama_token mask_token_id; + enum diffusion_alg algorithm; + float alg_temp; + diffusion_step_callback_t step_callback; + void * step_callback_user_data; + int32_t seed; +}; + + +static diffusion_params diffusion_default_params() { + diffusion_params params = {}; + params.steps = 64; + params.eps = 1e-3f; + params.temperature = 0.2f; + params.top_p = 0.95f; + params.top_k = 0; + params.mask_token_id = LLAMA_TOKEN_NULL; + params.algorithm = DIFFUSION_ALG_ORIGIN; + params.alg_temp = 0.0f; + params.step_callback = nullptr; + params.step_callback_user_data = nullptr; + params.seed = 0; + return params; +} + +static void diffusion_generate(llama_context * ctx, + const llama_token * input_tokens, + llama_token * output_tokens, + int32_t n_input, + int32_t max_length, + struct diffusion_params params, + int32_t & n_generated) { + + n_generated = 0; + if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || max_length <= n_input) { + return; + } + + const llama_model * model = llama_get_model(ctx); + + // Initialize with input and pad with mask tokens + std::copy(input_tokens, input_tokens + n_input, output_tokens); + std::fill(output_tokens + n_input, output_tokens + max_length, params.mask_token_id); + + std::mt19937 rng(params.seed); + + std::vector timesteps(params.steps + 1); + for (int32_t i = 0; i <= params.steps; i++) { + timesteps[i] = 1.0f - (float) i / params.steps * (1.0f - params.eps); + } + + llama_set_causal_attn(ctx, false); + + int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model)); + + std::vector candidates(n_vocab); + + std::vector conf_candidates; + conf_candidates.reserve(max_length); + + std::vector mask_positions; + mask_positions.reserve(max_length); + + struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params()); + if (params.top_k > 0) { + llama_sampler_chain_add(sampler, llama_sampler_init_top_k(params.top_k)); + } + if (params.top_p < 1.0f) { + llama_sampler_chain_add(sampler, llama_sampler_init_top_p(params.top_p, 1)); + } + if (params.temperature > 0.0f) { + llama_sampler_chain_add(sampler, llama_sampler_init_temp(params.temperature)); + } + llama_sampler_chain_add(sampler, llama_sampler_init_dist(params.seed)); + + struct llama_sampler * dist_sampler = llama_sampler_init_dist(params.seed); + + llama_batch batch = llama_batch_init(max_length, 0, 1); + batch.n_tokens = max_length; + + int64_t total_sampling_time = 0; + int64_t total_time = 0; + + int64_t time_start = ggml_time_us(); + for (int32_t step = 0; step < params.steps; step++) { + if (params.step_callback) { + if (!params.step_callback(step, params.steps, output_tokens, max_length, params.step_callback_user_data)) { + break; + } + } + + for (int32_t i = 0; i < max_length; i++) { + batch.token[i] = output_tokens[i]; + batch.pos[i] = i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + batch.logits[i] = 1; + } + + int ret = llama_decode(ctx, batch); + if (ret != 0) { + LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, step, ret); + break; + } + + float * raw_logits = llama_get_logits(ctx); + if (!raw_logits) { + LOG_ERR("%s: failed to get logits at step %d\n", __func__, step); + break; + } + + auto get_logits_for_pos = [&](int32_t pos) -> const float * { + return pos == 0 ? raw_logits : raw_logits + (pos - 1) * n_vocab; + }; + + int64_t time_start_sampling = ggml_time_us(); + + mask_positions.clear(); + for (int32_t i = 0; i < max_length; i++) { + if (output_tokens[i] == params.mask_token_id) { + mask_positions.push_back(i); + } + } + + if (mask_positions.empty()) { + break; + } + + float t = timesteps[step]; + float s = timesteps[step + 1]; + + if (params.algorithm == DIFFUSION_ALG_ORIGIN) { + float p_transfer = (step < params.steps - 1) ? (1.0f - s / t) : 1.0f; + + for (int32_t pos : mask_positions) { + if (std::uniform_real_distribution(0.0f, 1.0f)(rng) < p_transfer) { + const float * pos_logits = get_logits_for_pos(pos); + for (int32_t token_id = 0; token_id < n_vocab; token_id++) { + candidates[token_id].id = token_id; + candidates[token_id].logit = pos_logits[token_id]; + candidates[token_id].p = 0.0f; + } + + llama_token_data_array cur_p = { + /* .data = */ candidates.data(), + /* .size = */ (size_t) n_vocab, // Reset size to full vocab + /* .selected = */ -1, + /* .sorted = */ false, + }; + + llama_sampler_apply(sampler, &cur_p); + output_tokens[pos] = cur_p.data[cur_p.selected].id; + } + } + } else { + std::vector> confidences; + std::vector sampled_tokens(mask_positions.size()); + + for (size_t i = 0; i < mask_positions.size(); i++) { + int32_t pos = mask_positions[i]; + const float * pos_logits = get_logits_for_pos(pos); + + for (int32_t token_id = 0; token_id < n_vocab; token_id++) { + candidates[token_id].logit = pos_logits[token_id]; + candidates[token_id].p = 0.0f; + candidates[token_id].id = token_id; + } + + llama_token_data_array cur_p = { + /* .data = */ candidates.data(), + /* .size = */ candidates.size(), + /* .selected = */ -1, + /* .sorted = */ false, + }; + + llama_sampler_apply(sampler, &cur_p); + + llama_token sampled_token = cur_p.data[cur_p.selected].id; + + float confidence = 0.0f; + if (params.algorithm == DIFFUSION_ALG_ENTROPY) { + const float epsilon = 1e-10f; + for (size_t j = 0; j < cur_p.size; j++) { + float prob = cur_p.data[j].p; + confidence += prob * logf(prob + epsilon); + } + } else if (params.algorithm == DIFFUSION_ALG_TOPK_MARGIN) { + confidence = cur_p.data[0].p - cur_p.data[1].p; + } else { + confidence = cur_p.data[cur_p.selected].p; + } + + sampled_tokens[i] = sampled_token; + confidences.emplace_back(confidence, i); + } + + int32_t num_transfer = + (step < params.steps - 1) ? (int32_t) (mask_positions.size() * (1.0f - s / t)) : mask_positions.size(); + + if (num_transfer > 0) { + if (params.alg_temp == 0.0f) { + std::partial_sort(confidences.begin(), confidences.begin() + num_transfer, confidences.end(), + [](const std::pair & a, const std::pair & b) { + if (a.first != b.first) { + return a.first > b.first; + } + return a.second < b.second; + }); + } else { + conf_candidates.clear(); + + for (int32_t pos = 0; pos < max_length; pos++) { + float conf_logit = -std::numeric_limits::infinity(); + + auto it = std::find(mask_positions.begin(), mask_positions.end(), pos); + if (it != mask_positions.end()) { + size_t mask_idx = std::distance(mask_positions.begin(), it); + conf_logit = confidences[mask_idx].first / params.alg_temp; // Apply temperature scaling + } + + conf_candidates.emplace_back(llama_token_data{ pos, conf_logit, 0.0f }); + } + + llama_token_data_array conf_array = { + /* .data = */ conf_candidates.data(), + /* .size = */ conf_candidates.size(), + /* .selected = */ -1, + /* .sorted = */ false, + }; + + for (int32_t i = 0; i < num_transfer; i++) { + // Apply distribution sampler to get selected index + llama_sampler_apply(dist_sampler, &conf_array); + int selected_idx = conf_array.selected; + confidences[i].second = conf_candidates[selected_idx].id; + + conf_candidates[selected_idx].p = 0.0f; + conf_array.selected = -1; + } + } + + if (params.alg_temp == 0.0f) { + // Deterministic - use confidence order + for (int32_t i = 0; i < num_transfer; i++) { + int32_t mask_idx = confidences[i].second; + int32_t pos = mask_positions[mask_idx]; + llama_token token = sampled_tokens[mask_idx]; + output_tokens[pos] = token; + } + } else { + for (int32_t i = 0; i < num_transfer; i++) { + int32_t pos = confidences[i].second; + auto it = std::find(mask_positions.begin(), mask_positions.end(), pos); + if (it != mask_positions.end()) { + int32_t mask_idx = std::distance(mask_positions.begin(), it); + output_tokens[pos] = sampled_tokens[mask_idx]; + } + } + } + } + } + int64_t time_end_sampling = ggml_time_us(); + total_sampling_time += time_end_sampling - time_start_sampling; + } + int64_t time_end = ggml_time_us(); + total_time += time_end - time_start; + + LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n", + total_time / 1000.0, total_time / 1000.0 / params.steps, total_sampling_time / 1000.0 / params.steps); + + + llama_batch_free(batch); + llama_sampler_free(sampler); + llama_sampler_free(dist_sampler); + + n_generated = max_length; +} + + + + +static std::string format_input_text(const std::string & prompt, bool use_chat_template, llama_model * model) { + if (!use_chat_template) { + return prompt; + } + + auto chat_templates = common_chat_templates_init(model, ""); + + common_chat_templates_inputs inputs; + common_chat_msg user_msg; + user_msg.role = "user"; + user_msg.content = prompt; + inputs.add_generation_prompt = true; + inputs.messages.push_back(user_msg); + + auto result = common_chat_templates_apply(chat_templates.get(), inputs); + + return result.prompt; +} + +struct callback_data { + const common_params_diffusion * diff_params; + const llama_vocab * vocab; + int32_t n_input; +}; + +static bool diffusion_step_callback(int32_t step, + int32_t total_steps, + const llama_token * tokens, + int32_t n_tokens, + void * user_data) { + (void)user_data; + + callback_data * data = static_cast(user_data); + + auto print_progress_bar = [](int32_t step, int32_t total_steps) { + int progress_percent = (step * 100) / total_steps; + int progress_bars = (step * 50) / total_steps; + LOG_INF("\rdiffusion step: %d/%d [%s%s] %d%%", + step, + total_steps, + std::string(progress_bars, '=').c_str(), + std::string(50 - progress_bars, ' ').c_str(), + progress_percent); + }; + + if (data->diff_params->visual_mode) { + // Visual mode: clear + LOG_INF("\033[2J\033[H"); // Clear screen and move cursor to top-left + + print_progress_bar(step, total_steps); + + LOG_INF("\n"); + + std::string current_text = " "; + + for (int32_t i = data->n_input; i < n_tokens; i++) { + std::string token_str; + if (tokens[i] != llama_vocab_mask(data->vocab)) { + char piece[256]; + int n_chars = llama_token_to_piece(data->vocab, tokens[i], piece, sizeof(piece), 0, false); + if (n_chars > 0) { + piece[n_chars] = '\0'; + token_str = piece; + } + } else { + token_str = " "; + } + + current_text += token_str; + } + + LOG_INF("%s\n", current_text.c_str()); + } else { + print_progress_bar(step, total_steps); + } + + return true; +} + +int main(int argc, char ** argv) { + ggml_time_init(); + + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DIFFUSION)) { + return 1; + } + + const char * alg_names[] = { "ORIGIN", "MASKGIT_PLUS", "TOPK_MARGIN", "ENTROPY" }; + const char * alg_name = (params.diffusion.algorithm >= 0 && params.diffusion.algorithm <= 3) ? + alg_names[params.diffusion.algorithm] : + "UNKNOWN"; + + common_init(); + llama_backend_init(); + + llama_model_params model_params = llama_model_default_params(); + model_params.n_gpu_layers = params.n_gpu_layers; + model_params.devices = params.devices.data(); + model_params.use_mmap = params.use_mmap; + model_params.use_mlock = params.use_mlock; + model_params.check_tensors = params.check_tensors; + + llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params); + if (!model) { + LOG_ERR("error: failed to load model '%s'\n", params.model.path.c_str()); + return 1; + } + + llama_context_params ctx_params = llama_context_default_params(); + ctx_params.n_ctx = params.n_ctx; + ctx_params.n_batch = params.n_batch; + ctx_params.n_ubatch = params.n_ubatch; + ctx_params.flash_attn = params.flash_attn; + ctx_params.no_perf = params.no_perf; + ctx_params.type_k = params.cache_type_k; + ctx_params.type_v = params.cache_type_v; + + llama_context * ctx = llama_init_from_model(model, ctx_params); + if (!ctx) { + LOG_ERR("error: failed to create context\n"); + llama_model_free(model); + return 1; + } + + llama_set_n_threads(ctx, params.cpuparams.n_threads, params.cpuparams_batch.n_threads); + + const llama_vocab * vocab = llama_model_get_vocab(model); + std::string formatted_prompt = format_input_text(params.prompt, params.enable_chat_template, model); + + std::vector input_tokens = common_tokenize(vocab, formatted_prompt, + /*add special tokens*/ true, + /*parse special*/ true); + int n_input = input_tokens.size(); + + if (n_input >= params.n_ctx) { + LOG_ERR("error: input too long (%d tokens), max context is %d\n", n_input, params.n_ctx); + llama_free(ctx); + llama_model_free(model); + return 1; + } + + struct diffusion_params ldiff_params = diffusion_default_params(); + ldiff_params.steps = params.diffusion.steps; + ldiff_params.eps = params.diffusion.eps; + ldiff_params.temperature = params.sampling.temp; + ldiff_params.top_p = params.sampling.top_p; + ldiff_params.top_k = params.sampling.top_k; + ldiff_params.algorithm = static_cast(params.diffusion.algorithm); + ldiff_params.alg_temp = params.diffusion.alg_temp; + ldiff_params.seed = params.sampling.seed; + + llama_token mask_token_id = llama_vocab_mask(vocab); + GGML_ASSERT(mask_token_id != LLAMA_TOKEN_NULL); + + LOG_INF("diffusion_params: - %-25s llama_token = %d\n", "mask_token_id", mask_token_id); + LOG_INF("diffusion_params: - %-25s u32 = %d\n", "steps", params.diffusion.steps); + LOG_INF("diffusion_params: - %-25s f32 = %.6f\n", "eps", params.diffusion.eps); + LOG_INF("diffusion_params: - %-25s u32 = %d (%s)\n", "algorithm", params.diffusion.algorithm, + alg_name); + LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "alg_temp", params.diffusion.alg_temp); + + ldiff_params.mask_token_id = mask_token_id; + + callback_data cb_data = { ¶ms.diffusion, vocab, n_input }; + + ldiff_params.step_callback = diffusion_step_callback; + ldiff_params.step_callback_user_data = &cb_data; + + int32_t n_generated = 0; + + std::vector output_tokens(params.n_ubatch); + diffusion_generate(ctx, input_tokens.data(), output_tokens.data(), n_input, params.n_ubatch, + ldiff_params, n_generated); + + if (n_generated > 0) { + if (params.diffusion.visual_mode) { + //clear screen and move cursor to top-left + LOG_INF("\033[2J\033[H"); + } + output_tokens.erase(output_tokens.begin(), output_tokens.begin() + n_input); + std::string output_data = common_detokenize(vocab, output_tokens, false); + LOG_INF("\n%s\n", output_data.c_str()); + } else { + LOG_INF("Error: diffusion generation failed\n"); + } + + llama_free(ctx); + llama_model_free(model); + llama_backend_free(); + + return 0; +} diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 0ec2999a0c8e9..40ff6483807ee 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -107,7 +107,7 @@ int main(int argc, char ** argv) { const llama_vocab * vocab = llama_model_get_vocab(model); const int n_ctx_train = llama_model_n_ctx_train(model); - const int n_ctx = llama_n_ctx(ctx); + const int n_ctx = llama_n_ctx(ctx); const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index d53e089a4cbc2..46fb451baa712 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -224,6 +224,7 @@ int main(int argc, char ** argv) { auto & client = clients[i]; client.id = i; client.smpl = common_sampler_init(model, params.sampling); + //params.sampling.seed++; } std::vector tokens_system; @@ -345,7 +346,7 @@ int main(int argc, char ** argv) { client.n_decoded = 0; client.i_batch = batch.n_tokens - 1; - LOG_INF("\033[31mClient %3d, seq %4d, junk = %4d, started decoding ...\033[0m\n", client.id, client.seq_id, n_junk_cur); + LOG_INF("\033[31mClient %3d, seq %4d, junk = %4d, prompt = %d, started decoding ...\033[0m\n", client.id, client.seq_id, n_junk_cur, client.n_prompt); g_seq_id += 1; diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index eaba9c70469ef..de6d789c98a03 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -181,6 +181,8 @@ option(GGML_VULKAN_MEMORY_DEBUG "ggml: enable Vulkan memory debug ou option(GGML_VULKAN_SHADER_DEBUG_INFO "ggml: enable Vulkan shader debug info" OFF) option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation" OFF) option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF) +option(GGML_WEBGPU "ggml: use WebGPU" OFF) +option(GGML_WEBGPU_DEBUG "ggml: enable WebGPU debug output" OFF) option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT}) option(GGML_METAL_USE_BF16 "ggml: use bfloat if available" OFF) option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF) @@ -270,6 +272,7 @@ set(GGML_PUBLIC_HEADERS include/ggml-rpc.h include/ggml-sycl.h include/ggml-vulkan.h + include/ggml-webgpu.h include/gguf.h) set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}") diff --git a/ggml/include/ggml-webgpu.h b/ggml/include/ggml-webgpu.h new file mode 100644 index 0000000000000..65b8ed9bb6644 --- /dev/null +++ b/ggml/include/ggml-webgpu.h @@ -0,0 +1,19 @@ +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define GGML_WEBGPU_NAME "WebGPU" + +// Needed for examples in ggml +GGML_BACKEND_API ggml_backend_t ggml_backend_webgpu_init(void); + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_webgpu_reg(void); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 8760c2d35eca4..0425fd60a9412 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -370,6 +370,7 @@ ggml_add_backend(MUSA) ggml_add_backend(RPC) ggml_add_backend(SYCL) ggml_add_backend(Vulkan) +ggml_add_backend(WebGPU) ggml_add_backend(OpenCL) foreach (target ggml-base ggml) diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index 042ea77aca721..f0cdac31eae9a 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -45,6 +45,10 @@ #include "ggml-vulkan.h" #endif +#ifdef GGML_USE_WEBGPU +#include "ggml-webgpu.h" +#endif + #ifdef GGML_USE_OPENCL #include "ggml-opencl.h" #endif @@ -173,6 +177,9 @@ struct ggml_backend_registry { #ifdef GGML_USE_VULKAN register_backend(ggml_backend_vk_reg()); #endif +#ifdef GGML_USE_WEBGPU + register_backend(ggml_backend_webgpu_reg()); +#endif #ifdef GGML_USE_OPENCL register_backend(ggml_backend_opencl_reg()); #endif diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index ccb17eb072eb2..e5e11d4cdced9 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2090,6 +2090,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, { // TODO: add support // ref: https://github.com/ggml-org/llama.cpp/pull/14274 +#pragma message("TODO: implement F32, F16, BF16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)") return false; } break; case GGML_OP_CPY: { diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index ed61869a5508a..2be54c31b5f3e 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -1541,7 +1541,7 @@ class tinyBLAS_BF16_PPC { } else if constexpr(RM == 8 && RN == 4) { KERNEL_8x4(ii,jj); } else { - static_assert(false, "RN/RM values not supported"); + assert(false && "RN/RM values not supported"); } } @@ -1573,13 +1573,13 @@ class tinyBLAS_BF16_PPC { const int nth; }; -template +template class tinyBLAS_Q0_PPC { public: tinyBLAS_Q0_PPC(int64_t k, const TA *A, int64_t lda, - const TB *B, int64_t ldb, - TC *C, int64_t ldc, + const block_q8_0 *B, int64_t ldb, + float *C, int64_t ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } @@ -1590,8 +1590,7 @@ class tinyBLAS_Q0_PPC { private: - template - inline void save_res(int ii, int jj, int idx, vector float* fin_res) { + inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) { for (int I = 0; I < RM; I++) { for (int J = 0; J < RN; J++) { *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J); @@ -1611,29 +1610,67 @@ class tinyBLAS_Q0_PPC { fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]); } } - - template - void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, VA* vec, std::array& comparray) { - int64_t i, j; - TA *aoffset = NULL; - VA *vecOffset = NULL; - TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; - TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; - VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0}; - VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0}; - VB t1, t2, t3, t4, t5, t6, t7, t8; + /* This function processes quantized data from block_q4_0 elements. + * First the we try to extract the two int4 values stored in single int8_t into two signed int8. + * And then we subtract each of the resultant element with 8, to convert signed int8 to unsigned int8. + * Also compute the rowsum which is required to compensate the above conversion. */ + inline void process_q4_elements(vector signed char (&c)[2], int* ca) { const vector signed char lowMask = vec_splats((signed char)0xF); const vector unsigned char v4 = vec_splats((unsigned char)0x4); const vector signed char v8 = vec_splats((signed char)0x8); - aoffset = const_cast(a); - vecOffset = vec; + vector signed int vsum = {0}; + vector signed int vsum2 = {0}; + c[0] = vec_and(c[1], lowMask); + c[1] = vec_sr(c[1], v4); + c[0] = vec_sub(c[0], v8); + c[1] = vec_sub(c[1], v8); + vsum = vec_sum4s(c[0], vsum); + vsum2 = vec_sum4s(c[1], vsum2); + vsum = vec_add(vsum, vsum2); + *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + } + + template + inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) { vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; - vector signed int vsum = {0}; - vector signed int vsum2 = {0}; + V2 t1, t2, t3, t4, t5, t6, t7, t8; + vector unsigned char xor_vector; + uint8_t flip_vec = 0x80; + xor_vector = vec_splats(flip_vec); + t1 = vec_perm(s1, s2, swiz1); + t2 = vec_perm(s1, s2, swiz2); + t3 = vec_perm(s3, s4, swiz1); + t4 = vec_perm(s3, s4, swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + if (flip == true) { + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); + } + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset+16); + vec_xst(t7, 0, vecOffset+32); + vec_xst(t8, 0, vecOffset+48); + } + template + void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array& comparray) { + int64_t i, j; + TA *aoffset = NULL; + int8_t *vecOffset = NULL; + TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; + TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; + vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0}; + vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0}; + aoffset = const_cast(a); + vecOffset = vec; j = (rows >> 3); if (j > 0) { do { @@ -1646,159 +1683,30 @@ class tinyBLAS_Q0_PPC { aoffset7 = aoffset6 + lda; aoffset8 = aoffset7 + lda; aoffset += 8 * lda; - i = (cols >> 2); if (i > 0) { do { - c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); - c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); - c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); - c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs)); - c5[1] = reinterpret_cast(vec_xl(0, aoffset5->qs)); - c6[1] = reinterpret_cast(vec_xl(0, aoffset6->qs)); - c7[1] = reinterpret_cast(vec_xl(0, aoffset7->qs)); - c8[1] = reinterpret_cast(vec_xl(0, aoffset8->qs)); - - c1[0] = vec_and(c1[1], lowMask); - c1[1] = vec_sr(c1[1], v4); - c1[0] = vec_sub(c1[0], v8); - c1[1] = vec_sub(c1[1], v8); - vsum = vec_sum4s(c1[0], vsum); - vsum2 = vec_sum4s(c1[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c2[0] = vec_and(c2[1], lowMask); - c2[1] = vec_sr(c2[1], v4); - c2[0] = vec_sub(c2[0], v8); - c2[1] = vec_sub(c2[1], v8); - vsum = vec_sum4s(c2[0], vsum); - vsum2 = vec_sum4s(c2[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c3[0] = vec_and(c3[1], lowMask); - c3[1] = vec_sr(c3[1], v4); - c3[0] = vec_sub(c3[0], v8); - c3[1] = vec_sub(c3[1], v8); - vsum = vec_sum4s(c3[0], vsum); - vsum2 = vec_sum4s(c3[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c4[0] = vec_and(c4[1], lowMask); - c4[1] = vec_sr(c4[1], v4); - c4[0] = vec_sub(c4[0], v8); - c4[1] = vec_sub(c4[1], v8); - vsum = vec_sum4s(c4[0], vsum); - vsum2 = vec_sum4s(c4[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c5[0] = vec_and(c5[1], lowMask); - c5[1] = vec_sr(c5[1], v4); - c5[0] = vec_sub(c5[0], v8); - c5[1] = vec_sub(c5[1], v8); - vsum = vec_sum4s(c5[0], vsum); - vsum2 = vec_sum4s(c5[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[4] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c6[0] = vec_and(c6[1], lowMask); - c6[1] = vec_sr(c6[1], v4); - c6[0] = vec_sub(c6[0], v8); - c6[1] = vec_sub(c6[1], v8); - vsum = vec_sum4s(c6[0], vsum); - vsum2 = vec_sum4s(c6[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[5] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c7[0] = vec_and(c7[1], lowMask); - c7[1] = vec_sr(c7[1], v4); - c7[0] = vec_sub(c7[0], v8); - c7[1] = vec_sub(c7[1], v8); - vsum = vec_sum4s(c7[0], vsum); - vsum2 = vec_sum4s(c7[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[6] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c8[0] = vec_and(c8[1], lowMask); - c8[1] = vec_sr(c8[1], v4); - c8[0] = vec_sub(c8[0], v8); - c8[1] = vec_sub(c8[1], v8); - vsum = vec_sum4s(c8[0], vsum); - vsum2 = vec_sum4s(c8[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[7] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - t1 = vec_perm(c1[0], c2[0], swiz1); - t2 = vec_perm(c1[0], c2[0], swiz2); - t3 = vec_perm(c3[0], c4[0], swiz1); - t4 = vec_perm(c3[0], c4[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset+16); - vec_xst(t7, 0, vecOffset+32); - vec_xst(t8, 0, vecOffset+48); - - t1 = vec_perm(c1[1], c2[1], swiz1); - t2 = vec_perm(c1[1], c2[1], swiz2); - t3 = vec_perm(c3[1], c4[1], swiz1); - t4 = vec_perm(c3[1], c4[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset+64); - vec_xst(t6, 0, vecOffset+80); - vec_xst(t7, 0, vecOffset+96); - vec_xst(t8, 0, vecOffset+112); - - t1 = vec_perm(c5[0], c6[0], swiz1); - t2 = vec_perm(c5[0], c6[0], swiz2); - t3 = vec_perm(c7[0], c8[0], swiz1); - t4 = vec_perm(c7[0], c8[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset+128); - vec_xst(t6, 0, vecOffset+144); - vec_xst(t7, 0, vecOffset+160); - vec_xst(t8, 0, vecOffset+176); - - t1 = vec_perm(c5[1], c6[1], swiz1); - t2 = vec_perm(c5[1], c6[1], swiz2); - t3 = vec_perm(c7[1], c8[1], swiz1); - t4 = vec_perm(c7[1], c8[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset+192); - vec_xst(t6, 0, vecOffset+208); - vec_xst(t7, 0, vecOffset+224); - vec_xst(t8, 0, vecOffset+240); - + c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); + c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); + c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); + c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs)); + c5[1] = reinterpret_cast(vec_xl(0, aoffset5->qs)); + c6[1] = reinterpret_cast(vec_xl(0, aoffset6->qs)); + c7[1] = reinterpret_cast(vec_xl(0, aoffset7->qs)); + c8[1] = reinterpret_cast(vec_xl(0, aoffset8->qs)); + + process_q4_elements(c1, &comparray[0]); + process_q4_elements(c2, &comparray[1]); + process_q4_elements(c3, &comparray[2]); + process_q4_elements(c4, &comparray[3]); + process_q4_elements(c5, &comparray[4]); + process_q4_elements(c6, &comparray[5]); + process_q4_elements(c7, &comparray[6]); + process_q4_elements(c8, &comparray[7]); + vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false); + vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); + vector_permute_store(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false); + vector_permute_store(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false); aoffset1 += lda; aoffset2 += lda; aoffset3 += lda; @@ -1821,85 +1729,20 @@ class tinyBLAS_Q0_PPC { aoffset3 = aoffset2 + lda; aoffset4 = aoffset3 + lda; aoffset += 4 * lda; - i = (cols >> 2); if (i > 0) { do { - c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); - c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); - c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); - c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs)); - - c1[0] = vec_and(c1[1], lowMask); - c1[1] = vec_sr(c1[1], v4); - c1[0] = vec_sub(c1[0], v8); - c1[1] = vec_sub(c1[1], v8); - vsum = vec_sum4s(c1[0], vsum); - vsum2 = vec_sum4s(c1[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c2[0] = vec_and(c2[1], lowMask); - c2[1] = vec_sr(c2[1], v4); - c2[0] = vec_sub(c2[0], v8); - c2[1] = vec_sub(c2[1], v8); - vsum = vec_sum4s(c2[0], vsum); - vsum2 = vec_sum4s(c2[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c3[0] = vec_and(c3[1], lowMask); - c3[1] = vec_sr(c3[1], v4); - c3[0] = vec_sub(c3[0], v8); - c3[1] = vec_sub(c3[1], v8); - vsum = vec_sum4s(c3[0], vsum); - vsum2 = vec_sum4s(c3[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c4[0] = vec_and(c4[1], lowMask); - c4[1] = vec_sr(c4[1], v4); - c4[0] = vec_sub(c4[0], v8); - c4[1] = vec_sub(c4[1], v8); - vsum = vec_sum4s(c4[0], vsum); - vsum2 = vec_sum4s(c4[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats( 0); - - t1 = vec_perm(c1[0], c2[0], swiz1); - t2 = vec_perm(c1[0], c2[0], swiz2); - t3 = vec_perm(c3[0], c4[0], swiz1); - t4 = vec_perm(c3[0], c4[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset+16); - vec_xst(t7, 0, vecOffset+32); - vec_xst(t8, 0, vecOffset+48); - - t1 = vec_perm(c1[1], c2[1], swiz1); - t2 = vec_perm(c1[1], c2[1], swiz2); - t3 = vec_perm(c3[1], c4[1], swiz1); - t4 = vec_perm(c3[1], c4[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset+64); - vec_xst(t6, 0, vecOffset+80); - vec_xst(t7, 0, vecOffset+96); - vec_xst(t8, 0, vecOffset+112); - + c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); + c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); + c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); + c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs)); + + process_q4_elements(c1, &comparray[0]); + process_q4_elements(c2, &comparray[1]); + process_q4_elements(c3, &comparray[2]); + process_q4_elements(c4, &comparray[3]); + vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false); + vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); aoffset1 += lda; aoffset2 += lda; aoffset3 += lda; @@ -1918,80 +1761,17 @@ class tinyBLAS_Q0_PPC { if (i > 0) { do { switch(rows) { - case 3: c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); - case 2: c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); - case 1: c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); + case 3: c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); + case 2: c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); + case 1: c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); break; } - c1[0] = vec_and(c1[1], lowMask); - c1[1] = vec_sr(c1[1], v4); - c1[0] = vec_sub(c1[0], v8); - c1[1] = vec_sub(c1[1], v8); - vsum = vec_sum4s(c1[0], vsum); - vsum2 = vec_sum4s(c1[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c2[0] = vec_and(c2[1], lowMask); - c2[1] = vec_sr(c2[1], v4); - c2[0] = vec_sub(c2[0], v8); - c2[1] = vec_sub(c2[1], v8); - vsum = vec_sum4s(c2[0], vsum); - vsum2 = vec_sum4s(c2[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c3[0] = vec_and(c3[1], lowMask); - c3[1] = vec_sr(c3[1], v4); - c3[0] = vec_sub(c3[0], v8); - c3[1] = vec_sub(c3[1], v8); - vsum = vec_sum4s(c3[0], vsum); - vsum2 = vec_sum4s(c3[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c4[0] = vec_and(c4[1], lowMask); - c4[1] = vec_sr(c4[1], v4); - c4[0] = vec_sub(c4[0], v8); - c4[1] = vec_sub(c4[1], v8); - vsum = vec_sum4s(c4[0], vsum); - vsum2 = vec_sum4s(c4[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - t1 = vec_perm(c1[0], c2[0], swiz1); - t2 = vec_perm(c1[0], c2[0], swiz2); - t3 = vec_perm(c3[0], c4[0], swiz1); - t4 = vec_perm(c3[0], c4[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset+16); - vec_xst(t7, 0, vecOffset+32); - vec_xst(t8, 0, vecOffset+48); - - t1 = vec_perm(c1[1], c2[1], swiz1); - t2 = vec_perm(c1[1], c2[1], swiz2); - t3 = vec_perm(c3[1], c4[1], swiz1); - t4 = vec_perm(c3[1], c4[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset+64); - vec_xst(t6, 0, vecOffset+80); - vec_xst(t7, 0, vecOffset+96); - vec_xst(t8, 0, vecOffset+112); + process_q4_elements(c1, &comparray[0]); + process_q4_elements(c2, &comparray[1]); + process_q4_elements(c3, &comparray[2]); + process_q4_elements(c4, &comparray[3]); + vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false); + vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); aoffset1 += lda; aoffset2 += lda; aoffset3 += lda; @@ -2001,146 +1781,40 @@ class tinyBLAS_Q0_PPC { } } } - template - void packNormal(const TB* a, int64_t lda, int rows, int cols, VA* vec, bool flip) { + void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) { int64_t i, j; - TB *aoffset = NULL; + block_q8_0 *aoffset = NULL; VA *vecOffset = NULL; - TB *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; - TB *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; - __vector_pair C1, C2, C3, C4, C5, C6, C7, C8; - VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0}; - VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0}; - VB t1, t2, t3, t4, t5, t6, t7, t8; - vector unsigned char xor_vector; - uint8_t flip_vec = 0x80; - xor_vector = vec_splats(flip_vec); - vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; - vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; - vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; - vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; - - aoffset = const_cast(a); + block_q8_0* aoffsets[8]; + __vector_pair arr[8]; + VB c[8][2] = {0}; + VB c1[8] = {0}; VB c2[8] = {0}; + aoffset = const_cast(a); vecOffset = vec; j = (rows >> 3); if (j > 0) { do { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; - aoffset4 = aoffset3 + lda; - aoffset5 = aoffset4 + lda; - aoffset6 = aoffset5 + lda; - aoffset7 = aoffset6 + lda; - aoffset8 = aoffset7 + lda; + aoffsets[0] = aoffset; + for (int it = 1; it < 8; it++) + aoffsets[it] = aoffsets[it-1] + lda; aoffset += 8 * lda; i = (cols >> 3); if (i > 0) { do { - C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs); - C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs); - C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs); - C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs); - C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5->qs); - C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6->qs); - C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7->qs); - C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8->qs); - - __builtin_vsx_disassemble_pair(c1, &C1); - __builtin_vsx_disassemble_pair(c2, &C2); - __builtin_vsx_disassemble_pair(c3, &C3); - __builtin_vsx_disassemble_pair(c4, &C4); - __builtin_vsx_disassemble_pair(c5, &C5); - __builtin_vsx_disassemble_pair(c6, &C6); - __builtin_vsx_disassemble_pair(c7, &C7); - __builtin_vsx_disassemble_pair(c8, &C8); - - t1 = vec_perm(c1[0], c2[0], swiz1); - t2 = vec_perm(c1[0], c2[0], swiz2); - t3 = vec_perm(c3[0], c4[0], swiz1); - t4 = vec_perm(c3[0], c4[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); - } - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset+16); - vec_xst(t7, 0, vecOffset+32); - vec_xst(t8, 0, vecOffset+48); - - t1 = vec_perm(c1[1], c2[1], swiz1); - t2 = vec_perm(c1[1], c2[1], swiz2); - t3 = vec_perm(c3[1], c4[1], swiz1); - t4 = vec_perm(c3[1], c4[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); - } - vec_xst(t5, 0, vecOffset+64); - vec_xst(t6, 0, vecOffset+80); - vec_xst(t7, 0, vecOffset+96); - vec_xst(t8, 0, vecOffset+112); - - t1 = vec_perm(c5[0], c6[0], swiz1); - t2 = vec_perm(c5[0], c6[0], swiz2); - t3 = vec_perm(c7[0], c8[0], swiz1); - t4 = vec_perm(c7[0], c8[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); - } - vec_xst(t5, 0, vecOffset+128); - vec_xst(t6, 0, vecOffset+144); - vec_xst(t7, 0, vecOffset+160); - vec_xst(t8, 0, vecOffset+176); - - t1 = vec_perm(c5[1], c6[1], swiz1); - t2 = vec_perm(c5[1], c6[1], swiz2); - t3 = vec_perm(c7[1], c8[1], swiz1); - t4 = vec_perm(c7[1], c8[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); + for (int it = 0; it < 8; it++) { + arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs); + __builtin_vsx_disassemble_pair(c[it], &arr[it]); + c1[it] = c[it][0]; + c2[it] = c[it][1]; } - vec_xst(t5, 0, vecOffset+192); - vec_xst(t6, 0, vecOffset+208); - vec_xst(t7, 0, vecOffset+224); - vec_xst(t8, 0, vecOffset+240); - - aoffset1 += lda; - aoffset2 += lda; - aoffset3 += lda; - aoffset4 += lda; - aoffset5 += lda; - aoffset6 += lda; - aoffset7 += lda; - aoffset8 += lda; + vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); + vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); + vector_permute_store(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip); + vector_permute_store(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip); + for (int it = 0; it < 8; it++) + aoffsets[it] += lda; vecOffset += 256; i--; } while(i > 0); @@ -2150,129 +1824,53 @@ class tinyBLAS_Q0_PPC { } if (rows & 4) { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; - aoffset4 = aoffset3 + lda; - aoffset += 4 * lda; - + aoffsets[0] = aoffset; + for (int it = 1; it < 4; it++ ) + aoffsets[it] = aoffsets[it-1] + lda; + aoffset += 4 * lda; i = (cols >> 3); if (i > 0) { do { - C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs); - C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs); - C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs); - C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs); - - __builtin_vsx_disassemble_pair(c1, &C1); - __builtin_vsx_disassemble_pair(c2, &C2); - __builtin_vsx_disassemble_pair(c3, &C3); - __builtin_vsx_disassemble_pair(c4, &C4); - - t1 = vec_perm(c1[0], c2[0], swiz1); - t2 = vec_perm(c1[0], c2[0], swiz2); - t3 = vec_perm(c3[0], c4[0], swiz1); - t4 = vec_perm(c3[0], c4[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); + for (int it = 0; it < 4; it++) { + arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs); + __builtin_vsx_disassemble_pair(c[it], &arr[it]); + c1[it] = c[it][0]; + c2[it] = c[it][1]; } - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset+16); - vec_xst(t7, 0, vecOffset+32); - vec_xst(t8, 0, vecOffset+48); - - t1 = vec_perm(c1[1], c2[1], swiz1); - t2 = vec_perm(c1[1], c2[1], swiz2); - t3 = vec_perm(c3[1], c4[1], swiz1); - t4 = vec_perm(c3[1], c4[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); + vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); + vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); + for (int it = 0; it < 4; it++) { + aoffsets[it] += lda; } - vec_xst(t5, 0, vecOffset+64); - vec_xst(t6, 0, vecOffset+80); - vec_xst(t7, 0, vecOffset+96); - vec_xst(t8, 0, vecOffset+112); - - aoffset1 += lda; - aoffset2 += lda; - aoffset3 += lda; - aoffset4 += lda; vecOffset += 128; i--; } while(i > 0); } } + if (rows & 3) { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; + aoffsets[0] = aoffset; + for (int it = 1; it < 3; it++ ) + aoffsets[it] = aoffsets[it-1] + lda; i = (cols >> 3); if (i > 0) { do { switch(rows) { - case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs); - __builtin_vsx_disassemble_pair(c3, &C3); - case 2: C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs); - __builtin_vsx_disassemble_pair(c2, &C2); - case 1: C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs); - __builtin_vsx_disassemble_pair(c1, &C1); + case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs); + __builtin_vsx_disassemble_pair(c[2], &arr[2]); + c1[2] = c[2][0]; c2[2] = c[2][1]; + case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs); + __builtin_vsx_disassemble_pair(c[1], &arr[1]); + c1[1] = c[1][0]; c2[1] = c[1][1]; + case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs); + __builtin_vsx_disassemble_pair(c[0], &arr[0]); + c1[0] = c[0][0]; c2[0] = c[0][1]; break; } - t1 = vec_perm(c1[0], c2[0], swiz1); - t2 = vec_perm(c1[0], c2[0], swiz2); - t3 = vec_perm(c3[0], c4[0], swiz1); - t4 = vec_perm(c3[0], c4[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); - } - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset+16); - vec_xst(t7, 0, vecOffset+32); - vec_xst(t8, 0, vecOffset+48); - - t1 = vec_perm(c1[1], c2[1], swiz1); - t2 = vec_perm(c1[1], c2[1], swiz2); - t3 = vec_perm(c3[1], c4[1], swiz1); - t4 = vec_perm(c3[1], c4[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); - } - vec_xst(t5, 0, vecOffset+64); - vec_xst(t6, 0, vecOffset+80); - vec_xst(t7, 0, vecOffset+96); - vec_xst(t8, 0, vecOffset+112); - - aoffset1 += lda; - aoffset2 += lda; - aoffset3 += lda; + vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); + vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); + for (int it = 0; it < 3; it++) + aoffsets[it] += lda; vecOffset += 128; i--; } while(i > 0); @@ -2281,159 +1879,42 @@ class tinyBLAS_Q0_PPC { } void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t mc, nc, mp, np; - int m_rem = MIN(m - m0, 8); - int n_rem = MIN(n - n0, 8); - // TO-DO: KERNEL_16x8 and KERNEL_8x16 are having some performance - // issues. After resolving them, below code will be enabled. - /*if (m_rem >= 16 && n_rem >= 8) { - mc = 16; - nc = 8; - gemm<16,8>(m0, m, n0, n); - } else if(m_rem >= 8 && n_rem >= 16) { - mc = 8; - nc = 16; - gemm<8,16>(m0, m, n0, n); - }*/ + int m_rem = MIN(m - m0, 16); + int n_rem = MIN(n - n0, 16); + + int mc = 0, nc = 0; + if (m_rem >= 8 && n_rem >= 8) { - mc = 8; - nc = 8; - gemm<8,8>(m0, m, n0, n); + mc = 8; + nc = 8; + gemm<8, 8>(m0, m, n0, n); } else if (m_rem >= 4 && n_rem >= 8) { mc = 4; nc = 8; - gemm<4,8>(m0, m, n0, n); + gemm<4, 8>(m0, m, n0, n); } else if (m_rem >= 8 && n_rem >= 4) { mc = 8; nc = 4; - gemm<8,4>(m0, m, n0, n); + gemm<8, 4>(m0, m, n0, n); } else if (m_rem >= 4 && n_rem >= 4) { mc = 4; nc = 4; - gemm_small<4, 4>(m0, m, n0, n); - } else if ((m_rem < 4) && (n_rem > 4)) { - nc = 4; - switch(m_rem) { - case 1: - mc = 1; - gemm_small<1, 4>(m0, m, n0, n); - break; - case 2: - mc = 2; - gemm_small<2, 4>(m0, m, n0, n); - break; - case 3: - mc = 3; - gemm_small<3, 4>(m0, m, n0, n); - break; - default: - return; - } - } else if ((m_rem > 4) && (n_rem < 4)) { - mc = 4; - switch(n_rem) { - case 1: - nc = 1; - gemm_small<4, 1>(m0, m, n0, n); - break; - case 2: - nc = 2; - gemm_small<4, 2>(m0, m, n0, n); - break; - case 3: - nc = 3; - gemm_small<4, 3>(m0, m, n0, n); - break; - default: - return; - } + gemm_small(m0, m, n0, n, mc, nc); } else { - switch((m_rem << 4) | n_rem) { - case 0x43: - mc = 4; - nc = 3; - gemm_small<4, 3>(m0, m, n0, n); - break; - case 0x42: - mc = 4; - nc = 2; - gemm_small<4, 2>(m0, m, n0, n); - break; - case 0x41: - mc = 4; - nc = 1; - gemm_small<4, 1>(m0, m, n0, n); - break; - case 0x34: - mc = 3; - nc = 4; - gemm_small<3, 4>(m0, m, n0, n); - break; - case 0x33: - mc = 3; - nc = 3; - gemm_small<3, 3>(m0, m, n0, n); - break; - case 0x32: - mc = 3; - nc = 2; - gemm_small<3, 2>(m0, m, n0, n); - break; - case 0x31: - mc = 3; - nc = 1; - gemm_small<3, 1>(m0, m, n0, n); - break; - case 0x24: - mc = 2; - nc = 4; - gemm_small<2, 4>(m0, m, n0, n); - break; - case 0x23: - mc = 2; - nc = 3; - gemm_small<2, 3>(m0, m, n0, n); - break; - case 0x22: - mc = 2; - nc = 2; - gemm_small<2, 2>(m0, m, n0, n); - break; - case 0x21: - mc = 2; - nc = 1; - gemm_small<2, 1>(m0, m, n0, n); - break; - case 0x14: - mc = 1; - nc = 4; - gemm_small<1, 4>(m0, m, n0, n); - break; - case 0x13: - mc = 1; - nc = 3; - gemm_small<1, 3>(m0, m, n0, n); - break; - case 0x12: - mc = 1; - nc = 2; - gemm_small<1, 2>(m0, m, n0, n); - break; - case 0x11: - mc = 1; - nc = 1; - gemm_small<1, 1>(m0, m, n0, n); - break; - default: - return; - } + mc = (m_rem >= 4) ? 4 : m_rem; + nc = (n_rem >= 4) ? 4 : n_rem; + if (mc == 0 || nc == 0) + return; + gemm_small(m0, m, n0, n, mc, nc); } - mp = m0 + (m - m0) / mc * mc; - np = n0 + (n - n0) / nc * nc; + + int64_t mp = m0 + ((m - m0) / mc) * mc; + int64_t np = n0 + ((n - n0) / nc) * nc; mnpack(mp, m, n0, np); mnpack(m0, m, np, n); } + void KERNEL_4x8(int64_t ii, int64_t jj) { vec_t vec_A[8], vec_B[16] = {0}; acc_t acc_0, acc_1; @@ -2445,9 +1926,9 @@ class tinyBLAS_Q0_PPC { __builtin_mma_xxsetaccz(&acc_0); __builtin_mma_xxsetaccz(&acc_1); if (std::is_same_v) { - packNormalInt4((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray); + packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray); } else { - packNormal((const TB*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false); } packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true); for(int x = 0; x < 8; x++) { @@ -2475,8 +1956,8 @@ class tinyBLAS_Q0_PPC { compute<4>(&acc_0, 0, 0, comparray, vs, fin_res); compute<4>(&acc_1, 0, 4, comparray, vs, fin_res); } - save_res<4, 4>(ii, jj, 0, fin_res); - save_res<4, 4>(ii, jj+4, 4, fin_res); + save_res(ii, jj, 0, fin_res); + save_res(ii, jj+4, 4, fin_res); } void KERNEL_8x4(int64_t ii, int64_t jj) { @@ -2490,9 +1971,9 @@ class tinyBLAS_Q0_PPC { __builtin_mma_xxsetaccz(&acc_0); __builtin_mma_xxsetaccz(&acc_1); if (std::is_same_v) { - packNormalInt4((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray); + packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray); } else { - packNormal((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); } packNormal((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true); for(int x = 0; x < 8; x++) { @@ -2519,8 +2000,8 @@ class tinyBLAS_Q0_PPC { compute<8>(&acc_0, 0, 0, comparray, vs, fin_res); compute<8>(&acc_1, 4, 4, comparray, vs, fin_res); } - save_res<4, 4>(ii, jj, 0, fin_res); - save_res<4, 4>(ii+4, jj, 4, fin_res); + save_res(ii, jj, 0, fin_res); + save_res(ii+4, jj, 4, fin_res); } void KERNEL_8x8(int64_t ii, int64_t jj) { @@ -2536,9 +2017,9 @@ class tinyBLAS_Q0_PPC { __builtin_mma_xxsetaccz(&acc_2); __builtin_mma_xxsetaccz(&acc_3); if (std::is_same_v) { - packNormalInt4((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray); + packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray); } else { - packNormal((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); } packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true); for(int x = 0; x < 8; x++) { @@ -2570,14 +2051,13 @@ class tinyBLAS_Q0_PPC { compute<8>(&acc_2, 0, 8, comparray, vs, fin_res); compute<8>(&acc_3, 4, 12, comparray, vs, fin_res); } - save_res<4, 4>(ii, jj, 0, fin_res); - save_res<4, 4>(ii+4, jj, 4, fin_res); - save_res<4, 4>(ii, jj+4, 8, fin_res); - save_res<4, 4>(ii+4, jj+4, 12, fin_res); + save_res(ii, jj, 0, fin_res); + save_res(ii+4, jj, 4, fin_res); + save_res(ii, jj+4, 8, fin_res); + save_res(ii+4, jj+4, 12, fin_res); } - template - void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) { + void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) { int64_t ytiles = (m - m0) / RM; int64_t xtiles = (n - n0) / RN; int64_t tiles = xtiles * ytiles; @@ -2606,9 +2086,9 @@ class tinyBLAS_Q0_PPC { __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead __builtin_mma_xxsetaccz(&acc_0); if (isAblock_q4) { - packNormalInt4((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray); + packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray); } else { - packNormal((const TB*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false); } packNormal((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true); for(int x = 0; x < 8; x+=4) { @@ -2641,7 +2121,7 @@ class tinyBLAS_Q0_PPC { fin_res[i] = vec_madd(res[i], vs[i], fin_res[i]); } } - save_res(ii, jj, 0, fin_res); + save_res(ii, jj, 0, fin_res, RM, RN); } } @@ -2654,7 +2134,7 @@ class tinyBLAS_Q0_PPC { } else if constexpr(RM == 8 && RN == 8) { KERNEL_8x8(ii,jj); } else { - static_assert(false, "RN/RM values not supported"); + assert(false && "RN/RM values not supported"); } } @@ -2676,10 +2156,8 @@ class tinyBLAS_Q0_PPC { } const TA *const A; - const TB *const B; - TC *C; - TA *At; - TB *Bt; + const block_q8_0 *const B; + float *C; const int64_t k; const int64_t lda; const int64_t ldb; @@ -2688,13 +2166,12 @@ class tinyBLAS_Q0_PPC { const int nth; }; -template class tinyBLAS_PPC { public: tinyBLAS_PPC(int64_t k, - const TA *A, int64_t lda, - const TB *B, int64_t ldb, - TC *C, int64_t ldc, + const float *A, int64_t lda, + const float *B, int64_t ldb, + float *C, int64_t ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } @@ -2707,247 +2184,139 @@ class tinyBLAS_PPC { void (tinyBLAS_PPC::*kernel)(int64_t, int64_t); - template - void packTranspose(const TA* a, int64_t lda, int rows, int cols, TA* vec) { + inline void vector_permute_store_4(vector float *src, float *vecOffset) { + vector float t1, t2, t3, t4, t5, t6, t7, t8; + t1 = vec_mergeh(src[0], src[1]); + t2 = vec_mergeh(src[2], src[3]); + t3 = vec_mergel(src[0], src[1]); + t4 = vec_mergel(src[2], src[3]); + + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t1, t2, 3); + t7 = vec_xxpermdi(t3, t4, 0); + t8 = vec_xxpermdi(t3, t4, 3); + + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset + 4); + vec_xst(t7, 0, vecOffset + 8); + vec_xst(t8, 0, vecOffset + 12); + } + + inline void vector_permute_store_8(vector float *src, float *vecOffset) { + vector float t1, t2, t3, t4, t5, t6, t7, t8; + t1 = vec_mergeh(src[0], src[1]); + t2 = vec_mergeh(src[2], src[3]); + t3 = vec_mergeh(src[4], src[5]); + t4 = vec_mergeh(src[6], src[7]); + + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t3, t4, 0); + t7 = vec_xxpermdi(t1, t2, 3); + t8 = vec_xxpermdi(t3, t4, 3); + + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset + 4); + vec_xst(t7, 0, vecOffset + 8); + vec_xst(t8, 0, vecOffset + 12); + + t1 = vec_mergel(src[0], src[1]); + t2 = vec_mergel(src[2], src[3]); + t3 = vec_mergel(src[4], src[5]); + t4 = vec_mergel(src[6], src[7]); + + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t3, t4, 0); + t7 = vec_xxpermdi(t1, t2, 3); + t8 = vec_xxpermdi(t3, t4, 3); + + vec_xst(t5, 0, vecOffset + 16); + vec_xst(t6, 0, vecOffset + 20); + vec_xst(t7, 0, vecOffset + 24); + vec_xst(t8, 0, vecOffset + 28); + } + + void packTranspose(const float* a, int64_t lda, int rows, int cols, float* vec) { int64_t i, j; - TA *aoffset = NULL, *boffset = NULL; - TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; - TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; - __vector_pair C1, C2, C3, C4, C5, C6, C7, C8; - VA c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0}; - VA c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0}; - VA t1, t2, t3, t4, t5, t6, t7, t8; - aoffset = const_cast(a); + float * aoffsets[8]; + float *aoffset = NULL, *boffset = NULL; + __vector_pair arr[8]; + vector float c[8][2] = {0}; + vector float c1[8] = {0}; + vector float c2[8] = {0}; + aoffset = const_cast(a); boffset = vec; j = (rows >> 3); if (j > 0) { do { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; - aoffset4 = aoffset3 + lda; - aoffset5 = aoffset4 + lda; - aoffset6 = aoffset5 + lda; - aoffset7 = aoffset6 + lda; - aoffset8 = aoffset7 + lda; + aoffsets[0] = aoffset; + for (int it = 1; it< 8; it++) + aoffsets[it] = aoffsets[it-1] + lda; aoffset += 8 * lda; i = (cols >> 3); if (i > 0) { do { - C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1); - C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2); - C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3); - C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4); - C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5); - C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6); - C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7); - C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8); - __builtin_vsx_disassemble_pair(c1, &C1); - __builtin_vsx_disassemble_pair(c2, &C2); - __builtin_vsx_disassemble_pair(c3, &C3); - __builtin_vsx_disassemble_pair(c4, &C4); - __builtin_vsx_disassemble_pair(c5, &C5); - __builtin_vsx_disassemble_pair(c6, &C6); - __builtin_vsx_disassemble_pair(c7, &C7); - __builtin_vsx_disassemble_pair(c8, &C8); - - t1 = vec_mergeh(c1[0], c2[0]); - t2 = vec_mergeh(c3[0], c4[0]); - t3 = vec_mergeh(c5[0], c6[0]); - t4 = vec_mergeh(c7[0], c8[0]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset); - vec_xst(t6, 0, boffset+4); - vec_xst(t7, 0, boffset+8); - vec_xst(t8, 0, boffset+12); - - t1 = vec_mergel(c1[0], c2[0]); - t2 = vec_mergel(c3[0], c4[0]); - t3 = vec_mergel(c5[0], c6[0]); - t4 = vec_mergel(c7[0], c8[0]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset+16); - vec_xst(t6, 0, boffset+20); - vec_xst(t7, 0, boffset+24); - vec_xst(t8, 0, boffset+28); - - t1 = vec_mergeh(c1[1], c2[1]); - t2 = vec_mergeh(c3[1], c4[1]); - t3 = vec_mergeh(c5[1], c6[1]); - t4 = vec_mergeh(c7[1], c8[1]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset+32); - vec_xst(t6, 0, boffset+36); - vec_xst(t7, 0, boffset+40); - vec_xst(t8, 0, boffset+44); - - t1 = vec_mergel(c1[1], c2[1]); - t2 = vec_mergel(c3[1], c4[1]); - t3 = vec_mergel(c5[1], c6[1]); - t4 = vec_mergel(c7[1], c8[1]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset+48); - vec_xst(t6, 0, boffset+52); - vec_xst(t7, 0, boffset+56); - vec_xst(t8, 0, boffset+60); - - aoffset1 += 8*lda; - aoffset2 += 8*lda; - aoffset3 += 8*lda; - aoffset4 += 8*lda; + for (int it = 0; it< 8; it++) { + arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]); + __builtin_vsx_disassemble_pair(c[it], &arr[it]); + c1[it] = c[it][0]; + c2[it] = c[it][1]; + } + + vector_permute_store_8(c1, boffset); + vector_permute_store_8(c2, boffset+32); + for (int it = 0; it < 4; it++) + aoffsets[it] = aoffsets[it] + 8*lda; boffset += 64; i--; } while(i > 0); } if (cols & 4) { - c1[0] = vec_xl(0, aoffset1); - c2[0] = vec_xl(0, aoffset2); - c3[0] = vec_xl(0, aoffset3); - c4[0] = vec_xl(0, aoffset4); - c5[0] = vec_xl(0, aoffset5); - c6[0] = vec_xl(0, aoffset6); - c7[0] = vec_xl(0, aoffset7); - c8[0] = vec_xl(0, aoffset8); - - t1 = vec_mergeh(c1[0], c2[0]); - t2 = vec_mergeh(c3[0], c4[0]); - t3 = vec_mergeh(c5[0], c6[0]); - t4 = vec_mergeh(c7[0], c8[0]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset); - vec_xst(t6, 0, boffset+4); - vec_xst(t7, 0, boffset+8); - vec_xst(t8, 0, boffset+12); - - t1 = vec_mergel(c1[0], c2[0]); - t2 = vec_mergel(c3[0], c4[0]); - t3 = vec_mergel(c5[0], c6[0]); - t4 = vec_mergel(c7[0], c8[0]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset+16); - vec_xst(t6, 0, boffset+20); - vec_xst(t7, 0, boffset+24); - vec_xst(t8, 0, boffset+28); + for (int it = 0; it < 8 ; it++) + c1[it] = vec_xl(0, aoffsets[it]); + vector_permute_store_8(c1, boffset); } j--; } while(j > 0); } if (rows & 4) { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; - aoffset4 = aoffset3 + lda; + aoffsets[0] = aoffset; + for (int it = 1; it < 4; it++) + aoffsets[it] = aoffsets[it-1] + lda; aoffset += 4 * lda; i = (cols >> 3); if (i > 0) { do { - C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1); - C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2); - C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3); - C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4); - __builtin_vsx_disassemble_pair(c1, &C1); - __builtin_vsx_disassemble_pair(c2, &C2); - __builtin_vsx_disassemble_pair(c3, &C3); - __builtin_vsx_disassemble_pair(c4, &C4); - - t1 = vec_mergeh(c1[0], c2[0]); - t2 = vec_mergeh(c3[0], c4[0]); - t3 = vec_mergel(c1[0], c2[0]); - t4 = vec_mergel(c3[0], c4[0]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t1, t2, 3); - t7 = vec_xxpermdi(t3, t4, 0); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset); - vec_xst(t6, 0, boffset+4); - vec_xst(t7, 0, boffset+8); - vec_xst(t8, 0, boffset+12); - - t1 = vec_mergeh(c1[1], c2[1]); - t2 = vec_mergeh(c3[1], c4[1]); - t3 = vec_mergel(c1[1], c2[1]); - t4 = vec_mergel(c3[1], c4[1]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t1, t2, 3); - t7 = vec_xxpermdi(t3, t4, 0); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset+16); - vec_xst(t6, 0, boffset+20); - vec_xst(t7, 0, boffset+24); - vec_xst(t8, 0, boffset+28); - - aoffset1 += 8*lda; - aoffset2 += 8*lda; - aoffset3 += 8*lda; - aoffset4 += 8*lda; + for (int it = 0; it < 4; it++) { + arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]); + __builtin_vsx_disassemble_pair(c[it], &arr[it]); + c1[it] = c[it][0]; + c2[it] = c[it][1]; + } + vector_permute_store_4(c1, boffset); + vector_permute_store_4(c2, boffset+16); + for (int it = 0; it < 4; it++) + aoffsets[it] += 8*lda; boffset += 32; i--; } while(i > 0); } if (cols & 4) { - c1[0] = vec_xl(0, aoffset1); - c2[0] = vec_xl(0, aoffset2); - c3[0] = vec_xl(0, aoffset3); - c4[0] = vec_xl(0, aoffset4); - - t1 = vec_mergeh(c1[0], c2[0]); - t2 = vec_mergeh(c3[0], c4[0]); - t3 = vec_xxpermdi(t1, t2, 0); - t4 = vec_xxpermdi(t1, t2, 3); - vec_xst(t3, 0, boffset); - vec_xst(t4, 0, boffset+4); - - t1 = vec_mergel(c1[0], c2[0]); - t2 = vec_mergel(c3[0], c4[0]); - t3 = vec_xxpermdi(t1, t2, 0); - t4 = vec_xxpermdi(t1, t2, 3); - vec_xst(t3, 0, boffset+8); - vec_xst(t4, 0, boffset+12); + for (int it = 0; it < 4; it++) + c1[it] = vec_xl(0, aoffsets[it]); + vector_permute_store_4(c1, boffset); } } if (rows & 3) { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; + aoffsets[0] = aoffset; + for (int it = 1; it < 3; it++) + aoffsets[it] = aoffsets[it-1] + lda; if (cols & 4) { - c1[0] = vec_xl(0, aoffset1); - c2[0] = vec_xl(0, aoffset2); - c3[0] = vec_xl(0, aoffset3); - - t1 = vec_mergeh(c1[0], c2[0]); - t2 = vec_mergeh(c3[0], c4[0]); - t3 = vec_xxpermdi(t1, t2, 0); - t4 = vec_xxpermdi(t1, t2, 3); - vec_xst(t3, 0, boffset); - vec_xst(t4, 0, boffset+4); - - t1 = vec_mergel(c1[0], c2[0]); - t2 = vec_mergel(c3[0], c4[0]); - t3 = vec_xxpermdi(t1, t2, 0); - t4 = vec_xxpermdi(t1, t2, 3); - vec_xst(t3, 0, boffset+8); - vec_xst(t4, 0, boffset+12); + for (int it = 0; it < 3; it++) + c1[it] = vec_xl(0, aoffsets[it]); + vector_permute_store_4(c1, boffset); } } } @@ -2957,8 +2326,8 @@ class tinyBLAS_PPC { acc_t acc_0; __builtin_mma_xxsetaccz(&acc_0); for (int l = 0; l < k; l+=4) { - packTranspose(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B); + packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B); __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]); @@ -2973,8 +2342,8 @@ class tinyBLAS_PPC { __builtin_mma_xxsetaccz(&acc_0); __builtin_mma_xxsetaccz(&acc_1); for (int64_t l = 0; l < k; l+=4) { - packTranspose(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, 8, 4, (TA*)vec_B); + packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B); __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]); __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]); @@ -2994,8 +2363,8 @@ class tinyBLAS_PPC { __builtin_mma_xxsetaccz(&acc_0); __builtin_mma_xxsetaccz(&acc_1); for (int64_t l = 0; l < k; l+=4) { - packTranspose(A+(ii*lda)+l, lda, 8, 4, (TA*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B); + packTranspose(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B); __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]); @@ -3017,8 +2386,8 @@ class tinyBLAS_PPC { __builtin_mma_xxsetaccz(&acc_2); __builtin_mma_xxsetaccz(&acc_3); for (int l = 0; l < k; l+=8) { - packTranspose(A+(ii*lda)+l, lda, 8, 8, (TA*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, 8, 8, (TA*)vec_B); + packTranspose(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B); for(int x = 0; x < 16; x+=2) { __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]); __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]); @@ -3033,155 +2402,37 @@ class tinyBLAS_PPC { } void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t mc, nc, mp, np; - int m_rem = MIN(m - m0, 16); - int n_rem = MIN(n - n0, 16); - if (m_rem >= 16 && n_rem >= 8) { - mc = 8; - nc = 8; - gemm<8,8>(m0, m, n0, n); - } else if(m_rem >= 8 && n_rem >= 16) { - mc = 8; - nc = 8; - gemm<8,8>(m0, m, n0, n); - } else if (m_rem >= 8 && n_rem >= 8) { - mc = 8; - nc = 8; - gemm<8,8>(m0, m, n0, n); + int m_rem = MIN(m - m0, 8); + int n_rem = MIN(n - n0, 8); + int mc = 0, nc = 0; + if (m_rem >= 8 && n_rem >= 8) { + mc = 8; + nc = 8; + gemm<8, 8>(m0, m, n0, n); } else if (m_rem >= 4 && n_rem >= 8) { - mc = 4; - nc = 8; - gemm<4,8>(m0, m, n0, n); + mc = 4; + nc = 8; + gemm<4, 8>(m0, m, n0, n); } else if (m_rem >= 8 && n_rem >= 4) { - mc = 8; - nc = 4; - gemm<8,4>(m0, m, n0, n); + mc = 8; + nc = 4; + gemm<8, 4>(m0, m, n0, n); } else if (m_rem >= 4 && n_rem >= 4) { - mc = 4; - nc = 4; - gemm<4,4>(m0, m, n0, n); - } else if ((m_rem < 4) && (n_rem > 4)) { - nc = 4; - switch(m_rem) { - case 1: - mc = 1; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 2: - mc = 2; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 3: - mc = 3; - gemm_small(m0, m, n0, n, mc, nc); - break; - default: - return; - } - } else if ((m_rem > 4) && (n_rem < 4)) { - mc = 4; - switch(n_rem) { - case 1: - nc = 1; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 2: - nc = 2; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 3: - nc = 3; - gemm_small(m0, m, n0, n, mc, nc); - break; - default: - return; - } + mc = 4; + nc = 4; + gemm<4, 4>(m0, m, n0, n); } else { - switch((m_rem << 4) | n_rem) { - case 0x43: - mc = 4; - nc = 3; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x42: - mc = 4; - nc = 2; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x41: - mc = 4; - nc = 1; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x34: - mc = 3; - nc = 4; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x33: - mc = 3; - nc = 3; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x32: - mc = 3; - nc = 2; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x31: - mc = 3; - nc = 1; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x24: - mc = 2; - nc = 4; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x23: - mc = 2; - nc = 3; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x22: - mc = 2; - nc = 2; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x21: - mc = 2; - nc = 1; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x14: - mc = 1; - nc = 4; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x13: - mc = 1; - nc = 3; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x12: - mc = 1; - nc = 2; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x11: - mc = 1; - nc = 1; - gemm_small(m0, m, n0, n, mc, nc); - break; - default: - return; - } + mc = (m_rem >= 4) ? 4 : m_rem; + nc = (n_rem >= 4) ? 4 : n_rem; + if (mc == 0 || nc == 0) + return; + gemm_small(m0, m, n0, n, mc, nc); } - mp = m0 + (m - m0) / mc * mc; - np = n0 + (n - n0) / nc * nc; + int64_t mp = m0 + ((m - m0) / mc) * mc; + int64_t np = n0 + ((n - n0) / nc) * nc; mnpack(mp, m, n0, np); mnpack(m0, m, np, n); - } + } void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) { int64_t ytiles = (m - m0) / RM; @@ -3206,22 +2457,22 @@ class tinyBLAS_PPC { * matrix elements. */ if (RM == 1) { - TA* a = const_cast(A+(ii)*lda+l); - packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B); + float* a = const_cast(A+(ii)*lda+l); + packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B); vec_A[0] = (vec_t)vec_xl(0,a); - vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1)); - vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2)); - vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3)); + vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1)); + vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2)); + vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3)); } else if (RN == 1) { - packTranspose(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A); - TB* b = const_cast(B+(jj)*ldb+l); + packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A); + float* b = const_cast(B+(jj)*ldb+l); vec_B[0] = (vec_t)vec_xl(0,b); - vec_B[1] = (vec_t)vec_splats(*((TB*)&vec_B+1)); - vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2)); - vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3)); + vec_B[1] = (vec_t)vec_splats(*((float*)&vec_B+1)); + vec_B[2] = (vec_t)vec_splats(*((float*)&vec_B+2)); + vec_B[3] = (vec_t)vec_splats(*((float*)&vec_B+3)); } else { - packTranspose(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B); + packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B); } __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]); @@ -3231,7 +2482,7 @@ class tinyBLAS_PPC { __builtin_mma_disassemble_acc(vec_C, &acc_0); for (int I = 0; I < RM; I++) { for (int J = 0; J < RN; J++) { - *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J); + *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J); } } } @@ -3263,11 +2514,9 @@ class tinyBLAS_PPC { } } - const TA *const A; - const TB *const B; - TC *C; - TA *At; - TB *Bt; + const float *const A; + const float *const B; + float *C; const int64_t k; const int64_t lda; const int64_t ldb; @@ -3366,7 +2615,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 #elif defined(__MMA__) if (k % 8) return false; - tinyBLAS_PPC tb{ + tinyBLAS_PPC tb{ k, (const float *)A, lda, (const float *)B, ldb, (float *)C, ldc, @@ -3493,7 +2742,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 return false; if (m < 8 && m != 4) return false; - tinyBLAS_Q0_PPC tb{ + tinyBLAS_Q0_PPC tb{ k, (const block_q8_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, @@ -3530,7 +2779,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 return false; if (m < 8 && m != 4) return false; - tinyBLAS_Q0_PPC tb{ + tinyBLAS_Q0_PPC tb{ k, (const block_q4_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index fd77e9a6abad5..6581d27adde2e 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -4015,6 +4015,9 @@ static void ggml_compute_forward_rms_norm_f32( const float scale = 1.0f/sqrtf(mean + eps); + // if you hit this, likely you got an inf somewhere earlier + assert(scale > 0.0f); + ggml_vec_scale_f32(ne00, y, scale); } } diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index a8156011eba2d..07b377bdd82a7 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -221,6 +221,9 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G for (int i = np; i < n; ++i) { sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i])); } + + // if you hit this, you are likely running outside the FP range + assert(!isnan(sumf) && !isinf(sumf)); #else for (int i = 0; i < n; ++i) { sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i])); diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 075f14a49e9ac..9122fca6cf99f 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -33,8 +33,10 @@ typedef void (* fattn_kernel_t)( const int ne13, const int ne31, const int ne32, + const int ne33, const int nb31, const int nb32, + const int nb33, const int nb01, const int nb02, const int nb03, @@ -521,7 +523,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) { template // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_stream_k_fixup( - float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) { + float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) { constexpr int ncols = ncols1*ncols2; const int bidx0 = blockIdx.x; @@ -535,8 +537,8 @@ static __global__ void flash_attn_stream_k_fixup( const int iter_k = ne11 / FATTN_KQ_STRIDE; const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; - const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; - const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; + const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; const bool did_not_have_any_data = kbc0 == kbc0_stop; const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; @@ -545,14 +547,15 @@ static __global__ void flash_attn_stream_k_fixup( return; } - const int channel = kbc0 / (iter_k*iter_j); - const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k; + const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2)); + const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); + const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile. if (jt*ncols1 + j >= ne01) { return; } - dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid; + dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid; // Load the partial result that needs a fixup: float dst_val = 0.0f; @@ -571,7 +574,7 @@ static __global__ void flash_attn_stream_k_fixup( int bidx = bidx0 - 1; int kbc_stop = kbc0; while(true) { - const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; if (kbc == kbc_stop) { // Did not have any data. bidx--; kbc_stop = kbc; @@ -617,16 +620,31 @@ static __global__ void flash_attn_combine_results( const float2 * __restrict__ VKQ_meta, float * __restrict__ dst, const int parallel_blocks) { - VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x; - VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x; - dst += D * gridDim.z*blockIdx.x; + // Dimension 0: threadIdx.x + // Dimension 1: blockIdx.x + // Dimension 2: blockIdx.y + // Dimension 3: blockIdx.z + // Memory layout is permuted with [0, 2, 1, 3] + + const int ne01 = gridDim.x; + const int ne02 = gridDim.y; + + const int col = blockIdx.x; + const int head = blockIdx.y; + const int sequence = blockIdx.z; + + const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head; + + VKQ_parts += j_dst_unrolled * parallel_blocks*D; + VKQ_meta += j_dst_unrolled * parallel_blocks; + dst += j_dst_unrolled * D; const int tid = threadIdx.x; __builtin_assume(tid < D); extern __shared__ float2 meta[]; for (int i = tid; i < 2*parallel_blocks; i += D) { - ((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i]; + ((float *) meta)[i] = ((const float *)VKQ_meta) [i]; } __syncthreads(); @@ -644,11 +662,11 @@ static __global__ void flash_attn_combine_results( const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); *((uint32_t *) &KQ_max_scale) &= ftz_mask; - VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid]; + VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid]; VKQ_denominator += KQ_max_scale * meta[l].y; } - dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator; + dst[tid] = VKQ_numerator / VKQ_denominator; } [[noreturn]] @@ -705,8 +723,6 @@ void launch_fattn( GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); - GGML_ASSERT(Q->ne[3] == 1); - ggml_cuda_pool & pool = ctx.pool(); cudaStream_t main_stream = ctx.stream(); const int id = ggml_cuda_get_device(); @@ -853,8 +869,8 @@ void launch_fattn( scale, max_bias, m0, m1, n_head_log2, logit_softcap, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, - mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, + mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0, + mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0, Q->nb[1], Q->nb[2], Q->nb[3], nb11, nb12, nb13, nb21, nb22, nb23, @@ -869,11 +885,11 @@ void launch_fattn( flash_attn_stream_k_fixup <<>> - ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]); + ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]); } } else if (parallel_blocks > 1) { const dim3 block_dim_combine(DV, 1, 1); - const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z); + const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]); const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2); flash_attn_combine_results diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 709589854f0af..6fa2e77299eb0 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -1224,8 +1224,10 @@ static __global__ void flash_attn_ext_f16( const int ne13, const int ne31, const int ne32, + const int ne33, const int nb31, const int nb32, + const int nb33, const int nb01, const int nb02, const int nb03, @@ -1274,8 +1276,8 @@ static __global__ void flash_attn_ext_f16( constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice. // kbc == k block continuous, current index in continuous ijk space. - int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; - const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; + int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; + const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined. // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup). @@ -1285,18 +1287,19 @@ static __global__ void flash_attn_ext_f16( int kb0_start = kbc % iter_k; int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc); while (kbc < kbc_stop && kb0_stop == iter_k) { - const int channel = kbc / (iter_k*iter_j); - const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. + const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2)); + const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); + const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile. - const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); - const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2)); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio)); const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : - (const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1); - float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); + (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1); + float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); + const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio)); - const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f; const int kb0_start_kernel = kb0_start * kb_niter; const int kb0_stop_kernel = kb0_stop * kb_niter; @@ -1325,18 +1328,19 @@ static __global__ void flash_attn_ext_f16( return; } - const int channel = kbc / (iter_k*iter_j); - const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. + const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2)); + const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); + const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile. - const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); - const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2)); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio)); const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : - (const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1); - float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); + (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1); + float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); + const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio)); - const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f; const int kb0_start_kernel = kb0_start * kb_niter; const int kb0_stop_kernel = kb0_stop * kb_niter; diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index 0c967f178e7b1..1f141328845a4 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f16( const int ne13, const int ne31, const int ne32, + const int ne33, const int nb31, const int nb32, + const int nb33, const int nb01, const int nb02, const int nb03, @@ -62,15 +64,17 @@ static __global__ void flash_attn_tile_ext_f16( const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); + const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0); + const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); const int stride_KV2 = nb11 / sizeof(half2); - const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); + const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); @@ -255,6 +259,8 @@ static __global__ void flash_attn_tile_ext_f16( __syncthreads(); } + float2 * dst2 = (float2 *) dst; + #pragma unroll for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { const int j_VKQ = j_VKQ_0 + threadIdx.y; @@ -266,21 +272,21 @@ static __global__ void flash_attn_tile_ext_f16( half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]); kqsum_j = warp_reduce_sum((float)kqsum_j); + const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; + #pragma unroll - for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) { - const int i0 = i00 + 2*threadIdx.x; + for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) { + const int i0 = i00 + threadIdx.x; - half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)]; + half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE]; if (gridDim.y == 1) { dst_val /= __half2half2(kqsum_j); } - const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; - dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = __low2float(dst_val); - dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val); + dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val); } if (gridDim.y != 1 && threadIdx.x == 0) { - dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); + dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); } } #else @@ -290,8 +296,8 @@ static __global__ void flash_attn_tile_ext_f16( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index 908c76dbdd270..a4965583cef1c 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f32( const int ne13, const int ne31, const int ne32, + const int ne33, const int nb31, const int nb32, + const int nb33, const int nb01, const int nb02, const int nb03, @@ -74,15 +76,17 @@ static __global__ void flash_attn_tile_ext_f32( const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); + const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0); + const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); const int stride_KV2 = nb11 / sizeof(half2); - const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); + const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); @@ -265,6 +269,8 @@ static __global__ void flash_attn_tile_ext_f32( __syncthreads(); } + float2 * dst2 = (float2 *) dst; + #pragma unroll for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { const int j_VKQ = j_VKQ_0 + threadIdx.y; @@ -276,22 +282,22 @@ static __global__ void flash_attn_tile_ext_f32( float kqsum_j = kqsum[j_VKQ_0/nwarps]; kqsum_j = warp_reduce_sum(kqsum_j); + const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; + #pragma unroll - for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) { - const int i0 = i00 + 2*threadIdx.x; + for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) { + const int i0 = i00 + threadIdx.x; - float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)]; + float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE]; if (gridDim.y == 1) { dst_val.x /= kqsum_j; dst_val.y /= kqsum_j; } - const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; - dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = dst_val.x; - dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = dst_val.y; + dst2[j_dst_unrolled*(D/2) + i0] = dst_val; } if (gridDim.y != 1 && threadIdx.x == 0) { - dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); + dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); } } #else diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index e78fb181919fd..b2d469938abf2 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f16( const int ne13, const int ne31, const int ne32, + const int ne33, const int nb31, const int nb32, + const int nb33, const int nb01, const int nb02, const int nb03, @@ -65,14 +67,16 @@ static __global__ void flash_attn_vec_ext_f16( const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - Q += nb02* blockIdx.z + nb01*ic0; - K += nb12*(blockIdx.z / gqa_ratio); - V += nb22*(blockIdx.z / gqa_ratio); + Q += nb03*sequence + nb02* head + nb01*ic0; + K += nb13*sequence + nb12*(head / gqa_ratio); + V += nb23*sequence + nb22*(head / gqa_ratio); - const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); - const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); + const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); @@ -330,12 +334,11 @@ static __global__ void flash_attn_vec_ext_f16( if (gridDim.y == 1) { dst_val /= kqsum[j_VKQ]; } - const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; - dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val; + dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val; } if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { - dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); + dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); } #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); @@ -344,8 +347,8 @@ static __global__ void flash_attn_vec_ext_f16( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne32); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index b2f1724c95588..405b6f5106ea0 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f32( const int ne13, const int ne31, const int ne32, + const int ne33, const int nb31, const int nb32, + const int nb33, const int nb01, const int nb02, const int nb03, @@ -53,8 +55,8 @@ static __global__ void flash_attn_vec_ext_f32( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); @@ -77,14 +79,16 @@ static __global__ void flash_attn_vec_ext_f32( const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - Q += nb02* blockIdx.z + nb01*ic0; - K += nb12*(blockIdx.z / gqa_ratio); - V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape + Q += nb03*sequence + nb02* head + nb01*ic0; + K += nb13*sequence + nb12*(head / gqa_ratio); + V += nb23*sequence + nb22*(head / gqa_ratio); - const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); - const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); + const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); constexpr int nwarps = D / WARP_SIZE; @@ -326,12 +330,11 @@ static __global__ void flash_attn_vec_ext_f32( if (gridDim.y == 1) { dst_val /= kqsum[j_VKQ]; } - const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; - dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val; + dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val; } if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { - dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); + dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); } #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); @@ -340,8 +343,8 @@ static __global__ void flash_attn_vec_ext_f32( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); - GGML_UNUSED(ne31); GGML_UNUSED(ne32); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); + GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index c95ca7b1f285f..741b8781d29f5 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -47,8 +47,10 @@ static __global__ void flash_attn_ext_f16( const int ne13, const int ne31, const int ne32, + const int ne33, const int nb31, const int nb32, + const int nb33, const int nb01, const int nb02, const int nb03, @@ -95,17 +97,19 @@ static __global__ void flash_attn_ext_f16( constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0); - const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio)); - const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); + const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0); + const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio)); + const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); const half2 * mask2 = (const half2 *) maskh; const int stride_Q = nb01 / sizeof(float); const int stride_KV = nb11 / sizeof(half); - const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); + const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); const half2 slope2 = make_half2(slopef, slopef); @@ -400,7 +404,6 @@ static __global__ void flash_attn_ext_f16( if (ic0 + j_VKQ >= ne01) { return; } - const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; float KQ_rowsum_j; if (std::is_same::value) { @@ -409,6 +412,8 @@ static __global__ void flash_attn_ext_f16( KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]); } + const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; + #pragma unroll for (int i0 = 0; i0 < D; i0 += warp_size) { const int i = i0 + threadIdx.x; @@ -419,7 +424,7 @@ static __global__ void flash_attn_ext_f16( if (gridDim.y == 1) { dst_val /= KQ_rowsum_j; } - dst[j_dst*gridDim.z*D + blockIdx.z*D + i] = dst_val; + dst[j_dst_unrolled*D + i] = dst_val; } if (gridDim.y == 1 || threadIdx.x != 0) { @@ -433,7 +438,7 @@ static __global__ void flash_attn_ext_f16( dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]); } dst_meta_val.y = KQ_rowsum_j; - dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val; + dst_meta[j_dst_unrolled] = dst_meta_val; } #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); @@ -442,7 +447,8 @@ static __global__ void flash_attn_ext_f16( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); - GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); GGML_UNUSED(nb31); + GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 88b17dd682c95..778d5a48bd9f8 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2303,6 +2303,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_EXP: ggml_cuda_op_exp(ctx, dst); break; + case GGML_UNARY_OP_ELU: + ggml_cuda_op_elu(ctx, dst); + break; default: return false; } @@ -3116,6 +3119,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_ELU: return ggml_is_contiguous(op->src[0]); default: return false; @@ -3222,7 +3226,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g } break; case GGML_OP_SET_ROWS: { - return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && +#pragma message("TODO: implement Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)") + return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16) && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64; } break; @@ -3408,12 +3413,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (op->src[0]->ne[0] == 192) { return false; } - // TODO: support broadcast - // note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but - // the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505 - if (op->src[0]->ne[3] != 1) { - return false; - } if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) { return false; } @@ -3426,6 +3425,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) { return true; } + if (op->src[3] && op->src[3]->ne[2] != 1) { + return false; + } return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16; } diff --git a/ggml/src/ggml-cuda/set-rows.cu b/ggml/src/ggml-cuda/set-rows.cu index d8b3e63e1aa57..58cee9244018f 100644 --- a/ggml/src/ggml-cuda/set-rows.cu +++ b/ggml/src/ggml-cuda/set-rows.cu @@ -3,13 +3,21 @@ typedef void (*set_rows_kernel_t)(const char * src, char * dst); template -__device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {} +__device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) { + GGML_UNUSED(src_f); + GGML_UNUSED(dst_f); +} template<> __device__ __forceinline__ void set_rows_1(const float * src_f, half * dst_h) { *dst_h = __float2half(*src_f); } +template<> +__device__ __forceinline__ void set_rows_1(const float * src_f, nv_bfloat16 * dst_b) { + *dst_b = *src_f; +} + template<> __device__ __forceinline__ void set_rows_1(const float * src_f, float * dst_f) { *dst_f = *src_f; @@ -48,6 +56,9 @@ static __global__ void k_set_rows( const src_t* src_elem = src0_row + i00; dst_t* dst_elem = dst_row_ptr + i00; set_rows_1(src_elem, dst_elem); + + GGML_UNUSED(ne10); + GGML_UNUSED(ne13); } template @@ -124,6 +135,16 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { nb1, nb2, nb3, stream ); + } else if (dst->type == GGML_TYPE_BF16) { + set_rows_cuda( + src0_d, src1_d, (nv_bfloat16*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); } else { GGML_ABORT("unsupported type"); } diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index f9c7b83c40d1b..91c830c4dacc3 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -83,6 +83,10 @@ static __device__ __forceinline__ float op_log(float x) { return logf(x); } +static __device__ __forceinline__ float op_elu(float x) { + return (x > 0.f) ? x : expm1f(x); +} + template static __global__ void unary_op_kernel(const T * x, T * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -196,6 +200,9 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_unary(ctx, dst); } +void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} /* gated ops */ template diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 289d690e5cff6..cb14d16f8f3f5 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -59,6 +59,8 @@ void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 83a0739809a6e..44ddc69d08f1c 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -173,6 +173,12 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_SILU, GGML_METAL_KERNEL_TYPE_SILU_4, GGML_METAL_KERNEL_TYPE_ELU, + GGML_METAL_KERNEL_TYPE_ABS, + GGML_METAL_KERNEL_TYPE_SGN, + GGML_METAL_KERNEL_TYPE_STEP, + GGML_METAL_KERNEL_TYPE_HARDSWISH, + GGML_METAL_KERNEL_TYPE_HARDSIGMOID, + GGML_METAL_KERNEL_TYPE_EXP, GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, @@ -1155,6 +1161,12 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ABS, abs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SGN, sgn, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_STEP, step, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSWISH, hardswish, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSIGMOID, hardsigmoid, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_EXP, exp, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction); @@ -1688,6 +1700,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_ELU: case GGML_UNARY_OP_NEG: + case GGML_UNARY_OP_ABS: + case GGML_UNARY_OP_SGN: + case GGML_UNARY_OP_STEP: + case GGML_UNARY_OP_HARDSWISH: + case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_EXP: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; default: return false; @@ -2439,6 +2457,78 @@ static bool ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_UNARY_OP_ABS: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ABS].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_SGN: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SGN].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_STEP: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_STEP].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_HARDSWISH: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSWISH].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_HARDSIGMOID: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSIGMOID].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_EXP: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_EXP].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; default: { GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op)); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 239ec31fbcb58..13235e2885241 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1199,6 +1199,51 @@ kernel void kernel_neg( dst[tpig] = -src0[tpig]; } +kernel void kernel_abs( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = fabs(src0[tpig]); +} + +kernel void kernel_sgn( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = (x > 0.0f) ? 1.0f : ((x < 0.0f) ? -1.0f : 0.0f); +} + +kernel void kernel_step( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] > 0.0f ? 1.0f : 0.0f; +} + +kernel void kernel_hardswish( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); +} + +kernel void kernel_hardsigmoid( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); +} + +kernel void kernel_exp( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = exp(src0[tpig]); +} + kernel void kernel_reglu( device const char * src0, device const char * src1, diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 58830b733a8af..3388259152b46 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -2280,6 +2280,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te { // TODO: add support // ref: https://github.com/ggml-org/llama.cpp/pull/14274 +#pragma message("TODO: implement BF16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)") if (op->src[0]->type != GGML_TYPE_F32) { return false; } diff --git a/ggml/src/ggml-sycl/gemm.hpp b/ggml/src/ggml-sycl/gemm.hpp index 5efe03d364b1b..dcf6c7aeeb4ad 100644 --- a/ggml/src/ggml-sycl/gemm.hpp +++ b/ggml/src/ggml-sycl/gemm.hpp @@ -32,39 +32,28 @@ class DnnlGemmWrapper { else static_assert(0); } - // matrix A has m rows, k columns - // matrix B has k rows, n columns - // nra - number of elements to skip when moving into next row in A - // nrb - number of elements to skip when moving into next row in B - // nca - number of elements to skip when moving into next column in A - // ncb - number of elements to skip when moving into next column in B - // stride_a - number of elements to skip when moving to next A matrix - // stride_b - number of elements to skip when moving to next B matrix - // batches_a - number of A matrices - // batches_b - number of B matrices static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k, - const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a, - const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b, + const void * a, dt at, dnnl_dim_t stra0, dnnl_dim_t stra1, dnnl_dim_t stra2, + const void * b, dt bt, dnnl_dim_t strb0, dnnl_dim_t strb1, dnnl_dim_t strb2, void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) { auto stream = ctx.stream_dnnl(q); auto eng = ctx.engine_dnnl(q); - // { # strides, # rows, # columns } - dnnl::memory::dims a_dims = { batches_a, m, k }; - dnnl::memory::dims b_dims = { batches_b, k, n }; - dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n }; - - // { # elements to skip to next stride, # elements to skip to next row, # elements to skip to next column } - dnnl::memory::dims a_strides = { stride_a, nra, nca }; - dnnl::memory::dims b_strides = { stride_b, nrb, ncb }; - + dnnl::memory::dims a_dims = {batches_a, m, k }; + dnnl::memory::dims a_strides = {stra2, stra1, stra0}; const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides); + + dnnl::memory::dims b_dims = {batches_b, k, n }; + dnnl::memory::dims b_strides = {strb2, strb0, strb1}; const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides); - const auto c_md = dnnl::memory::desc(c_dims, ct, tag::abc); + dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n}; + dnnl::memory::dims c_strides = {m*n, 1, m }; + const auto c_md = dnnl::memory::desc(c_dims, ct, c_strides); dnnl::primitive_attr primitive_attr; primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + #ifdef GGML_SYCL_F16 primitive_attr.set_fpmath_mode(dnnl::fpmath_mode::f16); #endif @@ -76,24 +65,23 @@ class DnnlGemmWrapper { auto scratchpad_md = matmul_pd.scratchpad_desc(); auto scratchpad_mem = ctx.get_scratchpad_mem(scratchpad_md, eng, q); + auto matmul_prim = dnnl::matmul(matmul_pd); std::unordered_map matmul_args; matmul_args.insert({ DNNL_ARG_SRC, a_mem }); matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem }); + matmul_args.insert({ DNNL_ARG_DST, c_mem }); matmul_args.insert({ DNNL_ARG_SCRATCHPAD, scratchpad_mem }); matmul_prim.execute(stream, matmul_args); } - // matrices A and B are column major, both having k rows - // matrix A has m column, matrix B has n columns - // output: column major matrix C = A transposed * B static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k, const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) { - gemm(ctx, m, n, k, a, at, k, 1, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1); + gemm(ctx, m, n, k, a, at, 1, k, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1); } }; diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 65b26fd02766e..a6f9af0c86e11 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -1546,7 +1546,7 @@ static void mul_mat_p021_f16_f32( static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, - const int row_stride_x, const int channel_stride_x, const int channel_x_divisor, + const int row_stride_x, const int channel_stride_x,const int channel_stride_y, const int channel_x_divisor, const sycl::nd_item<3> &item_ct1) { const sycl::half *x = (const sycl::half *)vx; @@ -1557,7 +1557,6 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous item_ct1.get_local_id(0); const int channel_x = channel / channel_x_divisor; - const int nrows_y = ncols_x; const int nrows_dst = nrows_x; const int row_dst = row_x; @@ -1576,7 +1575,7 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous const int row_y = col_x; const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x; - const int iy = channel*nrows_y + row_y; + const int iy = channel * channel_stride_y + row_y; const float xi = sycl::vec(x[ix]) @@ -1823,7 +1822,7 @@ static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y, static void ggml_mul_mat_vec_nc_f16_f32_sycl( const void *vx, const float *y, float *dst, const int ncols_x, const int nrows_x, const int row_stride_x, const int nchannels_x, - const int nchannels_y, const int channel_stride_x, queue_ptr stream) { + const int nchannels_y, const int channel_stride_x, const int channel_stride_y, queue_ptr stream) { const sycl::range<3> block_nums(nchannels_y, nrows_x, 1); const sycl::range<3> block_dims(1, 1, WARP_SIZE); @@ -1835,7 +1834,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x, - row_stride_x, channel_stride_x, + row_stride_x, channel_stride_x, channel_stride_y, nchannels_y / nchannels_x, item_ct1); }); } @@ -2124,8 +2123,8 @@ inline void ggml_sycl_op_mul_mat_sycl( #if GGML_SYCL_DNNL if (!g_ggml_sycl_disable_dnn) { - DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr, - DnnlGemmWrapper::to_dt(), src0_ptr, DnnlGemmWrapper::to_dt(), + DnnlGemmWrapper::row_gemm(ctx,row_diff, src1_ncols , ne10, src0_ptr, + DnnlGemmWrapper::to_dt(), src1_ptr, DnnlGemmWrapper::to_dt(), dst_dd_i, DnnlGemmWrapper::to_dt(), stream); } else @@ -2171,8 +2170,8 @@ inline void ggml_sycl_op_mul_mat_sycl( #if GGML_SYCL_DNNL if (!g_ggml_sycl_disable_dnn) { - DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i, - DnnlGemmWrapper::to_dt(), src0_ddf_i, DnnlGemmWrapper::to_dt(), + DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, src0_ddf_i, + DnnlGemmWrapper::to_dt(), src1_ddf1_i, DnnlGemmWrapper::to_dt(), dst_dd_i, DnnlGemmWrapper::to_dt(), stream); } else @@ -2776,6 +2775,7 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml const int64_t nb02 = src0->nb[2]; const int64_t ne12 = src1->ne[2]; + const int64_t nb11 = src1->nb[1]; SYCL_CHECK(ggml_sycl_set_device(ctx.device)); queue_ptr main_stream = ctx.stream(); @@ -2786,8 +2786,9 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml const int64_t row_stride_x = nb01 / sizeof(sycl::half); const int64_t channel_stride_x = nb02 / sizeof(sycl::half); + const int64_t channel_stride_y = nb11 / sizeof(float); - ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream); + ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x,channel_stride_y, main_stream); } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ @@ -2841,8 +2842,8 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons float * dst_ddf = static_cast(dst->data); const sycl::half * src1_f16 = static_cast(src1->data); + const size_t type_size_src0 = ggml_type_size(src0->type); const size_t type_size_src1 = ggml_type_size(src1->type); - GGML_ASSERT(nb10 == type_size_src1); // SRC1 strides int64_t s11 = nb11 / type_size_src1; @@ -2854,11 +2855,40 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons if (src1->type != GGML_TYPE_F16) { scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_nc_sycl", dst, /*num_src=*/2, " : converting src1 to fp16"); - const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type); - GGML_ASSERT(to_fp16_nc_sycl != nullptr); + + // iterate tensor dims and find the slowest moving dim and stride + int64_t last_dim=0; + int64_t last_str=0; + int64_t largest_str=0; + for(int i = 0; i< 4; i++){ + // last stride is always the largest + if(src1->nb[i] == largest_str){ + if(src1->ne[last_dim] == 1){ + last_str = i; + last_dim = i; + } + } + if(src1->nb[i] > largest_str){ + largest_str = src1->nb[i]; + last_str = i; + last_dim = i; + } + + } +#if GGML_SYCL_DNNL + // oneDNN handles strided data and does not need overhead of get_to_fp16_nc_sycl + const int64_t ne_src1 = src1->nb[last_str] * src1->ne[last_dim] / type_size_src1; + src1_f16_alloc.alloc(ne_src1); + const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst); + GGML_ASSERT(to_fp16_sycl != nullptr); + to_fp16_sycl(src1_f16, src1_f16_alloc.get(), ne_src1, queue); +# else const int64_t ne_src1 = ggml_nelements(src1); src1_f16_alloc.alloc(ne_src1); + const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type); + GGML_ASSERT(to_fp16_nc_sycl != nullptr); to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue); +#endif src1_f16 = src1_f16_alloc.get(); s11 = ne10; @@ -2892,38 +2922,89 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons #if GGML_SYCL_DNNL if (!g_ggml_sycl_disable_dnn) { - auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12] - (const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) { - - DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10, - src1, DnnlGemmWrapper::to_dt(), s11, 1, s12, - src0, DnnlGemmWrapper::to_dt(), 1, nb01/nb00, nb02/nb00, - dst, DnnlGemmWrapper::to_dt(), queue, batches_a, batches_b); - }; - - if (r2 == 1 && r3 == 1) { - if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { - dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03); - } - else { - for (int64_t ie03 = 0; ie03 < ne03; ++ie03) { - const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes - const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13; - float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float)); - dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02); + int64_t str_a0 = nb00 / type_size_src0; + int64_t str_a1 = nb01 / type_size_src0; + int64_t str_a2 = nb02 / type_size_src0; + + int64_t str_b0 = nb10 / type_size_src1; + int64_t str_b1 = nb11 / type_size_src1; + int64_t str_b2 = nb12 / type_size_src1; + + auto launch_gemm_for_batches = [&ctx, queue](const sycl::half *src0, + const sycl::half *src1, float *dst, + int64_t a0, int64_t a1, int64_t batcha, + int64_t b0, int64_t b1, int64_t batchb, + int64_t sa0, int64_t sa1, int64_t sa2, + int64_t sb0, int64_t sb1, int64_t sb2, + int64_t sd2) { + bool supported_broadcast = batchb == batcha ? true + : batchb == 1 || batcha == 1 ? true + : false; + if (supported_broadcast) { + DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0, + DnnlGemmWrapper::to_dt(), sa0, sa1, sa2, src1, + DnnlGemmWrapper::to_dt(), sb0, sb1, sb2, dst, + DnnlGemmWrapper::to_dt(), queue, batcha, batchb); + } else { + // iterate over batches from smaller set of matrices (matrix 0) + int64_t batches0 = batcha; + int64_t batches1 = batchb; + + if (batches0 > batches1) { + int64_t num_mul_mats = batches1; + int64_t sub_batch = batches0 / num_mul_mats; + // src0 is batched and bigger, shift and multiply with src1 + for (int64_t i0 = 0; i0 < num_mul_mats; i0++) { + const sycl::half *src0_shifted = src0 + (sa2 * i0 * sub_batch); + const sycl::half *src1_shifted = src1 + (sb2 * i0); + float *dst_shifted = dst + (sd2 * i0 * sub_batch); + DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted, + DnnlGemmWrapper::to_dt(), sa0, sa1, sa2, + src1_shifted, DnnlGemmWrapper::to_dt(), sb0, + sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt(), + queue, sub_batch, 1); + } + } else { + int64_t num_mul_mats = batches0; + int64_t sub_batch = batches1 / num_mul_mats; + // src1 is batched and bigger, shift and multiply with src0 + for (int64_t i1 = 0; i1 < num_mul_mats; i1++) { + const sycl::half *src0_shifted = src0 + (sa2 * i1); + const sycl::half *src1_shifted = src1 + (sb2 * i1 * sub_batch); + float *dst_shifted = dst + (sd2 * i1 * sub_batch); + DnnlGemmWrapper::gemm(ctx, a1, b1, a0, src0_shifted, + DnnlGemmWrapper::to_dt(), sa0, sa1, sa2, + src1_shifted, DnnlGemmWrapper::to_dt(), sb0, + sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt(), + queue, 1, sub_batch); + } + } } - } - } else { - // iterate over batches from smaller set of matrices (matrix 0) - for (int64_t ie02 = 0; ie02 < ne02; ++ie02) { - for (int64_t ie03 = 0; ie03 < ne03; ++ie03) { - const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half)); - const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3; - float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float)); - dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1); + }; + + bool cont_batches_a = nb02 * ne02 == nb03; + bool cont_batches_b = nb12 * ne12 == nb13; + if (cont_batches_a && cont_batches_b) { + int64_t batches0 = ne02 * ne03; + int64_t batches1 = ne12 * ne13; + launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0, + ne10, ne11, batches1, str_a0, str_a1, str_a2, str_b0, str_b1, + str_b2, nb2 / sizeof(float)); + } else { + for (int64_t b_a = 0; b_a < ne03; b_a++) { + const sycl::half *src0_f16_shifted + = src0_f16 + (nb03 * b_a / type_size_src0); + const sycl::half *src1_f16_shifted + = src1_f16 + (nb13 * b_a / type_size_src1); + float *dst_shifted = dst_ddf + (nb3 * b_a / sizeof(float)); + int64_t batches0 = ne02; + int64_t batches1 = ne12; + launch_gemm_for_batches(src0_f16_shifted, src1_f16_shifted, dst_shifted, + ne00, ne01, batches0, ne10, ne11, batches1, str_a0, str_a1, + str_a2, str_b0, str_b1, str_b2, nb2 / sizeof(float)); } } - } + } else #endif @@ -3263,10 +3344,10 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor // The kernel from the if path is faster for that specific case, but does not support all mul mats. ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst); } - } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { + } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { // KQV single-batch ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst); - } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { + } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2] * src1->ne[3] > 1) { // KQ + KQV multi-batch ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst); } else if (use_dequantize_mul_mat_vec) { @@ -4303,6 +4384,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g { // TODO: add support // ref: https://github.com/ggml-org/llama.cpp/pull/14274 +#pragma message("TODO: implement BF16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)") return (op->type == GGML_TYPE_F32 || (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64)); } break; case GGML_OP_CPY: diff --git a/ggml/src/ggml-sycl/set_rows.cpp b/ggml/src/ggml-sycl/set_rows.cpp index 4a76a63d3545d..3091fab39958d 100644 --- a/ggml/src/ggml-sycl/set_rows.cpp +++ b/ggml/src/ggml-sycl/set_rows.cpp @@ -6,46 +6,49 @@ static constexpr bool is_arithmetic_v() { return std::is_arithmetic_v || std::is_same_v || std::is_same_v; } } + template static inline std::enable_if_t() && utils::is_arithmetic_v(), void> convert (const char* src, char* dst) { auto src_val = *reinterpret_cast(src); auto dst_val = sycl::vec(src_val).template convert()[0]; - *reinterpret_cast(dst) = dst_val;; + *reinterpret_cast(dst) = dst_val; } template static void k_set_rows( const char * __restrict__ src0, const int64_t * __restrict__ src1, char * __restrict__ dst, - const int64_t ne00, const int64_t ne01, const int64_t ne11, const int64_t ne12, + const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t ne11, const int64_t ne12, const size_t nb01, const size_t nb02, const size_t nb03, const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb1, const size_t nb2, const size_t nb3, const size_t src_type_size, const size_t dst_type_size, - const sycl::nd_item<3> & item_ct1) { - - const int i03 = item_ct1.get_group(0); - const int i02 = item_ct1.get_group(1); - const int i01 = item_ct1.get_group(2) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1); // Row index + const int64_t total_elements, + const sycl::nd_item<1> & item_ct1) { - if (i01 >= ne01) { + const int64_t i = item_ct1.get_global_linear_id(); + if (i >= total_elements) { return; } - const int i12 = i03 % ne12; - const int i11 = i02 % ne11; - const int i10 = i01; + const int64_t i03 = i / (ne00 * ne01 * ne02); + const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01); + const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00; + const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00; + + const int64_t i12 = i03 % ne12; + const int64_t i11 = i02 % ne11; + const int64_t i10 = i01; const int64_t dst_row = *(const int64_t *)((const char *)src1 + calculate_offset<3>({nb10, nb11, nb12}, {i10, i11, i12})); const char * src0_row = src0 + calculate_offset<3>({nb01, nb02, nb03}, {i01, i02, i03}); - char * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3; + const char * src_elem = src0_row + i00 * src_type_size; + char * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3; + char * dst_elem = dst_row_ptr + i00 * dst_type_size; - for (int col = item_ct1.get_local_id(0); col < ne00; col += item_ct1.get_local_range(0)) { - const char * src_elem = src0_row + col * src_type_size; - char * dst_elem = dst_row_ptr + col * dst_type_size; - convert(src_elem, dst_elem); - } + convert(src_elem, dst_elem); } template @@ -58,32 +61,29 @@ static void set_rows_sycl( const size_t src_type_size, const size_t dst_type_size, queue_ptr stream) { - constexpr int max_threads_per_row = 64; // KEEPING 64 for now - const int threads_per_row = std::min((int)ne00, max_threads_per_row); - - constexpr int max_threads_per_block = 64; - const int rows_per_block = std::max(1, max_threads_per_block / threads_per_row); - - const sycl::range<3> block_size(1, rows_per_block, threads_per_row); - const sycl::range<3> grid_size(ne03, ne02, (ne01 + rows_per_block - 1) / rows_per_block); - - sycl_parallel_for( - stream, - sycl::nd_range<3>(grid_size * block_size, block_size), - [=](sycl::nd_item<3> item_ct1) { - k_set_rows( - src0_d, src1_d, dst_d, - ne00, ne01, ne11, ne12, - nb01, nb02, nb03, - nb10, nb11, nb12, - nb1, nb2, nb3, - src_type_size, dst_type_size, - item_ct1 - ); - } - ); -} + const int64_t total_elements = ne00 * ne01 * ne02 * ne03; + constexpr int block_size = 64; + const int64_t grid_size = ceil_div(total_elements, block_size); + + sycl_parallel_for( + stream, + sycl::nd_range<1>(grid_size * block_size, block_size), + [=](sycl::nd_item<1> item_ct1) { + k_set_rows( + src0_d, src1_d, dst_d, + ne00, ne01, ne02, + ne11, ne12, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + src_type_size, dst_type_size, + total_elements, + item_ct1 + ); + } + ); +} void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); @@ -122,7 +122,7 @@ void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { nb1, nb2, nb3, sizeof(float), sizeof(sycl::half), stream - ); + ); break; default: GGML_ABORT("Unsupported tensor type!"); diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 416ee3bd3f70a..3019a545d58ed 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2835,10 +2835,11 @@ static void ggml_vk_load_shaders(vk_device& device) { return s; }; + bool rte = device->float_controls_rte_fp16; #define CREATE_BINARY(name, namemod, spec) \ for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \ ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \ - #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \ + #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \ "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1); CREATE_BINARY(add, , {0}) @@ -2890,8 +2891,13 @@ static void ggml_vk_load_shaders(vk_device& device) { #undef CREATE_UNARY #define CREATE_GLU(name) \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); + if (device->float_controls_rte_fp16) { \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ + } else { \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ + } CREATE_GLU(geglu) CREATE_GLU(reglu) @@ -4916,7 +4922,7 @@ static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) { return tensor->nb[0] == ggml_type_size(tensor->type) && tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) && - tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; + (tensor->ne[3] == 1 || tensor->nb[3] == tensor->nb[2]*tensor->ne[2]); } static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src, const ggml_tensor * dst, ggml_type to) { @@ -10350,10 +10356,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm // If there's not enough shared memory for row_ids and the result tile, fallback to CPU return false; } - // Check against size of shared memory variable - if (op->src[2]->ne[0] > 4096) { - return false; - } } switch (src0_type) { case GGML_TYPE_F32: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp index e06547e48f7fe..27d6b7464f62c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp @@ -1,10 +1,6 @@ #version 450 -#if RTE16 -#extension GL_EXT_spirv_intrinsics : enable -spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits -#endif // RTE16 - +#include "rte.comp" #include "types.comp" #if defined(SET_ROWS) && QUANT_K == 1 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp index 157154af3a328..d4e4e6bae63df 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp @@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; void main() { [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { const uint i = gl_WorkGroupID.x * 256 + wgy; - if (i >= p.M * p.K / QUANT_K) { + if (i >= p.nel / QUANT_K) { return; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp index c17dd0d999116..3661f771c745f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp @@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; void main() { [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { const uint i = uint(gl_WorkGroupID.x * 256 + wgy); - if (i >= p.M * p.K / QUANT_K) { + if (i >= p.nel / QUANT_K) { return; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp index 987f113a35ad0..1370db3654dd7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp @@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; void main() { [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { const uint ib = gl_WorkGroupID.x * 256 + wgy; - if (ib >= p.M * p.K / QUANT_K) { + if (ib >= p.nel / QUANT_K) { return; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp index 6db5403b6613e..3f3b839e11832 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp @@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; void main() { [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { const uint ib = gl_WorkGroupID.x * 256 + wgy; - if (ib >= p.M * p.K / QUANT_K) { + if (ib >= p.nel / QUANT_K) { return; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp index 0b91317550f97..9cf34256e8c80 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp @@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; void main() { [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { const uint i = gl_WorkGroupID.x * 256 + wgy; - if (i >= p.M * p.K / QUANT_K) { + if (i >= p.nel / QUANT_K) { return; } const uint tid = gl_LocalInvocationID.x; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp index 062e2a4cdf2d8..4b4316cf3d9f2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp @@ -1,6 +1,8 @@ #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_control_flow_attributes : require +#include "rte.comp" + layout (push_constant) uniform parameter { uint ne; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp index 41a29889075f6..004a61fc16254 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp @@ -1,5 +1,7 @@ #extension GL_EXT_shader_16bit_storage : require +#include "rte.comp" + layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index 09aa849e8815c..17c7ccb90d001 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -1,12 +1,9 @@ #version 450 #extension GL_EXT_shader_16bit_storage : require -#extension GL_EXT_spirv_intrinsics: enable #extension GL_EXT_control_flow_attributes : require -#if RTE16 -spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits -#endif +#include "rte.comp" layout (push_constant) uniform parameter { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp index 96c9c4cbd307c..00e203e73bd1b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp @@ -1,11 +1,8 @@ #include "types.comp" #extension GL_EXT_shader_16bit_storage : require -#extension GL_EXT_spirv_intrinsics: enable -#if RTE16 -spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits -#endif +#include "rte.comp" layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp new file mode 100644 index 0000000000000..ad51c1e80b856 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp @@ -0,0 +1,5 @@ + +#if RTE16 +#extension GL_EXT_spirv_intrinsics : enable +spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits +#endif // RTE16 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index d4a4e4c5290d8..809c0bd9bd305 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -537,8 +537,10 @@ void process_shaders() { for (auto src0_f16 : {false, true}) { for (auto src1_f16 : {false, true}) { for (auto dst_f16 : {false, true}) { - auto name = op + get_suffix(src0_f16, src1_f16, dst_f16); - string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}}); + for (auto rte : {false, true}) { + auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : ""); + string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + } } } } @@ -592,16 +594,19 @@ void process_shaders() { string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("geglu_f16", "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); - string_to_spv("geglu_f32", "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("reglu_f16", "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); - string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); - string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("geglu_erf_f16", "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); - string_to_spv("geglu_erf_f32", "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); - string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + for (auto rte : {false, true}) { + std::string suffix = rte ? "_rte" : ""; + string_to_spv("geglu_f16" + suffix, "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("geglu_f32" + suffix, "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("reglu_f16" + suffix, "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("reglu_f32" + suffix, "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("swiglu_f16" + suffix, "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("swiglu_f32" + suffix, "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("geglu_erf_f16" + suffix, "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("geglu_erf_f32" + suffix, "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("geglu_quick_f16" + suffix,"geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("geglu_quick_f32" + suffix,"geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + } string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); @@ -709,11 +714,59 @@ void write_output_files() { std::remove(path.c_str()); } } + + std::string suffixes[2] = {"_f32", "_f16"}; for (const char *op : {"add", "sub", "mul", "div"}) { - fprintf(hdr, "extern unsigned char *%s_data[2][2][2];\n", op); - fprintf(hdr, "extern uint64_t %s_len[2][2][2];\n", op); - fprintf(src, "unsigned char *%s_data[2][2][2] = {{{%s_f32_f32_f32_data, %s_f32_f32_f16_data}, {%s_f32_f16_f32_data, %s_f32_f16_f16_data}}, {{%s_f16_f32_f32_data, %s_f16_f32_f16_data}, {%s_f16_f16_f32_data, %s_f16_f16_f16_data}}};\n", op, op, op, op, op, op, op, op, op); - fprintf(src, "uint64_t %s_len[2][2][2] = {{{%s_f32_f32_f32_len, %s_f32_f32_f16_len}, {%s_f32_f16_f32_len, %s_f32_f16_f16_len}}, {{%s_f16_f32_f32_len, %s_f16_f32_f16_len}, {%s_f16_f16_f32_len, %s_f16_f16_f16_len}}};\n", op, op, op, op, op, op, op, op, op); + fprintf(hdr, "extern unsigned char *%s_data[2][2][2][2];\n", op); + fprintf(hdr, "extern uint64_t %s_len[2][2][2][2];\n", op); + std::string data = "unsigned char *" + std::string(op) + "_data[2][2][2][2] = "; + std::string len = "uint64_t " + std::string(op) + "_len[2][2][2][2] = "; + for (uint32_t t0 = 0; t0 < 2; ++t0) { + if (t0 == 0) { + data += "{"; + len += "{"; + } + for (uint32_t t1 = 0; t1 < 2; ++t1) { + if (t1 == 0) { + data += "{"; + len += "{"; + } + for (uint32_t t2 = 0; t2 < 2; ++t2) { + if (t2 == 0) { + data += "{"; + len += "{"; + } + for (uint32_t rte = 0; rte < 2; ++rte) { + if (rte == 0) { + data += "{"; + len += "{"; + } + data += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : ""); + len += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : ""); + data += "_data,"; + len += "_len,"; + if (rte == 1) { + data += "}, "; + len += "}, "; + } + } + if (t2 == 1) { + data += "}, "; + len += "}, "; + } + } + if (t1 == 1) { + data += "}, "; + len += "}, "; + } + } + if (t0 == 1) { + data += "};\n"; + len += "};\n"; + } + } + fprintf(src, data.c_str()); + fprintf(src, len.c_str()); } fclose(hdr); fclose(src); diff --git a/ggml/src/ggml-webgpu/CMakeLists.txt b/ggml/src/ggml-webgpu/CMakeLists.txt new file mode 100644 index 0000000000000..79ef68b85a477 --- /dev/null +++ b/ggml/src/ggml-webgpu/CMakeLists.txt @@ -0,0 +1,54 @@ +cmake_minimum_required(VERSION 3.13) + +find_package(Python3 REQUIRED) + +# Shader locations +set(SHADER_DIR "${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders") +set(SHADER_OUTPUT_DIR "${CMAKE_CURRENT_BINARY_DIR}/generated") +set(SHADER_HEADER "${SHADER_OUTPUT_DIR}/ggml-wgsl-shaders.hpp") +file(MAKE_DIRECTORY ${SHADER_OUTPUT_DIR}) + +message(STATUS "Shader output dir: ${SHADER_OUTPUT_DIR}") + +# Find all WGSL files +file(GLOB WGSL_SHADER_FILES "${SHADER_DIR}/*.wgsl") + +# Generate the header using a Python script +add_custom_command( + OUTPUT ${SHADER_HEADER} + COMMAND ${CMAKE_COMMAND} -E echo "Embedding WGSL shaders to ggml-wgsl-shaders.hpp" + COMMAND ${CMAKE_COMMAND} -E make_directory ${SHADER_OUTPUT_DIR} + COMMAND ${CMAKE_COMMAND} -E env PYTHONIOENCODING=utf-8 + ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py + --input "${SHADER_DIR}" + --output "${SHADER_HEADER}" + DEPENDS ${WGSL_SHADER_FILES} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py + VERBATIM +) + +add_custom_target(generate_shaders DEPENDS ${SHADER_HEADER}) + +ggml_add_backend_library(ggml-webgpu + ggml-webgpu.cpp + ${SHADER_HEADER} + ../../include/ggml-webgpu.h +) + +add_dependencies(ggml-webgpu generate_shaders) + +if(EMSCRIPTEN) + set(EMDAWNWEBGPU_DIR "" CACHE PATH "Path to emdawnwebgpu_pkg") + + target_compile_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py") + target_link_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py") +else() + find_package(Dawn REQUIRED) + set(DawnWebGPU_TARGET dawn::webgpu_dawn) +endif() + +if (GGML_WEBGPU_DEBUG) + target_compile_definitions(ggml-webgpu PRIVATE GGML_WEBGPU_DEBUG=1) +endif() + +target_include_directories(ggml-webgpu PRIVATE ${SHADER_OUTPUT_DIR}) +target_link_libraries(ggml-webgpu PRIVATE ${DawnWebGPU_TARGET}) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp new file mode 100644 index 0000000000000..c5abc69343357 --- /dev/null +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -0,0 +1,907 @@ +#include "ggml-webgpu.h" + +#include + +#include "ggml-impl.h" +#include "ggml-backend-impl.h" + +#include "ggml-wgsl-shaders.hpp" + +#include +#include +#include +#include + +#ifdef GGML_WEBGPU_DEBUG +#define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl +#else +#define WEBGPU_LOG_DEBUG(msg) ((void) 0) +#endif // GGML_WEBGPU_DEBUG + +/* Constants */ + +#define WEBGPU_MUL_MAT_WG_SIZE 64 +#define WEBGPU_MUL_MAT_PARAMS_SIZE (13 * sizeof(uint32_t)) // M, N, K, batch sizes, broadcasts +#define WEBGPU_CPY_PARAMS_SIZE (15 * sizeof(uint32_t)) // strides and offsets +#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 + +/* End Constants */ + +// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations. +static void * const webgpu_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT + +// Always returns the base offset of a tensor, regardless of views. +static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) { + if (tensor->view_src) { + return (uint8_t *) tensor->view_src->data - (uint8_t *) webgpu_ptr_base; + } + return (uint8_t *) tensor->data - (uint8_t *) webgpu_ptr_base; +} + +/* Struct definitions */ + +// All the base objects needed to run operations on a WebGPU device +struct webgpu_context_struct { + wgpu::Instance instance; + wgpu::Adapter adapter; + wgpu::Device device; + wgpu::Queue queue; + wgpu::Limits limits; + wgpu::SupportedFeatures features; + + std::mutex mutex; + bool device_initialized = false; + + // pipelines and parameter buffers + // TODO: reuse params buffers for different pipelines when possible + wgpu::ComputePipeline memset_pipeline; + wgpu::Buffer memset_params_dev_buf; + wgpu::Buffer memset_params_host_buf; + wgpu::ComputePipeline mul_mat_pipeline; + wgpu::Buffer mul_mat_params_dev_buf; + wgpu::Buffer mul_mat_params_host_buf; + wgpu::ComputePipeline cpy_pipeline; + wgpu::Buffer cpy_params_dev_buf; + wgpu::Buffer cpy_params_host_buf; + + size_t memset_bytes_per_thread; + + // Staging buffer for reading data from the GPU + wgpu::Buffer get_tensor_staging_buf; +}; + +typedef std::shared_ptr webgpu_context; + +struct ggml_backend_webgpu_reg_context { + webgpu_context webgpu_ctx; + + size_t device_count; + const char * name; +}; + +struct ggml_backend_webgpu_device_context { + webgpu_context webgpu_ctx; + + std::string device_name; + std::string device_desc; +}; + +struct ggml_backend_webgpu_context { + webgpu_context webgpu_ctx; + + std::string name; +}; + +struct ggml_backend_webgpu_buffer_context { + webgpu_context webgpu_ctx; + + wgpu::Buffer buffer; + + ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf) : + webgpu_ctx(ctx), buffer(buf) { + } +}; + +/* End struct definitions */ + +/* WebGPU object initializations */ + +static void ggml_webgpu_create_pipeline(wgpu::Device &device, wgpu::ComputePipeline &pipeline, const char * shader_code, const char * label, const std::vector &constants = {}) { + WEBGPU_LOG_DEBUG("ggml_webgpu_create_pipeline()"); + wgpu::ShaderSourceWGSL shader_source; + shader_source.code = shader_code; + wgpu::ShaderModuleDescriptor shader_desc; + shader_desc.nextInChain = &shader_source; + wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc); + + wgpu::ComputePipelineDescriptor pipeline_desc; + pipeline_desc.label = label; + pipeline_desc.compute.module = shader_module; + pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code + pipeline_desc.layout = nullptr; // nullptr means auto layout + if (constants.size() > 0) { + pipeline_desc.compute.constants = constants.data(); + pipeline_desc.compute.constantCount = constants.size(); + } + pipeline = device.CreateComputePipeline(&pipeline_desc); +} + +static void ggml_webgpu_create_buffer(wgpu::Device &device, wgpu::Buffer &buffer, size_t size, wgpu::BufferUsage usage, const char* label) { + WEBGPU_LOG_DEBUG("ggml_webgpu_create_buffer()"); + + wgpu::BufferDescriptor buffer_desc; + buffer_desc.size = size; + buffer_desc.usage = usage; + buffer_desc.label = label; + buffer_desc.mappedAtCreation = false; + // TODO: error handling + buffer = device.CreateBuffer(&buffer_desc); +} + +/** End WebGPU object initializations */ + +/** WebGPU Actions */ + +static void ggml_backend_webgpu_map_buffer(webgpu_context ctx, wgpu::Buffer buffer, wgpu::MapMode mode, size_t offset, size_t size) { + ctx->instance.WaitAny(buffer.MapAsync( + mode, offset, size, wgpu::CallbackMode::WaitAnyOnly, + [](wgpu::MapAsyncStatus status, wgpu::StringView message) { + if (status != wgpu::MapAsyncStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n", message.data); + } + }), + UINT64_MAX + ); +} + +static void ggml_backend_webgpu_buffer_memset(webgpu_context ctx, wgpu::Buffer buf, uint32_t value, size_t offset, size_t size) { + std::lock_guard lock(ctx->mutex); + wgpu::Device device = ctx->device; + + // map the host parameters buffer + ggml_backend_webgpu_map_buffer(ctx, ctx->memset_params_host_buf, wgpu::MapMode::Write, 0, ctx->memset_params_host_buf.GetSize()); + uint32_t * params = (uint32_t *) ctx->memset_params_host_buf.GetMappedRange(); + + params[0] = (uint32_t)offset; + params[1] = (uint32_t)size; + params[2] = value; + ctx->memset_params_host_buf.Unmap(); + + wgpu::BindGroupEntry entries[2]; + entries[0].binding = 0; // binding for the buffer to memset + entries[0].buffer = buf; + entries[0].offset = 0; + entries[0].size = buf.GetSize(); + entries[1].binding = 1; // binding for the parameters + entries[1].buffer = ctx->memset_params_dev_buf; + entries[1].offset = 0; + entries[1].size = ctx->memset_params_dev_buf.GetSize(); + + wgpu::BindGroupDescriptor bind_group_desc; + bind_group_desc.layout = ctx->memset_pipeline.GetBindGroupLayout(0); + bind_group_desc.entryCount = 2; + bind_group_desc.label = "ggml_memset"; + bind_group_desc.entries = entries; + wgpu::BindGroup bind_group = device.CreateBindGroup(&bind_group_desc); + + wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); + encoder.CopyBufferToBuffer( + ctx->memset_params_host_buf, 0, + ctx->memset_params_dev_buf, 0, + ctx->memset_params_dev_buf.GetSize() + ); + wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); + pass.SetPipeline(ctx->memset_pipeline); + pass.SetBindGroup(0, bind_group); + size_t bytes_per_wg = ctx->limits.maxComputeWorkgroupSizeX * ctx->memset_bytes_per_thread; + pass.DispatchWorkgroups(((size + 3) + bytes_per_wg - 1) / bytes_per_wg, 1, 1); + pass.End(); + wgpu::CommandBuffer commands = encoder.Finish(); + + ctx->queue.Submit(1, &commands); +} + +static void ggml_backend_webgpu_wait_on_submission(webgpu_context ctx) { + // Wait for the queue to finish processing all commands + ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::WaitAnyOnly, + [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { + if (status != wgpu::QueueWorkDoneStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to wait on queue: %s\n", message.data); + } + }), + UINT64_MAX + ); +} + +/** End WebGPU Actions */ + +/** GGML Backend Interface */ + +static const char * ggml_backend_webgpu_name(ggml_backend_t backend) { + ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *)backend->context; + return ctx->name.c_str(); +} + +static void ggml_backend_webgpu_free(ggml_backend_t backend) { + ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *)backend->context; + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")"); + + // TODO: cleanup + GGML_UNUSED(ctx); +} + +// Returns true if node has enqueued work into the queue, false otherwise +static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node){ + if (ggml_is_empty(node)) { + return false; + } + + WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")"); + + + switch (node->op) { + // no-ops + case GGML_OP_NONE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + return false; + + case GGML_OP_CPY: { + std::lock_guard lock(ctx->mutex); + const ggml_tensor * src = node->src[0]; + ggml_backend_webgpu_buffer_context * src_ctx = (ggml_backend_webgpu_buffer_context *) src->buffer->context; + size_t src_offset = webgpu_tensor_offset(src) + src->view_offs; + // assumes power of 2 offset alignment + size_t src_misalignment = src_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); + // align to minimum offset alignment + src_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1); + ggml_backend_webgpu_buffer_context * dst_ctx = (ggml_backend_webgpu_buffer_context *) node->buffer->context; + size_t dst_offset = webgpu_tensor_offset(node) + node->view_offs; + size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); + dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1); + + wgpu::Device device = ctx->device; + ggml_backend_webgpu_map_buffer(ctx, ctx->cpy_params_host_buf, + wgpu::MapMode::Write, 0, ctx->cpy_params_host_buf.GetSize()); + uint32_t * params = (uint32_t *) ctx->cpy_params_host_buf.GetMappedRange(); + uint32_t ne = (uint32_t)ggml_nelements(node); + params[0] = ne; + params[1] = src_misalignment/ggml_type_size(src->type); + params[2] = dst_misalignment/ggml_type_size(node->type); + + // Convert byte-strides to element-strides + params[3] = (uint32_t)src->nb[0]/ggml_type_size(src->type); + params[4] = (uint32_t)src->nb[1]/ggml_type_size(src->type); + params[5] = (uint32_t)src->nb[2]/ggml_type_size(src->type); + params[6] = (uint32_t)src->nb[3]/ggml_type_size(src->type); + params[7] = (uint32_t)node->nb[0]/ggml_type_size(node->type); + params[8] = (uint32_t)node->nb[1]/ggml_type_size(node->type); + params[9] = (uint32_t)node->nb[2]/ggml_type_size(node->type); + params[10] = (uint32_t)node->nb[3]/ggml_type_size(node->type); + // Logical shape — same for both tensors even if permuted + params[11] = (uint32_t)(src->ne[0]); + params[12] = (uint32_t)(src->ne[1]); + params[13] = (uint32_t)(src->ne[2]); + params[14] = (uint32_t)(src->ne[3]); + + ctx->cpy_params_host_buf.Unmap(); + + wgpu::BindGroupEntry entries[3]; + entries[0].binding = 0; + entries[0].buffer = src_ctx->buffer; + entries[0].offset = src_offset; + entries[0].size = (ggml_nbytes(src) + src_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1); + + entries[1].binding = 1; + entries[1].buffer = dst_ctx->buffer; + entries[1].offset = dst_offset; + entries[1].size = (ggml_nbytes(node) + dst_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1); + + entries[2].binding = 2; + entries[2].buffer = ctx->cpy_params_dev_buf; + entries[2].offset = 0; + entries[2].size = ctx->cpy_params_dev_buf.GetSize(); + + wgpu::BindGroupDescriptor bind_group_desc; + bind_group_desc.layout = ctx->cpy_pipeline.GetBindGroupLayout(0); + bind_group_desc.label = "ggml_op_cpy"; + bind_group_desc.entryCount = 3; + bind_group_desc.entries = entries; + wgpu::BindGroup bind_group = device.CreateBindGroup(&bind_group_desc); + + wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); + encoder.CopyBufferToBuffer( + ctx->cpy_params_host_buf, 0, + ctx->cpy_params_dev_buf, 0, + ctx->cpy_params_dev_buf.GetSize() + ); + wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); + pass.SetPipeline(ctx->cpy_pipeline); + pass.SetBindGroup(0, bind_group); + size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX; + pass.DispatchWorkgroups((ne + max_wg_size - 1) / max_wg_size); + pass.End(); + wgpu::CommandBuffer commands = encoder.Finish(); + + // TODO, don't submit here, batch submissions + ctx->queue.Submit(1, &commands); + // TODO, don't wait on submission here + ggml_backend_webgpu_wait_on_submission(ctx); + return true; + } + + case GGML_OP_MUL_MAT: + { + const ggml_tensor * src0 = node->src[0]; + ggml_backend_webgpu_buffer_context * src0_ctx = (ggml_backend_webgpu_buffer_context *) src0->buffer->context; + size_t src0_offset = webgpu_tensor_offset(src0) + src0->view_offs; + const ggml_tensor * src1 = node->src[1]; + ggml_backend_webgpu_buffer_context * src1_ctx = (ggml_backend_webgpu_buffer_context *) src1->buffer->context; + size_t src1_offset = webgpu_tensor_offset(src1) + src1->view_offs; + ggml_backend_webgpu_buffer_context * dst_ctx = (ggml_backend_webgpu_buffer_context *) node->buffer->context; + + size_t dst_offset = webgpu_tensor_offset(node) + node->view_offs; + + wgpu::Device device = ctx->device; + + // map the host parameters buffer + ggml_backend_webgpu_map_buffer(ctx, ctx->mul_mat_params_host_buf, + wgpu::MapMode::Write, 0, ctx->mul_mat_params_host_buf.GetSize()); + uint32_t * params = (uint32_t *) ctx->mul_mat_params_host_buf.GetMappedRange(); + + params[0] = (uint32_t)node->ne[1]; // number of rows in result (M) + params[1] = (uint32_t)node->ne[0]; // number of columns in result (N) + params[2] = (uint32_t)src0->ne[0]; // number of columns in src0/src1 (K) + + params[3] = (uint32_t)src0->nb[1]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 1 + params[4] = (uint32_t)src1->nb[1]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 1 + params[5] = (uint32_t)src0->nb[2]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 2 + params[6] = (uint32_t)src1->nb[2]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 2 + params[7] = (uint32_t)src0->nb[3]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 3 + params[8] = (uint32_t)src1->nb[3]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 3 + + params[9] = (uint32_t)src0->ne[2]; // batch size in dimension 2 + params[10] = (uint32_t)src0->ne[3]; // batch size in dimension 3 + params[11] = (uint32_t)(src1->ne[2]/src0->ne[2]); // broadcast in dimension 2 + params[12] = (uint32_t)(src1->ne[3]/src0->ne[3]); // broadcast in dimension 3 + + ctx->mul_mat_params_host_buf.Unmap(); + + wgpu::BindGroupEntry entries[4]; + entries[0].binding = 0; + entries[0].buffer = src0_ctx->buffer; + entries[0].offset = src0_offset; + entries[0].size = ggml_nbytes(src0); + + entries[1].binding = 1; + entries[1].buffer = src1_ctx->buffer; + entries[1].offset = src1_offset; + entries[1].size = ggml_nbytes(src1); + + entries[2].binding = 2; + entries[2].buffer = dst_ctx->buffer; + entries[2].offset = dst_offset; + entries[2].size = ggml_nbytes(node); + + entries[3].binding = 3; + entries[3].buffer = ctx->mul_mat_params_dev_buf; + entries[3].offset = 0; + entries[3].size = ctx->mul_mat_params_dev_buf.GetSize(); + + wgpu::BindGroupDescriptor bind_group_desc; + bind_group_desc.layout = ctx->mul_mat_pipeline.GetBindGroupLayout(0); + bind_group_desc.entryCount = 4; + bind_group_desc.label = "ggml_op_mul_mat"; + bind_group_desc.entries = entries; + wgpu::BindGroup bind_group = device.CreateBindGroup(&bind_group_desc); + + wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); + encoder.CopyBufferToBuffer( + ctx->mul_mat_params_host_buf, 0, + ctx->mul_mat_params_dev_buf, 0, + ctx->mul_mat_params_dev_buf.GetSize() + ); + wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); + pass.SetPipeline(ctx->mul_mat_pipeline); + pass.SetBindGroup(0, bind_group); + pass.DispatchWorkgroups((node->ne[0] * node->ne[1] * node->ne[2] * node->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE); + pass.End(); + wgpu::CommandBuffer commands = encoder.Finish(); + + // TODO, don't submit here, batch submissions + ctx->queue.Submit(1, &commands); + // TODO, don't wait on submission here + ggml_backend_webgpu_wait_on_submission(ctx); + return true; + } + + default: + return false; + } +} + +static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)"); + + ggml_backend_webgpu_context * backend_ctx = static_cast(backend->context); + webgpu_context ctx = backend_ctx->webgpu_ctx; + + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_webgpu_encode_node(ctx, cgraph->nodes[i]); + } + + return GGML_STATUS_SUCCESS; +} + +static ggml_backend_i ggml_backend_webgpu_i = { + /* .get_name = */ ggml_backend_webgpu_name, + /* .free = */ ggml_backend_webgpu_free, + /* .set_tensor_async = */ NULL, + /* .get_tensor_async = */ NULL, + /* .cpy_tensor_async = */ NULL, + /* .synchronize = */ NULL, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_webgpu_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, +}; + +/* End GGML Backend Interface */ + +/* GGML Backend Buffer Interface */ + +static void ggml_backend_webgpu_buffer_free_buffer(ggml_backend_buffer_t buffer) { + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_free_buffer()"); + ggml_backend_webgpu_buffer_context * ctx = static_cast(buffer->context); + ctx->buffer.Destroy(); +} + +// Returns the "fake" base pointer. +static void * ggml_backend_webgpu_buffer_get_base(ggml_backend_buffer_t buffer) { + GGML_UNUSED(buffer); + return webgpu_ptr_base; +} + +static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + if (size == 0) { + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor: size is zero, nothing to do."); + return; + } + + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")"); + + ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; + size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; + // This is a trick to set all bytes of a u32 to the same 1 byte value. + uint32_t val32 = (uint32_t)value * 0x01010101; + ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size); +} + +static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); + ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; + webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx; + + size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; + + webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size/4)*4); + + if (size % 4 != 0) { + // If size is not a multiple of 4, we need to memset the remaining bytes + size_t remaining_size = size % 4; + // pack the remaining bytes into a uint32_t + uint32_t val32 = 0; + for (size_t i = 0; i < remaining_size; i++) { + ((uint8_t *)&val32)[i] = ((const uint8_t *)data)[size - remaining_size + i]; + } + // memset the remaining bytes + ggml_backend_webgpu_buffer_memset(webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size); + } +} + +static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); + + ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; + webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx; + wgpu::Device device = webgpu_ctx->device; + + size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; + + size_t final_size = size; + if (size % 4 != 0) { + // If size is not a multiple of 4, we need to round it up to the next multiple of 4 + final_size = size + (4 - (size % 4)); + } + + std::lock_guard lock(webgpu_ctx->mutex); + + if (webgpu_ctx->get_tensor_staging_buf == nullptr || + webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) { + // Create a new staging buffer if it doesn't exist or is too small + if (webgpu_ctx->get_tensor_staging_buf) { + webgpu_ctx->get_tensor_staging_buf.Destroy(); + } + ggml_webgpu_create_buffer(device, webgpu_ctx->get_tensor_staging_buf, final_size, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf"); + } + + // Copy the data from the buffer to the staging buffer + wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); + encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, webgpu_ctx->get_tensor_staging_buf, 0, final_size); + wgpu::CommandBuffer commands = encoder.Finish(); + // Submit the command buffer to the queue + webgpu_ctx->queue.Submit(1, &commands); + + // Map the staging buffer to read the data + ggml_backend_webgpu_map_buffer(webgpu_ctx, webgpu_ctx->get_tensor_staging_buf, wgpu::MapMode::Read, 0, final_size); + // Must specify size here since the staging buffer might be larger than the tensor size + const void * mapped_range = webgpu_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size); + + // Copy the data from the mapped range to the output buffer + std::memcpy(data, mapped_range, size); + webgpu_ctx->get_tensor_staging_buf.Unmap(); +} + +static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")"); + + ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; + ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, value, 0, buffer->size); +} + +static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = { + /* .free_buffer = */ ggml_backend_webgpu_buffer_free_buffer, + /* .get_base = */ ggml_backend_webgpu_buffer_get_base, + /* .init_tensor = */ NULL, // TODO: optional, needed? + /* .memset_tensor = */ ggml_backend_webgpu_buffer_memset_tensor, + /* .set_tensor = */ ggml_backend_webgpu_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_webgpu_buffer_get_tensor, + /* .cpy_tensor = */ NULL, // TODO: optional, implement this + /* .clear = */ ggml_backend_webgpu_buffer_clear, + /* .reset = */ NULL, // TODO: optional, think it coordinates with .init_tensor +}; + +/* End GGML Backend Buffer Interface */ + +/* GGML Backend Buffer Type Interface */ + +static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + ggml_backend_webgpu_device_context * ctx = static_cast(buft->device->context); + return ctx->device_name.c_str(); +} + +static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer(" << size << ")"); + ggml_backend_webgpu_device_context * ctx = static_cast(buft->device->context); + + wgpu::Buffer buf; + ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf, size, + wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst, "allocated_buffer"); + + ggml_backend_webgpu_buffer_context * buf_ctx = new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf); + + return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size); +} + +static size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + ggml_backend_webgpu_device_context * ctx = static_cast(buft->device->context); + return ctx->webgpu_ctx->limits.minStorageBufferOffsetAlignment; +} + +// maxBufferSize might be larger, but you can't bind more than maxStorageBufferBindingSize to a single binding. +static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { + ggml_backend_webgpu_device_context * ctx = static_cast(buft->device->context); + return ctx->webgpu_ctx->limits.maxStorageBufferBindingSize; +} + +/* End GGML Backend Buffer Type Interface */ + +/* GGML Backend Device Interface */ + +static const char * ggml_backend_webgpu_device_get_name(ggml_backend_dev_t dev) { + ggml_backend_webgpu_device_context * ctx = static_cast(dev->context); + return ctx->device_name.c_str(); +} + +static const char * ggml_backend_webgpu_device_get_description(ggml_backend_dev_t dev) { + ggml_backend_webgpu_device_context * ctx = static_cast(dev->context); + return ctx->device_desc.c_str(); +} + +static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + ggml_backend_webgpu_device_context * ctx = static_cast(dev->context); + // TODO: what do we actually want to return here? maxBufferSize might not be the full available memory. + *free = ctx->webgpu_ctx->limits.maxBufferSize; + *total = ctx->webgpu_ctx->limits.maxBufferSize; +} + +static enum ggml_backend_dev_type ggml_backend_webgpu_device_get_type(ggml_backend_dev_t dev) { + GGML_UNUSED(dev); + return GGML_BACKEND_DEVICE_TYPE_GPU; +} + +static void ggml_backend_webgpu_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + props->name = ggml_backend_webgpu_device_get_name(dev); + props->description = ggml_backend_webgpu_device_get_description(dev); + props->type = ggml_backend_webgpu_device_get_type(dev); + ggml_backend_webgpu_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = { + /* .async = */ false, + /* .host_buffer = */ false, + /* .buffer_from_host_ptr = */ false, + /* .events = */ false, + }; +} + +static ggml_guid_t ggml_backend_webgpu_guid(void) { + static const char * guid_str = "__ggml_webgpu :)"; + return reinterpret_cast((void *)guid_str); +} + +static void ggml_webgpu_init_memset_pipeline(webgpu_context webgpu_ctx) { + // we use the maximum workgroup size for the memset pipeline + size_t max_wg_size = webgpu_ctx->limits.maxComputeWorkgroupSizeX; + size_t max_threads = max_wg_size * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension; + // Size the bytes_per_thread so that the largest buffer size can be handled + webgpu_ctx->memset_bytes_per_thread = (webgpu_ctx->limits.maxStorageBufferBindingSize + max_threads - 1) / max_threads; + std::vector constants(2); + constants[0].key = "wg_size"; + constants[0].value = max_wg_size; + constants[1].key = "bytes_per_thread"; + constants[1].value = webgpu_ctx->memset_bytes_per_thread; + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->memset_pipeline, wgsl_memset, "memset", constants); + ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->memset_params_dev_buf, + 3 * sizeof(uint32_t), // 3 parameters: buffer size, offset, value + wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst, "memset_params_dev_buf"); + ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->memset_params_host_buf, + 3 * sizeof(uint32_t), wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc, "memset_params_host_buf"); +} + +static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context webgpu_ctx) { + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline, wgsl_mul_mat, "mul_mat"); + ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->mul_mat_params_dev_buf, WEBGPU_MUL_MAT_PARAMS_SIZE, + wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst, "mul_mat_params_dev_buf"); + ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->mul_mat_params_host_buf, WEBGPU_MUL_MAT_PARAMS_SIZE, + wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc, "mul_mat_params_host_buf"); +} + +static void ggml_webgpu_init_cpy_pipeline(webgpu_context webgpu_ctx) { + std::vector constants(1); + constants[0].key = "wg_size"; + constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX; + + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy", constants); + ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->cpy_params_dev_buf, WEBGPU_CPY_PARAMS_SIZE, + wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst, "cpy_params_dev_buf"); + ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->cpy_params_host_buf, WEBGPU_CPY_PARAMS_SIZE, + wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc, "cpy_params_host_buf"); +} + +// TODO: Make thread safe if multiple devices are used +static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) { + GGML_UNUSED(params); + + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_device_init()"); + + ggml_backend_webgpu_device_context * dev_ctx = static_cast(dev->context); + webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx; + + std::lock_guard lock(webgpu_ctx->mutex); + + if (!webgpu_ctx->device_initialized) { + // Initialize device + wgpu::DeviceDescriptor dev_desc; + dev_desc.requiredLimits = &webgpu_ctx->limits; + dev_desc.requiredFeatures = webgpu_ctx->features.features; + dev_desc.requiredFeatureCount = webgpu_ctx->features.featureCount; + dev_desc.SetDeviceLostCallback(wgpu::CallbackMode::AllowSpontaneous, + [](const wgpu::Device& device, wgpu::DeviceLostReason reason, wgpu::StringView message) { + GGML_UNUSED(device); + GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast(reason), message.data); + }); + dev_desc.SetUncapturedErrorCallback( + [](const wgpu::Device& device, wgpu::ErrorType reason, wgpu::StringView message) { + GGML_UNUSED(device); + GGML_LOG_ERROR("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast(reason), message.data); + }); + webgpu_ctx->instance.WaitAny(webgpu_ctx->adapter.RequestDevice(&dev_desc, wgpu::CallbackMode::WaitAnyOnly, + [webgpu_ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) { + if (status != wgpu::RequestDeviceStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", message.data); + return; + } + webgpu_ctx->device = device; + }), + UINT64_MAX + ); + GGML_ASSERT(webgpu_ctx->device != nullptr); + + // Initialize (compute) queue + webgpu_ctx->queue = webgpu_ctx->device.GetQueue(); + + ggml_webgpu_init_memset_pipeline(webgpu_ctx); + ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx); + ggml_webgpu_init_cpy_pipeline(webgpu_ctx); + webgpu_ctx->device_initialized = true; + } + + static ggml_backend_webgpu_context backend_ctx; + backend_ctx.name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name; + backend_ctx.webgpu_ctx = webgpu_ctx; + + // See GGML Backend Interface section + static ggml_backend backend = { + /* .guid = */ ggml_backend_webgpu_guid(), + /* .interface = */ ggml_backend_webgpu_i, + /* .device = */ dev, + /* .context = */ &backend_ctx, + }; + + return &backend; +} + +static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggml_backend_dev_t dev) { + // See GGML Backend Buffer Type Interface section + static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = { + /* .iface = */ { + /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size, + /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes + /* .is_host = */ NULL, // defaults to false + }, + /* .device = */ dev, + /* .context = */ NULL, + }; + + return &ggml_backend_webgpu_buffer_type; +} + +static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + GGML_UNUSED(dev); + return buft->iface.get_name == ggml_backend_webgpu_buffer_type_get_name; +} + +static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + GGML_UNUSED(dev); + + switch (op->op) { + case GGML_OP_NONE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + return true; + case GGML_OP_CPY: + return op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_MUL_MAT: + return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; + default: + return false; + } +} + +static struct ggml_backend_device_i ggml_backend_webgpu_device_i = { + /* .get_name = */ ggml_backend_webgpu_device_get_name, + /* .get_description = */ ggml_backend_webgpu_device_get_description, + /* .get_memory = */ ggml_backend_webgpu_device_get_memory, + /* .get_type = */ ggml_backend_webgpu_device_get_type, + /* .get_props = */ ggml_backend_webgpu_device_get_props, + /* .init_backend = */ ggml_backend_webgpu_device_init, + /* .get_buffer_type = */ ggml_backend_webgpu_device_get_buffer_type, + /* .get_host_buffer_type = */ NULL, + /* .buffer_from_host_ptr = */ NULL, + /* .supports_op = */ ggml_backend_webgpu_device_supports_op, + /* .supports_buft = */ ggml_backend_webgpu_device_supports_buft, + /* .offload_op = */ NULL, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; + +/* End GGML Backend Device Interface */ + +/* GGML Backend Registration Interface */ + +static const char * ggml_backend_webgpu_reg_get_name(ggml_backend_reg_t reg) { + ggml_backend_webgpu_reg_context * ctx = static_cast(reg->context); + return ctx->name; +} + +static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) { + ggml_backend_webgpu_reg_context * ctx = static_cast(reg->context); + return ctx->device_count; +} + +// TODO: Does this need to be thread safe? Is it only called once? +// Only one device is supported for now +static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) { + GGML_ASSERT(index == 0); + WEBGPU_LOG_DEBUG("ggml_backend_reg_get_device()"); + + ggml_backend_webgpu_reg_context * reg_ctx = static_cast(reg->context); + + webgpu_context ctx = reg_ctx->webgpu_ctx; + + wgpu::RequestAdapterOptions options = {}; + auto callback = [](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char *message, void *userdata) { + if (status != wgpu::RequestAdapterStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); + return; + } + *static_cast(userdata) = adapter; + }; + void *userdata = &ctx->adapter; + ctx->instance.WaitAny(ctx->instance.RequestAdapter(&options, wgpu::CallbackMode::WaitAnyOnly, callback, userdata), UINT64_MAX); + GGML_ASSERT(ctx->adapter != nullptr); + + ctx->adapter.GetLimits(&ctx->limits); + ctx->adapter.GetFeatures(&ctx->features); + + wgpu::AdapterInfo info{}; + ctx->adapter.GetInfo(&info); + + static ggml_backend_webgpu_device_context device_ctx; + device_ctx.webgpu_ctx = ctx; + device_ctx.device_name = GGML_WEBGPU_NAME; + device_ctx.device_desc = std::string(info.description.data); + + GGML_LOG_INFO("ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | device_desc: %s\n", + info.vendorID, info.vendor.data, info.architecture.data, info.deviceID, info.device.data, info.description.data); + + // See GGML Backend Device Interface section + static ggml_backend_device device = { + /* .iface = */ ggml_backend_webgpu_device_i, + /* .reg = */ reg, + /* .context = */ &device_ctx, + }; + return &device; +} + + +static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = { + /* .get_name = */ ggml_backend_webgpu_reg_get_name, + /* .get_device_count = */ ggml_backend_webgpu_reg_get_device_count, + /* .get_device = */ ggml_backend_webgpu_reg_get_device, + /* .get_proc_address = */ NULL, +}; + +/* End GGML Backend Registration Interface */ + +// TODO: Does this need to be thread safe? Is it only called once? +ggml_backend_reg_t ggml_backend_webgpu_reg() { + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()"); + + webgpu_context webgpu_ctx = std::make_shared(); + webgpu_ctx->device_initialized = false; + + static ggml_backend_webgpu_reg_context ctx; + ctx.webgpu_ctx = webgpu_ctx; + ctx.name = GGML_WEBGPU_NAME; + ctx.device_count = 1; + + wgpu::InstanceDescriptor instance_descriptor{}; + std::vector instance_features = {wgpu::InstanceFeatureName::TimedWaitAny}; + instance_descriptor.requiredFeatures = instance_features.data(); + instance_descriptor.requiredFeatureCount = instance_features.size(); + webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor); + GGML_ASSERT(webgpu_ctx->instance != nullptr); + + static ggml_backend_reg reg = { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_webgpu_reg_i, + /* .context = */ &ctx, + }; + return ® +} + +ggml_backend_t ggml_backend_webgpu_init(void) { + ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0); + + return ggml_backend_webgpu_device_init(dev, nullptr); +} + +GGML_BACKEND_DL_IMPL(ggml_backend_webgpu_reg) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl new file mode 100644 index 0000000000000..6fe924c554cc3 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl @@ -0,0 +1,60 @@ +enable f16; + +@group(0) @binding(0) +var src: array; + +@group(0) @binding(1) +var dst: array; + +struct Params { + ne: u32, // total number of elements + offset_src: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) — may be permuted + stride_src0: u32, + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_dst0: u32, + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Logical shape (same for both tensors) + ne0: u32, + ne1: u32, + ne2: u32, + ne3: u32, +}; + +@group(0) @binding(2) +var params: Params; + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne) { + return; + } + + var i = gid.x; + + let i3 = i / (params.ne2 * params.ne1 * params.ne0); + i = i % (params.ne2 * params.ne1 * params.ne0); + + let i2 = i / (params.ne1 * params.ne0); + i = i % (params.ne1 * params.ne0); + + let i1 = i / params.ne0; + let i0 = i % params.ne0; + + let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + + i2 * params.stride_src2 + i3 * params.stride_src3; + + let dst_idx = i0 * params.stride_dst0 + i1 * params.stride_dst1 + + i2 * params.stride_dst2 + i3 * params.stride_dst3; + + dst[params.offset_dst + dst_idx] = f16(src[params.offset_src + src_idx]); +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py new file mode 100755 index 0000000000000..962dcd6b170ed --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py @@ -0,0 +1,35 @@ +import os +import argparse + + +def escape_triple_quotes(wgsl): + # Simple defense in case of embedded """ + return wgsl.replace('"""', '\\"""') + + +def to_cpp_string_literal(varname, content): + return f'const char* wgsl_{varname} = R"({content})";\n' + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--input', required=True) + parser.add_argument('--output', required=True) + args = parser.parse_args() + + with open(args.output, 'w', encoding='utf-8') as out: + out.write("// Auto-generated shader embedding \n\n") + for fname in sorted(os.listdir(args.input)): + if not fname.endswith('.wgsl'): + continue + shader_path = os.path.join(args.input, fname) + varname = os.path.splitext(fname)[0] + with open(shader_path, 'r', encoding='utf-8') as f: + content = f.read() + content = escape_triple_quotes(content) + out.write(to_cpp_string_literal(varname, content)) + out.write('\n') + + +if __name__ == '__main__': + main() diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl new file mode 100644 index 0000000000000..cb7c8c3e09e91 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl @@ -0,0 +1,40 @@ +@group(0) @binding(0) +var output_buffer: array; + +struct Params { + offset: u32, // in bytes + size: u32, // in bytes + value: u32, // 4 8-bit values, which are either repeating (memset_tensor) or may be separate (cleaning up unaligned set_tensor operations) +}; + +@group(0) @binding(1) +var params: Params; + +override wg_size: u32; +override bytes_per_thread: u32; + +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + let i = gid.x * bytes_per_thread; + let start = params.offset; + let end = params.offset + params.size; + + for (var j: u32 = 0u; j < bytes_per_thread; j = j + 1u) { + let byte_index = start + i + j; + if (byte_index + 4u <= end) { + output_buffer[(byte_index >> 2u)] = params.value; + } else { + // Handle tail (unaligned) + for (var k: u32 = 0u; k < 4u; k = k + 1u) { + let idx = byte_index + k; + if (idx < end) { + let word_idx = idx >> 2u; + let byte_offset = (idx & 3u) * 8u; + let mask = ~(0xffu << byte_offset); + let existing = output_buffer[word_idx]; + output_buffer[word_idx] = (existing & mask) | ((params.value & 0xffu) << byte_offset); + } + } + } + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl new file mode 100644 index 0000000000000..054aab566f96b --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl @@ -0,0 +1,56 @@ +struct MulMatParams { + m: u32, + n: u32, + k: u32, + // all strides are in elements + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, + stride_03: u32, + stride_13: u32, + + bs02: u32, + bs03: u32, + broadcast2: u32, + broadcast3: u32 +}; + +@group(0) @binding(0) var src0: array; // N rows, K columns +@group(0) @binding(1) var src1: array; // M rows, K columns (transposed) +@group(0) @binding(2) var dst: array; // M rows, N columns + +@group(0) @binding(3) var params: MulMatParams; + +@compute @workgroup_size(64) +fn main(@builtin(global_invocation_id) global_id: vec3) { + let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; + if (global_id.x >= total) { + return; + } + + let dst2_stride = params.m * params.n; + let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; + + let dst3_idx = global_id.x / dst3_stride; + let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension + let src13_idx = dst3_idx; // src1 is not broadcast + let dst3_rem = global_id.x % dst3_stride; + + let dst2_idx = dst3_rem / dst2_stride; + let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension + let src12_idx = dst2_idx; // src1 is not broadcast + + let dst2_rem = dst3_rem % dst2_stride; + + let row = dst2_rem / params.n; // output row + let col = dst2_rem % params.n; // output column + + var sum = 0.0; + for (var i: u32 = 0u; i < params.k; i = i + 1u) { + let src0_idx = src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01 + i; + let src1_idx = src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11 + i; + sum = sum + src0[src0_idx] * src1[src1_idx]; + } + dst[dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.n + col] = sum; +} diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 4e2b878e189c6..d8afe7696d243 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -317,6 +317,7 @@ class MODEL_ARCH(IntEnum): PHI3 = auto() PHIMOE = auto() PLAMO = auto() + PLAMO2 = auto() CODESHELL = auto() ORION = auto() INTERNLM2 = auto() @@ -366,6 +367,7 @@ class MODEL_ARCH(IntEnum): HUNYUAN_MOE = auto() SMOLLM3 = auto() LFM2 = auto() + DREAM = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -631,6 +633,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.PHI3: "phi3", MODEL_ARCH.PHIMOE: "phimoe", MODEL_ARCH.PLAMO: "plamo", + MODEL_ARCH.PLAMO2: "plamo2", MODEL_ARCH.CODESHELL: "codeshell", MODEL_ARCH.ORION: "orion", MODEL_ARCH.INTERNLM2: "internlm2", @@ -681,6 +684,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe", MODEL_ARCH.SMOLLM3: "smollm3", MODEL_ARCH.LFM2: "lfm2", + MODEL_ARCH.DREAM: "dream", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -1287,6 +1291,21 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.DREAM: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], MODEL_ARCH.QWEN2VL: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -1369,6 +1388,36 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.PLAMO2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_POST_NORM, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_X, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_D, + MODEL_TENSOR.SSM_OUT, + MODEL_TENSOR.SSM_DT_NORM, + MODEL_TENSOR.SSM_B_NORM, + MODEL_TENSOR.SSM_C_NORM, + ], MODEL_ARCH.GPT2: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.POS_EMBD, diff --git a/gguf-py/gguf/scripts/gguf_dump.py b/gguf-py/gguf/scripts/gguf_dump.py index e282892d645c7..8177dff386c7e 100755 --- a/gguf-py/gguf/scripts/gguf_dump.py +++ b/gguf-py/gguf/scripts/gguf_dump.py @@ -234,6 +234,8 @@ def dump_markdown_metadata(reader: GGUFReader, args: argparse.Namespace) -> None markdown_content += '## Key Value Metadata Store\n\n' markdown_content += f'There are {len(reader.fields)} key-value pairs in this file\n' markdown_content += '\n' + total_model_bytes = 0 + total_model_elements = 0 kv_dump_table: list[dict[str, str | int]] = [] for n, field in enumerate(reader.fields.values(), 1): @@ -377,6 +379,8 @@ def escape_markdown_inline_code(value_string): tensors = tensor_groups[group] group_elements = sum(tensor.n_elements for tensor in tensors) group_percentage = group_elements / total_elements * 100 + total_group_bytes = 0 + total_group_elements = 0 markdown_content += f"### {translate_tensor_name(group)} Tensor Group : {element_count_rounded_notation(group_elements)} Elements\n\n" # Precalculate column sizing for visual consistency @@ -397,7 +401,13 @@ def escape_markdown_inline_code(value_string): element_count_est = f"({element_count_rounded_notation(tensor.n_elements):>{prettify_element_est_count_size}})" element_count_string = f"{element_count_est} {tensor.n_elements:>{prettify_element_count_size}}" type_name_string = f"{tensor.tensor_type.name}" - tensor_dump_table.append({"t_id":tensor_name_to_key[tensor.name], "layer_name":tensor.name, "human_layer_name":human_friendly_name, "element_count":element_count_string, "pretty_dimension":pretty_dimension, "tensor_type":type_name_string}) + if tensor.n_elements > 0: + bpw = (tensor.n_bytes * 8) / tensor.n_elements + else: + bpw = float('nan') + tensor_dump_table.append({"t_id":tensor_name_to_key[tensor.name], "layer_name":tensor.name, "human_layer_name":human_friendly_name, "element_count":element_count_string, "pretty_dimension":pretty_dimension, "tensor_type":type_name_string, "bpw": f"{bpw:.4f}"}) + total_group_bytes += tensor.n_bytes + total_group_elements += tensor.n_elements tensor_dump_table_header_map = [ {'key_name':'t_id', 'header_name':'T_ID', 'align':'right'}, @@ -406,6 +416,7 @@ def escape_markdown_inline_code(value_string): {'key_name':'element_count', 'header_name':'Elements', 'align':'left'}, {'key_name':'pretty_dimension', 'header_name':'Shape', 'align':'left'}, {'key_name':'tensor_type', 'header_name':'Type', 'align':'left'}, + {'key_name':'bpw', 'header_name':'BPW', 'align':'right'}, ] markdown_content += markdown_table_with_alignment_support(tensor_dump_table_header_map, tensor_dump_table) @@ -413,8 +424,20 @@ def escape_markdown_inline_code(value_string): markdown_content += "\n" markdown_content += f"- Total elements in {group}: ({element_count_rounded_notation(group_elements):>4}) {group_elements}\n" markdown_content += f"- Percentage of total elements: {group_percentage:.2f}%\n" + if total_group_elements > 0: + total_group_bpw = (total_group_bytes * 8) / total_group_elements + markdown_content += f"- Bits per Weight (BPW) for {group}: {total_group_bpw:.4f} bits\n" + else: + markdown_content += f"- Bits per Weight (BPW) for {group}: undefined (no elements)\n" markdown_content += "\n\n" + total_model_bytes += total_group_bytes + total_model_elements += total_group_elements + if total_model_elements > 0: + total_model_bpw = (total_model_bytes * 8) / total_model_elements + markdown_content += f"Total BPW for {os.path.basename(args.model)}: {total_model_bpw:.4f} bits" + else: + markdown_content += f"Total BPW for {os.path.basename(args.model)}: undefined (no elements)" print(markdown_content) # noqa: NP100 diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 75855eba52c3c..2a675044f9d99 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -13,7 +13,7 @@ class TensorNameMap: "transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone "transformer.word_embeddings", # falcon "word_embeddings", # bloom - "model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 granite-hybrid + "model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 plamo2 granite-hybrid "tok_embeddings", # llama-pth "embeddings.word_embeddings", # bert nomic-bert "language_model.embedding.word_embeddings", # persimmon @@ -63,7 +63,7 @@ class TensorNameMap: # Output MODEL_TENSOR.OUTPUT: ( "embed_out", # gptneox - "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2 phimoe + "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2 phimoe plamo2 "output", # llama-pth bloom internlm2 "word_embeddings_for_head", # persimmon "lm_head.linear", # phi2 @@ -77,7 +77,7 @@ class TensorNameMap: MODEL_TENSOR.OUTPUT_NORM: ( "gpt_neox.final_layer_norm", # gptneox "transformer.ln_f", # gpt2 gpt-j falcon jais exaone - "model.norm", # llama-hf baichuan internlm2 olmoe olmo2 phimoe + "model.norm", # llama-hf baichuan internlm2 olmoe olmo2 phimoe plamo2 "norm", # llama-pth "transformer.norm_f", # mpt dbrx "ln_f", # refact bloom qwen gpt2 @@ -126,6 +126,7 @@ class TensorNameMap: "h.{bid}.ln_1", # gpt2 "transformer.h.{bid}.ln", # phi2 "model.layers.layers.{bid}.norm", # plamo + "model.layers.layers.{bid}.pre_mixer_norm", # plamo2 "model.layers.{bid}.attention_norm", # internlm2 "model.layers.{bid}.norm", # mamba-qbert "backbone.layers.{bid}.norm", # mamba @@ -163,6 +164,7 @@ class TensorNameMap: "encoder.layers.{bid}.attn.Wqkv", # nomic-bert "encoder.layers.{bid}.mixer.Wqkv", # jina "model.layers.{bid}.self_attn.qkv_proj", # phi3 + "model.layers.layers.{bid}.mixer.qkv_proj", # plamo2 "encoder.layers.{bid}.self_attention.query_key_value", # chatglm "transformer.layers.{bid}.attn.qkv_proj", # openelm "transformer_encoder.{bid}.qkv", # neobert @@ -233,6 +235,7 @@ class TensorNameMap: "h.{bid}.attn.c_proj", # gpt2 "transformer.h.{bid}.mixer.out_proj", # phi2 "model.layers.layers.{bid}.self_attn.o_proj", # plamo + "model.layers.layers.{bid}.mixer.o_proj", # plamo2 "model.layers.{bid}.attention.wo", # internlm2 "encoder.layers.{bid}.attn.out_proj", # nomic-bert "encoder.layers.{bid}.mixer.out_proj", # jina @@ -255,8 +258,9 @@ class TensorNameMap: ), MODEL_TENSOR.ATTN_POST_NORM: ( - "model.layers.{bid}.post_attention_layernorm", # gemma2 olmo2 # ge - "model.layers.{bid}.post_self_attn_layernorm", # glm-4-0414 + "model.layers.{bid}.post_attention_layernorm", # gemma2 olmo2 # ge + "model.layers.{bid}.post_self_attn_layernorm", # glm-4-0414 + "model.layers.layers.{bid}.post_mixer_norm.weight", # plamo2 ), # Rotary embeddings @@ -286,6 +290,7 @@ class TensorNameMap: "model.layers.{bid}.pre_moe_layernorm", # mini-jamba "model.layers.{bid}.post_attention_layernorm", # llama4 "transformer_encoder.{bid}.ffn_norm", # neobert + "model.layers.layers.{bid}.pre_mlp_norm", # plamo2 ), # Post feed-forward norm @@ -298,6 +303,7 @@ class TensorNameMap: MODEL_TENSOR.FFN_POST_NORM: ( "model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2 "model.layers.{bid}.post_mlp_layernorm", # glm-4-0414 + "model.layers.layers.{bid}.post_mlp_norm.weight", # plamo2 "model.layers.{bid}.feed_forward.up_proj", ), @@ -342,6 +348,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.fc1", # phi2 "model.layers.{bid}.mlp.gate_up_proj", # phi3 glm-4-0414 "model.layers.layers.{bid}.mlp.up_proj", # plamo + "model.layers.layers.{bid}.mlp.gate_up_proj", # plamo2 "model.layers.{bid}.feed_forward.w3", # internlm2 "encoder.layers.{bid}.mlp.fc11", # nomic-bert "encoder.layers.{bid}.mlp.fc1", # nomic-bert-moe @@ -469,6 +476,7 @@ class TensorNameMap: "transformer.blocks.{bid}.attn.q_ln", # sea-lion "encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2 "transformer.layers.{bid}.attn.q_norm", # openelm + "model.layers.layers.{bid}.mixer.q", # plamo2 ), MODEL_TENSOR.ATTN_K_NORM: ( @@ -479,6 +487,7 @@ class TensorNameMap: "transformer.blocks.{bid}.attn.k_ln", # sea-lion "encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2 "transformer.layers.{bid}.attn.k_norm", # openelm + "model.layers.layers.{bid}.mixer.k", # plamo2 ), MODEL_TENSOR.ROPE_FREQS: ( @@ -559,27 +568,31 @@ class TensorNameMap: ), MODEL_TENSOR.SSM_IN: ( - "model.layers.{bid}.in_proj", # mamba-hf - "backbone.layers.{bid}.mixer.in_proj", # mamba - "model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 granite-hybrid + "model.layers.{bid}.in_proj", # mamba-hf + "backbone.layers.{bid}.mixer.in_proj", # mamba + "model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 granite-hybrid + "model.layers.layers.{bid}.mixer.in_proj", # plamo2 ), MODEL_TENSOR.SSM_CONV1D: ( - "model.layers.{bid}.conv1d", # mamba-hf - "backbone.layers.{bid}.mixer.conv1d", # mamba - "model.layers.{bid}.mamba.conv1d", # jamba falcon-h1 granite-hybrid + "model.layers.{bid}.conv1d", # mamba-hf + "backbone.layers.{bid}.mixer.conv1d", # mamba + "model.layers.{bid}.mamba.conv1d", # jamba falcon-h1 granite-hybrid + "model.layers.layers.{bid}.mixer.conv1d", # plamo2 ), MODEL_TENSOR.SSM_X: ( - "model.layers.{bid}.x_proj", # mamba-hf - "backbone.layers.{bid}.mixer.x_proj", # mamba - "model.layers.{bid}.mamba.x_proj", # jamba + "model.layers.{bid}.x_proj", # mamba-hf + "backbone.layers.{bid}.mixer.x_proj", # mamba + "model.layers.{bid}.mamba.x_proj", # jamba + "model.layers.layers.{bid}.mixer.bcdt_proj", # plamo2 ), MODEL_TENSOR.SSM_DT: ( - "model.layers.{bid}.dt_proj", # mamba-hf - "backbone.layers.{bid}.mixer.dt_proj", # mamba - "model.layers.{bid}.mamba.dt_proj", # jamba falcon-h1 granite-hybrid + "model.layers.{bid}.dt_proj", # mamba-hf + "backbone.layers.{bid}.mixer.dt_proj", # mamba + "model.layers.{bid}.mamba.dt_proj", # jamba falcon-h1 granite-hybrid + "model.layers.layers.{bid}.mixer.dt_proj", # plamo2 ), MODEL_TENSOR.SSM_DT_NORM: ( @@ -587,25 +600,33 @@ class TensorNameMap: ), MODEL_TENSOR.SSM_A: ( - "model.layers.{bid}.A_log", # mamba-hf - "backbone.layers.{bid}.mixer.A_log", # mamba - "model.layers.{bid}.mamba.A_log", # jamba falcon-h1 granite-hybrid + "model.layers.{bid}.A_log", # mamba-hf + "backbone.layers.{bid}.mixer.A_log", # mamba + "model.layers.{bid}.mamba.A_log", # jamba falcon-h1 granite-hybrid + "model.layers.layers.{bid}.mixer.A_log", # plamo2 ), MODEL_TENSOR.SSM_B_NORM: ( - "model.layers.{bid}.mamba.b_layernorm", # jamba - "model.layers.{bid}.mamba.B_layernorm", # mini-jamba + "model.layers.{bid}.mamba.b_layernorm", # jamba + "model.layers.{bid}.mamba.B_layernorm", # mini-jamba + "model.layers.layers.{bid}.mixer.B_norm.weight", # plamo2 ), MODEL_TENSOR.SSM_C_NORM: ( - "model.layers.{bid}.mamba.c_layernorm", # jamba - "model.layers.{bid}.mamba.C_layernorm", # mini-jamba + "model.layers.{bid}.mamba.c_layernorm", # jamba + "model.layers.{bid}.mamba.C_layernorm", # mini-jamba + "model.layers.layers.{bid}.mixer.C_norm.weight", # plamo2 ), MODEL_TENSOR.SSM_D: ( - "model.layers.{bid}.D", # mamba-hf - "backbone.layers.{bid}.mixer.D", # mamba - "model.layers.{bid}.mamba.D", # jamba falcon-h1 granite-hybrid + "model.layers.{bid}.D", # mamba-hf + "backbone.layers.{bid}.mixer.D", # mamba + "model.layers.{bid}.mamba.D", # jamba falcon-h1 granite-hybrid + "model.layers.layers.{bid}.mixer.D", # plamo2 + ), + + MODEL_TENSOR.SSM_DT_NORM: ( + "model.layers.layers.{bid}.mixer.dt_norm.weight", # plamo2 ), MODEL_TENSOR.SSM_NORM: ( @@ -614,9 +635,10 @@ class TensorNameMap: ), MODEL_TENSOR.SSM_OUT: ( - "model.layers.{bid}.out_proj", # mamba-hf - "backbone.layers.{bid}.mixer.out_proj", # mamba - "model.layers.{bid}.mamba.out_proj", # jamba falcon-h1 granite-hybrid + "model.layers.{bid}.out_proj", # mamba-hf + "backbone.layers.{bid}.mixer.out_proj", # mamba + "model.layers.{bid}.mamba.out_proj", # jamba falcon-h1 granite-hybrid + "model.layers.layers.{bid}.mixer.out_proj", # plamo2 ), MODEL_TENSOR.TIME_MIX_W0: ( diff --git a/include/llama.h b/include/llama.h index f73b1ab65fe6f..db6a5337b02a7 100644 --- a/include/llama.h +++ b/include/llama.h @@ -71,12 +71,13 @@ extern "C" { typedef int32_t llama_seq_id; enum llama_vocab_type { - LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab - LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback - LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE - LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece - LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram - LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization + LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab + LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback + LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE + LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece + LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram + LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization + LLAMA_VOCAB_TYPE_PLAMO2 = 6, // PLaMo-2 tokenizer based on Aho-Corasick with dynamic programming }; enum llama_rope_type { @@ -334,6 +335,9 @@ extern "C" { bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) // NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases // ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573 + bool kv_unified; // use a unified buffer across the input sequences when computing the attention + // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix + // ref: https://github.com/ggml-org/llama.cpp/pull/14363 }; // model quantization parameters @@ -724,7 +728,7 @@ extern "C" { // - lazily on next llama_decode() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - DEPRECATED(void llama_kv_self_seq_div( + DEPRECATED(LLAMA_API void llama_kv_self_seq_div( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, @@ -1004,6 +1008,7 @@ extern "C" { LLAMA_API llama_token llama_vocab_sep(const struct llama_vocab * vocab); // sentence separator LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding + LLAMA_API llama_token llama_vocab_mask(const struct llama_vocab * vocab); // mask LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab); LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab); diff --git a/models/templates/llama-cpp-rwkv-world.jinja b/models/templates/llama-cpp-rwkv-world.jinja new file mode 100644 index 0000000000000..690223f1b03fe --- /dev/null +++ b/models/templates/llama-cpp-rwkv-world.jinja @@ -0,0 +1,34 @@ +{%- if not add_generation_prompt is defined -%} + {%- set add_generation_prompt = true -%} +{%- endif -%} +{%- set ns = namespace(system_prompt='') -%} +{%- for message in messages -%} + {%- if message['role'] == 'system' -%} + {%- set ns.system_prompt = message['content'] -%} + {%- endif -%} +{%- endfor -%} +{{bos_token}} +{%- if ns.system_prompt != '' -%} +{{- 'System: ' + ns.system_prompt + '\n\n' -}} +{%- endif -%} +{%- for message in messages -%} + {%- if message['role'] == 'user' -%} + {{- 'User: ' + message['content']|trim + '\n\n' -}} + {%- endif -%} + {%- if message['role'] == 'assistant' and message['content'] is not none -%} + {%- set content = message['content'] -%} + {%- if '' in content -%} + {%- set content = content.split('')[-1] -%} + {%- endif -%} + {{- 'Assistant: ' + content|trim + '\n\n' -}} + {%- endif -%} +{%- endfor -%} +{%- if add_generation_prompt -%} + {{- 'Assistant:' -}} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- ' \n' }} + {%- endif %} + {%- if enable_thinking is defined and enable_thinking is true %} + {{- ' ' }} + {%- endif %} +{%- endif -%} \ No newline at end of file diff --git a/models/templates/moonshotai-Kimi-K2.jinja b/models/templates/moonshotai-Kimi-K2.jinja new file mode 100644 index 0000000000000..ecb49a210852c --- /dev/null +++ b/models/templates/moonshotai-Kimi-K2.jinja @@ -0,0 +1,43 @@ +{%- if tools -%} + <|im_system|>tool_declare<|im_middle|>{{ tools | tojson }}<|im_end|> +{%- endif -%} +{%- for message in messages -%} + {%- if loop.first and messages[0]['role'] != 'system' -%} + <|im_system|>system<|im_middle|>You are a helpful assistant<|im_end|> + {%- endif -%} + {%- if message['role'] == 'system' -%} + <|im_system|>system<|im_middle|> + {%- elif message['role'] == 'user' -%} + <|im_user|>user<|im_middle|> + {%- elif message['role'] == 'assistant' -%} + <|im_assistant|>assistant<|im_middle|> + {%- elif message['role'] == 'tool' -%} + <|im_system|>tool<|im_middle|> + {%- endif -%} + {%- if message['role'] == 'assistant' and message.get('tool_calls') -%} + {%- if message['content'] -%}{{ message['content'] }}{%- endif -%} + <|tool_calls_section_begin|> + {%- for tool_call in message['tool_calls'] -%} + {%- set func_name = tool_call['function']['name'] -%} + {%- set formatted_id = 'functions.' + func_name + ':' + loop.index0|string -%} + <|tool_call_begin|>{{ formatted_id }}<|tool_call_argument_begin|>{{ tool_call['function']['arguments'] | tojson}}<|tool_call_end|> + {%- endfor -%} + <|tool_calls_section_end|> + {%- elif message['role'] == 'tool' -%} + ## Return of {{ message.tool_call_id }}\n{{ message['content'] }} + {%- elif message['content'] is string -%} + {{ message['content'] }} + {%- elif message['content'] is not none -%} + {% for content in message['content'] -%} + {% if content['type'] == 'image' or 'image' in content or 'image_url' in content -%} + <|media_start|>image<|media_content|><|media_pad|><|media_end|> + {% else -%} + {{ content['text'] }} + {%- endif -%} + {%- endfor -%} + {%- endif -%} + <|im_end|> +{%- endfor -%} +{%- if add_generation_prompt -%} + <|im_assistant|>assistant<|im_middle|> +{%- endif -%} diff --git a/requirements/requirements-all.txt b/requirements/requirements-all.txt index 9fa7d4d0abdec..56b6752ac0645 100644 --- a/requirements/requirements-all.txt +++ b/requirements/requirements-all.txt @@ -3,6 +3,7 @@ -r ../tools/server/tests/requirements.txt -r ./requirements-compare-llama-bench.txt +-r ./requirements-server-bench.txt -r ./requirements-pydantic.txt -r ./requirements-test-tokenizer-random.txt diff --git a/requirements/requirements-server-bench.txt b/requirements/requirements-server-bench.txt new file mode 100644 index 0000000000000..ea5849fa104ef --- /dev/null +++ b/requirements/requirements-server-bench.txt @@ -0,0 +1,5 @@ +datasets~=3.2.0 +matplotlib~=3.10.0 +numpy~=1.26.4 +requests~=2.32.3 +tqdm~=4.67.1 diff --git a/scripts/server-bench.py b/scripts/server-bench.py new file mode 100755 index 0000000000000..3afad66ced47b --- /dev/null +++ b/scripts/server-bench.py @@ -0,0 +1,265 @@ +#!/usr/bin/env python3 + +import argparse +import json +import os +import random +import subprocess +from time import sleep, time +from typing import Optional, Union + +import datasets +import logging +import matplotlib.pyplot as plt +import numpy as np +import requests +from tqdm.contrib.concurrent import thread_map + + +logging.basicConfig(level=logging.INFO, format='%(message)s') +logger = logging.getLogger("server-bench") + + +def get_prompts_text(dataset_name: str, n_prompts: int) -> Optional[list[str]]: + ret = [] + if dataset_name.lower() == "mmlu": + logger.info("Loading MMLU dataset...") + ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"] # type: ignore + else: + return None + if n_prompts >= 0: + ret = ret[:n_prompts] + return ret + + +def get_prompt_lengths_rng(n_prompts: int, prompt_length_min: int, prompt_length_max: int) -> list[int]: + assert n_prompts >= 0 + ret: list[int] = [] + for i in range(n_prompts): + random.seed(13 * i + 0) + ret.append(random.randint(prompt_length_min, prompt_length_max)) + return ret + + +def get_prompts_rng(prompt_lengths: list[int]) -> list[list[int]]: + return [[random.randint(100, 10000) for _ in range(pl)] for pl in prompt_lengths] + + +def get_server(path_server: str, path_log: Optional[str]) -> dict: + logger.info("Starting the llama.cpp server...") + hostname: str = os.environ.get("LLAMA_ARG_HOST", "127.0.0.1") + port: str = os.environ.get("LLAMA_ARG_PORT", "8080") + address: str = f"http://{hostname}:{port}" + + fout = open(path_log, "w") if path_log is not None else subprocess.DEVNULL + process = subprocess.Popen([path_server], stdout=fout, stderr=subprocess.STDOUT) + + n_failures: int = 0 + while True: + try: + sleep(1.0) + exit_code = process.poll() + if exit_code is not None: + raise RuntimeError(f"llama.cpp server exited unexpectedly with exit code {exit_code}, see {path_log}") + response = requests.get(f"{address}/health") + if response.status_code == 200: + break + except requests.ConnectionError: + n_failures += 1 + if n_failures >= 10: + raise RuntimeError("llama.cpp server is not healthy after 10 seconds") + + return {"process": process, "address": address, "fout": fout} + + +def get_prompt_length(data: dict) -> int: + session = data["session"] + server_address: str = data["server_address"] + + response = session.post( + f"{server_address}/apply-template", + json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]} + ) + if response.status_code != 200: + raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}") + prompt: str = json.loads(response.text)["prompt"] + response = session.post( + f"{server_address}/tokenize", + json={"content": prompt, "add_special": True} + ) + if response.status_code != 200: + raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}") + tokens: list[str] = json.loads(response.text)["tokens"] + return len(tokens) + + +def send_prompt(data: dict) -> tuple[float, list[float]]: + session = data["session"] + server_address: str = data["server_address"] + + t_submit = time() + if data["synthetic_prompt"]: + json_data: dict = { + "prompt": data["prompt"], "ignore_eos": True, "cache_prompt": False, + "seed": data["seed"], "n_predict": data["n_predict"], "stream": True} + response = session.post(f"{server_address}/completion", json=json_data, stream=True) + else: + response = session.post( + f"{server_address}/apply-template", + json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]} + ) + if response.status_code != 200: + raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}") + prompt: str = json.loads(response.text)["prompt"] + + json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True} + response = session.post(f"{server_address}/completion", json=json_data, stream=True) + + token_arrival_times: list[float] = [] + for line in response.iter_lines(decode_unicode=False): + if not line.startswith(b"data: "): + continue + token_arrival_times.append(time()) + token_arrival_times = token_arrival_times[:-1] + + if response.status_code != 200: + raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}") + + return (t_submit, token_arrival_times) + + +def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_prompts: int, n_predict: int, n_predict_min: int): + if os.environ.get("LLAMA_ARG_N_PARALLEL") is None: + logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32") + os.environ["LLAMA_ARG_N_PARALLEL"] = "32" + if os.environ.get("LLAMA_ARG_N_GPU_LAYERS") is None: + logger.info("LLAMA_ARG_N_GPU_LAYERS not explicitly set, using 999") + os.environ["LLAMA_ARG_N_GPU_LAYERS"] = "999" + if os.environ.get("LLAMA_ARG_FLASH_ATTN") is None: + logger.info("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'") + os.environ["LLAMA_ARG_FLASH_ATTN"] = "true" + + parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL", 1)) + prompts: Union[None, list[str], list[list[int]]] = get_prompts_text(prompt_source, n_prompts) + synthetic_prompts: bool = prompts is None + prompt_n = [] + + if synthetic_prompts: + prompt_source_split: list[str] = prompt_source.split("-") + assert len(prompt_source_split) == 3 + assert prompt_source_split[0].lower() == "rng" + prompt_length_min: int = int(prompt_source_split[1]) + prompt_length_max: int = int(prompt_source_split[2]) + logger.info("Generating random prompts...") + prompt_n = get_prompt_lengths_rng(n_prompts, prompt_length_min, prompt_length_max) + prompts = get_prompts_rng(prompt_n) + else: + n_predict_min = n_predict + + if os.environ.get("LLAMA_ARG_CTX_SIZE") is None: + context_per_slot: int = int(1.05 * (n_predict + (np.max(prompt_n) if synthetic_prompts else 2048))) + context_total: int = context_per_slot * parallel + os.environ["LLAMA_ARG_CTX_SIZE"] = str(context_total) + logger.info(f"LLAMA_ARG_CTX_SIZE not explicitly set, using {context_total} ({context_per_slot} per slot).") + + server: Optional[dict] = None + session = None + try: + server = get_server(path_server, path_log) + server_address: str = server["address"] + + adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore + session = requests.Session() + session.mount("http://", adapter) + session.mount("https://", adapter) + + data: list[dict] = [] + + for i, p in enumerate(prompts): + random.seed(13 * i + 1) + data.append({ + "session": session, "server_address": server_address, "prompt": p, "synthetic_prompt": synthetic_prompts, + "n_predict": random.randint(n_predict_min, n_predict), "seed": 13 * i + 2}) + + if not synthetic_prompts: + logger.info("Getting the prompt lengths...") + prompt_n = [get_prompt_length(d) for d in data] + + logger.info("Starting the benchmark...\n") + t0 = time() + results: list[tuple[float, list[float]]] = thread_map(send_prompt, data, max_workers=parallel, chunksize=1) + finally: + if server is not None: + server["process"].terminate() + server["process"].wait() + if session is not None: + session.close() + + prompt_t = [] + token_t = [] + depth_sum: int = 0 + for pn, (t_submit, tat) in zip(prompt_n, results): + prompt_t.append(tat[0] - t_submit) + token_t += tat + n_tokens: int = len(tat) + depth_sum += n_tokens * pn + depth_sum += n_tokens * (n_tokens + 1) // 2 + assert len(token_t) > 0 + prompt_n = np.array(prompt_n, dtype=np.int64) + prompt_t = np.array(prompt_t, dtype=np.float64) + token_t = np.array(token_t, dtype=np.float64) + + token_t -= t0 + token_t_last = np.max(token_t) + + logger.info("") + logger.info(f"Benchmark duration: {token_t_last:.2f} s") + logger.info(f"Request throughput: {n_prompts / token_t_last:.2f} requests/s = {n_prompts / (token_t_last/60):.2f} requests/min") + logger.info(f"Total prompt length: {np.sum(prompt_n)} tokens") + logger.info(f"Average prompt length: {np.mean(prompt_n):.2f} tokens") + logger.info(f"Average prompt latency: {1e3 * np.mean(prompt_t):.2f} ms") + logger.info(f"Average prompt speed: {np.sum(prompt_n) / np.sum(prompt_t):.2f} tokens/s") + logger.info(f"Total generated tokens: {token_t.shape[0]}") + logger.info(f"Average generation depth: {depth_sum / token_t.shape[0]:.2f} tokens") + logger.info(f"Average total generation speed: {token_t.shape[0] / token_t_last:.2f} tokens/s") + logger.info(f"Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot") + logger.info("") + logger.info( + "The above numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, " + "particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model).") + + plt.figure() + plt.scatter(prompt_n, 1e3 * prompt_t, s=10.0, marker=".", alpha=0.25) + plt.xlim(0, 1.05e0 * np.max(prompt_n)) + plt.ylim(0, 1.05e3 * np.max(prompt_t)) + plt.xlabel("Prompt length [tokens]") + plt.ylabel("Time to first token [ms]") + plt.savefig("prompt_time.png", dpi=240) + + bin_max = np.ceil(token_t_last) + 1 + plt.figure() + plt.hist(token_t, np.arange(0, bin_max)) + plt.xlim(0, bin_max + 1) + plt.xlabel("Time [s]") + plt.ylabel("Num. tokens generated per second") + plt.savefig("gen_rate.png", dpi=240) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Tool for benchmarking the throughput of the llama.cpp HTTP server. " + "Results are printed to console and visualized as plots (saved to current working directory). " + "To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help).") + parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary") + parser.add_argument("--path_log", type=str, default="server-bench.log", help="Path to the model to use for the benchmark") + parser.add_argument( + "--prompt_source", type=str, default="rng-1024-2048", + help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or " + "rng-MIN-MAX for synthetic prompts with random lengths in the interval [MIN, MAX]") + parser.add_argument("--n_prompts", type=int, default=100, help="Number of prompts to evaluate") + parser.add_argument("--n_predict", type=int, default=2048, help="Max. number of tokens to predict per prompt") + parser.add_argument( + "--n_predict_min", type=int, default=1024, + help="Min. number of tokens to predict per prompt (supported for synthetic prompts only)") + args = parser.parse_args() + benchmark(**vars(args)) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index e63ab284bc3b5..9454d04e53801 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -34,6 +34,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_PHI3, "phi3" }, { LLM_ARCH_PHIMOE, "phimoe" }, { LLM_ARCH_PLAMO, "plamo" }, + { LLM_ARCH_PLAMO2, "plamo2" }, { LLM_ARCH_CODESHELL, "codeshell" }, { LLM_ARCH_ORION, "orion" }, { LLM_ARCH_INTERNLM2, "internlm2" }, @@ -84,6 +85,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" }, { LLM_ARCH_SMOLLM3, "smollm3" }, { LLM_ARCH_LFM2, "lfm2" }, + { LLM_ARCH_DREAM, "dream" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -784,6 +786,36 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_PLAMO2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_X, "blk.%d.ssm_x" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + { LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" }, + { LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" }, + { LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, { LLM_ARCH_CODESHELL, { @@ -1860,6 +1892,23 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, }, }, + { + LLM_ARCH_DREAM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, }; static const std::map LLM_TENSOR_INFOS = { @@ -2094,6 +2143,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { switch (arch) { case LLM_ARCH_JAMBA: case LLM_ARCH_FALCON_H1: + case LLM_ARCH_PLAMO2: case LLM_ARCH_GRANITE_HYBRID: case LLM_ARCH_LFM2: return true; @@ -2101,3 +2151,12 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { return false; } } + +bool llm_arch_is_diffusion(const llm_arch & arch) { + switch (arch) { + case LLM_ARCH_DREAM: + return true; + default: + return false; + } +} diff --git a/src/llama-arch.h b/src/llama-arch.h index 1f97325952411..0ead0d6cdb11b 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -38,6 +38,7 @@ enum llm_arch { LLM_ARCH_PHI3, LLM_ARCH_PHIMOE, LLM_ARCH_PLAMO, + LLM_ARCH_PLAMO2, LLM_ARCH_CODESHELL, LLM_ARCH_ORION, LLM_ARCH_INTERNLM2, @@ -88,6 +89,7 @@ enum llm_arch { LLM_ARCH_HUNYUAN_MOE, LLM_ARCH_SMOLLM3, LLM_ARCH_LFM2, + LLM_ARCH_DREAM, LLM_ARCH_UNKNOWN, }; @@ -478,3 +480,4 @@ const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor); bool llm_arch_is_recurrent(const llm_arch & arch); bool llm_arch_is_hybrid (const llm_arch & arch); +bool llm_arch_is_diffusion(const llm_arch & arch); diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 3bc8554e51ccf..f8227777f19de 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -27,6 +27,7 @@ bool llama_batch_allocr::init( const llama_vocab & vocab, const llama_memory_i * memory, uint32_t n_embd, + uint32_t n_seq_max, bool output_all) { clear(); @@ -40,6 +41,11 @@ bool llama_batch_allocr::init( // validate input batch // + if (n_seq_max > LLAMA_MAX_SEQ) { + LLAMA_LOG_ERROR("%s: n_seq_max = %d > %d\n", __func__, n_seq_max, LLAMA_MAX_SEQ); + return false; + } + if (batch.token) { for (int32_t i = 0; i < batch.n_tokens; ++i) { if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) { @@ -52,8 +58,8 @@ bool llama_batch_allocr::init( if (batch.seq_id) { for (int32_t i = 0; i < batch.n_tokens; ++i) { for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) { - if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) { - LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ); + if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) { + LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max); return false; } } @@ -86,7 +92,7 @@ bool llama_batch_allocr::init( // initialize the starting position for each sequence based on the positions in the memory llama_pos p0[LLAMA_MAX_SEQ]; - for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < n_seq_max; ++s) { if (!memory) { // if no memory -> start from 0 p0[s] = 0; @@ -143,7 +149,8 @@ bool llama_batch_allocr::init( // compute stats // - this->n_embd = n_embd; + this->n_embd = n_embd; + this->n_seq_max = n_seq_max; // count the outputs in this batch for (int32_t i = 0; i < batch.n_tokens; ++i) { @@ -189,7 +196,7 @@ bool llama_batch_allocr::init( seq_set_map[cur].push_back(i); } - for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < n_seq_max; ++s) { if (seq_set_unq.test(s)) { seq_idx[s] = seq_id_unq.size(); seq_id_unq.push_back(s); @@ -241,7 +248,7 @@ bool llama_batch_allocr::init( // consistency checks // - for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < n_seq_max; ++s) { if (seq_pos[s].empty()) { continue; } @@ -284,8 +291,8 @@ bool llama_batch_allocr::init( } if (memory) { - for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) { - for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) { + for (uint32_t s0 = 0; s0 < n_seq_max; ++s0) { + for (uint32_t s1 = 0; s1 < n_seq_max; ++s1) { if (seq_cpl[s0][s1]) { if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) || memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) { @@ -316,12 +323,12 @@ bool llama_batch_allocr::init( // { seq_set_t cur_seq_set[LLAMA_MAX_SEQ]; - for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < n_seq_max; ++s) { cur_seq_set[s].set(); } llama_pos cur_seq_pos[LLAMA_MAX_SEQ]; - for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < n_seq_max; ++s) { cur_seq_pos[s] = -1; } @@ -692,7 +699,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u } } - for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < n_seq_max; ++s) { if (seq_set_unq.test(s)) { ubatch.seq_idx[s] = ubatch.seq_id_unq.size(); ubatch.seq_id_unq.push_back(s); diff --git a/src/llama-batch.h b/src/llama-batch.h index 3420803ff9469..1a24440ba7562 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -48,6 +48,7 @@ class llama_batch_allocr { const llama_vocab & vocab, const llama_memory_i * memory, uint32_t n_embd, + uint32_t n_seq_max, bool output_all); const llama_batch & get_batch() const; @@ -100,6 +101,7 @@ class llama_batch_allocr { const uint32_t n_pos_per_embd; uint32_t n_embd; + uint32_t n_seq_max; uint32_t n_outputs; std::array seq_id_0 = { 0 }; // default sequence id diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index cbc19d3c40c30..240937eceee9d 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -65,6 +65,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 }, { "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM }, { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE }, + { "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 }, }; llm_chat_template llm_chat_template_from_str(const std::string & name) { @@ -170,7 +171,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb // EXAONE-3.0-7.8B-Instruct return LLM_CHAT_TEMPLATE_EXAONE_3; - } else if (tmpl_contains("rwkv-world")) { + } else if (tmpl_contains("rwkv-world") || tmpl_contains("{{- 'User: ' + message['content']|trim + '\\n\\n' -}}")) { return LLM_CHAT_TEMPLATE_RWKV_WORLD; } else if (tmpl_contains("<|start_of_role|>")) { return LLM_CHAT_TEMPLATE_GRANITE; @@ -188,6 +189,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_DOTS1; } else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) { return LLM_CHAT_TEMPLATE_HUNYUAN_MOE; + } else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) { + return LLM_CHAT_TEMPLATE_KIMI_K2; } return LLM_CHAT_TEMPLATE_UNKNOWN; } @@ -680,6 +683,26 @@ int32_t llm_chat_apply_template( ss << "<|startoftext|>" << message->content << "<|extra_0|>"; } } + } else if (tmpl == LLM_CHAT_TEMPLATE_KIMI_K2) { + // moonshotai/Kimi-K2-Instruct + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "<|im_system|>system<|im_middle|>"; + } else if (role == "user") { + ss << "<|im_user|>user<|im_middle|>"; + } else if (role == "assistant") { + ss << "<|im_assistant|>assistant<|im_middle|>"; + } else if (role == "tool") { + ss << "<|im_system|>tool<|im_middle|>"; + } + + ss << message->content << "<|im_end|>"; + + if (add_ass) { + ss << "<|im_assistant|>assistant<|im_middle|>"; + } + } } else { // template not supported return -1; diff --git a/src/llama-chat.h b/src/llama-chat.h index b621fda281669..cab0533485652 100644 --- a/src/llama-chat.h +++ b/src/llama-chat.h @@ -45,6 +45,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_SMOLVLM, LLM_CHAT_TEMPLATE_DOTS1, LLM_CHAT_TEMPLATE_HUNYUAN_MOE, + LLM_CHAT_TEMPLATE_KIMI_K2, LLM_CHAT_TEMPLATE_UNKNOWN, }; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 06e93b19cbf40..840ec9a9aaca1 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -98,10 +98,20 @@ llama_context::llama_context( LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD); cparams.n_batch = GGML_KQ_MASK_PAD; } - cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); cparams.op_offload = params.op_offload; + cparams.kv_unified = params.kv_unified; + + { + const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS"); + const bool supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0; + + if (!supports_set_rows && !cparams.kv_unified) { + LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__); + cparams.kv_unified = true; + } + } const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; @@ -112,6 +122,7 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn); LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); + LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false"); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); @@ -267,7 +278,7 @@ llama_context::llama_context( // reserve worst-case graph if (!hparams.vocab_only && memory) { - const uint32_t n_seqs = cparams.n_seq_max; + const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); @@ -300,7 +311,7 @@ llama_context::llama_context( // reserve with tg graph to get the number of splits and nodes { - auto * gf = graph_reserve(1, 1, 1, mctx.get()); + auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get()); if (!gf) { throw std::runtime_error("failed to allocate compute tg buffers"); } @@ -311,6 +322,10 @@ llama_context::llama_context( // reserve again with pp graph to avoid ggml-alloc reallocations during inference { + // TODO: not sure if the following graph would be worster case for multi-stream KV caches: + // + // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); + // auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); @@ -475,7 +490,7 @@ bool llama_context::kv_self_update(bool optimize) { throw std::runtime_error("failed to initialize memory context"); } - const uint32_t n_seqs = cparams.n_seq_max; + const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); @@ -731,16 +746,19 @@ int llama_context::encode(const llama_batch & batch_inp) { const auto & hparams = model.hparams; - const int64_t n_embd = hparams.n_embd; + const int64_t n_embd = hparams.n_embd; + const int32_t n_vocab = model.vocab.n_tokens(); // note: during encode, we always pass the full sequence starting from pos = 0 - if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) { + if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); return -1; } const uint32_t n_tokens = balloc->get_n_tokens(); + // [TAG_NO_CACHE_PAD] + // TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true const llama_ubatch ubatch = balloc->split_simple(n_tokens); // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot @@ -791,10 +809,20 @@ int llama_context::encode(const llama_batch & batch_inp) { } } + auto * t_logits = res->get_logits(); auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + // extract logits + if (logits && t_logits) { + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); + GGML_ASSERT(backend_res != nullptr); + GGML_ASSERT(logits != nullptr); + + ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float)); + } + // extract embeddings - if (t_embd) { + if (embd && t_embd) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr); @@ -899,7 +927,7 @@ int llama_context::decode(const llama_batch & batch_inp) { // when computing embeddings, all tokens are output const bool output_all = cparams.embeddings; - if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) { + if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) { LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); return -1; } @@ -2028,7 +2056,7 @@ void llama_context::opt_epoch_iter( batch.logits [pos_batch] = true; } - if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) { + if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); return; } @@ -2187,6 +2215,7 @@ llama_context_params llama_context_default_params() { /*.no_perf =*/ true, /*.op_offload =*/ true, /*.swa_full =*/ true, + /*.kv_unified =*/ false, }; return result; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 118615d5bd2d5..38750affc500b 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -11,8 +11,8 @@ struct llama_cparams { uint32_t n_batch; uint32_t n_ubatch; uint32_t n_seq_max; - int n_threads; // number of threads to use for generation - int n_threads_batch; // number of threads to use for batch processing + int32_t n_threads; // number of threads to use for generation + int32_t n_threads_batch; // number of threads to use for batch processing float rope_freq_base; float rope_freq_scale; @@ -33,6 +33,7 @@ struct llama_cparams { bool no_perf; bool warmup; bool op_offload; + bool kv_unified; enum llama_pooling_type pooling_type; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index a248a7ec22350..1a6355e85d11e 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -982,13 +982,16 @@ ggml_tensor * llm_graph_context::build_attn_mha( float kq_scale) const { const bool v_trans = v->nb[1] > v->nb[2]; + // split the batch into streams if needed + const auto n_stream = k->ne[3]; + + q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream); + q = ggml_permute(ctx0, q, 0, 2, 1, 3); k = ggml_permute(ctx0, k, 0, 2, 1, 3); v = ggml_permute(ctx0, v, 0, 2, 1, 3); - const auto n_tokens = q->ne[1]; - const auto n_head = q->ne[2]; - const auto n_kv = k->ne[1]; + const auto n_kv = k->ne[1]; ggml_tensor * cur; @@ -1030,7 +1033,7 @@ ggml_tensor * llm_graph_context::build_attn_mha( #endif } - cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]); } else { ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); @@ -1075,7 +1078,8 @@ ggml_tensor * llm_graph_context::build_attn_mha( cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3); - cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); + // recombine streams + cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]); if (!cparams.offload_kqv) { // all nodes between the KV store and the attention output are run on the CPU @@ -1122,6 +1126,10 @@ ggml_tensor * llm_graph_context::build_attn( const auto & kq_mask = inp->get_kq_mask(); + // [TAG_NO_CACHE_PAD] + // TODO: if ubatch.equal_seqs == true, we can split the three tensors below into ubatch.n_seqs_unq streams + assert(ubatch.equal_seqs == false); + ggml_tensor * q = q_cur; ggml_tensor * k = k_cur; ggml_tensor * v = v_cur; @@ -1156,13 +1164,14 @@ static std::unique_ptr build_attn_inp_kv_unifie { GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA"); - const auto n_kv = mctx_cur->get_n_kv(); + const auto n_kv = mctx_cur->get_n_kv(); const auto n_tokens = ubatch.n_tokens; + const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1362,13 +1371,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif auto inp = std::make_unique(hparams, cparams, mctx_cur); + const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; + { const auto n_kv = mctx_cur->get_base()->get_n_kv(); inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1382,7 +1393,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); + inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream); ggml_set_input(inp->self_kq_mask_swa); inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; diff --git a/src/llama-graph.h b/src/llama-graph.h index fbf8e2889564d..84a5b0b3f9c40 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -255,10 +255,10 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i { ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] - ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] + ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa] - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1] + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] const llama_hparams & hparams; const llama_cparams & cparams; @@ -289,14 +289,14 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; } ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] - ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] + ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa] ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch] - ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] + ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa] - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1] - ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1, 1] - ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1, 1] + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] const llama_hparams & hparams; const llama_cparams & cparams; diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 7aa736e2f39db..c6c67d26f9392 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -65,6 +65,46 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const { return n_embd_head_v * n_head_kv; } +bool llama_hparams::is_n_embd_k_gqa_variable() const { + const uint32_t val = n_embd_k_gqa(); + for (uint32_t il = 0; il < n_layer; ++il) { + if (val != n_embd_k_gqa(il)) { + return true; + } + } + + return false; +} + +bool llama_hparams::is_n_embd_v_gqa_variable() const { + const uint32_t val = n_embd_v_gqa(); + for (uint32_t il = 0; il < n_layer; ++il) { + if (val != n_embd_v_gqa(il)) { + return true; + } + } + + return false; +} + +uint32_t llama_hparams::n_embd_k_gqa_max() const { + uint32_t val = n_embd_k_gqa(); + for (uint32_t il = 0; il < n_layer; ++il) { + val = std::max(val, n_embd_k_gqa(il)); + } + + return val; +} + +uint32_t llama_hparams::n_embd_v_gqa_max() const { + uint32_t val = n_embd_v_gqa(); + for (uint32_t il = 0; il < n_layer; ++il) { + val = std::max(val, n_embd_v_gqa(il)); + } + + return val; +} + uint32_t llama_hparams::n_embd_r() const { if (wkv_head_size != 0) { // for RWKV models diff --git a/src/llama-hparams.h b/src/llama-hparams.h index d0500e4d0fd77..c422cd7be827a 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -6,7 +6,7 @@ // bump if necessary #define LLAMA_MAX_LAYERS 512 -#define LLAMA_MAX_EXPERTS 256 // DeepSeekV3 +#define LLAMA_MAX_EXPERTS 384 // Kimi-K2 enum llama_expert_gating_func_type { LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0, @@ -191,6 +191,14 @@ struct llama_hparams { // dimension of value embeddings across all k-v heads uint32_t n_embd_v_gqa(uint32_t il = 0) const; + // true if any layer has a different n_embd_k_gqa/n_embd_v_gqa + bool is_n_embd_k_gqa_variable() const; + bool is_n_embd_v_gqa_variable() const; + + // return the maximum n_embd_k_gqa/n_embd_v_gqa across all layers + uint32_t n_embd_k_gqa_max() const; + uint32_t n_embd_v_gqa_max() const; + // dimension of the rolling state embeddings // corresponds to Mamba's conv_states size or RWKV's token_shift states size uint32_t n_embd_r() const; diff --git a/src/llama-kv-cache-unified-iswa.cpp b/src/llama-kv-cache-unified-iswa.cpp index fe207ad536032..01d27fb4db9b1 100644 --- a/src/llama-kv-cache-unified-iswa.cpp +++ b/src/llama-kv-cache-unified-iswa.cpp @@ -18,16 +18,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( bool v_trans, bool offload, bool swa_full, + bool unified, uint32_t kv_size, uint32_t n_seq_max, uint32_t n_ubatch, - uint32_t n_pad) : hparams(model.hparams) { + uint32_t n_pad) : hparams(model.hparams), unified(unified) { llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); }; llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); }; const uint32_t size_base = kv_size; - uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad)); + uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad)); // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size if (swa_full) { @@ -41,14 +42,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( kv_base = std::make_unique( model, std::move(filter_base), type_k, type_v, - v_trans, offload, size_base, n_seq_max, n_pad, + v_trans, offload, unified, size_base, n_seq_max, n_pad, 0, LLAMA_SWA_TYPE_NONE); LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); kv_swa = std::make_unique( model, std::move(filter_swa), type_k, type_v, - v_trans, offload, size_swa, n_seq_max, n_pad, + v_trans, offload, unified, size_swa, n_seq_max, n_pad, hparams.n_swa, hparams.swa_type); } @@ -100,6 +101,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all // first try simple split do { + if (!unified) { + // requires equal splits, so we skip the simple split + break; + } + balloc.split_reset(); std::vector ubatches; @@ -140,7 +146,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all std::vector ubatches; while (true) { - auto ubatch = balloc.split_equal(n_ubatch, false); + auto ubatch = balloc.split_equal(n_ubatch, !unified); if (ubatch.n_tokens == 0) { break; diff --git a/src/llama-kv-cache-unified-iswa.h b/src/llama-kv-cache-unified-iswa.h index 23205d826b23b..d2650dadd3595 100644 --- a/src/llama-kv-cache-unified-iswa.h +++ b/src/llama-kv-cache-unified-iswa.h @@ -20,6 +20,7 @@ class llama_kv_cache_unified_iswa : public llama_memory_i { bool v_trans, bool offload, bool swa_full, + bool unified, uint32_t kv_size, uint32_t n_seq_max, uint32_t n_ubatch, @@ -68,6 +69,8 @@ class llama_kv_cache_unified_iswa : public llama_memory_i { private: const llama_hparams & hparams; + const bool unified; + std::unique_ptr kv_base; std::unique_ptr kv_swa; }; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index d3129cc53281e..7e92e6b4df9d4 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -23,13 +23,14 @@ llama_kv_cache_unified::llama_kv_cache_unified( ggml_type type_v, bool v_trans, bool offload, + bool unified, uint32_t kv_size, uint32_t n_seq_max, uint32_t n_pad, uint32_t n_swa, llama_swa_type swa_type) : model(model), hparams(model.hparams), v_trans(v_trans), - n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { + n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { GGML_ASSERT(kv_size % n_pad == 0); @@ -45,7 +46,7 @@ llama_kv_cache_unified::llama_kv_cache_unified( auto it = ctx_map.find(buft); if (it == ctx_map.end()) { ggml_init_params params = { - /*.mem_size =*/ size_t(2u*n_layer_cache*ggml_tensor_overhead()), + /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_cache*ggml_tensor_overhead()), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; @@ -64,9 +65,33 @@ llama_kv_cache_unified::llama_kv_cache_unified( return it->second; }; - head = 0; + GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max); - cells.resize(kv_size); + v_heads.resize(n_stream); + for (uint32_t s = 0; s < n_stream; ++s) { + v_heads[s] = 0; + } + + v_cells.resize(n_stream); + for (uint32_t s = 0; s < n_stream; ++s) { + v_cells[s].resize(kv_size); + } + + // by default, all sequence ids are mapped to the 0th stream + seq_to_stream.resize(LLAMA_MAX_SEQ, 0); + + if (n_stream > 1) { + seq_to_stream.resize(n_stream, 0); + for (uint32_t s = 0; s < n_stream; ++s) { + seq_to_stream[s] = s; + } + } + + // [TAG_V_CACHE_VARIABLE] + if (v_trans && hparams.is_n_embd_v_gqa_variable()) { + LLAMA_LOG_WARN("%s: the V embeddings have different sizes across layers and FA is not enabled - padding V cache to %d\n", + __func__, hparams.n_embd_v_gqa_max()); + } for (uint32_t il = 0; il < n_layer_cache; il++) { if (filter && !filter(il)) { @@ -74,8 +99,9 @@ llama_kv_cache_unified::llama_kv_cache_unified( continue; } - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + // [TAG_V_CACHE_VARIABLE] + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max(); const char * dev_name = "CPU"; @@ -98,14 +124,23 @@ llama_kv_cache_unified::llama_kv_cache_unified( ggml_tensor * k; ggml_tensor * v; - k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size); - v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size); + k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream); + v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream); ggml_format_name(k, "cache_k_l%d", il); ggml_format_name(v, "cache_v_l%d", il); + std::vector k_stream; + std::vector v_stream; + + for (uint32_t s = 0; s < n_stream; ++s) { + k_stream.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2])); + v_stream.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2])); + } + map_layer_ids[il] = layers.size(); - layers.push_back({ il, k, v }); + + layers.push_back({ il, k, v, k_stream, v_stream, }); } // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE] @@ -148,8 +183,8 @@ llama_kv_cache_unified::llama_kv_cache_unified( const size_t memory_size_k = size_k_bytes(); const size_t memory_size_v = size_v_bytes(); - LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, - (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, + LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_stream, ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); } @@ -160,15 +195,21 @@ llama_kv_cache_unified::llama_kv_cache_unified( const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS"); supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0; + if (!supports_set_rows) { + // ref: https://github.com/ggml-org/llama.cpp/pull/14363 + GGML_ASSERT(unified && "cannot use non-unified KV cache without ggml_set_rows() support"); + } + if (!supports_set_rows) { LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__); } } void llama_kv_cache_unified::clear(bool data) { - cells.reset(); - - head = 0; + for (uint32_t s = 0; s < n_stream; ++s) { + v_cells[s].reset(); + v_heads[s] = 0; + } if (data) { for (auto & buf : bufs) { @@ -178,6 +219,11 @@ void llama_kv_cache_unified::clear(bool data) { } bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); + + auto & cells = v_cells[seq_to_stream[seq_id]]; + auto & head = v_heads[seq_to_stream[seq_id]]; + uint32_t new_head = cells.size(); if (p0 < 0) { @@ -224,30 +270,94 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos } void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - if (seq_id_src == seq_id_dst) { + GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size()); + GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size()); + + const auto s0 = seq_to_stream[seq_id_src]; + const auto s1 = seq_to_stream[seq_id_dst]; + + if (s0 == s1) { + // since both sequences are in the same stream, no data copy is necessary + // we just have to update the cells meta data + + auto & cells = v_cells[s0]; + + if (seq_id_src == seq_id_dst) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + if (cells.seq_has(i, seq_id_src)) { + cells.seq_add(i, seq_id_dst); + } + } + return; } - if (p0 < 0) { - p0 = 0; + // cross-stream sequence copies require to copy the actual buffer data + + bool is_full = true; + + if (p0 > 0 && p0 + 1 < (int) get_size()) { + is_full = false; } - if (p1 < 0) { - p1 = std::numeric_limits::max(); + if (p1 > 0 && p1 + 1 < (int) get_size()) { + is_full = false; } - for (uint32_t i = 0; i < cells.size(); ++i) { - if (!cells.pos_in(i, p0, p1)) { - continue; - } + GGML_ASSERT(is_full && "seq_cp() is only supported for full KV buffers"); + + // enqueue the copy operation - the buffer copy will be performed during the next update + sc_info.ssrc.push_back(s0); + sc_info.sdst.push_back(s1); + + v_cells[s1].reset(); + for (uint32_t i = 0; i < v_cells[s0].size(); ++i) { + if (v_cells[s0].seq_has(i, seq_id_src)) { + llama_pos pos = v_cells[s0].pos_get(i); + llama_pos shift = v_cells[s0].get_shift(i); + + if (shift != 0) { + pos -= shift; + assert(pos >= 0); + } + + v_cells[s1].pos_set(i, pos); + v_cells[s1].seq_add(i, seq_id_dst); - if (cells.seq_has(i, seq_id_src)) { - cells.seq_add(i, seq_id_dst); + if (shift != 0) { + v_cells[s1].pos_add(i, shift); + } } } + + v_heads[s1] = v_heads[s0]; + + //for (uint32_t s = 0; s < n_stream; ++s) { + // LLAMA_LOG_WARN("%s: seq %d: min = %d, max = %d\n", __func__, s, v_cells[s].seq_pos_min(s), v_cells[s].seq_pos_max(s)); + //} } void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); + + auto & cells = v_cells[seq_to_stream[seq_id]]; + auto & head = v_heads[seq_to_stream[seq_id]]; + uint32_t new_head = cells.size(); for (uint32_t i = 0; i < cells.size(); ++i) { @@ -265,6 +375,11 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { } void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); + + auto & cells = v_cells[seq_to_stream[seq_id]]; + auto & head = v_heads[seq_to_stream[seq_id]]; + if (shift == 0) { return; } @@ -304,6 +419,10 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po } void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); + + auto & cells = v_cells[seq_to_stream[seq_id]]; + if (d == 1) { return; } @@ -333,10 +452,18 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po } llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); + + const auto & cells = v_cells[seq_to_stream[seq_id]]; + return cells.seq_pos_min(seq_id); } llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); + + const auto & cells = v_cells[seq_to_stream[seq_id]]; + return cells.seq_pos_max(seq_id); } @@ -351,7 +478,7 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch( std::vector ubatches; while (true) { - auto ubatch = balloc.split_simple(n_ubatch); + auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true); if (ubatch.n_tokens == 0) { break; @@ -387,7 +514,10 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct defrag_info dinfo; // see if we need to defrag - { + if (n_stream == 1) { + // note : for now do not consider defrag for n_stream > 1 + const auto & cells = v_cells[seq_to_stream[0]]; + bool do_defrag = optimize; const auto thold = lctx->get_cparams().defrag_thold; @@ -411,22 +541,22 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct } } - return std::make_unique(this, lctx, do_shift, std::move(dinfo)); + return std::make_unique(this, lctx, do_shift, std::move(dinfo), std::move(sc_info)); } llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector & ubatches) { llama_kv_cache_unified::slot_info_vec_t res; - struct state { - uint32_t head_old; // old position of the head, before placing the ubatch - + struct state_t { slot_info sinfo; // slot info for the ubatch - llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch + std::vector v_heads_old; // old positions of the heads, before placing the ubatch + + std::vector v_cells; // copy of the old cells, before placing the ubatch }; // remember the old state of the cells so we can restore it in the end - std::vector states; + std::vector states; bool success = true; @@ -445,16 +575,35 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st res.push_back(sinfo_new); // store the old state of the cells in the recovery stack - states.push_back({head, sinfo_new, cells.cp(sinfo_new.idxs)}); + { + state_t state = { sinfo_new, v_heads, {} }; + + for (uint32_t s = 0; s < sinfo_new.n_stream(); ++s) { + auto & cells = v_cells[sinfo_new.strm[s]]; + + state.v_cells.push_back(cells.cp(sinfo_new.idxs[s])); + } + + states.push_back(std::move(state)); + } // now emplace the ubatch apply_ubatch(sinfo_new, ubatch); } + GGML_ASSERT(!states.empty() || !success); + // iterate backwards and restore the cells to their original state for (auto it = states.rbegin(); it != states.rend(); ++it) { - cells.set(it->sinfo.idxs, it->cells); - head = it->head_old; + const auto & sinfo = it->sinfo; + + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + auto & cells = v_cells[sinfo.strm[s]]; + auto & head = v_heads[sinfo.strm[s]]; + + cells.set(sinfo.idxs[s], it->v_cells[s]); + head = it->v_heads_old[s]; + } } if (!success) { @@ -464,11 +613,38 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st return res; } -bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo) { +bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info) { bool updated = false; auto * sched = lctx->get_sched(); + if (!sc_info.empty()) { + assert(n_stream > 1 && "stream copy should never happen with a single stream"); + + llama_synchronize(lctx); + + const size_t n_copy = sc_info.ssrc.size(); + + for (size_t i = 0; i < n_copy; ++i) { + const auto ssrc = sc_info.ssrc[i]; + const auto sdst = sc_info.sdst[i]; + + assert(ssrc < n_stream); + assert(sdst < n_stream); + + LLAMA_LOG_DEBUG("%s: copying KV buffer: stream %d to stream %d\n", __func__, ssrc, sdst); + + assert(ssrc != sdst); + + for (uint32_t il = 0; il < layers.size(); ++il) { + const auto & layer = layers[il]; + + ggml_backend_tensor_copy(layer.k_stream[ssrc], layer.k_stream[sdst]); + ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]); + } + } + } + if (do_shift) { if (!get_can_shift()) { GGML_ABORT("The current KV cache / model configuration does not support K-shift"); @@ -503,12 +679,20 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d updated = true; } - cells.reset_shift(); + for (uint32_t s = 0; s < n_stream; ++s) { + auto & cells = v_cells[s]; + + cells.reset_shift(); + } } if (!dinfo.empty()) { LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); + // note: for now do not consider defrag for n_stream > 1 + auto & cells = v_cells[seq_to_stream[0]]; + auto & head = v_heads[seq_to_stream[0]]; + // apply moves: { const auto n_kv = dinfo.ids.size(); @@ -556,23 +740,13 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d } llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const { - const uint32_t n_tokens = ubatch.n_tokens; - - uint32_t head_cur = this->head; - - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (head_cur > cells.get_used() + 2*ubatch.n_tokens) { - head_cur = 0; - } + if (debug > 0) { + const auto & cells = v_cells[seq_to_stream[1]]; - if (n_tokens > cells.size()) { - LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); - return { }; - } + const uint32_t head_cur = v_heads[1]; - if (debug > 0) { - LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa); + LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", + __func__, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa); if ((debug == 2 && n_swa > 0) || debug > 2) { std::string ss; @@ -629,86 +803,133 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ } } - uint32_t n_tested = 0; + uint32_t n_tokens = ubatch.n_tokens; + uint32_t n_seqs = 1; - // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head - // for non-continuous slots, we test the tokens one by one - const uint32_t n_test = cont ? n_tokens : 1; + if (n_stream > 1) { + GGML_ASSERT(n_tokens % ubatch.n_seqs_unq == 0); - slot_info res; + n_seqs = ubatch.n_seqs_unq; + n_tokens = n_tokens / n_seqs; + } + + slot_info res = { + /*.s0 =*/ LLAMA_MAX_SEQ, + /*.s1 =*/ 0, + /*.strm =*/ { }, + /*.idxs =*/ { }, + }; + + res.resize(n_seqs); + + for (uint32_t s = 0; s < n_seqs; ++s) { + const auto seq_id = ubatch.seq_id_unq[s]; + + if (n_stream > 1) { + GGML_ASSERT(ubatch.n_seq_id[s*n_tokens] == 1); + GGML_ASSERT(ubatch.seq_id [s*n_tokens][0] == seq_id); + } + + res.s0 = std::min(res.s0, seq_to_stream[seq_id]); + res.s1 = std::max(res.s1, seq_to_stream[seq_id]); + + res.strm[s] = seq_to_stream[seq_id]; + res.idxs[s].reserve(n_tokens); - auto & idxs = res.idxs; + const auto & cells = v_cells[seq_to_stream[seq_id]]; - idxs.reserve(n_tokens); + uint32_t head_cur = v_heads[seq_to_stream[seq_id]]; - while (true) { - if (head_cur + n_test > cells.size()) { - n_tested += cells.size() - head_cur; + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (head_cur > cells.get_used() + 2*n_tokens) { head_cur = 0; - continue; } - for (uint32_t i = 0; i < n_test; i++) { - const auto idx = head_cur; + if (n_tokens > cells.size()) { + LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); + return { }; + } + + uint32_t n_tested = 0; + + // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head + // for non-continuous slots, we test the tokens one by one + const uint32_t n_test = cont ? n_tokens : 1; + + while (true) { + if (head_cur + n_test > cells.size()) { + n_tested += cells.size() - head_cur; + head_cur = 0; + continue; + } + + for (uint32_t i = 0; i < n_test; i++) { + const auto idx = head_cur; + + head_cur++; + n_tested++; - //const llama_pos pos = ubatch.pos[i]; - //const llama_seq_id seq_id = ubatch.seq_id[i][0]; + //const llama_pos pos = ubatch.pos[i]; + //const llama_seq_id seq_id = ubatch.seq_id[i][0]; - // can we use this cell? either: - // - the cell is empty - // - the cell is occupied only by one sequence: - // - (disabled) mask causally, if the sequence is the same as the one we are inserting - // - mask SWA, using current max pos for that sequence in the cache - // always insert in the cell with minimum pos - bool can_use = cells.is_empty(idx); + // can we use this cell? either: + // - the cell is empty + // - the cell is occupied only by one sequence: + // - (disabled) mask causally, if the sequence is the same as the one we are inserting + // - mask SWA, using current max pos for that sequence in the cache + // always insert in the cell with minimum pos + bool can_use = cells.is_empty(idx); - if (!can_use && cells.seq_count(idx) == 1) { - const llama_pos pos_cell = cells.pos_get(idx); + if (!can_use && cells.seq_count(idx) == 1) { + const llama_pos pos_cell = cells.pos_get(idx); - // (disabled) causal mask - // note: it's better to purge any "future" tokens beforehand - //if (cells.seq_has(idx, seq_id)) { - // can_use = pos_cell >= pos; - //} + // (disabled) causal mask + // note: it's better to purge any "future" tokens beforehand + //if (cells.seq_has(idx, seq_id)) { + // can_use = pos_cell >= pos; + //} - if (!can_use) { - const llama_seq_id seq_id_cell = cells.seq_get(idx); + if (!can_use) { + const llama_seq_id seq_id_cell = cells.seq_get(idx); - // SWA mask - if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { - can_use = true; + // SWA mask + if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { + can_use = true; + } } } - } - head_cur++; - n_tested++; + if (can_use) { + res.idxs[s].push_back(idx); + } else { + if (cont) { + break; + } + } + } - if (can_use) { - idxs.push_back(idx); - } else { + if (res.idxs[s].size() == n_tokens) { break; } - } - if (idxs.size() == n_tokens) { - break; - } + if (cont) { + res.idxs[s].clear(); + } - if (cont) { - idxs.clear(); + if (n_tested >= cells.size()) { + //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + return { }; + } } - if (n_tested >= cells.size()) { - //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + // we didn't find a suitable slot - return empty result + if (res.idxs[s].size() < n_tokens) { return { }; } } - // we didn't find a suitable slot - return empty result - if (idxs.size() < n_tokens) { - res.clear(); - } + assert(res.s1 >= res.s0); return res; } @@ -717,41 +938,51 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u // keep track of the max sequence position that we would overwrite with this ubatch // for non-SWA cache, this would be always empty llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; - for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { seq_pos_max_rm[s] = -1; } - assert(ubatch.n_tokens == sinfo.idxs.size()); + assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size()); - for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { - const auto idx = sinfo.idxs.at(i); + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + for (uint32_t ii = 0; ii < sinfo.size(); ++ii) { + const uint32_t i = s*sinfo.size() + ii; - if (!cells.is_empty(idx)) { - assert(cells.seq_count(idx) == 1); + auto & cells = v_cells[sinfo.strm[s]]; - const llama_seq_id seq_id = cells.seq_get(idx); - const llama_pos pos = cells.pos_get(idx); + const auto idx = sinfo.idxs[s][ii]; - seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); + if (!cells.is_empty(idx)) { + assert(cells.seq_count(idx) == 1); - cells.rm(idx); - } + const llama_seq_id seq_id = cells.seq_get(idx); + const llama_pos pos = cells.pos_get(idx); - cells.pos_set(idx, ubatch.pos[i]); + seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); - for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { - cells.seq_add(idx, ubatch.seq_id[i][s]); + cells.rm(idx); + } + + cells.pos_set(idx, ubatch.pos[i]); + + for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { + cells.seq_add(idx, ubatch.seq_id[i][s]); + } } } // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence // will be present in the cache. so we have to purge any position which is less than those we would overwrite // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 - for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { if (seq_pos_max_rm[s] == -1) { continue; } + GGML_ASSERT(s < seq_to_stream.size()); + + auto & cells = v_cells[seq_to_stream[s]]; + if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) { LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n", __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s); @@ -761,7 +992,11 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u } // move the head at the end of the slot - head = sinfo.idxs.back() + 1; + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + auto & head = v_heads[sinfo.strm[s]]; + + head = sinfo.idxs[s].back() + 1; + } } bool llama_kv_cache_unified::get_can_shift() const { @@ -769,49 +1004,87 @@ bool llama_kv_cache_unified::get_can_shift() const { } uint32_t llama_kv_cache_unified::get_size() const { + const auto & cells = v_cells[seq_to_stream[0]]; + return cells.size(); } +uint32_t llama_kv_cache_unified::get_n_stream() const { + return n_stream; +} + bool llama_kv_cache_unified::get_has_shift() const { - return cells.get_has_shift(); + bool result = false; + + for (uint32_t s = 0; s < n_stream; ++s) { + result |= v_cells[s].get_has_shift(); + } + + return result; } uint32_t llama_kv_cache_unified::get_n_kv() const { - return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))); + uint32_t result = 0; + + for (uint32_t s = 0; s < n_stream; ++s) { + const auto & cells = v_cells[s]; + + result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result); + } + + return result; } -ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const { +ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const { const int32_t ikv = map_layer_ids.at(il); auto * k = layers[ikv].k; - return ggml_view_3d(ctx, k, - hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, + const uint64_t kv_size = get_size(); + const uint64_t n_embd_k_gqa = k->ne[0]; + + assert(n_embd_k_gqa == hparams.n_embd_k_gqa(il)); + + const uint32_t ns = sinfo.s1 - sinfo.s0 + 1; + + return ggml_view_4d(ctx, k, + hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns, ggml_row_size(k->type, hparams.n_embd_head_k), - ggml_row_size(k->type, hparams.n_embd_k_gqa(il)), - 0); + ggml_row_size(k->type, n_embd_k_gqa), + ggml_row_size(k->type, n_embd_k_gqa*kv_size), + ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0); } -ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const { +ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const { const int32_t ikv = map_layer_ids.at(il); auto * v = layers[ikv].v; + const uint64_t kv_size = get_size(); + const uint64_t n_embd_v_gqa = v->ne[0]; + + // [TAG_V_CACHE_VARIABLE] + assert(n_embd_v_gqa >= hparams.n_embd_v_gqa(il)); + + const uint32_t ns = sinfo.s1 - sinfo.s0 + 1; + if (!v_trans) { // note: v->nb[1] <= v->nb[2] - return ggml_view_3d(ctx, v, - hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, - ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] - ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2] - 0); + return ggml_view_4d(ctx, v, + hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns, + ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] + ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2] + ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3] + ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0); } // note: v->nb[1] > v->nb[2] - return ggml_view_3d(ctx, v, - n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, - ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1] - ggml_row_size(v->type, v->ne[1]), // v->nb[2] - 0); + return ggml_view_4d(ctx, v, + n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns, + ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1] + ggml_row_size(v->type, kv_size), // v->nb[2] + ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3] + ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0); } ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const { @@ -825,12 +1098,18 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_ k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens); if (k_idxs && supports_set_rows) { + if (k->ne[2] > 1) { + k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]); + } + return ggml_set_rows(ctx, k, k_cur, k_idxs); } // TODO: fallback to old ggml_cpy() method for backwards compatibility // will be removed when ggml_set_rows() is adopted by all backends + GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS"); + ggml_tensor * k_view = ggml_view_1d(ctx, k, n_tokens*n_embd_k_gqa, ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head()); @@ -843,37 +1122,38 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_ auto * v = layers[ikv].v; - const int64_t n_embd_v_gqa = v->ne[0]; - const int64_t n_tokens = v_cur->ne[2]; + const int64_t n_embd_v_gqa = v_cur->ne[0]*v_cur->ne[1]; + const int64_t n_tokens = v_cur->ne[2]; v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens); if (v_idxs && supports_set_rows) { if (!v_trans) { + if (v->ne[2] > 1) { + v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]); + } + return ggml_set_rows(ctx, v, v_cur, v_idxs); } - // the row becomes a single element - ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]); + // [TAG_V_CACHE_VARIABLE] + if (n_embd_v_gqa < v->ne[0]) { + v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0); + } - // note: the V cache is transposed when not using flash attention - v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3); + // the row becomes a single element + ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]); - // note: we can be more explicit here at the cost of extra cont - // however, above we take advantage that a row of single element is always continuous regardless of the row stride - //v_cur = ggml_transpose(ctx, v_cur); - //v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]); + v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]); - // we broadcast the KV indices n_embd_v_gqa times - // v [1, n_kv, n_embd_v_gqa] - // v_cur [1, n_tokens, n_embd_v_gqa] - // v_idxs [n_tokens, 1, 1] return ggml_set_rows(ctx, v_view, v_cur, v_idxs); } // TODO: fallback to old ggml_cpy() method for backwards compatibility // will be removed when ggml_set_rows() is adopted by all backends + GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS"); + ggml_tensor * v_view = nullptr; if (!v_trans) { @@ -904,7 +1184,13 @@ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, con ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { const uint32_t n_tokens = ubatch.n_tokens; - ggml_tensor * v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens); + ggml_tensor * v_idxs; + + if (!v_trans) { + v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens); + } else { + v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa_max()); + } ggml_set_input(v_idxs); @@ -917,12 +1203,17 @@ void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_uba } const uint32_t n_tokens = ubatch->n_tokens; + GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream()); GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); int64_t * data = (int64_t *) dst->data; - for (int64_t i = 0; i < n_tokens; ++i) { - data[i] = sinfo.idxs.at(i); + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + const int64_t offs = sinfo.strm[s]*get_size(); + + for (uint32_t i = 0; i < sinfo.size(); ++i) { + data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i]; + } } } @@ -932,12 +1223,48 @@ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_uba } const uint32_t n_tokens = ubatch->n_tokens; + GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream()); GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); int64_t * data = (int64_t *) dst->data; - for (int64_t i = 0; i < n_tokens; ++i) { - data[i] = sinfo.idxs.at(i); + if (!v_trans) { + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + const int64_t offs = sinfo.strm[s]*get_size(); + + for (uint32_t i = 0; i < sinfo.size(); ++i) { + data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i]; + } + } + } else { + // note: the V cache is transposed when not using flash attention + const int64_t kv_size = get_size(); + + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max(); + + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + const int64_t offs = sinfo.strm[s]*kv_size*n_embd_v_gqa; + + for (uint32_t i = 0; i < sinfo.size(); ++i) { + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + data[s*sinfo.size()*n_embd_v_gqa + i*n_embd_v_gqa + j] = offs + j*kv_size + sinfo.idxs[s][i]; + } + } + } + } +} + +void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const { + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + + int32_t * data = (int32_t *) dst->data; + + for (uint32_t s = 0; s < n_stream; ++s) { + const auto & cells = v_cells[s]; + + for (uint32_t i = 0; i < cells.size(); ++i) { + data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i); + } } } @@ -947,7 +1274,14 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); float * data = (float *) dst->data; - const int64_t n_kv = dst->ne[0]; + const int64_t n_kv = dst->ne[0]; + const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch + + GGML_ASSERT(n_tokens%n_stream == 0); + + // n_tps == n_tokens_per_stream + const int64_t n_tps = n_tokens/n_stream; + const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD); // Use only the previous KV cells of the correct sequence for each token of the ubatch. // It's assumed that if a token in the batch has multiple sequences, they are equivalent. @@ -962,67 +1296,66 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub // xxxxx----- // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 for (uint32_t h = 0; h < 1; ++h) { - for (uint32_t i = 0; i < n_tokens; ++i) { - const llama_seq_id seq_id = ubatch->seq_id[i][0]; + for (uint32_t s = 0; s < n_stream; ++s) { + for (uint32_t ii = 0; ii < n_tps; ++ii) { + const uint32_t i = s*n_tps + ii; - const llama_pos p1 = ubatch->pos[i]; + const llama_seq_id seq_id = ubatch->seq_id[i][0]; - for (uint32_t j = 0; j < n_kv; ++j) { - float f = 0.0f; + const auto & cells = v_cells[seq_to_stream[seq_id]]; - bool masked = false; + const llama_pos p1 = ubatch->pos[i]; - if (cells.is_empty(j)) { - masked = true; - } else { - const llama_pos p0 = cells.pos_get(j); + for (uint32_t j = 0; j < n_kv; ++j) { + float f = 0.0f; - // mask the token if not the same sequence - masked = masked || (!cells.seq_has(j, seq_id)); + bool masked = false; - // mask future tokens - masked = masked || (causal_attn && p0 > p1); + if (cells.is_empty(j)) { + masked = true; + } else { + const llama_pos p0 = cells.pos_get(j); + + // mask the token if not the same sequence + masked = masked || (!cells.seq_has(j, seq_id)); + + // mask future tokens + masked = masked || (causal_attn && p0 > p1); - // apply SWA if any - masked = masked || (is_masked_swa(p0, p1)); + // apply SWA if any + masked = masked || (is_masked_swa(p0, p1)); - if (!masked && hparams.use_alibi) { - f = -std::abs(p0 - p1); + if (!masked && hparams.use_alibi) { + f = -std::abs(p0 - p1); + } } - } - if (masked) { - f = -INFINITY; - } + if (masked) { + f = -INFINITY; + } - data[h*(n_kv*n_tokens) + i*n_kv + j] = f; - } - } + data[h*n_stream*n_tps_pad*n_kv + s*n_tps_pad*n_kv + ii*n_kv + j] = f; + } - // mask padded tokens - if (data) { - for (uint32_t i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (uint32_t j = 0; j < n_kv; ++j) { - data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + // mask padded tokens + if (data) { + for (uint32_t ii = n_tps; ii < n_tps_pad; ++ii) { + for (uint32_t j = 0; j < n_kv; ++j) { + data[h*n_stream*n_tps_pad*n_kv + s*n_tps_pad*n_kv + ii*n_kv + j] = -INFINITY; + } + } } } } } } -void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const { - GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); - - int32_t * data = (int32_t *) dst->data; - - for (uint32_t i = 0; i < cells.size(); ++i) { - data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i); - } -} - void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { const int64_t n_tokens = ubatch->n_tokens; + GGML_ASSERT(n_stream == 1 && "TODO: support multiple streams"); + const auto & cells = v_cells[0]; + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing @@ -1129,7 +1462,7 @@ class llm_graph_input_k_shift : public llm_graph_input_i { void set_input(const llama_ubatch * ubatch) override; - ggml_tensor * k_shift; // I32 [kv_size] + ggml_tensor * k_shift; // I32 [kv_size*n_stream] const llama_kv_cache_unified * kv_self; }; @@ -1153,7 +1486,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( auto inp = std::make_unique(this); - inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cells.size()); + inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream); ggml_set_input(inp->k_shift); for (const auto & layer : layers) { @@ -1169,7 +1502,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( ggml_tensor * k = ggml_view_3d(ctx, layer.k, - n_embd_head_k, n_head_kv, cells.size(), + n_embd_head_k, n_head_kv, get_size()*n_stream, ggml_row_size(layer.k->type, n_embd_head_k), ggml_row_size(layer.k->type, n_embd_k_gqa), 0); @@ -1191,6 +1524,10 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( const defrag_info & dinfo) const { auto res = std::make_unique(); + GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag"); + + const auto & cells = v_cells[0]; + const auto & ids = dinfo.ids; #if 0 @@ -1333,6 +1670,10 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( } llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const { + GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag"); + + const auto & cells = v_cells[0]; + const uint32_t n_layer = layers.size(); const uint32_t n_kv = cells.used_max_p1(); @@ -1478,64 +1819,94 @@ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { } void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { - std::vector> cell_ranges; // ranges, from inclusive, to exclusive - uint32_t cell_count = 0; + io.write(&n_stream, sizeof(n_stream)); - // Count the number of cells with the specified seq_id - // Find all the ranges of cells with this seq id (or all, when -1) - uint32_t cell_range_begin = cells.size(); + for (uint32_t s = 0; s < n_stream; ++s) { + cell_ranges_t cr { s, {} }; - for (uint32_t i = 0; i < cells.size(); ++i) { - if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) { - ++cell_count; - if (cell_range_begin == cells.size()) { - cell_range_begin = i; - } - } else { - if (cell_range_begin != cells.size()) { - cell_ranges.emplace_back(cell_range_begin, i); - cell_range_begin = cells.size(); + uint32_t cell_count = 0; + + const auto & cells = v_cells[s]; + + // Count the number of cells with the specified seq_id + // Find all the ranges of cells with this seq id (or all, when -1) + uint32_t cell_range_begin = cells.size(); + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) { + ++cell_count; + if (cell_range_begin == cells.size()) { + cell_range_begin = i; + } + } else { + if (cell_range_begin != cells.size()) { + cr.data.emplace_back(cell_range_begin, i); + cell_range_begin = cells.size(); + } } } - } - if (cell_range_begin != cells.size()) { - cell_ranges.emplace_back(cell_range_begin, cells.size()); - } + if (cell_range_begin != cells.size()) { + cr.data.emplace_back(cell_range_begin, cells.size()); + } - // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count - uint32_t cell_count_check = 0; - for (const auto & range : cell_ranges) { - cell_count_check += range.second - range.first; - } - GGML_ASSERT(cell_count == cell_count_check); + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count + uint32_t cell_count_check = 0; + for (const auto & range : cr.data) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(cell_count == cell_count_check); - io.write(&cell_count, sizeof(cell_count)); + io.write(&cell_count, sizeof(cell_count)); - state_write_meta(io, cell_ranges, seq_id); - state_write_data(io, cell_ranges); + // skip empty streams + if (cell_count == 0) { + continue; + } + + state_write_meta(io, cr, seq_id); + state_write_data(io, cr); + } } void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) { - uint32_t cell_count; - io.read_to(&cell_count, sizeof(cell_count)); + GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size())); - bool res = true; - res = res && state_read_meta(io, cell_count, seq_id); - res = res && state_read_data(io, cell_count); + uint32_t n_stream_cur; + io.read_to(&n_stream_cur, sizeof(n_stream_cur)); + if (n_stream_cur != n_stream) { + throw std::runtime_error("n_stream mismatch"); + } + + for (uint32_t s = 0; s < n_stream; ++s) { + uint32_t cell_count; + io.read_to(&cell_count, sizeof(cell_count)); + + if (cell_count == 0) { + continue; + } + + const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id]; - if (!res) { - if (seq_id == -1) { - clear(true); - } else { - seq_rm(seq_id, -1, -1); + bool res = true; + res = res && state_read_meta(io, strm, cell_count, seq_id); + res = res && state_read_data(io, strm, cell_count); + + if (!res) { + if (seq_id == -1) { + clear(true); + } else { + seq_rm(seq_id, -1, -1); + } + throw std::runtime_error("failed to restore kv cache"); } - throw std::runtime_error("failed to restore kv cache"); } } -void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { - for (const auto & range : cell_ranges) { +void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const { + const auto & cells = v_cells[cr.strm]; + + for (const auto & range : cr.data) { for (uint32_t i = range.first; i < range.second; ++i) { std::vector seq_ids; @@ -1560,7 +1931,9 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std:: } } -void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const { +void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const { + const auto & cells = v_cells[cr.strm]; + const uint32_t v_trans = this->v_trans ? 1 : 0; const uint32_t n_layer = layers.size(); @@ -1576,19 +1949,21 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + auto * k = layer.k_stream[cr.strm]; + // Write key type - const int32_t k_type_i = (int32_t)layer.k->type; + const int32_t k_type_i = (int32_t) k->type; io.write(&k_type_i, sizeof(k_type_i)); // Write row size of key - const uint64_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa); + const uint64_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa); io.write(&k_size_row, sizeof(k_size_row)); // Read each range of cells of k_size length each into tmp_buf and write out - for (const auto & range : cell_ranges) { + for (const auto & range : cr.data) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * k_size_row; - io.write_tensor(layer.k, range.first * k_size_row, buf_size); + io.write_tensor(k, range.first * k_size_row, buf_size); } } @@ -1598,19 +1973,21 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + auto * v = layer.v_stream[cr.strm]; + // Write value type - const int32_t v_type_i = (int32_t)layer.v->type; + const int32_t v_type_i = (int32_t) v->type; io.write(&v_type_i, sizeof(v_type_i)); // Write row size of value - const uint64_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa); + const uint64_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa); io.write(&v_size_row, sizeof(v_size_row)); // Read each range of cells of v_size length each into tmp_buf and write out - for (const auto & range : cell_ranges) { + for (const auto & range : cr.data) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * v_size_row; - io.write_tensor(layer.v, range.first * v_size_row, buf_size); + io.write_tensor(v, range.first * v_size_row, buf_size); } } } else { @@ -1622,12 +1999,14 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + auto * v = layer.v_stream[cr.strm]; + // Write value type - const int32_t v_type_i = (int32_t)layer.v->type; + const int32_t v_type_i = (int32_t) v->type; io.write(&v_type_i, sizeof(v_type_i)); // Write element size - const uint32_t v_size_el = ggml_type_size(layer.v->type); + const uint32_t v_size_el = ggml_type_size(v->type); io.write(&v_size_el, sizeof(v_size_el)); // Write GQA embedding size @@ -1636,27 +2015,31 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: // For each row, we get the element values of each cell for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { // Read each range of cells of v_size_el length each into tmp_buf and write out - for (const auto & range : cell_ranges) { + for (const auto & range : cr.data) { const size_t range_size = range.second - range.first; const size_t src_offset = (range.first + j * kv_size) * v_size_el; const size_t buf_size = range_size * v_size_el; - io.write_tensor(layer.v, src_offset, buf_size); + io.write_tensor(v, src_offset, buf_size); } } } } } -bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { +bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) { + auto & cells = v_cells[strm]; + auto & head = v_heads[strm]; + if (dest_seq_id != -1) { // single sequence - seq_rm(dest_seq_id, -1, -1); llama_batch_allocr balloc(hparams.n_pos_per_embd()); llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1); + ubatch.seq_id_unq[0] = dest_seq_id; + for (uint32_t i = 0; i < cell_count; ++i) { llama_pos pos; uint32_t n_seq_id; @@ -1693,6 +2076,8 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell // keep the head at the old position because we will read the KV data into it in state_read_data() head = head_cur; + LLAMA_LOG_DEBUG("%s: head_cur = %d, head = %d, cell_count = %d, dest_seq_id = %d\n", __func__, head_cur, head, cell_count, dest_seq_id); + // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values) // Assume that this is one contiguous block of cells GGML_ASSERT(head_cur + cell_count <= cells.size()); @@ -1738,7 +2123,10 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell return true; } -bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) { +bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) { + auto & cells = v_cells[strm]; + auto & head = v_heads[strm]; + uint32_t v_trans; uint32_t n_layer; @@ -1766,10 +2154,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + auto * k = layer.k_stream[strm]; + // Read type of key int32_t k_type_i_ref; io.read_to(&k_type_i_ref, sizeof(k_type_i_ref)); - const int32_t k_type_i = (int32_t) layer.k->type; + const int32_t k_type_i = (int32_t) k->type; if (k_type_i != k_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); return false; @@ -1778,7 +2168,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell // Read row size of key uint64_t k_size_row_ref; io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); - const size_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa); + const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa); if (k_size_row != k_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); return false; @@ -1786,7 +2176,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell if (cell_count) { // Read and set the keys for the whole cell range - ggml_backend_tensor_set(layer.k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); + ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); } } @@ -1796,10 +2186,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + auto * v = layer.v_stream[strm]; + // Read type of value int32_t v_type_i_ref; io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); - const int32_t v_type_i = (int32_t)layer.v->type; + const int32_t v_type_i = (int32_t) v->type; if (v_type_i != v_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); return false; @@ -1808,7 +2200,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell // Read row size of value uint64_t v_size_row_ref; io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); - const size_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa); + const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa); if (v_size_row != v_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); return false; @@ -1816,7 +2208,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell if (cell_count) { // Read and set the values for the whole cell range - ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); + ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); } } } else { @@ -1826,10 +2218,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + auto * v = layer.v_stream[strm]; + // Read type of value int32_t v_type_i_ref; io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); - const int32_t v_type_i = (int32_t)layer.v->type; + const int32_t v_type_i = (int32_t) v->type; if (v_type_i != v_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); return false; @@ -1838,7 +2232,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell // Read element size of value uint32_t v_size_el_ref; io.read_to(&v_size_el_ref, sizeof(v_size_el_ref)); - const size_t v_size_el = ggml_type_size(layer.v->type); + const size_t v_size_el = ggml_type_size(v->type); if (v_size_el != v_size_el_ref) { LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); return false; @@ -1856,7 +2250,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell // For each row in the transposed matrix, read the values for the whole cell range for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { const size_t dst_offset = (head + j * cells.size()) * v_size_el; - ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); } } } @@ -1875,18 +2269,26 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) { n_kv = kv->get_size(); + const uint32_t n_stream = kv->get_n_stream(); + // create a dummy slot info - the actual data is irrelevant. we just need to build the graph sinfos.resize(1); - sinfos[0].idxs.resize(1); - sinfos[0].idxs[0] = 0; + sinfos[0].s0 = 0; + sinfos[0].s1 = n_stream - 1; + sinfos[0].idxs.resize(n_stream); + for (uint32_t s = 0; s < n_stream; ++s) { + sinfos[0].strm.push_back(s); + sinfos[0].idxs[s].resize(1, 0); + } } llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified * kv, llama_context * lctx, bool do_shift, - defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) { - if (!do_shift && this->dinfo.empty()) { + defrag_info dinfo, + stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)), sc_info(std::move(sc_info)) { + if (!do_shift && this->dinfo.empty() && this->sc_info.empty()) { status = LLAMA_MEMORY_STATUS_NO_UPDATE; } } @@ -1914,7 +2316,7 @@ bool llama_kv_cache_unified_context::apply() { // no ubatches -> this is a KV cache update if (ubatches.empty()) { - kv->update(lctx, do_shift, dinfo); + kv->update(lctx, do_shift, dinfo, sc_info); return true; } @@ -1941,11 +2343,11 @@ uint32_t llama_kv_cache_unified_context::get_n_kv() const { } ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const { - return kv->get_k(ctx, il, n_kv); + return kv->get_k(ctx, il, n_kv, sinfos[i_cur]); } ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const { - return kv->get_v(ctx, il, n_kv); + return kv->get_v(ctx, il, n_kv, sinfos[i_cur]); } ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const { diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index b8b0356e830c8..3bfda4600d843 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -35,16 +35,50 @@ class llama_kv_cache_unified : public llama_memory_i { std::vector ids; }; + struct stream_copy_info { + bool empty() const { + assert(ssrc.size() == sdst.size()); + return ssrc.empty(); + } + + std::vector ssrc; + std::vector sdst; + }; + // for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the // KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]] struct slot_info { // data for ggml_set_rows using idx_vec_t = std::vector; - idx_vec_t idxs; + // number of streams: ns = s1 - s0 + 1 + llama_seq_id s0; + llama_seq_id s1; + + std::vector strm; // [ns] + std::vector idxs; // [ns] uint32_t head() const { - return idxs.at(0); + GGML_ASSERT(idxs.size() == 1); + GGML_ASSERT(!idxs[0].empty()); + + return idxs[0][0]; + } + + void resize(size_t n) { + strm.resize(n); + idxs.resize(n); + } + + size_t size() const { + GGML_ASSERT(idxs.size() == strm.size()); + GGML_ASSERT(!idxs.empty()); + + return idxs[0].size(); + } + + size_t n_stream() const { + return strm.size(); } bool empty() const { @@ -54,9 +88,6 @@ class llama_kv_cache_unified : public llama_memory_i { void clear() { idxs.clear(); } - - // TODO: implement - //std::vector seq_idxs; }; using slot_info_vec_t = std::vector; @@ -68,6 +99,7 @@ class llama_kv_cache_unified : public llama_memory_i { ggml_type type_v, bool v_trans, bool offload, + bool unified, uint32_t kv_size, uint32_t n_seq_max, uint32_t n_pad, @@ -111,7 +143,8 @@ class llama_kv_cache_unified : public llama_memory_i { // llama_kv_cache_unified specific API // - uint32_t get_size() const; + uint32_t get_size() const; + uint32_t get_n_stream() const; bool get_has_shift() const; @@ -122,8 +155,8 @@ class llama_kv_cache_unified : public llama_memory_i { uint32_t get_n_kv() const; // get views of the current state of the cache - ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const; - ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const; + ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; + ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; // store k_cur and v_cur in the cache based on the provided head location ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const; @@ -137,7 +170,7 @@ class llama_kv_cache_unified : public llama_memory_i { // return empty vector on failure slot_info_vec_t prepare(const std::vector & ubatches); - bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo); + bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info); // find a slot of kv cells that can hold the ubatch // if cont == true, then the slot must be continuous @@ -157,8 +190,9 @@ class llama_kv_cache_unified : public llama_memory_i { void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; + void set_input_k_shift(ggml_tensor * dst) const; + void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; - void set_input_k_shift (ggml_tensor * dst) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; private: @@ -172,15 +206,15 @@ class llama_kv_cache_unified : public llama_memory_i { ggml_tensor * k; ggml_tensor * v; + + std::vector k_stream; + std::vector v_stream; }; bool v_trans = true; // the value tensor is transposed - // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) - // note: this is not part of the KV state and it's only used to speed-up the find_slot() method - uint32_t head = 0; - const uint32_t n_seq_max = 1; + const uint32_t n_stream = 1; // required padding const uint32_t n_pad = 1; @@ -200,7 +234,17 @@ class llama_kv_cache_unified : public llama_memory_i { std::vector ctxs; std::vector bufs; - llama_kv_cells_unified cells; + // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) + // note: this is not part of the KV state and it's only used to speed-up the find_slot() method + std::vector v_heads; + + std::vector v_cells; + + // maps from a sequence id to a stream id + std::vector seq_to_stream; + + // pending stream copies that will be applied during the next update + stream_copy_info sc_info; std::vector layers; @@ -237,18 +281,25 @@ class llama_kv_cache_unified : public llama_memory_i { ggml_cgraph * gf, const defrag_info & dinfo) const; - void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; - void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; + struct cell_ranges_t { + uint32_t strm; - bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); - bool state_read_data(llama_io_read_i & io, uint32_t cell_count); + std::vector> data; // ranges, from inclusive, to exclusive + }; + + void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const; + void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const; + + bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1); + bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count); }; class llama_kv_cache_unified_context : public llama_memory_context_i { public: // some shorthands - using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t; - using defrag_info = llama_kv_cache_unified::defrag_info; + using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t; + using defrag_info = llama_kv_cache_unified::defrag_info; + using stream_copy_info = llama_kv_cache_unified::stream_copy_info; // used for errors llama_kv_cache_unified_context(llama_memory_status status); @@ -262,7 +313,8 @@ class llama_kv_cache_unified_context : public llama_memory_context_i { llama_kv_cache_unified * kv, llama_context * lctx, bool do_shift, - defrag_info dinfo); + defrag_info dinfo, + stream_copy_info sc_info); // used to create a batch procesing context from a batch llama_kv_cache_unified_context( @@ -320,6 +372,8 @@ class llama_kv_cache_unified_context : public llama_memory_context_i { defrag_info dinfo; + stream_copy_info sc_info; + // // batch processing context // diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index 6cd10db06b775..d8e2086c87514 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -38,6 +38,7 @@ llama_memory_hybrid::llama_memory_hybrid( type_v, v_trans, offload, + 1, kv_size, n_seq_max, n_pad, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a322fc39352e7..9d8a686e0a571 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -849,6 +849,21 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_DREAM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // Dream models are primarily 7B with 28 layers + switch (hparams.n_layer) { + case 28: + type = LLM_TYPE_7B; + break; + default: + type = LLM_TYPE_UNKNOWN; + } + // Set non-causal attention for diffusion models + hparams.causal_attn = false; + } + break; case LLM_ARCH_QWEN2MOE: { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); @@ -935,6 +950,33 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_PLAMO2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // Load Mamba SSM parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; + } + + switch (hparams.n_layer) { + case 16: type = LLM_TYPE_1B; break; + case 32: + if (hparams.n_embd == 2048) { + type = LLM_TYPE_2B; + } else if (hparams.n_embd == 4096) { + type = LLM_TYPE_8B; + } + break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_GPT2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -2643,12 +2685,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_QWEN2: case LLM_ARCH_QWEN2VL: + case LLM_ARCH_DREAM: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (output == NULL) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); @@ -2938,6 +2982,73 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } } break; + case LLM_ARCH_PLAMO2: + { + const uint32_t d_conv = hparams.ssm_d_conv; + const uint32_t d_state = hparams.ssm_d_state; + const uint32_t num_heads = hparams.ssm_dt_rank; + const uint32_t intermediate_size = hparams.ssm_d_inner; + const uint32_t head_dim = intermediate_size / num_heads; + const uint32_t qk_dim = head_dim; + const uint32_t v_dim = head_dim; + const int64_t num_attention_heads = hparams.n_head(); + const int64_t q_num_heads = num_attention_heads; + const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + bool is_mamba_layer = hparams.is_recurrent(i); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (is_mamba_layer) { + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2 * intermediate_size}, 0); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, intermediate_size}, 0); + + layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {intermediate_size, dt_dim + 2*d_state}, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_dim, num_heads}, 0); + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {num_heads}, 0); + + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {num_heads}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {num_heads}, 0); + + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {intermediate_size, n_embd}, 0); + + layer.ssm_dt_norm = create_tensor(tn(LLM_TENSOR_SSM_DT_NORM, i), {dt_dim}, 0); + layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, i), {d_state}, 0); + layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, i), {d_state}, 0); + } else { + const int64_t num_key_value_heads = hparams.n_head_kv(i); + const int64_t k_num_heads = num_key_value_heads; + const int64_t v_num_heads = num_key_value_heads; + const int64_t q_proj_dim = q_num_heads * qk_dim; + const int64_t k_proj_dim = k_num_heads * qk_dim; + const int64_t v_proj_dim = v_num_heads * v_dim; + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, q_proj_dim + k_proj_dim + v_proj_dim}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {head_dim, num_attention_heads}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {head_dim, k_num_heads}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {q_num_heads * v_dim, n_embd}, 0); + } + + // All layers have post-attention norm, FFN norm, and FFN tensors + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, i), {n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0); + } + } break; case LLM_ARCH_GPT2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -5209,6 +5320,7 @@ void llama_model::print_info() const { arch == LLM_ARCH_MAMBA2 || arch == LLM_ARCH_JAMBA || arch == LLM_ARCH_FALCON_H1 || + arch == LLM_ARCH_PLAMO2 || arch == LLM_ARCH_GRANITE_HYBRID) { LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); @@ -7654,6 +7766,113 @@ struct llm_build_qwen2 : public llm_graph_context { // lm_head cur = build_lora_mm(model.output, cur); + if (model.output_b != nullptr) { + cur = ggml_add(ctx0, cur, model.output_b); + } + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_dream : public llm_graph_context { + llm_build_dream(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : + llm_graph_context(params) { + //copied from qwen2 + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_no_cache(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, + nullptr, 1.0f / sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); res->t_logits = cur; @@ -15476,6 +15695,320 @@ struct llm_build_falcon_h1 : public llm_graph_context_mamba { } }; +struct llm_build_plamo2 : public llm_graph_context_mamba { + llm_build_plamo2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) { + ggml_tensor * cur; + ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = build_inp_embd(model.tok_embd); + cb(inpL, "embedding_output", -1); + + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_hybrid = build_inp_mem_hybrid(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * residual = inpL; + + // ggml_graph_add_node(gf, model.layers[il].attn_norm); + // cb(model.layers[il].attn_norm, "attn_norm", il); + + // pre_mixer_norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + + // check if this layer is Mamba or Attention + bool is_mamba_layer = hparams.is_recurrent(il); + + if (is_mamba_layer) { + // PLaMo-2 Mamba layer + cur = build_plamo2_mamba_layer(inp_hybrid->get_recr(), gf, cur, model, ubatch, il); + } else { + // PLaMo-2 Attention layer + cur = build_plamo2_attn_layer(inp_hybrid->get_attn(), inp_pos, gf, cur, model, il); + } + + // post_mixer_norm + cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + // residual connection + cur = ggml_add(ctx0, cur, residual); + cb(cur, "attn_residual", il); + residual = cur; + + // pre-ffn norm + cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_pre_norm", il); + + // feed-forward network + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SWIGLU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + // post ffn norm + cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_post_norm", il); + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + residual = ggml_get_rows(ctx0, residual, inp_out_ids); + } + + // residual connection + cur = ggml_add(ctx0, cur, residual); + cb(cur, "ffn_residual", il); + + inpL = cur; + } + + cur = inpL; + + // final norm + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); + + // Explicitly mark as output tensor to ensure proper backend assignment + ggml_set_output(cur); + + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } + +private: + ggml_tensor * build_plamo2_attn_layer( + llm_graph_input_attn_kv_unified * inp, + ggml_tensor * inp_pos, + ggml_cgraph * gf, + ggml_tensor * cur, + const llama_model & model, + int il) { + + // self-attention + { + // PLaMo-2 uses combined QKV tensor + ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur); + cb(qkv, "qkv", il); + + // split QKV tensor into Q, K, V + const int64_t n_embd_head_q = hparams.n_embd_head_k; + const int64_t n_embd_head_k = hparams.n_embd_head_k; + const int64_t n_embd_head_v = hparams.n_embd_head_v; + int32_t n_head_kv = hparams.n_head_kv(il); + + const int64_t q_offset = 0; + const int64_t k_offset = n_embd_head_q * n_head; + const int64_t v_offset = k_offset + n_embd_head_k * n_head_kv; + + ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_q, n_head, n_tokens, n_embd_head_q * sizeof(float), qkv->nb[1], q_offset * ggml_element_size(qkv)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head_kv, n_tokens, n_embd_head_k * sizeof(float), qkv->nb[1], k_offset * ggml_element_size(qkv)); + ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_head_v * n_head_kv, n_tokens, qkv->nb[1], v_offset * ggml_element_size(qkv))); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cur = build_attn(inp, gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f, il); + } + + cb(cur, "attn_out", il); + + return cur; + } + + ggml_tensor * build_plamo2_mamba_layer( + llm_graph_input_rs * inp, + ggml_cgraph * gf, + ggml_tensor * cur, + const llama_model & model, + const llama_ubatch & ubatch, + int il) { + + const auto * mctx_cur = inp->mctx; + + const auto kv_head = mctx_cur->get_head(); + + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_heads = hparams.ssm_dt_rank; + const int64_t head_dim = d_inner / n_heads; + const int64_t n_group = hparams.ssm_n_group; + const int64_t n_seqs = ubatch.n_seqs; + + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); + ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); + + ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs); + conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); + + // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} + cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); + + // in_proj: {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs} + ggml_tensor * zx = build_lora_mm(model.layers[il].ssm_in, cur); + cb(zx, "mamba_in_proj", il); + // {8192, 5, 1, 1} -> {8192, 1, 5, 1} + zx = ggml_permute(ctx0, zx, 0, 2, 1, 3); + zx = ggml_cont(ctx0, zx); + zx = ggml_reshape_4d(ctx0, zx, head_dim * 2, n_heads, n_seq_tokens, n_seqs); + cb(zx, "mamba_in_proj_out", il); + + // split into z and x + // => {head_dim * n_heads, n_seq_tokens, n_seqs} + ggml_tensor * x = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], head_dim*ggml_element_size(zx)); + x = ggml_cont(ctx0, x); + x = ggml_reshape_3d(ctx0, x, head_dim * n_heads, n_seq_tokens, n_seqs); + // x = ggml_permute(ctx0, x, 0, 2, 1, 3); + cb(x, "mamba_x_split", il); + + ggml_tensor * z = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], 0); + cb(z, "mamba_z_split", il); + + // conv1d + { + // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs} + ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, x), 0); + cb(conv_x, "mamba_conv1d_input", il); + + // copy last (d_conv - 1) columns back into the state cache + ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs, + conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); + + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, last_conv, + ggml_view_1d(ctx0, conv_states_all, + (d_conv - 1)*(d_inner)*(n_seqs), + kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all)))); + + // 1D convolution + x = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d); + cb(x, "mamba_conv1d", il); + + x = ggml_silu(ctx0, x); + cb(x, "mamba_conv1d_silu", il); + } + + // SSM + { + // bcdt_proj: {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs} + ggml_tensor * x_bcdt = build_lora_mm(model.layers[il].ssm_x, x); + cb(x_bcdt, "mamba_bcdt_proj", il); + + // split into dt, B, C + const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); + ggml_tensor * B = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], 0); + ggml_tensor * C = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*d_state); + ggml_tensor * dt = ggml_view_3d(ctx0, x_bcdt, dt_dim, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*(2*d_state)); + cb(B, "mamba_B_raw", il); + cb(C, "mamba_C_raw", il); + cb(dt, "mamba_dt_raw", il); + + // Apply RMS norm to dt, B, C (PLaMo-2 specific) + B = build_norm(B, model.layers[il].ssm_b_norm, NULL, LLM_NORM_RMS, il); + C = build_norm(C, model.layers[il].ssm_c_norm, NULL, LLM_NORM_RMS, il); + dt = build_norm(dt, model.layers[il].ssm_dt_norm, NULL, LLM_NORM_RMS, il); + cb(B, "mamba_B_normed", il); + cb(C, "mamba_C_normed", il); + cb(dt, "mamba_dt_normed", il); + + // dt_proj: {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs} + dt = build_lora_mm(model.layers[il].ssm_dt, dt); + dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); + cb(dt, "mamba_dt_proj", il); + + ggml_tensor * A = ggml_reshape_2d(ctx0, model.layers[il].ssm_a, 1, n_heads); + cb(A, "mamba_A", il); + + x = ggml_view_4d(ctx0, x, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x), head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0); + B = ggml_view_4d(ctx0, B, d_state, 1, n_seq_tokens, n_seqs, d_state * B->nb[0], B->nb[1], B->nb[2], 0); + C = ggml_view_4d(ctx0, C, d_state, 1, n_seq_tokens, n_seqs, d_state * C->nb[0], C->nb[1], C->nb[2], 0); + + // use the states and the indices provided by build_recurrent_state + // (this is necessary in order to properly use the states before they are overwritten, + // while avoiding to make unnecessary copies of the states) + auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { + ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_heads, mctx_cur->get_size()); + + // Custom operator to optimize the parallel associative scan + // as described in the Annex D of the Mamba paper. + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); + }; + + ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); + cb(y_ssm, "mamba_ssm_scan", il); + + // store last states + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, + ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]*x->ne[3]), + ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, + kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); + + ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x), head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0); + cb(y, "mamba_y_view", il); + + // Add D parameter and apply gating with z + // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs} + ggml_tensor * D = ggml_reshape_2d(ctx0, model.layers[il].ssm_d, 1, n_heads); + y = ggml_add(ctx0, y, ggml_mul(ctx0, x, D)); + cb(y, "mamba_y_add_d", il); + + y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); + cb(y, "mamba_y_swiglu_z", il); + + // out_proj: {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} + y = ggml_view_3d(ctx0, y, head_dim * n_heads, n_seq_tokens, n_seqs, y->nb[2], y->nb[3], 0); + cur = build_lora_mm(model.layers[il].ssm_out, y); + cb(cur, "mamba_out_proj", il); + } + + // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); + cb(cur, "mamba_out", il); + + return cur; + } +}; + struct llm_build_arcee : public llm_graph_context { llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -16078,6 +16611,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_NEO_BERT: case LLM_ARCH_WAVTOKENIZER_DEC: + case LLM_ARCH_DREAM: { res = nullptr; } break; @@ -16118,7 +16652,18 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, } else { const auto padding = llama_kv_cache_unified::get_padding(cparams); - cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); + uint32_t n_ctx_per_stream = cparams.n_ctx; + + if (!cparams.kv_unified) { + n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max; + n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding); + + cparams.n_ctx = n_ctx_per_stream*cparams.n_seq_max; + } else { + n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding); + + cparams.n_ctx = n_ctx_per_stream; + } LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); @@ -16132,7 +16677,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, !cparams.flash_attn, cparams.offload_kqv, params.swa_full, - cparams.n_ctx, + cparams.kv_unified, + n_ctx_per_stream, cparams.n_seq_max, cparams.n_ubatch, padding); @@ -16146,7 +16692,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, params.type_v, !cparams.flash_attn, cparams.offload_kqv, - cparams.n_ctx, + cparams.kv_unified, + n_ctx_per_stream, cparams.n_seq_max, padding, hparams.n_swa, @@ -16229,6 +16776,11 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_DREAM: + { + llm = std::make_unique(*this, params, gf); + } + break; case LLM_ARCH_QWEN2VL: { llm = std::make_unique(*this, params, gf); @@ -16262,6 +16814,10 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_PLAMO2: + { + llm = std::make_unique(*this, params, gf); + } break; case LLM_ARCH_GPT2: { llm = std::make_unique(*this, params, gf); @@ -16642,6 +17198,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_BITNET: case LLM_ARCH_QWEN: case LLM_ARCH_QWEN2: + case LLM_ARCH_DREAM: case LLM_ARCH_QWEN2MOE: case LLM_ARCH_QWEN3: case LLM_ARCH_QWEN3MOE: @@ -16651,6 +17208,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_PHI3: case LLM_ARCH_PHIMOE: case LLM_ARCH_PLAMO: + case LLM_ARCH_PLAMO2: case LLM_ARCH_GEMMA: case LLM_ARCH_GEMMA2: case LLM_ARCH_GEMMA3: diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 4dbd1e309919a..a00af7a1d1758 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -884,8 +884,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) { if (qtype != new_type) { LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type)); - new_type = qtype; - break; // if two or more types are specified for the tensor, first match wins + new_type = qtype; // if two or more types are specified for the same tensor, the last match wins } } } diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index e0e578d6394d8..2181c01e31a87 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -404,6 +405,13 @@ struct llm_tokenizer_bpe : llm_tokenizer { "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_KIMI_K2: + regex_exprs = { + // K2 trigger pattern - this will activate the custom K2 handler in unicode.cpp + // The custom handler implements all K2 patterns with proper Han character exclusion + "\\p{Han}+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_SUPERBPE: regex_exprs = { "\\p{N}+", @@ -1196,6 +1204,284 @@ struct llm_tokenizer_rwkv_session { const llm_tokenizer_rwkv & tokenizer; }; +struct llm_tokenizer_plamo2 : llm_tokenizer { + llm_tokenizer_plamo2(const llama_vocab & vocab) { + build(vocab); + } + + void build(const llama_vocab & vocab) { + // Reset internal structures + tokens_.clear(); + bytes_.assign(256, 0); + to_suffix_id_.clear(); + table_.clear(); + + // Build token list and byte mapping + std::unordered_map suffix_to_score; + std::unordered_map token_to_id; + + for (size_t token_id = 0; token_id < vocab.n_tokens(); ++token_id) { + const auto & entry = vocab.get_token_data(token_id); + tokens_.push_back(entry.text); + token_to_id[entry.text] = static_cast(token_id); + + // Handle byte tokens + if (vocab.is_byte(token_id)) { + if (entry.text.length() == 6 && entry.text.substr(0, 3) == "<0x" && entry.text.back() == '>') { + std::string hex_str = entry.text.substr(3, 2); + int byte_val = std::stoi(hex_str, nullptr, 16); + bytes_[byte_val] = static_cast(token_id); + } + continue; + } + + // Add token and all its suffixes to suffix_to_score + suffix_to_score[entry.text] = entry.score; + + // Extract suffixes character by character (UTF-8 aware) + std::vector cpts = unicode_cpts_from_utf8(entry.text); + for (size_t i = 1; i < cpts.size(); ++i) { + std::string suffix; + for (size_t j = i; j < cpts.size(); ++j) { + suffix += unicode_cpt_to_utf8(cpts[j]); + } + if (suffix_to_score.find(suffix) == suffix_to_score.end()) { + suffix_to_score[suffix] = std::numeric_limits::quiet_NaN(); + } + } + } + + // Check that all byte tokens are set + for (int i = 0; i < 256; ++i) { + if (bytes_[i] == 0) { + throw std::runtime_error("Byte token for <0x" + std::to_string(i) + "> is not set"); + } + } + + // Build suffix list in lexicographical order of reversed strings + std::vector suffixes; + for (const auto & pair : suffix_to_score) { + suffixes.push_back(pair.first); + } + suffixes.push_back(""); // Empty suffix + + std::sort(suffixes.begin(), suffixes.end(), [](const std::string & a, const std::string & b) { + std::string rev_a(a.rbegin(), a.rend()); + std::string rev_b(b.rbegin(), b.rend()); + return rev_a < rev_b; + }); + + // Build suffix_to_id and to_suffix_id_ + std::unordered_map suffix_to_id; + int32_t num_pieces = 0; + + for (const auto & suffix : suffixes) { + suffix_to_id[suffix] = num_pieces; + if (!suffix.empty()) { + std::vector cpts = unicode_cpts_from_utf8(suffix); + + std::string remaining; + for (size_t i = 1; i < cpts.size(); ++i) { + remaining += unicode_cpt_to_utf8(cpts[i]); + } + + int64_t piece_code = (static_cast(cpts[0]) << 32) | suffix_to_id[remaining]; + to_suffix_id_[piece_code] = num_pieces; + + // Count number of pieces for this suffix + int32_t pieces_for_suffix = 1; // sentinel row + for (int32_t piece_length = static_cast(cpts.size()); piece_length > 0; --piece_length) { + std::string piece; + for (int32_t i = 0; i < piece_length; ++i) { + piece += unicode_cpt_to_utf8(cpts[i]); + } + if (suffix_to_score.find(piece) != suffix_to_score.end()) { + pieces_for_suffix++; + } + } + num_pieces += pieces_for_suffix; + } else { + num_pieces++; // Empty suffix contributes one piece (sentinel row) + } + } + + // Build flattened table + table_.resize(num_pieces, std::vector(4, 0)); + int32_t table_idx = 0; + + for (const auto & suffix : suffixes) { + // Add all prefixes of the suffix to the table (in decreasing order of length) + std::vector cpts = unicode_cpts_from_utf8(suffix); + for (int32_t piece_length = static_cast(cpts.size()); piece_length > 0; --piece_length) { + std::string piece; + for (int32_t i = 0; i < piece_length; ++i) { + piece += unicode_cpt_to_utf8(cpts[i]); + } + + auto score_it = suffix_to_score.find(piece); + if (score_it == suffix_to_score.end()) { + continue; + } + + table_[table_idx][TABLE_PIECE_LENGTH] = piece_length; + auto token_it = token_to_id.find(piece); + table_[table_idx][TABLE_TOKEN_ID] = (token_it != token_to_id.end()) ? token_it->second : -1; + + float score = score_it->second; + table_[table_idx][TABLE_SCORE] = std::isfinite(score) ? + static_cast(std::round(score * 1e4)) : INVALID_SCORE; + table_[table_idx][TABLE_PIECE_ID] = suffix_to_id[piece]; + + table_idx++; + } + + // Add sentinel row + table_[table_idx][TABLE_PIECE_LENGTH] = 1; + table_[table_idx][TABLE_TOKEN_ID] = -1; + table_[table_idx][TABLE_SCORE] = UNKNOWN_SCORE; + table_idx++; + } + } + + std::vector encode(const std::string & text) const { + std::vector unicode_data = unicode_cpts_from_utf8(text); + // Skip the first code point if it is a BOM (Byte Order Mark) + if (!unicode_data.empty() && unicode_data[0] == 0xFEFF) { + unicode_data.erase(unicode_data.begin()); + } + + if (unicode_data.empty()) { + return {}; + } + + const size_t data_len = unicode_data.size(); + + // Initialize scores array (dynamic programming) + std::vector scores(data_len + 1, static_cast(1) << 60); + scores[data_len] = 0; + + // Path array to track best tokenization + std::vector> path(data_len + 1, std::vector(3, 0)); + + int32_t suffix_id = 0; + + // Process from end to beginning + for (int i = static_cast(data_len) - 1; i >= 0; --i) { + uint32_t c = unicode_data[i]; + + // Find next suffix ID + for (size_t p = suffix_id; p < table_.size(); ++p) { + int64_t piece_code = (static_cast(c) << 32) | table_[p][TABLE_PIECE_ID]; + auto it = to_suffix_id_.find(piece_code); + suffix_id = (it != to_suffix_id_.end()) ? it->second : 0; + + if (suffix_id > 0 || table_[p][TABLE_SCORE] == UNKNOWN_SCORE) { + break; + } + } + + // Update best path + for (size_t p = suffix_id; p < table_.size(); ++p) { + int32_t score = table_[p][TABLE_SCORE]; + if (score > INVALID_SCORE) { + int32_t piece_length = table_[p][TABLE_PIECE_LENGTH]; + int64_t s = scores[i + piece_length] - score; + + if (s < scores[i]) { + scores[i] = s; + path[i][PATH_TOKEN_LENGTH] = piece_length; + path[i][PATH_TOKEN_ID] = table_[p][TABLE_TOKEN_ID]; + path[i][PATH_NUM_TOKENS] = path[i + piece_length][PATH_NUM_TOKENS] + 1; + + if (score == UNKNOWN_SCORE) { + // Add UTF-8 byte count + path[i][PATH_NUM_TOKENS] += (c >= 0x80) + (c >= 0x800) + (c >= 0x10000); + } + } + } + + if (score == UNKNOWN_SCORE) { + break; + } + } + } + + // Decode the best path + std::vector token_ids; + token_ids.reserve(path[0][PATH_NUM_TOKENS]); + + int pos = 0; + while (pos < static_cast(data_len)) { + if (path[pos][PATH_TOKEN_ID] >= 0) { + token_ids.push_back(path[pos][PATH_TOKEN_ID]); + } else { + // Fall back to byte tokens + uint32_t c = unicode_data[pos]; + int s = 1 + (c >= 0x80) + (c >= 0x800) + (c >= 0x10000); + + for (int i = 0; i < s; ++i) { + uint8_t b; + if (s == 1) { + b = c; + } else { + if (i == 0) { + b = (0xF00 >> s) & 0xFF; + } else { + b = 0x80; + } + } + token_ids.push_back(bytes_[b | ((c >> ((s - i - 1) * 6)) & 0x3F)]); + } + } + + assert(path[pos][PATH_TOKEN_LENGTH] > 0); + pos += path[pos][PATH_TOKEN_LENGTH]; + } + + return token_ids; + } +private: + // Constants for table structure + static constexpr int32_t TABLE_PIECE_LENGTH = 0; + static constexpr int32_t TABLE_TOKEN_ID = 1; + static constexpr int32_t TABLE_SCORE = 2; + static constexpr int32_t TABLE_PIECE_ID = 3; + + // Constants for path array + static constexpr int32_t PATH_TOKEN_LENGTH = 0; + static constexpr int32_t PATH_TOKEN_ID = 1; + static constexpr int32_t PATH_NUM_TOKENS = 2; + + // Score constants + static constexpr int32_t INVALID_SCORE = -20000000; + static constexpr int32_t UNKNOWN_SCORE = -10000000; + + // List of tokens in the vocabulary + std::vector tokens_; + + // Mapping from byte code point to token ID (for byte fallback) + std::vector bytes_; + + // Mapping from piece code to suffix ID + std::unordered_map to_suffix_id_; + + // Flattened table representing the Trie structure + // Each row contains: [piece_length, token_id, score, piece_id] + std::vector> table_; +}; + +struct llm_tokenizer_plamo2_session { + llm_tokenizer_plamo2_session(const llm_tokenizer_plamo2 & tokenizer) : tokenizer(tokenizer) {} + + void tokenize(const std::string & text, std::vector & output) { + std::vector tokens = tokenizer.encode(text); + output.insert(output.end(), tokens.begin(), tokens.end()); + } + +private: + const llm_tokenizer_plamo2 & tokenizer; +}; + // // impl // @@ -1499,6 +1785,16 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { special_unk_id = LLAMA_TOKEN_NULL; special_sep_id = LLAMA_TOKEN_NULL; special_pad_id = LLAMA_TOKEN_NULL; + } else if (tokenizer_model == "plamo2") { + type = LLAMA_VOCAB_TYPE_PLAMO2; + + // PLaMo-2 default special tokens (these will be overridden by model config) + special_bos_id = 1; // <|plamo:bos|> + special_eos_id = 2; // <|plamo:eos|> + special_unk_id = 0; // <|plamo:unk|> + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = 3; // <|plamo:pad|> + special_mask_id = LLAMA_TOKEN_NULL; } else { throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); } @@ -1665,6 +1961,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "hunyuan") { pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN; clean_spaces = false; + } else if ( + tokenizer_pre == "kimi-k2") { + pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2; + clean_spaces = false; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } @@ -2145,13 +2445,14 @@ enum llama_vocab_type llama_vocab::impl::get_type() const { std::string llama_vocab::impl::type_name() const{ switch (type) { - case LLAMA_VOCAB_TYPE_NONE: return "no vocab"; - case LLAMA_VOCAB_TYPE_SPM: return "SPM"; - case LLAMA_VOCAB_TYPE_BPE: return "BPE"; - case LLAMA_VOCAB_TYPE_WPM: return "WPM"; - case LLAMA_VOCAB_TYPE_UGM: return "UGM"; - case LLAMA_VOCAB_TYPE_RWKV: return "RWKV"; - default: return "unknown"; + case LLAMA_VOCAB_TYPE_NONE: return "no vocab"; + case LLAMA_VOCAB_TYPE_SPM: return "SPM"; + case LLAMA_VOCAB_TYPE_BPE: return "BPE"; + case LLAMA_VOCAB_TYPE_WPM: return "WPM"; + case LLAMA_VOCAB_TYPE_UGM: return "UGM"; + case LLAMA_VOCAB_TYPE_RWKV: return "RWKV"; + case LLAMA_VOCAB_TYPE_PLAMO2: return "PLaMo2"; + default: return "unknown"; } } @@ -2234,6 +2535,9 @@ void llama_vocab::impl::init_tokenizer(enum llama_vocab_type type) { case LLAMA_VOCAB_TYPE_RWKV: tokenizer = std::make_unique(vocab); break; + case LLAMA_VOCAB_TYPE_PLAMO2: + tokenizer = std::make_unique(vocab); + break; default: GGML_ABORT("unsupported vocab type"); } @@ -2566,6 +2870,23 @@ std::vector llama_vocab::impl::tokenize( if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { std::string text = fragment.raw_text.substr(fragment.offset, fragment.length); +#ifdef PRETOKENIZERDEBUG + LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str()); +#endif + + session.tokenize(text, output); + } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) + output.push_back(fragment.token); + } + } + } break; + case LLAMA_VOCAB_TYPE_PLAMO2: + { + llm_tokenizer_plamo2_session session(*static_cast(tokenizer.get())); + for (const auto & fragment : fragment_buffer) { + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { + std::string text = fragment.raw_text.substr(fragment.offset, fragment.length); + #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str()); #endif @@ -2664,6 +2985,24 @@ int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t memcpy(buf, result.data(), result.size()); return (int)result.size(); } + case LLAMA_VOCAB_TYPE_PLAMO2: { + // PLaMo-2 uses similar token handling as BPE/SPM + if (vocab.is_byte(token)) { + // Handle byte tokens like <0xXX> + if (token_text.length() == 6 && token_text.substr(0, 3) == "<0x" && token_text.back() == '>') { + int hex_val = std::stoi(token_text.substr(3, 2), nullptr, 16); + if (length < 1) { + return -1; + } + buf[0] = static_cast(hex_val); + return 1; + } + } + + // Normal token - just copy the text + std::string result = token_text; + return _try_copy(result.data(), result.size()); + } default: GGML_ABORT("fatal error"); } @@ -2908,6 +3247,12 @@ llama_token llama_vocab::byte_to_token(uint8_t ch) const { case LLAMA_VOCAB_TYPE_BPE: { return pimpl->token_to_id.at(unicode_byte_to_utf8(ch)); } + case LLAMA_VOCAB_TYPE_PLAMO2: { + // PLaMo-2 uses byte tokens in format <0xXX> + char hex_str[8]; + snprintf(hex_str, sizeof(hex_str), "<0x%02X>", ch); + return pimpl->token_to_id.at(hex_str); + } default: GGML_ABORT("fatal error"); } @@ -3009,6 +3354,10 @@ llama_token llama_vocab::token_fim_sep() const { return pimpl->special_fim_sep_id; } +llama_token llama_vocab::token_mask() const { + return pimpl->special_mask_id; +} + bool llama_vocab::get_add_space_prefix() const { return pimpl->add_space_prefix; } @@ -3249,6 +3598,10 @@ llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab) { return vocab->token_fim_sep(); } +llama_token llama_vocab_mask(const struct llama_vocab* vocab) { + return vocab->token_mask(); +} + // deprecated const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token) { return llama_vocab_get_text(vocab, token); @@ -3385,4 +3738,3 @@ int32_t llama_detokenize( bool unparse_special) { return vocab->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special); } - diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 46a1ccecb51fc..842b129e86171 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -45,6 +45,7 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, + LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37, }; struct LLM_KV; @@ -100,6 +101,7 @@ struct llama_vocab { llama_token token_sep() const; llama_token token_nl () const; llama_token token_pad() const; + llama_token token_mask() const; llama_token token_prefix() const; llama_token token_middle() const; diff --git a/src/unicode.cpp b/src/unicode.cpp index 43a4581b961fe..65f3665171582 100644 --- a/src/unicode.cpp +++ b/src/unicode.cpp @@ -557,6 +557,178 @@ static std::vector unicode_regex_split_stl(const std::string & text, con return bpe_offsets; } +// K2 system regex patterns (from tokenization_kimi.py): +// [\p{Han}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+ +static std::vector unicode_regex_split_custom_kimi_k2(const std::string & text, const std::vector & offsets) { + std::vector bpe_offsets; + bpe_offsets.reserve(offsets.size()); + + const auto cpts = unicode_cpts_from_utf8(text); + + size_t start = 0; + for (auto offset : offsets) { + const size_t offset_ini = start; + const size_t offset_end = start + offset; + assert(offset_end <= cpts.size()); + start = offset_end; + + static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF; + auto _get_cpt = [&] (const size_t pos) -> uint32_t { + return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE; + }; + + auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags { + return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{}; + }; + + size_t _prev_end = offset_ini; + auto _add_token = [&] (const size_t end) -> size_t { + assert(_prev_end <= end && end <= offset_end); + size_t len = end - _prev_end; + if (len > 0) { + bpe_offsets.push_back(len); + } + _prev_end = end; + return len; + }; + + for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) { + const uint32_t cpt = _get_cpt(pos); + const auto flags = _get_flags(pos); + + // Pattern 1: [\p{Han}]+ (Chinese characters) + if (unicode_cpt_is_han(cpt)) { + while (unicode_cpt_is_han(_get_cpt(pos))) { + pos++; + } + _add_token(pos); + continue; + } + + // Pattern 2 & 3: Letter words excluding Han characters with optional contractions + // [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?:'s|'t|'re|'ve|'m|'ll|'d)? + // [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?:'s|'t|'re|'ve|'m|'ll|'d)? + // Check if current char is a letter OR if current char could be a leading char and next char is a letter + bool is_letter_pattern = (flags.is_letter && !unicode_cpt_is_han(cpt)) || + (!(cpt == '\r' || cpt == '\n' || flags.is_letter || flags.is_number) && + _get_flags(pos + 1).is_letter && !unicode_cpt_is_han(_get_cpt(pos + 1))); + + if (is_letter_pattern) { + // Handle optional leading non-letter/non-number character + bool has_leading_char = false; + if (!(cpt == '\r' || cpt == '\n' || flags.is_letter || flags.is_number)) { + has_leading_char = true; + pos++; + } + + // Match letter sequence (excluding Han characters) + bool has_letters = false; + while (_get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos))) { + has_letters = true; + pos++; + } + + // Only proceed if we found letters (after potentially skipping leading char) + if (has_letters || (!has_leading_char && _get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos)))) { + if (!has_letters) pos++; // consume the first letter if we didn't already + + // Continue consuming letters + while (_get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos))) { + pos++; + } + + // Check for optional contractions (?:'s|'t|'re|'ve|'m|'ll|'d) + if (_get_cpt(pos) == '\'' && pos + 1 < offset_end) { + uint32_t cpt_next = unicode_tolower(_get_cpt(pos + 1)); + if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') { + pos += 2; + } else if (pos + 2 < offset_end) { + uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos + 2)); + if ((cpt_next == 'r' && cpt_next_next == 'e') || + (cpt_next == 'v' && cpt_next_next == 'e') || + (cpt_next == 'l' && cpt_next_next == 'l')) { + pos += 3; + } + } + } + + _add_token(pos); + continue; + } else if (has_leading_char) { + // We consumed a leading char but found no letters, backtrack + pos--; + } + } + + // Pattern 4: \p{N}{1,3} (numbers 1-3 digits) + if (flags.is_number) { + size_t ini = pos; + while (_get_flags(pos).is_number) { + if (++pos - ini >= 3) { + _add_token(pos); + ini = pos; + } + } + _add_token(pos); + continue; + } + + // Pattern 5: ?[^\s\p{L}\p{N}]+[\r\n]* (optional space + non-word chars + optional newlines) + auto flags2 = (cpt == ' ' ? _get_flags(pos + 1) : flags); + if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number) && flags2.as_uint()) { + pos += (cpt == ' '); + while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number) && flags2.as_uint()) { + flags2 = _get_flags(++pos); + } + // Match optional [\r\n]* + uint32_t cpt2 = _get_cpt(pos); + while (cpt2 == '\r' || cpt2 == '\n') { + cpt2 = _get_cpt(++pos); + } + _add_token(pos); + continue; + } + + // Count whitespace characters + size_t num_whitespaces = 0; + size_t last_end_r_or_n = 0; + while (_get_flags(pos + num_whitespaces).is_whitespace) { + uint32_t cpt2 = _get_cpt(pos + num_whitespaces); + if (cpt2 == '\r' || cpt2 == '\n') { + last_end_r_or_n = pos + num_whitespaces + 1; + } + num_whitespaces++; + } + + // Pattern 6: \s*[\r\n]+ (whitespace with newlines) + if (last_end_r_or_n > 0) { + pos = last_end_r_or_n; + _add_token(pos); + continue; + } + + // Pattern 7: \s+(?!\S) (trailing whitespace) + if (num_whitespaces > 1 && _get_cpt(pos + num_whitespaces) != OUT_OF_RANGE) { + pos += num_whitespaces - 1; + _add_token(pos); + continue; + } + + // Pattern 8: \s+ (general whitespace) + if (num_whitespaces > 0) { + pos += num_whitespaces; + _add_token(pos); + continue; + } + + // No matches - consume single character + _add_token(++pos); + } + } + + return bpe_offsets; +} + static std::vector unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector & offsets) { std::vector bpe_offsets; @@ -567,6 +739,9 @@ static std::vector unicode_regex_split_custom(const std::string & text, regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") { bpe_offsets = unicode_regex_split_custom_llama3(text, offsets); + } else if (regex_expr == "\\p{Han}+") { + // K2's first pattern - handle all K2 patterns together + bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets); } return bpe_offsets; @@ -672,6 +847,38 @@ uint32_t unicode_tolower(uint32_t cpt) { return cpt; // Return the original code point if no lowercase mapping is found } +bool unicode_cpt_is_han(uint32_t cpt) { + // Han character ranges (Chinese/CJK characters) + // CJK Unified Ideographs (most common) + if (cpt >= 0x4E00 && cpt <= 0x9FFF) return true; + + // CJK Extension A + if (cpt >= 0x3400 && cpt <= 0x4DBF) return true; + + // CJK Extension B + if (cpt >= 0x20000 && cpt <= 0x2A6DF) return true; + + // CJK Extension C + if (cpt >= 0x2A700 && cpt <= 0x2B73F) return true; + + // CJK Extension D + if (cpt >= 0x2B740 && cpt <= 0x2B81F) return true; + + // CJK Extension E + if (cpt >= 0x2B820 && cpt <= 0x2CEAF) return true; + + // CJK Extension F + if (cpt >= 0x2CEB0 && cpt <= 0x2EBEF) return true; + + // CJK Compatibility Ideographs + if (cpt >= 0xF900 && cpt <= 0xFAFF) return true; + + // CJK Compatibility Ideographs Supplement + if (cpt >= 0x2F800 && cpt <= 0x2FA1F) return true; + + return false; +} + std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs) { // unicode categories static const std::map k_ucat_enum = { diff --git a/src/unicode.h b/src/unicode.h index c27098df7d4be..0a5fa2a78ceff 100644 --- a/src/unicode.h +++ b/src/unicode.h @@ -63,4 +63,6 @@ uint8_t unicode_utf8_to_byte(const std::string & utf8); uint32_t unicode_tolower(uint32_t cpt); +bool unicode_cpt_is_han(uint32_t cpt); + std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 81fe90b99323d..a3d68fba046cf 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4282,7 +4282,7 @@ struct test_flash_attn_ext : public test_case { ggml_tensor * m = nullptr; if (mask) { - m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), nr23[0], nr23[1]); + m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, nr23[1]); ggml_set_name(m, "m"); } diff --git a/tools/batched-bench/batched-bench.cpp b/tools/batched-bench/batched-bench.cpp index a0a2e5ac56ea9..03628f74b2880 100644 --- a/tools/batched-bench/batched-bench.cpp +++ b/tools/batched-bench/batched-bench.cpp @@ -127,10 +127,9 @@ int main(int argc, char ** argv) { for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) { for (int i = 0; i < pp; ++i) { - common_batch_add(batch, 0, i, { j }, false); + common_batch_add(batch, 0, i, { j }, i == pp - 1); } } - batch.logits[batch.n_tokens - 1] = true; const auto t_pp_start = ggml_time_us(); diff --git a/tools/server/README.md b/tools/server/README.md index 6f962664f6774..e29511cb1b457 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -7,7 +7,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp. **Features:** * LLM inference of F16 and quantized models on GPU and CPU * [OpenAI API](https://github.com/openai/openai-openapi) compatible chat completions and embeddings routes - * Reranking endoint (https://github.com/ggml-org/llama.cpp/pull/9510) + * Reranking endpoint (https://github.com/ggml-org/llama.cpp/pull/9510) * Parallel decoding with multi-user support * Continuous batching * Multimodal ([documentation](../../docs/multimodal.md)) / with OpenAI-compatible API support diff --git a/tools/server/server.cpp b/tools/server/server.cpp index d4dffb39c8d16..0afe213af1e47 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -127,7 +127,6 @@ struct slot_params { std::vector response_fields; bool timings_per_token = false; bool post_sampling_probs = false; - bool ignore_eos = false; struct common_params_sampling sampling; struct common_params_speculative speculative; @@ -441,7 +440,6 @@ struct server_task { { params.sampling.logit_bias.clear(); - params.ignore_eos = json_value(data, "ignore_eos", false); const auto & logit_bias = data.find("logit_bias"); if (logit_bias != data.end() && logit_bias->is_array()) { @@ -472,6 +470,13 @@ struct server_task { } } } + + params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos); + if (params.sampling.ignore_eos) { + params.sampling.logit_bias.insert( + params.sampling.logit_bias.end(), + defaults.sampling.logit_bias_eog.begin(), defaults.sampling.logit_bias_eog.end()); + } } { @@ -1898,7 +1903,6 @@ struct server_context { bool clean_kv_cache = true; bool add_bos_token = true; - bool has_eos_token = false; int32_t n_ctx; // total context for all clients / slots @@ -1957,7 +1961,6 @@ struct server_context { n_ctx = llama_n_ctx(ctx); add_bos_token = llama_vocab_get_add_bos(vocab); - has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL; if (!params_base.speculative.model.path.empty() || !params_base.speculative.model.hf_repo.empty()) { SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str()); @@ -2217,10 +2220,6 @@ struct server_context { slot.params.n_predict = slot.n_predict; } - if (slot.params.ignore_eos && has_eos_token) { - slot.params.sampling.logit_bias.push_back({llama_vocab_eos(vocab), -INFINITY}); - } - { if (slot.smpl != nullptr) { common_sampler_free(slot.smpl); diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index 6c2e91359a663..f3dfc8225da4d 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -11,6 +11,8 @@ // increase max payload length to allow use of larger context size #define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576 +// increase backlog size to avoid connection resets for >> 1 slots +#define CPPHTTPLIB_LISTEN_BACKLOG 512 // disable Nagle's algorithm #define CPPHTTPLIB_TCP_NODELAY true #include