diff --git a/.devops/main-cuda.Dockerfile b/.devops/main-cuda.Dockerfile index b9f4873937b..c2bf0fbd1c6 100644 --- a/.devops/main-cuda.Dockerfile +++ b/.devops/main-cuda.Dockerfile @@ -1,6 +1,6 @@ ARG UBUNTU_VERSION=22.04 # This needs to generally match the container host's environment. -ARG CUDA_VERSION=12.3.1 +ARG CUDA_VERSION=13.0.0 # Target the CUDA build image ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION} # Target the CUDA runtime image @@ -20,12 +20,12 @@ RUN apt-get update && \ && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* # Ref: https://stackoverflow.com/a/53464012 -ENV CUDA_MAIN_VERSION=12.3 +ENV CUDA_MAIN_VERSION=13.0 ENV LD_LIBRARY_PATH /usr/local/cuda-${CUDA_MAIN_VERSION}/compat:$LD_LIBRARY_PATH COPY .. . # Enable cuBLAS -RUN make base.en CMAKE_ARGS="-DGGML_CUDA=1" +RUN make base.en CMAKE_ARGS="-DGGML_CUDA=1 -DCMAKE_CUDA_ARCHITECTURES='75;80;86;90'" RUN find /app/build -name "*.o" -delete && \ find /app/build -name "*.a" -delete && \ @@ -34,7 +34,7 @@ RUN find /app/build -name "*.o" -delete && \ rm -rf /app/build/_deps FROM ${BASE_CUDA_RUN_CONTAINER} AS runtime -ENV CUDA_MAIN_VERSION=12.3 +ENV CUDA_MAIN_VERSION=13.0 ENV LD_LIBRARY_PATH /usr/local/cuda-${CUDA_MAIN_VERSION}/compat:$LD_LIBRARY_PATH WORKDIR /app diff --git a/.devops/main-musa.Dockerfile b/.devops/main-musa.Dockerfile index 0c9be7c3e6e..026791e3f89 100644 --- a/.devops/main-musa.Dockerfile +++ b/.devops/main-musa.Dockerfile @@ -1,10 +1,10 @@ ARG UBUNTU_VERSION=22.04 # This needs to generally match the container host's environment. -ARG MUSA_VERSION=rc4.0.1 +ARG MUSA_VERSION=rc4.2.0 # Target the MUSA build image -ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-mudnn-devel-ubuntu${UBUNTU_VERSION} +ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-devel-ubuntu${UBUNTU_VERSION}-amd64 # Target the MUSA runtime image -ARG BASE_MUSA_RUN_CONTAINER=mthreads/musa:${MUSA_VERSION}-mudnn-runtime-ubuntu${UBUNTU_VERSION} +ARG BASE_MUSA_RUN_CONTAINER=mthreads/musa:${MUSA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}-amd64 FROM ${BASE_MUSA_DEV_CONTAINER} AS build WORKDIR /app diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0e80bdfae78..dd4eff2c7fb 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -6,6 +6,25 @@ on: - master tags: - 'v*' + paths: ['.github/workflows/build.yml', + '**/CMakeLists.txt', + '**/Makefile', + '**/*.mk', + '**/*.cmake', + '**/*.in', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.cu', + '**/*.cuh', + '**/*.cl', + '**/*.swift', + '**/*.m', + '**/*.mm', + '**/*.metal', + '**/*.comp', + '**/*.java'] pull_request: types: [opened, synchronize, reopened] workflow_dispatch: @@ -222,7 +241,8 @@ jobs: - name: Dependencies run: | brew update - brew install sdl2 cmake + cmake --version + brew install sdl2 - name: Build run: | diff --git a/README.md b/README.md index 6b81a54f7d8..e6c07bbcb00 100644 --- a/README.md +++ b/README.md @@ -386,7 +386,7 @@ Run the inference examples as usual, for example: ## Moore Threads GPU support With Moore Threads cards the processing of the models is done efficiently on the GPU via muBLAS and custom MUSA kernels. -First, make sure you have installed `MUSA SDK rc4.0.1`: https://developer.mthreads.com/sdk/download/musa?equipment=&os=&driverVersion=&version=4.0.1 +First, make sure you have installed `MUSA SDK rc4.2.0`: https://developer.mthreads.com/sdk/download/musa?equipment=&os=&driverVersion=&version=4.2.0 Now build `whisper.cpp` with MUSA support: diff --git a/bindings/go/Makefile b/bindings/go/Makefile index edcc0166b74..e4436a6a291 100644 --- a/bindings/go/Makefile +++ b/bindings/go/Makefile @@ -15,7 +15,7 @@ BUILD_DIR := build_go MODELS_DIR := models EXAMPLES_DIR := $(wildcard examples/*) INCLUDE_PATH := $(abspath ../../include):$(abspath ../../ggml/include) -LIBRARY_PATH := $(abspath ../../${BUILD_DIR}/src:$(abspath ../../${BUILD_DIR}/ggml/src)) +LIBRARY_PATH := $(abspath ../../${BUILD_DIR}/src):$(abspath ../../${BUILD_DIR}/ggml/src) ifeq ($(GGML_CUDA),1) LIBRARY_PATH := $(LIBRARY_PATH):$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib/ @@ -23,7 +23,8 @@ ifeq ($(GGML_CUDA),1) endif ifeq ($(UNAME_S),Darwin) - EXT_LDFLAGS := -framework Foundation -framework Metal -framework MetalKit + LIBRARY_PATH := $(LIBRARY_PATH):$(abspath ../../${BUILD_DIR}/ggml/src/ggml-blas):$(abspath ../../${BUILD_DIR}/ggml/src/ggml-metal) + EXT_LDFLAGS := -framework Foundation -framework Metal -framework MetalKit -lggml-metal -lggml-blas endif all: clean whisper examples diff --git a/bindings/go/whisper.go b/bindings/go/whisper.go index 525b72d2318..3ef73414d90 100644 --- a/bindings/go/whisper.go +++ b/bindings/go/whisper.go @@ -9,7 +9,9 @@ import ( // CGO /* -#cgo LDFLAGS: -lwhisper -lggml -lggml-base -lggml-cpu -lm -lstdc++ -fopenmp +#cgo LDFLAGS: -lwhisper -lggml -lggml-base -lggml-cpu -lm -lstdc++ +#cgo linux LDFLAGS: -fopenmp +#cgo darwin LDFLAGS: -lggml-metal -lggml-blas #cgo darwin LDFLAGS: -framework Accelerate -framework Metal -framework Foundation -framework CoreGraphics #include #include diff --git a/bindings/ruby/ext/ruby_whisper_params.c b/bindings/ruby/ext/ruby_whisper_params.c index 71337c818c3..882c68d042f 100644 --- a/bindings/ruby/ext/ruby_whisper_params.c +++ b/bindings/ruby/ext/ruby_whisper_params.c @@ -26,7 +26,7 @@ rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \ rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1); -#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 35 +#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 36 extern VALUE cParams; extern VALUE cVADParams; @@ -49,6 +49,7 @@ static ID id_print_timestamps; static ID id_suppress_blank; static ID id_suppress_nst; static ID id_token_timestamps; +static ID id_max_len; static ID id_split_on_word; static ID id_initial_prompt; static ID id_diarize; @@ -514,6 +515,33 @@ ruby_whisper_params_set_token_timestamps(VALUE self, VALUE value) { BOOL_PARAMS_SETTER(self, token_timestamps, value) } + +/* + * max segment length in characters. + * + * call-seq: + * max_len -> Integer + */ +static VALUE +ruby_whisper_params_get_max_len(VALUE self) +{ + ruby_whisper_params *rwp; + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); + return INT2NUM(rwp->params.max_len); +} +/* + * call-seq: + * max_len = length -> length + */ +static VALUE +ruby_whisper_params_set_max_len(VALUE self, VALUE value) +{ + ruby_whisper_params *rwp; + TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); + rwp->params.max_len = NUM2INT(value); + return value; +} + /* * If true, split on word rather than on token (when used with max_len). * @@ -1137,6 +1165,7 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self) SET_PARAM_IF_SAME(suppress_blank) SET_PARAM_IF_SAME(suppress_nst) SET_PARAM_IF_SAME(token_timestamps) + SET_PARAM_IF_SAME(max_len) SET_PARAM_IF_SAME(split_on_word) SET_PARAM_IF_SAME(initial_prompt) SET_PARAM_IF_SAME(offset) @@ -1271,30 +1300,31 @@ init_ruby_whisper_params(VALUE *mWhisper) DEFINE_PARAM(suppress_blank, 8) DEFINE_PARAM(suppress_nst, 9) DEFINE_PARAM(token_timestamps, 10) - DEFINE_PARAM(split_on_word, 11) - DEFINE_PARAM(initial_prompt, 12) - DEFINE_PARAM(diarize, 13) - DEFINE_PARAM(offset, 14) - DEFINE_PARAM(duration, 15) - DEFINE_PARAM(max_text_tokens, 16) - DEFINE_PARAM(temperature, 17) - DEFINE_PARAM(max_initial_ts, 18) - DEFINE_PARAM(length_penalty, 19) - DEFINE_PARAM(temperature_inc, 20) - DEFINE_PARAM(entropy_thold, 21) - DEFINE_PARAM(logprob_thold, 22) - DEFINE_PARAM(no_speech_thold, 23) - DEFINE_PARAM(new_segment_callback, 24) - DEFINE_PARAM(new_segment_callback_user_data, 25) - DEFINE_PARAM(progress_callback, 26) - DEFINE_PARAM(progress_callback_user_data, 27) - DEFINE_PARAM(encoder_begin_callback, 28) - DEFINE_PARAM(encoder_begin_callback_user_data, 29) - DEFINE_PARAM(abort_callback, 30) - DEFINE_PARAM(abort_callback_user_data, 31) - DEFINE_PARAM(vad, 32) - DEFINE_PARAM(vad_model_path, 33) - DEFINE_PARAM(vad_params, 34) + DEFINE_PARAM(max_len, 11) + DEFINE_PARAM(split_on_word, 12) + DEFINE_PARAM(initial_prompt, 13) + DEFINE_PARAM(diarize, 14) + DEFINE_PARAM(offset, 15) + DEFINE_PARAM(duration, 16) + DEFINE_PARAM(max_text_tokens, 17) + DEFINE_PARAM(temperature, 18) + DEFINE_PARAM(max_initial_ts, 19) + DEFINE_PARAM(length_penalty, 20) + DEFINE_PARAM(temperature_inc, 21) + DEFINE_PARAM(entropy_thold, 22) + DEFINE_PARAM(logprob_thold, 23) + DEFINE_PARAM(no_speech_thold, 24) + DEFINE_PARAM(new_segment_callback, 25) + DEFINE_PARAM(new_segment_callback_user_data, 26) + DEFINE_PARAM(progress_callback, 27) + DEFINE_PARAM(progress_callback_user_data, 28) + DEFINE_PARAM(encoder_begin_callback, 29) + DEFINE_PARAM(encoder_begin_callback_user_data, 30) + DEFINE_PARAM(abort_callback, 31) + DEFINE_PARAM(abort_callback_user_data, 32) + DEFINE_PARAM(vad, 33) + DEFINE_PARAM(vad_model_path, 34) + DEFINE_PARAM(vad_params, 35) rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0); rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0); diff --git a/bindings/ruby/lib/whisper/model/uri.rb b/bindings/ruby/lib/whisper/model/uri.rb index 5bf5a098624..d8a98699f07 100644 --- a/bindings/ruby/lib/whisper/model/uri.rb +++ b/bindings/ruby/lib/whisper/model/uri.rb @@ -94,7 +94,7 @@ def download(response) end def show_progress(current, size) - progress_rate_available = size && $stderr.tty? + progress_rate_available = size && $stderr.tty? && $stderr.winsize[1] >= line.size unless @prev @prev = Time.now diff --git a/bindings/ruby/sig/whisper.rbs b/bindings/ruby/sig/whisper.rbs index 5966ce31592..0489432a249 100644 --- a/bindings/ruby/sig/whisper.rbs +++ b/bindings/ruby/sig/whisper.rbs @@ -135,6 +135,7 @@ module Whisper ?suppress_blank: boolish, ?suppress_nst: boolish, ?token_timestamps: boolish, + ?max_len: Integer, ?split_on_word: boolish, ?initial_prompt: string | nil, ?diarize: boolish, @@ -222,6 +223,12 @@ module Whisper # def token_timestamps: () -> (true | false) + def max_len=: (Integer) -> Integer + + # max segment length in characters. + # + def max_len: () -> Integer + def split_on_word=: (boolish) -> boolish # If true, split on word rather than on token (when used with max_len). diff --git a/bindings/ruby/test/test_params.rb b/bindings/ruby/test/test_params.rb index 9a9535799b7..d5c5d140e8c 100644 --- a/bindings/ruby/test/test_params.rb +++ b/bindings/ruby/test/test_params.rb @@ -13,6 +13,7 @@ class TestParams < TestBase :suppress_blank, :suppress_nst, :token_timestamps, + :max_len, :split_on_word, :initial_prompt, :diarize, @@ -139,6 +140,13 @@ def test_token_timestamps assert !@params.token_timestamps end + def test_max_len + @params.max_len = 42 + assert_equal @params.max_len, 42 + @params.max_len = 0 + assert_equal @params.max_len, 0 + end + def test_split_on_word @params.split_on_word = true assert @params.split_on_word diff --git a/build-xcframework.sh b/build-xcframework.sh index 337e1936a56..bbf2764d729 100755 --- a/build-xcframework.sh +++ b/build-xcframework.sh @@ -15,6 +15,7 @@ GGML_METAL_EMBED_LIBRARY=ON GGML_BLAS_DEFAULT=ON GGML_METAL_USE_BF16=ON GGML_OPENMP=OFF +BUILD_STATIC_XCFRAMEWORK=${BUILD_STATIC_XCFRAMEWORK:-OFF} COMMON_C_FLAGS="-Wno-macro-redefined -Wno-shorten-64-to-32 -Wno-unused-command-line-argument -g" COMMON_CXX_FLAGS="-Wno-macro-redefined -Wno-shorten-64-to-32 -Wno-unused-command-line-argument -g" @@ -327,6 +328,15 @@ combine_static_libraries() { arch_flags+=" -arch $arch" done + + if [[ "${BUILD_STATIC_XCFRAMEWORK}" == "ON" ]]; then + echo "Packaging static framework for ${platform}." + mkdir -p "$(dirname "${base_dir}/${output_lib}")" + cp "${temp_dir}/combined.a" "${base_dir}/${output_lib}" + rm -rf "${temp_dir}" + return + fi + # Create dynamic library echo "Creating dynamic library for ${platform}." xcrun -sdk $sdk clang++ -dynamiclib \ @@ -529,6 +539,20 @@ combine_static_libraries "build-tvos-device" "Release-appletvos" "tvos" "false" # Create XCFramework with correct debug symbols paths echo "Creating XCFramework..." + +if [[ "${BUILD_STATIC_XCFRAMEWORK}" == "ON" ]]; then + xcodebuild -create-xcframework \ + -framework $(pwd)/build-ios-sim/framework/whisper.framework \ + -framework $(pwd)/build-ios-device/framework/whisper.framework \ + -framework $(pwd)/build-macos/framework/whisper.framework \ + -framework $(pwd)/build-visionos/framework/whisper.framework \ + -framework $(pwd)/build-visionos-sim/framework/whisper.framework \ + -framework $(pwd)/build-tvos-device/framework/whisper.framework \ + -framework $(pwd)/build-tvos-sim/framework/whisper.framework \ + -output $(pwd)/build-apple/whisper.xcframework + exit 0 +fi + xcodebuild -create-xcframework \ -framework $(pwd)/build-ios-sim/framework/whisper.framework \ -debug-symbols $(pwd)/build-ios-sim/dSYMs/whisper.dSYM \ diff --git a/examples/addon.node/index.js b/examples/addon.node/index.js index 9324d6fa548..2e4ed610b66 100644 --- a/examples/addon.node/index.js +++ b/examples/addon.node/index.js @@ -1,8 +1,10 @@ -const path = require("path"); -const { whisper } = require(path.join( - __dirname, - "../../build/Release/addon.node" -)); +const path = require('path'); +const os = require('os'); + +const isWindows = os.platform() === 'win32'; +const buildPath = isWindows ? "../../build/bin/Release/addon.node" : "../../build/Release/addon.node"; + +const { whisper } = require(path.join(__dirname, buildPath)); const { promisify } = require("util"); const whisperAsync = promisify(whisper); diff --git a/examples/bench.wasm/README.md b/examples/bench.wasm/README.md index 1a16a11018f..0ebf0a21f2d 100644 --- a/examples/bench.wasm/README.md +++ b/examples/bench.wasm/README.md @@ -32,6 +32,16 @@ cp bin/libbench.js /path/to/html/ cp bin/libbench.worker.js /path/to/html/ ``` +> 📝 **Note:** By default this example is built with `WHISPER_WASM_SINGLE_FILE=ON` +> which means that that a separate .wasm file will not be generated. Instead, the +> WASM module is embedded in the main JS file as a base64 encoded string. To +> generate a separate .wasm file, you need to disable this option by passing +> `-DWHISPER_WASM_SINGLE_FILE=OFF`: +> ```console +> emcmake cmake .. -DWHISPER_WASM_SINGLE_FILE=OFF +> ``` +> This will generate a `libbench.wasm` file in the build/bin directory. + > 📝 **Note:** As of Emscripten 3.1.58 (April 2024), separate worker.js files are no > longer generated and the worker is embedded in the main JS file. So the worker > file will not be geneated for versions later than `3.1.58`. diff --git a/examples/bench.wasm/index-tmpl.html b/examples/bench.wasm/index-tmpl.html index e9b49e07216..91589c35b3f 100644 --- a/examples/bench.wasm/index-tmpl.html +++ b/examples/bench.wasm/index-tmpl.html @@ -191,15 +191,15 @@ function loadWhisper(model) { let urls = { - 'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin', - 'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin', - 'small.en': 'https://whisper.ggerganov.com/ggml-model-whisper-small.en.bin', - - 'tiny-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin', - 'base-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en-q5_1.bin', - 'small-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-small.en-q5_1.bin', - 'medium-en-q5_0':'https://whisper.ggerganov.com/ggml-model-whisper-medium.en-q5_0.bin', - 'large-q5_0': 'https://whisper.ggerganov.com/ggml-model-whisper-large-q5_0.bin', + 'tiny.en': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en.bin', + 'base.en': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin', + 'small.en': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.en.bin', + + 'tiny-en-q5_1': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en-q5_1.bin', + 'base-en-q5_1': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en-q5_1.bin', + 'small-en-q5_1': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.en-q5_1.bin', + 'medium-en-q5_0':'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-medium.en-q5_0.bin', + 'large-q5_0': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-q5_0.bin', }; let sizes = { diff --git a/examples/command.wasm/README.md b/examples/command.wasm/README.md index b46b89e2043..c50c2b43ed5 100644 --- a/examples/command.wasm/README.md +++ b/examples/command.wasm/README.md @@ -32,6 +32,16 @@ cp bin/libcommand.js /path/to/html/ cp bin/libcommand.worker.js /path/to/html/ ``` +> 📝 **Note:** By default this example is built with `WHISPER_WASM_SINGLE_FILE=ON` +> which means that that a separate .wasm file will not be generated. Instead, the +> WASM module is embedded in the main JS file as a base64 encoded string. To +> generate a separate .wasm file, you need to disable this option by passing +> `-DWHISPER_WASM_SINGLE_FILE=OFF`: +> ```console +> emcmake cmake .. -DWHISPER_WASM_SINGLE_FILE=OFF +> ``` +> This will generate a `libcommand.wasm` file in the build/bin directory. + > 📝 **Note:** As of Emscripten 3.1.58 (April 2024), separate worker.js files are no > longer generated and the worker is embedded in the main JS file. So the worker > file will not be geneated for versions later than `3.1.58`. diff --git a/examples/command.wasm/index-tmpl.html b/examples/command.wasm/index-tmpl.html index 752d851e9cf..2221e9340ba 100644 --- a/examples/command.wasm/index-tmpl.html +++ b/examples/command.wasm/index-tmpl.html @@ -174,11 +174,11 @@ function loadWhisper(model) { let urls = { - 'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin', - 'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin', + 'tiny.en': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en.bin', + 'base.en': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin', - 'tiny-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin', - 'base-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en-q5_1.bin', + 'tiny-en-q5_1': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en-q5_1.bin', + 'base-en-q5_1': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en-q5_1.bin', }; let sizes = { diff --git a/examples/common-ggml.cpp b/examples/common-ggml.cpp index d59b77fce6b..c42b644fedd 100644 --- a/examples/common-ggml.cpp +++ b/examples/common-ggml.cpp @@ -72,6 +72,7 @@ bool ggml_common_quantize_0( case GGML_FTYPE_MOSTLY_IQ4_XS: case GGML_FTYPE_MOSTLY_IQ1_M: case GGML_FTYPE_MOSTLY_BF16: + case GGML_FTYPE_MOSTLY_MXFP4: { fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype); return false; @@ -211,6 +212,7 @@ bool ggml_common_quantize_0( case GGML_TYPE_BF16: case GGML_TYPE_TQ1_0: case GGML_TYPE_TQ2_0: + case GGML_TYPE_MXFP4: case GGML_TYPE_COUNT: { fprintf(stderr, "%s: unsupported quantization type %d (%s)\n", __func__, ttype, ggml_type_name((ggml_type) ttype)); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 643d08a799a..901f65f6c35 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -104,6 +104,7 @@ struct whisper_params { bool flash_attn = false; bool suppress_nst = false; bool no_context = false; + bool no_language_probabilities = false; std::string language = "en"; std::string prompt = ""; @@ -178,6 +179,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -nc, --no-context [%-7s] do not use previous audio context\n", params.no_context ? "true" : "false"); fprintf(stderr, " -ng, --no-gpu [%-7s] do not use gpu\n", params.use_gpu ? "false" : "true"); fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false"); + fprintf(stderr, " -nlp, --no-language-probabilities [%-7s] exclude language probabilities from verbose_json output\n", params.no_language_probabilities ? "true" : "false"); // Voice Activity Detection (VAD) parameters fprintf(stderr, "\nVoice Activity Detection (VAD) options:\n"); fprintf(stderr, " --vad [%-7s] enable Voice Activity Detection (VAD)\n", params.vad ? "true" : "false"); @@ -237,6 +239,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; } else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(argv[++i]); } else if (arg == "-nc" || arg == "--no-context") { params.no_context = true; } + else if (arg == "-nlp" || arg == "--no-language-probabilities") { params.no_language_probabilities = true; } // server params else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); } @@ -599,6 +602,10 @@ void get_req_parameters(const Request & req, whisper_params & params) { params.vad_samples_overlap = std::stof(req.get_file_value("vad_samples_overlap").content); } + if (req.has_file("no_language_probabilities")) + { + params.no_language_probabilities = parse_str_to_bool(req.get_file_value("no_language_probabilities").content); + } } } // namespace @@ -1024,23 +1031,25 @@ int main(int argc, char ** argv) { } else if (params.response_format == vjson_format) { /* try to match openai/whisper's Python format */ std::string results = output_str(ctx, params, pcmf32s); - // Get language probabilities - std::vector lang_probs(whisper_lang_max_id() + 1, 0.0f); - const auto detected_lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, lang_probs.data()); json jres = json{ {"task", params.translate ? "translate" : "transcribe"}, {"language", whisper_lang_str_full(whisper_full_lang_id(ctx))}, {"duration", float(pcmf32.size())/WHISPER_SAMPLE_RATE}, {"text", results}, - {"segments", json::array()}, - {"detected_language", whisper_lang_str_full(detected_lang_id)}, - {"detected_language_probability", lang_probs[detected_lang_id]}, - {"language_probabilities", json::object()} + {"segments", json::array()} }; - // Add all language probabilities - for (int i = 0; i <= whisper_lang_max_id(); ++i) { - if (lang_probs[i] > 0.001f) { // Only include non-negligible probabilities - jres["language_probabilities"][whisper_lang_str(i)] = lang_probs[i]; + // Only compute language probabilities if requested (expensive operation) + if (!params.no_language_probabilities) { + std::vector lang_probs(whisper_lang_max_id() + 1, 0.0f); + const auto detected_lang_id = whisper_lang_auto_detect(ctx, 0, params.n_threads, lang_probs.data()); + jres["detected_language"] = whisper_lang_str_full(detected_lang_id); + jres["detected_language_probability"] = lang_probs[detected_lang_id]; + jres["language_probabilities"] = json::object(); + // Add all language probabilities + for (int i = 0; i <= whisper_lang_max_id(); ++i) { + if (lang_probs[i] > 0.001f) { // Only include non-negligible probabilities + jres["language_probabilities"][whisper_lang_str(i)] = lang_probs[i]; + } } } const int n_segments = whisper_full_n_segments(ctx); diff --git a/examples/stream.wasm/README.md b/examples/stream.wasm/README.md index 29ff982d617..431555655ac 100644 --- a/examples/stream.wasm/README.md +++ b/examples/stream.wasm/README.md @@ -30,6 +30,16 @@ cp bin/libstream.js /path/to/html/ cp bin/libstream.worker.js /path/to/html/ ``` +> 📝 **Note:** By default this example is built with `WHISPER_WASM_SINGLE_FILE=ON` +> which means that that a separate .wasm file will not be generated. Instead, the +> WASM module is embedded in the main JS file as a base64 encoded string. To +> generate a separate .wasm file, you need to disable this option by passing +> `-DWHISPER_WASM_SINGLE_FILE=OFF`: +> ```console +> emcmake cmake .. -DWHISPER_WASM_SINGLE_FILE=OFF +> ``` +> This will generate a `libstream.wasm` file in the build/bin directory. + > 📝 **Note:** As of Emscripten 3.1.58 (April 2024), separate worker.js files are no > longer generated and the worker is embedded in the main JS file. So the worker > file will not be geneated for versions later than `3.1.58`. diff --git a/examples/stream.wasm/emscripten.cpp b/examples/stream.wasm/emscripten.cpp index 43e71bf23f0..5dff24ad3bd 100644 --- a/examples/stream.wasm/emscripten.cpp +++ b/examples/stream.wasm/emscripten.cpp @@ -31,10 +31,11 @@ void stream_set_status(const std::string & status) { g_status = status; } -void stream_main(size_t index) { +void stream_main(size_t index, const std::string & lang) { stream_set_status("loading data ..."); struct whisper_full_params wparams = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY); + bool is_multilingual = whisper_is_multilingual(g_contexts[index]); wparams.n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency()); wparams.offset_ms = 0; @@ -52,7 +53,7 @@ void stream_main(size_t index) { // disable temperature fallback wparams.temperature_inc = -1.0f; - wparams.language = "en"; + wparams.language = is_multilingual ? lang.c_str() : "en"; printf("stream: using %d threads\n", wparams.n_threads); @@ -127,9 +128,8 @@ void stream_main(size_t index) { g_contexts[index] = nullptr; } } - EMSCRIPTEN_BINDINGS(stream) { - emscripten::function("init", emscripten::optional_override([](const std::string & path_model) { + emscripten::function("init", emscripten::optional_override([](const std::string & path_model, const std::string & lang) { for (size_t i = 0; i < g_contexts.size(); ++i) { if (g_contexts[i] == nullptr) { g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params()); @@ -138,8 +138,8 @@ EMSCRIPTEN_BINDINGS(stream) { if (g_worker.joinable()) { g_worker.join(); } - g_worker = std::thread([i]() { - stream_main(i); + g_worker = std::thread([i, lang]() { + stream_main(i, lang); }); return i + 1; diff --git a/examples/stream.wasm/index-tmpl.html b/examples/stream.wasm/index-tmpl.html index c831b2f52b7..941f45075c4 100644 --- a/examples/stream.wasm/index-tmpl.html +++ b/examples/stream.wasm/index-tmpl.html @@ -55,6 +55,7 @@ Whisper model: +

Quantized models:

@@ -66,6 +67,77 @@ --> + + + + +
+ Language: + +
+
@@ -174,16 +246,18 @@ function loadWhisper(model) { let urls = { - 'tiny.en': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin', - 'base.en': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin', + 'tiny.en': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en.bin', + 'base.en': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin', + 'base' : 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.bin', - 'tiny-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin', - 'base-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en-q5_1.bin', + 'tiny-en-q5_1': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en-q5_1.bin', + 'base-en-q5_1': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en-q5_1.bin', }; let sizes = { 'tiny.en': 75, 'base.en': 142, + 'base': 142, 'tiny-en-q5_1': 31, 'base-en-q5_1': 57, @@ -197,6 +271,7 @@ document.getElementById('fetch-whisper-tiny-en').style.display = 'none'; document.getElementById('fetch-whisper-base-en').style.display = 'none'; + document.getElementById('fetch-whisper-base').style.display = 'none'; document.getElementById('fetch-whisper-tiny-en-q5_1').style.display = 'none'; document.getElementById('fetch-whisper-base-en-q5_1').style.display = 'none'; @@ -212,6 +287,7 @@ var el; el = document.getElementById('fetch-whisper-tiny-en'); if (el) el.style.display = 'inline-block'; el = document.getElementById('fetch-whisper-base-en'); if (el) el.style.display = 'inline-block'; + el = document.getElementById('fetch-whisper-base'); if (el) el.style.display = 'inline-block'; el = document.getElementById('fetch-whisper-tiny-en-q5_1'); if (el) el.style.display = 'inline-block'; el = document.getElementById('fetch-whisper-base-en-q5_1'); if (el) el.style.display = 'inline-block'; @@ -368,7 +444,7 @@ function onStart() { if (!instance) { - instance = Module.init('whisper.bin'); + instance = Module.init('whisper.bin', document.getElementById('language').value); if (instance) { printTextarea("js: whisper initialized, instance: " + instance); diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index e63ab284bc3..18dcc6ddfe5 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -34,6 +34,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_PHI3, "phi3" }, { LLM_ARCH_PHIMOE, "phimoe" }, { LLM_ARCH_PLAMO, "plamo" }, + { LLM_ARCH_PLAMO2, "plamo2" }, { LLM_ARCH_CODESHELL, "codeshell" }, { LLM_ARCH_ORION, "orion" }, { LLM_ARCH_INTERNLM2, "internlm2" }, @@ -61,12 +62,14 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_DEEPSEEK2, "deepseek2" }, { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_GLM4, "glm4" }, + { LLM_ARCH_GLM4_MOE, "glm4moe" }, { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5ENCODER, "t5encoder" }, { LLM_ARCH_JAIS, "jais" }, { LLM_ARCH_NEMOTRON, "nemotron" }, { LLM_ARCH_EXAONE, "exaone" }, + { LLM_ARCH_EXAONE4, "exaone4" }, { LLM_ARCH_RWKV6, "rwkv6" }, { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" }, { LLM_ARCH_RWKV7, "rwkv7" }, @@ -81,9 +84,15 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_DOTS1, "dots1" }, { LLM_ARCH_ARCEE, "arcee" }, { LLM_ARCH_ERNIE4_5, "ernie4_5" }, + { LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" }, { LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" }, + { LLM_ARCH_HUNYUAN_DENSE, "hunyuan-dense" }, { LLM_ARCH_SMOLLM3, "smollm3" }, + { LLM_ARCH_OPENAI_MOE, "gpt-oss" }, { LLM_ARCH_LFM2, "lfm2" }, + { LLM_ARCH_DREAM, "dream" }, + { LLM_ARCH_SMALLTHINKER, "smallthinker" }, + { LLM_ARCH_LLADA, "llada" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -120,6 +129,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" }, { LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" }, { LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" }, + { LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" }, { LLM_KV_POOLING_TYPE, "%s.pooling_type" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, @@ -784,6 +794,36 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_PLAMO2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_X, "blk.%d.ssm_x" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + { LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" }, + { LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" }, + { LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, { LLM_ARCH_CODESHELL, { @@ -1354,6 +1394,40 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, }, }, + { + LLM_ARCH_GLM4_MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + // NextN/MTP tensors - preserved but unused (in final layer, dynamic layer number) + { LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" }, + { LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" }, + { LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" }, + { LLM_TENSOR_NEXTN_HNORM, "blk.%d.nextn.hnorm" }, + { LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.nextn.shared_head_head" }, + { LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" }, + }, + }, { LLM_ARCH_BITNET, { @@ -1477,6 +1551,26 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_EXAONE4, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + } + }, { LLM_ARCH_RWKV6, { @@ -1793,6 +1887,31 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_ERNIE4_5_MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + }, + }, { LLM_ARCH_HUNYUAN_MOE, { @@ -1816,6 +1935,26 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_HUNYUAN_DENSE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + + }, + }, { LLM_ARCH_SMOLLM3, { @@ -1833,6 +1972,25 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_OPENAI_MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_SINKS, "blk.%d.attn_sinks" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_LFM2, { @@ -1854,6 +2012,61 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, } }, + { + LLM_ARCH_SMALLTHINKER, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" } + }, + }, + { + LLM_ARCH_DREAM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_LLADA, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -1893,6 +2106,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_SINKS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SCALE}}, {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, @@ -2024,6 +2238,14 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SHORTCONV_CONV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}}, {LLM_TENSOR_SHORTCONV_INPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SHORTCONV_OUTPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + // NextN/MTP tensors are currently ignored (reserved for future MTP support) + // These tensors only exist in the last layer(s) and are treated as output tensors + {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} @@ -2094,6 +2316,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { switch (arch) { case LLM_ARCH_JAMBA: case LLM_ARCH_FALCON_H1: + case LLM_ARCH_PLAMO2: case LLM_ARCH_GRANITE_HYBRID: case LLM_ARCH_LFM2: return true; @@ -2101,3 +2324,13 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { return false; } } + +bool llm_arch_is_diffusion(const llm_arch & arch) { + switch (arch) { + case LLM_ARCH_DREAM: + case LLM_ARCH_LLADA: + return true; + default: + return false; + } +} diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index 1f973259524..7af587e7951 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -38,6 +38,7 @@ enum llm_arch { LLM_ARCH_PHI3, LLM_ARCH_PHIMOE, LLM_ARCH_PLAMO, + LLM_ARCH_PLAMO2, LLM_ARCH_CODESHELL, LLM_ARCH_ORION, LLM_ARCH_INTERNLM2, @@ -65,12 +66,14 @@ enum llm_arch { LLM_ARCH_DEEPSEEK2, LLM_ARCH_CHATGLM, LLM_ARCH_GLM4, + LLM_ARCH_GLM4_MOE, LLM_ARCH_BITNET, LLM_ARCH_T5, LLM_ARCH_T5ENCODER, LLM_ARCH_JAIS, LLM_ARCH_NEMOTRON, LLM_ARCH_EXAONE, + LLM_ARCH_EXAONE4, LLM_ARCH_RWKV6, LLM_ARCH_RWKV6QWEN2, LLM_ARCH_RWKV7, @@ -85,9 +88,15 @@ enum llm_arch { LLM_ARCH_DOTS1, LLM_ARCH_ARCEE, LLM_ARCH_ERNIE4_5, + LLM_ARCH_ERNIE4_5_MOE, LLM_ARCH_HUNYUAN_MOE, + LLM_ARCH_HUNYUAN_DENSE, LLM_ARCH_SMOLLM3, + LLM_ARCH_OPENAI_MOE, LLM_ARCH_LFM2, + LLM_ARCH_DREAM, + LLM_ARCH_SMALLTHINKER, + LLM_ARCH_LLADA, LLM_ARCH_UNKNOWN, }; @@ -124,6 +133,7 @@ enum llm_kv { LLM_KV_EXPERT_WEIGHTS_NORM, LLM_KV_EXPERT_GATING_FUNC, LLM_KV_MOE_EVERY_N_LAYERS, + LLM_KV_NEXTN_PREDICT_LAYERS, LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, LLM_KV_DECODER_START_TOKEN_ID, @@ -256,6 +266,7 @@ enum llm_tensor { LLM_TENSOR_ATTN_OUT_NORM, LLM_TENSOR_ATTN_POST_NORM, LLM_TENSOR_ATTN_ROT_EMBD, + LLM_TENSOR_ATTN_SINKS, LLM_TENSOR_FFN_GATE_INP, LLM_TENSOR_FFN_GATE_INP_SHEXP, LLM_TENSOR_FFN_NORM, @@ -402,6 +413,12 @@ enum llm_tensor { LLM_TENSOR_SHORTCONV_CONV, LLM_TENSOR_SHORTCONV_INPROJ, LLM_TENSOR_SHORTCONV_OUTPROJ, + LLM_TENSOR_NEXTN_EH_PROJ, + LLM_TENSOR_NEXTN_EMBED_TOKENS, + LLM_TENSOR_NEXTN_ENORM, + LLM_TENSOR_NEXTN_HNORM, + LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, + LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, }; enum llm_tensor_layer { @@ -478,3 +495,4 @@ const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor); bool llm_arch_is_recurrent(const llm_arch & arch); bool llm_arch_is_hybrid (const llm_arch & arch); +bool llm_arch_is_diffusion(const llm_arch & arch); diff --git a/examples/talk-llama/llama-batch.cpp b/examples/talk-llama/llama-batch.cpp index 3bc8554e51c..55d89eca0ad 100644 --- a/examples/talk-llama/llama-batch.cpp +++ b/examples/talk-llama/llama-batch.cpp @@ -27,6 +27,7 @@ bool llama_batch_allocr::init( const llama_vocab & vocab, const llama_memory_i * memory, uint32_t n_embd, + uint32_t n_seq_max, bool output_all) { clear(); @@ -40,6 +41,11 @@ bool llama_batch_allocr::init( // validate input batch // + if (n_seq_max > LLAMA_MAX_SEQ) { + LLAMA_LOG_ERROR("%s: n_seq_max = %d > %d\n", __func__, n_seq_max, LLAMA_MAX_SEQ); + return false; + } + if (batch.token) { for (int32_t i = 0; i < batch.n_tokens; ++i) { if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) { @@ -52,8 +58,8 @@ bool llama_batch_allocr::init( if (batch.seq_id) { for (int32_t i = 0; i < batch.n_tokens; ++i) { for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) { - if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) { - LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ); + if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) { + LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d >= %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max); return false; } } @@ -86,7 +92,7 @@ bool llama_batch_allocr::init( // initialize the starting position for each sequence based on the positions in the memory llama_pos p0[LLAMA_MAX_SEQ]; - for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < n_seq_max; ++s) { if (!memory) { // if no memory -> start from 0 p0[s] = 0; @@ -143,13 +149,16 @@ bool llama_batch_allocr::init( // compute stats // - this->n_embd = n_embd; + this->n_embd = n_embd; + this->n_seq_max = n_seq_max; // count the outputs in this batch for (int32_t i = 0; i < batch.n_tokens; ++i) { n_outputs += batch.logits[i] != 0; } + has_cpl = false; + // determine coupled sequences // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them for (int32_t i = 0; i < batch.n_tokens; ++i) { @@ -189,7 +198,7 @@ bool llama_batch_allocr::init( seq_set_map[cur].push_back(i); } - for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < n_seq_max; ++s) { if (seq_set_unq.test(s)) { seq_idx[s] = seq_id_unq.size(); seq_id_unq.push_back(s); @@ -201,7 +210,7 @@ bool llama_batch_allocr::init( LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__); llama_ubatch ubatch { - /*.equal_seqs =*/ false, + /*.b_equal_seqs =*/ false, /*.n_tokens =*/ (uint32_t) batch.n_tokens, /*.n_seq_tokens =*/ (uint32_t) 1, /*.n_seqs =*/ (uint32_t) batch.n_tokens, @@ -214,6 +223,7 @@ bool llama_batch_allocr::init( /*.seq_id_unq =*/ this->seq_id_unq.data(), /*.seq_idx =*/ this->seq_idx.data(), /*.output =*/ batch.logits, + /*.data =*/ {}, }; ubatch_print(ubatch, debug); @@ -241,7 +251,7 @@ bool llama_batch_allocr::init( // consistency checks // - for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < n_seq_max; ++s) { if (seq_pos[s].empty()) { continue; } @@ -284,8 +294,8 @@ bool llama_batch_allocr::init( } if (memory) { - for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) { - for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) { + for (uint32_t s0 = 0; s0 < n_seq_max; ++s0) { + for (uint32_t s1 = 0; s1 < n_seq_max; ++s1) { if (seq_cpl[s0][s1]) { if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) || memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) { @@ -316,12 +326,12 @@ bool llama_batch_allocr::init( // { seq_set_t cur_seq_set[LLAMA_MAX_SEQ]; - for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < n_seq_max; ++s) { cur_seq_set[s].set(); } llama_pos cur_seq_pos[LLAMA_MAX_SEQ]; - for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < n_seq_max; ++s) { cur_seq_pos[s] = -1; } @@ -357,39 +367,38 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t clear(); split_reset(); - ubatches.emplace_back(); + auto udata = std::make_shared(); - auto & ubatch = ubatches.back(); - - ubatch.token .resize(n_tokens); - ubatch.embd .clear(); - ubatch.pos .resize(n_tokens); - ubatch.n_seq_id .resize(n_tokens); - ubatch.seq_id .resize(n_tokens); - ubatch.seq_id_unq.resize(0); - ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1); - ubatch.output .resize(n_tokens); + udata->token .resize(n_tokens); + udata->embd .clear(); + udata->pos .resize(n_tokens); + udata->n_seq_id .resize(n_tokens); + udata->seq_id .resize(n_tokens); + udata->seq_id_unq.resize(0); + udata->seq_idx .resize(LLAMA_MAX_SEQ, -1); + udata->output .resize(n_tokens); for (uint32_t s = 0; s < n_seqs; ++s) { - ubatch.seq_idx[s] = s; - ubatch.seq_id_unq.push_back(s); + udata->seq_idx[s] = s; + udata->seq_id_unq.push_back(s); } llama_ubatch res { - /*.equal_seqs =*/ true, + /*.b_equal_seqs =*/ true, /*.n_tokens =*/ n_tokens, /*.n_seq_tokens =*/ n_seq_tokens, /*.n_seqs =*/ n_seqs, /*.n_seqs_unq =*/ n_seqs, - /*.token =*/ ubatch.token.data(), + /*.token =*/ udata->token.data(), /*.embd =*/ nullptr, - /*.pos =*/ ubatch.pos.data(), - /*.n_seq_id =*/ ubatch.n_seq_id.data(), - /*.seq_id =*/ ubatch.seq_id.data(), - /*.seq_id_unq =*/ ubatch.seq_id_unq.data(), - /*.seq_idx =*/ ubatch.seq_idx.data(), - /*.output =*/ ubatch.output.data(), + /*.pos =*/ udata->pos.data(), + /*.n_seq_id =*/ udata->n_seq_id.data(), + /*.seq_id =*/ udata->seq_id.data(), + /*.seq_id_unq =*/ udata->seq_id_unq.data(), + /*.seq_idx =*/ udata->seq_idx.data(), + /*.output =*/ udata->output.data(), + /*.data =*/ std::move(udata), }; return res; @@ -430,8 +439,6 @@ void llama_batch_allocr::split_reset() { used.clear(); used.resize(get_n_tokens(), false); - - ubatches.clear(); } llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) { @@ -470,7 +477,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) { llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) { if (sequential && has_cpl) { - LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__); + LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch (you may need to use the -kvu flag)\n", __func__); return {}; } @@ -646,78 +653,77 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u assert(n_tokens%n_seqs == 0); - ubatches.emplace_back(); - - auto & ubatch = ubatches.back(); + auto udata = std::make_shared(); const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1; const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0; const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur; - ubatch.token .resize(n_tokens); - ubatch.embd .resize(n_embd_all); - ubatch.pos .resize(n_pos_all); - ubatch.n_seq_id .resize(n_tokens); - ubatch.seq_id .resize(n_tokens); - ubatch.seq_id_unq.resize(0); - ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1); - ubatch.output .resize(n_tokens); + udata->token .resize(n_tokens); + udata->embd .resize(n_embd_all); + udata->pos .resize(n_pos_all); + udata->n_seq_id .resize(n_tokens); + udata->seq_id .resize(n_tokens); + udata->seq_id_unq.resize(0); + udata->seq_idx .resize(LLAMA_MAX_SEQ, -1); + udata->output .resize(n_tokens); seq_set_t seq_set_unq; for (size_t i = 0; i < idxs.size(); ++i) { if (batch.token) { - ubatch.token[i] = batch.token[idxs[i]]; + udata->token[i] = batch.token[idxs[i]]; } if (batch.embd) { - memcpy(ubatch.embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float)); + memcpy(udata->embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float)); } for (int j = 0; j < n_pos_cur; ++j) { - ubatch.pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]]; + udata->pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]]; } - ubatch.n_seq_id[i] = batch.n_seq_id[idxs[i]]; - ubatch.seq_id[i] = batch.seq_id[idxs[i]]; - ubatch.output[i] = batch.logits[idxs[i]]; + udata->n_seq_id[i] = batch.n_seq_id[idxs[i]]; + udata->seq_id[i] = batch.seq_id[idxs[i]]; + udata->output[i] = batch.logits[idxs[i]]; - for (int s = 0; s < ubatch.n_seq_id[i]; ++s) { - seq_set_unq.set(ubatch.seq_id[i][s]); + for (int s = 0; s < udata->n_seq_id[i]; ++s) { + seq_set_unq.set(udata->seq_id[i][s]); } - if (ubatch.output[i]) { + if (udata->output[i]) { out_ids.push_back(idxs[i]); } } - for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < n_seq_max; ++s) { if (seq_set_unq.test(s)) { - ubatch.seq_idx[s] = ubatch.seq_id_unq.size(); - ubatch.seq_id_unq.push_back(s); + udata->seq_idx[s] = udata->seq_id_unq.size(); + udata->seq_id_unq.push_back(s); } } llama_ubatch res { - /*.equal_seqs =*/ equal_seqs, + /*.b_equal_seqs =*/ equal_seqs, /*.n_tokens =*/ n_tokens, /*.n_seq_tokens =*/ n_tokens/n_seqs, /*.n_seqs =*/ n_seqs, - /*.n_seqs_unq =*/ (uint32_t) ubatch.seq_id_unq.size(), - - /*.token =*/ batch.token ? ubatch.token.data() : nullptr, - /*.embd =*/ batch.embd ? ubatch.embd.data() : nullptr, - /*.pos =*/ ubatch.pos.data(), - /*.n_seq_id =*/ ubatch.n_seq_id.data(), - /*.seq_id =*/ ubatch.seq_id.data(), - /*.seq_id_unq =*/ ubatch.seq_id_unq.data(), - /*.seq_idx =*/ ubatch.seq_idx.data(), - /*.output =*/ ubatch.output.data(), + /*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(), + + /*.token =*/ batch.token ? udata->token.data() : nullptr, + /*.embd =*/ batch.embd ? udata->embd.data() : nullptr, + /*.pos =*/ udata->pos.data(), + /*.n_seq_id =*/ udata->n_seq_id.data(), + /*.seq_id =*/ udata->seq_id.data(), + /*.seq_id_unq =*/ udata->seq_id_unq.data(), + /*.seq_idx =*/ udata->seq_idx.data(), + /*.output =*/ udata->output.data(), + /*.data =*/ std::move(udata), }; if (debug > 0) { - LLAMA_LOG_DEBUG("%s: added ubatch %d to split:\n", __func__, (int) ubatches.size() - 1); + LLAMA_LOG_DEBUG("%s: added ubatch to split:\n", __func__); ubatch_print(res, debug); } @@ -727,7 +733,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) { if (debug > 0) { - LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs); + LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs()); LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, ubatch.n_tokens); LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d\n", __func__, ubatch.n_seq_tokens); LLAMA_LOG_DEBUG("%s: n_seqs = %d\n", __func__, ubatch.n_seqs); diff --git a/examples/talk-llama/llama-batch.h b/examples/talk-llama/llama-batch.h index 3420803ff94..d563adc66aa 100644 --- a/examples/talk-llama/llama-batch.h +++ b/examples/talk-llama/llama-batch.h @@ -8,12 +8,17 @@ #include #include #include +#include #include // keep this struct lightweight -// it points to data in `llama_batch_allocr` struct llama_ubatch { - bool equal_seqs; + bool equal_seqs() const { + return b_equal_seqs != 0; + } + + uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment + // otherwise address sanitizer complains // TODO: whole_seqs for embeddings? uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs) @@ -34,6 +39,20 @@ struct llama_ubatch { llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx int8_t * output; // [n_tokens] | i | - + + struct data_t { + std::vector token; + std::vector embd; + std::vector pos; + std::vector n_seq_id; + std::vector seq_id; + std::vector seq_id_unq; + std::vector seq_idx; + std::vector output; + }; + + // the llama_ubatch pointers above point to this data if set. otherwise - points to non-owning data + std::shared_ptr data; }; // a helper for sanitizing, fulfilling and splitting a batch @@ -48,6 +67,7 @@ class llama_batch_allocr { const llama_vocab & vocab, const llama_memory_i * memory, uint32_t n_embd, + uint32_t n_seq_max, bool output_all); const llama_batch & get_batch() const; @@ -100,6 +120,7 @@ class llama_batch_allocr { const uint32_t n_pos_per_embd; uint32_t n_embd; + uint32_t n_seq_max; uint32_t n_outputs; std::array seq_id_0 = { 0 }; // default sequence id @@ -115,7 +136,7 @@ class llama_batch_allocr { using seq_cpl_t = std::vector; // helper flag to quickly determine if there are any coupled sequences in the batch - bool has_cpl; + bool has_cpl = false; std::vector seq_pos; // seq_pos[s]: the set of positions in sequence s std::vector seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1 @@ -135,20 +156,5 @@ class llama_batch_allocr { // used[i] indicates if token i has already been used in a previous ubatch std::vector used; - // llama_ubatch points to this data: - struct ubatch { - std::vector token; - std::vector embd; - std::vector pos; - std::vector n_seq_id; - std::vector seq_id; - std::vector seq_id_unq; - std::vector seq_idx; - std::vector output; - }; - - // current splitting state: - std::vector ubatches; - int debug; }; diff --git a/examples/talk-llama/llama-chat.cpp b/examples/talk-llama/llama-chat.cpp index cbc19d3c40c..0a96a9a579e 100644 --- a/examples/talk-llama/llama-chat.cpp +++ b/examples/talk-llama/llama-chat.cpp @@ -56,6 +56,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE }, { "minicpm", LLM_CHAT_TEMPLATE_MINICPM }, { "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 }, + { "exaone4", LLM_CHAT_TEMPLATE_EXAONE_4 }, { "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD }, { "granite", LLM_CHAT_TEMPLATE_GRANITE }, { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT }, @@ -65,6 +66,9 @@ static const std::map LLM_CHAT_TEMPLATES = { { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 }, { "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM }, { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE }, + { "gpt-oss", LLM_CHAT_TEMPLATE_OPENAI_MOE }, + { "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE }, + { "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 }, }; llm_chat_template llm_chat_template_from_str(const std::string & name) { @@ -167,10 +171,13 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { } else if (tmpl_contains(LU8("<|Assistant|>")) && tmpl_contains(LU8("<|User|>")) && tmpl_contains(LU8("<|end▁of▁sentence|>"))) { return LLM_CHAT_TEMPLATE_DEEPSEEK_3; } else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) { + if (tmpl_contains("[|tool|]")) { + return LLM_CHAT_TEMPLATE_EXAONE_4; + } // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb // EXAONE-3.0-7.8B-Instruct return LLM_CHAT_TEMPLATE_EXAONE_3; - } else if (tmpl_contains("rwkv-world")) { + } else if (tmpl_contains("rwkv-world") || tmpl_contains("{{- 'User: ' + message['content']|trim + '\\n\\n' -}}")) { return LLM_CHAT_TEMPLATE_RWKV_WORLD; } else if (tmpl_contains("<|start_of_role|>")) { return LLM_CHAT_TEMPLATE_GRANITE; @@ -186,8 +193,14 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_LLAMA4; } else if (tmpl_contains("<|endofuserprompt|>")) { return LLM_CHAT_TEMPLATE_DOTS1; - } else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) { + } else if (tmpl_contains("<|extra_0|>") && tmpl_contains("<|extra_4|>")) { return LLM_CHAT_TEMPLATE_HUNYUAN_MOE; + } else if (tmpl_contains("<|start|>") && tmpl_contains("<|channel|>")) { + return LLM_CHAT_TEMPLATE_OPENAI_MOE; + } else if (tmpl_contains("<|hy_Assistant|>") && tmpl_contains("<|hy_place▁holder▁no▁3|>")) { + return LLM_CHAT_TEMPLATE_HUNYUAN_DENSE; + } else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) { + return LLM_CHAT_TEMPLATE_KIMI_K2; } return LLM_CHAT_TEMPLATE_UNKNOWN; } @@ -529,6 +542,22 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "[|assistant|]"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_4) { + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "[|system|]" << trim(message->content) << "[|endofturn|]\n"; + } else if (role == "user") { + ss << "[|user|]" << trim(message->content) << "\n"; + } else if (role == "assistant") { + ss << "[|assistant|]" << trim(message->content) << "[|endofturn|]\n"; + } else if (role == "tool") { + ss << "[|tool|]" << trim(message->content) << "[|endofturn|]\n"; + } + } + if (add_ass) { + ss << "[|assistant|]"; + } } else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) { // this template requires the model to have "\n\n" as EOT token for (size_t i = 0; i < chat.size(); i++) { @@ -596,8 +625,6 @@ int32_t llm_chat_apply_template( } else if (tmpl == LLM_CHAT_TEMPLATE_YANDEX) { // Yandex template ("\n\n" is defined as EOT token) - ss << ""; - for (size_t i = 0; i < chat.size(); i++) { std::string role(chat[i]->role); if (role == "user") { @@ -675,11 +702,56 @@ int32_t llm_chat_apply_template( if (role == "system") { ss << "<|startoftext|>" << message->content << "<|extra_4|>"; } else if (role == "assistant") { - ss << "<|startoftext|>" << message->content << "<|eos|>"; + ss << message->content << "<|eos|>"; } else { ss << "<|startoftext|>" << message->content << "<|extra_0|>"; } } + } else if (tmpl == LLM_CHAT_TEMPLATE_OPENAI_MOE) { + // OpenAI MoE (based on Harmony chat template) + for (auto message : chat) { + std::string role(message->role); + ss << "<|start|>" << role << "<|message|>" << message->content; + ss << (role == "assistant" ? "<|return|>" : "<|end|>"); + } + if (add_ass) { + ss << "<|start|>assistant"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_DENSE) { + // tencent/Hunyuan-4B-Instruct + for (size_t i = 0; i < chat.size(); i++) { + std::string role(chat[i]->role); + if (i == 0) { + if (role == "system") { + ss << chat[i]->content << "<|hy_place▁holder▁no▁3|>"; + } + } + + if (role == "assistant") { + ss << "<|hy_Assistant|>" << chat[i]->content << "<|hy_place▁holder▁no▁2|>"; + } else if (role == "user") { + ss << "<|hy_User|>" << chat[i]->content << "<|hy_Assistant|>"; + } + } + } else if (tmpl == LLM_CHAT_TEMPLATE_KIMI_K2) { + // moonshotai/Kimi-K2-Instruct + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "<|im_system|>system<|im_middle|>"; + } else if (role == "user") { + ss << "<|im_user|>user<|im_middle|>"; + } else if (role == "assistant") { + ss << "<|im_assistant|>assistant<|im_middle|>"; + } else if (role == "tool") { + ss << "<|im_system|>tool<|im_middle|>"; + } + + ss << message->content << "<|im_end|>"; + } + if (add_ass) { + ss << "<|im_assistant|>assistant<|im_middle|>"; + } } else { // template not supported return -1; diff --git a/examples/talk-llama/llama-chat.h b/examples/talk-llama/llama-chat.h index b621fda2816..35a943856fa 100644 --- a/examples/talk-llama/llama-chat.h +++ b/examples/talk-llama/llama-chat.h @@ -35,6 +35,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_GLMEDGE, LLM_CHAT_TEMPLATE_MINICPM, LLM_CHAT_TEMPLATE_EXAONE_3, + LLM_CHAT_TEMPLATE_EXAONE_4, LLM_CHAT_TEMPLATE_RWKV_WORLD, LLM_CHAT_TEMPLATE_GRANITE, LLM_CHAT_TEMPLATE_GIGACHAT, @@ -45,6 +46,9 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_SMOLVLM, LLM_CHAT_TEMPLATE_DOTS1, LLM_CHAT_TEMPLATE_HUNYUAN_MOE, + LLM_CHAT_TEMPLATE_OPENAI_MOE, + LLM_CHAT_TEMPLATE_HUNYUAN_DENSE, + LLM_CHAT_TEMPLATE_KIMI_K2, LLM_CHAT_TEMPLATE_UNKNOWN, }; diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index 06e93b19cbf..7d7abad5d4a 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -98,10 +98,29 @@ llama_context::llama_context( LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD); cparams.n_batch = GGML_KQ_MASK_PAD; } - cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); cparams.op_offload = params.op_offload; + cparams.kv_unified = params.kv_unified; + + { + const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS"); + supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : supports_set_rows; + + if (!supports_set_rows && !cparams.kv_unified) { + LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__); + cparams.kv_unified = true; + } + } + + { + const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE"); + graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable; + + if (graph_reuse_disable) { + LLAMA_LOG_WARN("%s: graph reuse disabled\n", __func__); + } + } const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; @@ -112,6 +131,7 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn); LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); + LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false"); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); @@ -227,8 +247,8 @@ llama_context::llama_context( LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes); - // buffer used to store the computation graph and the tensor meta data - buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false)); + gf_res_prev.reset(new llm_graph_result(max_nodes)); + gf_res_reserve.reset(new llm_graph_result(max_nodes)); // TODO: move these checks to ggml_backend_sched // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary @@ -267,7 +287,7 @@ llama_context::llama_context( // reserve worst-case graph if (!hparams.vocab_only && memory) { - const uint32_t n_seqs = cparams.n_seq_max; + const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); @@ -287,7 +307,7 @@ llama_context::llama_context( cross.v_embd.clear(); - // reserve pp graph first so that buffers are only allocated once + // reserve pp (prompt processing) graph first so that buffers are only allocated once { auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); if (!gf) { @@ -298,9 +318,9 @@ llama_context::llama_context( n_nodes_pp = ggml_graph_n_nodes(gf); } - // reserve with tg graph to get the number of splits and nodes + // reserve with tg (token generation) graph to get the number of splits and nodes { - auto * gf = graph_reserve(1, 1, 1, mctx.get()); + auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get()); if (!gf) { throw std::runtime_error("failed to allocate compute tg buffers"); } @@ -311,6 +331,10 @@ llama_context::llama_context( // reserve again with pp graph to avoid ggml-alloc reallocations during inference { + // TODO: not sure if the following graph would be worster case for multi-stream KV caches: + // + // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); + // auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); @@ -388,10 +412,6 @@ ggml_backend_sched_t llama_context::get_sched() const { return sched.get(); } -ggml_context * llama_context::get_ctx_compute() const { - return ctx_compute.get(); -} - uint32_t llama_context::n_ctx() const { return cparams.n_ctx; } @@ -463,6 +483,11 @@ bool llama_context::kv_self_update(bool optimize) { } } + // reset the previous graph result to make sure that it won't be reused + // TODO: change the mctx->apply() to return information if a graph reserve is needed + // reset the graph result only if the memory module did reset the scheduler + gf_res_prev->reset(); + if (!mctx->apply()) { LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__); } @@ -475,7 +500,7 @@ bool llama_context::kv_self_update(bool optimize) { throw std::runtime_error("failed to initialize memory context"); } - const uint32_t n_seqs = cparams.n_seq_max; + const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); @@ -492,12 +517,16 @@ enum llama_pooling_type llama_context::pooling_type() const { } float * llama_context::get_logits() { + output_reorder(); + return logits; } float * llama_context::get_logits_ith(int32_t i) { int64_t j = -1; + output_reorder(); + try { if (logits == nullptr) { throw std::runtime_error("no logits"); @@ -534,12 +563,16 @@ float * llama_context::get_logits_ith(int32_t i) { } float * llama_context::get_embeddings() { + output_reorder(); + return embd; } float * llama_context::get_embeddings_ith(int32_t i) { int64_t j = -1; + output_reorder(); + try { if (embd == nullptr) { throw std::runtime_error("no embeddings"); @@ -678,38 +711,59 @@ bool llama_context::apply_adapter_cvec( return cvec.apply(model, data, len, n_embd, il_start, il_end); } -llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { +llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { if (mctx && !mctx->apply()) { LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); ret = GGML_STATUS_FAILED; return nullptr; } - auto * gf = graph_init(); - if (!gf) { - LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__); - ret = GGML_STATUS_FAILED; - return nullptr; - } + auto * res = gf_res_prev.get(); + auto * gf = res->get_gf(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx); - if (!res) { - LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__); - ret = GGML_STATUS_FAILED; - return nullptr; - } + // the new graph parameters + // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters + const auto gparams = graph_params(res, ubatch, mctx, gtype); - // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); + if (!graph_reuse_disable && res->can_reuse(gparams)) { + //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__); - if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { - LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); - ret = GGML_STATUS_ALLOC_FAILED; - return nullptr; + n_reused++; + } else { + res->reset(); + + ggml_backend_sched_reset(sched.get()); + ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); + + //const auto t_start_us = ggml_time_us(); + + gf = model.build_graph(gparams); + + //LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); + + if (!gf) { + LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__); + ret = GGML_STATUS_FAILED; + return nullptr; + } + + if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); + ret = GGML_STATUS_ALLOC_FAILED; + return nullptr; + } } - res->set_inputs(&ubatch); + // set the input data for the input tensors + { + //const auto t_start_us = ggml_time_us(); - const auto status = graph_compute(gf, ubatch.n_tokens > 1); + res->set_inputs(&ubatch); + + //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); + } + + const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1); if (status != GGML_STATUS_SUCCESS) { LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status); ret = status; @@ -731,16 +785,19 @@ int llama_context::encode(const llama_batch & batch_inp) { const auto & hparams = model.hparams; - const int64_t n_embd = hparams.n_embd; + const int64_t n_embd = hparams.n_embd; + const int64_t n_vocab = model.vocab.n_tokens(); // note: during encode, we always pass the full sequence starting from pos = 0 - if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) { + if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); return -1; } const uint32_t n_tokens = balloc->get_n_tokens(); + // [TAG_NO_CACHE_PAD] + // TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true const llama_ubatch ubatch = balloc->split_simple(n_tokens); // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot @@ -767,9 +824,6 @@ int llama_context::encode(const llama_batch & batch_inp) { n_outputs = n_tokens; - ggml_backend_sched_reset(sched.get()); - ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); - const auto causal_attn_org = cparams.causal_attn; // always use non-causal attention for encoder graphs @@ -778,7 +832,7 @@ int llama_context::encode(const llama_batch & batch_inp) { cparams.causal_attn = false; ggml_status status; - const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); cparams.causal_attn = causal_attn_org; @@ -791,10 +845,20 @@ int llama_context::encode(const llama_batch & batch_inp) { } } + auto * t_logits = res->get_logits(); auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + // extract logits + if (logits && t_logits) { + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); + GGML_ASSERT(backend_res != nullptr); + GGML_ASSERT(logits != nullptr); + + ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float)); + } + // extract embeddings - if (t_embd) { + if (embd && t_embd) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr); @@ -844,9 +908,11 @@ int llama_context::encode(const llama_batch & batch_inp) { } } - // Reset state for the next token before backend sync, to allow the CPU activities in the reset to - // overlap with device computation. - ggml_backend_sched_reset(sched.get()); + if (!supports_set_rows) { + // Reset state for the next token before backend sync, to allow the CPU activities in the reset to + // overlap with device computation. + ggml_backend_sched_reset(sched.get()); + } // TODO: hacky solution if (model.arch == LLM_ARCH_T5 && t_embd) { @@ -893,13 +959,13 @@ int llama_context::decode(const llama_batch & batch_inp) { const auto & vocab = model.vocab; const auto & hparams = model.hparams; - const int32_t n_vocab = vocab.n_tokens(); + const int64_t n_vocab = vocab.n_tokens(); const int64_t n_embd = hparams.n_embd; // when computing embeddings, all tokens are output const bool output_all = cparams.embeddings; - if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) { + if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) { LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); return -1; } @@ -927,6 +993,7 @@ int llama_context::decode(const llama_batch & batch_inp) { // TODO: this clear of the buffer can easily be forgotten - need something better embd_seq.clear(); + output_swaps.clear(); bool did_optimize = false; @@ -1005,11 +1072,8 @@ int llama_context::decode(const llama_batch & batch_inp) { n_outputs = n_outputs_new; } - ggml_backend_sched_reset(sched.get()); - ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); - ggml_status status; - const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); if (!res) { // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache @@ -1149,9 +1213,6 @@ int llama_context::decode(const llama_batch & batch_inp) { // make the outputs have the same order they had in the user-provided batch // note: this is mostly relevant for recurrent models atm if (!sorted_output) { - const uint32_t n_vocab = model.vocab.n_tokens(); - const uint64_t n_embd = model.hparams.n_embd; - GGML_ASSERT((size_t) n_outputs == out_ids.size()); // TODO: is there something more efficient which also minimizes swaps? @@ -1167,16 +1228,9 @@ int llama_context::decode(const llama_batch & batch_inp) { continue; } std::swap(out_ids[i], out_ids[j_min]); - if (logits_size > 0) { - for (uint32_t k = 0; k < n_vocab; k++) { - std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]); - } - } - if (embd_size > 0) { - for (uint32_t k = 0; k < n_embd; k++) { - std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]); - } - } + + // remember the swaps and apply them lazily upon logits/embeddings access + output_swaps.push_back({ i, j_min }); } std::fill(output_ids.begin(), output_ids.end(), -1); @@ -1190,9 +1244,11 @@ int llama_context::decode(const llama_batch & batch_inp) { // wait for the computation to finish (automatically done when obtaining the model output) //synchronize(); - // Reset state for the next token before backend sync, to allow the CPU activities in the reset to - // overlap with device computation. - ggml_backend_sched_reset(sched.get()); + if (!supports_set_rows) { + // Reset state for the next token before backend sync, to allow the CPU activities in the reset to + // overlap with device computation. + ggml_backend_sched_reset(sched.get()); + } return 0; } @@ -1271,24 +1327,40 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { return n_outputs_max; } +void llama_context::output_reorder() { + const uint64_t n_vocab = model.vocab.n_tokens(); + const uint64_t n_embd = model.hparams.n_embd; + + for (size_t s = 0; s < output_swaps.size(); ++s) { + const uint64_t i0 = output_swaps[s].i0; + const uint64_t i1 = output_swaps[s].i1; + + if (logits_size > 0) { + for (uint64_t k = 0; k < n_vocab; k++) { + std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]); + } + } + + if (embd_size > 0) { + for (uint64_t k = 0; k < n_embd; k++) { + std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]); + } + } + } + + output_swaps.clear(); +} + // // graph // -int32_t llama_context::graph_max_nodes() const { - return std::max(65536, 5*model.n_tensors()); +uint32_t llama_context::graph_max_nodes() const { + return std::max(1024u, 8u*model.n_tensors()); } -ggml_cgraph * llama_context::graph_init() { - ggml_init_params params = { - /*.mem_size =*/ buf_compute_meta.size(), - /*.mem_buffer =*/ buf_compute_meta.data(), - /*.no_alloc =*/ true, - }; - - ctx_compute.reset(ggml_init(params)); - - return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false); +llm_graph_result * llama_context::get_gf_res_reserve() const { + return static_cast(gf_res_reserve.get()); } ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) { @@ -1301,6 +1373,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs); } + ggml_backend_sched_reset(sched.get()); + + // when the scheduler is reset, we cannnot reuse the old graph, so we reset the previous graph result to prevent that + gf_res_prev->reset(); + // store the n_outputs as it is, and restore it afterwards // TODO: not sure if needed, might simplify in the future by removing this const auto save_n_outputs = this->n_outputs; @@ -1310,17 +1387,15 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u llama_batch_allocr balloc(model.hparams.n_pos_per_embd()); llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs); - auto * gf = graph_init(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx); + auto * res = gf_res_reserve.get(); - this->n_outputs = save_n_outputs; + const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT); - if (!res) { - LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__); - return nullptr; - } + res->reset(); - ggml_backend_sched_reset(sched.get()); + auto * gf = model.build_graph(gparams); + + this->n_outputs = save_n_outputs; // initialize scheduler with the specified graph if (!ggml_backend_sched_reserve(sched.get(), gf)) { @@ -1331,28 +1406,27 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u return gf; } -llm_graph_result_ptr llama_context::graph_build( - ggml_context * ctx, - ggml_cgraph * gf, - const llama_ubatch & ubatch, - llm_graph_type gtype, - const llama_memory_context_i * mctx) { - return model.build_graph( - { - /*.ctx =*/ ctx, - /*.arch =*/ model.arch, - /*.hparams =*/ model.hparams, - /*.cparams =*/ cparams, - /*.ubatch =*/ ubatch, - /*.sched =*/ sched.get(), - /*.backend_cpu =*/ backend_cpu, - /*.cvec =*/ &cvec, - /*.loras =*/ &loras, - /*.mctx =*/ mctx, - /*.cross =*/ &cross, - /*.n_outputs =*/ n_outputs, - /*.cb =*/ graph_get_cb(), - }, gf, gtype); +llm_graph_params llama_context::graph_params( + llm_graph_result * res, + const llama_ubatch & ubatch, + const llama_memory_context_i * mctx, + llm_graph_type gtype) const { + return { + /*.arch =*/ model.arch, + /*.hparams =*/ model.hparams, + /*.cparams =*/ cparams, + /*.ubatch =*/ ubatch, + /*.gtype =*/ gtype, + /*.sched =*/ sched.get(), + /*.backend_cpu =*/ backend_cpu, + /*.cvec =*/ &cvec, + /*.loras =*/ &loras, + /*.mctx =*/ mctx, + /*.cross =*/ &cross, + /*.n_outputs =*/ n_outputs, + /*.cb =*/ graph_get_cb(), + /*.res =*/ res, + }; } ggml_status llama_context::graph_compute( @@ -1583,30 +1657,30 @@ size_t llama_context::state_set_data(const uint8_t * src, size_t size) { } } -size_t llama_context::state_seq_get_size(llama_seq_id seq_id) { +size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags) { llama_io_write_dummy io; try { - return state_seq_write_data(io, seq_id); + return state_seq_write_data(io, seq_id, flags); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); return 0; } } -size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) { +size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags) { llama_io_write_buffer io(dst, size); try { - return state_seq_write_data(io, seq_id); + return state_seq_write_data(io, seq_id, flags); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); return 0; } } -size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) { +size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags) { llama_io_read_buffer io(src, size); try { - return state_seq_read_data(io, seq_id); + return state_seq_read_data(io, seq_id, flags); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); return 0; @@ -1704,7 +1778,7 @@ size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * file { const size_t state_size = file.size() - file.tell(); llama_io_read_file io(&file); - const size_t nread = state_seq_read_data(io, seq_id); + const size_t nread = state_seq_read_data(io, seq_id, 0); if (!nread) { LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__); return 0; @@ -1728,7 +1802,7 @@ size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * file // save the context state using stream saving llama_io_write_file io(&file); - state_seq_write_data(io, seq_id); + state_seq_write_data(io, seq_id, 0); const size_t res = file.tell(); GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes()); @@ -1897,21 +1971,21 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { return io.n_bytes(); } -size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) { +size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { GGML_UNUSED(seq_id); if (memory) { - memory->state_write(io, seq_id); + memory->state_write(io, seq_id, flags); } return io.n_bytes(); } -size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) { +size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { GGML_UNUSED(seq_id); if (memory) { - memory->state_read(io, seq_id); + memory->state_read(io, seq_id, flags); } return io.n_bytes(); @@ -1930,6 +2004,7 @@ llama_perf_context_data llama_context::perf_get_data() const { data.t_eval_ms = 1e-3 * t_eval_us; data.n_p_eval = std::max(1, n_p_eval); data.n_eval = std::max(1, n_eval); + data.n_reused = std::max(0, n_reused); return data; } @@ -1938,6 +2013,7 @@ void llama_context::perf_reset() { t_start_us = ggml_time_us(); t_eval_us = n_eval = 0; t_p_eval_us = n_p_eval = 0; + n_reused = 0; } // @@ -1972,7 +2048,7 @@ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params opt_params.opt_period = n_batch / n_ubatch; opt_params.get_opt_pars = lopt_params.get_opt_pars; opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud; - + opt_params.optimizer = lopt_params.optimizer_type; opt_ctx = ggml_opt_init(opt_params); llama_opt_param_filter param_filter = lopt_params.param_filter; @@ -2028,7 +2104,7 @@ void llama_context::opt_epoch_iter( batch.logits [pos_batch] = true; } - if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) { + if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); return; } @@ -2064,8 +2140,13 @@ void llama_context::opt_epoch_iter( break; } - auto * gf = graph_init(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get()); + auto * res = gf_res_prev.get(); + + const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT); + + res->reset(); + + auto * gf = model.build_graph(gparams); struct ggml_context * ctx_compute_opt; { @@ -2187,6 +2268,7 @@ llama_context_params llama_context_default_params() { /*.no_perf =*/ true, /*.op_offload =*/ true, /*.swa_full =*/ true, + /*.kv_unified =*/ false, }; return result; @@ -2719,19 +2801,31 @@ bool llama_state_save_file(llama_context * ctx, const char * path_session, const } size_t llama_state_seq_get_size(llama_context * ctx, llama_seq_id seq_id) { - return ctx->state_seq_get_size(seq_id); + return llama_state_seq_get_size_ext(ctx, seq_id, 0); } size_t llama_state_seq_get_data(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) { + return llama_state_seq_get_data_ext(ctx, dst, size, seq_id, 0); +} + +size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) { + return llama_state_seq_set_data_ext(ctx, src, size, seq_id, 0); +} + +size_t llama_state_seq_get_size_ext(llama_context * ctx, llama_seq_id seq_id, llama_state_seq_flags flags) { + return ctx->state_seq_get_size(seq_id, flags); +} + +size_t llama_state_seq_get_data_ext(llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) { ctx->synchronize(); - return ctx->state_seq_get_data(seq_id, dst, size); + return ctx->state_seq_get_data(seq_id, dst, size, flags); } -size_t llama_state_seq_set_data(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) { +size_t llama_state_seq_set_data_ext(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) { ctx->synchronize(); - return ctx->state_seq_set_data(seq_id, src, size); + return ctx->state_seq_set_data(seq_id, src, size, flags); } size_t llama_state_seq_save_file(llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { @@ -2807,6 +2901,7 @@ void llama_perf_context_print(const llama_context * ctx) { LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval); LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval)); + LLAMA_LOG_INFO("%s: graphs reused = %10d\n", __func__, data.n_reused); } void llama_perf_context_reset(llama_context * ctx) { diff --git a/examples/talk-llama/llama-context.h b/examples/talk-llama/llama-context.h index 9ce05715a8c..230ef8962b8 100644 --- a/examples/talk-llama/llama-context.h +++ b/examples/talk-llama/llama-context.h @@ -35,8 +35,6 @@ struct llama_context { ggml_backend_sched_t get_sched() const; - ggml_context * get_ctx_compute() const; - uint32_t n_ctx() const; uint32_t n_ctx_per_seq() const; uint32_t n_batch() const; @@ -96,7 +94,7 @@ struct llama_context { // if memory_context is provided, it will be applied first to the context's memory // ret contains the status of the graph computation // returns nullptr only if ret != GGML_STATUS_SUCCESS - llm_graph_result_ptr process_ubatch( + llm_graph_result * process_ubatch( const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, @@ -113,9 +111,9 @@ struct llama_context { size_t state_get_data( uint8_t * dst, size_t size); size_t state_set_data(const uint8_t * src, size_t size); - size_t state_seq_get_size(llama_seq_id seq_id); - size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size); - size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size); + size_t state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags); + size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags); + size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags); bool state_load_file( const char * filepath, @@ -154,6 +152,7 @@ struct llama_context { void opt_init(struct llama_model * model, struct llama_opt_params lopt_params); + // TODO: more flexible combinations of logical/physical batch size and context size void opt_epoch( ggml_opt_dataset_t dataset, ggml_opt_result_t result_train, @@ -183,15 +182,17 @@ struct llama_context { // Returns max number of outputs for which space was reserved. uint32_t output_reserve(int32_t n_outputs); + void output_reorder(); + // // graph // public: - int32_t graph_max_nodes() const; + uint32_t graph_max_nodes() const; - // zero-out inputs and create the ctx_compute for the compute graph - ggml_cgraph * graph_init(); + // can reuse the llm_graph_result instance of the context (for example to update a memory module) + llm_graph_result * get_gf_res_reserve() const; // returns the result of ggml_backend_sched_graph_compute_async execution ggml_status graph_compute(ggml_cgraph * gf, bool batched); @@ -200,12 +201,11 @@ struct llama_context { ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx); private: - llm_graph_result_ptr graph_build( - ggml_context * ctx, - ggml_cgraph * gf, - const llama_ubatch & ubatch, - llm_graph_type gtype, - const llama_memory_context_i * mctx); + llm_graph_params graph_params( + llm_graph_result * res, + const llama_ubatch & ubatch, + const llama_memory_context_i * mctx, + llm_graph_type gtype) const; llm_graph_cb graph_get_cb() const; @@ -213,8 +213,8 @@ struct llama_context { size_t state_write_data(llama_io_write_i & io); size_t state_read_data (llama_io_read_i & io); - size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id); - size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id); + size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags); + size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags); // // members @@ -253,13 +253,18 @@ struct llama_context { std::vector output_ids; // map batch token positions to ids of the logits and embd buffers + struct swap_info { + uint32_t i0; + uint32_t i1; + }; + + std::vector output_swaps; + ggml_backend_sched_ptr sched; ggml_backend_t backend_cpu = nullptr; std::vector backends; - ggml_context_ptr ctx_compute; - // training ggml_opt_context_t opt_ctx = nullptr; @@ -275,14 +280,21 @@ struct llama_context { std::vector backend_ptrs; std::vector backend_buft; - // memory buffers used to evaluate the model - std::vector buf_compute_meta; + llm_graph_result_ptr gf_res_prev; + llm_graph_result_ptr gf_res_reserve; // host buffer for the model output (logits and embeddings) ggml_backend_buffer_ptr buf_output; bool has_evaluated_once = false; + // env: LLAMA_SET_ROWS (temporary) + // ref: https://github.com/ggml-org/llama.cpp/pull/14285 + bool supports_set_rows = true; + + // env: LLAMA_GRAPH_REUSE_DISABLE + bool graph_reuse_disable = false; + // perf mutable int64_t t_start_us = 0; mutable int64_t t_load_us = 0; @@ -294,4 +306,6 @@ struct llama_context { mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) mutable int32_t n_eval = 0; // number of eval calls + + mutable int32_t n_reused = 0; // number of times the previous graph was reused }; diff --git a/examples/talk-llama/llama-cparams.h b/examples/talk-llama/llama-cparams.h index 118615d5bd2..38750affc50 100644 --- a/examples/talk-llama/llama-cparams.h +++ b/examples/talk-llama/llama-cparams.h @@ -11,8 +11,8 @@ struct llama_cparams { uint32_t n_batch; uint32_t n_ubatch; uint32_t n_seq_max; - int n_threads; // number of threads to use for generation - int n_threads_batch; // number of threads to use for batch processing + int32_t n_threads; // number of threads to use for generation + int32_t n_threads_batch; // number of threads to use for batch processing float rope_freq_base; float rope_freq_scale; @@ -33,6 +33,7 @@ struct llama_cparams { bool no_perf; bool warmup; bool op_offload; + bool kv_unified; enum llama_pooling_type pooling_type; diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index a248a7ec223..053c72d6dc8 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -28,6 +28,15 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { } } +bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) { + bool res = true; + + res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens); + res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[0] == params.ubatch.n_tokens); + + return res; +} + void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) { if (ubatch->pos && pos) { const int64_t n_tokens = ubatch->n_tokens; @@ -50,6 +59,14 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) { } } +bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) { + bool res = true; + + res &= pos->ne[0] == params.ubatch.n_tokens; + + return res; +} + void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) { if (ubatch->pos && attn_scale) { const int64_t n_tokens = ubatch->n_tokens; @@ -71,7 +88,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) { const int64_t n_tokens = ubatch->n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer)); - GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing + GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing int32_t * data = (int32_t *) pos_bucket->data; @@ -118,6 +135,14 @@ void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) { } } +bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) { + bool res = true; + + res &= n_outputs == params.n_outputs; + + return res; +} + void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) { if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { const int64_t n_tokens = ubatch->n_tokens; @@ -163,38 +188,23 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) { void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { const int64_t n_tokens = ubatch->n_tokens; - const int64_t n_seq_tokens = ubatch->n_seq_tokens; const int64_t n_seqs_unq = ubatch->n_seqs_unq; if (cparams.embeddings && ( - cparams.pooling_type == LLAMA_POOLING_TYPE_CLS || - cparams.pooling_type == LLAMA_POOLING_TYPE_RANK - )) { + cparams.pooling_type == LLAMA_POOLING_TYPE_CLS || + cparams.pooling_type == LLAMA_POOLING_TYPE_RANK || + cparams.pooling_type == LLAMA_POOLING_TYPE_LAST + )) { GGML_ASSERT(cls); GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer)); uint32_t * data = (uint32_t *) cls->data; memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls)); - for (int i = 0; i < n_tokens; i += n_seq_tokens) { - for (int s = 0; s < ubatch->n_seq_id[i]; ++s) { - const llama_seq_id seq_id = ubatch->seq_id[i][s]; - const int32_t seq_idx = ubatch->seq_idx[seq_id]; - - data[seq_idx] = i; - } - } - } - - if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) { - GGML_ASSERT(cls); - GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer)); + std::vector target_pos(n_seqs_unq, -1); + std::vector target_row(n_seqs_unq, -1); - uint32_t * data = (uint32_t *) cls->data; - memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls)); - - std::vector last_pos(n_seqs_unq, -1); - std::vector last_row(n_seqs_unq, -1); + bool last = cparams.pooling_type == LLAMA_POOLING_TYPE_LAST; for (int i = 0; i < n_tokens; ++i) { const llama_pos pos = ubatch->pos[i]; @@ -203,16 +213,20 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { const llama_seq_id seq_id = ubatch->seq_id[i][s]; const int32_t seq_idx = ubatch->seq_idx[seq_id]; - if (pos >= last_pos[seq_idx]) { - last_pos[seq_idx] = pos; - last_row[seq_idx] = i; + if ( + (target_pos[seq_idx] == -1) || + ( last && pos >= target_pos[seq_idx]) || + (!last && pos < target_pos[seq_idx]) + ) { + target_pos[seq_idx] = pos; + target_row[seq_idx] = i; } } } for (int s = 0; s < n_seqs_unq; ++s) { - if (last_row[s] >= 0) { - data[s] = last_row[s]; + if (target_row[s] >= 0) { + data[s] = target_row[s]; } } } @@ -287,6 +301,24 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); } +bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; + //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= self_kq_mask->ne[0] == mctx->get_n_kv(); + res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD); + + res &= mctx->get_supports_set_rows(); // TODO: tmp + + return res; +} + void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) { mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); @@ -299,6 +331,30 @@ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); } +bool llm_graph_input_attn_kv_unified_iswa::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; + //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; + //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv(); + res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD); + + res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv(); + res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD); + + res &= mctx->get_base()->get_supports_set_rows(); // TODO: tmp + + return res; +} + void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { GGML_ASSERT(cross_kq_mask); @@ -306,7 +362,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { const int64_t n_tokens = ubatch->n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer)); - GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing + GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing float * data = (float *) cross_kq_mask->data; @@ -340,6 +396,91 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { inp_rs->set_input(ubatch); } +// +// llm_graph_result +// + +llm_graph_result::llm_graph_result(int64_t max_nodes) : max_nodes(max_nodes) { + reset(); + + const char * LLAMA_GRAPH_RESULT_DEBUG = getenv("LLAMA_GRAPH_RESULT_DEBUG"); + debug = LLAMA_GRAPH_RESULT_DEBUG ? atoi(LLAMA_GRAPH_RESULT_DEBUG) : 0; +} + +int64_t llm_graph_result::get_max_nodes() const { + return max_nodes; +} + +void llm_graph_result::reset() { + t_tokens = nullptr; + t_logits = nullptr; + t_embd = nullptr; + t_embd_pooled = nullptr; + + params = {}; + + inputs.clear(); + + buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false)); + + ggml_init_params params = { + /*.mem_size =*/ buf_compute_meta.size(), + /*.mem_buffer =*/ buf_compute_meta.data(), + /*.no_alloc =*/ true, + }; + + ctx_compute.reset(ggml_init(params)); + + gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes, false); +} + +void llm_graph_result::set_inputs(const llama_ubatch * ubatch) { + for (auto & input : inputs) { + input->set_input(ubatch); + } +} + +bool llm_graph_result::can_reuse(const llm_graph_params & params) { + if (!this->params.allow_reuse(params)) { + if (debug > 1) { + LLAMA_LOG_DEBUG("%s: cannot reuse graph due to incompatible graph parameters\n", __func__); + } + + return false; + } + + if (debug > 1) { + LLAMA_LOG_DEBUG("%s: checking compatibility of %d inputs:\n", __func__, (int) inputs.size()); + } + + bool res = true; + + for (auto & input : inputs) { + const bool cur = input->can_reuse(params); + + if (debug > 1) { + LLAMA_LOG_DEBUG("%s: can_reuse = %d\n", "placeholder", cur); + } + + res = res && cur; + } + + if (debug > 0) { + LLAMA_LOG_DEBUG("%s: can reuse graph = %d\n", __func__, res); + } + + return res; +} + +llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) { + inputs.emplace_back(std::move(input)); + return inputs.back().get(); +} + +void llm_graph_result::set_params(const llm_graph_params & params) { + this->params = params; +} + // // llm_graph_context // @@ -374,7 +515,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : n_ctx_orig (cparams.n_ctx_orig_yarn), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), - ctx0 (params.ctx), sched (params.sched), backend_cpu (params.backend_cpu), cvec (params.cvec), @@ -382,7 +522,10 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : mctx (params.mctx), cross (params.cross), cb_func (params.cb), - res (std::make_unique()) { + res (params.res), + ctx0 (res->get_ctx()), + gf (res->get_gf()) { + res->set_params(params); } void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const { @@ -597,6 +740,8 @@ ggml_tensor * llm_graph_context::build_ffn( cur = ggml_reglu(ctx0, cur); cb(cur, "ffn_reglu", il); } break; + default: + GGML_ABORT("fatal error"); } if (gate && type_gate == LLM_FFN_PAR) { @@ -606,8 +751,8 @@ ggml_tensor * llm_graph_context::build_ffn( if (down) { cur = build_lora_mm(down, cur); - if (arch == LLM_ARCH_GLM4) { - // GLM4 seems to have numerical issues with half-precision accumulators + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { + // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } } @@ -642,13 +787,64 @@ ggml_tensor * llm_graph_context::build_moe_ffn( bool scale_w, float w_scale, llama_expert_gating_func_type gating_op, - int il) const { + int il, + ggml_tensor * probs_in) const { + return build_moe_ffn( + cur, + gate_inp, /* gate_inp_b */ nullptr, + up_exps, /* up_exps_b */ nullptr, + gate_exps, /* gate_exps_b */ nullptr, + down_exps, /* down_exps_b */ nullptr, + exp_probs_b, + n_expert, + n_expert_used, + type_op, + norm_w, + scale_w, + w_scale, + gating_op, + il, + probs_in + ); +} + +ggml_tensor * llm_graph_context::build_moe_ffn( + ggml_tensor * cur, + ggml_tensor * gate_inp, + ggml_tensor * gate_inp_b, + ggml_tensor * up_exps, + ggml_tensor * up_exps_b, + ggml_tensor * gate_exps, + ggml_tensor * gate_exps_b, + ggml_tensor * down_exps, + ggml_tensor * down_exps_b, + ggml_tensor * exp_probs_b, + int64_t n_expert, + int64_t n_expert_used, + llm_ffn_op_type type_op, + bool norm_w, + bool scale_w, + float w_scale, + llama_expert_gating_func_type gating_op, + int il, + ggml_tensor * probs_in) const { const int64_t n_embd = cur->ne[0]; const int64_t n_tokens = cur->ne[1]; const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN - ggml_tensor * logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens] - cb(logits, "ffn_moe_logits", il); + ggml_tensor * logits = nullptr; + + if (probs_in == nullptr) { + logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens] + cb(logits, "ffn_moe_logits", il); + } else { + logits = probs_in; + } + + if (gate_inp_b) { + logits = ggml_add(ctx0, logits, gate_inp_b); + cb(logits, "ffn_moe_logits_biased", il); + } ggml_tensor * probs = nullptr; switch (gating_op) { @@ -660,6 +856,10 @@ ggml_tensor * llm_graph_context::build_moe_ffn( { probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens] } break; + case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT: + { + probs = logits; // [n_expert, n_tokens] + } break; default: GGML_ABORT("fatal error"); } @@ -688,6 +888,13 @@ ggml_tensor * llm_graph_context::build_moe_ffn( ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] cb(weights, "ffn_moe_weights", il); + if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) { + weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); + weights = ggml_soft_max(ctx0, weights); // [n_expert_used, n_tokens] + weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens); + cb(weights, "ffn_moe_weights_softmax", il); + } + if (norm_w) { weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); @@ -716,6 +923,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn( ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] cb(up, "ffn_moe_up", il); + if (up_exps_b) { + up = ggml_add_id(ctx0, up, up_exps_b, selected_experts); + cb(up, "ffn_moe_up_biased", il); + } + ggml_tensor * experts = nullptr; if (gate_exps) { cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] @@ -724,6 +936,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cur = up; } + if (gate_exps_b) { + cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts); + cb(cur, "ffn_moe_gate_biased", il); + } + switch (type_op) { case LLM_FFN_SILU: if (gate_exps) { @@ -741,6 +958,22 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cur = ggml_gelu(ctx0, cur); cb(cur, "ffn_moe_gelu", il); } break; + case LLM_FFN_SWIGLU_OAI_MOE: + { + // TODO: move to hparams? + constexpr float alpha = 1.702f; + constexpr float limit = 7.0f; + cur = ggml_swiglu_oai(ctx0, cur, up, alpha, limit); + cb(cur, "ffn_moe_swiglu_oai", il); + } break; + case LLM_FFN_RELU: + if (gate_exps) { + cur = ggml_reglu_split(ctx0, cur, up); + cb(cur, "ffn_moe_reglu", il); + } else { + cur = ggml_relu(ctx0, cur); + cb(cur, "ffn_moe_relu", il); + } break; default: GGML_ABORT("fatal error"); } @@ -748,25 +981,38 @@ ggml_tensor * llm_graph_context::build_moe_ffn( experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens] cb(experts, "ffn_moe_down", il); + if (down_exps_b) { + experts = ggml_add_id(ctx0, experts, down_exps_b, selected_experts); + cb(experts, "ffn_moe_down_biased", il); + } + if (!weight_before_ffn) { experts = ggml_mul(ctx0, experts, weights); cb(cur, "ffn_moe_weighted", il); } + ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr }; + + assert(n_expert_used > 0); + + // order the views before the adds + for (uint32_t i = 0; i < hparams.n_expert_used; ++i) { + cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]); + + ggml_build_forward_expand(gf, cur_experts[i]); + } + // aggregate experts - ggml_tensor * moe_out = nullptr; - for (int i = 0; i < n_expert_used; ++i) { - ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens, - experts->nb[2], i*experts->nb[1]); + // note: here we explicitly use hparams.n_expert_used instead of n_expert_used + // to avoid potentially a large number of add nodes during warmup + // ref: https://github.com/ggml-org/llama.cpp/pull/14753 + ggml_tensor * moe_out = cur_experts[0]; - if (i == 0) { - moe_out = cur_expert; - } else { - moe_out = ggml_add(ctx0, moe_out, cur_expert); - } + for (uint32_t i = 1; i < hparams.n_expert_used; ++i) { + moe_out = ggml_add(ctx0, moe_out, cur_experts[i]); } - if (n_expert_used == 1) { + if (hparams.n_expert_used == 1) { // avoid returning a non-contiguous tensor moe_out = ggml_cont(ctx0, moe_out); } @@ -972,23 +1218,26 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t } ggml_tensor * llm_graph_context::build_attn_mha( - ggml_cgraph * gf, ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, ggml_tensor * kq_b, ggml_tensor * kq_mask, ggml_tensor * v_mla, + ggml_tensor * sinks, float kq_scale) const { const bool v_trans = v->nb[1] > v->nb[2]; + // split the batch into streams if needed + const auto n_stream = k->ne[3]; + + q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream); + q = ggml_permute(ctx0, q, 0, 2, 1, 3); k = ggml_permute(ctx0, k, 0, 2, 1, 3); v = ggml_permute(ctx0, v, 0, 2, 1, 3); - const auto n_tokens = q->ne[1]; - const auto n_head = q->ne[2]; - const auto n_kv = k->ne[1]; + const auto n_kv = k->ne[1]; ggml_tensor * cur; @@ -1012,7 +1261,8 @@ ggml_tensor * llm_graph_context::build_attn_mha( cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); - ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); + ggml_flash_attn_ext_add_sinks(cur, sinks); + ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32); if (v_mla) { #if 0 @@ -1030,7 +1280,7 @@ ggml_tensor * llm_graph_context::build_attn_mha( #endif } - cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]); } else { ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); @@ -1060,6 +1310,7 @@ ggml_tensor * llm_graph_context::build_attn_mha( } kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); + ggml_soft_max_add_sinks(kq, sinks); if (!v_trans) { // note: avoid this branch @@ -1075,7 +1326,8 @@ ggml_tensor * llm_graph_context::build_attn_mha( cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3); - cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); + // recombine streams + cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]); if (!cparams.offload_kqv) { // all nodes between the KV store and the attention output are run on the CPU @@ -1102,7 +1354,6 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_no_cache * inp, - ggml_cgraph * gf, ggml_tensor * wo, ggml_tensor * wo_b, ggml_tensor * q_cur, @@ -1122,11 +1373,15 @@ ggml_tensor * llm_graph_context::build_attn( const auto & kq_mask = inp->get_kq_mask(); + // [TAG_NO_CACHE_PAD] + // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams + assert(!ubatch.equal_seqs()); + ggml_tensor * q = q_cur; ggml_tensor * k = k_cur; ggml_tensor * v = v_cur; - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale); cb(cur, "kqv_out", il); if (wo) { @@ -1156,13 +1411,14 @@ static std::unique_ptr build_attn_inp_kv_unifie { GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA"); - const auto n_kv = mctx_cur->get_n_kv(); + const auto n_kv = mctx_cur->get_n_kv(); const auto n_tokens = ubatch.n_tokens; + const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1181,7 +1437,6 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_kv_unified * inp, - ggml_cgraph * gf, ggml_tensor * wo, ggml_tensor * wo_b, ggml_tensor * q_cur, @@ -1214,13 +1469,13 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k = mctx_cur->get_k(ctx0, il); ggml_tensor * v = mctx_cur->get_v(ctx0, il); - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale); cb(cur, "kqv_out", il); if (wo) { cur = build_lora_mm(wo, cur); - if (arch == LLM_ARCH_GLM4) { - // GLM4 seems to have numerical issues with half-precision accumulators + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { + // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } } @@ -1234,7 +1489,6 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_kv_unified_iswa * inp, - ggml_cgraph * gf, ggml_tensor * wo, ggml_tensor * wo_b, ggml_tensor * q_cur, @@ -1244,6 +1498,32 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * v_mla, float kq_scale, int il) const { + return build_attn_with_sinks( + inp, + wo, + wo_b, + q_cur, + k_cur, + v_cur, + kq_b, + v_mla, + nullptr, + kq_scale, + il); +} + +ggml_tensor * llm_graph_context::build_attn_with_sinks( + llm_graph_input_attn_kv_unified_iswa * inp, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + ggml_tensor * v_mla, + ggml_tensor * sinks, + float kq_scale, + int il) const { // these nodes are added to the graph together so that they are not reordered // by doing so, the number of splits in the graph is reduced ggml_build_forward_expand(gf, q_cur); @@ -1281,7 +1561,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k = mctx_cur->get_k(ctx0, il); ggml_tensor * v = mctx_cur->get_v(ctx0, il); - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, sinks, kq_scale); cb(cur, "kqv_out", il); if (wo) { @@ -1314,7 +1594,6 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const { ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_cross * inp, - ggml_cgraph * gf, ggml_tensor * wo, ggml_tensor * wo_b, ggml_tensor * q_cur, @@ -1336,7 +1615,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k = k_cur; ggml_tensor * v = v_cur; - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale); cb(cur, "kqv_out", il); if (wo) { @@ -1362,13 +1641,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif auto inp = std::make_unique(hparams, cparams, mctx_cur); + const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; + { const auto n_kv = mctx_cur->get_base()->get_n_kv(); inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1382,7 +1663,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); + inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream); ggml_set_input(inp->self_kq_mask_swa); inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; @@ -1392,18 +1673,18 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif } ggml_tensor * llm_graph_context::build_rs( - ggml_cgraph * gf, ggml_tensor * s, - ggml_tensor * state_copy, + ggml_tensor * state_copy_main, + ggml_tensor * state_copy_extra, int32_t state_size, int32_t n_seqs, - uint32_t n_kv, - uint32_t kv_head, - uint32_t kv_size, + uint32_t n_rs, + uint32_t rs_head, + uint32_t rs_size, int32_t rs_zero, const llm_graph_get_rows_fn & get_state_rows) const { - ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size); + ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size); // Clear a single state which will then be copied to the other cleared states. // Note that this is a no-op when the view is zero-sized. @@ -1411,60 +1692,65 @@ ggml_tensor * llm_graph_context::build_rs( ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0)); // copy states - // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv - // {state_size, kv_size} -> {state_size, n_seqs} - ggml_tensor * output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0)); + // NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs + // {state_size, rs_size} -> {state_size, n_seqs} + ggml_tensor * output_states = get_state_rows(ctx0, states, state_copy_main); ggml_build_forward_expand(gf, output_states); - // copy extra states which won't be changed further (between n_seqs and n_kv) - ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0])); + // copy extra states which won't be changed further (between n_seqs and n_rs) + ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra); ggml_build_forward_expand(gf, ggml_cpy(ctx0, states_extra, - ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s)))); + ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s)))); return output_states; } static std::unique_ptr build_rs_inp_impl( ggml_context * ctx0, + const llama_ubatch & ubatch, const llama_memory_recurrent_context * mctx_cur) { auto inp = std::make_unique(mctx_cur); - const auto n_rs = mctx_cur->get_n_rs(); + const int64_t n_rs = mctx_cur->get_n_rs(); + const int64_t n_seqs = ubatch.n_seqs; inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs); ggml_set_input(inp->s_copy); + inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0); + inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]); + return inp; } llm_graph_input_rs * llm_graph_context::build_rs_inp() const { const auto * mctx_cur = static_cast(mctx); - auto inp = build_rs_inp_impl(ctx0, mctx_cur); + auto inp = build_rs_inp_impl(ctx0, ubatch, mctx_cur); return (llm_graph_input_rs *) res->add_input(std::move(inp)); } ggml_tensor * llm_graph_context::build_rs( llm_graph_input_rs * inp, - ggml_cgraph * gf, ggml_tensor * s, int32_t state_size, int32_t n_seqs, const llm_graph_get_rows_fn & get_state_rows) const { const auto * kv_state = inp->mctx; - return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows); + return build_rs(s, inp->s_copy_main, inp->s_copy_extra, state_size, n_seqs, + kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), + get_state_rows); } ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( llm_graph_input_rs * inp, - ggml_cgraph * gf, const llama_ubatch & ubatch, - int il) const { + int il) const { const auto * mctx_cur = static_cast(mctx); const auto token_shift_count = hparams.token_shift_count; @@ -1474,7 +1760,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( ggml_tensor * token_shift_all = mctx_cur->get_r_l(il); ggml_tensor * token_shift = build_rs( - inp, gf, token_shift_all, + inp, token_shift_all, hparams.n_embd_r(), n_seqs); token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs); @@ -1505,7 +1791,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store( llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { const auto * mctx_cur = static_cast(mctx); - auto inp_rs = build_rs_inp_impl(ctx0, mctx_cur->get_recr()); + auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr()); auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); auto inp = std::make_unique(std::move(inp_attn), std::move(inp_rs), mctx_cur); @@ -1514,7 +1800,6 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { } void llm_graph_context::build_pooling( - ggml_cgraph * gf, ggml_tensor * cls, ggml_tensor * cls_b, ggml_tensor * cls_out, diff --git a/examples/talk-llama/llama-graph.h b/examples/talk-llama/llama-graph.h index fbf8e288956..6ff49de3a1c 100644 --- a/examples/talk-llama/llama-graph.h +++ b/examples/talk-llama/llama-graph.h @@ -1,6 +1,7 @@ #pragma once #include "llama-arch.h" +#include "llama-batch.h" #include "llama-hparams.h" #include "llama-adapter.h" @@ -14,7 +15,6 @@ struct ggml_cgraph; struct ggml_context; struct ggml_tensor; -struct llama_ubatch; struct llama_cparams; struct llama_memory_context_i; @@ -39,6 +39,7 @@ enum llm_ffn_op_type { LLM_FFN_SWIGLU, LLM_FFN_GEGLU, LLM_FFN_REGLU, + LLM_FFN_SWIGLU_OAI_MOE, }; enum llm_ffn_gate_type { @@ -69,6 +70,8 @@ struct llama_cross { std::vector> seq_ids_enc; }; +struct llm_graph_params; + // // llm_graph_input // @@ -78,11 +81,19 @@ class llm_graph_input_i { virtual ~llm_graph_input_i() = default; virtual void set_input(const llama_ubatch * ubatch) = 0; + + // return true if the resulting input tensors using the provided graph parameters would be + // the same as the previous input tensors that we have currently stored in the object + virtual bool can_reuse(const llm_graph_params & params) { + // returning false here by default will prevent from reusing the graph if the check + // for the input type has not been implemented yet + GGML_UNUSED(params); + return false; + } }; using llm_graph_input_ptr = std::unique_ptr; - class llm_graph_input_embd : public llm_graph_input_i { public: llm_graph_input_embd() = default; @@ -90,6 +101,8 @@ class llm_graph_input_embd : public llm_graph_input_i { void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + ggml_tensor * tokens = nullptr; // I32 [n_batch] ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch] }; @@ -101,6 +114,8 @@ class llm_graph_input_pos : public llm_graph_input_i { void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + ggml_tensor * pos = nullptr; // I32 [n_batch] const uint32_t n_pos_per_embd = 1; @@ -130,7 +145,7 @@ class llm_graph_input_pos_bucket : public llm_graph_input_i { ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch] - const llama_hparams & hparams; + const llama_hparams hparams; }; class llm_graph_input_pos_bucket_kv : public llm_graph_input_i { @@ -144,7 +159,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i { ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch] - const llama_hparams & hparams; + const llama_hparams hparams; const llama_kv_cache_unified_context * mctx; }; @@ -154,17 +169,19 @@ class llm_graph_input_out_ids : public llm_graph_input_i { llm_graph_input_out_ids( const llama_hparams & hparams, const llama_cparams & cparams, - int32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {} + uint32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {} virtual ~llm_graph_input_out_ids() = default; void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + ggml_tensor * out_ids; // I32 [n_outputs] - const llama_hparams & hparams; - const llama_cparams & cparams; + const llama_hparams hparams; + const llama_cparams cparams; - const int32_t n_outputs; + const uint32_t n_outputs; }; class llm_graph_input_mean : public llm_graph_input_i { @@ -176,7 +193,7 @@ class llm_graph_input_mean : public llm_graph_input_i { ggml_tensor * mean; // F32 [n_batch, n_batch] - const llama_cparams & cparams; + const llama_cparams cparams; }; class llm_graph_input_cls : public llm_graph_input_i { @@ -188,7 +205,7 @@ class llm_graph_input_cls : public llm_graph_input_i { ggml_tensor * cls; // I32 [n_batch] - const llama_cparams & cparams; + const llama_cparams cparams; }; class llm_graph_input_rs : public llm_graph_input_i { @@ -198,7 +215,12 @@ class llm_graph_input_rs : public llm_graph_input_i { void set_input(const llama_ubatch * ubatch) override; - ggml_tensor * s_copy; // I32 [kv_size] + ggml_tensor * s_copy; // I32 [n_rs] + + // views of s_copy, computed once per graph + // and shared across layers which use build_rs + ggml_tensor * s_copy_main; // I32 [n_seqs] + ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs] const llama_memory_recurrent_context * mctx; }; @@ -231,8 +253,8 @@ class llm_graph_input_attn_no_cache : public llm_graph_input_i { ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1] ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1] - const llama_hparams & hparams; - const llama_cparams & cparams; + const llama_hparams hparams; + const llama_cparams cparams; }; class llm_graph_input_attn_kv_unified : public llm_graph_input_i { @@ -249,19 +271,24 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i { void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + ggml_tensor * get_k_idxs() const { return self_k_idxs; } ggml_tensor * get_v_idxs() const { return self_v_idxs; } ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] - ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] + ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa] - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1] + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] - const llama_hparams & hparams; - const llama_cparams & cparams; + // note: these have to be copies because in order to be able to reuse a graph, its inputs + // need to carry these parameters with them. otherwise, they can point to freed + // llm_graph_params from a previous batch, causing stack-use-after-return + const llama_hparams hparams; + const llama_cparams cparams; const llama_kv_cache_unified_context * mctx; }; @@ -280,6 +307,8 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + ggml_tensor * get_k_idxs() const { return self_k_idxs; } ggml_tensor * get_v_idxs() const { return self_v_idxs; } ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; } @@ -289,17 +318,17 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; } ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] - ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] + ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa] ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch] - ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] + ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa] - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1] - ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1, 1] - ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1, 1] + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] - const llama_hparams & hparams; - const llama_cparams & cparams; + const llama_hparams hparams; + const llama_cparams cparams; const llama_kv_cache_unified_iswa_context * mctx; }; @@ -351,40 +380,110 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i { // along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc. // these are used by the llama_context to extact the relevant data, based on the compute parameters -class llm_graph_result_i { -public: - virtual ~llm_graph_result_i() = default; +// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.) +using llm_graph_cb = std::function; - virtual ggml_tensor * get_tokens() = 0; - virtual ggml_tensor * get_logits() = 0; - virtual ggml_tensor * get_embd() = 0; - virtual ggml_tensor * get_embd_pooled() = 0; +class llm_graph_result; - virtual void set_inputs(const llama_ubatch * ubatch) = 0; -}; +struct llm_graph_params { + llm_arch arch = LLM_ARCH_UNKNOWN; -using llm_graph_result_ptr = std::unique_ptr; + llama_hparams hparams; + llama_cparams cparams; + llama_ubatch ubatch; // note: intentionally make a copy -class llm_graph_result : public llm_graph_result_i { -public: - virtual ~llm_graph_result() = default; + llm_graph_type gtype; + + ggml_backend_sched_t sched; + ggml_backend_t backend_cpu; + + const llama_adapter_cvec * cvec; + const llama_adapter_loras * loras; + const llama_memory_context_i * mctx; + const llama_cross * cross; - ggml_tensor * get_tokens() override { return t_tokens; } - ggml_tensor * get_logits() override { return t_logits; } - ggml_tensor * get_embd() override { return t_embd; } - ggml_tensor * get_embd_pooled() override { return t_embd_pooled; } + uint32_t n_outputs; - void set_inputs(const llama_ubatch * ubatch) override { - for (auto & input : inputs) { - input->set_input(ubatch); + llm_graph_cb cb; + + llm_graph_result * res; + + // return true if the "other" params would result in a graph with the same topology as with the current params + // having the same topology allows us to reuse the graph in some cases + bool allow_reuse(const llm_graph_params & other) const { + // first check the ubatch + bool can_reuse_ubatch = + ubatch.equal_seqs() == other.ubatch.equal_seqs() && + ubatch.n_tokens == other.ubatch.n_tokens && + ubatch.n_seq_tokens == other.ubatch.n_seq_tokens && + ubatch.n_seqs == other.ubatch.n_seqs && + ubatch.n_seqs_unq == other.ubatch.n_seqs_unq && + ( + (!ubatch.token && !other.ubatch.token) || + (!ubatch.embd && !other.ubatch.embd) + ); + + // when we split the batch using "equal_seqs" we have to verify that the participating sequences are the same + // the reason is because the set of attention streams would be different for different sequences + if (can_reuse_ubatch && ubatch.equal_seqs()) { + if (!ubatch.data) { + // if the old ubatch does not own it's data, then we cannot guarantee that it is still alive, and + // therefore we cannot perform the sequence id check. normally should never happen + can_reuse_ubatch = false; + } else { + for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { + can_reuse_ubatch &= ubatch.seq_id_unq[s] == other.ubatch.seq_id_unq[s]; + } + } } - } - llm_graph_input_i * add_input(llm_graph_input_ptr input) { - inputs.emplace_back(std::move(input)); - return inputs.back().get(); + if (!can_reuse_ubatch) { + return false; + } + + return + cparams.embeddings == other.cparams.embeddings && + cparams.causal_attn == other.cparams.causal_attn && + arch == other.arch && + gtype == other.gtype && + cvec == other.cvec && + loras == other.loras && + cross == other.cross && + n_outputs == other.n_outputs; } +}; + +class llm_graph_result { +public: + llm_graph_result(int64_t max_nodes); + + virtual ~llm_graph_result() = default; + + ggml_tensor * get_tokens() const { return t_tokens; } + ggml_tensor * get_logits() const { return t_logits; } + ggml_tensor * get_embd() const { return t_embd; } + ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } + + ggml_cgraph * get_gf() const { return gf; } + ggml_context * get_ctx() const { return ctx_compute.get(); } + + int64_t get_max_nodes() const; + + void reset(); + + void set_inputs(const llama_ubatch * ubatch); + + // try to update the existing graph result using the new graph parameters in order to reuse it + // this can only be done if we determine that the resulting graph using the new graph parameters + // would be identical to the existing graph. in that case, we simply have to update the memory + // contexts of the input tensors of the graph and we can reuse it for another computation + // return true if the graph was updated and can be reused + bool can_reuse(const llm_graph_params & params); + + llm_graph_input_i * add_input(llm_graph_input_ptr input); + + void set_params(const llm_graph_params & params); // important graph nodes ggml_tensor * t_tokens = nullptr; @@ -393,36 +492,31 @@ class llm_graph_result : public llm_graph_result_i { ggml_tensor * t_embd_pooled = nullptr; std::vector inputs; -}; -// -// llm_graph_context -// + ggml_context_ptr ctx_compute; -// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.) -using llm_graph_cb = std::function; - -struct llm_graph_params { - ggml_context * ctx; + // memory buffers used to evaluate the model + std::vector buf_compute_meta; - const llm_arch arch; + ggml_cgraph * gf; - const llama_hparams & hparams; - const llama_cparams & cparams; - const llama_ubatch & ubatch; + int64_t max_nodes; - ggml_backend_sched_t sched; - ggml_backend_t backend_cpu; +private: + // keep a copy of the previous graph parameters + // we will use this to determine whether the graph can be reused by comparing them with the new parameters + // note: these are updated after constructing the new graph + llm_graph_params params; - const llama_adapter_cvec * cvec; - const llama_adapter_loras * loras; - const llama_memory_context_i * mctx; - const llama_cross * cross; + // env: LLAMA_GRAPH_RESULT_DEBUG + int debug = 0; +}; - uint32_t n_outputs; +using llm_graph_result_ptr = std::unique_ptr; - const llm_graph_cb & cb; -}; +// +// llm_graph_context +// // used in build_rs to properly order writes and avoid unnecessary copies using llm_graph_get_rows_fn = std::function; @@ -463,8 +557,6 @@ struct llm_graph_context { const enum llama_pooling_type pooling_type; const enum llama_rope_type rope_type; - ggml_context * ctx0 = nullptr; - ggml_backend_sched_t sched; ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove? @@ -476,7 +568,10 @@ struct llm_graph_context { const llm_graph_cb & cb_func; - std::unique_ptr res; + llm_graph_result * res; + + ggml_context * ctx0 = nullptr; + ggml_cgraph * gf = nullptr; llm_graph_context(const llm_graph_params & params); virtual ~llm_graph_context() = default; @@ -525,6 +620,7 @@ struct llm_graph_context { llm_ffn_gate_type type_gate, int il) const; + // build MoE FFN without bias tensors ggml_tensor * build_moe_ffn( ggml_tensor * cur, ggml_tensor * gate_inp, @@ -539,7 +635,29 @@ struct llm_graph_context { bool scale_w, float w_scale, llama_expert_gating_func_type gating_op, - int il) const; + int il, + ggml_tensor * probs_in = nullptr) const; + + ggml_tensor * build_moe_ffn( + ggml_tensor * cur, + ggml_tensor * gate_inp, + ggml_tensor * gate_inp_b, + ggml_tensor * up_exps, + ggml_tensor * up_exps_b, + ggml_tensor * gate_exps, + ggml_tensor * gate_exps_b, + ggml_tensor * down_exps, + ggml_tensor * down_exps_b, + ggml_tensor * exp_probs_b, + int64_t n_expert, + int64_t n_expert_used, + llm_ffn_op_type type_op, + bool norm_w, + bool scale_w, + float w_scale, + llama_expert_gating_func_type gating_op, + int il, + ggml_tensor * probs_in = nullptr) const; // // inputs @@ -562,12 +680,12 @@ struct llm_graph_context { // ggml_tensor * build_attn_mha( - ggml_cgraph * gf, ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens] ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false) ggml_tensor * kq_b, ggml_tensor * kq_mask, + ggml_tensor * sinks, ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] float kq_scale) const; @@ -575,7 +693,6 @@ struct llm_graph_context { ggml_tensor * build_attn( llm_graph_input_attn_no_cache * inp, - ggml_cgraph * gf, ggml_tensor * wo, ggml_tensor * wo_b, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] @@ -590,7 +707,6 @@ struct llm_graph_context { ggml_tensor * build_attn( llm_graph_input_attn_kv_unified * inp, - ggml_cgraph * gf, ggml_tensor * wo, ggml_tensor * wo_b, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] @@ -606,7 +722,6 @@ struct llm_graph_context { // note: if k_cur or v_cur are not provided, they will not be stored in the memory ggml_tensor * build_attn( llm_graph_input_attn_kv_unified_iswa * inp, - ggml_cgraph * gf, ggml_tensor * wo, ggml_tensor * wo_b, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] @@ -617,11 +732,24 @@ struct llm_graph_context { float kq_scale, int il) const; + // TODO: temporary to keep the diff small. after the code is public will refactor to simplify this + ggml_tensor * build_attn_with_sinks( + llm_graph_input_attn_kv_unified_iswa * inp, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional + ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional + ggml_tensor * kq_b, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] + ggml_tensor * sinks, // [n_head_q] + float kq_scale, + int il) const; + llm_graph_input_attn_cross * build_attn_inp_cross() const; ggml_tensor * build_attn( llm_graph_input_attn_cross * inp, - ggml_cgraph * gf, ggml_tensor * wo, ggml_tensor * wo_b, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] @@ -636,21 +764,20 @@ struct llm_graph_context { // recurrent // - // TODO: avoid notion of "kv" // TODO: move this implementation to llama_memory_recurrent. // this is analogous to llama_kv_cache_unified::cpy_k / cpy_v // when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the // implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in // `llama_memory_recurrent` ggml_tensor * build_rs( - ggml_cgraph * gf, ggml_tensor * s, - ggml_tensor * state_copy, + ggml_tensor * state_copy_main, + ggml_tensor * state_copy_extra, int32_t state_size, int32_t n_seqs, - uint32_t n_kv, - uint32_t kv_head, - uint32_t kv_size, + uint32_t n_rs, + uint32_t rs_head, + uint32_t rs_size, int32_t rs_zero, const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const; @@ -658,7 +785,6 @@ struct llm_graph_context { ggml_tensor * build_rs( llm_graph_input_rs * inp, - ggml_cgraph * gf, ggml_tensor * s, int32_t state_size, int32_t n_seqs, @@ -666,9 +792,8 @@ struct llm_graph_context { ggml_tensor * build_rwkv_token_shift_load( llm_graph_input_rs * inp, - ggml_cgraph * gf, const llama_ubatch & ubatch, - int il) const; + int il) const; ggml_tensor * build_rwkv_token_shift_store( ggml_tensor * token_shift, @@ -685,7 +810,6 @@ struct llm_graph_context { // void build_pooling( - ggml_cgraph * gf, ggml_tensor * cls, ggml_tensor * cls_b, ggml_tensor * cls_out, diff --git a/examples/talk-llama/llama-hparams.cpp b/examples/talk-llama/llama-hparams.cpp index 7aa736e2f39..7a06368dcda 100644 --- a/examples/talk-llama/llama-hparams.cpp +++ b/examples/talk-llama/llama-hparams.cpp @@ -2,9 +2,15 @@ #include "ggml.h" -void llama_hparams::set_swa_pattern(uint32_t n_pattern) { - for (uint32_t il = 0; il < n_layer; ++il) { - swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1)); +void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) { + if (dense_first) { + for (uint32_t il = 0; il < n_layer; ++il) { + swa_layers[il] = n_pattern == 0 || (il % n_pattern != 0); + } + } else { + for (uint32_t il = 0; il < n_layer; ++il) { + swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1)); + } } } @@ -65,6 +71,46 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const { return n_embd_head_v * n_head_kv; } +bool llama_hparams::is_n_embd_k_gqa_variable() const { + const uint32_t val = n_embd_k_gqa(); + for (uint32_t il = 0; il < n_layer; ++il) { + if (val != n_embd_k_gqa(il)) { + return true; + } + } + + return false; +} + +bool llama_hparams::is_n_embd_v_gqa_variable() const { + const uint32_t val = n_embd_v_gqa(); + for (uint32_t il = 0; il < n_layer; ++il) { + if (val != n_embd_v_gqa(il)) { + return true; + } + } + + return false; +} + +uint32_t llama_hparams::n_embd_k_gqa_max() const { + uint32_t val = n_embd_k_gqa(); + for (uint32_t il = 0; il < n_layer; ++il) { + val = std::max(val, n_embd_k_gqa(il)); + } + + return val; +} + +uint32_t llama_hparams::n_embd_v_gqa_max() const { + uint32_t val = n_embd_v_gqa(); + for (uint32_t il = 0; il < n_layer; ++il) { + val = std::max(val, n_embd_v_gqa(il)); + } + + return val; +} + uint32_t llama_hparams::n_embd_r() const { if (wkv_head_size != 0) { // for RWKV models diff --git a/examples/talk-llama/llama-hparams.h b/examples/talk-llama/llama-hparams.h index d0500e4d0fd..bd231224432 100644 --- a/examples/talk-llama/llama-hparams.h +++ b/examples/talk-llama/llama-hparams.h @@ -6,12 +6,13 @@ // bump if necessary #define LLAMA_MAX_LAYERS 512 -#define LLAMA_MAX_EXPERTS 256 // DeepSeekV3 +#define LLAMA_MAX_EXPERTS 384 // Kimi-K2 enum llama_expert_gating_func_type { - LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0, - LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX = 1, - LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2, + LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX = 1, + LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT = 3, // applied to the router weights instead of the logits }; enum llama_swa_type { @@ -73,6 +74,7 @@ struct llama_hparams { bool expert_weights_norm = false; uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE; uint32_t moe_every_n_layers = 0; + uint32_t nextn_predict_layers = 0; float f_norm_eps; float f_norm_rms_eps; @@ -98,7 +100,7 @@ struct llama_hparams { float rope_freq_scale_train; float rope_freq_scale_train_swa; uint32_t n_ctx_orig_yarn; - float rope_yarn_log_mul; + float rope_yarn_log_mul = 0.0f; std::array rope_sections; @@ -140,7 +142,7 @@ struct llama_hparams { // for Classifiers uint32_t n_cls_out = 1; - // llama4 + // llama4 smallthinker uint32_t n_moe_layer_step = 0; uint32_t n_no_rope_layer_step = 4; uint32_t n_attn_temp_floor_scale = 8192; @@ -161,9 +163,10 @@ struct llama_hparams { enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; // this value n_pattern means that every nth layer is dense (i.e. non-SWA) + // dense_first means whether the pattern is start with a dense layer // note that if n_pattern == 0, all layers are SWA // if n_pattern == 1, all layers are dense - // example: n_pattern = 3 + // example 1: n_pattern = 3, dense_first = false // il == 0: swa // il == 1: swa // il == 2: dense @@ -172,7 +175,13 @@ struct llama_hparams { // il == 5: dense // il == 6: swa // etc ... - void set_swa_pattern(uint32_t n_pattern); + // example 2: n_pattern = 2, dense_first = true + // il == 0: dense + // il == 1: swa + // il == 2: dense + // il == 3: swa + // etc ... + void set_swa_pattern(uint32_t n_pattern, bool dense_first = false); // return true if one of the layers is SWA bool is_swa_any() const; @@ -191,6 +200,14 @@ struct llama_hparams { // dimension of value embeddings across all k-v heads uint32_t n_embd_v_gqa(uint32_t il = 0) const; + // true if any layer has a different n_embd_k_gqa/n_embd_v_gqa + bool is_n_embd_k_gqa_variable() const; + bool is_n_embd_v_gqa_variable() const; + + // return the maximum n_embd_k_gqa/n_embd_v_gqa across all layers + uint32_t n_embd_k_gqa_max() const; + uint32_t n_embd_v_gqa_max() const; + // dimension of the rolling state embeddings // corresponds to Mamba's conv_states size or RWKV's token_shift states size uint32_t n_embd_r() const; diff --git a/examples/talk-llama/llama-kv-cache-unified-iswa.cpp b/examples/talk-llama/llama-kv-cache-unified-iswa.cpp index fe207ad5360..1e363fff2a5 100644 --- a/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +++ b/examples/talk-llama/llama-kv-cache-unified-iswa.cpp @@ -18,16 +18,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( bool v_trans, bool offload, bool swa_full, + bool unified, uint32_t kv_size, uint32_t n_seq_max, uint32_t n_ubatch, - uint32_t n_pad) : hparams(model.hparams) { + uint32_t n_pad) : hparams(model.hparams), unified(unified) { llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); }; llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); }; const uint32_t size_base = kv_size; - uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad)); + uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad)); // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size if (swa_full) { @@ -41,14 +42,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( kv_base = std::make_unique( model, std::move(filter_base), type_k, type_v, - v_trans, offload, size_base, n_seq_max, n_pad, + v_trans, offload, unified, size_base, n_seq_max, n_pad, 0, LLAMA_SWA_TYPE_NONE); LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); kv_swa = std::make_unique( model, std::move(filter_swa), type_k, type_v, - v_trans, offload, size_swa, n_seq_max, n_pad, + v_trans, offload, unified, size_swa, n_seq_max, n_pad, hparams.n_swa, hparams.swa_type); } @@ -100,6 +101,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all // first try simple split do { + if (!unified) { + // requires equal splits, so we skip the simple split + break; + } + balloc.split_reset(); std::vector ubatches; @@ -140,7 +146,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all std::vector ubatches; while (true) { - auto ubatch = balloc.split_equal(n_ubatch, false); + auto ubatch = balloc.split_equal(n_ubatch, !unified); if (ubatch.n_tokens == 0) { break; @@ -188,14 +194,20 @@ bool llama_kv_cache_unified_iswa::get_can_shift() const { return kv_base->get_size() == kv_swa->get_size(); } -void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { - kv_base->state_write(io, seq_id); - kv_swa ->state_write(io, seq_id); +void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) { + kv_base->state_write(io, seq_id, flags); + } + + kv_swa->state_write(io, seq_id, flags); } -void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) { - kv_base->state_read(io, seq_id); - kv_swa ->state_read(io, seq_id); +void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) { + kv_base->state_read(io, seq_id, flags); + } + + kv_swa->state_read(io, seq_id, flags); } llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const { diff --git a/examples/talk-llama/llama-kv-cache-unified-iswa.h b/examples/talk-llama/llama-kv-cache-unified-iswa.h index 23205d826b2..7bc4df718d3 100644 --- a/examples/talk-llama/llama-kv-cache-unified-iswa.h +++ b/examples/talk-llama/llama-kv-cache-unified-iswa.h @@ -20,6 +20,7 @@ class llama_kv_cache_unified_iswa : public llama_memory_i { bool v_trans, bool offload, bool swa_full, + bool unified, uint32_t kv_size, uint32_t n_seq_max, uint32_t n_ubatch, @@ -55,8 +56,8 @@ class llama_kv_cache_unified_iswa : public llama_memory_i { // state write/load - void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; - void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; // // llama_kv_cache_unified_iswa specific API @@ -68,6 +69,8 @@ class llama_kv_cache_unified_iswa : public llama_memory_i { private: const llama_hparams & hparams; + const bool unified; + std::unique_ptr kv_base; std::unique_ptr kv_swa; }; diff --git a/examples/talk-llama/llama-kv-cache-unified.cpp b/examples/talk-llama/llama-kv-cache-unified.cpp index d3129cc5328..478ebffac0f 100644 --- a/examples/talk-llama/llama-kv-cache-unified.cpp +++ b/examples/talk-llama/llama-kv-cache-unified.cpp @@ -23,13 +23,14 @@ llama_kv_cache_unified::llama_kv_cache_unified( ggml_type type_v, bool v_trans, bool offload, + bool unified, uint32_t kv_size, uint32_t n_seq_max, uint32_t n_pad, uint32_t n_swa, llama_swa_type swa_type) : model(model), hparams(model.hparams), v_trans(v_trans), - n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { + n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { GGML_ASSERT(kv_size % n_pad == 0); @@ -38,6 +39,10 @@ llama_kv_cache_unified::llama_kv_cache_unified( if (model.arch == LLM_ARCH_GEMMA3N) { n_layer_cache = 20; } + if (model.arch == LLM_ARCH_GLM4_MOE) { + // GLM-4.5: Only process up to last layer, skip final NextN layer + n_layer_cache = hparams.n_layer - hparams.nextn_predict_layers; + } // create a context for each buffer type std::map ctx_map; @@ -45,7 +50,7 @@ llama_kv_cache_unified::llama_kv_cache_unified( auto it = ctx_map.find(buft); if (it == ctx_map.end()) { ggml_init_params params = { - /*.mem_size =*/ size_t(2u*n_layer_cache*ggml_tensor_overhead()), + /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_cache*ggml_tensor_overhead()), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; @@ -64,9 +69,33 @@ llama_kv_cache_unified::llama_kv_cache_unified( return it->second; }; - head = 0; + GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max); + + v_heads.resize(n_stream); + for (uint32_t s = 0; s < n_stream; ++s) { + v_heads[s] = 0; + } + + v_cells.resize(n_stream); + for (uint32_t s = 0; s < n_stream; ++s) { + v_cells[s].resize(kv_size); + } - cells.resize(kv_size); + // by default, all sequence ids are mapped to the 0th stream + seq_to_stream.resize(LLAMA_MAX_SEQ, 0); + + if (n_stream > 1) { + seq_to_stream.resize(n_stream, 0); + for (uint32_t s = 0; s < n_stream; ++s) { + seq_to_stream[s] = s; + } + } + + // [TAG_V_CACHE_VARIABLE] + if (v_trans && hparams.is_n_embd_v_gqa_variable()) { + LLAMA_LOG_WARN("%s: the V embeddings have different sizes across layers and FA is not enabled - padding V cache to %d\n", + __func__, hparams.n_embd_v_gqa_max()); + } for (uint32_t il = 0; il < n_layer_cache; il++) { if (filter && !filter(il)) { @@ -74,8 +103,9 @@ llama_kv_cache_unified::llama_kv_cache_unified( continue; } - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + // [TAG_V_CACHE_VARIABLE] + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max(); const char * dev_name = "CPU"; @@ -98,14 +128,23 @@ llama_kv_cache_unified::llama_kv_cache_unified( ggml_tensor * k; ggml_tensor * v; - k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size); - v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size); + k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream); + v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream); ggml_format_name(k, "cache_k_l%d", il); ggml_format_name(v, "cache_v_l%d", il); + std::vector k_stream; + std::vector v_stream; + + for (uint32_t s = 0; s < n_stream; ++s) { + k_stream.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2])); + v_stream.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2])); + } + map_layer_ids[il] = layers.size(); - layers.push_back({ il, k, v }); + + layers.push_back({ il, k, v, k_stream, v_stream, }); } // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE] @@ -148,8 +187,8 @@ llama_kv_cache_unified::llama_kv_cache_unified( const size_t memory_size_k = size_k_bytes(); const size_t memory_size_v = size_v_bytes(); - LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, - (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, + LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_stream, ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); } @@ -158,7 +197,12 @@ llama_kv_cache_unified::llama_kv_cache_unified( debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0; const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS"); - supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0; + supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) != 0 : supports_set_rows; + + if (!supports_set_rows) { + // ref: https://github.com/ggml-org/llama.cpp/pull/14363 + GGML_ASSERT(unified && "cannot use non-unified KV cache without ggml_set_rows() support"); + } if (!supports_set_rows) { LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__); @@ -166,9 +210,10 @@ llama_kv_cache_unified::llama_kv_cache_unified( } void llama_kv_cache_unified::clear(bool data) { - cells.reset(); - - head = 0; + for (uint32_t s = 0; s < n_stream; ++s) { + v_cells[s].reset(); + v_heads[s] = 0; + } if (data) { for (auto & buf : bufs) { @@ -178,7 +223,7 @@ void llama_kv_cache_unified::clear(bool data) { } bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - uint32_t new_head = cells.size(); + GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size())); if (p0 < 0) { p0 = 0; @@ -189,6 +234,11 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos } if (seq_id >= 0) { + auto & cells = v_cells[seq_to_stream[seq_id]]; + auto & head = v_heads[seq_to_stream[seq_id]]; + + uint32_t new_head = cells.size(); + for (uint32_t i = 0; i < cells.size(); ++i) { if (!cells.pos_in(i, p0, p1)) { continue; @@ -200,54 +250,130 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos } } } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cells.size() && new_head < head) { + head = new_head; + } } else { // match any sequence - for (uint32_t i = 0; i < cells.size(); ++i) { - if (!cells.pos_in(i, p0, p1)) { - continue; - } + for (uint32_t s = 0; s < n_stream; ++s) { + auto & cells = v_cells[s]; + auto & head = v_heads[s]; - cells.rm(i); + uint32_t new_head = cells.size(); - if (new_head == cells.size()) { - new_head = i; + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + cells.rm(i); + + if (new_head == cells.size()) { + new_head = i; + } } - } - } - // If we freed up a slot, set head to it so searching can start there. - if (new_head != cells.size() && new_head < head) { - head = new_head; + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cells.size() && new_head < head) { + head = new_head; + } + } } return true; } void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - if (seq_id_src == seq_id_dst) { + GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size()); + GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size()); + + const auto s0 = seq_to_stream[seq_id_src]; + const auto s1 = seq_to_stream[seq_id_dst]; + + if (s0 == s1) { + // since both sequences are in the same stream, no data copy is necessary + // we just have to update the cells meta data + + auto & cells = v_cells[s0]; + + if (seq_id_src == seq_id_dst) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + if (cells.seq_has(i, seq_id_src)) { + cells.seq_add(i, seq_id_dst); + } + } + return; } - if (p0 < 0) { - p0 = 0; + // cross-stream sequence copies require to copy the actual buffer data + + bool is_full = true; + + if (p0 > 0 && p0 + 1 < (int) get_size()) { + is_full = false; } - if (p1 < 0) { - p1 = std::numeric_limits::max(); + if (p1 > 0 && p1 + 1 < (int) get_size()) { + is_full = false; } - for (uint32_t i = 0; i < cells.size(); ++i) { - if (!cells.pos_in(i, p0, p1)) { - continue; - } + GGML_ASSERT(is_full && "seq_cp() is only supported for full KV buffers"); + + // enqueue the copy operation - the buffer copy will be performed during the next update + sc_info.ssrc.push_back(s0); + sc_info.sdst.push_back(s1); - if (cells.seq_has(i, seq_id_src)) { - cells.seq_add(i, seq_id_dst); + v_cells[s1].reset(); + for (uint32_t i = 0; i < v_cells[s0].size(); ++i) { + if (v_cells[s0].seq_has(i, seq_id_src)) { + llama_pos pos = v_cells[s0].pos_get(i); + llama_pos shift = v_cells[s0].get_shift(i); + + if (shift != 0) { + pos -= shift; + assert(pos >= 0); + } + + v_cells[s1].pos_set(i, pos); + v_cells[s1].seq_add(i, seq_id_dst); + + if (shift != 0) { + v_cells[s1].pos_add(i, shift); + } } } + + v_heads[s1] = v_heads[s0]; + + //for (uint32_t s = 0; s < n_stream; ++s) { + // LLAMA_LOG_WARN("%s: seq %d: min = %d, max = %d\n", __func__, s, v_cells[s].seq_pos_min(s), v_cells[s].seq_pos_max(s)); + //} } void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); + + auto & cells = v_cells[seq_to_stream[seq_id]]; + auto & head = v_heads[seq_to_stream[seq_id]]; + uint32_t new_head = cells.size(); for (uint32_t i = 0; i < cells.size(); ++i) { @@ -265,6 +391,11 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { } void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); + + auto & cells = v_cells[seq_to_stream[seq_id]]; + auto & head = v_heads[seq_to_stream[seq_id]]; + if (shift == 0) { return; } @@ -304,6 +435,10 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po } void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); + + auto & cells = v_cells[seq_to_stream[seq_id]]; + if (d == 1) { return; } @@ -333,10 +468,18 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po } llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); + + const auto & cells = v_cells[seq_to_stream[seq_id]]; + return cells.seq_pos_min(seq_id); } llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); + + const auto & cells = v_cells[seq_to_stream[seq_id]]; + return cells.seq_pos_max(seq_id); } @@ -351,7 +494,7 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch( std::vector ubatches; while (true) { - auto ubatch = balloc.split_simple(n_ubatch); + auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true); if (ubatch.n_tokens == 0) { break; @@ -387,7 +530,10 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct defrag_info dinfo; // see if we need to defrag - { + if (n_stream == 1) { + // note : for now do not consider defrag for n_stream > 1 + const auto & cells = v_cells[seq_to_stream[0]]; + bool do_defrag = optimize; const auto thold = lctx->get_cparams().defrag_thold; @@ -411,22 +557,22 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct } } - return std::make_unique(this, lctx, do_shift, std::move(dinfo)); + return std::make_unique(this, lctx, do_shift, std::move(dinfo), std::move(sc_info)); } llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector & ubatches) { llama_kv_cache_unified::slot_info_vec_t res; - struct state { - uint32_t head_old; // old position of the head, before placing the ubatch - + struct state_t { slot_info sinfo; // slot info for the ubatch - llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch + std::vector v_heads_old; // old positions of the heads, before placing the ubatch + + std::vector v_cells; // copy of the old cells, before placing the ubatch }; // remember the old state of the cells so we can restore it in the end - std::vector states; + std::vector states; bool success = true; @@ -445,16 +591,35 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st res.push_back(sinfo_new); // store the old state of the cells in the recovery stack - states.push_back({head, sinfo_new, cells.cp(sinfo_new.idxs)}); + { + state_t state = { sinfo_new, v_heads, {} }; + + for (uint32_t s = 0; s < sinfo_new.n_stream(); ++s) { + auto & cells = v_cells[sinfo_new.strm[s]]; + + state.v_cells.push_back(cells.cp(sinfo_new.idxs[s])); + } + + states.push_back(std::move(state)); + } // now emplace the ubatch apply_ubatch(sinfo_new, ubatch); } + GGML_ASSERT(!states.empty() || !success); + // iterate backwards and restore the cells to their original state for (auto it = states.rbegin(); it != states.rend(); ++it) { - cells.set(it->sinfo.idxs, it->cells); - head = it->head_old; + const auto & sinfo = it->sinfo; + + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + auto & cells = v_cells[sinfo.strm[s]]; + auto & head = v_heads[sinfo.strm[s]]; + + cells.set(sinfo.idxs[s], it->v_cells[s]); + head = it->v_heads_old[s]; + } } if (!success) { @@ -464,11 +629,38 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st return res; } -bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo) { +bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info) { bool updated = false; auto * sched = lctx->get_sched(); + if (!sc_info.empty()) { + assert(n_stream > 1 && "stream copy should never happen with a single stream"); + + llama_synchronize(lctx); + + const size_t n_copy = sc_info.ssrc.size(); + + for (size_t i = 0; i < n_copy; ++i) { + const auto ssrc = sc_info.ssrc[i]; + const auto sdst = sc_info.sdst[i]; + + assert(ssrc < n_stream); + assert(sdst < n_stream); + + LLAMA_LOG_DEBUG("%s: copying KV buffer: stream %d to stream %d\n", __func__, ssrc, sdst); + + assert(ssrc != sdst); + + for (uint32_t il = 0; il < layers.size(); ++il) { + const auto & layer = layers[il]; + + ggml_backend_tensor_copy(layer.k_stream[ssrc], layer.k_stream[sdst]); + ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]); + } + } + } + if (do_shift) { if (!get_can_shift()) { GGML_ABORT("The current KV cache / model configuration does not support K-shift"); @@ -480,14 +672,11 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { ggml_backend_sched_reset(sched); - auto * gf = lctx->graph_init(); + auto * res = lctx->get_gf_res_reserve(); - auto res = build_graph_shift(lctx->get_cparams(), lctx->get_ctx_compute(), gf); - if (!res) { - LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__); - return updated; - } + res->reset(); + auto * gf = build_graph_shift(res, lctx); if (!ggml_backend_sched_alloc_graph(sched, gf)) { LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__); return updated; @@ -503,12 +692,20 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d updated = true; } - cells.reset_shift(); + for (uint32_t s = 0; s < n_stream; ++s) { + auto & cells = v_cells[s]; + + cells.reset_shift(); + } } if (!dinfo.empty()) { LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); + // note: for now do not consider defrag for n_stream > 1 + auto & cells = v_cells[seq_to_stream[0]]; + auto & head = v_heads[seq_to_stream[0]]; + // apply moves: { const auto n_kv = dinfo.ids.size(); @@ -529,14 +726,11 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d ggml_backend_sched_reset(sched); - auto * gf = lctx->graph_init(); + auto * res = lctx->get_gf_res_reserve(); - auto res = build_graph_defrag(lctx->get_cparams(), lctx->get_ctx_compute(), gf, dinfo); - if (!res) { - LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__); - return updated; - } + res->reset(); + auto * gf = build_graph_defrag(res, lctx, dinfo); if (!ggml_backend_sched_alloc_graph(sched, gf)) { LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__); return updated; @@ -556,159 +750,200 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d } llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const { - const uint32_t n_tokens = ubatch.n_tokens; - - uint32_t head_cur = this->head; - - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (head_cur > cells.get_used() + 2*ubatch.n_tokens) { - head_cur = 0; - } - - if (n_tokens > cells.size()) { - LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); - return { }; - } if (debug > 0) { - LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa); + for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { + const auto seq_id = ubatch.seq_id_unq[s]; + const auto stream_id = seq_to_stream[seq_id]; + const auto & cells = v_cells[stream_id]; + const uint32_t head_cur = v_heads[stream_id]; + + LLAMA_LOG_DEBUG("%s: stream[%d], n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", + __func__, stream_id, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa); + + if ((debug == 2 && n_swa > 0) || debug > 2) { + std::string ss; + for (uint32_t i = 0; i < cells.size(); ++i) { + if (cells.is_empty(i)) { + ss += '.'; + } else { + assert(cells.seq_count(i) >= 1); - if ((debug == 2 && n_swa > 0) || debug > 2) { - std::string ss; - for (uint32_t i = 0; i < cells.size(); ++i) { - if (cells.is_empty(i)) { - ss += '.'; - } else { - assert(cells.seq_count(i) >= 1); + if (cells.seq_count(i) == 1) { + ss += std::to_string(cells.seq_get(i)); + } else { + ss += 'M'; + } + } + if (i%256 == 255) { + ss += " *"; + ss += '\n'; + } + } + LLAMA_LOG_DEBUG("\n%s\n", ss.c_str()); + } - if (cells.seq_count(i) == 1) { - ss += std::to_string(cells.seq_get(i)); + if ((debug == 2 && n_swa > 0) || debug > 2) { + std::string ss; + for (uint32_t i = 0; i < cells.size(); ++i) { + std::string cur; + if (cells.is_empty(i)) { + cur = '.'; } else { - ss += 'M'; + cur = std::to_string(cells.pos_get(i)); + } + const int n = cur.size(); + for (int j = 0; j < 5 - n; ++j) { + cur += ' '; + } + ss += cur; + if (i%256 == 255) { + ss += " *"; + } + if (i%64 == 63) { + ss += '\n'; } } - if (i%256 == 255) { - ss += " *"; - ss += '\n'; - } + LLAMA_LOG_DEBUG("\n%s\n", ss.c_str()); } - LLAMA_LOG_DEBUG("\n%s\n", ss.c_str()); - } - if ((debug == 2 && n_swa > 0) || debug > 2) { - std::string ss; - for (uint32_t i = 0; i < cells.size(); ++i) { - std::string cur; - if (cells.is_empty(i)) { - cur = '.'; - } else { - cur = std::to_string(cells.pos_get(i)); - } - const int n = cur.size(); - for (int j = 0; j < 5 - n; ++j) { - cur += ' '; - } - ss += cur; - if (i%256 == 255) { - ss += " *"; - } - if (i%64 == 63) { - ss += '\n'; + for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + if (cells.seq_pos_min(s) < 0) { + continue; } + + LLAMA_LOG_DEBUG("%s: stream[%d] min[%d] = %5d, max[%d] = %5d\n", __func__, stream_id, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s)); } - LLAMA_LOG_DEBUG("\n%s\n", ss.c_str()); } + } - for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { - if (cells.seq_pos_min(s) < 0) { - continue; - } + uint32_t n_tokens = ubatch.n_tokens; + uint32_t n_seqs = 1; - LLAMA_LOG_DEBUG("%s: min[%d] = %5d, max[%d] = %5d\n", __func__, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s)); - } + if (n_stream > 1) { + GGML_ASSERT(n_tokens % ubatch.n_seqs_unq == 0); + + n_seqs = ubatch.n_seqs_unq; + n_tokens = n_tokens / n_seqs; } - uint32_t n_tested = 0; + slot_info res = { + /*.s0 =*/ LLAMA_MAX_SEQ, + /*.s1 =*/ 0, + /*.strm =*/ { }, + /*.idxs =*/ { }, + }; - // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head - // for non-continuous slots, we test the tokens one by one - const uint32_t n_test = cont ? n_tokens : 1; + res.resize(n_seqs); - slot_info res; + for (uint32_t s = 0; s < n_seqs; ++s) { + const auto seq_id = ubatch.seq_id_unq[s]; + + if (n_stream > 1) { + GGML_ASSERT(ubatch.n_seq_id[s*n_tokens] == 1); + GGML_ASSERT(ubatch.seq_id [s*n_tokens][0] == seq_id); + } - auto & idxs = res.idxs; + res.s0 = std::min(res.s0, seq_to_stream[seq_id]); + res.s1 = std::max(res.s1, seq_to_stream[seq_id]); - idxs.reserve(n_tokens); + res.strm[s] = seq_to_stream[seq_id]; + res.idxs[s].reserve(n_tokens); - while (true) { - if (head_cur + n_test > cells.size()) { - n_tested += cells.size() - head_cur; + const auto & cells = v_cells[seq_to_stream[seq_id]]; + + uint32_t head_cur = v_heads[seq_to_stream[seq_id]]; + + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (head_cur > cells.get_used() + 2*n_tokens) { head_cur = 0; - continue; } - for (uint32_t i = 0; i < n_test; i++) { - const auto idx = head_cur; + if (n_tokens > cells.size()) { + LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); + return { }; + } + + uint32_t n_tested = 0; + + // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head + // for non-continuous slots, we test the tokens one by one + const uint32_t n_test = cont ? n_tokens : 1; - //const llama_pos pos = ubatch.pos[i]; - //const llama_seq_id seq_id = ubatch.seq_id[i][0]; + while (true) { + if (head_cur + n_test > cells.size()) { + n_tested += cells.size() - head_cur; + head_cur = 0; + continue; + } - // can we use this cell? either: - // - the cell is empty - // - the cell is occupied only by one sequence: - // - (disabled) mask causally, if the sequence is the same as the one we are inserting - // - mask SWA, using current max pos for that sequence in the cache - // always insert in the cell with minimum pos - bool can_use = cells.is_empty(idx); + for (uint32_t i = 0; i < n_test; i++) { + const auto idx = head_cur; - if (!can_use && cells.seq_count(idx) == 1) { - const llama_pos pos_cell = cells.pos_get(idx); + head_cur++; + n_tested++; - // (disabled) causal mask - // note: it's better to purge any "future" tokens beforehand - //if (cells.seq_has(idx, seq_id)) { - // can_use = pos_cell >= pos; - //} + //const llama_pos pos = ubatch.pos[i]; + //const llama_seq_id seq_id = ubatch.seq_id[i][0]; - if (!can_use) { - const llama_seq_id seq_id_cell = cells.seq_get(idx); + // can we use this cell? either: + // - the cell is empty + // - the cell is occupied only by one sequence: + // - (disabled) mask causally, if the sequence is the same as the one we are inserting + // - mask SWA, using current max pos for that sequence in the cache + // always insert in the cell with minimum pos + bool can_use = cells.is_empty(idx); - // SWA mask - if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { - can_use = true; + if (!can_use && cells.seq_count(idx) == 1) { + const llama_pos pos_cell = cells.pos_get(idx); + + // (disabled) causal mask + // note: it's better to purge any "future" tokens beforehand + //if (cells.seq_has(idx, seq_id)) { + // can_use = pos_cell >= pos; + //} + + if (!can_use) { + const llama_seq_id seq_id_cell = cells.seq_get(idx); + + // SWA mask + if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { + can_use = true; + } } } - } - head_cur++; - n_tested++; + if (can_use) { + res.idxs[s].push_back(idx); + } else { + if (cont) { + break; + } + } + } - if (can_use) { - idxs.push_back(idx); - } else { + if (res.idxs[s].size() == n_tokens) { break; } - } - if (idxs.size() == n_tokens) { - break; - } + if (cont) { + res.idxs[s].clear(); + } - if (cont) { - idxs.clear(); + if (n_tested >= cells.size()) { + //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + return { }; + } } - if (n_tested >= cells.size()) { - //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + // we didn't find a suitable slot - return empty result + if (res.idxs[s].size() < n_tokens) { return { }; } } - // we didn't find a suitable slot - return empty result - if (idxs.size() < n_tokens) { - res.clear(); - } + assert(res.s1 >= res.s0); return res; } @@ -717,41 +952,51 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u // keep track of the max sequence position that we would overwrite with this ubatch // for non-SWA cache, this would be always empty llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; - for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { seq_pos_max_rm[s] = -1; } - assert(ubatch.n_tokens == sinfo.idxs.size()); + assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size()); - for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { - const auto idx = sinfo.idxs.at(i); + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + for (uint32_t ii = 0; ii < sinfo.size(); ++ii) { + const uint32_t i = s*sinfo.size() + ii; - if (!cells.is_empty(idx)) { - assert(cells.seq_count(idx) == 1); + auto & cells = v_cells[sinfo.strm[s]]; - const llama_seq_id seq_id = cells.seq_get(idx); - const llama_pos pos = cells.pos_get(idx); + const auto idx = sinfo.idxs[s][ii]; - seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); + if (!cells.is_empty(idx)) { + assert(cells.seq_count(idx) == 1); - cells.rm(idx); - } + const llama_seq_id seq_id = cells.seq_get(idx); + const llama_pos pos = cells.pos_get(idx); - cells.pos_set(idx, ubatch.pos[i]); + seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); - for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { - cells.seq_add(idx, ubatch.seq_id[i][s]); + cells.rm(idx); + } + + cells.pos_set(idx, ubatch.pos[i]); + + for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { + cells.seq_add(idx, ubatch.seq_id[i][s]); + } } } // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence // will be present in the cache. so we have to purge any position which is less than those we would overwrite // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 - for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { if (seq_pos_max_rm[s] == -1) { continue; } + GGML_ASSERT(s < seq_to_stream.size()); + + auto & cells = v_cells[seq_to_stream[s]]; + if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) { LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n", __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s); @@ -761,7 +1006,11 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u } // move the head at the end of the slot - head = sinfo.idxs.back() + 1; + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + auto & head = v_heads[sinfo.strm[s]]; + + head = sinfo.idxs[s].back() + 1; + } } bool llama_kv_cache_unified::get_can_shift() const { @@ -769,49 +1018,91 @@ bool llama_kv_cache_unified::get_can_shift() const { } uint32_t llama_kv_cache_unified::get_size() const { + const auto & cells = v_cells[seq_to_stream[0]]; + return cells.size(); } +uint32_t llama_kv_cache_unified::get_n_stream() const { + return n_stream; +} + bool llama_kv_cache_unified::get_has_shift() const { - return cells.get_has_shift(); + bool result = false; + + for (uint32_t s = 0; s < n_stream; ++s) { + result |= v_cells[s].get_has_shift(); + } + + return result; } uint32_t llama_kv_cache_unified::get_n_kv() const { - return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))); + uint32_t result = 0; + + for (uint32_t s = 0; s < n_stream; ++s) { + const auto & cells = v_cells[s]; + + result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result); + } + + return result; } -ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const { +bool llama_kv_cache_unified::get_supports_set_rows() const { + return supports_set_rows; +} + +ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const { const int32_t ikv = map_layer_ids.at(il); auto * k = layers[ikv].k; - return ggml_view_3d(ctx, k, - hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, + const uint64_t kv_size = get_size(); + const uint64_t n_embd_k_gqa = k->ne[0]; + + assert(n_embd_k_gqa == hparams.n_embd_k_gqa(il)); + + const uint32_t ns = sinfo.s1 - sinfo.s0 + 1; + + return ggml_view_4d(ctx, k, + hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns, ggml_row_size(k->type, hparams.n_embd_head_k), - ggml_row_size(k->type, hparams.n_embd_k_gqa(il)), - 0); + ggml_row_size(k->type, n_embd_k_gqa), + ggml_row_size(k->type, n_embd_k_gqa*kv_size), + ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0); } -ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const { +ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const { const int32_t ikv = map_layer_ids.at(il); auto * v = layers[ikv].v; + const uint64_t kv_size = get_size(); + const uint64_t n_embd_v_gqa = v->ne[0]; + + // [TAG_V_CACHE_VARIABLE] + assert(n_embd_v_gqa >= hparams.n_embd_v_gqa(il)); + + const uint32_t ns = sinfo.s1 - sinfo.s0 + 1; + if (!v_trans) { // note: v->nb[1] <= v->nb[2] - return ggml_view_3d(ctx, v, - hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, - ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] - ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2] - 0); + return ggml_view_4d(ctx, v, + hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns, + ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] + ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2] + ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3] + ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0); } // note: v->nb[1] > v->nb[2] - return ggml_view_3d(ctx, v, - n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, - ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1] - ggml_row_size(v->type, v->ne[1]), // v->nb[2] - 0); + return ggml_view_4d(ctx, v, + n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns, + ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1] + ggml_row_size(v->type, kv_size), // v->nb[2] + ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3] + ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0); } ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const { @@ -825,12 +1116,18 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_ k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens); if (k_idxs && supports_set_rows) { + if (k->ne[2] > 1) { + k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]); + } + return ggml_set_rows(ctx, k, k_cur, k_idxs); } // TODO: fallback to old ggml_cpy() method for backwards compatibility // will be removed when ggml_set_rows() is adopted by all backends + GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS"); + ggml_tensor * k_view = ggml_view_1d(ctx, k, n_tokens*n_embd_k_gqa, ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head()); @@ -843,37 +1140,38 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_ auto * v = layers[ikv].v; - const int64_t n_embd_v_gqa = v->ne[0]; - const int64_t n_tokens = v_cur->ne[2]; + const int64_t n_embd_v_gqa = v_cur->ne[0]*v_cur->ne[1]; + const int64_t n_tokens = v_cur->ne[2]; v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens); if (v_idxs && supports_set_rows) { if (!v_trans) { + if (v->ne[2] > 1) { + v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]); + } + return ggml_set_rows(ctx, v, v_cur, v_idxs); } - // the row becomes a single element - ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]); + // [TAG_V_CACHE_VARIABLE] + if (n_embd_v_gqa < v->ne[0]) { + v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0); + } - // note: the V cache is transposed when not using flash attention - v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3); + // the row becomes a single element + ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]); - // note: we can be more explicit here at the cost of extra cont - // however, above we take advantage that a row of single element is always continuous regardless of the row stride - //v_cur = ggml_transpose(ctx, v_cur); - //v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]); + v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]); - // we broadcast the KV indices n_embd_v_gqa times - // v [1, n_kv, n_embd_v_gqa] - // v_cur [1, n_tokens, n_embd_v_gqa] - // v_idxs [n_tokens, 1, 1] return ggml_set_rows(ctx, v_view, v_cur, v_idxs); } // TODO: fallback to old ggml_cpy() method for backwards compatibility // will be removed when ggml_set_rows() is adopted by all backends + GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS"); + ggml_tensor * v_view = nullptr; if (!v_trans) { @@ -904,7 +1202,13 @@ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, con ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { const uint32_t n_tokens = ubatch.n_tokens; - ggml_tensor * v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens); + ggml_tensor * v_idxs; + + if (!v_trans) { + v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens); + } else { + v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa_max()); + } ggml_set_input(v_idxs); @@ -917,12 +1221,17 @@ void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_uba } const uint32_t n_tokens = ubatch->n_tokens; + GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream()); GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); int64_t * data = (int64_t *) dst->data; - for (int64_t i = 0; i < n_tokens; ++i) { - data[i] = sinfo.idxs.at(i); + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + const int64_t offs = sinfo.strm[s]*get_size(); + + for (uint32_t i = 0; i < sinfo.size(); ++i) { + data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i]; + } } } @@ -932,12 +1241,48 @@ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_uba } const uint32_t n_tokens = ubatch->n_tokens; + GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream()); GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); int64_t * data = (int64_t *) dst->data; - for (int64_t i = 0; i < n_tokens; ++i) { - data[i] = sinfo.idxs.at(i); + if (!v_trans) { + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + const int64_t offs = sinfo.strm[s]*get_size(); + + for (uint32_t i = 0; i < sinfo.size(); ++i) { + data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i]; + } + } + } else { + // note: the V cache is transposed when not using flash attention + const int64_t kv_size = get_size(); + + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max(); + + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + const int64_t offs = sinfo.strm[s]*kv_size*n_embd_v_gqa; + + for (uint32_t i = 0; i < sinfo.size(); ++i) { + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + data[s*sinfo.size()*n_embd_v_gqa + i*n_embd_v_gqa + j] = offs + j*kv_size + sinfo.idxs[s][i]; + } + } + } + } +} + +void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const { + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + + int32_t * data = (int32_t *) dst->data; + + for (uint32_t s = 0; s < n_stream; ++s) { + const auto & cells = v_cells[s]; + + for (uint32_t i = 0; i < cells.size(); ++i) { + data[s*cells.size() + i] = cells.is_empty(i) ? 0 : cells.get_shift(i); + } } } @@ -947,7 +1292,16 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); float * data = (float *) dst->data; - const int64_t n_kv = dst->ne[0]; + const int64_t n_kv = dst->ne[0]; + const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch + + GGML_ASSERT(n_tokens%n_stream == 0); + + // n_tps == n_tokens_per_stream + const int64_t n_tps = n_tokens/n_stream; + const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD); + + std::fill(data, data + ggml_nelements(dst), -INFINITY); // Use only the previous KV cells of the correct sequence for each token of the ubatch. // It's assumed that if a token in the batch has multiple sequences, they are equivalent. @@ -961,70 +1315,57 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub // xxxxx----- // xxxxx----- // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 + // TODO: optimize this section for (uint32_t h = 0; h < 1; ++h) { - for (uint32_t i = 0; i < n_tokens; ++i) { - const llama_seq_id seq_id = ubatch->seq_id[i][0]; + for (uint32_t s = 0; s < n_stream; ++s) { + for (uint32_t ii = 0; ii < n_tps; ++ii) { + const uint32_t i = s*n_tps + ii; - const llama_pos p1 = ubatch->pos[i]; + const llama_seq_id seq_id = ubatch->seq_id[i][0]; - for (uint32_t j = 0; j < n_kv; ++j) { - float f = 0.0f; + const auto & cells = v_cells[seq_to_stream[seq_id]]; - bool masked = false; + const llama_pos p1 = ubatch->pos[i]; - if (cells.is_empty(j)) { - masked = true; - } else { - const llama_pos p0 = cells.pos_get(j); + const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii); + + for (uint32_t j = 0; j < n_kv; ++j) { + if (cells.is_empty(j)) { + continue; + } // mask the token if not the same sequence - masked = masked || (!cells.seq_has(j, seq_id)); + if (!cells.seq_has(j, seq_id)) { + continue; + } + + const llama_pos p0 = cells.pos_get(j); // mask future tokens - masked = masked || (causal_attn && p0 > p1); + if (causal_attn && p0 > p1) { + continue; + } // apply SWA if any - masked = masked || (is_masked_swa(p0, p1)); - - if (!masked && hparams.use_alibi) { - f = -std::abs(p0 - p1); + if (is_masked_swa(p0, p1)) { + continue; } - } - - if (masked) { - f = -INFINITY; - } - - data[h*(n_kv*n_tokens) + i*n_kv + j] = f; - } - } - // mask padded tokens - if (data) { - for (uint32_t i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (uint32_t j = 0; j < n_kv; ++j) { - data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f; } } } } } -void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const { - GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); - - int32_t * data = (int32_t *) dst->data; - - for (uint32_t i = 0; i < cells.size(); ++i) { - data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i); - } -} - void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { const int64_t n_tokens = ubatch->n_tokens; + GGML_ASSERT(n_stream == 1 && "TODO: support multiple streams"); + const auto & cells = v_cells[0]; + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); - GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing + GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing int32_t * data = (int32_t *) dst->data; @@ -1129,7 +1470,7 @@ class llm_graph_input_k_shift : public llm_graph_input_i { void set_input(const llama_ubatch * ubatch) override; - ggml_tensor * k_shift; // I32 [kv_size] + ggml_tensor * k_shift; // I32 [kv_size*n_stream] const llama_kv_cache_unified * kv_self; }; @@ -1142,20 +1483,20 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { } } -llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( - const llama_cparams & cparams, - ggml_context * ctx, - ggml_cgraph * gf) const { - auto res = std::make_unique(); +ggml_cgraph * llama_kv_cache_unified::build_graph_shift(llm_graph_result * res, llama_context * lctx) const { + auto * ctx = res->get_ctx(); + auto * gf = res->get_gf(); const auto & n_embd_head_k = hparams.n_embd_head_k; //const auto & n_embd_head_v = hparams.n_embd_head_v; auto inp = std::make_unique(this); - inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cells.size()); + inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream); ggml_set_input(inp->k_shift); + const auto & cparams = lctx->get_cparams(); + for (const auto & layer : layers) { const uint32_t il = layer.il; @@ -1169,7 +1510,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( ggml_tensor * k = ggml_view_3d(ctx, layer.k, - n_embd_head_k, n_head_kv, cells.size(), + n_embd_head_k, n_head_kv, get_size()*n_stream, ggml_row_size(layer.k->type, n_embd_head_k), ggml_row_size(layer.k->type, n_embd_k_gqa), 0); @@ -1181,18 +1522,24 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( res->add_input(std::move(inp)); - return res; + return gf; } -llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( - const llama_cparams & cparams, - ggml_context * ctx, - ggml_cgraph * gf, - const defrag_info & dinfo) const { - auto res = std::make_unique(); +ggml_cgraph * llama_kv_cache_unified::build_graph_defrag( + llm_graph_result * res, + llama_context * lctx, + const defrag_info & dinfo) const { + auto * ctx = res->get_ctx(); + auto * gf = res->get_gf(); + + GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag"); + + const auto & cells = v_cells[0]; const auto & ids = dinfo.ids; + const auto & cparams = lctx->get_cparams(); + #if 0 // CPU defrag // @@ -1329,10 +1676,14 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes); #endif - return res; + return gf; } llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const { + GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag"); + + const auto & cells = v_cells[0]; + const uint32_t n_layer = layers.size(); const uint32_t n_kv = cells.used_max_p1(); @@ -1477,65 +1828,99 @@ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { return false; } -void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { - std::vector> cell_ranges; // ranges, from inclusive, to exclusive - uint32_t cell_count = 0; +void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + GGML_UNUSED(flags); - // Count the number of cells with the specified seq_id - // Find all the ranges of cells with this seq id (or all, when -1) - uint32_t cell_range_begin = cells.size(); + io.write(&n_stream, sizeof(n_stream)); - for (uint32_t i = 0; i < cells.size(); ++i) { - if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) { - ++cell_count; - if (cell_range_begin == cells.size()) { - cell_range_begin = i; - } - } else { - if (cell_range_begin != cells.size()) { - cell_ranges.emplace_back(cell_range_begin, i); - cell_range_begin = cells.size(); + for (uint32_t s = 0; s < n_stream; ++s) { + cell_ranges_t cr { s, {} }; + + uint32_t cell_count = 0; + + const auto & cells = v_cells[s]; + + // Count the number of cells with the specified seq_id + // Find all the ranges of cells with this seq id (or all, when -1) + uint32_t cell_range_begin = cells.size(); + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) { + ++cell_count; + if (cell_range_begin == cells.size()) { + cell_range_begin = i; + } + } else { + if (cell_range_begin != cells.size()) { + cr.data.emplace_back(cell_range_begin, i); + cell_range_begin = cells.size(); + } } } - } - if (cell_range_begin != cells.size()) { - cell_ranges.emplace_back(cell_range_begin, cells.size()); + if (cell_range_begin != cells.size()) { + cr.data.emplace_back(cell_range_begin, cells.size()); + } + + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count + uint32_t cell_count_check = 0; + for (const auto & range : cr.data) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(cell_count == cell_count_check); + + io.write(&cell_count, sizeof(cell_count)); + + // skip empty streams + if (cell_count == 0) { + continue; + } + + state_write_meta(io, cr, seq_id); + state_write_data(io, cr); } +} + +void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + GGML_UNUSED(flags); + + GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size())); - // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count - uint32_t cell_count_check = 0; - for (const auto & range : cell_ranges) { - cell_count_check += range.second - range.first; + uint32_t n_stream_cur; + io.read_to(&n_stream_cur, sizeof(n_stream_cur)); + if (n_stream_cur != n_stream) { + throw std::runtime_error("n_stream mismatch"); } - GGML_ASSERT(cell_count == cell_count_check); - io.write(&cell_count, sizeof(cell_count)); + for (uint32_t s = 0; s < n_stream; ++s) { + uint32_t cell_count; + io.read_to(&cell_count, sizeof(cell_count)); - state_write_meta(io, cell_ranges, seq_id); - state_write_data(io, cell_ranges); -} + if (cell_count == 0) { + continue; + } -void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) { - uint32_t cell_count; - io.read_to(&cell_count, sizeof(cell_count)); + const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id]; - bool res = true; - res = res && state_read_meta(io, cell_count, seq_id); - res = res && state_read_data(io, cell_count); + bool res = true; + res = res && state_read_meta(io, strm, cell_count, seq_id); + res = res && state_read_data(io, strm, cell_count); - if (!res) { - if (seq_id == -1) { - clear(true); - } else { - seq_rm(seq_id, -1, -1); + if (!res) { + if (seq_id == -1) { + clear(true); + } else { + seq_rm(seq_id, -1, -1); + } + throw std::runtime_error("failed to restore kv cache"); } - throw std::runtime_error("failed to restore kv cache"); } } -void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { - for (const auto & range : cell_ranges) { +void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const { + const auto & cells = v_cells[cr.strm]; + + for (const auto & range : cr.data) { for (uint32_t i = range.first; i < range.second; ++i) { std::vector seq_ids; @@ -1560,7 +1945,9 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std:: } } -void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const { +void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const { + const auto & cells = v_cells[cr.strm]; + const uint32_t v_trans = this->v_trans ? 1 : 0; const uint32_t n_layer = layers.size(); @@ -1576,19 +1963,21 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + auto * k = layer.k_stream[cr.strm]; + // Write key type - const int32_t k_type_i = (int32_t)layer.k->type; + const int32_t k_type_i = (int32_t) k->type; io.write(&k_type_i, sizeof(k_type_i)); // Write row size of key - const uint64_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa); + const uint64_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa); io.write(&k_size_row, sizeof(k_size_row)); // Read each range of cells of k_size length each into tmp_buf and write out - for (const auto & range : cell_ranges) { + for (const auto & range : cr.data) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * k_size_row; - io.write_tensor(layer.k, range.first * k_size_row, buf_size); + io.write_tensor(k, range.first * k_size_row, buf_size); } } @@ -1598,19 +1987,21 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + auto * v = layer.v_stream[cr.strm]; + // Write value type - const int32_t v_type_i = (int32_t)layer.v->type; + const int32_t v_type_i = (int32_t) v->type; io.write(&v_type_i, sizeof(v_type_i)); // Write row size of value - const uint64_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa); + const uint64_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa); io.write(&v_size_row, sizeof(v_size_row)); // Read each range of cells of v_size length each into tmp_buf and write out - for (const auto & range : cell_ranges) { + for (const auto & range : cr.data) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * v_size_row; - io.write_tensor(layer.v, range.first * v_size_row, buf_size); + io.write_tensor(v, range.first * v_size_row, buf_size); } } } else { @@ -1622,12 +2013,14 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + auto * v = layer.v_stream[cr.strm]; + // Write value type - const int32_t v_type_i = (int32_t)layer.v->type; + const int32_t v_type_i = (int32_t) v->type; io.write(&v_type_i, sizeof(v_type_i)); // Write element size - const uint32_t v_size_el = ggml_type_size(layer.v->type); + const uint32_t v_size_el = ggml_type_size(v->type); io.write(&v_size_el, sizeof(v_size_el)); // Write GQA embedding size @@ -1636,27 +2029,31 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: // For each row, we get the element values of each cell for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { // Read each range of cells of v_size_el length each into tmp_buf and write out - for (const auto & range : cell_ranges) { + for (const auto & range : cr.data) { const size_t range_size = range.second - range.first; const size_t src_offset = (range.first + j * kv_size) * v_size_el; const size_t buf_size = range_size * v_size_el; - io.write_tensor(layer.v, src_offset, buf_size); + io.write_tensor(v, src_offset, buf_size); } } } } } -bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { +bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) { + auto & cells = v_cells[strm]; + auto & head = v_heads[strm]; + if (dest_seq_id != -1) { // single sequence - seq_rm(dest_seq_id, -1, -1); llama_batch_allocr balloc(hparams.n_pos_per_embd()); llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1); + ubatch.seq_id_unq[0] = dest_seq_id; + for (uint32_t i = 0; i < cell_count; ++i) { llama_pos pos; uint32_t n_seq_id; @@ -1693,6 +2090,8 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell // keep the head at the old position because we will read the KV data into it in state_read_data() head = head_cur; + LLAMA_LOG_DEBUG("%s: head_cur = %d, head = %d, cell_count = %d, dest_seq_id = %d\n", __func__, head_cur, head, cell_count, dest_seq_id); + // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values) // Assume that this is one contiguous block of cells GGML_ASSERT(head_cur + cell_count <= cells.size()); @@ -1738,7 +2137,10 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell return true; } -bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) { +bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) { + auto & cells = v_cells[strm]; + auto & head = v_heads[strm]; + uint32_t v_trans; uint32_t n_layer; @@ -1766,10 +2168,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + auto * k = layer.k_stream[strm]; + // Read type of key int32_t k_type_i_ref; io.read_to(&k_type_i_ref, sizeof(k_type_i_ref)); - const int32_t k_type_i = (int32_t) layer.k->type; + const int32_t k_type_i = (int32_t) k->type; if (k_type_i != k_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); return false; @@ -1778,7 +2182,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell // Read row size of key uint64_t k_size_row_ref; io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); - const size_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa); + const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa); if (k_size_row != k_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); return false; @@ -1786,7 +2190,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell if (cell_count) { // Read and set the keys for the whole cell range - ggml_backend_tensor_set(layer.k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); + ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); } } @@ -1796,10 +2200,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + auto * v = layer.v_stream[strm]; + // Read type of value int32_t v_type_i_ref; io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); - const int32_t v_type_i = (int32_t)layer.v->type; + const int32_t v_type_i = (int32_t) v->type; if (v_type_i != v_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); return false; @@ -1808,7 +2214,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell // Read row size of value uint64_t v_size_row_ref; io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); - const size_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa); + const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa); if (v_size_row != v_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); return false; @@ -1816,7 +2222,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell if (cell_count) { // Read and set the values for the whole cell range - ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); + ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); } } } else { @@ -1826,10 +2232,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + auto * v = layer.v_stream[strm]; + // Read type of value int32_t v_type_i_ref; io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); - const int32_t v_type_i = (int32_t)layer.v->type; + const int32_t v_type_i = (int32_t) v->type; if (v_type_i != v_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); return false; @@ -1838,7 +2246,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell // Read element size of value uint32_t v_size_el_ref; io.read_to(&v_size_el_ref, sizeof(v_size_el_ref)); - const size_t v_size_el = ggml_type_size(layer.v->type); + const size_t v_size_el = ggml_type_size(v->type); if (v_size_el != v_size_el_ref) { LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); return false; @@ -1856,7 +2264,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell // For each row in the transposed matrix, read the values for the whole cell range for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { const size_t dst_offset = (head + j * cells.size()) * v_size_el; - ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); } } } @@ -1875,18 +2283,26 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) { n_kv = kv->get_size(); + const uint32_t n_stream = kv->get_n_stream(); + // create a dummy slot info - the actual data is irrelevant. we just need to build the graph sinfos.resize(1); - sinfos[0].idxs.resize(1); - sinfos[0].idxs[0] = 0; + sinfos[0].s0 = 0; + sinfos[0].s1 = n_stream - 1; + sinfos[0].idxs.resize(n_stream); + for (uint32_t s = 0; s < n_stream; ++s) { + sinfos[0].strm.push_back(s); + sinfos[0].idxs[s].resize(1, 0); + } } llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified * kv, llama_context * lctx, bool do_shift, - defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) { - if (!do_shift && this->dinfo.empty()) { + defrag_info dinfo, + stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)), sc_info(std::move(sc_info)) { + if (!do_shift && this->dinfo.empty() && this->sc_info.empty()) { status = LLAMA_MEMORY_STATUS_NO_UPDATE; } } @@ -1914,7 +2330,7 @@ bool llama_kv_cache_unified_context::apply() { // no ubatches -> this is a KV cache update if (ubatches.empty()) { - kv->update(lctx, do_shift, dinfo); + kv->update(lctx, do_shift, dinfo, sc_info); return true; } @@ -1940,12 +2356,16 @@ uint32_t llama_kv_cache_unified_context::get_n_kv() const { return n_kv; } +bool llama_kv_cache_unified_context::get_supports_set_rows() const { + return kv->get_supports_set_rows(); +} + ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const { - return kv->get_k(ctx, il, n_kv); + return kv->get_k(ctx, il, n_kv, sinfos[i_cur]); } ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const { - return kv->get_v(ctx, il, n_kv); + return kv->get_v(ctx, il, n_kv, sinfos[i_cur]); } ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const { diff --git a/examples/talk-llama/llama-kv-cache-unified.h b/examples/talk-llama/llama-kv-cache-unified.h index b8b0356e830..07a7c9e4e46 100644 --- a/examples/talk-llama/llama-kv-cache-unified.h +++ b/examples/talk-llama/llama-kv-cache-unified.h @@ -35,16 +35,50 @@ class llama_kv_cache_unified : public llama_memory_i { std::vector ids; }; + struct stream_copy_info { + bool empty() const { + assert(ssrc.size() == sdst.size()); + return ssrc.empty(); + } + + std::vector ssrc; + std::vector sdst; + }; + // for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the // KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]] struct slot_info { // data for ggml_set_rows using idx_vec_t = std::vector; - idx_vec_t idxs; + // number of streams: ns = s1 - s0 + 1 + llama_seq_id s0; + llama_seq_id s1; + + std::vector strm; // [ns] + std::vector idxs; // [ns] uint32_t head() const { - return idxs.at(0); + GGML_ASSERT(idxs.size() == 1); + GGML_ASSERT(!idxs[0].empty()); + + return idxs[0][0]; + } + + void resize(size_t n) { + strm.resize(n); + idxs.resize(n); + } + + size_t size() const { + GGML_ASSERT(idxs.size() == strm.size()); + GGML_ASSERT(!idxs.empty()); + + return idxs[0].size(); + } + + size_t n_stream() const { + return strm.size(); } bool empty() const { @@ -54,9 +88,6 @@ class llama_kv_cache_unified : public llama_memory_i { void clear() { idxs.clear(); } - - // TODO: implement - //std::vector seq_idxs; }; using slot_info_vec_t = std::vector; @@ -68,6 +99,7 @@ class llama_kv_cache_unified : public llama_memory_i { ggml_type type_v, bool v_trans, bool offload, + bool unified, uint32_t kv_size, uint32_t n_seq_max, uint32_t n_pad, @@ -104,14 +136,15 @@ class llama_kv_cache_unified : public llama_memory_i { // state write/load - void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; - void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; // // llama_kv_cache_unified specific API // - uint32_t get_size() const; + uint32_t get_size() const; + uint32_t get_n_stream() const; bool get_has_shift() const; @@ -121,9 +154,12 @@ class llama_kv_cache_unified : public llama_memory_i { uint32_t get_n_kv() const; + // TODO: temporary + bool get_supports_set_rows() const; + // get views of the current state of the cache - ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const; - ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const; + ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; + ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; // store k_cur and v_cur in the cache based on the provided head location ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const; @@ -137,7 +173,7 @@ class llama_kv_cache_unified : public llama_memory_i { // return empty vector on failure slot_info_vec_t prepare(const std::vector & ubatches); - bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo); + bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info); // find a slot of kv cells that can hold the ubatch // if cont == true, then the slot must be continuous @@ -157,8 +193,9 @@ class llama_kv_cache_unified : public llama_memory_i { void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; + void set_input_k_shift(ggml_tensor * dst) const; + void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; - void set_input_k_shift (ggml_tensor * dst) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; private: @@ -172,15 +209,15 @@ class llama_kv_cache_unified : public llama_memory_i { ggml_tensor * k; ggml_tensor * v; + + std::vector k_stream; + std::vector v_stream; }; bool v_trans = true; // the value tensor is transposed - // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) - // note: this is not part of the KV state and it's only used to speed-up the find_slot() method - uint32_t head = 0; - const uint32_t n_seq_max = 1; + const uint32_t n_stream = 1; // required padding const uint32_t n_pad = 1; @@ -193,14 +230,24 @@ class llama_kv_cache_unified : public llama_memory_i { // env: LLAMA_SET_ROWS (temporary) // ref: https://github.com/ggml-org/llama.cpp/pull/14285 - int supports_set_rows = false; + bool supports_set_rows = true; const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; std::vector ctxs; std::vector bufs; - llama_kv_cells_unified cells; + // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) + // note: this is not part of the KV state and it's only used to speed-up the find_slot() method + std::vector v_heads; + + std::vector v_cells; + + // maps from a sequence id to a stream id + std::vector seq_to_stream; + + // pending stream copies that will be applied during the next update + stream_copy_info sc_info; std::vector layers; @@ -226,29 +273,34 @@ class llama_kv_cache_unified : public llama_memory_i { float freq_base, float freq_scale) const; - llm_graph_result_ptr build_graph_shift( - const llama_cparams & cparams, - ggml_context * ctx, - ggml_cgraph * gf) const; + ggml_cgraph * build_graph_shift( + llm_graph_result * res, + llama_context * lctx) const; - llm_graph_result_ptr build_graph_defrag( - const llama_cparams & cparams, - ggml_context * ctx, - ggml_cgraph * gf, + ggml_cgraph * build_graph_defrag( + llm_graph_result * res, + llama_context * lctx, const defrag_info & dinfo) const; - void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; - void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; + struct cell_ranges_t { + uint32_t strm; - bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); - bool state_read_data(llama_io_read_i & io, uint32_t cell_count); + std::vector> data; // ranges, from inclusive, to exclusive + }; + + void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const; + void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const; + + bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1); + bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count); }; class llama_kv_cache_unified_context : public llama_memory_context_i { public: // some shorthands - using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t; - using defrag_info = llama_kv_cache_unified::defrag_info; + using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t; + using defrag_info = llama_kv_cache_unified::defrag_info; + using stream_copy_info = llama_kv_cache_unified::stream_copy_info; // used for errors llama_kv_cache_unified_context(llama_memory_status status); @@ -262,7 +314,8 @@ class llama_kv_cache_unified_context : public llama_memory_context_i { llama_kv_cache_unified * kv, llama_context * lctx, bool do_shift, - defrag_info dinfo); + defrag_info dinfo, + stream_copy_info sc_info); // used to create a batch procesing context from a batch llama_kv_cache_unified_context( @@ -288,6 +341,9 @@ class llama_kv_cache_unified_context : public llama_memory_context_i { uint32_t get_n_kv() const; + // TODO: temporary + bool get_supports_set_rows() const; + // get views of the current state of the cache ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; @@ -320,6 +376,8 @@ class llama_kv_cache_unified_context : public llama_memory_context_i { defrag_info dinfo; + stream_copy_info sc_info; + // // batch processing context // diff --git a/examples/talk-llama/llama-memory-hybrid.cpp b/examples/talk-llama/llama-memory-hybrid.cpp index 6cd10db06b7..cbeeb21344e 100644 --- a/examples/talk-llama/llama-memory-hybrid.cpp +++ b/examples/talk-llama/llama-memory-hybrid.cpp @@ -25,6 +25,7 @@ llama_memory_hybrid::llama_memory_hybrid( /* common */ uint32_t n_seq_max, bool offload, + bool unified, /* layer filters */ layer_filter_cb && filter_attn, layer_filter_cb && filter_recr) : @@ -38,6 +39,7 @@ llama_memory_hybrid::llama_memory_hybrid( type_v, v_trans, offload, + unified, kv_size, n_seq_max, n_pad, @@ -163,12 +165,16 @@ llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const { return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id)); } -void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { +void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + GGML_UNUSED(flags); + mem_attn->state_write(io, seq_id); mem_recr->state_write(io, seq_id); } -void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) { +void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + GGML_UNUSED(flags); + mem_attn->state_read(io, seq_id); mem_recr->state_read(io, seq_id); } diff --git a/examples/talk-llama/llama-memory-hybrid.h b/examples/talk-llama/llama-memory-hybrid.h index 4ac31817578..acdbc26bfb6 100644 --- a/examples/talk-llama/llama-memory-hybrid.h +++ b/examples/talk-llama/llama-memory-hybrid.h @@ -39,6 +39,7 @@ class llama_memory_hybrid : public llama_memory_i { /* common */ uint32_t n_seq_max, bool offload, + bool unified, /* layer filters */ layer_filter_cb && filter_attn = nullptr, layer_filter_cb && filter_recr = nullptr); @@ -73,8 +74,8 @@ class llama_memory_hybrid : public llama_memory_i { // state write/load - void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; - void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; // // llama_memory_hybrid specific API diff --git a/examples/talk-llama/llama-memory-recurrent.cpp b/examples/talk-llama/llama-memory-recurrent.cpp index 2c1ae67098c..849675c4188 100644 --- a/examples/talk-llama/llama-memory-recurrent.cpp +++ b/examples/talk-llama/llama-memory-recurrent.cpp @@ -446,7 +446,7 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) { // A slot should be always be contiguous. // can only process batches with an equal number of new tokens in each sequence - GGML_ASSERT(ubatch.equal_seqs); + GGML_ASSERT(ubatch.equal_seqs()); int32_t min = size - 1; int32_t max = 0; @@ -680,7 +680,9 @@ size_t llama_memory_recurrent::size_s_bytes() const { return size_s_bytes; } -void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { +void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + GGML_UNUSED(flags); + std::vector> cell_ranges; // ranges, from inclusive, to exclusive uint32_t cell_count = 0; @@ -718,7 +720,9 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq state_write_data(io, cell_ranges); } -void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) { +void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + GGML_UNUSED(flags); + uint32_t cell_count; io.read_to(&cell_count, sizeof(cell_count)); @@ -768,6 +772,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: // Iterate and write all the keys first, each row is a cell // Get whole range at a time for (uint32_t il = 0; il < n_layer; ++il) { + // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) + if (r_l[il] == nullptr) continue; // Write key type const int32_t r_type_i = (int32_t)r_l[il]->type; @@ -787,6 +793,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: if (!s_trans) { for (uint32_t il = 0; il < n_layer; ++il) { + // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) + if (s_l[il] == nullptr) continue; // Write value type const int32_t s_type_i = (int32_t)s_l[il]->type; @@ -807,6 +815,9 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: // When v is transposed, we also need the element size and get the element ranges from each row const uint32_t mem_size = size; for (uint32_t il = 0; il < n_layer; ++il) { + // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) + if (s_l[il] == nullptr) continue; + const uint32_t n_embd_s = hparams.n_embd_s(); // Write value type @@ -951,6 +962,8 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block for (uint32_t il = 0; il < n_layer; ++il) { + // skip null layers + if (r_l[il] == nullptr) continue; // Read type of key int32_t r_type_i_ref; @@ -978,11 +991,14 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell if (!s_trans) { for (uint32_t il = 0; il < n_layer; ++il) { + // skip null layers + if (s_l[il] == nullptr) continue; // Read type of value int32_t s_type_i_ref; io.read_to(&s_type_i_ref, sizeof(s_type_i_ref)); const int32_t s_type_i = (int32_t)s_l[il]->type; + if (s_type_i != s_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il); return false; @@ -1005,6 +1021,9 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell } else { // For each layer, read the values for each cell (transposed) for (uint32_t il = 0; il < n_layer; ++il) { + // skip null layers + if (s_l[il] == nullptr) continue; + const uint32_t n_embd_s = hparams.n_embd_s(); // Read type of value diff --git a/examples/talk-llama/llama-memory-recurrent.h b/examples/talk-llama/llama-memory-recurrent.h index 4d094f9a057..95c617b2c94 100644 --- a/examples/talk-llama/llama-memory-recurrent.h +++ b/examples/talk-llama/llama-memory-recurrent.h @@ -63,8 +63,8 @@ class llama_memory_recurrent : public llama_memory_i { // state write/load - void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; - void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) uint32_t size = 0; // total number of cells, shared across all sequences diff --git a/examples/talk-llama/llama-memory.h b/examples/talk-llama/llama-memory.h index e8ba336e852..171d312cc99 100644 --- a/examples/talk-llama/llama-memory.h +++ b/examples/talk-llama/llama-memory.h @@ -104,8 +104,8 @@ struct llama_memory_i { // state write/read // - virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0; - virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0; + virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const = 0; + virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) = 0; }; using llama_memory_ptr = std::unique_ptr; diff --git a/examples/talk-llama/llama-model-loader.cpp b/examples/talk-llama/llama-model-loader.cpp index bd9e6da8832..f71c40f8e3f 100644 --- a/examples/talk-llama/llama-model-loader.cpp +++ b/examples/talk-llama/llama-model-loader.cpp @@ -35,6 +35,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0"; case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1"; case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0"; + case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return "MXFP4 MoE"; case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium"; case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small"; case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small"; diff --git a/examples/talk-llama/llama-model-loader.h b/examples/talk-llama/llama-model-loader.h index 0f52b011b69..c9189f6cb44 100644 --- a/examples/talk-llama/llama-model-loader.h +++ b/examples/talk-llama/llama-model-loader.h @@ -58,8 +58,9 @@ struct llama_model_loader { } }; - static const int TENSOR_NOT_REQUIRED = 1; - static const int TENSOR_DUPLICATED = 2; + static const int TENSOR_NOT_REQUIRED = 1 << 0; + static const int TENSOR_DUPLICATED = 1 << 1; + static const int TENSOR_SKIP = 1 << 2; int n_kv = 0; int n_tensors = 0; diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index a322fc39352..23a26f0c64e 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -107,8 +107,12 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_17B_16E: return "17Bx16E (Scout)"; case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)"; case LLM_TYPE_A13B: return "A13B"; + case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; + case LLM_TYPE_106B_A12B: return "106B.A12B"; case LLM_TYPE_235B_A22B: return "235B.A22B"; + case LLM_TYPE_300B_A47B: return "300B.A47B"; + case LLM_TYPE_355B_A32B: return "355B.A32B"; case LLM_TYPE_E2B: return "E2B"; case LLM_TYPE_E4B: return "E4B"; default: return "?B"; @@ -188,6 +192,13 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); op_tensor = ggml_add(ctx, a, w); } break; + case GGML_OP_ADD_ID: + { + int n_expert_used = hparams.n_expert_used; + ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0], n_expert_used, 512); + ggml_tensor * c = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_expert_used, 512); + op_tensor = ggml_add_id(ctx, a, w, c); + } break; case GGML_OP_MUL: { ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); @@ -256,6 +267,10 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd, w->ne[1], 1, 1); op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16); } break; + case GGML_OP_SCALE: + { + op_tensor = ggml_scale(ctx, w, 1.0f); + } break; default: GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name); } @@ -288,7 +303,7 @@ static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hpara } // CPU: ACCEL -> GPU host -> CPU extra -> CPU -static buft_list_t make_cpu_buft_list(const std::vector & devices) { +static buft_list_t make_cpu_buft_list(const std::vector & devices, bool use_extra_bufts) { buft_list_t buft_list; // add ACCEL buffer types @@ -317,21 +332,22 @@ static buft_list_t make_cpu_buft_list(const std::vector & de } } - // add extra buffer types, only if no GPU device is present - // ref: https://github.com/ggml-org/llama.cpp/issues/12481#issuecomment-2743136094 - auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); - if (cpu_dev == nullptr) { - throw std::runtime_error(format("%s: no CPU backend found", __func__)); - } + // add extra buffer types + if (use_extra_bufts) { + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (cpu_dev == nullptr) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } - auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); - auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) - ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); - if (ggml_backend_dev_get_extra_bufts_fn) { - ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev); - while (extra_bufts && *extra_bufts) { - buft_list.emplace_back(cpu_dev, *extra_bufts); - ++extra_bufts; + auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); + auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) + ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); + if (ggml_backend_dev_get_extra_bufts_fn) { + ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev); + while (extra_bufts && *extra_bufts) { + buft_list.emplace_back(cpu_dev, *extra_bufts); + ++extra_bufts; + } } } @@ -644,6 +660,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + // MiniCPM uses rope by default, unlike Granite which uses it as a switch + hparams.rope_finetuned = true; + switch (hparams.n_layer) { case 52: type = LLM_TYPE_1B; break; case 40: type = LLM_TYPE_2B; break; @@ -849,6 +868,36 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_DREAM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // Dream models are primarily 7B with 28 layers + switch (hparams.n_layer) { + case 28: + type = LLM_TYPE_7B; + break; + default: + type = LLM_TYPE_UNKNOWN; + } + // Set non-causal attention for diffusion models + hparams.causal_attn = false; + } + break; + case LLM_ARCH_LLADA: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // LLaDA-8B has 32 layers, similar to LLaMA but for diffusion + switch (hparams.n_layer) { + case 32: + type = LLM_TYPE_8B; + break; + default: + type = LLM_TYPE_UNKNOWN; + } + // Set non-causal attention for diffusion models + hparams.causal_attn = false; + } + break; case LLM_ARCH_QWEN2MOE: { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); @@ -863,6 +912,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_QWEN3: { + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { case 28: type = hparams.n_embd == 1024 ? LLM_TYPE_0_6B : LLM_TYPE_1_7B; break; @@ -935,6 +985,33 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_PLAMO2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // Load Mamba SSM parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; + } + + switch (hparams.n_layer) { + case 16: type = LLM_TYPE_1B; break; + case 32: + if (hparams.n_embd == 2048) { + type = LLM_TYPE_2B; + } else if (hparams.n_embd == 4096) { + type = LLM_TYPE_8B; + } + break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_GPT2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -1018,6 +1095,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { + case 18: type = LLM_TYPE_537M; break; case 26: type = LLM_TYPE_1B; break; case 34: type = LLM_TYPE_4B; break; case 48: type = LLM_TYPE_12B; break; @@ -1322,7 +1400,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // that have no expert_gating_func model parameter set hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; } - ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false); switch (hparams.n_layer) { case 27: type = LLM_TYPE_16B; break; @@ -1370,6 +1448,34 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_GLM4_MOE: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // MoE parameters + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + // Expert gating function (GLM-4.5 uses sigmoid) + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } + + // NextN/MTP parameters + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + + switch (hparams.n_layer) { + case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer) + case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer) + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_BITNET: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -1446,6 +1552,23 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_EXAONE4: + { + if (hparams.n_layer == 64) { // 32B + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.n_swa = 4096; + hparams.set_swa_pattern(4); + } + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 30: type = LLM_TYPE_1_2B; break; + case 64: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_RWKV6: case LLM_ARCH_RWKV6QWEN2: { @@ -1483,7 +1606,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); switch (hparams.n_layer) { - case 12: type = LLM_TYPE_190M; break; + case 12: + switch (hparams.n_embd) { + case 768: type = LLM_TYPE_190M; break; + default: type = LLM_TYPE_UNKNOWN; + } break; case 24: switch (hparams.n_embd) { case 1024: type = LLM_TYPE_450M; break; @@ -1496,7 +1623,17 @@ void llama_model::load_hparams(llama_model_loader & ml) { case 3584: type = LLM_TYPE_7B; break; default: type = LLM_TYPE_UNKNOWN; } break; - case 32: type = LLM_TYPE_2_9B; break; // RWKV-7-World + case 32: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_2_9B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 61: + switch (hparams.n_embd) { + case 4096: type = LLM_TYPE_14B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; default: type = LLM_TYPE_UNKNOWN; } } break; @@ -1607,10 +1744,20 @@ void llama_model::load_hparams(llama_model_loader & ml) { } } break; case LLM_ARCH_ERNIE4_5: + case LLM_ARCH_ERNIE4_5_MOE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + if (arch == LLM_ARCH_ERNIE4_5_MOE) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + } + switch (hparams.n_layer) { case 18: type = LLM_TYPE_0_3B; break; + case 28: type = LLM_TYPE_21B_A3B; break; + case 54: type = LLM_TYPE_300B_A47B; break; default: type = LLM_TYPE_UNKNOWN; } } break; @@ -1656,6 +1803,18 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_HUNYUAN_DENSE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_embd) { + case 1024: type = LLM_TYPE_0_5B; break; + case 2048: type = LLM_TYPE_1_8B; break; + case 3072: type = LLM_TYPE_4B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_SMOLLM3: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -1666,6 +1825,17 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_OPENAI_MOE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.set_swa_pattern(2); + + // TODO: switch (hparams.n_layer) + } break; case LLM_ARCH_LFM2: { ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); @@ -1680,6 +1850,29 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_SMALLTHINKER: + { + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + + if (found_swa && hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.n_swa = 4096; + hparams.set_swa_pattern(4, true); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + hparams.n_no_rope_layer_step = hparams.n_layer; + } + + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_4B; break; + case 52: type = LLM_TYPE_20B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -1713,7 +1906,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s)\n", __func__, ml.use_mmap ? "true" : "false"); // build a list of buffer types for the CPU and GPU devices - pimpl->cpu_buft_list = make_cpu_buft_list(devices); + pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts); for (auto * dev : devices) { buft_list_t buft_list = make_gpu_buft_list(dev, split_mode, tensor_split); // add CPU buffer types as a fallback @@ -1809,6 +2002,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const auto TENSOR_DUPLICATED = llama_model_loader::TENSOR_DUPLICATED; const auto TENSOR_NOT_REQUIRED = llama_model_loader::TENSOR_NOT_REQUIRED; + const auto TENSOR_SKIP = llama_model_loader::TENSOR_SKIP; // create tensors for the weights { @@ -1864,7 +2058,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // skip unused tensors - if (info.op == GGML_OP_NONE) { + if (info.op == GGML_OP_NONE || flags & TENSOR_SKIP) { const size_t nbytes = ggml_nbytes(t_meta); LLAMA_LOG_WARN("model has unused tensor %s (size = %zu bytes) -- ignoring\n", tn.str().c_str(), nbytes); @@ -1874,11 +2068,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) { return nullptr; } - // tensors with "bias" suffix are always used with GGML_OP_ADD + // tensors with "bias" suffix are always used with GGML_OP_ADD or GGML_OP_ADD_ID ggml_op op; bool bias = tn.suffix != nullptr && strcmp(tn.suffix, "bias") == 0; if (bias) { - op = GGML_OP_ADD; + if (info.op == GGML_OP_MUL_MAT_ID) { + op = GGML_OP_ADD_ID; + } else { + op = GGML_OP_ADD; + } } else { op = info.op; } @@ -1918,7 +2116,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) { std::regex pattern(overrides->pattern); if (std::regex_search(tensor_name, pattern)) { - buft = overrides->buft; + if (overrides->buft == ggml_backend_cpu_buffer_type()) { + // when overriding to a CPU buffer, consider the extra buffer types + buft = select_weight_buft(hparams, t_meta, op, pimpl->cpu_buft_list); + } else { + buft = overrides->buft; + } + LLAMA_LOG_DEBUG("tensor %s (%zu MiB %s) buffer type overridden to %s\n", tensor_name.c_str(), ggml_nbytes(t_meta) / 1024 / 1024, ggml_type_name(t_meta->type), @@ -2038,6 +2242,53 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } } break; + case LLM_ARCH_LLADA: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = + create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + // Use separate Q, K, V projections without bias, matching LLaDALlamaBlock + layer.wq = + create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + // No bias for QKV projections as per config: include_bias=false, include_qkv_bias=false + layer.wo = + create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_rot / 2 }, + TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + + // optional MLP bias + layer.ffn_gate_b = + create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), { n_ff }, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = + create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), { n_ff }, TENSOR_NOT_REQUIRED); + } + } + break; case LLM_ARCH_LLAMA4: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -2643,12 +2894,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_QWEN2: case LLM_ARCH_QWEN2VL: + case LLM_ARCH_DREAM: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (output == NULL) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); @@ -2938,6 +3191,73 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } } break; + case LLM_ARCH_PLAMO2: + { + const uint32_t d_conv = hparams.ssm_d_conv; + const uint32_t d_state = hparams.ssm_d_state; + const uint32_t num_heads = hparams.ssm_dt_rank; + const uint32_t intermediate_size = hparams.ssm_d_inner; + const uint32_t head_dim = intermediate_size / num_heads; + const uint32_t qk_dim = head_dim; + const uint32_t v_dim = head_dim; + const int64_t num_attention_heads = hparams.n_head(); + const int64_t q_num_heads = num_attention_heads; + const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + bool is_mamba_layer = hparams.is_recurrent(i); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (is_mamba_layer) { + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2 * intermediate_size}, 0); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, intermediate_size}, 0); + + layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {intermediate_size, dt_dim + 2*d_state}, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_dim, num_heads}, 0); + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {num_heads}, 0); + + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {num_heads}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {num_heads}, 0); + + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {intermediate_size, n_embd}, 0); + + layer.ssm_dt_norm = create_tensor(tn(LLM_TENSOR_SSM_DT_NORM, i), {dt_dim}, 0); + layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, i), {d_state}, 0); + layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, i), {d_state}, 0); + } else { + const int64_t num_key_value_heads = hparams.n_head_kv(i); + const int64_t k_num_heads = num_key_value_heads; + const int64_t v_num_heads = num_key_value_heads; + const int64_t q_proj_dim = q_num_heads * qk_dim; + const int64_t k_proj_dim = k_num_heads * qk_dim; + const int64_t v_proj_dim = v_num_heads * v_dim; + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, q_proj_dim + k_proj_dim + v_proj_dim}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {head_dim, num_attention_heads}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {head_dim, k_num_heads}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {q_num_heads * v_dim, n_embd}, 0); + } + + // All layers have post-attention norm, FFN norm, and FFN tensors + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, i), {n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0); + } + } break; case LLM_ARCH_GPT2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4165,6 +4485,105 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); } } break; + case LLM_ARCH_GLM4_MOE: + { + const int64_t n_expert = hparams.n_expert; + const int64_t n_expert_used = hparams.n_expert_used; + const int64_t n_expert_shared = hparams.n_expert_shared; + + GGML_ASSERT(hparams.n_expert > 0 && "n_expert must be > 0 for GLM4_MOE MoE layers"); + GGML_ASSERT(hparams.n_expert_used > 0 && "n_expert_used must be > 0 for GLM4_MOE MoE layers"); + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + // Load ALL tensors including NextN layer to satisfy total tensor count + // but only PROCESS up to last layer (skipping final NextN layer) in forward pass + for (int i = 0; i < n_layer; ++i) { + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + // skip all tensors in the NextN layers + flags |= TENSOR_SKIP; + } + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, flags); + + // GLM-style attention with bias terms + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, flags); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, flags); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, flags); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, flags); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, flags); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, flags); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, flags); + + // K/Q norm tensors (optional for GLM-4.5 355B variant) + layer.attn_q_norm = create_tensor( + tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED | flags); + layer.attn_k_norm = create_tensor( + tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED | flags); + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, flags); + + // Check if this layer uses MoE or dense FFN based on n_layer_dense_lead + // GLM 4.5 uses hybrid architecture: layer 0 is dense, layers 1+ are MoE + const bool use_moe = (static_cast(i) >= hparams.n_layer_dense_lead); + + if (use_moe) { + // MoE layers + layer.ffn_gate_inp = + create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, flags); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), { n_expert }, flags); + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor( + tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags); + layer.ffn_down_exps = create_tensor( + tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, flags); + layer.ffn_up_exps = create_tensor( + tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags); + + // Shared expert + if (n_expert_shared > 0) { + const int64_t n_ff_shexp = n_ff_exp * n_expert_shared; + layer.ffn_gate_shexp = create_tensor( + tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags); + layer.ffn_down_shexp = create_tensor( + tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, flags); + layer.ffn_up_shexp = create_tensor( + tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags); + } + } else { + // Dense layers (first k layers) - GLM uses separate gate/up projections + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, flags); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags); + } + } + } + break; case LLM_ARCH_NEMOTRON: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4232,6 +4651,39 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } } break; + case LLM_ARCH_EXAONE4: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } + } break; case LLM_ARCH_RWKV6: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4747,6 +5199,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; case LLM_ARCH_ERNIE4_5: + case LLM_ARCH_ERNIE4_5_MOE: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4775,9 +5228,27 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + if (arch == LLM_ARCH_ERNIE4_5_MOE && static_cast(i) >= hparams.n_layer_dense_lead) { // MoE layers + int n_ff_exp = hparams.n_ff_exp; + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert (if present) + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp}, 0); + } + } else { // Dense layers + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } } } break; case LLM_ARCH_FALCON_H1: @@ -4894,6 +5365,39 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); } } break; + case LLM_ARCH_HUNYUAN_DENSE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + } + } break; case LLM_ARCH_SMOLLM3: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4923,6 +5427,46 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } } break; + case LLM_ARCH_OPENAI_MOE: + { + const int64_t n_ff_exp = hparams.n_ff_exp; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_rot}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_head_kv * n_rot}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_head_kv * n_rot}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0); + + layer.attn_sinks = create_tensor(tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + // bias + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_head * n_rot}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_head_kv * n_rot}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_head_kv * n_rot}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_gate_inp_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "bias", i), {n_expert}, 0); + layer.ffn_gate_exps_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "bias", i), {n_ff_exp, n_expert}, 0); + layer.ffn_down_exps_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "bias", i), { n_embd, n_expert}, 0); + layer.ffn_up_exps_b = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "bias", i), {n_ff_exp, n_expert}, 0); + } + } break; case LLM_ARCH_LFM2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4956,6 +5500,42 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } } break; + case LLM_ARCH_SMALLTHINKER: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + + GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for SMALLTHINKER"); + GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for SMALLTHINKER"); + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp; + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -5209,6 +5789,7 @@ void llama_model::print_info() const { arch == LLM_ARCH_MAMBA2 || arch == LLM_ARCH_JAMBA || arch == LLM_ARCH_FALCON_H1 || + arch == LLM_ARCH_PLAMO2 || arch == LLM_ARCH_GRANITE_HYBRID) { LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); @@ -5258,7 +5839,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } - if (arch == LLM_ARCH_QWEN3MOE) { + if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE) { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); } @@ -5280,6 +5861,11 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); } + if (arch == LLM_ARCH_SMALLTHINKER) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + } + vocab.print_info(); } @@ -5381,7 +5967,7 @@ ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int i } struct llm_build_llama : public llm_graph_context { - llm_build_llama(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_llama(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -5457,7 +6043,7 @@ struct llm_build_llama : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); @@ -5537,7 +6123,7 @@ struct llm_build_llama : public llm_graph_context { }; struct llm_build_llama_iswa : public llm_graph_context { - llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -5631,7 +6217,7 @@ struct llm_build_llama_iswa : public llm_graph_context { cb(Kcur, "Kcur_normed", il); } - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); @@ -5720,7 +6306,7 @@ struct llm_build_llama_iswa : public llm_graph_context { }; struct llm_build_deci : public llm_graph_context { - llm_build_deci(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_deci(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -5808,7 +6394,7 @@ struct llm_build_deci : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); } @@ -5876,7 +6462,7 @@ struct llm_build_deci : public llm_graph_context { }; struct llm_build_baichuan : public llm_graph_context { - llm_build_baichuan(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_baichuan(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -5940,7 +6526,7 @@ struct llm_build_baichuan : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -5998,7 +6584,7 @@ struct llm_build_baichuan : public llm_graph_context { }; struct llm_build_xverse : public llm_graph_context { - llm_build_xverse(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_xverse(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -6055,7 +6641,7 @@ struct llm_build_xverse : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -6111,7 +6697,7 @@ struct llm_build_xverse : public llm_graph_context { }; struct llm_build_falcon : public llm_graph_context { - llm_build_falcon(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_falcon(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -6178,7 +6764,7 @@ struct llm_build_falcon : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -6233,7 +6819,7 @@ struct llm_build_falcon : public llm_graph_context { }; struct llm_build_grok : public llm_graph_context { - llm_build_grok(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_grok(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -6308,7 +6894,7 @@ struct llm_build_grok : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } @@ -6395,7 +6981,7 @@ struct llm_build_grok : public llm_graph_context { }; struct llm_build_dbrx : public llm_graph_context { - llm_build_dbrx(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_dbrx(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -6457,7 +7043,7 @@ struct llm_build_dbrx : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -6520,7 +7106,7 @@ struct llm_build_dbrx : public llm_graph_context { }; struct llm_build_starcoder : public llm_graph_context { - llm_build_starcoder(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_starcoder(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -6571,7 +7157,7 @@ struct llm_build_starcoder : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -6629,7 +7215,7 @@ struct llm_build_starcoder : public llm_graph_context { }; struct llm_build_refact : public llm_graph_context { - llm_build_refact(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_refact(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -6670,7 +7256,7 @@ struct llm_build_refact : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -6728,7 +7314,7 @@ struct llm_build_refact : public llm_graph_context { }; struct llm_build_bert : public llm_graph_context { - llm_build_bert(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -6827,7 +7413,7 @@ struct llm_build_bert : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); @@ -6914,7 +7500,7 @@ struct llm_build_bert : public llm_graph_context { }; struct llm_build_neo_bert : public llm_graph_context { - llm_build_neo_bert(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_neo_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -6972,7 +7558,7 @@ struct llm_build_neo_bert : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); @@ -7024,7 +7610,7 @@ struct llm_build_neo_bert : public llm_graph_context { }; struct llm_build_bloom : public llm_graph_context { - llm_build_bloom(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_bloom(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -7072,7 +7658,7 @@ struct llm_build_bloom : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -7130,7 +7716,7 @@ struct llm_build_bloom : public llm_graph_context { }; struct llm_build_mpt : public llm_graph_context { - llm_build_mpt(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_mpt(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -7219,7 +7805,7 @@ struct llm_build_mpt : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -7278,7 +7864,7 @@ struct llm_build_mpt : public llm_graph_context { }; struct llm_build_stablelm : public llm_graph_context { - llm_build_stablelm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_stablelm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -7365,7 +7951,7 @@ struct llm_build_stablelm : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -7430,7 +8016,7 @@ struct llm_build_stablelm : public llm_graph_context { }; struct llm_build_qwen : public llm_graph_context { - llm_build_qwen(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_qwen(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -7486,7 +8072,7 @@ struct llm_build_qwen : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -7544,7 +8130,7 @@ struct llm_build_qwen : public llm_graph_context { }; struct llm_build_qwen2 : public llm_graph_context { - llm_build_qwen2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_qwen2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -7606,7 +8192,7 @@ struct llm_build_qwen2 : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -7654,6 +8240,10 @@ struct llm_build_qwen2 : public llm_graph_context { // lm_head cur = build_lora_mm(model.output, cur); + if (model.output_b != nullptr) { + cur = ggml_add(ctx0, cur, model.output_b); + } + cb(cur, "result_output", -1); res->t_logits = cur; @@ -7661,8 +8251,10 @@ struct llm_build_qwen2 : public llm_graph_context { } }; -struct llm_build_qwen2vl : public llm_graph_context { - llm_build_qwen2vl(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { +struct llm_build_dream : public llm_graph_context { + llm_build_dream(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + //copied from qwen2 const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -7676,10 +8268,7 @@ struct llm_build_qwen2vl : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); - - int sections[4]; - std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + auto * inp_attn = build_attn_inp_no_cache(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -7687,53 +8276,44 @@ struct llm_build_qwen2vl : public llm_graph_context { ggml_tensor * inpSA = inpL; // norm - cur = build_norm(inpL, - model.layers[il].attn_norm, NULL, - LLM_NORM_RMS, il); + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); // self-attention { // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); cb(Qcur, "Qcur", il); ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); cb(Kcur, "Kcur", il); ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); cb(Vcur, "Vcur", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_multi( - ctx0, Qcur, inp_pos, nullptr, - n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); - Kcur = ggml_rope_multi( - ctx0, Kcur, inp_pos, nullptr, - n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, - model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, + nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -7741,17 +8321,11 @@ struct llm_build_qwen2vl : public llm_graph_context { cb(ffn_inp, "ffn_inp", il); // feed-forward network - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, il); + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "ffn_norm", il); - cur = build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); + cur = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); cur = ggml_add(ctx0, cur, ffn_inp); @@ -7765,9 +8339,7 @@ struct llm_build_qwen2vl : public llm_graph_context { cur = inpL; - cur = build_norm(cur, - model.output_norm, NULL, - LLM_NORM_RMS, -1); + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); res->t_embd = cur; @@ -7782,8 +8354,10 @@ struct llm_build_qwen2vl : public llm_graph_context { } }; -struct llm_build_qwen2moe : public llm_graph_context { - llm_build_qwen2moe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { +struct llm_build_llada : public llm_graph_context { + llm_build_llada(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + // LLaDA is similar to LLaMA but uses non-causal attention for diffusion const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -7797,7 +8371,8 @@ struct llm_build_qwen2moe : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + // Non-causal attention for diffusion + auto * inp_attn = build_attn_inp_no_cache(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -7805,113 +8380,53 @@ struct llm_build_qwen2moe : public llm_graph_context { ggml_tensor * inpSA = inpL; // norm - cur = build_norm(inpL, - model.layers[il].attn_norm, NULL, - LLM_NORM_RMS, il); + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); - // self_attention + // self-attention { - // compute Q and K and RoPE them + // compute separate Q, K, V projections without bias, matching LLaDALlamaBlock ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, - model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, + 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); - // MoE branch - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, il); + // feed-forward network + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "ffn_norm", il); - ggml_tensor * moe_out = - build_moe_ffn(cur, - model.layers[il].ffn_gate_inp, - model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, - model.layers[il].ffn_down_exps, - nullptr, - n_expert, n_expert_used, - LLM_FFN_SILU, false, - false, 0.0, - LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, - il); - cb(moe_out, "ffn_moe_out", il); - - // FFN shared expert - { - ggml_tensor * cur_gate_inp = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur); - cb(cur_gate_inp, "ffn_shexp_gate_inp", il); - - // sigmoid - ggml_tensor * cur_gate = ggml_div(ctx0, ggml_silu(ctx0, cur_gate_inp), cur_gate_inp); - cb(cur_gate, "ffn_shexp_gate", il); - - ggml_tensor * cur_ffn = build_ffn(cur, - model.layers[il].ffn_up_shexp, NULL, NULL, - model.layers[il].ffn_gate_shexp, NULL, NULL, - model.layers[il].ffn_down_shexp, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur_ffn, "ffn_shexp", il); - - ggml_tensor * ffn_shexp_out = ggml_mul(ctx0, cur_ffn, cur_gate); - cb(ffn_shexp_out, "ffn_shexp_out", il); - - moe_out = ggml_add(ctx0, moe_out, ffn_shexp_out); - cb(moe_out, "ffn_out", il); - - cur = moe_out; - } + cur = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); cur = ggml_add(ctx0, cur, ffn_inp); @@ -7924,9 +8439,7 @@ struct llm_build_qwen2moe : public llm_graph_context { cur = inpL; - cur = build_norm(cur, - model.output_norm, NULL, - LLM_NORM_RMS, -1); + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); res->t_embd = cur; @@ -7941,8 +8454,8 @@ struct llm_build_qwen2moe : public llm_graph_context { } }; -struct llm_build_qwen3 : public llm_graph_context { - llm_build_qwen3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { +struct llm_build_qwen2vl : public llm_graph_context { + llm_build_qwen2vl(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -7958,6 +8471,9 @@ struct llm_build_qwen3 : public llm_graph_context { auto * inp_attn = build_attn_inp_kv_unified(); + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { @@ -7973,33 +8489,30 @@ struct llm_build_qwen3 : public llm_graph_context { { // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); cb(Qcur, "Qcur", il); ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); cb(Kcur, "Kcur", il); ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); cb(Vcur, "Vcur", il); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); - cb(Qcur, "Qcur_normed", il); - - Qcur = ggml_rope_ext( + Qcur = ggml_rope_multi( ctx0, Qcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); - Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); - cb(Kcur, "Kcur_normed", il); - - Kcur = ggml_rope_ext( + Kcur = ggml_rope_multi( ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -8007,7 +8520,7 @@ struct llm_build_qwen3 : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -8062,8 +8575,8 @@ struct llm_build_qwen3 : public llm_graph_context { } }; -struct llm_build_qwen3moe : public llm_graph_context { - llm_build_qwen3moe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { +struct llm_build_qwen2moe : public llm_graph_context { + llm_build_qwen2moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -8095,29 +8608,35 @@ struct llm_build_qwen3moe : public llm_graph_context { // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); - cb(Qcur, "Qcur_normed", il); - Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); - Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); - cb(Kcur, "Kcur_normed", il); - Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -8128,7 +8647,7 @@ struct llm_build_qwen3moe : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -8155,12 +8674,37 @@ struct llm_build_qwen3moe : public llm_graph_context { model.layers[il].ffn_down_exps, nullptr, n_expert, n_expert_used, - LLM_FFN_SILU, true, + LLM_FFN_SILU, false, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(moe_out, "ffn_moe_out", il); - cur = moe_out; + + // FFN shared expert + { + ggml_tensor * cur_gate_inp = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur); + cb(cur_gate_inp, "ffn_shexp_gate_inp", il); + + // sigmoid + ggml_tensor * cur_gate = ggml_div(ctx0, ggml_silu(ctx0, cur_gate_inp), cur_gate_inp); + cb(cur_gate, "ffn_shexp_gate", il); + + ggml_tensor * cur_ffn = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur_ffn, "ffn_shexp", il); + + ggml_tensor * ffn_shexp_out = ggml_mul(ctx0, cur_ffn, cur_gate); + cb(ffn_shexp_out, "ffn_shexp_out", il); + + moe_out = ggml_add(ctx0, moe_out, ffn_shexp_out); + cb(moe_out, "ffn_out", il); + + cur = moe_out; + } cur = ggml_add(ctx0, cur, ffn_inp); @@ -8190,16 +8734,14 @@ struct llm_build_qwen3moe : public llm_graph_context { } }; -struct llm_build_phi2 : public llm_graph_context { - llm_build_phi2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { +struct llm_build_qwen3 : public llm_graph_context { + llm_build_qwen3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); ggml_tensor * cur; - ggml_tensor * attn_norm_output; - ggml_tensor * ffn_output; ggml_tensor * inpL; inpL = build_inp_embd(model.tok_embd); @@ -8212,48 +8754,42 @@ struct llm_build_phi2 : public llm_graph_context { ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { - attn_norm_output = build_norm(inpL, - model.layers[il].attn_norm, - model.layers[il].attn_norm_b, - LLM_NORM, il); - cb(attn_norm_output, "attn_norm", il); + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); // self-attention { - ggml_tensor * Qcur = nullptr; - ggml_tensor * Kcur = nullptr; - ggml_tensor * Vcur = nullptr; - - if (model.layers[il].wqkv) { - cur = build_lora_mm(model.layers[il].wqkv, attn_norm_output); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - } else { - Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); - Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); - Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - } - + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); cb(Vcur, "Vcur", il); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -8264,34 +8800,34 @@ struct llm_build_phi2 : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - // with phi2, we scale the Q to avoid precision issues - // ref: https://github.com/ml-explore/mlx-examples/blob/08e862336ade809bc37d1035f94b359e7d1a5152/phi2/phi2.py#L64-L66 - Qcur = ggml_scale(ctx0, Qcur, 1.0f/sqrtf(float(n_embd_head))); - - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); - attn_norm_output = ggml_get_rows(ctx0, attn_norm_output, inp_out_ids); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } - // FF - { - ffn_output = build_ffn(attn_norm_output, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, - NULL, NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, - NULL, - LLM_FFN_GELU, LLM_FFN_SEQ, il); - cb(ffn_output, "ffn_out", il); - } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); - cur = ggml_add(ctx0, cur, ffn_output); - cur = ggml_add(ctx0, cur, inpL); + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); cur = build_cvec(cur, il); cb(cur, "l_out", il); @@ -8300,18 +8836,17 @@ struct llm_build_phi2 : public llm_graph_context { inpL = cur; } - cur = build_norm(inpL, - model.output_norm, - model.output_norm_b, - LLM_NORM, -1); + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); res->t_embd = cur; + // lm_head cur = build_lora_mm(model.output, cur); - cb(cur, "result_output_no_bias", -1); - - cur = ggml_add(ctx0, cur, model.output_b); cb(cur, "result_output", -1); res->t_logits = cur; @@ -8320,13 +8855,12 @@ struct llm_build_phi2 : public llm_graph_context { } }; -template -struct llm_build_phi3 : public llm_graph_context { - llm_build_phi3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { +struct llm_build_qwen3moe : public llm_graph_context { + llm_build_qwen3moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -8336,64 +8870,49 @@ struct llm_build_phi3 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - using inp_attn_type = std::conditional_t; - inp_attn_type * inp_attn = nullptr; - - if constexpr (iswa) { - inp_attn = build_attn_inp_kv_unified_iswa(); - } else { - inp_attn = build_attn_inp_kv_unified(); - } + auto * inp_attn = build_attn_inp_kv_unified(); ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { - auto * residual = inpL; - - // self-attention - { - // rope freq factors for 128k context - ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); - - ggml_tensor* attn_norm_output = build_norm(inpL, - model.layers[il].attn_norm, - model.layers[il].attn_norm_b, - LLM_NORM_RMS, il); - cb(attn_norm_output, "attn_norm", il); - - ggml_tensor * Qcur = nullptr; - ggml_tensor * Kcur = nullptr; - ggml_tensor * Vcur = nullptr; - - if (model.layers[il].wqkv) { - cur = build_lora_mm(model.layers[il].wqkv, attn_norm_output); - cb(cur, "wqkv", il); + ggml_tensor * inpSA = inpL; - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 0 * sizeof(float) * (n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd)); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa))); - } else { - Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); - Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); - Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - } + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + // self_attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); cb(Vcur, "Vcur", il); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, rope_factors, + ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, rope_factors, + ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -8402,39 +8921,27 @@ struct llm_build_phi3 : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); - cb(Qcur, "Qcur", il); - - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - residual = ggml_get_rows(ctx0, residual, inp_out_ids); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } - cur = ggml_add(ctx0, cur, residual); - residual = cur; + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); - cur = build_norm(cur, - model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "ffn_norm", il); - // feed-forward network - if (model.layers[il].ffn_gate_inp == nullptr) { - cur = build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, - NULL, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, - NULL, - LLM_FFN_SWIGLU, LLM_FFN_SEQ, il); - cb(cur, "ffn_out", il); - } else { - // MoE branch - cur = build_moe_ffn(cur, + ggml_tensor * moe_out = + build_moe_ffn(cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, @@ -8445,10 +8952,10 @@ struct llm_build_phi3 : public llm_graph_context { false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); - cb(cur, "ffn_moe_out", il); - } + cb(moe_out, "ffn_moe_out", il); + cur = moe_out; - cur = ggml_add(ctx0, residual, cur); + cur = ggml_add(ctx0, cur, ffn_inp); cur = build_cvec(cur, il); cb(cur, "l_out", il); @@ -8457,21 +8964,18 @@ struct llm_build_phi3 : public llm_graph_context { inpL = cur; } - cur = build_norm(inpL, - model.output_norm, - model.output_norm_b, + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); res->t_embd = cur; + // lm_head cur = build_lora_mm(model.output, cur); - if (model.output_b != nullptr) { - cb(cur, "result_output_no_bias", -1); - cur = ggml_add(ctx0, cur, model.output_b); - } - cb(cur, "result_output", -1); res->t_logits = cur; @@ -8479,14 +8983,16 @@ struct llm_build_phi3 : public llm_graph_context { } }; -struct llm_build_plamo : public llm_graph_context { - llm_build_plamo(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { +struct llm_build_phi2 : public llm_graph_context { + llm_build_phi2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); ggml_tensor * cur; + ggml_tensor * attn_norm_output; + ggml_tensor * ffn_output; ggml_tensor * inpL; inpL = build_inp_embd(model.tok_embd); @@ -8499,39 +9005,51 @@ struct llm_build_plamo : public llm_graph_context { ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { - // norm - cur = build_norm(inpL, - model.layers[il].attn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); - - ggml_tensor * sa_inp = cur; + attn_norm_output = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(attn_norm_output, "attn_norm", il); // self-attention { - // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); + ggml_tensor * Qcur = nullptr; + ggml_tensor * Kcur = nullptr; + ggml_tensor * Vcur = nullptr; - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); + if (model.layers[il].wqkv) { + cur = build_lora_mm(model.layers[il].wqkv, attn_norm_output); + cb(cur, "wqkv", il); - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + } else { + Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); + Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); + Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, - n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, nullptr, - n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -8539,33 +9057,33 @@ struct llm_build_plamo : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, - model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + // with phi2, we scale the Q to avoid precision issues + // ref: https://github.com/ml-explore/mlx-examples/blob/08e862336ade809bc37d1035f94b359e7d1a5152/phi2/phi2.py#L64-L66 + Qcur = ggml_scale(ctx0, Qcur, 1.0f/sqrtf(float(n_embd_head))); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - sa_inp = ggml_get_rows(ctx0, sa_inp, inp_out_ids); - inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + attn_norm_output = ggml_get_rows(ctx0, attn_norm_output, inp_out_ids); } - ggml_tensor * sa_out = cur; - - cur = sa_inp; - - // feed-forward network + // FF { - cur = build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, + ffn_output = build_ffn(attn_norm_output, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur, "ffn_out", il); + LLM_FFN_GELU, LLM_FFN_SEQ, il); + cb(ffn_output, "ffn_out", il); } - cur = ggml_add(ctx0, cur, sa_out); + cur = ggml_add(ctx0, cur, ffn_output); cur = ggml_add(ctx0, cur, inpL); cur = build_cvec(cur, il); @@ -8575,17 +9093,18 @@ struct llm_build_plamo : public llm_graph_context { inpL = cur; } - cur = inpL; - - cur = build_norm(cur, - model.output_norm, NULL, - LLM_NORM_RMS, -1); + cur = build_norm(inpL, + model.output_norm, + model.output_norm_b, + LLM_NORM, -1); cb(cur, "result_norm", -1); res->t_embd = cur; - // lm_head cur = build_lora_mm(model.output, cur); + cb(cur, "result_output_no_bias", -1); + + cur = ggml_add(ctx0, cur, model.output_b); cb(cur, "result_output", -1); res->t_logits = cur; @@ -8594,15 +9113,15 @@ struct llm_build_plamo : public llm_graph_context { } }; -struct llm_build_gpt2 : public llm_graph_context { - llm_build_gpt2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { +template +struct llm_build_phi3 : public llm_graph_context { + llm_build_phi3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); ggml_tensor * cur; - ggml_tensor * pos; ggml_tensor * inpL; inpL = build_inp_embd(model.tok_embd); @@ -8610,75 +9129,349 @@ struct llm_build_gpt2 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); - - pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); - cb(pos, "pos_embd", -1); + using inp_attn_type = std::conditional_t; + inp_attn_type * inp_attn = nullptr; - inpL = ggml_add(ctx0, inpL, pos); - cb(inpL, "inpL", -1); + if constexpr (iswa) { + inp_attn = build_attn_inp_kv_unified_iswa(); + } else { + inp_attn = build_attn_inp_kv_unified(); + } ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { - cur = build_norm(inpL, - model.layers[il].attn_norm, - model.layers[il].attn_norm_b, - LLM_NORM, il); - cb(cur, "attn_norm", il); + auto * residual = inpL; // self-attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); + // rope freq factors for 128k context + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); + ggml_tensor* attn_norm_output = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM_RMS, il); + cb(attn_norm_output, "attn_norm", il); - ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + ggml_tensor * Qcur = nullptr; + ggml_tensor * Kcur = nullptr; + ggml_tensor * Vcur = nullptr; + + if (model.layers[il].wqkv) { + cur = build_lora_mm(model.layers[il].wqkv, attn_norm_output); + cb(cur, "wqkv", il); + + Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 0 * sizeof(float) * (n_embd)); + Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd)); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa))); + } else { + Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); + Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); + Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + } cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - cur = build_attn(inp_attn, gf, + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); + cb(Qcur, "Qcur", il); + + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + residual = ggml_get_rows(ctx0, residual, inp_out_ids); } - // add the input - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); - cb(ffn_inp, "ffn_inp", il); + cur = ggml_add(ctx0, cur, residual); + residual = cur; - // FF - { - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm, - model.layers[il].ffn_norm_b, - LLM_NORM, il); - cb(cur, "ffn_norm", il); + cur = build_norm(cur, + model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + // feed-forward network + if (model.layers[il].ffn_gate_inp == nullptr) { cur = build_ffn(cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, - NULL, NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, - LLM_FFN_GELU, LLM_FFN_SEQ, il); + LLM_FFN_SWIGLU, LLM_FFN_SEQ, il); cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(cur, "ffn_moe_out", il); } - cur = ggml_add(ctx0, cur, ffn_inp); + cur = ggml_add(ctx0, residual, cur); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = build_norm(inpL, + model.output_norm, + model.output_norm_b, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + if (model.output_b != nullptr) { + cb(cur, "result_output_no_bias", -1); + cur = ggml_add(ctx0, cur, model.output_b); + } + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_plamo : public llm_graph_context { + llm_build_plamo(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + ggml_tensor * sa_inp = cur; + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + sa_inp = ggml_get_rows(ctx0, sa_inp, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + ggml_tensor * sa_out = cur; + + cur = sa_inp; + + // feed-forward network + { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, sa_out); + cur = ggml_add(ctx0, cur, inpL); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_gpt2 : public llm_graph_context { + llm_build_gpt2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * pos; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); + cb(pos, "pos_embd", -1); + + inpL = ggml_add(ctx0, inpL, pos); + cb(inpL, "inpL", -1); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + cur = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(cur, "attn_norm", il); + + // self-attention + { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // add the input + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); cur = build_cvec(cur, il); cb(cur, "l_out", il); @@ -8705,7 +9498,7 @@ struct llm_build_gpt2 : public llm_graph_context { }; struct llm_build_codeshell : public llm_graph_context { - llm_build_codeshell(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_codeshell(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -8761,7 +9554,7 @@ struct llm_build_codeshell : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -8819,7 +9612,7 @@ struct llm_build_codeshell : public llm_graph_context { }; struct llm_build_orion : public llm_graph_context { - llm_build_orion(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_orion(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -8890,7 +9683,7 @@ struct llm_build_orion : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -8946,7 +9739,7 @@ struct llm_build_orion : public llm_graph_context { }; struct llm_build_internlm2 : public llm_graph_context { - llm_build_internlm2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_internlm2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -9017,7 +9810,7 @@ struct llm_build_internlm2 : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -9073,7 +9866,7 @@ struct llm_build_internlm2 : public llm_graph_context { }; struct llm_build_minicpm3 : public llm_graph_context { - llm_build_minicpm3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_minicpm3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { //TODO: if the model varies, these parameters need to be read from the model const int64_t n_embd_base = 256; const float scale_embd = 12.0f; @@ -9205,7 +9998,7 @@ struct llm_build_minicpm3 : public llm_graph_context { ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); cb(k_states, "k_states", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, q_states, k_states, v_states, nullptr, nullptr, kq_scale, il); } @@ -9277,7 +10070,7 @@ struct llm_build_minicpm3 : public llm_graph_context { }; struct llm_build_gemma : public llm_graph_context { - llm_build_gemma(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_gemma(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; ggml_tensor * cur; @@ -9335,7 +10128,7 @@ struct llm_build_gemma : public llm_graph_context { Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); cb(Qcur, "Qcur_scaled", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } @@ -9393,7 +10186,7 @@ struct llm_build_gemma : public llm_graph_context { }; struct llm_build_gemma2_iswa : public llm_graph_context { - llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_k; ggml_tensor * cur; @@ -9450,7 +10243,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context { Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } @@ -9523,7 +10316,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context { }; struct llm_build_gemma3_iswa : public llm_graph_context { - llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_k; ggml_tensor * cur; @@ -9592,7 +10385,7 @@ struct llm_build_gemma3_iswa : public llm_graph_context { // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315 Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } @@ -9661,7 +10454,6 @@ struct llm_build_gemma3_iswa : public llm_graph_context { struct llm_build_gemma3n_iswa : public llm_graph_context { const llama_model & model; - ggml_cgraph * gf; const int64_t n_embd_head; const int64_t n_embd_altup; @@ -9671,10 +10463,9 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { const int n_layer_sparsity = 10; // number of layers using activation sparsity const float f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95) - llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) + llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model), - gf(gf), n_embd_head(model.hparams.n_embd_head_k), n_embd_altup(model.hparams.n_embd_altup), n_altup(model.hparams.n_altup), @@ -9775,7 +10566,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { cb(Qcur, "Qcur_pos", il); cb(Kcur, "Kcur_pos", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il); } else { @@ -9793,7 +10584,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur_pos", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); } @@ -10087,7 +10878,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { // TODO: move up next to build_starcoder struct llm_build_starcoder2 : public llm_graph_context { - llm_build_starcoder2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_starcoder2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -10158,7 +10949,7 @@ struct llm_build_starcoder2 : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -10219,7 +11010,6 @@ struct llm_graph_context_mamba : public llm_graph_context { ggml_tensor * build_mamba_layer( llm_graph_input_rs * inp, - ggml_cgraph * gf, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, @@ -10244,13 +11034,13 @@ struct llm_graph_context_mamba : public llm_graph_context { const int64_t n_seq_tokens = ubatch.n_seq_tokens; GGML_ASSERT(n_seqs != 0); - GGML_ASSERT(ubatch.equal_seqs); + GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs); + ggml_tensor * conv = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} @@ -10331,7 +11121,7 @@ struct llm_graph_context_mamba : public llm_graph_context { return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); }; - ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); + ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); // store last states ggml_build_forward_expand(gf, @@ -10358,11 +11148,10 @@ struct llm_graph_context_mamba : public llm_graph_context { ggml_tensor * build_mamba2_layer( llm_graph_input_rs * inp, - ggml_cgraph * gf, - ggml_tensor * cur, - const llama_model & model, - const llama_ubatch & ubatch, - int il) const { + ggml_tensor * cur, + const llama_model & model, + const llama_ubatch & ubatch, + int il) const { const auto * mctx_cur = inp->mctx; @@ -10379,13 +11168,13 @@ struct llm_graph_context_mamba : public llm_graph_context { const int64_t n_seq_tokens = ubatch.n_seq_tokens; GGML_ASSERT(n_seqs != 0); - GGML_ASSERT(ubatch.equal_seqs); + GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs); + ggml_tensor * conv = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} @@ -10455,7 +11244,7 @@ struct llm_graph_context_mamba : public llm_graph_context { return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); }; - ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); + ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); // store last states ggml_build_forward_expand(gf, @@ -10491,7 +11280,7 @@ struct llm_graph_context_mamba : public llm_graph_context { }; struct llm_build_mamba : public llm_graph_context_mamba { - llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) { + llm_build_mamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) { ggml_tensor * cur; ggml_tensor * inpL; @@ -10510,9 +11299,9 @@ struct llm_build_mamba : public llm_graph_context_mamba { cb(cur, "attn_norm", il); if (model.arch == LLM_ARCH_MAMBA2) { - cur = build_mamba2_layer(rs_inp, gf, cur, model, ubatch, il); + cur = build_mamba2_layer(rs_inp, cur, model, ubatch, il); } else { - cur = build_mamba_layer(rs_inp, gf, cur, model, ubatch, il); + cur = build_mamba_layer(rs_inp, cur, model, ubatch, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -10548,7 +11337,7 @@ struct llm_build_mamba : public llm_graph_context_mamba { }; struct llm_build_jamba : public llm_graph_context_mamba { - llm_build_jamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) { + llm_build_jamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) { const int64_t n_embd_head = hparams.n_embd_head_v; ggml_tensor * cur; @@ -10568,7 +11357,7 @@ struct llm_build_jamba : public llm_graph_context_mamba { cb(cur, "attn_norm", il); if (n_head_kv == 0) { - cur = build_mamba_layer(inp_hybrid->get_recr(), gf, cur, model, ubatch, il); + cur = build_mamba_layer(inp_hybrid->get_recr(), cur, model, ubatch, il); } else { // Attention @@ -10589,7 +11378,7 @@ struct llm_build_jamba : public llm_graph_context_mamba { cb(Vcur, "Vcur", il); // No RoPE :) - cur = build_attn(inp_hybrid->get_attn(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il); + cur = build_attn(inp_hybrid->get_attn(), model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -10657,7 +11446,7 @@ struct llm_build_jamba : public llm_graph_context_mamba { }; struct llm_build_command_r : public llm_graph_context { - llm_build_command_r(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_command_r(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -10745,7 +11534,7 @@ struct llm_build_command_r : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -10804,7 +11593,7 @@ struct llm_build_command_r : public llm_graph_context { }; struct llm_build_cohere2_iswa : public llm_graph_context { - llm_build_cohere2_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_cohere2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -10880,7 +11669,7 @@ struct llm_build_cohere2_iswa : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -10940,7 +11729,7 @@ struct llm_build_cohere2_iswa : public llm_graph_context { // * removed bias // * removed MoE struct llm_build_olmo : public llm_graph_context { - llm_build_olmo(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_olmo(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -11011,7 +11800,7 @@ struct llm_build_olmo : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -11068,7 +11857,7 @@ struct llm_build_olmo : public llm_graph_context { }; struct llm_build_olmo2 : public llm_graph_context { - llm_build_olmo2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_olmo2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -11131,7 +11920,7 @@ struct llm_build_olmo2 : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -11197,7 +11986,7 @@ struct llm_build_olmo2 : public llm_graph_context { // * removed bias // * added q, k norm struct llm_build_olmoe : public llm_graph_context { - llm_build_olmoe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_olmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -11264,7 +12053,7 @@ struct llm_build_olmoe : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -11325,7 +12114,7 @@ struct llm_build_olmoe : public llm_graph_context { }; struct llm_build_openelm : public llm_graph_context { - llm_build_openelm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_openelm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -11397,7 +12186,7 @@ struct llm_build_openelm : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Qcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -11454,7 +12243,7 @@ struct llm_build_openelm : public llm_graph_context { }; struct llm_build_gptneox : public llm_graph_context { - llm_build_gptneox(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_gptneox(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -11509,7 +12298,7 @@ struct llm_build_gptneox : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -11600,7 +12389,7 @@ struct llm_build_gptneox : public llm_graph_context { }; struct llm_build_arctic : public llm_graph_context { - llm_build_arctic(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_arctic(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -11659,7 +12448,7 @@ struct llm_build_arctic : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -11738,7 +12527,7 @@ struct llm_build_arctic : public llm_graph_context { }; struct llm_build_deepseek : public llm_graph_context { - llm_build_deepseek(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_deepseek(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -11814,7 +12603,7 @@ struct llm_build_deepseek : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); } @@ -11900,7 +12689,7 @@ struct llm_build_deepseek : public llm_graph_context { }; struct llm_build_deepseek2 : public llm_graph_context { - llm_build_deepseek2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { bool is_lite = (hparams.n_layer == 27); const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); @@ -12042,7 +12831,7 @@ struct llm_build_deepseek2 : public llm_graph_context { cb(Vcur, "Vcur", il); // note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group) - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, model.layers[il].wv_b, kq_scale, il); } else { @@ -12076,7 +12865,7 @@ struct llm_build_deepseek2 : public llm_graph_context { cb(Kcur, "Kcur", il); // note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups) - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); } @@ -12163,7 +12952,7 @@ struct llm_build_deepseek2 : public llm_graph_context { }; struct llm_build_bitnet : public llm_graph_context { - llm_build_bitnet(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_bitnet(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -12243,7 +13032,7 @@ struct llm_build_bitnet : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, NULL, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); @@ -12323,7 +13112,7 @@ struct llm_build_bitnet : public llm_graph_context { }; struct llm_build_t5_enc : public llm_graph_context { - llm_build_t5_enc(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_t5_enc(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -12366,7 +13155,7 @@ struct llm_build_t5_enc : public llm_graph_context { ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc; ggml_tensor * kq_b = build_pos_bias(pos_bucket_enc, attn_rel_b); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo_enc, nullptr, Qcur, Kcur, Vcur, kq_b, nullptr, 1.0f, il); cb(cur, "kqv_out", il); @@ -12424,7 +13213,7 @@ struct llm_build_t5_enc : public llm_graph_context { }; struct llm_build_t5_dec : public llm_graph_context { - llm_build_t5_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_t5_dec(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; //const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -12472,7 +13261,7 @@ struct llm_build_t5_dec : public llm_graph_context { ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b; ggml_tensor * kq_b = build_pos_bias(pos_bucket_dec, attn_rel_b); - cur = build_attn(inp_attn_self, gf, + cur = build_attn(inp_attn_self, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, kq_b, nullptr, 1.0f, il); cb(cur, "kqv_out", il); @@ -12504,7 +13293,7 @@ struct llm_build_t5_dec : public llm_graph_context { Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_outputs_enc); - cur = build_attn(inp_attn_cross, gf, + cur = build_attn(inp_attn_cross, model.layers[il].wo_cross, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); cb(cur, "kqv_out", il); @@ -12594,7 +13383,7 @@ struct llm_build_t5_dec : public llm_graph_context { }; struct llm_build_jais : public llm_graph_context { - llm_build_jais(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_jais(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -12636,7 +13425,7 @@ struct llm_build_jais : public llm_graph_context { Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/float(n_embd_head), il); } @@ -12673,12 +13462,296 @@ struct llm_build_jais : public llm_graph_context { cur = build_norm(inpL, model.output_norm, - model.output_norm_b, - LLM_NORM, -1); + model.output_norm_b, + LLM_NORM, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_chatglm : public llm_graph_context { + llm_build_chatglm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, + model.layers[il].attn_norm, + NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = nullptr; + ggml_tensor * Kcur = nullptr; + ggml_tensor * Vcur = nullptr; + + if (model.layers[il].wqkv == nullptr) { + Qcur = build_lora_mm(model.layers[il].wq, cur); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + } + Kcur = build_lora_mm(model.layers[il].wk, cur); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + } + Vcur = build_lora_mm(model.layers[il].wv, cur); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + } + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + } else { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + if (model.layers[il].bqkv) { + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + } + Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + } + + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + //printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor); + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // Add the input + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SWIGLU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + } + + inpL = ggml_add(ctx0, cur, ffn_inp); + cb(inpL, "l_out", il); + } + + cur = build_norm(inpL, + model.output_norm, + NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_glm4 : public llm_graph_context { + llm_build_glm4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // Pre-attention norm + cur = build_norm(inpL, + model.layers[il].attn_norm, + NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = nullptr; + ggml_tensor * Kcur = nullptr; + ggml_tensor * Vcur = nullptr; + + if (model.layers[il].wqkv == nullptr) { + Qcur = build_lora_mm(model.layers[il].wq, cur); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + } + Kcur = build_lora_mm(model.layers[il].wk, cur); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + } + Vcur = build_lora_mm(model.layers[il].wv, cur); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + } + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + } else { + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + if (model.layers[il].bqkv) { + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + } + Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + } + + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // Post-attention norm (new!) + cur = build_norm(cur, + model.layers[il].attn_post_norm, + NULL, + LLM_NORM_RMS, il); + cb(cur, "post_attn_norm", il); + + // Add the input (residual connection after post-attention norm) + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + // Pre-MLP norm + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // MLP + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SWIGLU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + // Post-MLP norm + cur = build_norm(cur, + model.layers[il].ffn_post_norm, + NULL, + LLM_NORM_RMS, il); + cb(cur, "post_mlp_norm", il); + } + + // Add residual connection after post-MLP norm + inpL = ggml_add(ctx0, cur, ffn_inp); + cb(inpL, "l_out", il); + } + + // Final norm + cur = build_norm(inpL, + model.output_norm, + NULL, + LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); res->t_embd = cur; + // Output projection cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); @@ -12688,10 +13761,9 @@ struct llm_build_jais : public llm_graph_context { } }; -struct llm_build_chatglm : public llm_graph_context { - llm_build_chatglm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { +struct llm_build_glm4_moe : public llm_graph_context { + llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -12707,51 +13779,50 @@ struct llm_build_chatglm : public llm_graph_context { ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + // Only process up to last layer (skip final NextN layer) + // Final layer tensors are loaded but not processed in forward pass + const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { ggml_tensor * inpSA = inpL; - cur = build_norm(inpL, - model.layers[il].attn_norm, - NULL, - LLM_NORM_RMS, il); + // Pre-attention norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); // self-attention { - ggml_tensor * Qcur = nullptr; - ggml_tensor * Kcur = nullptr; - ggml_tensor * Vcur = nullptr; + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + } + cb(Qcur, "Qcur", il); - if (model.layers[il].wqkv == nullptr) { - Qcur = build_lora_mm(model.layers[il].wq, cur); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - } - Kcur = build_lora_mm(model.layers[il].wk, cur); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - } - Vcur = build_lora_mm(model.layers[il].wv, cur); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - } else { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - if (model.layers[il].bqkv) { - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - } - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + } + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); } + cb(Vcur, "Vcur", il); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - //printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor); + // Apply Q/K norm if available (GLM-4.5 355B variant) + if (model.layers[il].attn_q_norm) { + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + } + if (model.layers[il].attn_k_norm) { + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + } + Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -12768,50 +13839,78 @@ struct llm_build_chatglm : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); + if (il == n_transformer_layers - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } - // Add the input ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); - // FF - { - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm, - NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); + // Post-attention norm + cur = build_norm(ffn_inp, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "post_attn_norm", il); + // Check if this is a dense layer (n_layer_dense_lead=1, so layer 0 is dense) + if (static_cast(il) < hparams.n_layer_dense_lead) { + // Dense FFN layer cur = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, - NULL, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, - LLM_FFN_SWIGLU, LLM_FFN_SEQ, il); + LLM_FFN_SILU, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); + } else { + // Process routed experts using existing MoE infrastructure + ggml_tensor * routed_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(routed_out, "ffn_moe_out", il); + + // Process shared expert on original input + ggml_tensor * shared_out = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(shared_out, "ffn_shexp_out", il); + // Final output: routed_output + shared_output + cur = ggml_add(ctx0, routed_out, shared_out); + cb(cur, "ffn_out", il); } - inpL = ggml_add(ctx0, cur, ffn_inp); - cb(inpL, "l_out", il); + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; } - cur = build_norm(inpL, - model.output_norm, - NULL, - LLM_NORM_RMS, -1); + cur = inpL; + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); res->t_embd = cur; + // lm_head cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); @@ -12821,12 +13920,12 @@ struct llm_build_chatglm : public llm_graph_context { } }; -struct llm_build_glm4 : public llm_graph_context { - llm_build_glm4(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { +struct llm_build_nemotron : public llm_graph_context { + llm_build_nemotron(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + //GGML_ASSERT(n_embd_head == hparams.n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -12843,46 +13942,39 @@ struct llm_build_glm4 : public llm_graph_context { for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; - // Pre-attention norm + // norm cur = build_norm(inpL, model.layers[il].attn_norm, - NULL, - LLM_NORM_RMS, il); + model.layers[il].attn_norm_b, + LLM_NORM, il); cb(cur, "attn_norm", il); // self-attention { - ggml_tensor * Qcur = nullptr; - ggml_tensor * Kcur = nullptr; - ggml_tensor * Vcur = nullptr; + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } - if (model.layers[il].wqkv == nullptr) { - Qcur = build_lora_mm(model.layers[il].wq, cur); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - } - Kcur = build_lora_mm(model.layers[il].wk, cur); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - } - Vcur = build_lora_mm(model.layers[il].wv, cur); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - } else { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - if (model.layers[il].bqkv) { - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - } - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); } + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); Qcur = ggml_rope_ext( @@ -12901,8 +13993,8 @@ struct llm_build_glm4 : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, - model.layers[il].wo, NULL, + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -12911,58 +14003,43 @@ struct llm_build_glm4 : public llm_graph_context { inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } - // Post-attention norm (new!) - cur = build_norm(cur, - model.layers[il].attn_post_norm, - NULL, - LLM_NORM_RMS, il); - cb(cur, "post_attn_norm", il); - - // Add the input (residual connection after post-attention norm) ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); - // FF - { - // Pre-MLP norm - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm, - NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(cur, "ffn_norm", il); - // MLP - cur = build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, - NULL, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, - NULL, - LLM_FFN_SWIGLU, LLM_FFN_SEQ, il); - cb(cur, "ffn_out", il); + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il); - // Post-MLP norm - cur = build_norm(cur, - model.layers[il].ffn_post_norm, - NULL, - LLM_NORM_RMS, il); - cb(cur, "post_mlp_norm", il); - } + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); - // Add residual connection after post-MLP norm - inpL = ggml_add(ctx0, cur, ffn_inp); - cb(inpL, "l_out", il); + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; } - // Final norm - cur = build_norm(inpL, - model.output_norm, - NULL, - LLM_NORM_RMS, -1); + cur = inpL; + + cur = build_norm(cur, + model.output_norm, model.output_norm_b, + LLM_NORM, -1); cb(cur, "result_norm", -1); res->t_embd = cur; - // Output projection + // lm_head cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); @@ -12972,12 +14049,12 @@ struct llm_build_glm4 : public llm_graph_context { } }; -struct llm_build_nemotron : public llm_graph_context { - llm_build_nemotron(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { +struct llm_build_exaone : public llm_graph_context { + llm_build_exaone(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - //GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -12996,13 +14073,15 @@ struct llm_build_nemotron : public llm_graph_context { // norm cur = build_norm(inpL, - model.layers[il].attn_norm, - model.layers[il].attn_norm_b, - LLM_NORM, il); + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); cb(cur, "attn_norm", il); // self-attention { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); cb(Qcur, "Qcur", il); @@ -13030,13 +14109,13 @@ struct llm_build_nemotron : public llm_graph_context { Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, nullptr, + ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, nullptr, + ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -13045,7 +14124,7 @@ struct llm_build_nemotron : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -13060,17 +14139,17 @@ struct llm_build_nemotron : public llm_graph_context { // feed-forward network cur = build_norm(ffn_inp, - model.layers[il].ffn_norm, - model.layers[il].ffn_norm_b, - LLM_NORM, il); + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); cb(cur, "ffn_norm", il); cur = build_ffn(cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, - NULL, NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, - LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il); + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); @@ -13085,8 +14164,8 @@ struct llm_build_nemotron : public llm_graph_context { cur = inpL; cur = build_norm(cur, - model.output_norm, model.output_norm_b, - LLM_NORM, -1); + model.output_norm, NULL, + LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); res->t_embd = cur; @@ -13101,11 +14180,12 @@ struct llm_build_nemotron : public llm_graph_context { } }; -struct llm_build_exaone : public llm_graph_context { - llm_build_exaone(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +template +struct llm_build_exaone4 : public llm_graph_context { + llm_build_exaone4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_k; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_v); GGML_ASSERT(n_embd_head == hparams.n_rot); ggml_tensor * cur; @@ -13116,69 +14196,69 @@ struct llm_build_exaone : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + using inp_attn_type = std::conditional_t; + inp_attn_type * inp_attn = nullptr; + + if constexpr (iswa) { + inp_attn = build_attn_inp_kv_unified_iswa(); + } else { + inp_attn = build_attn_inp_kv_unified(); + } ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; - // norm - cur = build_norm(inpL, - model.layers[il].attn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); + // use RoPE for SWA layers or non-SWA models + const bool use_rope = hparams.is_swa(il) || hparams.swa_type == LLAMA_SWA_TYPE_NONE; + + cur = inpL; // self-attention { - // rope freq factors for llama3; may return nullptr for llama2 and other models ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); - // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + cb(Kcur, "Kcur_normed", il); - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); + if (use_rope) { + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, - model.layers[il].wo, model.layers[il].bo, + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + cb(cur, "attn_out", il); } if (il == n_layer - 1 && inp_out_ids) { @@ -13186,16 +14266,16 @@ struct llm_build_exaone : public llm_graph_context { inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } + cur = build_norm(cur, + model.layers[il].attn_post_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); // feed-forward network - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); - - cur = build_ffn(cur, + cur = build_ffn(ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -13203,8 +14283,12 @@ struct llm_build_exaone : public llm_graph_context { LLM_FFN_SILU, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); + cur = build_norm(cur, + model.layers[il].ffn_post_norm, NULL, + LLM_NORM_RMS, -1); + cb(cur, "ffn_post_norm", -1); + cur = ggml_add(ctx0, cur, ffn_inp); - cb(cur, "ffn_out", il); cur = build_cvec(cur, il); cb(cur, "l_out", il); @@ -13269,7 +14353,6 @@ struct llm_build_rwkv6_base : public llm_graph_context { ggml_tensor * build_rwkv6_time_mix( llm_graph_input_rs * inp, - ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * x_prev, const llama_ubatch & ubatch, @@ -13396,7 +14479,7 @@ struct llm_build_rwkv6_base : public llm_graph_context { } ggml_tensor * wkv_state = build_rs( - inp, gf, mctx_cur->get_s_l(il), + inp, mctx_cur->get_s_l(il), hparams.n_embd_s(), n_seqs); ggml_tensor * wkv_output; @@ -13442,7 +14525,7 @@ struct llm_build_rwkv6_base : public llm_graph_context { }; struct llm_build_rwkv6 : public llm_build_rwkv6_base { - llm_build_rwkv6(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv6_base(model, params) { + llm_build_rwkv6(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv6_base(model, params) { GGML_ASSERT(hparams.token_shift_count == 2); ggml_tensor * cur; @@ -13463,7 +14546,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base { const llama_layer * layer = &model.layers[il]; inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); - ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il); + ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, ubatch, il); ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0); ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift)); @@ -13478,7 +14561,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base { 1 ); - cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il); + cur = build_rwkv6_time_mix(rs_inp, att_norm, x_prev, ubatch, il); ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); cb(ffn_inp, "ffn_inp", il); @@ -13543,7 +14626,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base { // ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { - llm_build_rwkv6qwen2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv6_base(model, params) { + llm_build_rwkv6qwen2(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv6_base(model, params) { GGML_ASSERT(n_embd == hparams.n_embd_r()); ggml_tensor * cur; @@ -13563,7 +14646,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { const llama_layer * layer = &model.layers[il]; inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); - ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il); + ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, ubatch, il); ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il); cb(att_norm, "attn_norm", il); @@ -13575,7 +14658,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { 1 ); - cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il); + cur = build_rwkv6_time_mix(rs_inp, att_norm, x_prev, ubatch, il); token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)); ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il)); @@ -13665,7 +14748,6 @@ struct llm_build_rwkv7_base : public llm_graph_context { ggml_tensor * build_rwkv7_time_mix( llm_graph_input_rs * inp, - ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * x_prev, ggml_tensor *& first_layer_value, @@ -13751,7 +14833,7 @@ struct llm_build_rwkv7_base : public llm_graph_context { a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens); ggml_tensor * wkv_state = build_rs( - inp, gf, mctx_cur->get_s_l(il), + inp, mctx_cur->get_s_l(il), hparams.n_embd_s(), n_seqs); ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state); @@ -13798,7 +14880,7 @@ struct llm_build_rwkv7_base : public llm_graph_context { }; struct llm_build_rwkv7 : public llm_build_rwkv7_base { - llm_build_rwkv7(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv7_base(model, params) { + llm_build_rwkv7(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv7_base(model, params) { GGML_ASSERT(hparams.token_shift_count == 2); ggml_tensor * cur; @@ -13820,7 +14902,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base { const llama_layer * layer = &model.layers[il]; inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); - ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il); + ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, ubatch, il); ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0); ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift)); @@ -13835,7 +14917,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base { 1 ); - cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il); + cur = build_rwkv7_time_mix(rs_inp, att_norm, x_prev, v_first, ubatch, il); ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); cb(ffn_inp, "ffn_inp", il); @@ -13894,7 +14976,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base { struct llm_build_arwkv7 : public llm_build_rwkv7_base { - llm_build_arwkv7(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv7_base(model, params) { + llm_build_arwkv7(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv7_base(model, params) { GGML_ASSERT(n_embd == hparams.n_embd_r()); ggml_tensor * cur; @@ -13915,7 +14997,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { const llama_layer * layer = &model.layers[il]; inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); - ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il); + ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, ubatch, il); ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il); cb(att_norm, "attn_norm", il); @@ -13927,7 +15009,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { 1 ); - cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il); + cur = build_rwkv7_time_mix(rs_inp, att_norm, x_prev, v_first, ubatch, il); token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)); ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il)); @@ -13984,8 +15066,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { struct llm_build_granite : public llm_graph_context { llm_build_granite( const llama_model & model, - const llm_graph_params & params, - ggml_cgraph * gf) + const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -14019,7 +15100,7 @@ struct llm_build_granite : public llm_graph_context { // self-attention cur = build_attention_layer( - gf, cur, inp_pos, inp_attn, + cur, inp_pos, inp_attn, model, n_embd_head, il); if (il == n_layer - 1 && inp_out_ids) { @@ -14055,7 +15136,6 @@ struct llm_build_granite : public llm_graph_context { } ggml_tensor * build_attention_layer( - ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv_unified * inp_attn, @@ -14110,7 +15190,7 @@ struct llm_build_granite : public llm_graph_context { cb(Vcur, "Vcur", il); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); @@ -14198,11 +15278,9 @@ struct llm_build_granite : public llm_graph_context { }; struct llm_build_granite_hybrid : public llm_graph_context_mamba { - llm_build_granite_hybrid( const llama_model & model, - const llm_graph_params & params, - ggml_cgraph * gf) : + const llm_graph_params & params) : llm_graph_context_mamba(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -14234,11 +15312,11 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba { if (hparams.is_recurrent(il)) { // ssm layer // - cur = build_mamba2_layer(inp->get_recr(), gf, cur, model, ubatch, il); + cur = build_mamba2_layer(inp->get_recr(), cur, model, ubatch, il); } else { // attention layer // cur = build_attention_layer( - gf, cur, inp_pos, inp->get_attn(), model, + cur, inp_pos, inp->get_attn(), model, n_embd_head, il); } @@ -14277,7 +15355,6 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba { } ggml_tensor * build_attention_layer( - ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv_unified * inp_attn, @@ -14332,7 +15409,7 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba { cb(Vcur, "Vcur", il); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); @@ -14426,7 +15503,7 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba { // * removed bias // * removed MoE struct llm_build_chameleon : public llm_graph_context { - llm_build_chameleon(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_chameleon(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -14517,7 +15594,7 @@ struct llm_build_chameleon : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -14603,7 +15680,7 @@ struct llm_build_chameleon : public llm_graph_context { }; struct llm_build_wavtokenizer_dec : public llm_graph_context { - llm_build_wavtokenizer_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_wavtokenizer_dec(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { ggml_tensor * cur; ggml_tensor * inpL; @@ -14755,7 +15832,7 @@ struct llm_build_wavtokenizer_dec : public llm_graph_context { }; struct llm_build_plm : public llm_graph_context { - llm_build_plm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_plm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const float kq_scale = 1.0f/sqrtf(float(hparams.n_embd_head_k)); const uint32_t n_embd_head_qk_rope = hparams.n_rot; @@ -14873,7 +15950,7 @@ struct llm_build_plm : public llm_graph_context { ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); cb(k_states, "k_states", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, q_states, k_states, v_states, nullptr, nullptr, kq_scale, il); } @@ -14927,7 +16004,7 @@ struct llm_build_plm : public llm_graph_context { }; struct llm_build_bailingmoe : public llm_graph_context { - llm_build_bailingmoe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_bailingmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { ggml_tensor * cur; ggml_tensor * inpL; @@ -14996,7 +16073,7 @@ struct llm_build_bailingmoe : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il); } @@ -15071,7 +16148,7 @@ struct llm_build_bailingmoe : public llm_graph_context { }; struct llm_build_dots1 : public llm_graph_context { - llm_build_dots1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_dots1(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -15136,7 +16213,7 @@ struct llm_build_dots1 : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -15149,13 +16226,170 @@ struct llm_build_dots1 : public llm_graph_context { ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); - // MoE branch - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + if ((uint32_t) il < hparams.n_layer_dense_lead) { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(moe_out, "ffn_moe_out", il); + + { + ggml_tensor * ffn_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_ernie4_5 : public llm_graph_context { + llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + { + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + } + + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); - if ((uint32_t) il < hparams.n_layer_dense_lead) { cur = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, @@ -15163,33 +16397,6 @@ struct llm_build_dots1 : public llm_graph_context { NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); - } else { - ggml_tensor * moe_out = - build_moe_ffn(cur, - model.layers[il].ffn_gate_inp, - model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, - model.layers[il].ffn_down_exps, - model.layers[il].ffn_exp_probs_b, - n_expert, n_expert_used, - LLM_FFN_SILU, hparams.expert_weights_norm, - true, hparams.expert_weights_scale, - (llama_expert_gating_func_type) hparams.expert_gating_func, - il); - cb(moe_out, "ffn_moe_out", il); - - { - ggml_tensor * ffn_shexp = build_ffn(cur, - model.layers[il].ffn_up_shexp, NULL, NULL, - model.layers[il].ffn_gate_shexp, NULL, NULL, - model.layers[il].ffn_down_shexp, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(ffn_shexp, "ffn_shexp", il); - - cur = ggml_add(ctx0, moe_out, ffn_shexp); - cb(cur, "ffn_out", il); - } } cur = ggml_add(ctx0, cur, ffn_inp); @@ -15220,8 +16427,8 @@ struct llm_build_dots1 : public llm_graph_context { } }; -struct llm_build_ernie4_5 : public llm_graph_context { - llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { +struct llm_build_ernie4_5_moe : public llm_graph_context { + llm_build_ernie4_5_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -15237,9 +16444,11 @@ struct llm_build_ernie4_5 : public llm_graph_context { auto * inp_attn = build_attn_inp_kv_unified(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + GGML_ASSERT(hparams.n_moe_layer_step > 0 && "Ernie 4.5 MoE requires n_moe_layer_step > 0"); for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; - // norm { cur = build_norm(inpL, @@ -15250,6 +16459,7 @@ struct llm_build_ernie4_5 : public llm_graph_context { // self-attention { + // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { @@ -15291,14 +16501,13 @@ struct llm_build_ernie4_5 : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + cb(cur, "attn_out", il); } - if (il == n_layer - 1) { - // skip computing output for unused tokens - ggml_tensor * inp_out_ids = build_inp_out_ids(); + if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -15307,7 +16516,9 @@ struct llm_build_ernie4_5 : public llm_graph_context { cb(ffn_inp, "ffn_inp", il); // feed-forward network - { + bool is_moe_layer = static_cast(il) >= hparams.n_layer_dense_lead && (il + 1) % hparams.n_moe_layer_step == 0; + + if (!is_moe_layer) { cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); @@ -15320,9 +16531,45 @@ struct llm_build_ernie4_5 : public llm_graph_context { NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(moe_out, "ffn_moe_out", il); + + // Shared expert (if present) + if (hparams.n_ff_shexp > 0) { + ggml_tensor * ffn_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + } else { + cur = moe_out; + } + cb(cur, "ffn_out", il); } cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); cur = build_cvec(cur, il); cb(cur, "l_out", il); @@ -15351,7 +16598,7 @@ struct llm_build_ernie4_5 : public llm_graph_context { }; struct llm_build_falcon_h1 : public llm_graph_context_mamba { - llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) { + llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) { const int64_t n_embd_head = hparams.n_embd_head_v; ggml_tensor * cur; @@ -15407,7 +16654,7 @@ struct llm_build_falcon_h1 : public llm_graph_context_mamba { cb(Kcur, "Kcur-post-rope", il); cb(Vcur, "Vcur-post-rope", il); - ggml_tensor * attn_out = build_attn(inp->get_attn(), gf, + ggml_tensor * attn_out = build_attn(inp->get_attn(), model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); cb(attn_out, "attn_out", il); @@ -15418,7 +16665,7 @@ struct llm_build_falcon_h1 : public llm_graph_context_mamba { // Mamba2 layer cb(cur, "ssm_in", il); - ggml_tensor * ssm_out = build_mamba2_layer(inp->get_recr(), gf, cur, model, ubatch, il); + ggml_tensor * ssm_out = build_mamba2_layer(inp->get_recr(), cur, model, ubatch, il); cb(ssm_out, "ssm_out", il); // // Aggregation @@ -15450,34 +16697,347 @@ struct llm_build_falcon_h1 : public llm_graph_context_mamba { cur = ggml_add(ctx0, cur, inpSA); - cur = build_cvec(cur, il); - cb(cur, "l_out", il); + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_plamo2 : public llm_graph_context_mamba { + llm_build_plamo2(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) { + ggml_tensor * cur; + ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = build_inp_embd(model.tok_embd); + cb(inpL, "embedding_output", -1); + + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_hybrid = build_inp_mem_hybrid(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * residual = inpL; + + // ggml_graph_add_node(gf, model.layers[il].attn_norm); + // cb(model.layers[il].attn_norm, "attn_norm", il); + + // pre_mixer_norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + + // check if this layer is Mamba or Attention + bool is_mamba_layer = hparams.is_recurrent(il); + + if (is_mamba_layer) { + // PLaMo-2 Mamba layer + cur = build_plamo2_mamba_layer(inp_hybrid->get_recr(), cur, model, ubatch, il); + } else { + // PLaMo-2 Attention layer + cur = build_plamo2_attn_layer(inp_hybrid->get_attn(), inp_pos, cur, model, il); + } + + // post_mixer_norm + cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + // residual connection + cur = ggml_add(ctx0, cur, residual); + cb(cur, "attn_residual", il); + residual = cur; + + // pre-ffn norm + cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_pre_norm", il); + + // feed-forward network + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SWIGLU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + // post ffn norm + cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_post_norm", il); + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + residual = ggml_get_rows(ctx0, residual, inp_out_ids); + } + + // residual connection + cur = ggml_add(ctx0, cur, residual); + cb(cur, "ffn_residual", il); + + inpL = cur; + } + + cur = inpL; + + // final norm + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); + + // Explicitly mark as output tensor to ensure proper backend assignment + ggml_set_output(cur); + + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } + +private: + ggml_tensor * build_plamo2_attn_layer( + llm_graph_input_attn_kv_unified * inp, + ggml_tensor * inp_pos, + ggml_tensor * cur, + const llama_model & model, + int il) { + + // self-attention + { + // PLaMo-2 uses combined QKV tensor + ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur); + cb(qkv, "wqkv", il); + + // split QKV tensor into Q, K, V + const int64_t n_embd_head_q = hparams.n_embd_head_k; + const int64_t n_embd_head_k = hparams.n_embd_head_k; + const int64_t n_embd_head_v = hparams.n_embd_head_v; + int32_t n_head_kv = hparams.n_head_kv(il); + + const int64_t q_offset = 0; + const int64_t k_offset = n_embd_head_q * n_head; + const int64_t v_offset = k_offset + n_embd_head_k * n_head_kv; + + ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_q, n_head, n_tokens, n_embd_head_q * sizeof(float), qkv->nb[1], q_offset * ggml_element_size(qkv)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head_kv, n_tokens, n_embd_head_k * sizeof(float), qkv->nb[1], k_offset * ggml_element_size(qkv)); + ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_head_v * n_head_kv, n_tokens, qkv->nb[1], v_offset * ggml_element_size(qkv))); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cur = build_attn(inp, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head_v)), il); + } + + cb(cur, "attn_out", il); + + return cur; + } + + ggml_tensor * build_plamo2_mamba_layer( + llm_graph_input_rs * inp, + ggml_tensor * cur, + const llama_model & model, + const llama_ubatch & ubatch, + int il) { + + const auto * mctx_cur = inp->mctx; + + const auto kv_head = mctx_cur->get_head(); + + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_heads = hparams.ssm_dt_rank; + const int64_t head_dim = d_inner / n_heads; + const int64_t n_group = hparams.ssm_n_group; + const int64_t n_seqs = ubatch.n_seqs; + + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs()); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); + ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); + + ggml_tensor * conv = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); + conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); + + // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} + cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); + + // in_proj: {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs} + ggml_tensor * zx = build_lora_mm(model.layers[il].ssm_in, cur); + cb(zx, "mamba_in_proj", il); + // {8192, 5, 1, 1} -> {8192, 1, 5, 1} + zx = ggml_permute(ctx0, zx, 0, 2, 1, 3); + zx = ggml_cont(ctx0, zx); + zx = ggml_reshape_4d(ctx0, zx, head_dim * 2, n_heads, n_seq_tokens, n_seqs); + cb(zx, "mamba_in_proj_out", il); + + // split into z and x + // => {head_dim * n_heads, n_seq_tokens, n_seqs} + ggml_tensor * x = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], head_dim*ggml_element_size(zx)); + x = ggml_cont(ctx0, x); + x = ggml_reshape_3d(ctx0, x, head_dim * n_heads, n_seq_tokens, n_seqs); + // x = ggml_permute(ctx0, x, 0, 2, 1, 3); + cb(x, "mamba_x_split", il); + + ggml_tensor * z = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], 0); + cb(z, "mamba_z_split", il); + + // conv1d + { + // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs} + ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, x), 0); + cb(conv_x, "mamba_conv1d_input", il); + + // copy last (d_conv - 1) columns back into the state cache + ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs, + conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); + + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, last_conv, + ggml_view_1d(ctx0, conv_states_all, + (d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs), + kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all)))); + cb(conv_states_all, "mamba_conv1d_state", il); + + // 1D convolution + x = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d); + cb(x, "mamba_conv1d", il); + + x = ggml_silu(ctx0, x); + cb(x, "mamba_conv1d_silu", il); + } + + // SSM + { + // bcdt_proj: {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs} + ggml_tensor * x_bcdt = build_lora_mm(model.layers[il].ssm_x, x); + cb(x_bcdt, "mamba_bcdt_proj", il); + + // split into dt, B, C + const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); + ggml_tensor * B = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], 0); + ggml_tensor * C = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*d_state); + ggml_tensor * dt = ggml_view_3d(ctx0, x_bcdt, dt_dim, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*(2*d_state)); + cb(B, "mamba_B_raw", il); + cb(C, "mamba_C_raw", il); + cb(dt, "mamba_dt_raw", il); + + // Apply RMS norm to dt, B, C (PLaMo-2 specific) + B = build_norm(B, model.layers[il].ssm_b_norm, NULL, LLM_NORM_RMS, il); + C = build_norm(C, model.layers[il].ssm_c_norm, NULL, LLM_NORM_RMS, il); + dt = build_norm(dt, model.layers[il].ssm_dt_norm, NULL, LLM_NORM_RMS, il); + cb(B, "mamba_B_normed", il); + cb(C, "mamba_C_normed", il); + cb(dt, "mamba_dt_normed", il); + + // dt_proj: {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs} + dt = build_lora_mm(model.layers[il].ssm_dt, dt); + dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); + cb(dt, "mamba_dt_proj", il); + + ggml_tensor * A = ggml_reshape_2d(ctx0, model.layers[il].ssm_a, 1, n_heads); + cb(A, "mamba_A", il); + + x = ggml_view_4d(ctx0, x, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x), head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0); + B = ggml_view_4d(ctx0, B, d_state, 1, n_seq_tokens, n_seqs, d_state * B->nb[0], B->nb[1], B->nb[2], 0); + C = ggml_view_4d(ctx0, C, d_state, 1, n_seq_tokens, n_seqs, d_state * C->nb[0], C->nb[1], C->nb[2], 0); + + // use the states and the indices provided by build_recurrent_state + // (this is necessary in order to properly use the states before they are overwritten, + // while avoiding to make unnecessary copies of the states) + auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { + ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_heads, mctx_cur->get_size()); + + // Custom operator to optimize the parallel associative scan + // as described in the Annex D of the Mamba paper. + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); + }; + + ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); + cb(y_ssm, "mamba_ssm_scan", il); - // input for next layer - inpL = cur; - } + // store last states + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, + ggml_view_1d(ctx0, y_ssm, n_heads*head_dim*d_state*n_seqs, n_heads*head_dim*n_seq_tokens*n_seqs*ggml_element_size(y_ssm)), + ggml_view_1d(ctx0, ssm_states_all, n_heads*head_dim*d_state*n_seqs, kv_head*n_seqs*n_heads*head_dim*d_state*ggml_element_size(ssm_states_all)))); + cb(ssm_states_all, "mamba_ssm_states", il); - cur = inpL; + ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x), head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0); + cb(y, "mamba_y_view", il); - cur = build_norm(cur, - model.output_norm, NULL, - LLM_NORM_RMS, -1); + // Add D parameter and apply gating with z + // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs} + ggml_tensor * D = ggml_reshape_2d(ctx0, model.layers[il].ssm_d, 1, n_heads); + y = ggml_add(ctx0, y, ggml_mul(ctx0, x, D)); + cb(y, "mamba_y_add_d", il); - cb(cur, "result_norm", -1); - res->t_embd = cur; + y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); + cb(y, "mamba_y_swiglu_z", il); - // lm_head - cur = build_lora_mm(model.output, cur); + // out_proj: {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} + y = ggml_view_3d(ctx0, y, head_dim * n_heads, n_seq_tokens, n_seqs, y->nb[2], y->nb[3], 0); + cur = build_lora_mm(model.layers[il].ssm_out, y); + cb(cur, "mamba_out_proj", il); + } - cb(cur, "result_output", -1); - res->t_logits = cur; + // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); + cb(cur, "mamba_out", il); - ggml_build_forward_expand(gf, cur); + return cur; } }; struct llm_build_arcee : public llm_graph_context { - llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_arcee(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -15553,7 +17113,7 @@ struct llm_build_arcee : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); @@ -15612,7 +17172,7 @@ struct llm_build_arcee : public llm_graph_context { }; struct llm_build_hunyuan_moe : public llm_graph_context { - llm_build_hunyuan_moe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_hunyuan_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -15698,7 +17258,7 @@ struct llm_build_hunyuan_moe : public llm_graph_context { LLM_NORM_RMS, il); cb(Qcur, "Qcur_norm", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); @@ -15742,10 +17302,284 @@ struct llm_build_hunyuan_moe : public llm_graph_context { il); cb(cur_moe, "ffn_moe_out", il); - ggml_tensor * ffn_out = ggml_add(ctx0, cur_moe, cur_mlp); - cb(ffn_out, "ffn_out", il); + ggml_tensor * ffn_out = ggml_add(ctx0, cur_moe, cur_mlp); + cb(ffn_out, "ffn_out", il); + + cur = ggml_add(ctx0, ffn_out, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_hunyuan_dense : public llm_graph_context { + llm_build_hunyuan_dense(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + const float kq_scale = 1.0f / sqrtf(float(n_embd_head)); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = build_norm(Kcur, + model.layers[il].attn_k_norm, nullptr, + LLM_NORM_RMS, il); + cb(Kcur, "Kcur_norm", il); + + Qcur = build_norm(Qcur, + model.layers[il].attn_q_norm, nullptr, + LLM_NORM_RMS, il); + cb(Qcur, "Qcur_norm", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + // feed-forward network (non-MoE) + ggml_tensor * cur_mlp = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur_mlp, "ffn_out", il); + + cur = ggml_add(ctx0, cur_mlp, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + // lm_head + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_smollm3 : public llm_graph_context { + llm_build_smollm3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + const bool use_rope = (il + 1) % hparams.n_no_rope_layer_step != 0; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + if (use_rope) { + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } - cur = ggml_add(ctx0, ffn_out, ffn_inp); + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); cur = build_cvec(cur, il); cb(cur, "l_out", il); @@ -15765,6 +17599,7 @@ struct llm_build_hunyuan_moe : public llm_graph_context { // lm_head cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); res->t_logits = cur; @@ -15772,13 +17607,8 @@ struct llm_build_hunyuan_moe : public llm_graph_context { } }; -struct llm_build_smollm3 : public llm_graph_context { - llm_build_smollm3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); - +struct llm_build_openai_moe_iswa : public llm_graph_context { + llm_build_openai_moe_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { ggml_tensor * cur; ggml_tensor * inpL; @@ -15787,20 +17617,14 @@ struct llm_build_smollm3 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); - - const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; - - ggml_tensor * inp_out_ids = build_inp_out_ids(); + auto * inp_attn = build_attn_inp_kv_unified_iswa(); for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; - const bool use_rope = (il + 1) % hparams.n_no_rope_layer_step != 0; - // norm cur = build_norm(inpL, - model.layers[il].attn_norm, NULL, + model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); @@ -15828,35 +17652,36 @@ struct llm_build_smollm3 : public llm_graph_context { cb(Vcur, "Vcur", il); } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_rot, n_head_kv, n_tokens); - if (use_rope) { - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - } + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn_with_sinks(inp_attn, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].attn_sinks, 1.0f/sqrtf(float(n_rot)), il); + cb(cur, "attn_out", il); } - if (il == n_layer - 1 && inp_out_ids) { + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -15864,24 +17689,27 @@ struct llm_build_smollm3 : public llm_graph_context { ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); - // feed-forward network - { - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); + cur = ffn_inp; + cur = build_norm(cur, + model.layers[il].attn_post_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); - cur = build_ffn(cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, - model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur, "ffn_out", il); - } + // MoE branch + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, model.layers[il].ffn_gate_inp_b, + model.layers[il].ffn_up_exps, model.layers[il].ffn_up_exps_b, + model.layers[il].ffn_gate_exps, model.layers[il].ffn_gate_exps_b, + model.layers[il].ffn_down_exps, model.layers[il].ffn_down_exps_b, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SWIGLU_OAI_MOE, false, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT, + il); + cb(cur, "ffn_moe_out", il); cur = ggml_add(ctx0, cur, ffn_inp); - cb(cur, "ffn_out", il); cur = build_cvec(cur, il); cb(cur, "l_out", il); @@ -15912,7 +17740,7 @@ struct llm_build_smollm3 : public llm_graph_context { struct llm_build_lfm2 : public llm_graph_context { const llama_model & model; - llm_build_lfm2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) { + llm_build_lfm2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) { ggml_tensor * cur = build_inp_embd(model.tok_embd); cb(cur, "model.embed_tokens", -1); @@ -15927,8 +17755,8 @@ struct llm_build_lfm2 : public llm_graph_context { cb(cur, "model.layers.{}.operator_norm", il); cur = hparams.is_recurrent(il) ? - build_shortconv_block(gf, cur, inp_hybrid->get_recr(), il) : - build_attn_block(gf, cur, inp_pos, inp_hybrid->get_attn(), il) ; + build_shortconv_block(cur, inp_hybrid->get_recr(), il) : + build_attn_block(cur, inp_pos, inp_hybrid->get_attn(), il) ; if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); @@ -15971,8 +17799,7 @@ struct llm_build_lfm2 : public llm_graph_context { return cur; } - ggml_tensor * build_attn_block(ggml_cgraph * gf, - ggml_tensor * cur, + ggml_tensor * build_attn_block(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv_unified * inp_attn, int il) const { @@ -16009,7 +17836,7 @@ struct llm_build_lfm2 : public llm_graph_context { ext_factor, attn_factor, beta_fast, beta_slow ); - cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, q, k, v, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "model.layers.{}.self_attn.out_proj", il); @@ -16017,11 +17844,22 @@ struct llm_build_lfm2 : public llm_graph_context { return cur; } - ggml_tensor * build_shortconv_block(ggml_cgraph * gf, - ggml_tensor * cur, + ggml_tensor * build_shortconv_block(ggml_tensor * cur, llm_graph_input_rs * inp_recr, int il) { - const auto * mctx_cur = static_cast(mctx)->get_recr(); + const auto * mctx_cur = static_cast(mctx)->get_recr(); + const uint32_t kv_head = mctx_cur->get_head(); + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seqs = ubatch.n_seqs; + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs()); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + GGML_ASSERT(hparams.n_shortconv_l_cache > 1); + const uint32_t d_conv = hparams.n_shortconv_l_cache - 1; + + // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} + cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); auto * bcx = build_lora_mm(model.layers[il].shortconv.in_proj, cur); cb(bcx, "model.layers.{}.conv.in_proj", il); @@ -16029,43 +17867,174 @@ struct llm_build_lfm2 : public llm_graph_context { constexpr auto n_chunks = 3; GGML_ASSERT(bcx->ne[0] % n_chunks == 0); auto const chunk_size = bcx->ne[0] / n_chunks; - auto * b = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 0 * chunk_size * ggml_element_size(bcx)); - auto * c = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 1 * chunk_size * ggml_element_size(bcx)); - auto * x = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 2 * chunk_size * ggml_element_size(bcx)); + auto * b = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], 0*chunk_size*ggml_element_size(bcx)); + auto * c = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], 1*chunk_size*ggml_element_size(bcx)); + auto * x = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], 2*chunk_size*ggml_element_size(bcx)); auto * bx = ggml_transpose(ctx0, ggml_mul(ctx0, b, x)); - // read conv state directly, with build_rs generation is slower - ggml_tensor * conv_state = mctx_cur->get_r_l(il); - const int64_t n_seqs = ubatch.n_seqs; - ggml_tensor * conv = build_rs(inp_recr, gf, conv_state, hparams.n_embd_r(), n_seqs); - conv = ggml_reshape_3d(ctx0, conv_state, hparams.n_shortconv_l_cache - 1, hparams.n_embd, n_seqs); + // read conv state + auto * conv_state = mctx_cur->get_r_l(il); + auto * conv_rs = build_rs(inp_recr, conv_state, hparams.n_embd_r(), n_seqs); + auto * conv = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs); bx = ggml_concat(ctx0, conv, bx, 0); GGML_ASSERT(bx->ne[0] > conv->ne[0]); - auto * new_conv = ggml_view_2d(ctx0, bx, conv->ne[0], bx->ne[1], bx->nb[1], (bx->ne[0] - conv->ne[0]) * ggml_element_size(bx)); + // last d_conv columns is a new conv state + auto * new_conv = ggml_view_3d(ctx0, bx, conv->ne[0], bx->ne[1], bx->ne[2], bx->nb[1], bx->nb[2], (bx->ne[0] - conv->ne[0])*ggml_element_size(bx)); GGML_ASSERT(ggml_are_same_shape(conv, new_conv)); - // write conv state - ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_conv, conv_state)); + // write new conv conv state + ggml_build_forward_expand( + gf, + ggml_cpy( + ctx0, + new_conv, + ggml_view_1d( + ctx0, + conv_state, + ggml_nelements(new_conv), + kv_head*d_conv*n_embd*ggml_element_size(new_conv) + ) + ) + ); auto * conv_kernel = model.layers[il].shortconv.conv; - GGML_ASSERT(hparams.n_shortconv_l_cache > 0); - - // construct ssm_conv op - ggml_tensor * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel); + auto * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel); cb(conv_out, "model.layers.{}.conv.conv", il); auto * y = ggml_mul(ctx0, c, conv_out); - y = build_lora_mm(model.layers[il].shortconv.out_proj, y); cb(y, "model.layers.{}.conv.out_proj", il); + // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} + y = ggml_reshape_2d(ctx0, y, y->ne[0], n_seq_tokens * n_seqs); return y; } }; +template +struct llm_build_smallthinker : public llm_graph_context{ + llm_build_smallthinker(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params){ + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + using inp_attn_type = std::conditional_t; + inp_attn_type * inp_attn = nullptr; + + if constexpr (iswa) { + inp_attn = build_attn_inp_kv_unified_iswa(); + } else { + inp_attn = build_attn_inp_kv_unified(); + } + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + ggml_tensor * probs = nullptr; + + probs = build_lora_mm(model.layers[il].ffn_gate_inp, inpL); // [n_expert, n_tokens] + cb(probs, "ffn_moe_logits", il); + + // norm + cur = build_norm(inpL,model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + if (hparams.n_no_rope_layer_step == n_layer || il % hparams.n_no_rope_layer_step != 0) { + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + probs = ggml_get_rows(ctx0, probs, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE branch + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * ffn_out = + build_moe_ffn(cur, + nullptr, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_RELU, true, + false, 0.0, + static_cast(hparams.expert_gating_func), + il, probs); + + cb(ffn_out, "ffn_out", il); + cur = ffn_out; + + cur = ggml_add(ctx0, cur, ffn_inp); + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const { llama_memory_i * res; @@ -16078,6 +18047,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_NEO_BERT: case LLM_ARCH_WAVTOKENIZER_DEC: + case LLM_ARCH_DREAM: + case LLM_ARCH_LLADA: { res = nullptr; } break; @@ -16113,12 +18084,24 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), /* n_seq_max */ cparams.n_seq_max, /* offload */ cparams.offload_kqv, + /* unified */ cparams.kv_unified, /* filter_attn */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr, /* filter_recr */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr); } else { const auto padding = llama_kv_cache_unified::get_padding(cparams); - cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); + uint32_t n_ctx_per_stream = cparams.n_ctx; + + if (!cparams.kv_unified) { + n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max; + n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding); + + cparams.n_ctx = n_ctx_per_stream*cparams.n_seq_max; + } else { + n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding); + + cparams.n_ctx = n_ctx_per_stream; + } LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); @@ -16132,7 +18115,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, !cparams.flash_attn, cparams.offload_kqv, params.swa_full, - cparams.n_ctx, + cparams.kv_unified, + n_ctx_per_stream, cparams.n_seq_max, cparams.n_ubatch, padding); @@ -16146,7 +18130,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, params.type_v, !cparams.flash_attn, cparams.offload_kqv, - cparams.n_ctx, + cparams.kv_unified, + n_ctx_per_stream, cparams.n_seq_max, padding, hparams.n_swa, @@ -16159,227 +18144,242 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, return res; } -llm_graph_result_ptr llama_model::build_graph( - const llm_graph_params & params, - ggml_cgraph * gf, - llm_graph_type type) const { +ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { std::unique_ptr llm; switch (arch) { case LLM_ARCH_LLAMA: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_LLAMA4: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_DECI: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_BAICHUAN: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_FALCON: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_GROK: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_STARCODER: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_REFACT: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_BERT: case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_NEO_BERT: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_BLOOM: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_MPT: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_STABLELM: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_QWEN: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_QWEN2: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; + case LLM_ARCH_DREAM: + { + llm = std::make_unique(*this, params); + } + break; + case LLM_ARCH_LLADA: + { + llm = std::make_unique(*this, params); + } + break; case LLM_ARCH_QWEN2VL: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_QWEN2MOE: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_QWEN3: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_QWEN3MOE: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_PHI2: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_PHI3: case LLM_ARCH_PHIMOE: { if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - llm = std::make_unique> (*this, params, gf); + llm = std::make_unique> (*this, params); } else { - llm = std::make_unique>(*this, params, gf); + llm = std::make_unique>(*this, params); } } break; case LLM_ARCH_PLAMO: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); + } break; + case LLM_ARCH_PLAMO2: + { + llm = std::make_unique(*this, params); } break; case LLM_ARCH_GPT2: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_CODESHELL: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_ORION: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_INTERNLM2: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_MINICPM3: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_GEMMA: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_GEMMA2: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_GEMMA3: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_GEMMA3N: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_STARCODER2: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_MAMBA: case LLM_ARCH_MAMBA2: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_JAMBA: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_XVERSE: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_COMMAND_R: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_COHERE2: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_DBRX: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_OLMO: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_OLMO2: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_OLMOE: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_OPENELM: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_GPTNEOX: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_ARCTIC: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_DEEPSEEK: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_DEEPSEEK2: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_CHATGLM: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_GLM4: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); + } break; + case LLM_ARCH_GLM4_MOE: + { + llm = std::make_unique(*this, params); } break; case LLM_ARCH_BITNET: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_T5: { - switch (type) { + switch (params.gtype) { case LLM_GRAPH_TYPE_ENCODER: - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); break; case LLM_GRAPH_TYPE_DEFAULT: case LLM_GRAPH_TYPE_DECODER: - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); break; default: GGML_ABORT("invalid graph type"); @@ -16387,99 +18387,127 @@ llm_graph_result_ptr llama_model::build_graph( } break; case LLM_ARCH_T5ENCODER: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_JAIS: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_NEMOTRON: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_EXAONE: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); + } break; + case LLM_ARCH_EXAONE4: + { + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + llm = std::make_unique>(*this, params); + } else { + llm = std::make_unique>(*this, params); + } } break; case LLM_ARCH_RWKV6: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_RWKV6QWEN2: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_RWKV7: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_ARWKV7: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: case LLM_ARCH_MINICPM: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_GRANITE_HYBRID: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_CHAMELEON: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_WAVTOKENIZER_DEC: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_PLM: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_BAILINGMOE: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_DOTS1: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_ARCEE: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_ERNIE4_5: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); + } break; + case LLM_ARCH_ERNIE4_5_MOE: + { + llm = std::make_unique(*this, params); } break; case LLM_ARCH_HUNYUAN_MOE: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); + } break; + case LLM_ARCH_HUNYUAN_DENSE: + { + llm = std::make_unique(*this, params); } break; case LLM_ARCH_SMOLLM3: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); + } break; + case LLM_ARCH_OPENAI_MOE: + { + llm = std::make_unique(*this, params); } break; case LLM_ARCH_FALCON_H1: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_LFM2: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); + } break; + case LLM_ARCH_SMALLTHINKER: + { + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + llm = std::make_unique> (*this, params); + } else { + llm = std::make_unique>(*this, params); + } } break; default: GGML_ABORT("fatal error"); } // add on pooling layer - llm->build_pooling(gf, cls, cls_b, cls_out, cls_out_b); + llm->build_pooling(cls, cls_b, cls_out, cls_out_b); - return std::move(llm->res); + return llm->res->get_gf(); } // @@ -16501,6 +18529,7 @@ llama_model_params llama_model_default_params() { /*.use_mmap =*/ true, /*.use_mlock =*/ false, /*.check_tensors =*/ false, + /*.use_extra_bufts =*/ true, }; #ifdef GGML_USE_METAL @@ -16603,6 +18632,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { // use what we call a normal RoPE, operating on pairs of consecutive head values case LLM_ARCH_LLAMA: + case LLM_ARCH_LLADA: case LLM_ARCH_LLAMA4: case LLM_ARCH_DECI: case LLM_ARCH_BAICHUAN: @@ -16628,6 +18658,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_SMOLLM3: case LLM_ARCH_ARCEE: case LLM_ARCH_ERNIE4_5: + case LLM_ARCH_ERNIE4_5_MOE: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 @@ -16642,6 +18673,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_BITNET: case LLM_ARCH_QWEN: case LLM_ARCH_QWEN2: + case LLM_ARCH_DREAM: case LLM_ARCH_QWEN2MOE: case LLM_ARCH_QWEN3: case LLM_ARCH_QWEN3MOE: @@ -16651,6 +18683,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_PHI3: case LLM_ARCH_PHIMOE: case LLM_ARCH_PLAMO: + case LLM_ARCH_PLAMO2: case LLM_ARCH_GEMMA: case LLM_ARCH_GEMMA2: case LLM_ARCH_GEMMA3: @@ -16662,10 +18695,15 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ORION: case LLM_ARCH_NEMOTRON: case LLM_ARCH_EXAONE: + case LLM_ARCH_EXAONE4: case LLM_ARCH_MINICPM3: case LLM_ARCH_DOTS1: case LLM_ARCH_HUNYUAN_MOE: + case LLM_ARCH_OPENAI_MOE: + case LLM_ARCH_HUNYUAN_DENSE: case LLM_ARCH_LFM2: + case LLM_ARCH_SMALLTHINKER: + case LLM_ARCH_GLM4_MOE: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: @@ -16776,6 +18814,10 @@ bool llama_model_is_recurrent(const llama_model * model) { return llm_arch_is_recurrent(model->arch); } +bool llama_model_is_diffusion(const llama_model * model) { + return llm_arch_is_diffusion(model->arch); +} + const std::vector> & llama_internal_get_tensor_map(const llama_model * model) { return model->tensors_by_name; } diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index 027a7f0c3e2..46f7d0480fa 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -39,6 +39,7 @@ enum llm_type { LLM_TYPE_410M, LLM_TYPE_450M, LLM_TYPE_475M, + LLM_TYPE_537M, LLM_TYPE_700M, LLM_TYPE_770M, LLM_TYPE_780M, @@ -99,8 +100,12 @@ enum llm_type { LLM_TYPE_17B_16E, // llama4 Scout LLM_TYPE_17B_128E, // llama4 Maverick LLM_TYPE_A13B, + LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_30B_A3B, + LLM_TYPE_106B_A12B, // GLM-4.5-Air LLM_TYPE_235B_A22B, + LLM_TYPE_300B_A47B, // Ernie MoE big + LLM_TYPE_355B_A32B, // GLM-4.5 LLM_TYPE_E2B, LLM_TYPE_E4B, }; @@ -164,6 +169,15 @@ struct llama_layer_shortconv { struct ggml_tensor * out_proj = nullptr; }; +struct llama_layer_nextn { + struct ggml_tensor * eh_proj = nullptr; + struct ggml_tensor * embed_tokens = nullptr; + struct ggml_tensor * enorm = nullptr; + struct ggml_tensor * hnorm = nullptr; + struct ggml_tensor * shared_head_head = nullptr; + struct ggml_tensor * shared_head_norm = nullptr; +}; + struct llama_layer { // normalization struct ggml_tensor * attn_norm = nullptr; @@ -239,10 +253,14 @@ struct llama_layer { struct ggml_tensor * ffn_up_enc = nullptr; // ff MoE - struct ggml_tensor * ffn_gate_inp = nullptr; - struct ggml_tensor * ffn_gate_exps = nullptr; - struct ggml_tensor * ffn_down_exps = nullptr; - struct ggml_tensor * ffn_up_exps = nullptr; + struct ggml_tensor * ffn_gate_inp = nullptr; + struct ggml_tensor * ffn_gate_exps = nullptr; + struct ggml_tensor * ffn_down_exps = nullptr; + struct ggml_tensor * ffn_up_exps = nullptr; + struct ggml_tensor * ffn_gate_inp_b = nullptr; + struct ggml_tensor * ffn_gate_exps_b = nullptr; + struct ggml_tensor * ffn_down_exps_b = nullptr; + struct ggml_tensor * ffn_up_exps_b = nullptr; // ff shared expert (shexp) struct ggml_tensor * ffn_gate_inp_shexp = nullptr; @@ -347,11 +365,16 @@ struct llama_layer { struct ggml_tensor * laurel_r = nullptr; struct ggml_tensor * laurel_post_norm = nullptr; + // openai-moe + struct ggml_tensor * attn_sinks = nullptr; + struct llama_layer_posnet posnet; struct llama_layer_convnext convnext; struct llama_layer_shortconv shortconv; + + struct llama_layer_nextn nextn; }; struct llama_model { @@ -452,10 +475,7 @@ struct llama_model { llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const; // TODO: move this to new llm_arch_model_i interface - llm_graph_result_ptr build_graph( - const llm_graph_params & params, - ggml_cgraph * gf, - llm_graph_type type) const; + ggml_cgraph * build_graph(const llm_graph_params & params) const; private: struct impl; diff --git a/examples/talk-llama/llama-quant.cpp b/examples/talk-llama/llama-quant.cpp index 4dbd1e30991..1d0361cc166 100644 --- a/examples/talk-llama/llama-quant.cpp +++ b/examples/talk-llama/llama-quant.cpp @@ -211,7 +211,10 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t const int64_t nx = tensor->ne[0]; const int64_t qk_k = ggml_blck_size(new_type); - if (arch == LLM_ARCH_FALCON || nx % qk_k != 0) { + if (ftype == LLAMA_FTYPE_MOSTLY_MXFP4_MOE) { + new_type = GGML_TYPE_Q8_0; + } + else if (arch == LLM_ARCH_FALCON || nx % qk_k != 0) { new_type = GGML_TYPE_Q8_0; } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || @@ -223,6 +226,14 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t new_type = GGML_TYPE_Q6_K; } } + } else if (ftype == LLAMA_FTYPE_MOSTLY_MXFP4_MOE) { + // MoE tensors -> MXFP4 + // other tensors -> Q8_0 + if (tensor->ne[2] > 1) { + new_type = GGML_TYPE_MXFP4; + } else { + new_type = GGML_TYPE_Q8_0; + } } else if (name == "token_embd.weight" || name == "per_layer_token_embd.weight") { if (qs.params->token_embedding_type < GGML_TYPE_COUNT) { new_type = qs.params->token_embedding_type; @@ -533,6 +544,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break; case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break; + case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: default_type = GGML_TYPE_MXFP4; break; + // K-quants case LLAMA_FTYPE_MOSTLY_Q2_K_S: case LLAMA_FTYPE_MOSTLY_Q2_K: default_type = GGML_TYPE_Q2_K; break; @@ -875,23 +888,22 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // get more optimal quantization type based on the tensor shape, layer, etc. if (!params->pure && ggml_is_quantized(default_type)) { + int fallback = qs.n_fallback; new_type = llama_tensor_get_type(qs, new_type, tensor, ftype); - // unless the user specifies a type - if (params->tensor_types) { + // unless the user specifies a type, and the tensor geometry will not require fallback quantisation + if (params->tensor_types && qs.n_fallback - fallback == 0) { const std::vector & tensor_types = *static_cast *>(params->tensor_types); const std::string tensor_name(tensor->name); for (const auto & [tname, qtype] : tensor_types) { if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) { if (qtype != new_type) { LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type)); - new_type = qtype; - break; // if two or more types are specified for the tensor, first match wins + new_type = qtype; // if two or more types are specified for the same tensor, the last match wins } } } } } - if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) { new_type = params->token_embedding_type; } @@ -985,6 +997,29 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr; new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use); + + // TODO: temporary sanity check that the F16 -> MXFP4 is lossless +#if 0 + if (new_type == GGML_TYPE_MXFP4) { + auto * x = f32_data_03; + + //LLAMA_LOG_INFO("nrows = %d, n_per_row = %d\n", nrows, n_per_row); + std::vector deq(nrows*n_per_row); + const ggml_type_traits * qtype = ggml_get_type_traits(new_type); + qtype->to_float(new_data_03, deq.data(), deq.size()); + + double err = 0.0f; + for (int i = 0; i < (int) deq.size(); ++i) { + err += fabsf(deq[i] - x[i]); + //if (fabsf(deq[i] - x[i]) > 0.00001 && i < 256) { + if (deq[i] != x[i]) { + LLAMA_LOG_INFO("deq[%d] = %f, x[%d] = %f\n", i, deq[i], i, x[i]); + } + } + //LLAMA_LOG_INFO("err = %f\n", err); + GGML_ASSERT(err == 0.00000); + } +#endif } LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0); } diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index e0e578d6394..de5d1681dff 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -306,6 +307,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { }; break; case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM: + case LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE: regex_exprs = { "\\p{N}{1,3}", "[一-龥぀-ゟ゠-ヿ]+", @@ -404,6 +406,13 @@ struct llm_tokenizer_bpe : llm_tokenizer { "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_KIMI_K2: + regex_exprs = { + // K2 trigger pattern - this will activate the custom K2 handler in unicode.cpp + // The custom handler implements all K2 patterns with proper Han character exclusion + "\\p{Han}+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_SUPERBPE: regex_exprs = { "\\p{N}+", @@ -1196,6 +1205,284 @@ struct llm_tokenizer_rwkv_session { const llm_tokenizer_rwkv & tokenizer; }; +struct llm_tokenizer_plamo2 : llm_tokenizer { + llm_tokenizer_plamo2(const llama_vocab & vocab) { + build(vocab); + } + + void build(const llama_vocab & vocab) { + // Reset internal structures + tokens_.clear(); + bytes_.assign(256, 0); + to_suffix_id_.clear(); + table_.clear(); + + // Build token list and byte mapping + std::unordered_map suffix_to_score; + std::unordered_map token_to_id; + + for (size_t token_id = 0; token_id < vocab.n_tokens(); ++token_id) { + const auto & entry = vocab.get_token_data(token_id); + tokens_.push_back(entry.text); + token_to_id[entry.text] = static_cast(token_id); + + // Handle byte tokens + if (vocab.is_byte(token_id)) { + if (entry.text.length() == 6 && entry.text.substr(0, 3) == "<0x" && entry.text.back() == '>') { + std::string hex_str = entry.text.substr(3, 2); + int byte_val = std::stoi(hex_str, nullptr, 16); + bytes_[byte_val] = static_cast(token_id); + } + continue; + } + + // Add token and all its suffixes to suffix_to_score + suffix_to_score[entry.text] = entry.score; + + // Extract suffixes character by character (UTF-8 aware) + std::vector cpts = unicode_cpts_from_utf8(entry.text); + for (size_t i = 1; i < cpts.size(); ++i) { + std::string suffix; + for (size_t j = i; j < cpts.size(); ++j) { + suffix += unicode_cpt_to_utf8(cpts[j]); + } + if (suffix_to_score.find(suffix) == suffix_to_score.end()) { + suffix_to_score[suffix] = std::numeric_limits::quiet_NaN(); + } + } + } + + // Check that all byte tokens are set + for (int i = 0; i < 256; ++i) { + if (bytes_[i] == 0) { + throw std::runtime_error("Byte token for <0x" + std::to_string(i) + "> is not set"); + } + } + + // Build suffix list in lexicographical order of reversed strings + std::vector suffixes; + for (const auto & pair : suffix_to_score) { + suffixes.push_back(pair.first); + } + suffixes.push_back(""); // Empty suffix + + std::sort(suffixes.begin(), suffixes.end(), [](const std::string & a, const std::string & b) { + std::string rev_a(a.rbegin(), a.rend()); + std::string rev_b(b.rbegin(), b.rend()); + return rev_a < rev_b; + }); + + // Build suffix_to_id and to_suffix_id_ + std::unordered_map suffix_to_id; + int32_t num_pieces = 0; + + for (const auto & suffix : suffixes) { + suffix_to_id[suffix] = num_pieces; + if (!suffix.empty()) { + std::vector cpts = unicode_cpts_from_utf8(suffix); + + std::string remaining; + for (size_t i = 1; i < cpts.size(); ++i) { + remaining += unicode_cpt_to_utf8(cpts[i]); + } + + int64_t piece_code = (static_cast(cpts[0]) << 32) | suffix_to_id[remaining]; + to_suffix_id_[piece_code] = num_pieces; + + // Count number of pieces for this suffix + int32_t pieces_for_suffix = 1; // sentinel row + for (int32_t piece_length = static_cast(cpts.size()); piece_length > 0; --piece_length) { + std::string piece; + for (int32_t i = 0; i < piece_length; ++i) { + piece += unicode_cpt_to_utf8(cpts[i]); + } + if (suffix_to_score.find(piece) != suffix_to_score.end()) { + pieces_for_suffix++; + } + } + num_pieces += pieces_for_suffix; + } else { + num_pieces++; // Empty suffix contributes one piece (sentinel row) + } + } + + // Build flattened table + table_.resize(num_pieces, std::vector(4, 0)); + int32_t table_idx = 0; + + for (const auto & suffix : suffixes) { + // Add all prefixes of the suffix to the table (in decreasing order of length) + std::vector cpts = unicode_cpts_from_utf8(suffix); + for (int32_t piece_length = static_cast(cpts.size()); piece_length > 0; --piece_length) { + std::string piece; + for (int32_t i = 0; i < piece_length; ++i) { + piece += unicode_cpt_to_utf8(cpts[i]); + } + + auto score_it = suffix_to_score.find(piece); + if (score_it == suffix_to_score.end()) { + continue; + } + + table_[table_idx][TABLE_PIECE_LENGTH] = piece_length; + auto token_it = token_to_id.find(piece); + table_[table_idx][TABLE_TOKEN_ID] = (token_it != token_to_id.end()) ? token_it->second : -1; + + float score = score_it->second; + table_[table_idx][TABLE_SCORE] = std::isfinite(score) ? + static_cast(std::round(score * 1e4)) : INVALID_SCORE; + table_[table_idx][TABLE_PIECE_ID] = suffix_to_id[piece]; + + table_idx++; + } + + // Add sentinel row + table_[table_idx][TABLE_PIECE_LENGTH] = 1; + table_[table_idx][TABLE_TOKEN_ID] = -1; + table_[table_idx][TABLE_SCORE] = UNKNOWN_SCORE; + table_idx++; + } + } + + std::vector encode(const std::string & text) const { + std::vector unicode_data = unicode_cpts_from_utf8(text); + // Skip the first code point if it is a BOM (Byte Order Mark) + if (!unicode_data.empty() && unicode_data[0] == 0xFEFF) { + unicode_data.erase(unicode_data.begin()); + } + + if (unicode_data.empty()) { + return {}; + } + + const size_t data_len = unicode_data.size(); + + // Initialize scores array (dynamic programming) + std::vector scores(data_len + 1, static_cast(1) << 60); + scores[data_len] = 0; + + // Path array to track best tokenization + std::vector> path(data_len + 1, std::vector(3, 0)); + + int32_t suffix_id = 0; + + // Process from end to beginning + for (int i = static_cast(data_len) - 1; i >= 0; --i) { + uint32_t c = unicode_data[i]; + + // Find next suffix ID + for (size_t p = suffix_id; p < table_.size(); ++p) { + int64_t piece_code = (static_cast(c) << 32) | table_[p][TABLE_PIECE_ID]; + auto it = to_suffix_id_.find(piece_code); + suffix_id = (it != to_suffix_id_.end()) ? it->second : 0; + + if (suffix_id > 0 || table_[p][TABLE_SCORE] == UNKNOWN_SCORE) { + break; + } + } + + // Update best path + for (size_t p = suffix_id; p < table_.size(); ++p) { + int32_t score = table_[p][TABLE_SCORE]; + if (score > INVALID_SCORE) { + int32_t piece_length = table_[p][TABLE_PIECE_LENGTH]; + int64_t s = scores[i + piece_length] - score; + + if (s < scores[i]) { + scores[i] = s; + path[i][PATH_TOKEN_LENGTH] = piece_length; + path[i][PATH_TOKEN_ID] = table_[p][TABLE_TOKEN_ID]; + path[i][PATH_NUM_TOKENS] = path[i + piece_length][PATH_NUM_TOKENS] + 1; + + if (score == UNKNOWN_SCORE) { + // Add UTF-8 byte count + path[i][PATH_NUM_TOKENS] += (c >= 0x80) + (c >= 0x800) + (c >= 0x10000); + } + } + } + + if (score == UNKNOWN_SCORE) { + break; + } + } + } + + // Decode the best path + std::vector token_ids; + token_ids.reserve(path[0][PATH_NUM_TOKENS]); + + int pos = 0; + while (pos < static_cast(data_len)) { + if (path[pos][PATH_TOKEN_ID] >= 0) { + token_ids.push_back(path[pos][PATH_TOKEN_ID]); + } else { + // Fall back to byte tokens + uint32_t c = unicode_data[pos]; + int s = 1 + (c >= 0x80) + (c >= 0x800) + (c >= 0x10000); + + for (int i = 0; i < s; ++i) { + uint8_t b; + if (s == 1) { + b = c; + } else { + if (i == 0) { + b = (0xF00 >> s) & 0xFF; + } else { + b = 0x80; + } + } + token_ids.push_back(bytes_[b | ((c >> ((s - i - 1) * 6)) & 0x3F)]); + } + } + + assert(path[pos][PATH_TOKEN_LENGTH] > 0); + pos += path[pos][PATH_TOKEN_LENGTH]; + } + + return token_ids; + } +private: + // Constants for table structure + static constexpr int32_t TABLE_PIECE_LENGTH = 0; + static constexpr int32_t TABLE_TOKEN_ID = 1; + static constexpr int32_t TABLE_SCORE = 2; + static constexpr int32_t TABLE_PIECE_ID = 3; + + // Constants for path array + static constexpr int32_t PATH_TOKEN_LENGTH = 0; + static constexpr int32_t PATH_TOKEN_ID = 1; + static constexpr int32_t PATH_NUM_TOKENS = 2; + + // Score constants + static constexpr int32_t INVALID_SCORE = -20000000; + static constexpr int32_t UNKNOWN_SCORE = -10000000; + + // List of tokens in the vocabulary + std::vector tokens_; + + // Mapping from byte code point to token ID (for byte fallback) + std::vector bytes_; + + // Mapping from piece code to suffix ID + std::unordered_map to_suffix_id_; + + // Flattened table representing the Trie structure + // Each row contains: [piece_length, token_id, score, piece_id] + std::vector> table_; +}; + +struct llm_tokenizer_plamo2_session { + llm_tokenizer_plamo2_session(const llm_tokenizer_plamo2 & tokenizer) : tokenizer(tokenizer) {} + + void tokenize(const std::string & text, std::vector & output) { + std::vector tokens = tokenizer.encode(text); + output.insert(output.end(), tokens.begin(), tokens.end()); + } + +private: + const llm_tokenizer_plamo2 & tokenizer; +}; + // // impl // @@ -1499,6 +1786,16 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { special_unk_id = LLAMA_TOKEN_NULL; special_sep_id = LLAMA_TOKEN_NULL; special_pad_id = LLAMA_TOKEN_NULL; + } else if (tokenizer_model == "plamo2") { + type = LLAMA_VOCAB_TYPE_PLAMO2; + + // PLaMo-2 default special tokens (these will be overridden by model config) + special_bos_id = 1; // <|plamo:bos|> + special_eos_id = 2; // <|plamo:eos|> + special_unk_id = 0; // <|plamo:unk|> + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = 3; // <|plamo:pad|> + special_mask_id = LLAMA_TOKEN_NULL; } else { throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); } @@ -1559,7 +1856,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "gigachat" || tokenizer_pre == "jina-v2-es" || tokenizer_pre == "jina-v2-de" || - tokenizer_pre == "a.x-4.0") { + tokenizer_pre == "a.x-4.0" || + tokenizer_pre == "mellum") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; } else if ( tokenizer_pre == "jina-v1-en" || @@ -1629,6 +1927,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } else if ( tokenizer_pre == "exaone") { pre_type = LLAMA_VOCAB_PRE_TYPE_EXAONE; + } else if ( + tokenizer_pre == "exaone4") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; } else if ( tokenizer_pre == "chameleon") { pre_type = LLAMA_VOCAB_PRE_TYPE_CHAMELEON; @@ -1665,6 +1966,14 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "hunyuan") { pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN; clean_spaces = false; + } else if ( + tokenizer_pre == "hunyuan-dense") { + pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE; + clean_spaces = false; + } else if ( + tokenizer_pre == "kimi-k2") { + pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2; + clean_spaces = false; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } @@ -1882,6 +2191,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|fim▁begin|>" // DeepSeek || t.first == "
"
                         || t.first == "▁
"          // CodeLlama
+                        || t.first == "<|code_prefix|>" // GLM-4.5
                         ) {
                     special_fim_pre_id = t.second;
                     if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -1901,6 +2211,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == "<|fim▁hole|>" // DeepSeek
                         || t.first == ""
                         || t.first == "▁"         // CodeLlama
+                        || t.first == "<|code_suffix|>" // GLM-4.5
                         ) {
                     special_fim_suf_id = t.second;
                     if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -1920,6 +2231,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == "<|fim▁end|>"  // DeepSeek
                         || t.first == ""
                         || t.first == "▁"         // CodeLlama
+                        || t.first == "<|code_middle|>" // GLM-4.5
                         ) {
                     special_fim_mid_id = t.second;
                     if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2002,6 +2314,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                     || t.first == "<|eot_id|>"
                     || t.first == "<|im_end|>"
                     || t.first == "<|end|>"
+                    || t.first == "<|return|>" // o200k_harmony
+                    || t.first == "<|call|>"   // o200k_harmony
                     || t.first == ""
                     || t.first == "<|endoftext|>"
                     || t.first == "<|eom_id|>"
@@ -2025,6 +2339,13 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
             }
         }
 
+        // @ngxson : quick hack for gpt-oss, always render these tokens
+        for (const auto & t : token_to_id) {
+            if (t.first == "<|channel|>" || t.first == "<|message|>" || t.first == "<|start|>" || t.first == "<|constrain|>") {
+                id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_USER_DEFINED;
+            }
+        }
+
         // sanity checks
         if (special_eos_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_eos_id) == 0) {
             special_eog_ids.insert(special_eos_id);
@@ -2040,6 +2361,37 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
             special_eog_ids.insert(special_eom_id);
             LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
         }
+
+        // TODO: workaround for o200k_harmony tokenizer: the "<|end|>" token should not be EOG
+        //       we don't have a good way to detect this, so for now, if we have "<|return|>" and "<|call|>" tokens,
+        //       we remove the "<|end|>" token from the EOG list
+        {
+            bool has_return = false;
+            bool has_call   = false;
+            bool has_end    = false;
+
+            llama_token end_id = LLAMA_TOKEN_NULL;
+
+            LLAMA_LOG_INFO("%s: printing all EOG tokens:\n", __func__);
+            for (auto tid : special_eog_ids) {
+                LLAMA_LOG_INFO("%s:   - %d ('%s')\n", __func__, tid, id_to_token[tid].text.c_str());
+
+                if (id_to_token[tid].text == "<|return|>") {
+                    has_return = true;
+                } else if (id_to_token[tid].text == "<|call|>") {
+                    has_call = true;
+                } else if (id_to_token[tid].text == "<|end|>") {
+                    has_end = true;
+                    end_id = tid;
+                }
+            }
+
+            if (has_return && has_call && has_end) {
+                special_eog_ids.erase(end_id);
+                id_to_token[end_id].attr = LLAMA_TOKEN_ATTR_USER_DEFINED;
+                LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>' tokens, removing '<|end|>' token from EOG list\n", __func__);
+            }
+        }
     }
 
     // build special tokens cache
@@ -2145,13 +2497,14 @@ enum llama_vocab_type llama_vocab::impl::get_type() const {
 
 std::string llama_vocab::impl::type_name() const{
     switch (type) {
-        case LLAMA_VOCAB_TYPE_NONE: return "no vocab";
-        case LLAMA_VOCAB_TYPE_SPM:  return "SPM";
-        case LLAMA_VOCAB_TYPE_BPE:  return "BPE";
-        case LLAMA_VOCAB_TYPE_WPM:  return "WPM";
-        case LLAMA_VOCAB_TYPE_UGM:  return "UGM";
-        case LLAMA_VOCAB_TYPE_RWKV: return "RWKV";
-        default:                    return "unknown";
+        case LLAMA_VOCAB_TYPE_NONE:   return "no vocab";
+        case LLAMA_VOCAB_TYPE_SPM:    return "SPM";
+        case LLAMA_VOCAB_TYPE_BPE:    return "BPE";
+        case LLAMA_VOCAB_TYPE_WPM:    return "WPM";
+        case LLAMA_VOCAB_TYPE_UGM:    return "UGM";
+        case LLAMA_VOCAB_TYPE_RWKV:   return "RWKV";
+        case LLAMA_VOCAB_TYPE_PLAMO2: return "PLaMo2";
+        default:                      return "unknown";
     }
 }
 
@@ -2234,6 +2587,9 @@ void llama_vocab::impl::init_tokenizer(enum llama_vocab_type type) {
         case LLAMA_VOCAB_TYPE_RWKV:
             tokenizer = std::make_unique(vocab);
             break;
+        case LLAMA_VOCAB_TYPE_PLAMO2:
+            tokenizer = std::make_unique(vocab);
+            break;
         default:
             GGML_ABORT("unsupported vocab type");
     }
@@ -2566,6 +2922,23 @@ std::vector llama_vocab::impl::tokenize(
                     if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
                         std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
 
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                    }
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_PLAMO2:
+            {
+                llm_tokenizer_plamo2_session session(*static_cast(tokenizer.get()));
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
 #ifdef PRETOKENIZERDEBUG
                         LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
 #endif
@@ -2664,6 +3037,24 @@ int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t
                 memcpy(buf, result.data(), result.size());
                 return (int)result.size();
             }
+            case LLAMA_VOCAB_TYPE_PLAMO2: {
+                // PLaMo-2 uses similar token handling as BPE/SPM
+                if (vocab.is_byte(token)) {
+                    // Handle byte tokens like <0xXX>
+                    if (token_text.length() == 6 && token_text.substr(0, 3) == "<0x" && token_text.back() == '>') {
+                        int hex_val = std::stoi(token_text.substr(3, 2), nullptr, 16);
+                        if (length < 1) {
+                            return -1;
+                        }
+                        buf[0] = static_cast(hex_val);
+                        return 1;
+                    }
+                }
+
+                // Normal token - just copy the text
+                std::string result = token_text;
+                return _try_copy(result.data(), result.size());
+            }
             default:
                 GGML_ABORT("fatal error");
         }
@@ -2908,6 +3299,12 @@ llama_token llama_vocab::byte_to_token(uint8_t ch) const {
         case LLAMA_VOCAB_TYPE_BPE: {
             return pimpl->token_to_id.at(unicode_byte_to_utf8(ch));
         }
+        case LLAMA_VOCAB_TYPE_PLAMO2: {
+            // PLaMo-2 uses byte tokens in format <0xXX>
+            char hex_str[8];
+            snprintf(hex_str, sizeof(hex_str), "<0x%02X>", ch);
+            return pimpl->token_to_id.at(hex_str);
+        }
         default:
             GGML_ABORT("fatal error");
     }
@@ -3009,6 +3406,10 @@ llama_token llama_vocab::token_fim_sep() const {
     return pimpl->special_fim_sep_id;
 }
 
+llama_token llama_vocab::token_mask() const {
+    return pimpl->special_mask_id;
+}
+
 bool llama_vocab::get_add_space_prefix() const {
     return pimpl->add_space_prefix;
 }
@@ -3249,6 +3650,10 @@ llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab) {
     return vocab->token_fim_sep();
 }
 
+llama_token llama_vocab_mask(const struct llama_vocab* vocab) {
+    return vocab->token_mask();
+}
+
 // deprecated
 const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token) {
     return llama_vocab_get_text(vocab, token);
@@ -3385,4 +3790,3 @@ int32_t llama_detokenize(
                         bool   unparse_special) {
     return vocab->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
 }
-
diff --git a/examples/talk-llama/llama-vocab.h b/examples/talk-llama/llama-vocab.h
index 46a1ccecb51..61b81242168 100644
--- a/examples/talk-llama/llama-vocab.h
+++ b/examples/talk-llama/llama-vocab.h
@@ -45,6 +45,8 @@ enum llama_vocab_pre_type {
     LLAMA_VOCAB_PRE_TYPE_PIXTRAL        = 34,
     LLAMA_VOCAB_PRE_TYPE_SEED_CODER     = 35,
     LLAMA_VOCAB_PRE_TYPE_HUNYUAN        = 36,
+    LLAMA_VOCAB_PRE_TYPE_KIMI_K2        = 37,
+    LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE  = 38,
 };
 
 struct LLM_KV;
@@ -100,6 +102,7 @@ struct llama_vocab {
     llama_token token_sep() const;
     llama_token token_nl () const;
     llama_token token_pad() const;
+    llama_token token_mask() const;
 
     llama_token token_prefix() const;
     llama_token token_middle() const;
diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h
index f73b1ab65fe..135eaf1b655 100644
--- a/examples/talk-llama/llama.h
+++ b/examples/talk-llama/llama.h
@@ -71,12 +71,13 @@ extern "C" {
     typedef int32_t llama_seq_id;
 
     enum llama_vocab_type {
-        LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab
-        LLAMA_VOCAB_TYPE_SPM  = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback
-        LLAMA_VOCAB_TYPE_BPE  = 2, // GPT-2 tokenizer based on byte-level BPE
-        LLAMA_VOCAB_TYPE_WPM  = 3, // BERT tokenizer based on WordPiece
-        LLAMA_VOCAB_TYPE_UGM  = 4, // T5 tokenizer based on Unigram
-        LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
+        LLAMA_VOCAB_TYPE_NONE   = 0, // For models without vocab
+        LLAMA_VOCAB_TYPE_SPM    = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback
+        LLAMA_VOCAB_TYPE_BPE    = 2, // GPT-2 tokenizer based on byte-level BPE
+        LLAMA_VOCAB_TYPE_WPM    = 3, // BERT tokenizer based on WordPiece
+        LLAMA_VOCAB_TYPE_UGM    = 4, // T5 tokenizer based on Unigram
+        LLAMA_VOCAB_TYPE_RWKV   = 5, // RWKV tokenizer based on greedy tokenization
+        LLAMA_VOCAB_TYPE_PLAMO2 = 6, // PLaMo-2 tokenizer based on Aho-Corasick with dynamic programming
     };
 
     enum llama_rope_type {
@@ -151,6 +152,7 @@ extern "C" {
         //LLAMA_FTYPE_MOSTLY_Q4_0_8_8      = 35, // removed from gguf files, use Q4_0 and runtime repack
         LLAMA_FTYPE_MOSTLY_TQ1_0         = 36, // except 1d tensors
         LLAMA_FTYPE_MOSTLY_TQ2_0         = 37, // except 1d tensors
+        LLAMA_FTYPE_MOSTLY_MXFP4_MOE     = 38, // except 1d tensors
 
         LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
     };
@@ -283,10 +285,11 @@ extern "C" {
         const struct llama_model_kv_override * kv_overrides;
 
         // Keep the booleans together to avoid misalignment during copy-by-value.
-        bool vocab_only;    // only load the vocabulary, no weights
-        bool use_mmap;      // use mmap if possible
-        bool use_mlock;     // force system to keep model in RAM
-        bool check_tensors; // validate model tensor data
+        bool vocab_only;      // only load the vocabulary, no weights
+        bool use_mmap;        // use mmap if possible
+        bool use_mlock;       // force system to keep model in RAM
+        bool check_tensors;   // validate model tensor data
+        bool use_extra_bufts; // use extra buffer types (used for weight repacking)
     };
 
     // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
@@ -334,6 +337,9 @@ extern "C" {
         bool swa_full;    // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
                           // NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
                           //       ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
+        bool kv_unified;  // use a unified buffer across the input sequences when computing the attention
+                          // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
+                          // ref: https://github.com/ggml-org/llama.cpp/pull/14363
     };
 
     // model quantization parameters
@@ -533,6 +539,9 @@ extern "C" {
     // Returns true if the model is recurrent (like Mamba, RWKV, etc.)
     LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);
 
+    // Returns true if the model is diffusion-based (like LLaDA, Dream, etc.)
+    LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model);
+
     // Returns 0 on success
     LLAMA_API uint32_t llama_model_quantize(
             const char * fname_inp,
@@ -724,7 +733,7 @@ extern "C" {
     //   - lazily on next llama_decode()
     // p0 < 0 : [0,  p1]
     // p1 < 0 : [p0, inf)
-    DEPRECATED(void llama_kv_self_seq_div(
+    DEPRECATED(LLAMA_API void llama_kv_self_seq_div(
             struct llama_context * ctx,
                     llama_seq_id   seq_id,
                        llama_pos   p0,
@@ -861,6 +870,29 @@ extern "C" {
                           size_t   n_token_capacity,
                           size_t * n_token_count_out);
 
+#define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1
+
+    typedef uint32_t llama_state_seq_flags;
+
+    LLAMA_API size_t llama_state_seq_get_size_ext(
+            struct llama_context * ctx,
+                    llama_seq_id   seq_id,
+           llama_state_seq_flags   flags);
+
+    LLAMA_API size_t llama_state_seq_get_data_ext(
+            struct llama_context * ctx,
+                         uint8_t * dst,
+                          size_t   size,
+                    llama_seq_id   seq_id,
+           llama_state_seq_flags   flags);
+
+    LLAMA_API size_t llama_state_seq_set_data_ext(
+            struct llama_context * ctx,
+                   const uint8_t * src,
+                          size_t   size,
+                    llama_seq_id   dest_seq_id,
+           llama_state_seq_flags   flags);
+
     //
     // Decoding
     //
@@ -952,6 +984,7 @@ extern "C" {
     // in the order they have appeared in the batch.
     // Rows: number of tokens for which llama_batch.logits[i] != 0
     // Cols: n_vocab
+    // TODO: deprecate in favor of llama_get_logits_ith() (ref: https://github.com/ggml-org/llama.cpp/pull/14853#issuecomment-3113143522)
     LLAMA_API float * llama_get_logits(struct llama_context * ctx);
 
     // Logits for the ith token. For positive indices, Equivalent to:
@@ -966,6 +999,7 @@ extern "C" {
     // in the order they have appeared in the batch.
     // shape: [n_outputs*n_embd]
     // Otherwise, returns NULL.
+    // TODO: deprecate in favor of llama_get_embeddings_ith() (ref: https://github.com/ggml-org/llama.cpp/pull/14853#issuecomment-3113143522)
     LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
 
     // Get the embeddings for the ith token. For positive indices, Equivalent to:
@@ -1004,6 +1038,7 @@ extern "C" {
     LLAMA_API llama_token llama_vocab_sep(const struct llama_vocab * vocab); // sentence separator
     LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line
     LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding
+    LLAMA_API llama_token llama_vocab_mask(const struct llama_vocab * vocab); // mask
 
     LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
     LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);
@@ -1389,6 +1424,7 @@ extern "C" {
 
         int32_t n_p_eval;
         int32_t n_eval;
+        int32_t n_reused; // number of times a ggml compute graph had been reused
     };
 
     struct llama_perf_sampler_data {
@@ -1424,6 +1460,8 @@ extern "C" {
 
         ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
         void * get_opt_pars_ud;                     // userdata for calculating optimizer parameters
+
+        enum ggml_opt_optimizer_type optimizer_type;
     };
 
     LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params);
diff --git a/examples/talk-llama/unicode.cpp b/examples/talk-llama/unicode.cpp
index 43a4581b961..65f36651715 100644
--- a/examples/talk-llama/unicode.cpp
+++ b/examples/talk-llama/unicode.cpp
@@ -557,6 +557,178 @@ static std::vector unicode_regex_split_stl(const std::string & text, con
     return bpe_offsets;
 }
 
+// K2 system regex patterns (from tokenization_kimi.py):
+// [\p{Han}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+
+static std::vector unicode_regex_split_custom_kimi_k2(const std::string & text, const std::vector & offsets) {
+    std::vector bpe_offsets;
+    bpe_offsets.reserve(offsets.size());
+
+    const auto cpts = unicode_cpts_from_utf8(text);
+
+    size_t start = 0;
+    for (auto offset : offsets) {
+        const size_t offset_ini = start;
+        const size_t offset_end = start + offset;
+        assert(offset_end <= cpts.size());
+        start = offset_end;
+
+        static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
+        auto _get_cpt = [&] (const size_t pos) -> uint32_t {
+            return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
+        };
+
+        auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
+            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
+        };
+
+        size_t _prev_end = offset_ini;
+        auto _add_token = [&] (const size_t end) -> size_t {
+            assert(_prev_end <= end && end <= offset_end);
+            size_t len = end - _prev_end;
+            if (len > 0) {
+                bpe_offsets.push_back(len);
+            }
+            _prev_end = end;
+            return len;
+        };
+
+        for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
+            const uint32_t cpt = _get_cpt(pos);
+            const auto flags = _get_flags(pos);
+
+            // Pattern 1: [\p{Han}]+ (Chinese characters)
+            if (unicode_cpt_is_han(cpt)) {
+                while (unicode_cpt_is_han(_get_cpt(pos))) {
+                    pos++;
+                }
+                _add_token(pos);
+                continue;
+            }
+
+            // Pattern 2 & 3: Letter words excluding Han characters with optional contractions
+            // [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?:'s|'t|'re|'ve|'m|'ll|'d)?
+            // [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?:'s|'t|'re|'ve|'m|'ll|'d)?
+            // Check if current char is a letter OR if current char could be a leading char and next char is a letter
+            bool is_letter_pattern = (flags.is_letter && !unicode_cpt_is_han(cpt)) ||
+                                     (!(cpt == '\r' || cpt == '\n' || flags.is_letter || flags.is_number) &&
+                                      _get_flags(pos + 1).is_letter && !unicode_cpt_is_han(_get_cpt(pos + 1)));
+
+            if (is_letter_pattern) {
+                // Handle optional leading non-letter/non-number character
+                bool has_leading_char = false;
+                if (!(cpt == '\r' || cpt == '\n' || flags.is_letter || flags.is_number)) {
+                    has_leading_char = true;
+                    pos++;
+                }
+
+                // Match letter sequence (excluding Han characters)
+                bool has_letters = false;
+                while (_get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos))) {
+                    has_letters = true;
+                    pos++;
+                }
+
+                // Only proceed if we found letters (after potentially skipping leading char)
+                if (has_letters || (!has_leading_char && _get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos)))) {
+                    if (!has_letters) pos++; // consume the first letter if we didn't already
+
+                    // Continue consuming letters
+                    while (_get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos))) {
+                        pos++;
+                    }
+
+                    // Check for optional contractions (?:'s|'t|'re|'ve|'m|'ll|'d)
+                    if (_get_cpt(pos) == '\'' && pos + 1 < offset_end) {
+                        uint32_t cpt_next = unicode_tolower(_get_cpt(pos + 1));
+                        if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
+                            pos += 2;
+                        } else if (pos + 2 < offset_end) {
+                            uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos + 2));
+                            if ((cpt_next == 'r' && cpt_next_next == 'e') ||
+                                (cpt_next == 'v' && cpt_next_next == 'e') ||
+                                (cpt_next == 'l' && cpt_next_next == 'l')) {
+                                pos += 3;
+                            }
+                        }
+                    }
+
+                    _add_token(pos);
+                    continue;
+                } else if (has_leading_char) {
+                    // We consumed a leading char but found no letters, backtrack
+                    pos--;
+                }
+            }
+
+            // Pattern 4: \p{N}{1,3} (numbers 1-3 digits)
+            if (flags.is_number) {
+                size_t ini = pos;
+                while (_get_flags(pos).is_number) {
+                    if (++pos - ini >= 3) {
+                        _add_token(pos);
+                        ini = pos;
+                    }
+                }
+                _add_token(pos);
+                continue;
+            }
+
+            // Pattern 5:  ?[^\s\p{L}\p{N}]+[\r\n]* (optional space + non-word chars + optional newlines)
+            auto flags2 = (cpt == ' ' ? _get_flags(pos + 1) : flags);
+            if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number) && flags2.as_uint()) {
+                pos += (cpt == ' ');
+                while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number) && flags2.as_uint()) {
+                    flags2 = _get_flags(++pos);
+                }
+                // Match optional [\r\n]*
+                uint32_t cpt2 = _get_cpt(pos);
+                while (cpt2 == '\r' || cpt2 == '\n') {
+                    cpt2 = _get_cpt(++pos);
+                }
+                _add_token(pos);
+                continue;
+            }
+
+            // Count whitespace characters
+            size_t num_whitespaces = 0;
+            size_t last_end_r_or_n = 0;
+            while (_get_flags(pos + num_whitespaces).is_whitespace) {
+                uint32_t cpt2 = _get_cpt(pos + num_whitespaces);
+                if (cpt2 == '\r' || cpt2 == '\n') {
+                    last_end_r_or_n = pos + num_whitespaces + 1;
+                }
+                num_whitespaces++;
+            }
+
+            // Pattern 6: \s*[\r\n]+ (whitespace with newlines)
+            if (last_end_r_or_n > 0) {
+                pos = last_end_r_or_n;
+                _add_token(pos);
+                continue;
+            }
+
+            // Pattern 7: \s+(?!\S) (trailing whitespace)
+            if (num_whitespaces > 1 && _get_cpt(pos + num_whitespaces) != OUT_OF_RANGE) {
+                pos += num_whitespaces - 1;
+                _add_token(pos);
+                continue;
+            }
+
+            // Pattern 8: \s+ (general whitespace)
+            if (num_whitespaces > 0) {
+                pos += num_whitespaces;
+                _add_token(pos);
+                continue;
+            }
+
+            // No matches - consume single character
+            _add_token(++pos);
+        }
+    }
+
+    return bpe_offsets;
+}
+
 static std::vector unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector & offsets) {
     std::vector bpe_offsets;
 
@@ -567,6 +739,9 @@ static std::vector unicode_regex_split_custom(const std::string & text,
             regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") {
 
         bpe_offsets = unicode_regex_split_custom_llama3(text, offsets);
+    } else if (regex_expr == "\\p{Han}+") {
+        // K2's first pattern - handle all K2 patterns together
+        bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets);
     }
 
     return bpe_offsets;
@@ -672,6 +847,38 @@ uint32_t unicode_tolower(uint32_t cpt) {
     return cpt;  // Return the original code point if no lowercase mapping is found
 }
 
+bool unicode_cpt_is_han(uint32_t cpt) {
+    // Han character ranges (Chinese/CJK characters)
+    // CJK Unified Ideographs (most common)
+    if (cpt >= 0x4E00 && cpt <= 0x9FFF) return true;
+
+    // CJK Extension A
+    if (cpt >= 0x3400 && cpt <= 0x4DBF) return true;
+
+    // CJK Extension B
+    if (cpt >= 0x20000 && cpt <= 0x2A6DF) return true;
+
+    // CJK Extension C
+    if (cpt >= 0x2A700 && cpt <= 0x2B73F) return true;
+
+    // CJK Extension D
+    if (cpt >= 0x2B740 && cpt <= 0x2B81F) return true;
+
+    // CJK Extension E
+    if (cpt >= 0x2B820 && cpt <= 0x2CEAF) return true;
+
+    // CJK Extension F
+    if (cpt >= 0x2CEB0 && cpt <= 0x2EBEF) return true;
+
+    // CJK Compatibility Ideographs
+    if (cpt >= 0xF900 && cpt <= 0xFAFF) return true;
+
+    // CJK Compatibility Ideographs Supplement
+    if (cpt >= 0x2F800 && cpt <= 0x2FA1F) return true;
+
+    return false;
+}
+
 std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs) {
     // unicode categories
     static const std::map k_ucat_enum = {
diff --git a/examples/talk-llama/unicode.h b/examples/talk-llama/unicode.h
index c27098df7d4..0a5fa2a78ce 100644
--- a/examples/talk-llama/unicode.h
+++ b/examples/talk-llama/unicode.h
@@ -63,4 +63,6 @@ uint8_t     unicode_utf8_to_byte(const std::string & utf8);
 
 uint32_t unicode_tolower(uint32_t cpt);
 
+bool unicode_cpt_is_han(uint32_t cpt);
+
 std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs);
diff --git a/examples/whisper.wasm/README.md b/examples/whisper.wasm/README.md
index b267d3d242b..da629c5f3eb 100644
--- a/examples/whisper.wasm/README.md
+++ b/examples/whisper.wasm/README.md
@@ -52,6 +52,16 @@ cp bin/libmain.js        /path/to/html/
 cp bin/libmain.worker.js /path/to/html/
 ```
 
+> 📝 **Note:** By default this example is built with `WHISPER_WASM_SINGLE_FILE=ON`
+> which means that that a separate .wasm file will not be generated. Instead, the
+> WASM module is embedded in the main JS file as a base64 encoded string. To
+> generate a separate .wasm file, you need to disable this option by passing
+> `-DWHISPER_WASM_SINGLE_FILE=OFF`:
+> ```console
+> emcmake cmake .. -DWHISPER_WASM_SINGLE_FILE=OFF
+> ```
+> This will generate a `libmain.wasm` file in the build/bin directory.
+
 > 📝 **Note:** As of Emscripten 3.1.58 (April 2024), separate worker.js files are no
 > longer generated and the worker is embedded in the main JS file. So the worker
 > file will not be geneated for versions later than `3.1.58`.
diff --git a/examples/whisper.wasm/index-tmpl.html b/examples/whisper.wasm/index-tmpl.html
index 32fdf12f93e..d5f1be8929c 100644
--- a/examples/whisper.wasm/index-tmpl.html
+++ b/examples/whisper.wasm/index-tmpl.html
@@ -338,22 +338,22 @@
 
             function loadWhisper(model) {
                 let urls = {
-                    'tiny.en':  'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en.bin',
-                    'tiny':     'https://whisper.ggerganov.com/ggml-model-whisper-tiny.bin',
-                    'base.en':  'https://whisper.ggerganov.com/ggml-model-whisper-base.en.bin',
-                    'base':     'https://whisper.ggerganov.com/ggml-model-whisper-base.bin',
-                    'small.en': 'https://whisper.ggerganov.com/ggml-model-whisper-small.en.bin',
-                    'small':    'https://whisper.ggerganov.com/ggml-model-whisper-small.bin',
-
-                    'tiny-en-q5_1':  'https://whisper.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin',
-                    'tiny-q5_1':     'https://whisper.ggerganov.com/ggml-model-whisper-tiny-q5_1.bin',
-                    'base-en-q5_1':  'https://whisper.ggerganov.com/ggml-model-whisper-base.en-q5_1.bin',
-                    'base-q5_1':     'https://whisper.ggerganov.com/ggml-model-whisper-base-q5_1.bin',
-                    'small-en-q5_1': 'https://whisper.ggerganov.com/ggml-model-whisper-small.en-q5_1.bin',
-                    'small-q5_1':    'https://whisper.ggerganov.com/ggml-model-whisper-small-q5_1.bin',
-                    'medium-en-q5_0':'https://whisper.ggerganov.com/ggml-model-whisper-medium.en-q5_0.bin',
-                    'medium-q5_0':   'https://whisper.ggerganov.com/ggml-model-whisper-medium-q5_0.bin',
-                    'large-q5_0':    'https://whisper.ggerganov.com/ggml-model-whisper-large-q5_0.bin',
+                    'tiny.en':  'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en.bin',
+                    'tiny':     'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.bin',
+                    'base.en':  'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin',
+                    'base':     'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.bin',
+                    'small.en': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.en.bin',
+                    'small':    'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.bin',
+
+                    'tiny-en-q5_1':  'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en-q5_1.bin',
+                    'tiny-q5_1':     'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny-q5_1.bin',
+                    'base-en-q5_1':  'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en-q5_1.bin',
+                    'base-q5_1':     'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base-q5_1.bin',
+                    'small-en-q5_1': 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.en-q5_1.bin',
+                    'small-q5_1':    'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small-q5_1.bin',
+                    'medium-en-q5_0':'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-medium.en-q5_0.bin',
+                    'medium-q5_0':   'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-medium-q5_0.bin',
+                    'large-q5_0':    'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-q5_0.bin',
                 };
 
                 let sizes = {
diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt
index eaba9c70469..90e274ccdbc 100644
--- a/ggml/CMakeLists.txt
+++ b/ggml/CMakeLists.txt
@@ -39,8 +39,9 @@ if (WIN32)
     set(CMAKE_SHARED_MODULE_PREFIX  "")
 endif()
 
-option(BUILD_SHARED_LIBS "ggml: build shared libraries" ${BUILD_SHARED_LIBS_DEFAULT})
-option(GGML_BACKEND_DL   "ggml: build backends as dynamic libraries (requires BUILD_SHARED_LIBS)" OFF)
+option(BUILD_SHARED_LIBS           "ggml: build shared libraries" ${BUILD_SHARED_LIBS_DEFAULT})
+option(GGML_BACKEND_DL             "ggml: build backends as dynamic libraries (requires BUILD_SHARED_LIBS)" OFF)
+set(GGML_BACKEND_DIR "" CACHE PATH "ggml: directory to load dynamic backends from (requires GGML_BACKEND_DL")
 
 #
 # option list
@@ -131,7 +132,7 @@ option(GGML_RVV              "ggml: enable rvv"              ON)
 option(GGML_RV_ZFH           "ggml: enable riscv zfh"        OFF)
 option(GGML_XTHEADVECTOR     "ggml: enable xtheadvector"     OFF)
 option(GGML_VXE              "ggml: enable vxe"              ON)
-option(GGML_NNPA             "ggml: enable nnpa"             ON)
+option(GGML_NNPA             "ggml: enable nnpa"             OFF)  # temp disabled by default, see: https://github.com/ggml-org/llama.cpp/issues/14877
 
 option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF)
 set(GGML_CPU_ARM_ARCH        "" CACHE STRING "ggml: CPU architecture for ARM")
@@ -174,6 +175,10 @@ option(GGML_HIP_GRAPHS                      "ggml: use HIP graph, experimental,
 option(GGML_HIP_NO_VMM                      "ggml: do not try to use HIP VMM"                 ON)
 option(GGML_HIP_ROCWMMA_FATTN               "ggml: enable rocWMMA for FlashAttention"         OFF)
 option(GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12   "ggml: enable rocWMMA FlashAttention on GFX12"    OFF)
+option(GGML_HIP_MMQ_MFMA                    "ggml: enable MFMA MMA for CDNA in MMQ"           ON)
+option(GGML_HIP_EXPORT_METRICS              "ggml: enable kernel perf metrics output"         OFF)
+option(GGML_MUSA_GRAPHS                     "ggml: use MUSA graph, experimental, unstable"    OFF)
+option(GGML_MUSA_MUDNN_COPY                 "ggml: enable muDNN for accelerated copy"         OFF)
 option(GGML_VULKAN                          "ggml: use Vulkan"                                OFF)
 option(GGML_VULKAN_CHECK_RESULTS            "ggml: run Vulkan op checks"                      OFF)
 option(GGML_VULKAN_DEBUG                    "ggml: enable Vulkan debug output"                OFF)
@@ -181,6 +186,9 @@ option(GGML_VULKAN_MEMORY_DEBUG             "ggml: enable Vulkan memory debug ou
 option(GGML_VULKAN_SHADER_DEBUG_INFO        "ggml: enable Vulkan shader debug info"           OFF)
 option(GGML_VULKAN_VALIDATE                 "ggml: enable Vulkan validation"                  OFF)
 option(GGML_VULKAN_RUN_TESTS                "ggml: run Vulkan tests"                          OFF)
+option(GGML_WEBGPU                          "ggml: use WebGPU"                                OFF)
+option(GGML_WEBGPU_DEBUG                    "ggml: enable WebGPU debug output"                OFF)
+option(GGML_ZDNN                            "ggml: use zDNN"                                  OFF)
 option(GGML_METAL                           "ggml: use Metal"                                 ${GGML_METAL_DEFAULT})
 option(GGML_METAL_USE_BF16                  "ggml: use bfloat if available"                   OFF)
 option(GGML_METAL_NDEBUG                    "ggml: disable Metal debugging"                   OFF)
@@ -270,6 +278,7 @@ set(GGML_PUBLIC_HEADERS
     include/ggml-rpc.h
     include/ggml-sycl.h
     include/ggml-vulkan.h
+    include/ggml-webgpu.h
     include/gguf.h)
 
 set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}")
diff --git a/ggml/cmake/ggml-config.cmake.in b/ggml/cmake/ggml-config.cmake.in
index 8c2dc31c6da..91c9d5cd343 100644
--- a/ggml/cmake/ggml-config.cmake.in
+++ b/ggml/cmake/ggml-config.cmake.in
@@ -1,152 +1,191 @@
-
-@GGML_VARIABLES_EXPANDED@
-
 @PACKAGE_INIT@
 
-set_and_check(GGML_INCLUDE_DIR "@PACKAGE_GGML_INCLUDE_INSTALL_DIR@")
-set_and_check(GGML_LIB_DIR "@PACKAGE_GGML_LIB_INSTALL_DIR@")
-#set_and_check(GGML_BIN_DIR "@PACKAGE_GGML_BIN_INSTALL_DIR@")
-
-find_package(Threads REQUIRED)
-
-find_library(GGML_LIBRARY ggml
-    REQUIRED
-    HINTS ${GGML_LIB_DIR}
-    NO_CMAKE_FIND_ROOT_PATH)
-
-add_library(ggml::ggml UNKNOWN IMPORTED)
-set_target_properties(ggml::ggml
-    PROPERTIES
-        IMPORTED_LOCATION "${GGML_LIBRARY}")
-
-find_library(GGML_BASE_LIBRARY ggml-base
-    REQUIRED
-    HINTS ${GGML_LIB_DIR}
-    NO_CMAKE_FIND_ROOT_PATH)
-
-add_library(ggml::ggml-base UNKNOWN IMPORTED)
-set_target_properties(ggml::ggml-base
-    PROPERTIES
-        IMPORTED_LOCATION "${GGML_BASE_LIBRARY}")
+@GGML_VARIABLES_EXPANDED@
 
+# Find all dependencies before creating any target.
+include(CMakeFindDependencyMacro)
+find_dependency(Threads)
 if (NOT GGML_SHARED_LIB)
+    set(GGML_CPU_INTERFACE_LINK_LIBRARIES "")
+    set(GGML_CPU_INTERFACE_LINK_OPTIONS   "")
+
     if (APPLE AND GGML_ACCELERATE)
-        find_library(ACCELERATE_FRAMEWORK Accelerate REQUIRED)
+        find_library(ACCELERATE_FRAMEWORK Accelerate)
+        if(NOT ACCELERATE_FRAMEWORK)
+            set(${CMAKE_FIND_PACKAGE_NAME}_FOUND 0)
+            return()
+        endif()
         list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${ACCELERATE_FRAMEWORK})
     endif()
 
-    if (GGML_OPENMP)
-        find_package(OpenMP REQUIRED)
+    if (GGML_OPENMP_ENABLED)
+        find_dependency(OpenMP)
         list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
     endif()
 
     if (GGML_CPU_HBM)
-        find_library(memkind memkind REQUIRED)
+        find_library(memkind memkind)
+        if(NOT memkind)
+            set(${CMAKE_FIND_PACKAGE_NAME}_FOUND 0)
+            return()
+        endif()
         list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES memkind)
     endif()
 
     if (GGML_BLAS)
-        find_package(BLAS REQUIRED)
-        list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${BLAS_LIBRARIES})
-        list(APPEND GGML_CPU_INTERFACE_LINK_OPTIONS   ${BLAS_LINKER_FLAGS})
+        find_dependency(BLAS)
+        list(APPEND GGML_BLAS_INTERFACE_LINK_LIBRARIES ${BLAS_LIBRARIES})
+        list(APPEND GGML_BLAS_INTERFACE_LINK_OPTIONS   ${BLAS_LINKER_FLAGS})
     endif()
 
     if (GGML_CUDA)
-        find_package(CUDAToolkit REQUIRED)
+        set(GGML_CUDA_INTERFACE_LINK_LIBRARIES "")
+        find_dependency(CUDAToolkit)
+        if (GGML_STATIC)
+            list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $)
+            if (WIN32)
+                list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $ $)
+            else()
+                list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $ $)
+            endif()
+        endif()
+        if (NOT GGML_CUDA_NO_VMM)
+            list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $)
+        endif()
     endif()
 
     if (GGML_METAL)
-        find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
-        find_library(METAL_FRAMEWORK    Metal REQUIRED)
-        find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
+        find_library(FOUNDATION_LIBRARY Foundation)
+        find_library(METAL_FRAMEWORK    Metal)
+        find_library(METALKIT_FRAMEWORK MetalKit)
+        if(NOT FOUNDATION_LIBRARY OR NOT METAL_FRAMEWORK OR NOT METALKIT_FRAMEWORK)
+            set(${CMAKE_FIND_PACKAGE_NAME}_FOUND 0)
+            return()
+        endif()
+        set(GGML_METAL_INTERFACE_LINK_LIBRARIES
+            ${FOUNDATION_LIBRARY} ${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK})
+    endif()
 
-        list(APPEND GGML_METAL_INTERFACE_LINK_LIBRARIES
-                    ${FOUNDATION_LIBRARY} ${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK})
+    if (GGML_OPENCL)
+        find_dependency(OpenCL)
+        set(GGML_OPENCL_INTERFACE_LINK_LIBRARIES $)
     endif()
 
     if (GGML_VULKAN)
-        find_package(Vulkan REQUIRED)
-        list(APPEND GGML_VULKAN_INTERFACE_LINK_LIBRARIES Vulkan::Vulkan)
+        find_dependency(Vulkan)
+        set(GGML_VULKAN_INTERFACE_LINK_LIBRARIES $)
     endif()
 
     if (GGML_HIP)
-        find_package(hip     REQUIRED)
-        find_package(hipblas REQUIRED)
-        find_package(rocblas REQUIRED)
-        list(APPEND GGML_HIP_INTERFACE_LINK_LIBRARIES hip::host roc::rocblas roc::hipblas)
+        find_dependency(hip)
+        find_dependency(hipblas)
+        find_dependency(rocblas)
+        set(GGML_HIP_INTERFACE_LINK_LIBRARIES hip::host roc::rocblas roc::hipblas)
     endif()
 
     if (GGML_SYCL)
+        set(GGML_SYCL_INTERFACE_LINK_LIBRARIES "")
         find_package(DNNL)
         if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
             list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES DNNL::dnnl)
         endif()
         if (WIN32)
-            find_package(IntelSYCL REQUIRED)
-            find_package(MKL       REQUIRED)
+            find_dependency(IntelSYCL)
+            find_dependency(MKL)
             list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL)
         endif()
     endif()
 endif()
 
-set(_ggml_all_targets "")
-foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS})
-    string(REPLACE "-" "_" _ggml_backend_pfx "${_ggml_backend}")
-    string(TOUPPER "${_ggml_backend_pfx}" _ggml_backend_pfx)
+set_and_check(GGML_INCLUDE_DIR "@PACKAGE_GGML_INCLUDE_INSTALL_DIR@")
+set_and_check(GGML_LIB_DIR "@PACKAGE_GGML_LIB_INSTALL_DIR@")
+#set_and_check(GGML_BIN_DIR "@PACKAGE_GGML_BIN_INSTALL_DIR@")
 
-    find_library(${_ggml_backend_pfx}_LIBRARY ${_ggml_backend}
+if(NOT TARGET ggml::ggml)
+    find_package(Threads REQUIRED)
+
+    find_library(GGML_LIBRARY ggml
         REQUIRED
         HINTS ${GGML_LIB_DIR}
         NO_CMAKE_FIND_ROOT_PATH)
 
-    message(STATUS "Found ${${_ggml_backend_pfx}_LIBRARY}")
+    add_library(ggml::ggml UNKNOWN IMPORTED)
+    set_target_properties(ggml::ggml
+        PROPERTIES
+            IMPORTED_LOCATION "${GGML_LIBRARY}")
+
+    find_library(GGML_BASE_LIBRARY ggml-base
+        REQUIRED
+        HINTS ${GGML_LIB_DIR}
+        NO_CMAKE_FIND_ROOT_PATH)
 
-    add_library(ggml::${_ggml_backend} UNKNOWN IMPORTED)
-    set_target_properties(ggml::${_ggml_backend}
+    add_library(ggml::ggml-base UNKNOWN IMPORTED)
+    set_target_properties(ggml::ggml-base
         PROPERTIES
-            INTERFACE_INCLUDE_DIRECTORIES "${GGML_INCLUDE_DIR}"
-            IMPORTED_LINK_INTERFACE_LANGUAGES "CXX"
-            IMPORTED_LOCATION "${${_ggml_backend_pfx}_LIBRARY}"
-            INTERFACE_COMPILE_FEATURES c_std_90
-            POSITION_INDEPENDENT_CODE ON)
-
-    string(REGEX MATCH "^ggml-cpu" is_cpu_variant "${_ggml_backend}")
-    if(is_cpu_variant)
-        list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES "ggml::ggml-base")
-        set_target_properties(ggml::${_ggml_backend}
-           PROPERTIES
-               INTERFACE_LINK_LIBRARIES "${GGML_CPU_INTERFACE_LINK_LIBRARIES}")
-
-        if(GGML_CPU_INTERFACE_LINK_OPTIONS)
-            set_target_properties(ggml::${_ggml_backend}
-                PROPERTIES
-                    INTERFACE_LINK_OPTIONS "${GGML_CPU_INTERFACE_LINK_OPTIONS}")
-        endif()
+            IMPORTED_LOCATION "${GGML_BASE_LIBRARY}")
+
+    set(_ggml_all_targets "")
+    if (NOT GGML_BACKEND_DL)
+        foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS})
+            string(REPLACE "-" "_" _ggml_backend_pfx "${_ggml_backend}")
+            string(TOUPPER "${_ggml_backend_pfx}" _ggml_backend_pfx)
 
-    else()
-        list(APPEND ${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES "ggml::ggml-base")
-        set_target_properties(ggml::${_ggml_backend}
-            PROPERTIES
-                INTERFACE_LINK_LIBRARIES "${${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES}")
+            find_library(${_ggml_backend_pfx}_LIBRARY ${_ggml_backend}
+                REQUIRED
+                HINTS ${GGML_LIB_DIR}
+                NO_CMAKE_FIND_ROOT_PATH)
 
-        if(${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS)
+            message(STATUS "Found ${${_ggml_backend_pfx}_LIBRARY}")
+
+            add_library(ggml::${_ggml_backend} UNKNOWN IMPORTED)
             set_target_properties(ggml::${_ggml_backend}
                 PROPERTIES
-                    INTERFACE_LINK_OPTIONS "${${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS}")
-        endif()
+                    INTERFACE_INCLUDE_DIRECTORIES "${GGML_INCLUDE_DIR}"
+                    IMPORTED_LINK_INTERFACE_LANGUAGES "CXX"
+                    IMPORTED_LOCATION "${${_ggml_backend_pfx}_LIBRARY}"
+                    INTERFACE_COMPILE_FEATURES c_std_90
+                    POSITION_INDEPENDENT_CODE ON)
+
+            string(REGEX MATCH "^ggml-cpu" is_cpu_variant "${_ggml_backend}")
+            if(is_cpu_variant)
+                list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES "ggml::ggml-base")
+                set_target_properties(ggml::${_ggml_backend}
+                PROPERTIES
+                    INTERFACE_LINK_LIBRARIES "${GGML_CPU_INTERFACE_LINK_LIBRARIES}")
+
+                if(GGML_CPU_INTERFACE_LINK_OPTIONS)
+                    set_target_properties(ggml::${_ggml_backend}
+                        PROPERTIES
+                            INTERFACE_LINK_OPTIONS "${GGML_CPU_INTERFACE_LINK_OPTIONS}")
+                endif()
+
+            else()
+                list(APPEND ${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES "ggml::ggml-base")
+                set_target_properties(ggml::${_ggml_backend}
+                    PROPERTIES
+                        INTERFACE_LINK_LIBRARIES "${${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES}")
+
+                if(${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS)
+                    set_target_properties(ggml::${_ggml_backend}
+                        PROPERTIES
+                            INTERFACE_LINK_OPTIONS "${${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS}")
+                endif()
+            endif()
+
+            list(APPEND _ggml_all_targets ggml::${_ggml_backend})
+        endforeach()
     endif()
 
-    list(APPEND _ggml_all_targets ggml::${_ggml_backend})
-endforeach()
+    list(APPEND GGML_INTERFACE_LINK_LIBRARIES ggml::ggml-base "${_ggml_all_targets}")
+    set_target_properties(ggml::ggml
+        PROPERTIES
+            INTERFACE_LINK_LIBRARIES "${GGML_INTERFACE_LINK_LIBRARIES}")
 
-list(APPEND GGML_INTERFACE_LINK_LIBRARIES ggml::ggml-base "${_ggml_all_targets}")
-set_target_properties(ggml::ggml
-    PROPERTIES
-        INTERFACE_LINK_LIBRARIES "${GGML_INTERFACE_LINK_LIBRARIES}")
+    add_library(ggml::all INTERFACE IMPORTED)
+    set_target_properties(ggml::all
+        PROPERTIES
+            INTERFACE_LINK_LIBRARIES "${_ggml_all_targets}")
 
-add_library(ggml::all INTERFACE IMPORTED)
-set_target_properties(ggml::all
-    PROPERTIES
-        INTERFACE_LINK_LIBRARIES "${_ggml_all_targets}")
+endif()
 
 check_required_components(ggml)
diff --git a/ggml/include/ggml-kompute.h b/ggml/include/ggml-kompute.h
deleted file mode 100644
index 154aa56a742..00000000000
--- a/ggml/include/ggml-kompute.h
+++ /dev/null
@@ -1,50 +0,0 @@
-#pragma once
-
-#include "ggml.h"
-#include "ggml-backend.h"
-
-#include 
-#include 
-#include 
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-#define GGML_KOMPUTE_MAX_DEVICES 16
-
-struct ggml_vk_device {
-    int index;
-    int type; // same as VkPhysicalDeviceType
-    size_t heapSize;
-    const char * name;
-    const char * vendor;
-    int subgroupSize;
-    uint64_t bufferAlignment;
-    uint64_t maxAlloc;
-};
-
-struct ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count);
-bool ggml_vk_get_device(struct ggml_vk_device * device, size_t memoryRequired, const char * name);
-bool ggml_vk_has_vulkan(void);
-bool ggml_vk_has_device(void);
-struct ggml_vk_device ggml_vk_current_device(void);
-
-//
-// backend API
-//
-
-// forward declaration
-typedef struct ggml_backend * ggml_backend_t;
-
-GGML_BACKEND_API ggml_backend_t ggml_backend_kompute_init(int device);
-
-GGML_BACKEND_API bool ggml_backend_is_kompute(ggml_backend_t backend);
-
-GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device);
-
-GGML_BACKEND_API ggml_backend_reg_t ggml_backend_kompute_reg(void);
-
-#ifdef __cplusplus
-}
-#endif
diff --git a/ggml/include/ggml-opt.h b/ggml/include/ggml-opt.h
index 74ec080a055..4703a05afe1 100644
--- a/ggml/include/ggml-opt.h
+++ b/ggml/include/ggml-opt.h
@@ -74,16 +74,26 @@ extern "C" {
         GGML_OPT_BUILD_TYPE_OPT     = 30,
     };
 
+    enum ggml_opt_optimizer_type {
+        GGML_OPT_OPTIMIZER_TYPE_ADAMW,
+        GGML_OPT_OPTIMIZER_TYPE_SGD,
+
+        GGML_OPT_OPTIMIZER_TYPE_COUNT
+    };
+
     // parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
     struct ggml_opt_optimizer_params {
-        // AdamW optimizer parameters
         struct {
             float alpha; // learning rate
-            float beta1;
-            float beta2;
+            float beta1; // first AdamW momentum
+            float beta2; // second AdamW momentum
             float eps;   // epsilon for numerical stability
-            float wd;    // weight decay for AdamW, use 0.0f to disable
+            float wd;    // weight decay - 0.0f to disable
         } adamw;
+        struct {
+            float alpha; // learning rate
+            float wd;    // weight decay
+        } sgd;
     };
 
     // callback to calculate optimizer parameters prior to a backward pass
@@ -112,8 +122,11 @@ extern "C" {
 
         int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done
 
-        ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
-        void * get_opt_pars_ud;                     // userdata for calculating optimizer parameters
+        ggml_opt_get_optimizer_params get_opt_pars;    // callback for calculating optimizer parameters
+        void *                        get_opt_pars_ud; // userdata for calculating optimizer parameters
+
+        // only GGML_OPT_OPTIMIZER_TYPE_ADAMW needs m, v momenta per parameter tensor
+        enum ggml_opt_optimizer_type optimizer;
     };
 
     // get parameters for an optimization context with defaults set where possible
@@ -142,6 +155,10 @@ extern "C" {
     // get the gradient accumulator for a node from the forward graph
     GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);
 
+    GGML_API enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t); //TODO consistent naming scheme
+
+    GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type);
+
     // ====== Optimization Result ======
 
     GGML_API ggml_opt_result_t ggml_opt_result_init(void);
@@ -226,12 +243,14 @@ extern "C" {
             struct ggml_tensor            * outputs,        // output tensor, must have shape [ne_label, ndata_batch] if labels are used
             ggml_opt_dataset_t              dataset,        // dataset with data and optionally also labels
             enum ggml_opt_loss_type         loss_type,      // loss to minimize
+            enum ggml_opt_optimizer_type    optimizer,      // sgd or adamw
             ggml_opt_get_optimizer_params   get_opt_pars,   // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
             int64_t                         nepoch,         // how many times the dataset should be iterated over
             int64_t                         nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs
             float                           val_split,      // fraction of the dataset to use for validation, must be in [0.0f, 1.0f)
             bool                            silent);        // whether or not info prints to stderr should be suppressed
 
+
 #ifdef  __cplusplus
 }
 #endif
diff --git a/ggml/include/ggml-webgpu.h b/ggml/include/ggml-webgpu.h
new file mode 100644
index 00000000000..65b8ed9bb66
--- /dev/null
+++ b/ggml/include/ggml-webgpu.h
@@ -0,0 +1,19 @@
+#pragma once
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+#define GGML_WEBGPU_NAME "WebGPU"
+
+// Needed for examples in ggml
+GGML_BACKEND_API ggml_backend_t ggml_backend_webgpu_init(void);
+
+GGML_BACKEND_API ggml_backend_reg_t ggml_backend_webgpu_reg(void);
+
+#ifdef  __cplusplus
+}
+#endif
diff --git a/ggml/include/ggml-zdnn.h b/ggml/include/ggml-zdnn.h
new file mode 100644
index 00000000000..c2c30c977c3
--- /dev/null
+++ b/ggml/include/ggml-zdnn.h
@@ -0,0 +1,16 @@
+#pragma once
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+GGML_BACKEND_API ggml_backend_t ggml_backend_zdnn_init(void);
+
+GGML_BACKEND_API ggml_backend_reg_t ggml_backend_zdnn_reg(void);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index 8a8775be365..da8813fd278 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -241,6 +241,8 @@
 #define GGML_ROPE_TYPE_MROPE  8
 #define GGML_ROPE_TYPE_VISION 24
 
+#define GGML_MROPE_SECTIONS   4
+
 #define GGML_UNUSED(x) (void)(x)
 
 #define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
@@ -304,6 +306,16 @@
     GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \
     GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
 
+#define GGML_TENSOR_TERNARY_OP_LOCALS \
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
+    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb) \
+    GGML_TENSOR_LOCALS(int64_t, ne2, src2, ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb2, src2, nb) \
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
+
 #define GGML_TENSOR_BINARY_OP_LOCALS01 \
     GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
     GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
@@ -395,7 +407,8 @@ extern "C" {
         // GGML_TYPE_IQ4_NL_4_4 = 36,
         // GGML_TYPE_IQ4_NL_4_8 = 37,
         // GGML_TYPE_IQ4_NL_8_8 = 38,
-        GGML_TYPE_COUNT   = 39,
+        GGML_TYPE_MXFP4   = 39, // MXFP4 (1 block)
+        GGML_TYPE_COUNT   = 40,
     };
 
     // precision
@@ -430,6 +443,7 @@ extern "C" {
         GGML_FTYPE_MOSTLY_IQ4_XS  = 22, // except 1d tensors
         GGML_FTYPE_MOSTLY_IQ1_M   = 23, // except 1d tensors
         GGML_FTYPE_MOSTLY_BF16    = 24, // except 1d tensors
+        GGML_FTYPE_MOSTLY_MXFP4   = 25, // except 1d tensors
     };
 
     // available tensor operations:
@@ -438,6 +452,7 @@ extern "C" {
 
         GGML_OP_DUP,
         GGML_OP_ADD,
+        GGML_OP_ADD_ID,
         GGML_OP_ADD1,
         GGML_OP_ACC,
         GGML_OP_SUB,
@@ -527,6 +542,7 @@ extern "C" {
         GGML_OP_CROSS_ENTROPY_LOSS,
         GGML_OP_CROSS_ENTROPY_LOSS_BACK,
         GGML_OP_OPT_STEP_ADAMW,
+        GGML_OP_OPT_STEP_SGD,
 
         GGML_OP_GLU,
 
@@ -557,6 +573,7 @@ extern "C" {
         GGML_GLU_OP_REGLU,
         GGML_GLU_OP_GEGLU,
         GGML_GLU_OP_SWIGLU,
+        GGML_GLU_OP_SWIGLU_OAI,
         GGML_GLU_OP_GEGLU_ERF,
         GGML_GLU_OP_GEGLU_QUICK,
 
@@ -831,6 +848,13 @@ extern "C" {
             struct ggml_tensor  * b,
             enum   ggml_type      type);
 
+    // dst[i0, i1, i2] = a[i0, i1, i2] + b[i0, ids[i1, i2]]
+    GGML_API struct ggml_tensor * ggml_add_id(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            struct ggml_tensor  * ids);
+
     GGML_API struct ggml_tensor * ggml_add1(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
@@ -1198,6 +1222,13 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    GGML_API struct ggml_tensor * ggml_swiglu_oai(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            float                 alpha,
+            float                 limit);
+
     // normalize along rows
     GGML_API struct ggml_tensor * ggml_norm(
             struct ggml_context * ctx,
@@ -1570,6 +1601,10 @@ extern "C" {
             float                 scale,
             float                 max_bias);
 
+    GGML_API void ggml_soft_max_add_sinks(
+            struct ggml_tensor * a,
+            struct ggml_tensor * sinks);
+
     GGML_API struct ggml_tensor * ggml_soft_max_ext_back(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
@@ -1628,7 +1663,7 @@ extern "C" {
             struct ggml_tensor  * b,
             struct ggml_tensor  * c,
             int                   n_dims,
-            int                   sections[4],
+            int                   sections[GGML_MROPE_SECTIONS],
             int                   mode,
             int                   n_ctx_orig,
             float                 freq_base,
@@ -1654,6 +1689,22 @@ extern "C" {
             float                 beta_fast,
             float                 beta_slow);
 
+    GGML_API struct ggml_tensor * ggml_rope_multi_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            struct ggml_tensor  * c,
+            int                   n_dims,
+            int                   sections[GGML_MROPE_SECTIONS],
+            int                   mode,
+            int                   n_ctx_orig,
+            float                 freq_base,
+            float                 freq_scale,
+            float                 ext_factor,
+            float                 attn_factor,
+            float                 beta_fast,
+            float                 beta_slow);
+
     GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
@@ -2052,6 +2103,10 @@ extern "C" {
     GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec(
             const struct ggml_tensor * a);
 
+    GGML_API void ggml_flash_attn_ext_add_sinks(
+            struct ggml_tensor * a,
+            struct ggml_tensor * sinks);
+
     // TODO: needs to be adapted to ggml_flash_attn_ext
     GGML_API struct ggml_tensor * ggml_flash_attn_back(
            struct ggml_context * ctx,
@@ -2257,7 +2312,14 @@ extern "C" {
             struct ggml_tensor  * grad,
             struct ggml_tensor  * m,
             struct ggml_tensor  * v,
-            struct ggml_tensor  * adamw_params); // parameters such a the learning rate
+            struct ggml_tensor  * adamw_params); // parameters such as the learning rate
+
+    // stochastic gradient descent step (with weight decay)
+    GGML_API struct ggml_tensor * ggml_opt_step_sgd(
+        struct ggml_context * ctx,
+        struct ggml_tensor *  a,
+        struct ggml_tensor *  grad,
+        struct ggml_tensor *  sgd_params); // alpha, weight decay
 
     //
     // automatic differentiation
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
index 8760c2d35ec..2b5b8169d75 100644
--- a/ggml/src/CMakeLists.txt
+++ b/ggml/src/CMakeLists.txt
@@ -214,6 +214,13 @@ add_library(ggml
             ggml-backend-reg.cpp)
 add_library(ggml::ggml ALIAS ggml)
 
+if (GGML_BACKEND_DIR)
+    if (NOT GGML_BACKEND_DL)
+        message(FATAL_ERROR "GGML_BACKEND_DIR requires GGML_BACKEND_DL")
+    endif()
+    target_compile_definitions(ggml PUBLIC GGML_BACKEND_DIR="${GGML_BACKEND_DIR}")
+endif()
+
 target_link_libraries(ggml PUBLIC ggml-base)
 
 if (CMAKE_SYSTEM_NAME MATCHES "Linux")
@@ -227,7 +234,11 @@ function(ggml_add_backend_library backend)
         set_target_properties(${backend} PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
         target_compile_definitions(${backend} PRIVATE GGML_BACKEND_DL)
         add_dependencies(ggml ${backend})
-        install(TARGETS ${backend} LIBRARY DESTINATION ${CMAKE_INSTALL_BINDIR})
+        if (GGML_BACKEND_DIR)
+            install(TARGETS ${backend} LIBRARY DESTINATION ${GGML_BACKEND_DIR})
+        else()
+            install(TARGETS ${backend} LIBRARY DESTINATION ${CMAKE_INSTALL_BINDIR})
+        endif()
     else()
         add_library(${backend} ${ARGN})
         target_link_libraries(ggml PUBLIC ${backend})
@@ -370,6 +381,8 @@ ggml_add_backend(MUSA)
 ggml_add_backend(RPC)
 ggml_add_backend(SYCL)
 ggml_add_backend(Vulkan)
+ggml_add_backend(WebGPU)
+ggml_add_backend(zDNN)
 ggml_add_backend(OpenCL)
 
 foreach (target ggml-base ggml)
diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c
index 5fd379f6a94..8b6e6028361 100644
--- a/ggml/src/ggml-alloc.c
+++ b/ggml/src/ggml-alloc.c
@@ -22,21 +22,6 @@ static bool ggml_is_view(const struct ggml_tensor * t) {
     return t->view_src != NULL;
 }
 
-static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
-    if (a->type != b->type) {
-        return false;
-    }
-    for (int i = 0; i < GGML_MAX_DIMS; i++) {
-        if (a->ne[i] != b->ne[i]) {
-            return false;
-        }
-        if (a->nb[i] != b->nb[i]) {
-            return false;
-        }
-    }
-    return true;
-}
-
 // ops that return true for this function must not use restrict pointers for their backend implementations
 static bool ggml_op_can_inplace(enum ggml_op op) {
     switch (op) {
@@ -44,6 +29,7 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
         case GGML_OP_DIAG_MASK_ZERO:
         case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_ADD:
+        case GGML_OP_ADD_ID:
         case GGML_OP_ADD1:
         case GGML_OP_SUB:
         case GGML_OP_MUL:
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
index 042ea77aca7..5f02a710a14 100644
--- a/ggml/src/ggml-backend-reg.cpp
+++ b/ggml/src/ggml-backend-reg.cpp
@@ -45,6 +45,14 @@
 #include "ggml-vulkan.h"
 #endif
 
+#ifdef GGML_USE_WEBGPU
+#include "ggml-webgpu.h"
+#endif
+
+#ifdef GGML_USE_ZDNN
+#include "ggml-zdnn.h"
+#endif
+
 #ifdef GGML_USE_OPENCL
 #include "ggml-opencl.h"
 #endif
@@ -173,6 +181,12 @@ struct ggml_backend_registry {
 #ifdef GGML_USE_VULKAN
         register_backend(ggml_backend_vk_reg());
 #endif
+#ifdef GGML_USE_WEBGPU
+        register_backend(ggml_backend_webgpu_reg());
+#endif
+#ifdef GGML_USE_ZDNN
+        register_backend(ggml_backend_zdnn_reg());
+#endif
 #ifdef GGML_USE_OPENCL
         register_backend(ggml_backend_opencl_reg());
 #endif
@@ -491,6 +505,9 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
 
     std::vector search_paths;
     if (user_search_path == nullptr) {
+#ifdef GGML_BACKEND_DIR
+        search_paths.push_back(fs::u8path(GGML_BACKEND_DIR));
+#endif
         // default search paths: executable directory, current directory
         search_paths.push_back(get_executable_path());
         search_paths.push_back(fs::current_path());
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
index 788861a365f..1b9d29e911f 100644
--- a/ggml/src/ggml-backend.cpp
+++ b/ggml/src/ggml-backend.cpp
@@ -352,21 +352,6 @@ ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) {
 
 // backend copy
 
-static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
-    if (a->type != b->type) {
-        return false;
-    }
-    for (int i = 0; i < GGML_MAX_DIMS; i++) {
-        if (a->ne[i] != b->ne[i]) {
-            return false;
-        }
-        if (a->nb[i] != b->nb[i]) {
-            return false;
-        }
-    }
-    return true;
-}
-
 void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) {
     GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
 
@@ -662,6 +647,7 @@ struct ggml_backend_sched {
     // pipeline parallelism support
     int n_copies;
     int cur_copy;
+    int next_copy;
     ggml_backend_event_t events[GGML_SCHED_MAX_BACKENDS][GGML_SCHED_MAX_COPIES];
     struct ggml_tensor * graph_inputs[GGML_SCHED_MAX_SPLIT_INPUTS];
     int n_graph_inputs;
@@ -1085,6 +1071,11 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
                 }
             }
         }
+        // if the node is still unassigned, assign it to the first backend that supports it
+        for (int b = 0; b < sched->n_backends && *cur_backend_id == -1; b++) {
+            ggml_backend_sched_set_if_supported(sched, node, b, cur_backend_id);
+        }
+        GGML_ASSERT(*cur_backend_id != -1);
     }
 
     // pass 5: split graph, find tensors that need to be copied
@@ -1112,7 +1103,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
 
             const int node_backend_id = tensor_backend_id(node);
 
-            assert(node_backend_id != -1); // all nodes should be assigned by now, this can happen if there is no CPU fallback
+            GGML_ASSERT(node_backend_id != -1); // all nodes should be assigned by now, this can happen if there is no CPU fallback
 
             // check if we should start a new split based on the sources of the current node
             bool need_new_split = false;
@@ -1170,7 +1161,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
 
                 size_t src_id = hash_id(src);
                 const int src_backend_id = sched->hv_tensor_backend_ids[src_id];
-                assert(src_backend_id != -1); // all inputs should be assigned by now
+                GGML_ASSERT(src_backend_id != -1); // all inputs should be assigned by now
 
                 if (src->flags & GGML_TENSOR_FLAG_INPUT && sched->n_copies > 1) {
                     if (tensor_id_copy(src_id, src_backend_id, 0) == NULL) {
@@ -1448,8 +1439,6 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
         }
     }
 
-    sched->cur_copy = (sched->cur_copy + 1) % sched->n_copies;
-
     return GGML_STATUS_SUCCESS;
 }
 
@@ -1550,10 +1539,10 @@ void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
 bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) {
     GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);
 
-    ggml_backend_sched_split_graph(sched, measure_graph);
-
     ggml_backend_sched_synchronize(sched);
 
+    ggml_backend_sched_split_graph(sched, measure_graph);
+
     if (!ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids)) {
         return false;
     }
@@ -1565,6 +1554,10 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
 
 bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
     GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + graph->n_leafs);
+    GGML_ASSERT(!sched->is_alloc);
+
+    sched->cur_copy = sched->next_copy;
+    sched->next_copy = (sched->next_copy + 1) % sched->n_copies;
 
     ggml_backend_sched_split_graph(sched, graph);
 
@@ -1605,7 +1598,7 @@ void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) {
         // if the graph is not already allocated, always use copy 0 after a synchronization
         // this ensures that during generation the same copy is used every time,
         // which avoids changes in the graph that could cause CUDA or other graphs to be disabled
-        sched->cur_copy = 0;
+        sched->next_copy = 0;
     }
 }
 
diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp
index ec158dfac6e..aeac2e57449 100644
--- a/ggml/src/ggml-blas/ggml-blas.cpp
+++ b/ggml/src/ggml-blas/ggml-blas.cpp
@@ -281,10 +281,10 @@ ggml_backend_t ggml_backend_blas_init(void) {
     ggml_backend_blas_context * ctx = new ggml_backend_blas_context;
 
     ggml_backend_t backend = new ggml_backend {
-        /* .guid      = */ ggml_backend_blas_guid(),
-        /* .interface = */ blas_backend_i,
-        /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_blas_reg(), 0),
-        /* .context   = */ ctx,
+        /* .guid    = */ ggml_backend_blas_guid(),
+        /* .iface   = */ blas_backend_i,
+        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_blas_reg(), 0),
+        /* .context = */ ctx,
     };
 
 #if defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP)
diff --git a/ggml/src/ggml-cann/CMakeLists.txt b/ggml/src/ggml-cann/CMakeLists.txt
index 7742b39153f..aee5e7b06e5 100755
--- a/ggml/src/ggml-cann/CMakeLists.txt
+++ b/ggml/src/ggml-cann/CMakeLists.txt
@@ -31,6 +31,13 @@ string(REGEX MATCH "[0-9]+[a-zA-Z]" SOC_TYPE_MAJOR_SN "${SOC_VERSION}")
 set(SOC_TYPE_COMPILE_OPTION "ASCEND_${SOC_TYPE_MAJOR_SN}")
 string(TOUPPER ${SOC_TYPE_COMPILE_OPTION} SOC_TYPE_COMPILE_OPTION)
 message(STATUS "CANN: SOC_VERSION =  ${SOC_VERSION}")
+option(USE_ACL_GRAPH "Enable CANN graph execution (ACL graph mode)" OFF)
+
+if(USE_ACL_GRAPH AND (SOC_TYPE_MAJOR_SN STREQUAL "310P" OR SOC_TYPE_COMPILE_OPTION STREQUAL "ASCEND_310P"))
+    message(FATAL_ERROR
+        "CANN Graph (ACL graph mode) is not supported on 310P devices. "
+        "Please build with -DUSE_ACL_GRAPH=OFF or use a supported SOC.")
+endif()
 
 if (CANN_INSTALL_DIR)
     # Only Support Linux.
@@ -68,6 +75,13 @@ if (CANN_INSTALL_DIR)
 
     target_compile_definitions(ggml-cann PRIVATE "-D${SOC_TYPE_COMPILE_OPTION}")
 
+    if (USE_ACL_GRAPH)
+        target_compile_definitions(ggml-cann PRIVATE USE_ACL_GRAPH)
+        message(STATUS "CANN: USE_ACL_GRAPH is enabled.")
+    else()
+        message(STATUS "CANN: USE_ACL_GRAPH is disabled.")
+    endif()
+
     message(STATUS "CANN: CANN_INCLUDE_DIRS =  ${CANN_INCLUDE_DIRS}")
     message(STATUS "CANN: CANN_LIBRARIES =  ${CANN_LIBRARIES}")
 else()
diff --git a/ggml/src/ggml-cann/acl_tensor.cpp b/ggml/src/ggml-cann/acl_tensor.cpp
index f311864d486..8ffac31dd66 100755
--- a/ggml/src/ggml-cann/acl_tensor.cpp
+++ b/ggml/src/ggml-cann/acl_tensor.cpp
@@ -77,6 +77,8 @@ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne,
     for (int i = 0; i < final_dims; i++) {
         acl_storage_len += (acl_ne[i] - 1) * acl_stride[i];
     }
+    size_t elem_offset = offset / ggml_element_size(tensor);
+    acl_storage_len += elem_offset;
 
     // Reverse ne and stride.
     std::reverse(acl_ne, acl_ne + final_dims);
@@ -84,7 +86,7 @@ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne,
 
     aclTensor* acl_tensor = aclCreateTensor(
         acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride,
-        offset / ggml_element_size(tensor), format, &acl_storage_len, 1,
+        elem_offset, format, &acl_storage_len, 1,
         tensor->data);
 
     return acl_tensor;
diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp
index 4d5c2c18252..259a2928b1f 100755
--- a/ggml/src/ggml-cann/aclnn_ops.cpp
+++ b/ggml/src/ggml-cann/aclnn_ops.cpp
@@ -68,6 +68,8 @@
 #include 
 #include 
 #include 
+#include 
+#include 
 #include 
 
 #include 
@@ -99,7 +101,7 @@ void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclT
     }
 }
 
-void ggml_cann_unary_op(
+void ggml_cann_op_unary(
     std::function unary_op,
     ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     ggml_tensor* src = dst->src[0];
@@ -111,6 +113,42 @@ void ggml_cann_unary_op(
     ggml_cann_release_resources(ctx, acl_src, acl_dst);
 }
 
+void ggml_cann_op_unary_gated(
+    std::function unary_op,
+    ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+    ggml_tensor* src0 = dst->src[0];
+    ggml_tensor* src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_is_contiguous_1(src0));
+    GGML_ASSERT(ggml_is_contiguous_1(dst));
+    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
+
+    aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+    aclTensor *acl_src0 = nullptr, *acl_src1 = nullptr;
+    if(src1) {
+        GGML_ASSERT(ggml_is_contiguous_1(src1));
+        GGML_ASSERT(src0->type == src1->type);
+
+        acl_src0 = ggml_cann_create_tensor(src0);
+        acl_src1 = ggml_cann_create_tensor(src1);
+    } else {
+        int64_t ne[] = {src0->ne[0] / 2, src0->ne[1], src0->ne[2], src0->ne[3]};
+        size_t nb[] = {src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]};
+        acl_src0 = ggml_cann_create_tensor(src0, ne, nb, GGML_MAX_DIMS, ACL_FORMAT_ND, 0);
+        acl_src1 = ggml_cann_create_tensor(src0, ne, nb, GGML_MAX_DIMS, ACL_FORMAT_ND, ne[0] * ggml_element_size(src0));
+        if (swapped) {
+            std::swap(acl_src0, acl_src1);
+        }
+    }
+
+    unary_op(ctx, acl_src0, acl_dst);
+    GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_dst, acl_src1);
+
+    ggml_cann_release_resources(ctx, acl_src0, acl_dst);
+    if(src1)
+        ggml_cann_release_resources(ctx, acl_src1);
+}
+
 /**
  * @brief Repeats elements of a tensor along each dimension according to the
  * specified repeat array.
@@ -715,69 +753,55 @@ static void cann_copy(ggml_backend_cann_context& ctx, aclTensor* acl_src,
 void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     ggml_tensor* src0 = dst->src[0];
 
-    aclTensor* acl_src = ggml_cann_create_tensor(src0);
-    aclTensor* acl_dst = ggml_cann_create_tensor(dst);
     if (ggml_are_same_shape(src0, dst)) {
+        aclTensor* acl_src = ggml_cann_create_tensor(src0);
+        aclTensor* acl_dst = ggml_cann_create_tensor(dst);
         if (dst->type == src0->type) {
             cann_copy(ctx, acl_src, acl_dst);
         } else {
             aclnn_cast(ctx, acl_src, acl_dst, ggml_cann_type_mapping(dst->type));
         }
+        ggml_cann_release_resources(ctx, acl_src, acl_dst);
     } else {
-        if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
-            if (dst->type == src0->type) {
-                size_t cpy_size = ggml_nbytes(dst);
-                ggml_cann_async_memcpy(ctx, dst->data, src0->data, cpy_size,
-                    ACL_MEMCPY_DEVICE_TO_DEVICE);
-                return;
-            } else {
-                ggml_cann_pool_alloc src_buffer_allocator(
-                    ctx.pool(),
-                    ggml_nelements(dst) * ggml_type_size(dst->type));
-                void* src_trans_buffer = src_buffer_allocator.get();
-                size_t src_trans_nb[GGML_MAX_DIMS];
-                src_trans_nb[0] = ggml_type_size(dst->type);
-                for (int i = 1; i < GGML_MAX_DIMS; i++) {
-                    src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
-                }
-                aclTensor* src_trans_tensor = ggml_cann_create_tensor(
-                    src_trans_buffer, ggml_cann_type_mapping(dst->type),
-                    ggml_type_size(dst->type), src0->ne, src_trans_nb,
-                    GGML_MAX_DIMS);
-
-                aclnn_cast(ctx, acl_src, src_trans_tensor, ggml_cann_type_mapping(dst->type));
-                size_t cpy_size = ggml_nbytes(dst);
-                ggml_cann_async_memcpy(ctx, dst->data, src_trans_buffer, cpy_size,
-                    ACL_MEMCPY_DEVICE_TO_DEVICE);
-                ggml_cann_release_resources(ctx, src_trans_tensor);
-                return;
-            }
-        } else if (ggml_is_contiguous(dst)) {
-            ggml_cann_pool_alloc src_buffer_allocator(
-                ctx.pool(), ggml_nelements(dst) * ggml_type_size(dst->type));
-            void* src_trans_buffer = src_buffer_allocator.get();
+        void* src_trans_buffer = src0->data;
+        ggml_cann_pool_alloc src_buffer_allocator;
+        if (!ggml_is_contiguous(src0)) {
+            aclTensor* acl_src = ggml_cann_create_tensor(src0);
+            src_buffer_allocator.alloc(ctx.pool(),
+                ggml_nelements(src0) * ggml_type_size(src0->type));
+            src_trans_buffer = src_buffer_allocator.get();
             size_t src_trans_nb[GGML_MAX_DIMS];
-            src_trans_nb[0] = ggml_type_size(dst->type);
+            src_trans_nb[0] = ggml_type_size(src0->type);
             for (int i = 1; i < GGML_MAX_DIMS; i++) {
                 src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
             }
             aclTensor* src_trans_tensor = ggml_cann_create_tensor(
-                src_trans_buffer, ggml_cann_type_mapping(dst->type),
-                ggml_type_size(dst->type), src0->ne, src_trans_nb,
+                src_trans_buffer, ggml_cann_type_mapping(src0->type),
+                ggml_type_size(src0->type), src0->ne, src_trans_nb,
                 GGML_MAX_DIMS);
+            cann_copy(ctx, acl_src, src_trans_tensor);
+            ggml_cann_release_resources(ctx, acl_src, src_trans_tensor);
+        }
 
-            aclnn_cast(ctx, acl_src, src_trans_tensor, ggml_cann_type_mapping(dst->type));
+        size_t src_reshape_nb[GGML_MAX_DIMS];
+        src_reshape_nb[0] = ggml_type_size(src0->type);
+        for (int i = 1; i < GGML_MAX_DIMS; i++) {
+            src_reshape_nb[i] = src_reshape_nb[i - 1] * dst->ne[i - 1];
+        }
 
-            size_t cpy_size = ggml_nbytes(dst);
-            ggml_cann_async_memcpy(ctx, dst->data, src_trans_buffer, cpy_size,
-                ACL_MEMCPY_DEVICE_TO_DEVICE);
-            ggml_cann_release_resources(ctx, src_trans_tensor);
-            return;
+        aclTensor* trans_acl_src = ggml_cann_create_tensor(src_trans_buffer,
+            ggml_cann_type_mapping(src0->type),ggml_type_size(src0->type),
+            dst->ne, src_reshape_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
+        aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+        if (dst->type == src0->type) {
+            cann_copy(ctx, trans_acl_src, acl_dst);
         } else {
-            GGML_ABORT("Unsupport dst is not tontiguous.");
+            aclnn_cast(ctx, trans_acl_src, acl_dst, ggml_cann_type_mapping(dst->type));
         }
+        ggml_cann_release_resources(ctx, trans_acl_src, acl_dst);
     }
-    ggml_cann_release_resources(ctx, acl_src, acl_dst);
+    return;
 }
 
 /**
@@ -1292,160 +1316,196 @@ static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx,
 }
 
 /**
- * @brief   Applies the Alibi (Attention with Linear Biases) mechanism to the
- * @details This function implements the Alibi mechanism, which introduces
- *          learnable biases into the attention scores to simulate relative
- *          position encoding without the need for explicit positional
- *          embeddings.
- *
- * @param ctx          The backend CANN context for executing operations.
- * @param acl_src      The source tensor representing the query or key.
- * @param acl_position The position tensor containing relative positions.
- * @param acl_dst      The destination tensor where the result will be stored.
- * @param n_head       The number of attention heads.
- * @param src_ne       The dimensions of the source tensor.
- * @param src_nb0      The byte size of the first dimension of the source
- tensor.
- * @param max_bias     The maximum bias value used in the Alibi mechanism.
- * @param dst          The destination tensor object for additional metadata.
- *
- * The function performs the following steps:
- * 1. Calculates the logarithm floor of the number of heads to determine the
-      base for bias calculation.
- * 2. Initializes arrays with arithmetic sequences and fills them with bias
-      values.
- * 3. Computes the bias tensor based on the calculated biases and arithmetic
-      sequences.
- * 4. Reshapes the bias tensor to match the dimensions of the input tensors.
- * 5. Multiplies the position tensor by the bias tensor.
- * 6. Adds the result of the multiplication to the source tensor to produce the
-      final output.
+ * @brief Generate a range of values and apply a scalar base exponentiation.
+ *
+ * This function creates an evenly spaced sequence from `start` to `stop` (exclusive),
+ * with step size `step`, stores it in a temporary buffer, and then computes:
+ *
+ * @f[
+ * slope[i] = m^{\left( start + i \cdot step \right)}, \quad 0 \le i < size
+ * @f]
+ *
+ * The results are written to the provided @p slope_buffer.
+ *
+ * @param ctx           CANN backend context for memory allocation and operator execution.
+ * @param slope_buffer  Pointer to the output buffer (float array) for the computed slope values.
+ * @param m             Scalar base for the exponentiation.
+ * @param size          Number of elements in the generated sequence.
+ * @param start         Starting exponent offset.
+ * @param stop          Stopping exponent offset (exclusive).
+ * @param step          Step size for the exponent increment.
  */
-static void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src,
-                        aclTensor* acl_position, aclTensor* acl_dst,
-                        const int n_head, int64_t* src_ne, const size_t src_nb0,
-                        float max_bias, ggml_tensor* dst) {
-    const int64_t ne2_ne3 = src_ne[2] * src_ne[3];
-    GGML_ASSERT(src_nb0 == sizeof(float));
-    GGML_ASSERT(n_head == src_ne[2]);
-
-    const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head));
-
-    float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
-    float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
-
-    // init arange
-    ggml_cann_pool_alloc arange_allocator(ctx.pool(),
-                                          ne2_ne3 * ggml_type_size(dst->type));
-    void* tmp_arange_buffer = arange_allocator.get();
+static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_buffer,
+    float m, int64_t size, float start, float stop, float step){
+    int64_t ne[] = {size};
+    size_t nb[] = {sizeof(float)};
 
-    // arange1: [1, ..., n_heads_log2_floor+1)
-    float start = 1;
-    float stop = n_heads_log2_floor + 1;
-    float step = 1;
-    int64_t n_elements_arange = n_heads_log2_floor;
+    ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * sizeof(float));
+    void* arange_buffer = arange_allocator.get();
 
-    int64_t tmp_arange1_ne[] = {n_heads_log2_floor};
-    size_t tmp_arange1_nb[] = {sizeof(dst->type)};
-    aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor(
-        tmp_arange_buffer, ggml_cann_type_mapping(dst->type),
-        ggml_type_size(dst->type), tmp_arange1_ne, tmp_arange1_nb,
-        GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+    aclTensor* arange_tensor = ggml_cann_create_tensor(
+        arange_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1);
+    aclnn_arange(ctx, arange_tensor, start, stop, step, size);
 
-    aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange);
-
-    aclTensor* tmp_arange2_tensor = nullptr;
-    if (n_heads_log2_floor < ne2_ne3) {
-        // arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1)
-        start = 1;
-        stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1;
-        step = 2;
-        n_elements_arange = ne2_ne3 - n_heads_log2_floor;
-        int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor};
-        size_t tmp_arange2_nb[] = {sizeof(dst->type)};
-
-        aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor(
-            (char*)tmp_arange_buffer +
-                n_heads_log2_floor * ggml_type_size(dst->type),
-            ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
-            tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
-        aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step,
-                     n_elements_arange);
-    }
+    aclTensor* slope_tensor = ggml_cann_create_tensor(
+        slope_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1);
 
-    // init mk_base
-    ggml_cann_pool_alloc mk_base_allocator(ctx.pool(),
-                                           ne2_ne3 * ggml_type_size(dst->type));
-    void* tmp_mk_base_buffer = mk_base_allocator.get();
-    int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor};
-    size_t tmp_mk_base1_nb[] = {sizeof(dst->type)};
-    aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor(
-        tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type),
-        ggml_type_size(dst->type), tmp_mk_base1_ne, tmp_mk_base1_nb,
-        GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+    aclScalar* sc = aclCreateScalar(&m, aclDataType::ACL_FLOAT);
 
-    aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor);
-
-    aclTensor* tmp_mk_base2_tensor = nullptr;
-    if (n_heads_log2_floor < ne2_ne3) {
-        int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor};
-        size_t tmp_mk_base2_nb[] = {sizeof(dst->type)};
-        aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor(
-            (char*)tmp_mk_base_buffer +
-                n_heads_log2_floor * ggml_type_size(dst->type),
-            ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
-            tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
-        aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor);
-    }
+    GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, sc, arange_tensor, slope_tensor);
+    ggml_cann_release_resources(ctx, sc, arange_tensor, slope_tensor);
+}
 
-    // init mk
-    int64_t tmp_mk_base_ne[] = {ne2_ne3};
-    size_t tmp_mk_base_nb[] = {sizeof(dst->type)};
-    aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor(
-        tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type),
-        ggml_type_size(dst->type), tmp_mk_base_ne, tmp_mk_base_nb,
-        GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
-    aclTensor* tmp_arange_tensor = ggml_cann_create_tensor(
-        tmp_arange_buffer, ggml_cann_type_mapping(dst->type),
-        ggml_type_size(dst->type), tmp_mk_base_ne, tmp_mk_base_nb,
-        GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
-    aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor);
+/**
+ * @brief Compute slope values for multiple attention heads based on ALiBi bias parameters.
+ *
+ * This function generates slope values for each attention head according to the ALiBi
+ * (Attention with Linear Biases) method. It splits the computation into two ranges depending
+ * on whether the head index is less than @p n_head_log2 or not, and uses different base values
+ * (`m0` and `m1`) for the exponentiation.
+ *
+ * @f[
+ * slope[h] =
+ * \begin{cases}
+ * m_0^{(h + 1)}, & h < n\_head\_log2 \\
+ * m_1^{\left( 2 \cdot (h - n\_head\_log2) + 1 \right)}, & h \geq n\_head\_log2
+ * \end{cases}
+ * \quad , \quad \text{if } max\_bias > 0
+ * @f]
+ *
+ * If @p max_bias <= 0, all slope values are set to 1.0.
+ *
+ * @param ctx           CANN backend context for memory allocation and operator execution.
+ * @param n_head        Total number of attention heads.
+ * @param slope_buffer  Pointer to the output buffer (float array) for storing slopes.
+ * @param max_bias      Maximum bias value for slope computation.
+ *
+*/
+static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head,
+    void* slope_buffer, float max_bias) {
+    const int n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
+
+    float m0 = powf(2.0f, -(max_bias) / n_head_log2);
+    float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+    // const float slope = (max_bias > 0.0f) ?
+    //                          h < n_head_log2 ?
+    //                              powf(m0, h + 1) :
+    //                              powf(m1, 2*(h - n_head_log2) + 1) :
+    //                          1.0f;
+    // arange1
+    float start = 0 + 1;
+    float end   = (n_head_log2 - 1) + 1;
+    float step  = 1;
+    float count = n_head_log2;
+    // end needs to be +1 because aclnn uses a left-closed, right-open interval.
+    aclnn_get_slope_inner(ctx, slope_buffer, m0, count, start, end + 1, step);
+    if (n_head_log2 < n_head) {
+        // arange2
+        start = 2 * (n_head_log2 - n_head_log2) + 1;
+        end   = 2 * ((n_head - 1) - n_head_log2) + 1;
+        step  = 2;
+        count = n_head - n_head_log2;
+        aclnn_get_slope_inner(
+            ctx, (char *) slope_buffer + n_head_log2 * sizeof(float),
+            m1, count, start, end + 1, step);
+    }
+}
 
-    // reshape mk
-    int64_t tmp_mk_ne[] = {1, 1, src_ne[2], src_ne[3]};
-    size_t tmp_mk_nb[GGML_MAX_DIMS];
-    tmp_mk_nb[0] = ggml_type_size(dst->type);
-    for (int i = 1; i < GGML_MAX_DIMS; i++) {
-        tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1];
+/**
+ * @brief Add ALiBi (Attention with Linear Biases) positional biases to the attention mask.
+ *
+ * This function computes the ALiBi slopes for each attention head (if max_bias > 0),
+ * multiplies them with the attention mask to produce bias tensors, and adds these biases
+ * to the destination tensor (@p dst).
+ *
+ * The function performs necessary broadcasting of the mask and slope tensors to match
+ * the shape of the destination tensor, then applies element-wise multiplication and addition
+ * using CANN operators.
+ *
+ * @param ctx         CANN backend context for memory management and operator execution.
+ * @param mask        Input attention mask tensor, assumed to be contiguous.
+ * @param dst         Destination tensor to which ALiBi biases will be added.
+ * @param dst_ptr     Pointer to the memory of the destination tensor.
+ * @param max_bias    Maximum bias value controlling the slope scaling.
+ *
+ * @note
+ * - Write data into dst_ptr using only the shape information of the dst tensor.
+ * - `GGML_MAX_DIMS + 2` is used to extend tensor dimensions for broadcasting.
+ */
+static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask,
+    ggml_tensor* dst, void* dst_ptr, float max_bias) {
+    void* slope_buffer = nullptr;
+    void* bias_buffer = nullptr;
+
+    if (max_bias > 0.0f) {
+        int64_t n_heads = dst->ne[2];
+        ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float));
+        slope_buffer = slope_allocator.get();
+        ggml_cann_pool_alloc bias_allocator(
+                    ctx.pool(), ggml_nelements(dst) * ggml_element_size(dst));
+        bias_buffer = bias_allocator.get();
+        aclnn_get_slope(ctx, n_heads, slope_buffer, max_bias);
     }
-    aclTensor* tmp_mk_tensor = ggml_cann_create_tensor(
-        tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type),
-        ggml_type_size(dst->type), tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS,
-        ACL_FORMAT_ND);
 
-    // acl_position * mk
-    int64_t tmp_output_ne[] = {src_ne[0], src_ne[1], src_ne[2], src_ne[3]};
-    size_t tmp_output_nb[GGML_MAX_DIMS];
-    tmp_output_nb[0] = ggml_type_size(dst->type);
-    for (int i = 1; i < GGML_MAX_DIMS; i++) {
-        tmp_output_nb[i] = tmp_output_nb[i - 1] * tmp_output_ne[i - 1];
+    // broadcast for mask, slop and dst;
+    int64_t nr2 = dst->ne[2] / mask->ne[2];
+    int64_t nr3 = dst->ne[3] / mask->ne[3];
+
+    // broadcast the mask across rows
+    int64_t mask_ne[] = { mask->ne[0], dst->ne[1], mask->ne[2], 1, mask->ne[3], 1 };
+    size_t  mask_nb[] = {
+        mask_nb[0] = mask->nb[0], mask_nb[1] = mask->nb[1], mask_nb[2] = mask->nb[2],
+        mask_nb[3] = mask->nb[2], mask_nb[4] = mask->nb[3], mask_nb[5] = mask->nb[3]
+    };
+
+    int64_t dst_ne[] = { dst->ne[0], dst->ne[1], mask->ne[2], nr2, mask->ne[3], nr3 };
+    size_t  dst_nb[] = {
+        dst_nb[0] = dst->nb[0], dst_nb[1] = dst->nb[1], dst_nb[2] = dst->nb[2],
+        dst_nb[3] = dst->nb[2], dst_nb[4] = dst->nb[3], dst_nb[5] = dst->nb[3]
+    };
+
+    // slope is a 1 dim tensor, slope.ne2 == dst.ne2
+    int64_t slope_ne[] = { 1, 1, mask->ne[2], nr2, 1, 1 };
+    size_t  slope_nb[GGML_MAX_DIMS + 2];
+    slope_nb[0] = sizeof(float);
+    for (int i = 1; i < GGML_MAX_DIMS + 2; i++) {
+        slope_nb[i] = slope_nb[i - 1] * slope_ne[i - 1];
     }
-    ggml_cann_pool_alloc output_allocator(ctx.pool(), ggml_nbytes(dst));
-    void* tmp_output_buffer = output_allocator.get();
-    aclTensor* tmp_output_tensor = ggml_cann_create_tensor(
-        tmp_output_buffer, ggml_cann_type_mapping(dst->type),
-        ggml_type_size(dst->type), tmp_output_ne, tmp_output_nb, GGML_MAX_DIMS,
-        ACL_FORMAT_ND);
-    aclnn_mul(ctx, acl_position, tmp_mk_tensor, tmp_output_tensor);
 
-    // add
-    aclnn_add(ctx, tmp_output_tensor, acl_src, acl_dst);
-    ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor,
-        tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor,
-        tmp_arange_tensor, tmp_mk_tensor, tmp_output_tensor);
+    aclTensor* acl_slope = ggml_cann_create_tensor(
+                            slope_buffer, ACL_FLOAT, sizeof(float),
+                            slope_ne, slope_nb, GGML_MAX_DIMS + 2);
+    aclTensor* acl_mask = ggml_cann_create_tensor(
+                            mask, mask_ne, mask_nb, GGML_MAX_DIMS + 2);
+
+    // write data into dst_ptr using only the shape information of the dst tensor.
+    aclTensor* acl_dst  = ggml_cann_create_tensor(
+                            dst_ptr, ggml_cann_type_mapping(dst->type),
+                            ggml_type_size(dst->type), dst_ne, dst_nb,
+                            GGML_MAX_DIMS + 2);
+
+    if (max_bias > 0.0f) {
+        int64_t bias_ne[] = { mask->ne[0], dst->ne[1], mask->ne[2], nr2, mask->ne[3], 1 };
+        size_t  bias_nb[GGML_MAX_DIMS + 2];
+        bias_nb[0] = sizeof(float);
+        for (int i = 1; i < GGML_MAX_DIMS + 2; i++) {
+            bias_nb[i] = bias_nb[i - 1] * bias_ne[i - 1];
+        }
+        aclTensor* bias_tensor = ggml_cann_create_tensor(
+                                    bias_buffer, ACL_FLOAT, sizeof(float),
+                                    bias_ne, bias_nb, GGML_MAX_DIMS + 2);
+
+        aclnn_mul(ctx, acl_slope, acl_mask, bias_tensor);
+        aclnn_add(ctx, acl_dst, bias_tensor);
+        ggml_cann_release_resources(ctx, bias_tensor);
+    } else {
+        aclnn_add(ctx, acl_dst, acl_mask);
+    }
+    ggml_cann_release_resources(ctx, acl_slope, acl_mask, acl_dst);
 }
 
-void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+void ggml_cann_cpy(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
     ggml_cann_dup(ctx, dst);
 }
 
@@ -1463,165 +1523,135 @@ void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  * @param acl_dst The destination tensor where the softmax results will be
  * stored.
  */
-static void aclnn_softmax(ggml_backend_cann_context& ctx, aclTensor* acl_src,
-                          int64_t dim, aclTensor* acl_dst) {
+static void aclnn_softmax(ggml_backend_cann_context & ctx,
+    aclTensor* acl_src, int64_t dim, aclTensor * acl_dst) {
     GGML_CANN_CALL_ACLNN_OP(ctx, Softmax, acl_src, dim, acl_dst);
 }
 
-void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+void ggml_cann_softmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
     ggml_tensor* src0 = dst->src[0];
     ggml_tensor* src1 = dst->src[1];  // mask
 
     aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
-    aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+    aclTensor* acl_dst  = ggml_cann_create_tensor(dst);
 
-    float scale = 1.0f;
+    float scale    = 1.0f;
     float max_bias = 0.0f;
 
-    memcpy(&scale, (float*)dst->op_params + 0, sizeof(float));
-    memcpy(&max_bias, (float*)dst->op_params + 1, sizeof(float));
+    memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
 
     // input mul scale
     aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT);
+    ggml_cann_pool_alloc src_tensor_allocator(ctx.pool(), ggml_nbytes(src0));
+    void* src_tensor_buffer = src_tensor_allocator.get();
+    aclTensor* softmax_tensor = ggml_cann_create_tensor(
+        src_tensor_buffer, ggml_cann_type_mapping(src0->type),
+        ggml_element_size(src0), src0->ne, src0->nb,GGML_MAX_DIMS);
 
-    size_t n_bytes = ggml_nbytes(src0);
-    ggml_cann_pool_alloc mul_scale_allocator(ctx.pool(), n_bytes);
-    void* input_mul_scale_buffer = mul_scale_allocator.get();
-    aclTensor* acl_input_mul_scale_tensor = ggml_cann_create_tensor(
-        input_mul_scale_buffer, ACL_FLOAT, ggml_type_size(src0->type), src0->ne,
-        src0->nb, GGML_MAX_DIMS);
-
-    bool inplace = false;
-    aclnn_muls(ctx, acl_src0, scale, acl_input_mul_scale_tensor, inplace);
+    aclnn_muls(ctx, acl_src0, scale, softmax_tensor, false);
 
     // mask
-    aclTensor* acl_src1_fp32_tensor = nullptr;
-    aclTensor* tmp_mask_tensor = nullptr;
-    ggml_cann_pool_alloc src1_fp32_allocator(ctx.pool());
     if (src1) {
-        const bool use_f16 = src1->type == GGML_TYPE_F16;
-        if (use_f16) {
-            // cast to fp32
-            size_t n_bytes = ggml_nelements(src1) * sizeof(float_t);
-            size_t src1_fp32_nb[GGML_MAX_DIMS];
-            src1_fp32_nb[0] = sizeof(float_t);
-            for (int i = 1; i < GGML_MAX_DIMS; i++) {
-                src1_fp32_nb[i] = src1_fp32_nb[i - 1] * src1->ne[i - 1];
-            }
-            src1_fp32_allocator.alloc(n_bytes);
-            void* src1_fp32_buffer = src1_fp32_allocator.get();
-            acl_src1_fp32_tensor = ggml_cann_create_tensor(
-                src1_fp32_buffer, ACL_FLOAT, sizeof(float), src1->ne,
-                src1_fp32_nb, GGML_MAX_DIMS);
-            aclTensor* acl_src1 = ggml_cann_create_tensor(src1);
-            aclnn_cast(ctx, acl_src1, acl_src1_fp32_tensor, ACL_FLOAT);
-            ggml_cann_release_resources(ctx, acl_src1);
-        } else {
-            acl_src1_fp32_tensor = ggml_cann_create_tensor(src1);
-        }
+        aclnn_add_alibi(ctx, src1, src0, src_tensor_buffer, max_bias);
+    }
+    // softmax
+    aclnn_softmax(ctx, softmax_tensor, 3, acl_dst);
+    ggml_cann_release_resources(ctx, acl_src0, acl_dst, acl_scale, softmax_tensor);
+}
 
-        // broadcast the mask across rows, only use ne11 of ne01 in mask
-        if (src1->ne[1] != src0->ne[1]) {
-            // mask shape: [1,1,ne11,ne10]
-            int64_t tmp_mask_ne[] = {src0->ne[0], src0->ne[1], 1, 1};
-            size_t tmp_mask_nb[GGML_MAX_DIMS];
-            tmp_mask_nb[0] = sizeof(float_t);
-            for (int i = 1; i < GGML_MAX_DIMS; i++) {
-                tmp_mask_nb[i] = tmp_mask_nb[i - 1] * tmp_mask_ne[i - 1];
-            }
-            tmp_mask_tensor = ggml_cann_create_tensor(
-                src1->data, ACL_FLOAT, sizeof(float), tmp_mask_ne, tmp_mask_nb,
-                GGML_MAX_DIMS, ACL_FORMAT_ND);
-        }
+/**
+ * @brief Performs index select operation on a 4D tensor using the CANN backend.
+ *
+ * This function applies the `IndexSelect` operation along a specific dimension
+ * of the source tensor (`src_buffer`) using the indices from the index tensor (`index`).
+ * It iterates over the last two dimensions of the source tensor, creates the corresponding
+ * CANN tensors for the source, index, and output slices, and executes the `IndexSelect`
+ * operation for each slice.
+ *
+ * @param ctx The context for CANN backend operations.
+ * @param src_buffer The source buffer containing the 4D input tensor data.
+ * @param src_ne The dimensions of the source tensor.
+ * @param src_nb The strides (byte offsets) of the source tensor.
+ * @param dst_buffer The destination buffer where the output tensor data will be written.
+ * @param dst_ne The dimensions of the destination tensor.
+ * @param dst_nb The strides (byte offsets) of the destination tensor.
+ * @param index The index tensor specifying the indices to select from the source tensor.
+ * @param type The data type of the source and destination tensors.
+ */
+static void aclnn_index_select_4d(ggml_backend_cann_context& ctx,
+                                void* src_buffer,int64_t* src_ne, size_t* src_nb,
+                                void* dst_buffer, int64_t* dst_ne, size_t* dst_nb,
+                                ggml_tensor* index, ggml_type type) {
+    for (int64_t i = 0; i < src_ne[3]; i++) {
+        for (int64_t j = 0; j < src_ne[2]; j++) {
+            // src
+            aclTensor* acl_src_tensor = ggml_cann_create_tensor(
+                (char*)src_buffer + i * src_nb[3] + j * src_nb[2],
+                ggml_cann_type_mapping(type), ggml_type_size(type),
+                src_ne, src_nb, 2);
 
-        // alibi
-        const int n_head = src0->ne[2];
-        const size_t src_nb0 = src0->nb[0];
-
-        n_bytes = ggml_nbytes(dst);
-        ggml_cann_pool_alloc output_allocator(ctx.pool(), n_bytes);
-        void* output_buffer = output_allocator.get();
-        aclTensor* alibi_output_tensor = ggml_cann_create_tensor(
-            output_buffer, ACL_FLOAT, ggml_type_size(dst->type), dst->ne,
-            dst->nb, GGML_MAX_DIMS);
-        if (max_bias <= 0.0f) {
-            // slope = 1.0
-            if (tmp_mask_tensor) {
-                aclnn_add(ctx, tmp_mask_tensor, acl_input_mul_scale_tensor,
-                          alibi_output_tensor);
-            } else {
-                aclnn_add(ctx, acl_src1_fp32_tensor, acl_input_mul_scale_tensor,
-                          alibi_output_tensor);
-            }
-        } else {
-            // slope != 1.0
-            if (tmp_mask_tensor) {
-                aclnn_alibi(ctx, acl_input_mul_scale_tensor, tmp_mask_tensor,
-                            alibi_output_tensor, n_head, src0->ne, src_nb0,
-                            max_bias, dst);
-            } else {
-                aclnn_alibi(ctx, acl_input_mul_scale_tensor,
-                            acl_src1_fp32_tensor, alibi_output_tensor, n_head,
-                            src0->ne, src_nb0, max_bias, dst);
-            }
-        }
+            // index
+            aclTensor* acl_index = ggml_cann_create_tensor(
+                (char*)index->data + (i % index->ne[2]) * index->nb[2] + (j % index->ne[1]) * index->nb[1],
+                ggml_cann_type_mapping(index->type), ggml_element_size(index),
+                index->ne, index->nb, 1);
 
-        // softmax
-        aclnn_softmax(ctx, alibi_output_tensor, 3, acl_dst);
-        ggml_cann_release_resources(ctx, alibi_output_tensor);
-    } else {
-        aclnn_softmax(ctx, acl_input_mul_scale_tensor, 3, acl_dst);
+            // out
+            aclTensor* acl_out = ggml_cann_create_tensor(
+                (char*)dst_buffer + i * dst_nb[3] + j * dst_nb[2],
+                ggml_cann_type_mapping(type), ggml_type_size(type),
+                dst_ne, dst_nb, 2);
+            GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, acl_src_tensor, 0, acl_index, acl_out);
+            ggml_cann_release_resources(ctx, acl_src_tensor, acl_index, acl_out);
+        }
     }
-
-    ggml_cann_release_resources(ctx, acl_src0, acl_src1_fp32_tensor, acl_dst,
-        acl_scale, acl_input_mul_scale_tensor, tmp_mask_tensor);
 }
 
 /**
- * @brief Performs embedding operation on a 4D tensor using the CANN backend.
+ * @brief Performs inplace index copy operation on a 4D tensor using the CANN backend.
  *
- * This function extracts slices from the source tensor (`src_buffer`),
- * index tensor (`index`), and destination tensor (`dst`), and performs an
- * embedding operation on them. The embedding operation is applied by iterating
- * over the last two dimensions of the source tensor, creating the necessary
- * tensors for the source, index, and output, and executing the embedding operation.
+ * This function applies the `IndexCopy` operation along a specific dimension of the
+ * destination tensor (`dst_buffer`) by copying elements from the source tensor (`src_buffer`)
+ * to positions specified by the index tensor (`index`).
+ * It iterates over the last two dimensions of the tensors, creates the corresponding
+ * CANN tensors for source, index, and destination slices, and performs the index copy
+ * operation for each slice.
  *
  * @param ctx The context for CANN backend operations.
- * @param src_buffer The source buffer holding the data for the source tensor.
+ * @param src_buffer The source buffer containing the 4D input tensor data to be copied.
  * @param src_ne The dimensions of the source tensor.
  * @param src_nb The strides (byte offsets) of the source tensor.
- * @param index The index tensor used in the embedding operation.
- * @param dst The destination tensor where the result will be stored.
+ * @param dst_buffer The destination buffer where values will be copied to.
+ * @param dst_ne The dimensions of the destination tensor.
+ * @param dst_nb The strides (byte offsets) of the destination tensor.
+ * @param index The index tensor specifying target positions in the destination tensor.
+ * @param type The data type of the source and destination tensors.
  */
-static void aclnn_embedding_4d(ggml_backend_cann_context& ctx, void* src_buffer,
-                            int64_t* src_ne, size_t* src_nb, ggml_tensor* index,
-                            ggml_tensor* dst) {
+static void aclnn_index_copy_4d(ggml_backend_cann_context& ctx,
+                                void* src_buffer,int64_t* src_ne, size_t* src_nb,
+                                void* dst_buffer, int64_t* dst_ne, size_t* dst_nb,
+                                ggml_tensor* index, ggml_type type) {
     for (int64_t i = 0; i < src_ne[3]; i++) {
         for (int64_t j = 0; j < src_ne[2]; j++) {
             // src
-            int64_t acl_src_ne[2] = {src_ne[0], src_ne[1]};
-            size_t acl_src_nb[2] = {src_nb[0], src_nb[1]};
             aclTensor* acl_src_tensor = ggml_cann_create_tensor(
                 (char*)src_buffer + i * src_nb[3] + j * src_nb[2],
-                ggml_cann_type_mapping(dst->type), ggml_element_size(dst),
-                acl_src_ne, acl_src_nb, 2);
+                ggml_cann_type_mapping(type), ggml_type_size(type),
+                src_ne, src_nb, 2);
 
             // index
-            int64_t acl_index_ne[1] = {index->ne[0]};
-            size_t acl_index_nb[1] = {index->nb[0]};
             aclTensor* acl_index = ggml_cann_create_tensor(
-                (char*)index->data + i * index->nb[2] + j * index->nb[1],
+                (char*)index->data + (i % index->ne[2]) * index->nb[2] + (j % index->ne[1]) * index->nb[1],
                 ggml_cann_type_mapping(index->type), ggml_element_size(index),
-                acl_index_ne, acl_index_nb, 1);
+                index->ne, index->nb, 1);
 
             // out
-            int64_t acl_out_ne[2] = {dst->ne[0], dst->ne[1]};
-            size_t acl_out_nb[2] = {dst->nb[0], dst->nb[1]};
             aclTensor* acl_out = ggml_cann_create_tensor(
-                (char*)dst->data + i * dst->nb[3] + j * dst->nb[2],
-                ggml_cann_type_mapping(dst->type), ggml_element_size(dst),
-                acl_out_ne, acl_out_nb, 2);
-            GGML_CANN_CALL_ACLNN_OP(ctx, Embedding, acl_src_tensor, acl_index, acl_out);
+                (char*)dst_buffer + i * dst_nb[3] + j * dst_nb[2],
+                ggml_cann_type_mapping(type), ggml_type_size(type),
+                dst_ne, dst_nb, 2);
+            GGML_CANN_CALL_ACLNN_OP(ctx, InplaceIndexCopy, acl_out, 0, acl_index, acl_src_tensor);
             ggml_cann_release_resources(ctx, acl_src_tensor, acl_index, acl_out);
         }
     }
@@ -1633,8 +1663,9 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
 
     switch (src0->type) {
         case GGML_TYPE_F32: {
-            aclnn_embedding_4d(ctx, src0->data, src0->ne, src0->nb, src1,
-                                   dst);
+            aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb,
+                                dst->data, dst->ne, dst->nb,
+                                src1, dst->type);
             break;
         }
         case GGML_TYPE_F16: {
@@ -1651,8 +1682,9 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
                 src_trans_buffer, ACL_FLOAT, ggml_type_size(dst->type),
                 src0->ne, src_trans_nb, GGML_MAX_DIMS);
             aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type));
-            aclnn_embedding_4d(ctx, src_trans_buffer, src0->ne,
-                                   src_trans_nb, src1, dst);
+            aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb,
+                                dst->data, dst->ne, dst->nb,
+                                src1, dst->type);
             ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor);
             break;
         }
@@ -1712,8 +1744,10 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
                 dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1];
             }
 
-            aclnn_embedding_4d(ctx, dequant_buffer_allocator.get(),
-                                   dequant_ne, dequant_nb, src1, dst);
+            aclnn_index_select_4d(ctx, dequant_buffer_allocator.get(),
+                                   dequant_ne, dequant_nb,
+                                   dst->data, dst->ne, dst->nb,
+                                   src1, dst->type);
 
             ggml_cann_release_resources(ctx, dequant_tensor);
             break;
@@ -1724,6 +1758,43 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     }
 }
 
+void ggml_cann_set_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+    ggml_tensor* src0 = dst->src[0];  // src
+    ggml_tensor* src1 = dst->src[1];  // index
+
+    switch (dst->type) {
+        case GGML_TYPE_F32: {
+            aclnn_index_copy_4d(ctx, src0->data, src0->ne, src0->nb,
+                                dst->data, dst->ne, dst->nb,
+                                src1, dst->type);
+            break;
+        }
+        case GGML_TYPE_F16: {
+            aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
+            ggml_cann_pool_alloc src_buffer_allocator(
+                ctx.pool(), ggml_nelements(src0) * sizeof(uint16_t));
+            void* src_trans_buffer = src_buffer_allocator.get();
+            size_t src_trans_nb[GGML_MAX_DIMS];
+            src_trans_nb[0] = sizeof(uint16_t);
+            for (int i = 1; i < GGML_MAX_DIMS; i++) {
+                src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
+            }
+            aclTensor* src_trans_tensor = ggml_cann_create_tensor(
+                src_trans_buffer, ACL_FLOAT16, ggml_type_size(dst->type),
+                src0->ne, src_trans_nb, GGML_MAX_DIMS);
+            aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type));
+            aclnn_index_copy_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb,
+                                dst->data, dst->ne, dst->nb,
+                                src1, dst->type);
+            ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor);
+            break;
+        }
+        default:
+            GGML_ABORT("Unsupported tensor type for GGML_OP_SET_ROWS");
+            break;
+    }
+}
+
 /**
  * @brief Repeats elements of a tensor along a specified dimension.
  *
@@ -1785,8 +1856,25 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx,
     size_t transpose_nb[] = {bcast_weight_nb[1], bcast_weight_nb[0],
                              bcast_weight_nb[2], bcast_weight_nb[3],
                              bcast_weight_nb[4], bcast_weight_nb[5]};
-    aclTensor* acl_weight_tensor =
-        ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims);
+    aclTensor* acl_weight_tensor;
+
+    // Only check env once.
+    static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or(""));
+    if (weight_to_nz && is_matmul_weight(weight)) {
+        int64_t acl_stride[2] = {1, transpose_ne[1]};
+
+        // Reverse ne.
+        std::reverse(transpose_ne, transpose_ne + n_dims);
+
+        std::vector storageDims = {transpose_ne[0], transpose_ne[1]};
+
+        acl_weight_tensor = aclCreateTensor(
+            transpose_ne, n_dims, ggml_cann_type_mapping(weight->type), acl_stride,
+            0, ACL_FORMAT_FRACTAL_NZ, storageDims.data(), 2, weight->data);
+    } else {
+        acl_weight_tensor =
+            ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND);
+    }
     aclTensor* acl_dst =
         ggml_cann_create_tensor(dst, bcast_dst_ne, bcast_dst_nb, n_dims);
 
@@ -3065,104 +3153,24 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
             // Compute the slope if needed. Derived from ggml_cann_softmax().
             if(maxBias != 0.0f){
                 // alibi
-                const int64_t ne2_ne3 = src0->ne[2] * src0->ne[3];
-                const int64_t n_head = src0->ne[2];
-                const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head));
-                float m0 = powf(2.0f, -(maxBias) / n_heads_log2_floor);
-                float m1 = powf(2.0f, -(maxBias / 2.0f) / n_heads_log2_floor);
-                // init arange
-                ggml_cann_pool_alloc arange_allocator(ctx.pool(),
-                                                    ne2_ne3 * faElemSize);
-                void* tmp_arange_buffer = arange_allocator.get();
-
-                // arange1: [1, ..., n_heads_log2_floor+1)
-                float start = 1;
-                float stop = n_heads_log2_floor + 1;
-                float step = 1;
-                int64_t n_elements_arange = n_heads_log2_floor;
-
-                int64_t tmp_arange1_ne[] = {n_heads_log2_floor};
-                size_t tmp_arange1_nb[] = {faElemSize};
-                aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor(
-                    tmp_arange_buffer, faDataType, faElemSize,
-                    tmp_arange1_ne, tmp_arange1_nb,
-                    GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
-
-                aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange);
-
-                aclTensor* tmp_arange2_tensor = nullptr;
-                if (n_heads_log2_floor < ne2_ne3) {
-                    // arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1)
-                    start = 1;
-                    stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1;
-                    step = 2;
-                    n_elements_arange = ne2_ne3 - n_heads_log2_floor;
-                    int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor};
-                    size_t tmp_arange2_nb[] = {faElemSize};
-
-                    aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor(
-                        (char*)tmp_arange_buffer +
-                            n_heads_log2_floor * faElemSize,
-                        faDataType, faElemSize,
-                        tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
-                    aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step,
-                                n_elements_arange);
+                const int64_t n_heads = src0->ne[2];
+                ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float));
+                void* slope_buffer = slope_allocator.get();
+                aclnn_get_slope(ctx, n_heads, slope_buffer, maxBias);
+
+                int64_t slope_ne[] = {1, 1, n_heads, 1};
+                size_t slope_nb[GGML_MAX_DIMS];
+                slope_nb[0] = sizeof(float);
+                for(int i = 1;ine[2], src0->ne[3]};
-                size_t tmp_mk_nb[GGML_MAX_DIMS];
-                tmp_mk_nb[0] = faElemSize;
-                for (int i = 1; i < GGML_MAX_DIMS; i++) {
-                    tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1];
-                }
-                aclTensor* tmp_mk_tensor = ggml_cann_create_tensor(
-                    tmp_mk_base_buffer, faDataType, faElemSize,
-                    tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS,
-                    ACL_FORMAT_ND);
-                GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, tmp_mk_tensor);
-
-                ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor,
-                    tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor,
-                    tmp_arange_tensor, tmp_mk_tensor);
+                ggml_cann_release_resources(ctx, slope_tensor);
             }
         }
 
diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h
index 80ce80baea0..5c510cc9932 100755
--- a/ggml/src/ggml-cann/aclnn_ops.h
+++ b/ggml/src/ggml-cann/aclnn_ops.h
@@ -23,6 +23,7 @@
 #ifndef CANN_ACLNN_OPS
 #define CANN_ACLNN_OPS
 
+#include 
 #include 
 #include 
 #include 
@@ -423,15 +424,25 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst);
  *
  * @details This function retrieves rows from a source tensor src0 according to
  *          the indices provided in another tensor src1 and stores the result in
- *          a destination tensor (\p dst). It supports different data types
- *          including F32, F16, Q4_0, and Q8_0.
+ *          a destination tensor (\p dst).
  *
  * @param ctx The backend CANN context for executing operations.
  * @param dst The destination tensor where the extracted rows will be stored.
- *            dst->op is `GGML_OP_GET_ROWS`.
  */
 void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst);
 
+/**
+ * @brief   Writes specific rows into a tensor at positions specified by indices.
+ *
+ * @details This function copies rows from a source tensor into a destination
+ *          tensor (\p dst) at the positions indicated by the indices in another
+ *          tensor.
+ *
+ * @param ctx The backend CANN context for executing operations.
+ * @param dst The destination tensor where the specified rows will be updated.
+ */
+void ggml_cann_set_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
 /**
  * @brief   Executes matrix multiplication for the given tensor.
  *
@@ -1020,6 +1031,37 @@ inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffe
  */
 void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst);
 
+/**
+ * @brief   Check whether a tensor is a weight tensor for matrix multiplication.
+ *
+ * @details Checks whether the given tensor serves as weight parameters in matrix multiplication operations,
+ *          typically within neural network layers. The function maintains a static set of canonical weight
+ *          naming suffixes from Transformer-based architectures. Uses substring matching to identify weight
+ *          tensors even with hierarchical naming patterns.
+ *
+ * @param tensor Pointer to the target ggml_tensor object (const-qualified).
+ */
+static bool is_matmul_weight(const ggml_tensor* tensor) {
+    std::string name = ggml_get_name(tensor);
+    static const std::unordered_set weight_suffixes{
+        "output.weight",
+        "attn_q.weight",
+        "attn_k.weight",
+        "attn_v.weight",
+        "attn_output.weight",
+        "ffn_gate.weight",
+        "ffn_up.weight",
+        "ffn_down.weight"
+    };
+
+    for (const auto& suffix : weight_suffixes) {
+        if (name.find(suffix) != std::string::npos) {
+            return true;
+        }
+    }
+    return false;
+}
+
 /**
  * @brief Applies a element-wise operation to two input tensors using the CANN
  * backend.
@@ -1066,7 +1108,7 @@ void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  * @param dst The destination tensor. Its src[0] is treated as the input tensor.
  */
 template 
-    void ggml_cann_unary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+    void ggml_cann_op_unary(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     ggml_tensor* src = dst->src[0];
 
     aclTensor* acl_src = ggml_cann_create_tensor(src);
@@ -1077,49 +1119,125 @@ template 
 }
 
 /**
- * @brief   Applies a unary operation to a ggml tensor using the CANN backend.
+ * @brief Applies a unary operation to a ggml tensor using the CANN backend.
  *
- * @details This function performs a unary operation on the input tensor using
- * a user-provided lambda or callable object `unary_op`, which accepts the CANN
- * context and two ACL tensors (source and destination). Internally, this function
- * creates ACL representations of the ggml tensors and invokes the unary operation.
- * The result is stored in the destination tensor `dst`. This utility abstracts the
- * common boilerplate of tensor conversion and cleanup when implementing unary ops.
+ * @details This function applies a unary operation to the input tensor using
+ * a user-provided lambda or callable `unary_op`. The lambda receives the
+ * CANN backend context and two ACL tensors: the source and the destination.
  *
- * @param unary_op A callable that performs the unary operation using CANN APIs.
- * @param ctx The CANN context used for operations.
- * @param dst The destination tensor where the result will be stored.
- *            The source tensor is retrieved from `dst->src[0]`.
+ * Internally, this function handles the conversion from GGML tensors to ACL tensors,
+ * calls the provided unary op, and manages resource cleanup. The input is assumed
+ * to be `dst->src[0]`, and the result is written to `dst`.
+ *
+ * This utility simplifies writing unary op wrappers by abstracting tensor preparation.
+ *
+ * @param unary_op A callable that performs the unary operation using CANN ACL APIs.
+ * @param ctx The CANN context for operation execution.
+ * @param dst The destination ggml_tensor where the result will be stored.
+ *            The input tensor is assumed to be `dst->src[0]`.
+ *
+ * @see GGML_CANN_CALL_OP_UNARY
  */
-void ggml_cann_unary_op(
+void ggml_cann_op_unary(
     std::function unary_op,
     ggml_backend_cann_context& ctx, ggml_tensor* dst);
 
 /**
- * @brief Helper macro to invoke a unary ACL operation using ggml_cann_unary_op.
+ * @brief Applies a gated (GLU-style) unary operation using the CANN backend.
+ *
+ * @details This function performs a gated activation such as GEGLU or ReGLU.
+ * It supports two input modes:
+ *
+ * 1. **Dual input mode**: `dst->src[0]` and `dst->src[1]` are both valid tensors.
+ *    These are used directly as the value and gate tensors.
  *
- * This macro defines an inline lambda wrapping a specific ACL operation name,
- * and passes it to the templated ggml_cann_unary_op function. It simplifies
- * calling unary ops by hiding the lambda boilerplate.
+ * 2. **Packed input mode**: Only `dst->src[0]` is valid, and it is assumed to
+ *    contain a concatenation of value and gate along the first dimension. This tensor
+ *    will be split into two equal halves to form the value and gate inputs.
+ *
+ * The function applies a user-provided unary operation (e.g., GELU) to the value tensor,
+ * then multiplies the result in-place with the gate tensor:
  *
- * Internally, the lambda will call:
  * @code
- * GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);
+ * dst = unary_op(value) * gate;
  * @endcode
  *
+ * The `swapped` parameter (from `dst->op_params[1]`) allows flipping the
+ * order of value/gate in the packed input case.
+ *
+ * @param unary_op A callable that performs the unary operation using CANN ACL APIs.
+ *                 It receives (ctx, acl_value_tensor, acl_output_tensor).
+ * @param ctx      The CANN context used for execution.
+ * @param dst      The destination ggml_tensor. Source tensors are in `dst->src[0]` and optionally `src[1]`.
+ *
+ * @see GGML_CANN_CALL_OP_UNARY_GATED
+ */
+void ggml_cann_op_unary_gated(
+    std::function unary_op,
+    ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief Helper macro to call a unary ACL operator via ggml_cann_op_unary.
+ *
+ * This macro wraps the specified ACLNN unary operator name into a lambda expression,
+ * and passes it to `ggml_cann_op_unary`, which handles the common logic for executing
+ * unary ops in the CANN backend.
+ *
+ * Internally, this macro expands to a lambda like:
+ * @code
+ * [](ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) {
+ *     GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);
+ * };
+ * @endcode
+ *
+ * This lambda is then passed to `ggml_cann_op_unary`, which applies the operation.
+ *
+ * @param OP_NAME The name of the ACL unary operator to invoke via GGML_CANN_CALL_ACLNN_OP.
+ *
+ * @see ggml_cann_op_unary
+ * @see GGML_CANN_CALL_ACLNN_OP
+ */
+#define GGML_CANN_CALL_OP_UNARY(OP_NAME)                              \
+    do {                                                              \
+        auto lambda = [](ggml_backend_cann_context& ctx,              \
+            aclTensor* acl_src,                                       \
+            aclTensor* acl_dst) {                                     \
+            GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);  \
+        };                                                            \
+        ggml_cann_op_unary(lambda, ctx, dst);                         \
+    }                                                                 \
+    while (0)
+
+/**
+ * @brief Helper macro to call a gated unary ACL operator via ggml_cann_op_unary_gated.
+ *
+ * This macro wraps the specified ACLNN unary operator name into a lambda expression,
+ * and passes it to `ggml_cann_op_unary_gated`, which handles the common logic for
+ * executing gated unary ops in the CANN backend.
+ *
+ * Internally, this macro expands to a lambda like:
+ * @code
+ * [](ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) {
+ *     GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);
+ * };
+ * @endcode
+ *
+ * This lambda is then passed to `ggml_cann_op_unary_gated`, which applies the operation.
+ *
  * @param OP_NAME The name of the ACL unary operator to invoke via GGML_CANN_CALL_ACLNN_OP.
  *
- * @see ggml_cann_unary_op
+ * @see ggml_cann_op_unary_gated
  * @see GGML_CANN_CALL_ACLNN_OP
  */
-#define GGML_CANN_CALL_UNARY_OP(OP_NAME)                              \
+#define GGML_CANN_CALL_OP_UNARY_GATED(OP_NAME)                        \
     do {                                                              \
         auto lambda = [](ggml_backend_cann_context& ctx,              \
             aclTensor* acl_src,                                       \
             aclTensor* acl_dst) {                                     \
             GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);  \
         };                                                            \
-        ggml_cann_unary_op(lambda, ctx, dst);                         \
+        ggml_cann_op_unary_gated(lambda, ctx, dst);                   \
     }                                                                 \
     while (0)
+
 #endif  // CANN_ACLNN_OPS
diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h
index 8dfe3b061c1..9d294f72b67 100755
--- a/ggml/src/ggml-cann/common.h
+++ b/ggml/src/ggml-cann/common.h
@@ -337,6 +337,29 @@ class cann_task_queue {
     int32_t device_;
 };
 
+#ifdef USE_ACL_GRAPH
+struct ggml_graph_node_properties {
+    void * node_address;
+    ggml_op node_op;
+    int64_t ne[GGML_MAX_DIMS];
+    size_t nb[GGML_MAX_DIMS];
+    void * src_address[GGML_MAX_SRC];
+    int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
+};
+
+struct ggml_cann_graph {
+    ~ggml_cann_graph() {
+        if (graph != nullptr) {
+            aclmdlRIDestroy(graph);
+        }
+    }
+
+    aclmdlRI graph = nullptr;
+
+    std::vector ggml_graph_properties;
+};
+#endif  // USE_ACL_GRAPH
+
 /**
  * @brief Context for managing CANN backend operations.
  */
@@ -345,8 +368,13 @@ struct ggml_backend_cann_context {
     std::string name;                /**< Name of the device. */
     std::string description;         /**< Description of the device. */
     aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
+#ifdef USE_ACL_GRAPH
+    /// Cached CANN ACL graph used for executing the current ggml computation graph.
+    std::unique_ptr cann_graph;
+#endif
     cann_task_queue task_queue;
     bool async_mode;
+    bool support_set_rows;
 
     aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
 
@@ -362,6 +390,14 @@ struct ggml_backend_cann_context {
         async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
         GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__,
             device, async_mode ? "ON" : "OFF");
+
+        support_set_rows = parse_bool(get_env("LLAMA_SET_ROWS").value_or(""));
+        GGML_LOG_INFO("%s: LLAMA_SET_ROWS is %s\n", __func__, support_set_rows ? "ON" : "OFF");
+
+        if (!support_set_rows) {
+            GGML_LOG_INFO("%s: CANN Graph currently only supports execution when LLAMA_SET_ROWS is ON. "
+                    "Falling back to eager mode.\n", __func__);
+        }
     }
 
     /**
diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp
index ccb17eb072e..cb8af42ebf9 100755
--- a/ggml/src/ggml-cann/ggml-cann.cpp
+++ b/ggml/src/ggml-cann/ggml-cann.cpp
@@ -24,6 +24,7 @@
 
 #include 
 #include 
+#include 
 
 #include 
 #include 
@@ -1115,6 +1116,61 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor(
     return GGML_STATUS_SUCCESS;
 }
 
+// ND to NZ Workspace Cache Management. Thread-safety: Not guaranteed
+namespace {
+    void* g_nz_workspace = nullptr;
+    size_t g_nz_workspace_allocated = 0;
+
+    void release_nz_workspace() {
+        if (g_nz_workspace) {
+            aclrtFree(g_nz_workspace);
+            g_nz_workspace = nullptr;
+            g_nz_workspace_allocated = 0;
+        }
+    }
+
+    void relloc_nz_workspace(size_t new_size) {
+        if (new_size > g_nz_workspace_allocated) {
+        if (g_nz_workspace) {
+            aclrtFree(g_nz_workspace);
+            g_nz_workspace = nullptr;
+        }
+        ACL_CHECK(aclrtMalloc(&g_nz_workspace, new_size, ACL_MEM_MALLOC_HUGE_FIRST));
+        g_nz_workspace_allocated = new_size;
+    }
+    }
+}
+
+/**
+ * @brief Convert tensor weights to NZ format using Ascend CANN API.
+ *
+ * This function creates a transposed tensor descriptor and performs the
+ * TransMatmulWeight operation. Converting tensor formats can significantly
+ * improve performance on certain hardware.
+ *
+ * @param tensor Pointer to the input ggml_tensor containing the weights.
+ * @param data Pointer to the raw data buffer for the tensor weights.
+ * @param offset Byte offset within the tensor data buffer where weights start.
+ *
+ * @note The workspace buffer used in this function is managed globally and reused
+ *       across calls. This reduces overhead from repeated memory allocation and deallocation.
+ */
+static void weight_format_to_nz(ggml_tensor *tensor, const void *data, size_t offset) {
+    aclTensor* weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne,
+                                    tensor->nb, 2, ACL_FORMAT_ND, offset);
+    uint64_t workspaceSize = 0;
+    aclOpExecutor *executor;
+
+    // TransMatmulWeight
+    ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed,
+                                                    &workspaceSize, &executor));
+    // Avoid frequent malloc/free of the workspace.
+    relloc_nz_workspace(workspaceSize);
+
+    ACL_CHECK(aclnnTransMatmulWeight(g_nz_workspace, workspaceSize, executor, nullptr));
+    ACL_CHECK(aclDestroyTensor(weightTransposed));
+}
+
 // TODO: need handle tensor which has paddings.
 /**
  * @brief Set tensor data in a CANN buffer.
@@ -1139,9 +1195,16 @@ static void ggml_backend_cann_buffer_set_tensor(
     // For acl, synchronous functions use this default stream.
     // Why aclrtSynchronizeDevice?
 
+    // Only check env once.
+    static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or(""));
     if (!need_transform(tensor->type)) {
         ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size, data, size,
                               ACL_MEMCPY_HOST_TO_DEVICE));
+        if (weight_to_nz && is_matmul_weight((const ggml_tensor*)tensor)) {
+            GGML_ASSERT(tensor->ne[2] == 1);
+            GGML_ASSERT(tensor->ne[3] == 1);
+            weight_format_to_nz(tensor, data, offset);
+        }
     } else {
         void *transform_buffer = malloc(size);
         ggml_backend_cann_transform(tensor, data, transform_buffer);
@@ -1375,20 +1438,32 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size(
     size_t size = ggml_nbytes(tensor);
     int64_t ne0 = tensor->ne[0];
 
+    // Only check env once.
+    static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or(""));
+
     // last line must bigger than 32, because every single op deal at
     // least 32 bytes.
     // TODO: quantized type?
     // int64_t line_size = ne0 * ggml_element_size(tensor);
     // int64_t line_size_align_32 = (line_size + 31) & ~31;
     // size += (line_size_align_32 - line_size);
-
-    // TODO: not support quantized yet.
-    // TODO: consider un-continue tensor.
     if (ggml_is_quantized(tensor->type)) {
         if (ne0 % MATRIX_ROW_PADDING != 0) {
             size += ggml_row_size(
                 tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
         }
+    } else if (weight_to_nz && is_matmul_weight((const ggml_tensor*)tensor)) {
+        // NZ format weight are not support quantized yet.
+        // If ND tensor transform to NZ, size may changed.
+        int64_t shape[] = {tensor->ne[1], tensor->ne[0]};
+        GGML_ASSERT(tensor->ne[2] == 1);
+        GGML_ASSERT(tensor->ne[3] == 1);
+        const aclIntArray *acl_shape = aclCreateIntArray(shape, 2);
+        size_t new_size;
+        ACL_CHECK(aclnnCalculateMatmulWeightSizeV2(acl_shape,
+                    ggml_cann_type_mapping(tensor->type), &new_size));
+        ACL_CHECK(aclDestroyIntArray(acl_shape));
+        size = std::max(size, new_size);
     }
 
     return size;
@@ -1594,6 +1669,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
         case GGML_OP_GET_ROWS:
             ggml_cann_get_rows(ctx, dst);
             break;
+        case GGML_OP_SET_ROWS:
+            ggml_cann_set_rows(ctx, dst);
+            break;
         case GGML_OP_DUP:
             ggml_cann_dup(ctx, dst);
             break;
@@ -1616,16 +1694,18 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(dst)) {
                 case GGML_UNARY_OP_ABS:
-                    GGML_CANN_CALL_UNARY_OP(Abs);
+                    GGML_CANN_CALL_OP_UNARY(Abs);
                     break;
                 case GGML_UNARY_OP_NEG:
-                    GGML_CANN_CALL_UNARY_OP(Neg);
+                    GGML_CANN_CALL_OP_UNARY(Neg);
                     break;
                 case GGML_UNARY_OP_GELU:
-                    GGML_CANN_CALL_UNARY_OP(Gelu);
+                case GGML_UNARY_OP_GELU_ERF:
+                    // aclnnGelu internally uses the erf-based approximation.
+                    GGML_CANN_CALL_OP_UNARY(Gelu);
                     break;
                 case GGML_UNARY_OP_SILU:
-                    GGML_CANN_CALL_UNARY_OP(Silu);
+                    GGML_CANN_CALL_OP_UNARY(Silu);
                     break;
                 case GGML_UNARY_OP_GELU_QUICK: {
                     auto lambda = [](ggml_backend_cann_context& ctx,
@@ -1633,31 +1713,31 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
                         aclTensor* acl_dst) {
                         GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
                     };
-                    ggml_cann_unary_op(lambda, ctx, dst);
+                    ggml_cann_op_unary(lambda, ctx, dst);
                 } break;
                 case GGML_UNARY_OP_TANH:
-                    GGML_CANN_CALL_UNARY_OP(Tanh);
+                    GGML_CANN_CALL_OP_UNARY(Tanh);
                     break;
                 case GGML_UNARY_OP_RELU:
-                    GGML_CANN_CALL_UNARY_OP(Relu);
+                    GGML_CANN_CALL_OP_UNARY(Relu);
                     break;
                 case GGML_UNARY_OP_SIGMOID:
-                    GGML_CANN_CALL_UNARY_OP(Sigmoid);
+                    GGML_CANN_CALL_OP_UNARY(Sigmoid);
                     break;
                 case GGML_UNARY_OP_HARDSIGMOID:
-                    GGML_CANN_CALL_UNARY_OP(Hardsigmoid);
+                    GGML_CANN_CALL_OP_UNARY(Hardsigmoid);
                     break;
                 case GGML_UNARY_OP_HARDSWISH:
-                    GGML_CANN_CALL_UNARY_OP(Hardswish);
+                    GGML_CANN_CALL_OP_UNARY(Hardswish);
                     break;
                 case GGML_UNARY_OP_EXP:
-                    GGML_CANN_CALL_UNARY_OP(Exp);
+                    GGML_CANN_CALL_OP_UNARY(Exp);
                     break;
                 case GGML_UNARY_OP_ELU:
                     ggml_cann_elu(ctx, dst);
                     break;
                 case GGML_UNARY_OP_SGN:
-                    GGML_CANN_CALL_UNARY_OP(Sign);
+                    GGML_CANN_CALL_OP_UNARY(Sign);
                     break;
                 case GGML_UNARY_OP_STEP:
                     ggml_cann_step(ctx, dst);
@@ -1666,6 +1746,31 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
                     return false;
             }
             break;
+        case GGML_OP_GLU:
+            switch (ggml_get_glu_op(dst)) {
+                case GGML_GLU_OP_REGLU:
+                    GGML_CANN_CALL_OP_UNARY_GATED(Relu);
+                    break;
+                case GGML_GLU_OP_GEGLU:
+                case GGML_GLU_OP_GEGLU_ERF:
+                    // aclnnGelu internally uses the erf-based approximation.
+                    GGML_CANN_CALL_OP_UNARY_GATED(Gelu);
+                    break;
+                case GGML_GLU_OP_SWIGLU:
+                    GGML_CANN_CALL_OP_UNARY_GATED(Silu);
+                    break;
+                case GGML_GLU_OP_GEGLU_QUICK: {
+                    auto lambda = [](ggml_backend_cann_context& ctx,
+                        aclTensor* acl_src,
+                        aclTensor* acl_dst) {
+                        GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
+                    };
+                    ggml_cann_op_unary_gated(lambda, ctx, dst);
+                } break;
+                default:
+                    return false;
+            }
+            break;
         case GGML_OP_NORM:
             ggml_cann_norm(ctx, dst);
             break;
@@ -1708,7 +1813,7 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
             ggml_cann_binary_op(ctx, dst);
             break;
         case GGML_OP_SQRT:
-            GGML_CANN_CALL_UNARY_OP(Sqrt);
+            GGML_CANN_CALL_OP_UNARY(Sqrt);
             break;
         case GGML_OP_CLAMP:
             ggml_cann_clamp(ctx, dst);
@@ -1753,16 +1858,16 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
             ggml_cann_argmax(ctx, dst);
             break;
         case GGML_OP_COS:
-            ggml_cann_unary_op(ctx, dst);
+            ggml_cann_op_unary(ctx, dst);
             break;
         case GGML_OP_SIN:
-            ggml_cann_unary_op(ctx, dst);
+            ggml_cann_op_unary(ctx, dst);
             break;
         case GGML_OP_CONV_TRANSPOSE_1D:
             ggml_cann_conv_transpose_1d(ctx, dst);
             break;
         case GGML_OP_LOG:
-            GGML_CANN_CALL_UNARY_OP(Log);
+            GGML_CANN_CALL_OP_UNARY(Log);
             break;
         case GGML_OP_MEAN:
             ggml_cann_mean(ctx, dst);
@@ -1911,6 +2016,9 @@ static bool ggml_backend_cann_cpy_tensor_async(
         (ggml_backend_cann_context*)backend_dst->context;
 
     size_t copy_size = ggml_nbytes(dst);
+    if (copy_size == 0) {
+        return true;
+    }
     if (backend_src != backend_dst) {
         ggml_backend_cann_buffer_context* buf_ctx_src =
             (ggml_backend_cann_buffer_context*)buf_src->context;
@@ -1967,6 +2075,160 @@ static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
     ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
 }
 
+#ifdef USE_ACL_GRAPH
+/**
+ * @brief Populate the internal CANN graph node properties from the ggml computation graph.
+ *
+ * This function copies all node attributes (operation type, dimensions, strides, input sources,
+ * and operation parameters) into the cached CANN graph structure for later reuse or comparison.
+ *
+ * @param cann_ctx  The CANN backend context.
+ * @param cgraph    The ggml computational graph.
+ */
+static void set_ggml_graph_node_properties(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) {
+    for (int node_idx = 0; node_idx < cgraph->n_nodes; node_idx++) {
+        ggml_tensor * node = cgraph->nodes[node_idx];
+        cann_ctx->cann_graph->ggml_graph_properties[node_idx].node_address = node->data;
+        cann_ctx->cann_graph->ggml_graph_properties[node_idx].node_op = node->op;
+
+        for (int dim = 0; dim < GGML_MAX_DIMS; dim++) {
+            cann_ctx->cann_graph->ggml_graph_properties[node_idx].ne[dim] = node->ne[dim];
+            cann_ctx->cann_graph->ggml_graph_properties[node_idx].nb[dim] = node->nb[dim];
+        }
+        for (int src = 0; src < GGML_MAX_SRC; src++) {
+            cann_ctx->cann_graph->ggml_graph_properties[node_idx].src_address[src] =
+                node->src[src] ? node->src[src]->data : nullptr;
+        }
+        memcpy(cann_ctx->cann_graph->ggml_graph_properties[node_idx].op_params, node->op_params, GGML_MAX_OP_PARAMS);
+    }
+}
+
+/**
+ * @brief Check if a ggml tensor node matches a previously captured CANN graph node.
+ *
+ * This function compares all relevant fields (address, op type, shape, source inputs, op params)
+ * to determine whether the current node matches a previously recorded version.
+ *
+ * @param node                  The current ggml tensor node.
+ * @param graph_node_properties The stored properties of a CANN graph node.
+ * @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
+ */
+static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
+    if (node->data != graph_node_properties->node_address &&
+           node->op != GGML_OP_VIEW) {
+        return false;
+    }
+    if (node->op != graph_node_properties->node_op) {
+        return false;
+    }
+    for (int i = 0; i < GGML_MAX_DIMS; i++) {
+        if (node->ne[i] != graph_node_properties->ne[i]) {
+            return false;
+        }
+        if (node->nb[i] != graph_node_properties->nb[i]) {
+            return false;
+        }
+    }
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        if (node->src[i] &&
+            node->src[i]->data != graph_node_properties->src_address[i] &&
+            node->op != GGML_OP_VIEW
+        ) {
+            return false;
+        }
+    }
+    if (node->op == GGML_OP_SCALE &&
+        memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
+        return false;
+    }
+    return true;
+}
+
+/**
+ * @brief Determine if the CANN graph needs to be rebuilt due to graph changes.
+ *
+ * This checks whether the number or properties of ggml graph nodes have changed
+ * compared to the last captured CANN graph. If so, the CANN graph must be re-captured.
+ *
+ * @param cann_ctx  The CANN backend context.
+ * @param cgraph    The current ggml computation graph.
+ * @return true if an update is required; false otherwise.
+ */
+static bool is_cann_graph_update_required(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) {
+    // The number of nodes is different, so the graph needs to be reconstructed.
+    if (cann_ctx->cann_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
+        cann_ctx->cann_graph->ggml_graph_properties.resize(cgraph->n_nodes);
+        return true;
+    }
+
+    // The number of nodes is the same; iterate over each node to check whether they match.
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        bool has_matching_properties = ggml_graph_node_has_matching_properties(
+            cgraph->nodes[i], &cann_ctx->cann_graph->ggml_graph_properties[i]);
+        if(!has_matching_properties) {
+            return true;
+        }
+    }
+    return false;
+}
+#endif  // USE_ACL_GRAPH
+
+/**
+ * @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API.
+ *
+ * If CANN graph execution is enabled and graph capture is required, this function begins
+ * graph capture, runs the graph, ends capture, and stores the captured graph.
+ *
+ * Otherwise, it falls back to op-by-op execution using the CANN compute kernel dispatcher.
+ *
+ * @param cann_ctx                 The CANN backend context.
+ * @param cgraph                   The ggml computation graph.
+ * @param use_cann_graph           Whether to use CANN graph execution.
+ * @param cann_graph_update_required Whether graph capture is needed due to graph changes.
+ */
+static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph,
+    bool & use_cann_graph, bool & cann_graph_update_required) {
+#ifdef USE_ACL_GRAPH
+    if (use_cann_graph && cann_graph_update_required) {
+        if (cann_ctx->cann_graph->graph != nullptr) {
+            ACL_CHECK(aclmdlRIDestroy(cann_ctx->cann_graph->graph));
+            cann_ctx->cann_graph->graph = nullptr;
+        }
+        ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL));
+    }
+#endif // USE_ACL_GRAPH
+
+    // Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
+    // With the use of CANN graphs, the execution will be performed by the graph launch.
+    if (!use_cann_graph || cann_graph_update_required) {
+        for (int i = 0; i < cgraph->n_nodes; i++) {
+            ggml_tensor * node = cgraph->nodes[i];
+
+            if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
+                continue;
+            }
+
+            bool ok = ggml_cann_compute_forward(*cann_ctx, node);
+            if (!ok) {
+                GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
+            }
+            GGML_ASSERT(ok);
+        }
+    }
+
+#ifdef USE_ACL_GRAPH
+    if (use_cann_graph && cann_graph_update_required) { // End CANN graph capture
+        ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &cann_ctx->cann_graph->graph));
+    }
+
+    if (use_cann_graph) {
+        // Execute graph
+        ACL_CHECK(aclmdlRIExecuteAsync(cann_ctx->cann_graph->graph, cann_ctx->stream()));
+    }
+#endif // USE_ACL_GRAPH
+}
+
+
 /**
  * @brief Computes a computational graph using a CANN backend.
  *
@@ -1983,24 +2245,37 @@ static enum ggml_status ggml_backend_cann_graph_compute(
     ggml_backend_t backend, ggml_cgraph* cgraph) {
     ggml_backend_cann_context* cann_ctx =
         (ggml_backend_cann_context*)backend->context;
-
     ggml_cann_set_device(cann_ctx->device);
+    release_nz_workspace();
+#ifdef USE_ACL_GRAPH
+    bool use_cann_graph = true;
+    bool cann_graph_update_required = false;
+
+    // check environment LLAMA_SET_ROWS
+    if (!cann_ctx->support_set_rows) {
+        use_cann_graph = false;
+    }
 
-    for (int i = 0; i < cgraph->n_nodes; i++) {
-        ggml_tensor* node = cgraph->nodes[i];
-
-        if (ggml_is_empty(node) || node->op == GGML_OP_NONE) {
-            continue;
+    if (use_cann_graph) {
+        if (cann_ctx->cann_graph == nullptr) {
+            cann_ctx->cann_graph.reset(new ggml_cann_graph());
+            cann_graph_update_required = true;
         }
 
-        bool ok = ggml_cann_compute_forward(*cann_ctx, node);
-
-        if (!ok) {
-            GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__,
-                    node->name, ggml_op_name(node->op));
-        }
-        GGML_ASSERT(ok);
+        cann_graph_update_required = is_cann_graph_update_required(cann_ctx, cgraph);
+        set_ggml_graph_node_properties(cann_ctx, cgraph);
     }
+#else
+    bool use_cann_graph = false;
+    bool cann_graph_update_required = false;
+#endif  // USE_ACL_GRAPH
+
+    evaluate_and_capture_cann_graph(
+        cann_ctx,
+        cgraph,
+        use_cann_graph,
+        cann_graph_update_required
+    );
 
     return GGML_STATUS_SUCCESS;
 }
@@ -2036,10 +2311,23 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
                 case GGML_UNARY_OP_ELU:
                 case GGML_UNARY_OP_SGN:
                 case GGML_UNARY_OP_STEP:
+                case GGML_UNARY_OP_GELU_ERF:
                     return true;
                 default:
                     return false;
             }
+        case GGML_OP_GLU:
+            switch (ggml_get_glu_op(op)) {
+                case GGML_GLU_OP_REGLU:
+                case GGML_GLU_OP_GEGLU:
+                case GGML_GLU_OP_SWIGLU:
+                case GGML_GLU_OP_GEGLU_ERF:
+                case GGML_GLU_OP_GEGLU_QUICK:
+                    return true;
+                default:
+                    return false;
+            }
+            break;
         case GGML_OP_MUL_MAT: {
             switch (op->src[0]->type) {
                 case GGML_TYPE_F16:
@@ -2086,12 +2374,15 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
                     return false;
             }
         } break;
-        case GGML_OP_SET_ROWS:
-            {
-                // TODO: add support
-                // ref: https://github.com/ggml-org/llama.cpp/pull/14274
-                return false;
-            } break;
+        case GGML_OP_SET_ROWS: {
+            switch (op->type) {
+                case GGML_TYPE_F32:
+                case GGML_TYPE_F16:
+                    return true;
+                default:
+                    return false;
+            }
+        } break;
         case GGML_OP_CPY: {
             ggml_tensor *src = op->src[0];
             if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) ||
@@ -2100,12 +2391,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
                 // only support F32 and F16.
                 return false;
             }
-
-            if (!ggml_are_same_shape(op, src) && !ggml_is_contiguous(op)) {
-                // unsupport dst is not contiguous.
-                return false;
-            }
-
             return true;
         } break;
         case GGML_OP_CONT: {
@@ -2171,8 +2456,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
             // value of paddingW should be at most half of kernelW
             return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2));
         }
-        case GGML_OP_SUM:
         case GGML_OP_DUP:
+        case GGML_OP_SUM:
         case GGML_OP_IM2COL:
         case GGML_OP_CONCAT:
         case GGML_OP_REPEAT:
@@ -2214,9 +2499,11 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
             memcpy(&bias, (float*)op->op_params + 1, sizeof(float));
             return bias == 0.0f; // TODO: support bias != 0.0f
         case GGML_OP_SOFT_MAX:
-            // TODO: support broadcast
-            // ref: https://github.com/ggml-org/llama.cpp/pull/14435
-            return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
+            // TODO: support attention sinks [TAG_ATTN_SINKS]
+            if (op->src[2]) {
+                return false;
+            }
+            return true;
         case GGML_OP_FLASH_ATTN_EXT:{
             // derived from [ggml-cuda.cu]
             if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
@@ -2228,6 +2515,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
             if(op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16){
                 return false;
             }
+            // TODO: support attention sinks [TAG_ATTN_SINKS]
+            if (op->src[4]) {
+                return false;
+            }
             if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
                 // different head sizes of K and V are not supported yet
                 return false;
@@ -2239,11 +2530,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
                 // DeepSeek MLA
                 return false;
             }
-            // TODO: support broadcast
-            // ref: https://github.com/ggml-org/llama.cpp/pull/14435
-            if (op->src[0]->ne[3] != 1) {
-                return false;
-            }
             float logitSoftcap = 0.0f;
             memcpy(&logitSoftcap,  (float*)op->op_params + 2, sizeof(float));
             if(logitSoftcap != 0.0f) {
diff --git a/ggml/src/ggml-cann/kernels/CMakeLists.txt b/ggml/src/ggml-cann/kernels/CMakeLists.txt
deleted file mode 100644
index d687220c3c5..00000000000
--- a/ggml/src/ggml-cann/kernels/CMakeLists.txt
+++ /dev/null
@@ -1,30 +0,0 @@
-file(GLOB SRC_FILES
-    get_row_f32.cpp
-    get_row_f16.cpp
-    get_row_q4_0.cpp
-    get_row_q8_0.cpp
-    quantize_f32_q8_0.cpp
-    quantize_f16_q8_0.cpp
-    quantize_float_to_q4_0.cpp
-    dup.cpp
-)
-
-set(ASCEND_CANN_PACKAGE_PATH ${CANN_INSTALL_DIR})
-set(RUN_MODE "npu" CACHE STRING "run mode: npu/sim")
-
-if(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake)
-    set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake)
-elseif(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/ascendc_devkit/tikcpp/samples/cmake)
-    set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/ascendc_devkit/tikcpp/samples/cmake)
-else()
-    message(FATAL_ERROR "ascendc_kernel_cmake does not exist, please check whether the compiler package is installed.")
-endif()
-include(${ASCENDC_CMAKE_DIR}/ascendc.cmake)
-
-ascendc_library(ascendc_kernels STATIC
-    ${SRC_FILES}
-)
-
-message(STATUS "CANN: compile ascend kernels witch SOC_TYPE:${SOC_TYPE}, SOC_VERSION:${SOC_VERSION}, compile macro:-D${SOC_TYPE_COMPILE_OPTION}.")
-ascendc_compile_definitions(ascendc_kernels PRIVATE "-D${SOC_TYPE_COMPILE_OPTION}")
-# ascendc_compile_definitions(ascendc_kernels PRIVATE -DASCENDC_DUMP)
diff --git a/ggml/src/ggml-cann/kernels/ascendc_kernels.h b/ggml/src/ggml-cann/kernels/ascendc_kernels.h
deleted file mode 100644
index 7e153208cfd..00000000000
--- a/ggml/src/ggml-cann/kernels/ascendc_kernels.h
+++ /dev/null
@@ -1,19 +0,0 @@
-#ifndef ASCENDC_KERNELS_H
-#define ASCENDC_KERNELS_H
-
-#include "aclrtlaunch_ascendc_get_row_f32.h"
-#include "aclrtlaunch_ascendc_get_row_f16.h"
-#include "aclrtlaunch_ascendc_get_row_q8_0.h"
-#include "aclrtlaunch_ascendc_get_row_q4_0.h"
-
-#include "aclrtlaunch_ascendc_quantize_f32_q8_0.h"
-#include "aclrtlaunch_ascendc_quantize_f16_q8_0.h"
-#include "aclrtlaunch_ascendc_quantize_f16_to_q4_0.h"
-#include "aclrtlaunch_ascendc_quantize_f32_to_q4_0.h"
-
-#include "aclrtlaunch_ascendc_dup_by_rows_fp16.h"
-#include "aclrtlaunch_ascendc_dup_by_rows_fp32.h"
-#include "aclrtlaunch_ascendc_dup_by_rows_fp32_to_fp16.h"
-#include "aclrtlaunch_ascendc_dup_by_rows_fp16_to_fp32.h"
-
-#endif  // ASCENDC_KERNELS_H
diff --git a/ggml/src/ggml-cann/kernels/dup.cpp b/ggml/src/ggml-cann/kernels/dup.cpp
deleted file mode 100644
index d9b9574494b..00000000000
--- a/ggml/src/ggml-cann/kernels/dup.cpp
+++ /dev/null
@@ -1,234 +0,0 @@
-#include "kernel_operator.h"
-
-using namespace AscendC;
-
-#define BUFFER_NUM 2
-const int64_t SUPPORTED_MAX_DIM = 65535;  // currently the limit of max block dim supportted by dup kernel is 65535template 
-
-template 
-class DupByRows {
-   public:
-    __aicore__ inline DupByRows() {}
-    __aicore__ inline void init(GM_ADDR src, GM_ADDR dst, int64_t *input_ne_ub,
-                                size_t *input_nb_ub) {
-        /* Dup by rows when src is contigous on first dimension and dst is
-        contiguous, each kernel process one row.
-        */
-
-        // Input has four dims.
-        int64_t op_block_num = GetBlockNum();
-        int64_t op_block_idx = GetBlockIdx();
-
-        // param
-        num_rows = input_ne_ub[1] * input_ne_ub[2] * input_ne_ub[3];
-        num_elem = input_ne_ub[0];
-
-        // index for (ne[1], ne[2], ne[3]): (idx_ne1, idx_ne2, idx_ne3)
-        idx_ne3 = op_block_idx / (input_ne_ub[1] * input_ne_ub[2]);
-        idx_ne2 = (op_block_idx - idx_ne3 * (input_ne_ub[1] * input_ne_ub[2]))
-                  / (input_ne_ub[1]);
-        idx_ne1 = op_block_idx - idx_ne3 * (input_ne_ub[1] * input_ne_ub[2])
-                - idx_ne2 * input_ne_ub[1];
-
-        // src may not contiguous in dim [1,2,3], so stride decited by ne&nb
-        src_stride = input_nb_ub[3] * idx_ne3 + input_nb_ub[2] * idx_ne2
-                     + input_nb_ub[1] * idx_ne1;
-
-        // dst is contiguous
-        dst_stride = op_block_idx * (input_ne_ub[0] * sizeof(DST_T));
-
-        src_gm.SetGlobalBuffer(reinterpret_cast<__gm__ SRC_T *>(src +
-                                                                src_stride));
-        dst_gm.SetGlobalBuffer(reinterpret_cast<__gm__ DST_T *>(dst +
-                                                                dst_stride));
-
-        pipe.InitBuffer(src_queue, BUFFER_NUM, (sizeof(SRC_T) * num_elem +
-                                                32 - 1) / 32 * 32);
-        pipe.InitBuffer(dst_queue, BUFFER_NUM, (sizeof(DST_T) * num_elem +
-                                                32 - 1) / 32 * 32);
-    }
-
-    __aicore__ inline void copy_in() {
-        LocalTensor src_local = src_queue.AllocTensor();
-        const size_t elem_per_block = 32 / sizeof(SRC_T);
-        size_t tail = num_elem % elem_per_block;
-        size_t cpy_elements_len = tail > 0 ? num_elem + 1 : num_elem;
-        DataCopy(src_local, src_gm, cpy_elements_len);
-        src_queue.EnQue(src_local);
-    }
-
-    __aicore__ inline void copy_out() {
-        LocalTensor dst_local = dst_queue.DeQue();
-#ifdef ASCEND_310P
-        const size_t elem_per_block = 32 / sizeof(DST_T);
-        size_t tail = num_elem % elem_per_block;
-        size_t len = num_elem & ~(elem_per_block - 1);
-        if (len > 0) {
-            DataCopy(dst_gm, dst_local, len);
-        }
-        if(tail != 0) {
-            for (size_t i = tail; i < elem_per_block; i++) {
-                dst_local[len + i].SetValue(0, 0);
-            }
-            SetAtomicAdd();
-            DataCopy(dst_gm[len], dst_local[len], elem_per_block);
-            SetAtomicNone();
-        }
-#else
-        DataCopyExtParams dataCopyParams;
-        dataCopyParams.blockCount = 1;
-        dataCopyParams.blockLen = num_elem * sizeof(DST_T);
-        DataCopyPad(dst_gm, dst_local, dataCopyParams);
-#endif
-        dst_queue.FreeTensor(dst_local);
-    }
-
-    __aicore__ inline void dup() {
-        // main process, copy one row data from src to dst.
-        copy_in();
-
-        LocalTensor src_local = src_queue.DeQue();
-        LocalTensor dst_local = dst_queue.AllocTensor();
-
-        int32_t BLOCK_NUM = 32 / sizeof(DST_T);
-        DataCopy(dst_local, src_local, (num_elem + BLOCK_NUM - 1)
-                                        / BLOCK_NUM * BLOCK_NUM);
-        dst_queue.EnQue(dst_local);
-
-        src_queue.FreeTensor(src_local);
-        copy_out();
-    }
-
-    __aicore__ inline void dup_with_cast() {
-        // main process, copy one row data from src to dst.
-        // cast dtype from src to dst.
-        copy_in();
-
-        LocalTensor src_local = src_queue.DeQue();
-        LocalTensor dst_local = dst_queue.AllocTensor();
-
-        Cast(dst_local, src_local, RoundMode::CAST_NONE, num_elem);
-        dst_queue.EnQue(dst_local);
-
-        src_queue.FreeTensor(src_local);
-        copy_out();
-    }
-
-   private:
-
-    TPipe pipe;
-    GlobalTensor src_gm;
-    GlobalTensor dst_gm;
-
-    int64_t num_rows;
-    int64_t num_elem;
-    int64_t idx_ne3;
-    int64_t idx_ne2;
-    int64_t idx_ne1;
-    int64_t src_stride;
-    int64_t dst_stride;
-
-    TQue src_queue;
-    TQue dst_queue;
-};
-
-template 
-__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
-    auto gm_ptr = (__gm__ uint8_t *)gm;
-    auto ub_ptr = (uint8_t *)(ub);
-    for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
-        *ub_ptr = *gm_ptr;
-    }
-}
-
-extern "C" __global__ __aicore__ void ascendc_dup_by_rows_fp16(
-                                                        GM_ADDR src_gm,
-                                                        GM_ADDR dst_gm,
-                                                        GM_ADDR input_ne_gm,
-                                                        GM_ADDR input_nb_gm,
-                                                        GM_ADDR output_ne_gm,
-                                                        GM_ADDR output_nb_gm) {
-
-    int64_t input_ne_ub[4];
-    size_t input_nb_ub[4];
-    int64_t output_ne_ub[4];
-    size_t output_nb_ub[4];
-
-    copy_to_ub(input_ne_gm, input_ne_ub, 32);
-    copy_to_ub(input_nb_gm, input_nb_ub, 32);
-    copy_to_ub(output_ne_gm, output_ne_ub, 32);
-    copy_to_ub(output_nb_gm, output_nb_ub, 32);
-
-    DupByRows op;
-    op.init(src_gm, dst_gm, input_ne_ub, input_nb_ub);
-    op.dup();
-}
-
-extern "C" __global__ __aicore__ void ascendc_dup_by_rows_fp32(
-                                                        GM_ADDR src_gm,
-                                                        GM_ADDR dst_gm,
-                                                        GM_ADDR input_ne_gm,
-                                                        GM_ADDR input_nb_gm,
-                                                        GM_ADDR output_ne_gm,
-                                                        GM_ADDR output_nb_gm) {
-    int64_t input_ne_ub[4];
-    size_t input_nb_ub[4];
-    int64_t output_ne_ub[4];
-    size_t output_nb_ub[4];
-
-    copy_to_ub(input_ne_gm, input_ne_ub, 32);
-    copy_to_ub(input_nb_gm, input_nb_ub, 32);
-    copy_to_ub(output_ne_gm, output_ne_ub, 32);
-    copy_to_ub(output_nb_gm, output_nb_ub, 32);
-
-    DupByRows op;
-    op.init(src_gm, dst_gm, input_ne_ub, input_nb_ub);
-    op.dup();
-}
-
-extern "C" __global__ __aicore__ void ascendc_dup_by_rows_fp32_to_fp16(
-                                                        GM_ADDR src_gm,
-                                                        GM_ADDR dst_gm,
-                                                        GM_ADDR input_ne_gm,
-                                                        GM_ADDR input_nb_gm,
-                                                        GM_ADDR output_ne_gm,
-                                                        GM_ADDR output_nb_gm) {
-
-    int64_t input_ne_ub[4];
-    size_t input_nb_ub[4];
-    int64_t output_ne_ub[4];
-    size_t output_nb_ub[4];
-
-    copy_to_ub(input_ne_gm, input_ne_ub, 32);
-    copy_to_ub(input_nb_gm, input_nb_ub, 32);
-    copy_to_ub(output_ne_gm, output_ne_ub, 32);
-    copy_to_ub(output_nb_gm, output_nb_ub, 32);
-
-    DupByRows op;
-    op.init(src_gm, dst_gm, input_ne_ub, input_nb_ub);
-    op.dup_with_cast();
-}
-
-extern "C" __global__ __aicore__ void ascendc_dup_by_rows_fp16_to_fp32(
-                                                        GM_ADDR src_gm,
-                                                        GM_ADDR dst_gm,
-                                                        GM_ADDR input_ne_gm,
-                                                        GM_ADDR input_nb_gm,
-                                                        GM_ADDR output_ne_gm,
-                                                        GM_ADDR output_nb_gm) {
-
-    // copy params from gm to ub.
-    int64_t input_ne_ub[4];
-    size_t input_nb_ub[4];
-    int64_t output_ne_ub[4];
-    size_t output_nb_ub[4];
-
-    copy_to_ub(input_ne_gm, input_ne_ub, 32);
-    copy_to_ub(input_nb_gm, input_nb_ub, 32);
-    copy_to_ub(output_ne_gm, output_ne_ub, 32);
-    copy_to_ub(output_nb_gm, output_nb_ub, 32);
-
-    DupByRows op;
-    op.init(src_gm, dst_gm, input_ne_ub, input_nb_ub);
-    op.dup_with_cast();
-}
diff --git a/ggml/src/ggml-cann/kernels/get_row_f16.cpp b/ggml/src/ggml-cann/kernels/get_row_f16.cpp
deleted file mode 100644
index 416b45104de..00000000000
--- a/ggml/src/ggml-cann/kernels/get_row_f16.cpp
+++ /dev/null
@@ -1,197 +0,0 @@
-#include "kernel_operator.h"
-
-// optimize me. Use template to avoid copy code.
-using namespace AscendC;
-
-#define BUFFER_NUM 2
-
-class GET_ROW_F16 {
-   public:
-    __aicore__ inline GET_ROW_F16() {}
-    __aicore__ inline void init(GM_ADDR input, GM_ADDR indices, GM_ADDR output,
-                                int64_t *input_ne_ub, size_t *input_nb_ub,
-                                int64_t *indices_ne_ub, size_t *indices_nb_ub,
-                                int64_t *output_ne_ub, size_t *output_nb_ub) {
-        // TODO, use template for F16/f32
-        int64_t op_block_num = GetBlockNum();
-        op_block_idx = GetBlockIdx();
-
-        for (int i = 0; i < 4; i++) {
-            input_ne[i] = input_ne_ub[i];
-            input_stride[i] = input_nb_ub[i] / input_nb_ub[0];
-
-            indices_ne[i] = indices_ne_ub[i];
-            indices_stride[i] = indices_nb_ub[i] / indices_nb_ub[0];
-
-            output_ne[i] = output_ne_ub[i];
-            output_stride[i] = output_nb_ub[i] / output_nb_ub[0];
-        }
-
-        // Indices has two dims. n_elements = all rows should get.
-        // dr, all rows should this thread get.
-        uint64_t n_elements =
-            indices_ne[0] * indices_ne[1] * indices_ne[2] * indices_ne[3];
-        dr = n_elements / op_block_num;
-
-        uint64_t tails = n_elements % op_block_num;
-        if (op_block_idx < tails) {
-            dr += 1;
-            ir = dr * op_block_idx;
-        } else {
-            ir = dr * op_block_idx + tails;
-        }
-
-        input_gm.SetGlobalBuffer((__gm__ half *)input);
-        indices_gm.SetGlobalBuffer((__gm__ int32_t *)indices);
-        output_gm.SetGlobalBuffer((__gm__ float *)output);
-
-        uint64_t input_local_buffer_size = ((input_ne[0] * sizeof(half) + 31)
-                                             & ~31);
-        uint64_t output_local_buffer_size = ((input_ne[0] * sizeof(float) + 31)
-                                              & ~31);
-
-        local_buffer_elems = input_local_buffer_size / sizeof(half);
-
-        // TODO, consider long row that can't put in UB.
-        // All data should asign to 32. It's ok because all data is align to 32.
-        pipe.InitBuffer(input_queue, BUFFER_NUM, input_local_buffer_size);
-        pipe.InitBuffer(output_queue, BUFFER_NUM, output_local_buffer_size);
-    }
-
-    __aicore__ inline void copy_in(uint32_t offset, size_t len) {
-        size_t origin_len = len;
-        LocalTensor input_local = input_queue.AllocTensor();
-        const size_t elem_per_block = 32 / sizeof(half);
-        size_t tail = len % elem_per_block;
-        len = len & ~(elem_per_block - 1);
-        if(tail != 0) {
-            len += elem_per_block;
-        }
-        DataCopy(input_local, input_gm[offset], len);
-        input_queue.EnQue(input_local);
-    }
-
-    __aicore__ inline void copy_out(uint32_t offset, size_t len) {
-        LocalTensor output_local = output_queue.DeQue();
-        const size_t elem_per_block = 32 / sizeof(float);
-        size_t tail = len % elem_per_block;
-        len = len & ~(elem_per_block - 1);
-        if (len > 0) {
-            DataCopy(output_gm[offset], output_local, len);
-        }
-
-        if(tail != 0) {
-#ifdef ASCEND_310P
-            for (size_t i = tail; i < elem_per_block; i++) {
-                output_local[len + i].SetValue(0, 0);
-            }
-            SetAtomicAdd();
-            DataCopy(output_gm[offset + len], output_local[len], elem_per_block);
-            SetAtomicNone();
-#else
-            DataCopyExtParams dataCopyParams;
-            dataCopyParams.blockCount = 1;
-            dataCopyParams.blockLen = tail * sizeof(float);
-            DataCopyPad(output_gm[offset + len], output_local[len],
-                        dataCopyParams);
-#endif
-        }
-        output_queue.FreeTensor(output_local);
-    }
-
-    __aicore__ inline void calculate_row(int64_t idx) {
-        const int64_t indices_ne2_idx = idx / (indices_ne[0] * indices_ne[1]);
-        const int64_t indices_ne1_idx =
-            (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1]) /
-            indices_ne[0];
-        const int64_t indices_ne0_idx =
-            (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1] -
-             indices_ne1_idx * indices_ne[0]);
-
-        const int64_t indices_offset = indices_ne0_idx * indices_stride[0] +
-                                       indices_ne1_idx * indices_stride[1] +
-                                       indices_ne2_idx * indices_stride[2];
-        const int32_t selected_row_idx = indices_gm.GetValue(indices_offset);
-
-        const int64_t input_offset = selected_row_idx * input_stride[1] +
-                                     indices_ne1_idx * input_stride[2] +
-                                     indices_ne2_idx * input_stride[3];
-
-        const int64_t output_offset = indices_ne0_idx * output_stride[1] +
-                                      indices_ne1_idx * output_stride[2] +
-                                      indices_ne2_idx * output_stride[3];
-
-        copy_in(input_offset, input_ne[0]);
-        LocalTensor input_local = input_queue.DeQue();
-        LocalTensor output_local = output_queue.AllocTensor();
-
-        Cast(output_local, input_local, RoundMode::CAST_NONE,
-             local_buffer_elems);
-        output_queue.EnQue(output_local);
-        copy_out(output_offset, input_ne[0]);
-
-        input_queue.FreeTensor(input_local);
-    }
-
-    __aicore__ inline void calculate() {
-        for (int64_t i = ir; i < ir + dr; i++) {
-            calculate_row(i);
-        }
-    }
-
-   private:
-    int64_t input_ne[4];
-    size_t input_stride[4];
-
-    int64_t indices_ne[4];
-    size_t indices_stride[4];
-
-    int64_t output_ne[4];
-    size_t output_stride[4];
-
-    size_t local_buffer_elems;
-
-    int64_t ir;
-    int64_t dr;
-
-    TPipe pipe;
-    GlobalTensor input_gm;
-    GlobalTensor indices_gm;
-    GlobalTensor output_gm;
-    TQue input_queue;
-    TQue output_queue;
-    int64_t op_block_idx;
-};
-
-template 
-__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
-    auto gm_ptr = (__gm__ uint8_t *)gm;
-    auto ub_ptr = (uint8_t *)(ub);
-    for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
-        *ub_ptr = *gm_ptr;
-    }
-}
-
-extern "C" __global__ __aicore__ void ascendc_get_row_f16(
-    GM_ADDR input_gm, GM_ADDR indices_gm, GM_ADDR output_gm,
-    GM_ADDR input_ne_gm, GM_ADDR input_nb_gm, GM_ADDR indices_ne_gm,
-    GM_ADDR indices_nb_gm, GM_ADDR output_ne_gm, GM_ADDR output_nb_gm) {
-    int64_t input_ne_ub[4];
-    size_t input_nb_ub[4];
-    int64_t indices_ne_ub[4];
-    size_t indices_nb_ub[4];
-    int64_t output_ne_ub[4];
-    size_t output_nb_ub[4];
-
-    copy_to_ub(input_ne_gm, input_ne_ub, 32);
-    copy_to_ub(input_nb_gm, input_nb_ub, 32);
-    copy_to_ub(indices_ne_gm, indices_ne_ub, 32);
-    copy_to_ub(indices_nb_gm, indices_nb_ub, 32);
-    copy_to_ub(output_ne_gm, output_ne_ub, 32);
-    copy_to_ub(output_nb_gm, output_nb_ub, 32);
-
-    GET_ROW_F16 op;
-    op.init(input_gm, indices_gm, output_gm, input_ne_ub, input_nb_ub,
-            indices_ne_ub, indices_nb_ub, output_ne_ub, output_nb_ub);
-    op.calculate();
-}
diff --git a/ggml/src/ggml-cann/kernels/get_row_f32.cpp b/ggml/src/ggml-cann/kernels/get_row_f32.cpp
deleted file mode 100644
index 02116905b18..00000000000
--- a/ggml/src/ggml-cann/kernels/get_row_f32.cpp
+++ /dev/null
@@ -1,190 +0,0 @@
-#include "kernel_operator.h"
-
-// optimize me. Use template to avoid copy code.
-using namespace AscendC;
-
-#define BUFFER_NUM 2
-
-class GET_ROW_F32 {
-   public:
-    __aicore__ inline GET_ROW_F32() {}
-    __aicore__ inline void init(GM_ADDR input, GM_ADDR indices, GM_ADDR output,
-                                int64_t *input_ne_ub, size_t *input_nb_ub,
-                                int64_t *indices_ne_ub, size_t *indices_nb_ub,
-                                int64_t *output_ne_ub, size_t *output_nb_ub) {
-        int64_t op_block_num = GetBlockNum();
-        op_block_idx = GetBlockIdx();
-
-        for (int i = 0; i < 4; i++) {
-            input_ne[i] = input_ne_ub[i];
-            input_stride[i] = input_nb_ub[i] / input_nb_ub[0];
-
-            indices_ne[i] = indices_ne_ub[i];
-            indices_stride[i] = indices_nb_ub[i] / indices_nb_ub[0];
-
-            output_ne[i] = output_ne_ub[i];
-            output_stride[i] = output_nb_ub[i] / output_nb_ub[0];
-        }
-
-        // Indices has two dims. n_elements = all rows should get.
-        // dr, all rows should this thread get.
-        uint64_t n_elements =
-            indices_ne[0] * indices_ne[1] * indices_ne[2] * indices_ne[3];
-        dr = n_elements / op_block_num;
-
-        uint64_t tails = n_elements % op_block_num;
-        if (op_block_idx < tails) {
-            dr += 1;
-            ir = dr * op_block_idx;
-        } else {
-            ir = dr * op_block_idx + tails;
-        }
-
-        input_gm.SetGlobalBuffer((__gm__ float *)input);
-        indices_gm.SetGlobalBuffer((__gm__ int32_t *)indices);
-        output_gm.SetGlobalBuffer((__gm__ float *)output);
-
-        uint64_t local_buffer_size = ((input_ne[0] * sizeof(float) + 31) & ~31);
-        local_buffer_elems = local_buffer_size / sizeof(float);
-
-        // TODO, consider long row that can't put in UB.
-        // All data should asign to 32. It's ok because all data is align to 32.
-        pipe.InitBuffer(input_queue, BUFFER_NUM, local_buffer_size);
-        pipe.InitBuffer(output_queue, BUFFER_NUM, local_buffer_size);
-    }
-
-    __aicore__ inline void copy_in(uint32_t offset, size_t len) {
-        LocalTensor input_local = input_queue.AllocTensor();
-        const size_t elem_per_block = 32 / sizeof(float);
-        size_t tail = len % elem_per_block;
-        len = len & ~(elem_per_block - 1);
-        if(tail != 0) {
-            len += elem_per_block;
-        }
-        DataCopy(input_local, input_gm[offset], len);
-        input_queue.EnQue(input_local);
-    }
-
-    __aicore__ inline void copy_out(uint32_t offset, size_t len) {
-        LocalTensor output_local = output_queue.DeQue();
-        const size_t elem_per_block = 32 / sizeof(float);
-        size_t tail = len % elem_per_block;
-        len = len & ~(elem_per_block - 1);
-        if (len > 0) {
-            DataCopy(output_gm[offset], output_local, len);
-        }
-
-        if(tail != 0) {
-#ifdef ASCEND_310P
-            for (size_t i = tail; i < elem_per_block; i++) {
-                output_local[len + i].SetValue(0, 0);
-            }
-            SetAtomicAdd();
-            DataCopy(output_gm[offset + len], output_local[len], elem_per_block);
-            SetAtomicNone();
-#else
-            DataCopyExtParams dataCopyParams;
-            dataCopyParams.blockCount = 1;
-            dataCopyParams.blockLen = tail * sizeof(float);
-            DataCopyPad(output_gm[offset + len], output_local[len],
-                        dataCopyParams);
-#endif
-        }
-        output_queue.FreeTensor(output_local);
-    }
-
-    __aicore__ inline void calculate_row(int64_t idx) {
-        const int64_t indices_ne2_idx = idx / (indices_ne[0] * indices_ne[1]);
-        const int64_t indices_ne1_idx =
-            (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1]) /
-            indices_ne[0];
-        const int64_t indices_ne0_idx =
-            (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1] -
-             indices_ne1_idx * indices_ne[0]);
-
-        const int64_t indices_offset = indices_ne0_idx * indices_stride[0] +
-                                       indices_ne1_idx * indices_stride[1] +
-                                       indices_ne2_idx * indices_stride[2];
-        const int32_t selected_row_idx = indices_gm.GetValue(indices_offset);
-
-        const int64_t input_offset = selected_row_idx * input_stride[1] +
-                                     indices_ne1_idx * input_stride[2] +
-                                     indices_ne2_idx * input_stride[3];
-
-        const int64_t output_offset = indices_ne0_idx * output_stride[1] +
-                                      indices_ne1_idx * output_stride[2] +
-                                      indices_ne2_idx * output_stride[3];
-
-        copy_in(input_offset, input_ne[0]);
-        LocalTensor input_local = input_queue.DeQue();
-        LocalTensor output_local = output_queue.AllocTensor();
-
-        DataCopy(output_local, input_local, local_buffer_elems);
-        output_queue.EnQue(output_local);
-        copy_out(output_offset, input_ne[0]);
-
-        input_queue.FreeTensor(input_local);
-    }
-
-    __aicore__ inline void calculate() {
-        for (int64_t i = ir; i < ir + dr; i++) {
-            calculate_row(i);
-        }
-    }
-
-   private:
-    int64_t input_ne[4];
-    size_t input_stride[4];
-
-    int64_t indices_ne[4];
-    size_t indices_stride[4];
-
-    int64_t output_ne[4];
-    size_t output_stride[4];
-
-    size_t local_buffer_elems;
-
-    int64_t ir;
-    int64_t dr;
-
-    TPipe pipe;
-    GlobalTensor input_gm;
-    GlobalTensor indices_gm;
-    GlobalTensor output_gm;
-    TQue input_queue;
-    TQue output_queue;
-    int64_t op_block_idx;
-};
-
-template 
-__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
-    auto gm_ptr = (__gm__ uint8_t *)gm;
-    auto ub_ptr = (uint8_t *)(ub);
-    for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
-        *ub_ptr = *gm_ptr;
-    }
-}
-
-extern "C" __global__ __aicore__ void ascendc_get_row_f32(
-    GM_ADDR input_gm, GM_ADDR indices_gm, GM_ADDR output_gm,
-    GM_ADDR input_ne_gm, GM_ADDR input_nb_gm, GM_ADDR indices_ne_gm,
-    GM_ADDR indices_nb_gm, GM_ADDR output_ne_gm, GM_ADDR output_nb_gm) {
-    int64_t input_ne_ub[4];
-    size_t input_nb_ub[4];
-    int64_t indices_ne_ub[4];
-    size_t indices_nb_ub[4];
-    int64_t output_ne_ub[4];
-    size_t output_nb_ub[4];
-
-    copy_to_ub(input_ne_gm, input_ne_ub, 32);
-    copy_to_ub(input_nb_gm, input_nb_ub, 32);
-    copy_to_ub(indices_ne_gm, indices_ne_ub, 32);
-    copy_to_ub(indices_nb_gm, indices_nb_ub, 32);
-    copy_to_ub(output_ne_gm, output_ne_ub, 32);
-    copy_to_ub(output_nb_gm, output_nb_ub, 32);
-
-    GET_ROW_F32 op;
-    op.init(input_gm, indices_gm, output_gm, input_ne_ub, input_nb_ub,
-            indices_ne_ub, indices_nb_ub, output_ne_ub, output_nb_ub);
-    op.calculate();
-}
diff --git a/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp b/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp
deleted file mode 100644
index 4fbe722086c..00000000000
--- a/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp
+++ /dev/null
@@ -1,204 +0,0 @@
-#include "kernel_operator.h"
-
-// optimize me. Use template to avoid copy code.
-using namespace AscendC;
-#ifdef ASCEND_310P // 310P not support 4bit get row
-    extern "C" __global__ __aicore__ void ascendc_get_row_q4_0(
-        GM_ADDR input_gm, GM_ADDR indices_gm, GM_ADDR output_gm,
-        GM_ADDR input_ne_gm, GM_ADDR indices_ne_gm, GM_ADDR indices_nb_gm,
-        GM_ADDR output_ne_gm, GM_ADDR output_nb_gm) {
-        // let following test cases can continue run, here just print error information. Of Cource the test case that call this operator is failed.
-        printf("Ascend310P not support 4bit get row.\n");
-    }
-#else
-
-#define BUFFER_NUM 2
-
-#define QK4_0 32
-
-class GET_ROW_Q4_0 {
-   public:
-    __aicore__ inline GET_ROW_Q4_0() {}
-    __aicore__ inline void init(GM_ADDR input, GM_ADDR indices, GM_ADDR output,
-                                int64_t *input_ne_ub, int64_t *indices_ne_ub,
-                                size_t *indices_nb_ub, int64_t *output_ne_ub,
-                                size_t *output_nb_ub) {
-        int64_t op_block_num = GetBlockNum();
-        int64_t op_block_idx = GetBlockIdx();
-
-        for (int i = 0; i < 4; i++) {
-            input_ne[i] = input_ne_ub[i];
-            indices_ne[i] = indices_ne_ub[i];
-            indices_stride[i] = indices_nb_ub[i] / indices_nb_ub[0];
-            scale_ne[i] = input_ne_ub[i];
-            output_ne[i] = output_ne_ub[i];
-            output_stride[i] = output_nb_ub[i] / output_nb_ub[0];
-        }
-
-        // one scale for a group.
-        scale_ne[0] /= QK4_0;
-
-        input_stride[0] = 1;
-        scale_stride[0] = 1;
-        output_stride[0] = 1;
-        for (int i = 1; i < 4; i++) {
-            input_stride[i] = input_stride[i - 1] * input_ne[i - 1];
-            scale_stride[i] = scale_stride[i - 1] * scale_ne[i - 1];
-        }
-
-        group_size_in_row = input_ne[0] / QK4_0;
-        int64_t scale_offset = input_ne[0] * input_ne[1] * input_ne[2] *
-                               input_ne[3] / 2;
-
-        // Indices has two dims. n_elements = all rows should get.
-        // dr, all rows should this thread get.
-        uint64_t n_elements =
-            indices_ne[0] * indices_ne[1] * indices_ne[2] * indices_ne[3];
-        dr = n_elements / op_block_num;
-
-        uint64_t tails = n_elements % op_block_num;
-        if (op_block_idx < tails) {
-            dr += 1;
-            ir = dr * op_block_idx;
-        } else {
-            ir = dr * op_block_idx + tails;
-        }
-
-        input_gm.SetGlobalBuffer((__gm__ int4b_t *)input);
-        scale_gm.SetGlobalBuffer((__gm__ half *)(input + scale_offset));
-        indices_gm.SetGlobalBuffer((__gm__ int32_t *)indices);
-        output_gm.SetGlobalBuffer((__gm__ float *)output);
-
-        pipe.InitBuffer(input_queue, BUFFER_NUM, QK4_0 * sizeof(int4b_t));
-        pipe.InitBuffer(cast_queue, BUFFER_NUM, QK4_0 * sizeof(half));
-        pipe.InitBuffer(output_queue, BUFFER_NUM, QK4_0 * sizeof(float));
-    }
-
-    __aicore__ inline void copy_in(uint32_t offset) {
-        LocalTensor input_local = input_queue.AllocTensor();
-        // 32 * sizeof(int4b_t) = 16, which is not aligned to 32, why no error?
-        DataCopy(input_local, input_gm[offset], QK4_0);
-        input_queue.EnQue(input_local);
-    }
-
-    __aicore__ inline void copy_out(uint32_t offset) {
-        LocalTensor output_local = output_queue.DeQue();
-        DataCopy(output_gm[offset], output_local, QK4_0);
-        output_queue.FreeTensor(output_local);
-    }
-
-    __aicore__ inline void calculate_group(int64_t idx, int64_t group) {
-        const int64_t indices_ne2_idx = idx / (indices_ne[0] * indices_ne[1]);
-        const int64_t indices_ne1_idx =
-            (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1]) /
-            indices_ne[0];
-        const int64_t indices_ne0_idx =
-            (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1] -
-             indices_ne1_idx * indices_ne[0]);
-
-        const int64_t indices_offset = indices_ne0_idx * indices_stride[0] +
-                                       indices_ne1_idx * indices_stride[1] +
-                                       indices_ne2_idx * indices_stride[2];
-        const int32_t selected_row_idx = indices_gm.GetValue(indices_offset);
-
-        const int64_t input_offset = selected_row_idx * input_stride[1] +
-                                     indices_ne1_idx * input_stride[2] +
-                                     indices_ne2_idx * input_stride[3] +
-                                     group * QK4_0;
-        const int64_t scale_offset = selected_row_idx * scale_stride[1] +
-                                     indices_ne1_idx * scale_stride[2] +
-                                     indices_ne2_idx * scale_stride[3] + group;
-        const int64_t output_offset = indices_ne0_idx * output_stride[1] +
-                                      indices_ne1_idx * output_stride[2] +
-                                      indices_ne2_idx * output_stride[3] +
-                                      group * QK4_0;
-
-        copy_in(input_offset);
-        LocalTensor input_local = input_queue.DeQue();
-        LocalTensor cast_local = cast_queue.AllocTensor();
-        LocalTensor output_local = output_queue.AllocTensor();
-
-        // TODO: cast more data to speed up.
-        Cast(cast_local, input_local, RoundMode::CAST_NONE, QK4_0);
-        Cast(output_local, cast_local, RoundMode::CAST_NONE, QK4_0);
-
-        // Only mul need compile by group.
-        half scale = scale_gm.GetValue(scale_offset);
-
-        Muls(output_local, output_local, (float)scale, QK4_0);
-
-        input_queue.FreeTensor(input_local);
-        cast_queue.FreeTensor(cast_local);
-        output_queue.EnQue(output_local);
-
-        copy_out(output_offset);
-    }
-
-    __aicore__ inline void calculate() {
-        for (int64_t i = ir; i < ir + dr; i++) {
-            for (int64_t j = 0; j < group_size_in_row; j++) {
-                calculate_group(i, j);
-            }
-        }
-    }
-
-   private:
-    int64_t input_ne[4];
-    size_t input_stride[4];
-
-    int64_t scale_ne[4];
-    size_t scale_stride[4];
-
-    int64_t indices_ne[4];
-    size_t indices_stride[4];
-
-    int64_t output_ne[4];
-    size_t output_stride[4];
-
-    int64_t ir;
-    int64_t dr;
-
-    int64_t group_size_in_row;
-
-    TPipe pipe;
-    GlobalTensor input_gm;
-    GlobalTensor scale_gm;
-    GlobalTensor indices_gm;
-    GlobalTensor output_gm;
-    TQue input_queue;
-    TQue output_queue;
-    TQue cast_queue;
-};
-
-template 
-__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
-    auto gm_ptr = (__gm__ uint8_t *)gm;
-    auto ub_ptr = (uint8_t *)(ub);
-    for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
-        *ub_ptr = *gm_ptr;
-    }
-}
-
-extern "C" __global__ __aicore__ void ascendc_get_row_q4_0(
-    GM_ADDR input_gm, GM_ADDR indices_gm, GM_ADDR output_gm,
-    GM_ADDR input_ne_gm, GM_ADDR indices_ne_gm, GM_ADDR indices_nb_gm,
-    GM_ADDR output_ne_gm, GM_ADDR output_nb_gm) {
-    int64_t input_ne_ub[4];
-    int64_t indices_ne_ub[4];
-    size_t indices_nb_ub[4];
-    int64_t output_ne_ub[4];
-    size_t output_nb_ub[4];
-
-    copy_to_ub(input_ne_gm, input_ne_ub, 32);
-    copy_to_ub(indices_ne_gm, indices_ne_ub, 32);
-    copy_to_ub(indices_nb_gm, indices_nb_ub, 32);
-    copy_to_ub(output_ne_gm, output_ne_ub, 32);
-    copy_to_ub(output_nb_gm, output_nb_ub, 32);
-
-    GET_ROW_Q4_0 op;
-    op.init(input_gm, indices_gm, output_gm, input_ne_ub, indices_ne_ub,
-            indices_nb_ub, output_ne_ub, output_nb_ub);
-    op.calculate();
-}
-
-#endif // #ifdef ASCEND_310P
diff --git a/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp b/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp
deleted file mode 100644
index ba9ab3c0483..00000000000
--- a/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp
+++ /dev/null
@@ -1,191 +0,0 @@
-#include "kernel_operator.h"
-
-// optimize me. Use template to avoid copy code.
-using namespace AscendC;
-
-#define BUFFER_NUM 2
-
-#define QK8_0 32
-
-class GET_ROW_Q8_0 {
-   public:
-    __aicore__ inline GET_ROW_Q8_0() {}
-    __aicore__ inline void init(GM_ADDR input, GM_ADDR indices, GM_ADDR output,
-                                int64_t *input_ne_ub, int64_t *indices_ne_ub,
-                                size_t *indices_nb_ub, int64_t *output_ne_ub,
-                                size_t *output_nb_ub) {
-        int64_t op_block_num = GetBlockNum();
-        int64_t op_block_idx = GetBlockIdx();
-
-        for (int i = 0; i < 4; i++) {
-            input_ne[i] = input_ne_ub[i];
-            indices_ne[i] = indices_ne_ub[i];
-            indices_stride[i] = indices_nb_ub[i] / indices_nb_ub[0];
-            scale_ne[i] = input_ne_ub[i];
-            output_ne[i] = output_ne_ub[i];
-            output_stride[i] = output_nb_ub[i] / output_nb_ub[0];
-        }
-
-        // one scale for a group.
-        scale_ne[0] /= QK8_0;
-
-        input_stride[0] = 1;
-        scale_stride[0] = 1;
-        output_stride[0] = 1;
-        for (int i = 1; i < 4; i++) {
-            input_stride[i] = input_stride[i - 1] * input_ne[i - 1];
-            scale_stride[i] = scale_stride[i - 1] * scale_ne[i - 1];
-        }
-
-        group_size_in_row = input_ne[0] / QK8_0;
-        int64_t scale_offset = input_ne[0] * input_ne[1] * input_ne[2] *
-                               input_ne[3] * sizeof(int8_t);
-
-        // Indices has two dims. n_elements = all rows should get.
-        // dr, all rows should this thread get.
-        uint64_t n_elements =
-            indices_ne[0] * indices_ne[1] * indices_ne[2] * indices_ne[3];
-        dr = n_elements / op_block_num;
-
-        uint64_t tails = n_elements % op_block_num;
-        if (op_block_idx < tails) {
-            dr += 1;
-            ir = dr * op_block_idx;
-        } else {
-            ir = dr * op_block_idx + tails;
-        }
-
-        input_gm.SetGlobalBuffer((__gm__ int8_t *)input);
-        scale_gm.SetGlobalBuffer((__gm__ half *)(input + scale_offset));
-        indices_gm.SetGlobalBuffer((__gm__ int32_t *)indices);
-        output_gm.SetGlobalBuffer((__gm__ float *)output);
-
-        pipe.InitBuffer(input_queue, BUFFER_NUM, QK8_0 * sizeof(int8_t));
-        pipe.InitBuffer(cast_queue, BUFFER_NUM, QK8_0 * sizeof(half));
-        pipe.InitBuffer(output_queue, BUFFER_NUM, QK8_0 * sizeof(float));
-    }
-
-    __aicore__ inline void copy_in(uint32_t offset) {
-        LocalTensor input_local = input_queue.AllocTensor();
-        DataCopy(input_local, input_gm[offset], QK8_0);
-        input_queue.EnQue(input_local);
-    }
-
-    __aicore__ inline void copy_out(uint32_t offset) {
-        LocalTensor output_local = output_queue.DeQue();
-        DataCopy(output_gm[offset], output_local, QK8_0);
-        output_queue.FreeTensor(output_local);
-    }
-
-    __aicore__ inline void calculate_group(int64_t idx, int64_t group) {
-        const int64_t indices_ne2_idx = idx / (indices_ne[0] * indices_ne[1]);
-        const int64_t indices_ne1_idx =
-            (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1]) /
-            indices_ne[0];
-        const int64_t indices_ne0_idx =
-            (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1] -
-             indices_ne1_idx * indices_ne[0]);
-
-        const int64_t indices_offset = indices_ne0_idx * indices_stride[0] +
-                                       indices_ne1_idx * indices_stride[1] +
-                                       indices_ne2_idx * indices_stride[2];
-        const int32_t selected_row_idx = indices_gm.GetValue(indices_offset);
-
-        const int64_t input_offset = selected_row_idx * input_stride[1] +
-                                     indices_ne1_idx * input_stride[2] +
-                                     indices_ne2_idx * input_stride[3] +
-                                     group * QK8_0;
-        const int64_t scale_offset = selected_row_idx * scale_stride[1] +
-                                     indices_ne1_idx * scale_stride[2] +
-                                     indices_ne2_idx * scale_stride[3] + group;
-        const int64_t output_offset = indices_ne0_idx * output_stride[1] +
-                                      indices_ne1_idx * output_stride[2] +
-                                      indices_ne2_idx * output_stride[3] +
-                                      group * QK8_0;
-
-        copy_in(input_offset);
-        LocalTensor input_local = input_queue.DeQue();
-        LocalTensor cast_local = cast_queue.AllocTensor();
-        LocalTensor output_local = output_queue.AllocTensor();
-
-        // TODO: cast more data to speed up.
-        Cast(cast_local, input_local, RoundMode::CAST_NONE, QK8_0);
-        Cast(output_local, cast_local, RoundMode::CAST_NONE, QK8_0);
-
-        // Only mul need compile by group.
-        half scale = scale_gm.GetValue(scale_offset);
-        Muls(output_local, output_local, (float)scale, QK8_0);
-
-        input_queue.FreeTensor(input_local);
-        cast_queue.FreeTensor(cast_local);
-        output_queue.EnQue(output_local);
-
-        copy_out(output_offset);
-    }
-
-    __aicore__ inline void calculate() {
-        for (int64_t i = ir; i < ir + dr; i++) {
-            for (int64_t j = 0; j < group_size_in_row; j++) {
-                calculate_group(i, j);
-            }
-        }
-    }
-
-   private:
-    int64_t input_ne[4];
-    size_t input_stride[4];
-
-    int64_t scale_ne[4];
-    size_t scale_stride[4];
-
-    int64_t indices_ne[4];
-    size_t indices_stride[4];
-
-    int64_t output_ne[4];
-    size_t output_stride[4];
-
-    int64_t ir;
-    int64_t dr;
-
-    int64_t group_size_in_row;
-
-    TPipe pipe;
-    GlobalTensor input_gm;
-    GlobalTensor scale_gm;
-    GlobalTensor indices_gm;
-    GlobalTensor output_gm;
-    TQue input_queue;
-    TQue output_queue;
-    TQue cast_queue;
-};
-
-template 
-__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
-    auto gm_ptr = (__gm__ uint8_t *)gm;
-    auto ub_ptr = (uint8_t *)(ub);
-    for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
-        *ub_ptr = *gm_ptr;
-    }
-}
-
-extern "C" __global__ __aicore__ void ascendc_get_row_q8_0(
-    GM_ADDR input_gm, GM_ADDR indices_gm, GM_ADDR output_gm,
-    GM_ADDR input_ne_gm, GM_ADDR indices_ne_gm, GM_ADDR indices_nb_gm,
-    GM_ADDR output_ne_gm, GM_ADDR output_nb_gm) {
-    int64_t input_ne_ub[4];
-    int64_t indices_ne_ub[4];
-    size_t indices_nb_ub[4];
-    int64_t output_ne_ub[4];
-    size_t output_nb_ub[4];
-
-    copy_to_ub(input_ne_gm, input_ne_ub, 32);
-    copy_to_ub(indices_ne_gm, indices_ne_ub, 32);
-    copy_to_ub(indices_nb_gm, indices_nb_ub, 32);
-    copy_to_ub(output_ne_gm, output_ne_ub, 32);
-    copy_to_ub(output_nb_gm, output_nb_ub, 32);
-
-    GET_ROW_Q8_0 op;
-    op.init(input_gm, indices_gm, output_gm, input_ne_ub, indices_ne_ub,
-            indices_nb_ub, output_ne_ub, output_nb_ub);
-    op.calculate();
-}
diff --git a/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp b/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp
deleted file mode 100644
index 504b43afaa1..00000000000
--- a/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp
+++ /dev/null
@@ -1,218 +0,0 @@
-#include "kernel_operator.h"
-
-using namespace AscendC;
-#ifdef ASCEND_310P
-    extern "C" __global__ __aicore__ void ascendc_quantize_f16_q8_0(
-        GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm,
-        GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) {
-        // let following test cases can continue run, here just print error information. Of Cource the test case that call this operator is failed.
-        printf("Ascend310P not support f16->8bit quantization.\n");
-    }
-#else
-
-#define BUFFER_NUM 2
-#define QK8_0 32
-
-class QUANTIZE_F16_Q8_0 {
-   public:
-    __aicore__ inline QUANTIZE_F16_Q8_0() {}
-    __aicore__ inline void init(GM_ADDR input, GM_ADDR output,
-                                int64_t *input_ne_ub, size_t *input_nb_ub,
-                                int64_t *output_ne_ub) {
-        int64_t op_block_num = GetBlockNum();
-        int64_t op_block_idx = GetBlockIdx();
-
-        for (int i = 0; i < 4; i++) {
-            input_ne[i] = input_ne_ub[i];
-            input_stride[i] = input_nb_ub[i] / input_nb_ub[0];
-
-            output_ne[i] = output_ne_ub[i];
-        }
-
-        output_stride[0] = 1;
-        for (int i = 1; i < 4; i++) {
-            output_stride[i] = output_stride[i - 1] * output_ne[i - 1];
-        }
-
-        scale_ne = input_ne;
-        scale_stride[0] = 1;
-        scale_stride[1] = input_ne[0] / QK8_0;
-        for (int i = 2; i < 4; i++) {
-            scale_stride[i] = scale_stride[i - 1] * scale_ne[i - 1];
-        }
-
-        // split input tensor by rows.
-        uint64_t nr = input_ne[1] * input_ne[2] * input_ne[3];
-        dr = nr / op_block_num;
-
-        uint64_t tails = nr % op_block_num;
-        if (op_block_idx < tails) {
-            dr += 1;
-            ir = dr * op_block_idx;
-        } else {
-            ir = dr * op_block_idx + tails;
-        }
-
-        group_size_in_row = scale_stride[1];
-        int64_t output_size = output_ne[0] * output_ne[1] * output_ne[2] *
-                              output_ne[3] * sizeof(uint8_t);
-
-        input_gm.SetGlobalBuffer((__gm__ half *)input);
-        output_gm.SetGlobalBuffer((__gm__ int8_t *)output);
-        scale_gm.SetGlobalBuffer((__gm__ half *)(output + output_size + ir *
-                                                 group_size_in_row *
-                                                 sizeof(half)));
-
-        pipe.InitBuffer(input_queue, BUFFER_NUM, QK8_0 * sizeof(half));
-        pipe.InitBuffer(output_queue, BUFFER_NUM, QK8_0 * sizeof(int8_t));
-        pipe.InitBuffer(work_queue, 1, 32);
-        pipe.InitBuffer(max_queue, 1, 32);
-        pipe.InitBuffer(abs_queue, 1, QK8_0 * sizeof(float));
-        pipe.InitBuffer(scale_queue, 1, 32);
-        pipe.InitBuffer(cast_queue ,1 ,QK8_0 * sizeof(float));
-    }
-
-    __aicore__ inline void copy_in(uint32_t offset) {
-        LocalTensor input_local = input_queue.AllocTensor();
-        DataCopy(input_local, input_gm[offset], QK8_0);
-        input_queue.EnQue(input_local);
-    }
-
-    __aicore__ inline void copy_out(uint32_t offset) {
-        LocalTensor output_local = output_queue.DeQue();
-        DataCopy(output_gm[offset], output_local, QK8_0);
-        output_queue.FreeTensor(output_local);
-    }
-
-    __aicore__ inline half calculate_group(int64_t row, int64_t group) {
-        const int64_t i3 = row / (input_ne[1] * input_ne[2]);
-        const int64_t i2 = (row - i3 * input_ne[1] * input_ne[2]) / input_ne[1];
-        const int64_t i1 =
-            row - i3 * input_ne[1] * input_ne[2] - i2 * input_ne[1];
-
-        const int64_t input_offset = i1 * input_stride[1] +
-                                     i2 * input_stride[2] +
-                                     i3 * input_stride[3] + QK8_0 * group;
-
-        const int64_t output_offset = i1 * output_stride[1] +
-                                      i2 * output_stride[2] +
-                                      i3 * output_stride[3] + QK8_0 * group;
-
-        copy_in(input_offset);
-        LocalTensor input_local = input_queue.DeQue();
-        LocalTensor output_local = output_queue.AllocTensor();
-        LocalTensor work_local = work_queue.AllocTensor();
-        LocalTensor abs_local = abs_queue.AllocTensor();
-        LocalTensor max_local = max_queue.AllocTensor();
-        LocalTensor cast_local = cast_queue.AllocTensor();
-
-        Cast(cast_local, input_local, RoundMode::CAST_NONE, QK8_0);
-        Abs(abs_local, cast_local, QK8_0);
-        ReduceMax(max_local, abs_local, work_local, QK8_0);
-
-        pipe_barrier(PIPE_ALL);
-        float d = max_local.GetValue(0);
-        d = d / ((1 << 7) - 1);
-        if (d != 0) {
-            Muls(cast_local, cast_local, 1.0f / d, QK8_0);
-        }
-
-        Cast(cast_local, cast_local, RoundMode::CAST_ROUND, QK8_0);
-        Cast(input_local, cast_local, RoundMode::CAST_ROUND, QK8_0);
-        Cast(output_local, input_local, RoundMode::CAST_ROUND, QK8_0);
-        output_queue.EnQue(output_local);
-        copy_out(output_offset);
-
-        input_queue.FreeTensor(input_local);
-        work_queue.FreeTensor(work_local);
-        abs_queue.FreeTensor(abs_local);
-        max_queue.FreeTensor(max_local);
-        cast_queue.FreeTensor(cast_local);
-        return (half)d;
-    }
-
-    __aicore__ inline void calculate() {
-        LocalTensor scale_local = scale_queue.AllocTensor();
-        uint32_t scale_local_offset = 0;
-        uint32_t scale_global_offset = 0;
-        for (int64_t i = ir; i < ir + dr; i++) {
-            for (int64_t j = 0; j < group_size_in_row; j++) {
-                half scale = calculate_group(i, j);
-                scale_local.SetValue(scale_local_offset++, scale);
-                if (scale_local_offset == 16) {
-                    scale_local_offset = 0;
-                    // TODO: OPTIMIZE ME
-                    pipe_barrier(PIPE_ALL);
-                    DataCopy(scale_gm[scale_global_offset], scale_local, 16);
-                    pipe_barrier(PIPE_ALL);
-                    scale_global_offset += 16;
-                }
-            }
-        }
-
-        if (scale_local_offset != 0) {
-            pipe_barrier(PIPE_ALL);
-            DataCopyExtParams dataCopyParams;
-            dataCopyParams.blockCount = 1;
-            dataCopyParams.blockLen = scale_local_offset * sizeof(half);
-            DataCopyPad(scale_gm[scale_global_offset], scale_local,
-                        dataCopyParams);
-            pipe_barrier(PIPE_ALL);
-        }
-    }
-
-   private:
-    int64_t input_ne[4];
-    size_t input_stride[4];
-
-    int64_t *scale_ne;
-    size_t scale_stride[4];
-
-    int64_t output_ne[4];
-    size_t output_stride[4];
-
-    int64_t group_size_in_row;
-
-    int64_t ir;
-    int64_t dr;
-
-    TPipe pipe;
-    GlobalTensor input_gm;
-    GlobalTensor scale_gm;
-    GlobalTensor output_gm;
-    TQue input_queue;
-    TQue output_queue;
-    TQue work_queue;
-    TQue max_queue;
-    TQue abs_queue;
-    TQue scale_queue;
-    TQue cast_queue;
-
-};
-
-template 
-__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
-    auto gm_ptr = (__gm__ uint8_t *)gm;
-    auto ub_ptr = (uint8_t *)(ub);
-    for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
-        *ub_ptr = *gm_ptr;
-    }
-}
-
-extern "C" __global__ __aicore__ void ascendc_quantize_f16_q8_0(
-    GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm,
-    GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) {
-    int64_t input_ne_ub[4];
-    size_t input_nb_ub[4];
-    int64_t output_ne_ub[4];
-
-    copy_to_ub(input_ne_gm, input_ne_ub, 32);
-    copy_to_ub(input_nb_gm, input_nb_ub, 32);
-    copy_to_ub(output_ne_gm, output_ne_ub, 32);
-
-    QUANTIZE_F16_Q8_0 op;
-    op.init(input_gm, output_gm, input_ne_ub, input_nb_ub, output_ne_ub);
-    op.calculate();
-}
-
-#endif // #ifdef ASCEND_310P
diff --git a/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp b/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp
deleted file mode 100644
index 05b0bc1df59..00000000000
--- a/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp
+++ /dev/null
@@ -1,216 +0,0 @@
-#include "kernel_operator.h"
-
-using namespace AscendC;
-#ifdef ASCEND_310P // 310P not support f32->8bit quantization
-    extern "C" __global__ __aicore__ void ascendc_quantize_f32_q8_0(
-        GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm,
-        GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) {
-        // let following test cases can continue run, here just print error information. Of Cource the test case that call this operator is failed.
-        printf("Ascend310P not support f32->8bit quantization.\n");
-    }
-#else
-
-#define BUFFER_NUM 2
-#define QK8_0 32
-
-class QUANTIZE_F32_Q8_0 {
-   public:
-    __aicore__ inline QUANTIZE_F32_Q8_0() {}
-    __aicore__ inline void init(GM_ADDR input, GM_ADDR output,
-                                int64_t *input_ne_ub, size_t *input_nb_ub,
-                                int64_t *output_ne_ub) {
-        int64_t op_block_num = GetBlockNum();
-        int64_t op_block_idx = GetBlockIdx();
-
-        for (int i = 0; i < 4; i++) {
-            input_ne[i] = input_ne_ub[i];
-            input_stride[i] = input_nb_ub[i] / input_nb_ub[0];
-
-            output_ne[i] = output_ne_ub[i];
-        }
-
-        output_stride[0] = 1;
-        for (int i = 1; i < 4; i++) {
-            output_stride[i] = output_stride[i - 1] * output_ne[i - 1];
-        }
-
-        scale_ne = input_ne;
-        scale_stride[0] = 1;
-        scale_stride[1] = input_ne[0] / QK8_0;
-        for (int i = 2; i < 4; i++) {
-            scale_stride[i] = scale_stride[i - 1] * scale_ne[i - 1];
-        }
-
-        // split input tensor by rows.
-        uint64_t nr = input_ne[1] * input_ne[2] * input_ne[3];
-        dr = nr / op_block_num;
-
-        uint64_t tails = nr % op_block_num;
-        if (op_block_idx < tails) {
-            dr += 1;
-            ir = dr * op_block_idx;
-        } else {
-            ir = dr * op_block_idx + tails;
-        }
-
-        group_size_in_row = scale_stride[1];
-        int64_t output_size = output_ne[0] * output_ne[1] * output_ne[2] *
-                              output_ne[3] * sizeof(uint8_t);
-
-        input_gm.SetGlobalBuffer((__gm__ float *)input);
-        output_gm.SetGlobalBuffer((__gm__ int8_t *)output);
-        scale_gm.SetGlobalBuffer((__gm__ half *)(output + output_size +
-                                                 ir * group_size_in_row *
-                                                 sizeof(half)));
-
-        pipe.InitBuffer(input_queue, BUFFER_NUM, QK8_0 * sizeof(float));
-        pipe.InitBuffer(output_queue, BUFFER_NUM, QK8_0 * sizeof(int8_t));
-        pipe.InitBuffer(work_queue, 1, 32);
-        pipe.InitBuffer(max_queue, 1, 32);
-        pipe.InitBuffer(abs_queue, 1, QK8_0 * sizeof(float));
-        pipe.InitBuffer(cast_queue, 1, QK8_0 * sizeof(half));
-        pipe.InitBuffer(scale_queue, 1, 32);
-    }
-
-    __aicore__ inline void copy_in(uint32_t offset) {
-        LocalTensor input_local = input_queue.AllocTensor();
-        DataCopy(input_local, input_gm[offset], QK8_0);
-        input_queue.EnQue(input_local);
-    }
-
-    __aicore__ inline void copy_out(uint32_t offset) {
-        LocalTensor output_local = output_queue.DeQue();
-        DataCopy(output_gm[offset], output_local, QK8_0);
-        output_queue.FreeTensor(output_local);
-    }
-
-    __aicore__ inline half calculate_group(int64_t row, int64_t group) {
-        const int64_t i3 = row / (input_ne[1] * input_ne[2]);
-        const int64_t i2 = (row - i3 * input_ne[1] * input_ne[2]) / input_ne[1];
-        const int64_t i1 =
-            row - i3 * input_ne[1] * input_ne[2] - i2 * input_ne[1];
-
-        const int64_t input_offset = i1 * input_stride[1] +
-                                     i2 * input_stride[2] +
-                                     i3 * input_stride[3] + QK8_0 * group;
-
-        const int64_t output_offset = i1 * output_stride[1] +
-                                      i2 * output_stride[2] +
-                                      i3 * output_stride[3] + QK8_0 * group;
-
-        copy_in(input_offset);
-        LocalTensor input_local = input_queue.DeQue();
-        LocalTensor output_local = output_queue.AllocTensor();
-        LocalTensor work_local = work_queue.AllocTensor();
-        LocalTensor abs_local = abs_queue.AllocTensor();
-        LocalTensor max_local = max_queue.AllocTensor();
-        LocalTensor cast_local = cast_queue.AllocTensor();
-
-        Abs(abs_local, input_local, QK8_0);
-        ReduceMax(max_local, abs_local, work_local, QK8_0);
-        pipe_barrier(PIPE_ALL);
-        float d = max_local.GetValue(0);
-        d = d / ((1 << 7) - 1);
-        if (d != 0) {
-            Muls(input_local, input_local, 1.0f / d, QK8_0);
-        }
-
-        Cast(input_local, input_local, RoundMode::CAST_ROUND, QK8_0);
-        Cast(cast_local, input_local, RoundMode::CAST_ROUND, QK8_0);
-        Cast(output_local, cast_local, RoundMode::CAST_ROUND, QK8_0);
-        output_queue.EnQue(output_local);
-        copy_out(output_offset);
-
-        input_queue.FreeTensor(input_local);
-        work_queue.FreeTensor(work_local);
-        abs_queue.FreeTensor(abs_local);
-        max_queue.FreeTensor(max_local);
-        cast_queue.FreeTensor(cast_local);
-
-        return (half)d;
-    }
-
-    __aicore__ inline void calculate() {
-        LocalTensor scale_local = scale_queue.AllocTensor();
-        uint32_t scale_local_offset = 0;
-        uint32_t scale_global_offset = 0;
-        for (int64_t i = ir; i < ir + dr; i++) {
-            for (int64_t j = 0; j < group_size_in_row; j++) {
-                half scale = calculate_group(i, j);
-                scale_local.SetValue(scale_local_offset++, scale);
-                if (scale_local_offset == 16) {
-                    scale_local_offset = 0;
-                    // TODO: OPTIMIZE ME
-                    pipe_barrier(PIPE_ALL);
-                    DataCopy(scale_gm[scale_global_offset], scale_local, 16);
-                    pipe_barrier(PIPE_ALL);
-                    scale_global_offset += 16;
-                }
-            }
-        }
-
-        if (scale_local_offset != 0) {
-            pipe_barrier(PIPE_ALL);
-            DataCopyExtParams dataCopyParams;
-            dataCopyParams.blockCount = 1;
-            dataCopyParams.blockLen = scale_local_offset * sizeof(half);
-            DataCopyPad(scale_gm[scale_global_offset], scale_local,
-                        dataCopyParams);
-            pipe_barrier(PIPE_ALL);
-        }
-    }
-
-   private:
-    int64_t input_ne[4];
-    size_t input_stride[4];
-
-    int64_t *scale_ne;
-    size_t scale_stride[4];
-
-    int64_t output_ne[4];
-    size_t output_stride[4];
-
-    int64_t group_size_in_row;
-
-    int64_t ir;
-    int64_t dr;
-
-    TPipe pipe;
-    GlobalTensor input_gm;
-    GlobalTensor scale_gm;
-    GlobalTensor output_gm;
-    TQue input_queue;
-    TQue output_queue;
-    TQue work_queue;
-    TQue max_queue;
-    TQue abs_queue;
-    TQue cast_queue;
-    TQue scale_queue;
-};
-
-template 
-__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
-    auto gm_ptr = (__gm__ uint8_t *)gm;
-    auto ub_ptr = (uint8_t *)(ub);
-    for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
-        *ub_ptr = *gm_ptr;
-    }
-}
-
-extern "C" __global__ __aicore__ void ascendc_quantize_f32_q8_0(
-    GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm,
-    GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) {
-    int64_t input_ne_ub[4];
-    size_t input_nb_ub[4];
-    int64_t output_ne_ub[4];
-
-    copy_to_ub(input_ne_gm, input_ne_ub, 32);
-    copy_to_ub(input_nb_gm, input_nb_ub, 32);
-    copy_to_ub(output_ne_gm, output_ne_ub, 32);
-
-    QUANTIZE_F32_Q8_0 op;
-    op.init(input_gm, output_gm, input_ne_ub, input_nb_ub, output_ne_ub);
-    op.calculate();
-}
-
-#endif // #ifdef ASCEND_310P
diff --git a/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp b/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp
deleted file mode 100644
index 1188937b744..00000000000
--- a/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp
+++ /dev/null
@@ -1,295 +0,0 @@
-#include "kernel_operator.h"
-
-using namespace AscendC;
-#ifdef ASCEND_310P // 310P not support float->4bit quantization
-    extern "C" __global__ __aicore__ void ascendc_quantize_f32_to_q4_0(
-        GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm,
-        GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) {
-        // let following test cases can continue run, here just print error information. Of Cource the test case that call this operator is failed.
-        printf("Ascend310P not support f32->4bit quantization.\n");
-    }
-
-    extern "C" __global__ __aicore__ void ascendc_quantize_f16_to_q4_0(
-        GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm,
-        GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) {
-        // let following test cases can continue run, here just print error information. Of Cource the test case that call this operator is failed.
-        printf("Ascend310P not support f16->4bit quantization.\n");
-    }
-#else
-
-#define BUFFER_NUM 2
-#define Group_Size 32
-
-template 
-class QUANTIZE_FLOAT_TO_Q4_0 {
-   public:
-    __aicore__ inline QUANTIZE_FLOAT_TO_Q4_0() {}
-    __aicore__ inline void init(GM_ADDR input, GM_ADDR output,
-                                int64_t *input_ne_ub, size_t *input_nb_ub,
-                                int64_t *output_ne_ub) {
-        // TODO: fix test_case CPY(type_src=f16,type_dst=q4_0,ne=[256,4,4,4],
-        //                         permute=[0,0,0,0]):
-        // [CPY] NMSE = 0.000008343 > 0.000001000 FAIL
-        int64_t op_block_num = GetBlockNum();
-        int64_t op_block_idx = GetBlockIdx();
-
-        // input stride of data elements
-        for (int i = 0; i < 4; i++) {
-            input_ne[i] = input_ne_ub[i];
-            input_stride[i] = input_nb_ub[i] / input_nb_ub[0];
-            output_ne[i] = output_ne_ub[i];
-        }
-
-        // output stride of data elements
-        output_stride[0] = 1;
-        for (int i = 1; i < 4; i++) {
-            output_stride[i] = output_stride[i - 1] * output_ne[i - 1];
-        }
-
-        // scale saved one by one after data:. [group1_scale, group2_scale, ...]
-        scale_ne = input_ne;
-        scale_stride[0] = 1;
-        scale_stride[1] = input_ne[0] / Group_Size;
-        for (int i = 2; i < 4; i++) {
-            scale_stride[i] = scale_stride[i - 1] * scale_ne[i - 1];
-        }
-
-        // split input tensor by rows.
-        uint64_t nr = input_ne[1] * input_ne[2] * input_ne[3];
-        dr = nr / op_block_num;
-
-        uint64_t tails = nr % op_block_num;
-        if (op_block_idx < tails) {
-            dr += 1;
-            ir = dr * op_block_idx;
-        } else {
-            ir = dr * op_block_idx + tails;
-        }
-
-        group_size_in_row = scale_stride[1];
-        int64_t scale_offset = output_ne[0] * output_ne[1] * output_ne[2] *
-                              output_ne[3] * sizeof(uint8_t) / 2;
-
-        input_gm.SetGlobalBuffer((__gm__ SRC_T *)input);
-        output_gm.SetGlobalBuffer((__gm__ int8_t *)output);
-        scale_gm.SetGlobalBuffer((__gm__ half *)(output + scale_offset + ir *
-                                                 group_size_in_row *
-                                                 sizeof(half)));
-
-        pipe.InitBuffer(input_queue, BUFFER_NUM, Group_Size * sizeof(SRC_T));
-        pipe.InitBuffer(output_queue, BUFFER_NUM,
-                            Group_Size * sizeof(int8_t) / 2);
-        pipe.InitBuffer(cast_queue , 1, Group_Size * sizeof(float));
-        pipe.InitBuffer(work_queue, 1, Group_Size * sizeof(float));
-        pipe.InitBuffer(max_queue, 1, Group_Size * sizeof(float));
-        pipe.InitBuffer(min_queue, 1, Group_Size * sizeof(float));
-        pipe.InitBuffer(scale_queue, 1, Group_Size / 2 * sizeof(half));
-        pipe.InitBuffer(int8_queue, 1, Group_Size * sizeof(int8_t));
-        pipe.InitBuffer(half_queue, 1, Group_Size * sizeof(half));
-    }
-
-    __aicore__ inline void copy_in(uint32_t offset) {
-        LocalTensor input_local = input_queue.AllocTensor();
-        DataCopy(input_local, input_gm[offset], Group_Size);
-        input_queue.EnQue(input_local);
-    }
-
-    __aicore__ inline void copy_out(uint32_t offset) {
-        // reinterpretcast Group_Size(32) * int4b_t to Group_Size / 2 * int8_t,
-        // and using DataCopyPad to avoid 32 bits align.
-        LocalTensor output_local = output_queue.DeQue();
-        LocalTensor output_int8_local =
-                                    output_local.ReinterpretCast();
-
-        DataCopyExtParams dataCopyParams;
-        dataCopyParams.blockCount = 1;
-        dataCopyParams.blockLen = Group_Size / 2  * sizeof(int8_t);
-        DataCopyPad(output_gm[offset], output_int8_local, dataCopyParams);
-
-        output_queue.FreeTensor(output_local);
-    }
-
-    __aicore__ inline void input_to_cast(LocalTensor cast_local,
-                                         LocalTensor input_local) {
-        DataCopy(cast_local, input_local, Group_Size);
-    }
-
-    __aicore__ inline void input_to_cast(LocalTensor cast_local,
-                                         LocalTensor input_local) {
-        Cast(cast_local, input_local, RoundMode::CAST_NONE, Group_Size);
-    }
-
-    __aicore__ inline half calculate_group(int64_t row, int64_t group) {
-        const int64_t i3 = row / (input_ne[1] * input_ne[2]);
-        const int64_t i2 = (row - i3 * input_ne[1] * input_ne[2]) / input_ne[1];
-        const int64_t i1 =
-            row - i3 * input_ne[1] * input_ne[2] - i2 * input_ne[1];
-
-        const int64_t input_offset = i1 * input_stride[1] +
-                                     i2 * input_stride[2] +
-                                     i3 * input_stride[3] + Group_Size * group;
-
-        // output_offset is stride for output_gm which datatype is int8_t and
-        // divided by 2 is needed for int4b_t.
-        const int64_t output_offset = (i1 * output_stride[1] +
-                                       i2 * output_stride[2] +
-                                       i3 * output_stride[3] +
-                                       Group_Size * group) / 2;
-        copy_in(input_offset);
-
-        LocalTensor input_local = input_queue.DeQue();
-        LocalTensor output_local = output_queue.AllocTensor();
-        LocalTensor cast_local = cast_queue.AllocTensor();
-        LocalTensor work_local = work_queue.AllocTensor();
-        LocalTensor max_local = max_queue.AllocTensor();
-        LocalTensor min_local = min_queue.AllocTensor();
-        LocalTensor int8_local = int8_queue.AllocTensor();
-        LocalTensor half_local = half_queue.AllocTensor();
-
-        input_to_cast(cast_local, input_local);
-
-        ReduceMax(max_local, cast_local, work_local, Group_Size);
-        ReduceMin(min_local, cast_local, work_local, Group_Size);
-        const float max_value = max_local.GetValue(0);
-        const float min_value = min_local.GetValue(0);
-        float d = max_value;
-        if (min_value < 0 && (-1 * min_value) > max_value) {
-            d = min_value;
-        }
-
-        d = d / (-8);
-        if (d != 0) {
-            Muls(cast_local, cast_local, 1.0f / d, Group_Size);
-        }
-
-        // range: [-8,8] -> [0.5,16.5] -> [0,16] -> [0,15] -> [-8,7]
-        float scalar = 8.5f;
-        Adds(cast_local, cast_local, scalar, Group_Size);
-        Cast(cast_local, cast_local, RoundMode::CAST_FLOOR, Group_Size);
-        scalar = 15.0f;
-        Mins(cast_local, cast_local, scalar, Group_Size);
-        scalar = -8.0f;
-        Adds(cast_local, cast_local, scalar, Group_Size);
-
-        // float->half->int4b
-        Cast(half_local, cast_local, RoundMode::CAST_NONE, Group_Size);
-        Cast(output_local, half_local, RoundMode::CAST_NONE, Group_Size);
-
-        output_queue.EnQue(output_local);
-        copy_out(output_offset);
-
-        input_queue.FreeTensor(input_local);
-        work_queue.FreeTensor(work_local);
-        max_queue.FreeTensor(max_local);
-        min_queue.FreeTensor(min_local);
-        int8_queue.FreeTensor(int8_local);
-        half_queue.FreeTensor(half_local);
-        cast_queue.FreeTensor(cast_local);
-        return (half)d;
-    }
-
-    __aicore__ inline void calculate() {
-        LocalTensor scale_local = scale_queue.AllocTensor();
-        uint32_t scale_local_offset = 0;
-        uint32_t scale_global_offset = 0;
-        for (int64_t i = ir; i < ir + dr; i++) {
-            for (int64_t j = 0; j < group_size_in_row; j++) {
-                half scale = calculate_group(i, j);
-                scale_local.SetValue(scale_local_offset++, scale);
-                // Copy Group_Size/2 length data each time.
-                if (scale_local_offset == Group_Size / 2) {
-                    scale_local_offset = 0;
-                    // TODO: OPTIMIZE ME
-                    pipe_barrier(PIPE_ALL);
-                    DataCopy(scale_gm[scale_global_offset], scale_local,
-                                      Group_Size / 2);
-                    pipe_barrier(PIPE_ALL);
-                    scale_global_offset += Group_Size / 2;
-                }
-            }
-        }
-
-        if (scale_local_offset != 0) {
-            pipe_barrier(PIPE_ALL);
-            DataCopyExtParams dataCopyParams;
-            dataCopyParams.blockCount = 1;
-            dataCopyParams.blockLen = scale_local_offset * sizeof(half);
-            DataCopyPad(scale_gm[scale_global_offset], scale_local,
-                        dataCopyParams);
-            pipe_barrier(PIPE_ALL);
-        }
-        scale_queue.FreeTensor(scale_local);
-    }
-
-   private:
-    int64_t input_ne[4];
-    size_t input_stride[4];
-
-    int64_t *scale_ne;
-    size_t scale_stride[4];
-
-    int64_t output_ne[4];
-    size_t output_stride[4];
-
-    int64_t group_size_in_row;
-
-    int64_t ir;
-    int64_t dr;
-
-    TPipe pipe;
-    GlobalTensor input_gm;
-    GlobalTensor scale_gm;
-    GlobalTensor output_gm;
-    TQue input_queue;
-    TQue output_queue;
-    TQue work_queue;
-    TQue max_queue;
-    TQue min_queue;
-    TQue scale_queue;
-    TQue cast_queue;
-    TQue int8_queue;
-    TQue half_queue;
-};
-
-template 
-__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
-    auto gm_ptr = (__gm__ uint8_t *)gm;
-    auto ub_ptr = (uint8_t *)(ub);
-    for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
-        *ub_ptr = *gm_ptr;
-    }
-}
-
-extern "C" __global__ __aicore__ void ascendc_quantize_f16_to_q4_0(
-    GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm,
-    GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) {
-    int64_t input_ne_ub[4];
-    size_t input_nb_ub[4];
-    int64_t output_ne_ub[4];
-
-    copy_to_ub(input_ne_gm, input_ne_ub, 32);
-    copy_to_ub(input_nb_gm, input_nb_ub, 32);
-    copy_to_ub(output_ne_gm, output_ne_ub, 32);
-
-    QUANTIZE_FLOAT_TO_Q4_0 op;
-    op.init(input_gm, output_gm, input_ne_ub, input_nb_ub, output_ne_ub);
-    op.calculate();
-}
-
-extern "C" __global__ __aicore__ void ascendc_quantize_f32_to_q4_0(
-    GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm,
-    GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) {
-    int64_t input_ne_ub[4];
-    size_t input_nb_ub[4];
-    int64_t output_ne_ub[4];
-
-    copy_to_ub(input_ne_gm, input_ne_ub, 32);
-    copy_to_ub(input_nb_gm, input_nb_ub, 32);
-    copy_to_ub(output_ne_gm, output_ne_ub, 32);
-
-    QUANTIZE_FLOAT_TO_Q4_0 op;
-    op.init(input_gm, output_gm, input_ne_ub, input_nb_ub, output_ne_ub);
-    op.calculate();
-}
-
-#endif // #ifdef ASCEND_310P
diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h
index fbb04426abe..93ab7ea446e 100644
--- a/ggml/src/ggml-common.h
+++ b/ggml/src/ggml-common.h
@@ -99,6 +99,9 @@ typedef sycl::half2 ggml_half2;
 #define QI4_1 (QK4_1 / (4 * QR4_1))
 #define QR4_1 2
 
+#define QI_MXFP4 (QK_MXFP4 / (4 * QR_MXFP4))
+#define QR_MXFP4 2
+
 #define QI5_0 (QK5_0 / (4 * QR5_0))
 #define QR5_0 2
 
@@ -184,6 +187,13 @@ typedef struct {
 } block_q4_1;
 static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding");
 
+#define QK_MXFP4 32
+typedef struct {
+    uint8_t e; // E8M0
+    uint8_t qs[QK_MXFP4/2];
+} block_mxfp4;
+static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + QK_MXFP4/2, "wrong mxfp4 block size/padding");
+
 #define QK5_0 32
 typedef struct {
     ggml_half d;           // delta
@@ -1074,10 +1084,17 @@ GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512)
     0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
 GGML_TABLE_END()
 
+// TODO: fix name to kvalues_iq4_nl
 GGML_TABLE_BEGIN(int8_t, kvalues_iq4nl, 16)
     -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113,
 GGML_TABLE_END()
 
+// e2m1 values (doubled)
+// ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+GGML_TABLE_BEGIN(int8_t, kvalues_mxfp4, 16)
+    0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12,
+GGML_TABLE_END()
+
 #define NGRID_IQ1S 2048
 #define IQ1S_DELTA 0.125f
 #define IQ1M_DELTA 0.125f
diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt
index 66a5ad8d2ed..ce0a3e1285e 100644
--- a/ggml/src/ggml-cpu/CMakeLists.txt
+++ b/ggml/src/ggml-cpu/CMakeLists.txt
@@ -70,10 +70,12 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
     if (GGML_OPENMP)
         find_package(OpenMP)
         if (OpenMP_FOUND)
+            set(GGML_OPENMP_ENABLED "ON" CACHE INTERNAL "")
             target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_OPENMP)
 
             target_link_libraries(${GGML_CPU_NAME} PRIVATE OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
         else()
+            set(GGML_OPENMP_ENABLED "OFF" CACHE INTERNAL "")
             message(WARNING "OpenMP not found")
         endif()
     endif()
@@ -456,8 +458,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
             list(APPEND ARCH_FLAGS -march=z16)
         elseif (${S390X_M} MATCHES "9175|9176")
             # NOTE: Only available from GCC 15.1.0 onwards. Any z17 machine with compile issues must first verify their GCC version.
+            #       binutils must also be updated to the latest for the -march=z17 flag to work. Otherwise, use -march=arch15.
             message(STATUS "z17 target")
-            list(APPEND ARCH_FLAGS -march=z17)
+            list(APPEND ARCH_FLAGS -march=arch15)
         else()
             message(STATUS "Unknown target")
             message(WARNING "Unknown target. If you are compiling for z14 and earlier, you might have to add -DGGML_VXE=OFF.")
@@ -494,9 +497,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
 
         # Fetch KleidiAI sources:
         include(FetchContent)
-        set(KLEIDIAI_COMMIT_TAG "v1.9.0")
+        set(KLEIDIAI_COMMIT_TAG "v1.11.0")
         set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
-        set(KLEIDIAI_ARCHIVE_MD5  "2a8e1bb55d201557553545536489a017")
+        set(KLEIDIAI_ARCHIVE_MD5  "3fe9e5ab964c375c53839296eb71eaa2")
 
         if (POLICY CMP0135)
             cmake_policy(SET CMP0135 NEW)
diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h
index 10e5342516a..f476127995b 100644
--- a/ggml/src/ggml-cpu/arch-fallback.h
+++ b/ggml/src/ggml-cpu/arch-fallback.h
@@ -13,6 +13,7 @@
 #define ggml_vec_dot_q5_0_q8_0_generic ggml_vec_dot_q5_0_q8_0
 #define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1
 #define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0
+#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
 #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
 #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
@@ -37,17 +38,25 @@
 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
+#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
+#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
+#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
+#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
 #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
 // repack.cpp
 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
+#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
+#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
 // repack.cpp
 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
@@ -64,6 +73,7 @@
 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
 #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
 #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
+#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
 // repack.cpp
 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@@ -72,18 +82,23 @@
 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
+#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
+#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
+#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
+#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
 #elif defined(__loongarch64)
 // quants.c
 #define quantize_row_q8_K_generic quantize_row_q8_K
 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
 #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
 #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
+#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
 // repack.cpp
 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@@ -92,12 +107,16 @@
 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
+#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
+#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
+#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
+#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
 #elif defined(__riscv)
 // quants.c
 #define quantize_row_q8_K_generic quantize_row_q8_K
@@ -112,6 +131,7 @@
 #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
 #define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
 #define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
+#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
 // repack.cpp
 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@@ -119,11 +139,15 @@
 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
+#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
+#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
+#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
+#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
 #elif defined(__s390x__)
 // quants.c
 #define quantize_row_q8_K_generic quantize_row_q8_K
@@ -139,6 +163,7 @@
 #define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K
 #define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K
 #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
+#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
 // repack.cpp
 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@@ -147,12 +172,16 @@
 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
+#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
+#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
+#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
+#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
 #elif defined(__wasm__)
 // quants.c
 #define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1
@@ -167,6 +196,7 @@
 #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
 #define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
 #define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
+#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
 // repack.cpp
 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@@ -175,10 +205,14 @@
 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
+#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
+#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
+#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
+#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
 #endif
diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c
index 3e2d3d03d67..aadbb487ec0 100644
--- a/ggml/src/ggml-cpu/arch/arm/quants.c
+++ b/ggml/src/ggml-cpu/arch/arm/quants.c
@@ -589,6 +589,67 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
     *s = sumf;
 }
 
+void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+    assert(n % QK_MXFP4 == 0);
+    static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
+
+    const block_mxfp4 * GGML_RESTRICT x = vx;
+    const block_q8_0 * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_MXFP4;
+
+    int ib = 0;
+    float sumf = 0;
+
+#if defined __ARM_NEON
+    const int8x16_t values = vld1q_s8(kvalues_mxfp4);
+    const uint8x16_t m4b = vdupq_n_u8(0x0f);
+    uint8x16x2_t q4bits;
+    int8x16x4_t q4b;
+    int8x16x4_t q8b;
+    int32x4_t prod_1;
+    int32x4_t prod_2;
+
+    for (; ib + 1 < nb; ib += 2) {
+        q4bits.val[0] = vld1q_u8(x[ib + 0].qs);
+        q4bits.val[1] = vld1q_u8(x[ib + 1].qs);
+        q8b.val[0]    = vld1q_s8(y[ib + 0].qs);
+        q8b.val[1]    = vld1q_s8(y[ib + 0].qs + 16);
+        q8b.val[2]    = vld1q_s8(y[ib + 1].qs);
+        q8b.val[3]    = vld1q_s8(y[ib + 1].qs + 16);
+
+        q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits.val[0], m4b));
+        q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
+        q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits.val[1], m4b));
+        q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
+
+        prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
+        prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
+
+        sumf +=
+            GGML_E8M0_TO_FP32_HALF(x[ib + 0].e) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) +
+            GGML_E8M0_TO_FP32_HALF(x[ib + 1].e) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2);
+    }
+
+#endif
+    for (; ib < nb; ++ib) {
+        const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
+        int sumi1 = 0;
+        int sumi2 = 0;
+        for (int j = 0; j < QK_MXFP4/2; ++j) {
+            sumi1 += y[ib].qs[j +          0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
+            sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >>  4];
+        }
+        sumf += d * (sumi1 + sumi2);
+    }
+    *s = sumf;
+}
+
 void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     const int qk = QK8_0;
     const int nb = n / qk;
@@ -1236,44 +1297,10 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
     *s = sumf;
 
 #else
-    const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243};
-
-    float sumf = 0.0f;
-
-    for (int i = 0; i < nb; ++i) {
-        int sum = 0;
-
-        for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) {
-            for (size_t l = 0; l < 5; ++l) {
-                for (size_t m = 0; m < 32; ++m) {
-                    uint8_t q = x[i].qs[j + m] * pow3[l];
-                    uint16_t xi = ((uint16_t) q * 3) >> 8;
-                    sum += (xi - 1) * y[i].qs[j*5 + l*32 + m];
-                }
-            }
-        }
-        for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) {
-            for (size_t l = 0; l < 5; ++l) {
-                for (size_t m = 0; m < 16; ++m) {
-                    uint8_t q = x[i].qs[j + m] * pow3[l];
-                    uint16_t xi = ((uint16_t) q * 3) >> 8;
-                    sum += (xi - 1) * y[i].qs[j*5 + l*16 + m];
-                }
-            }
-        }
-
-        for (size_t l = 0; l < 4; ++l) {
-            for (size_t j = 0; j < sizeof(x->qh); ++j) {
-                uint8_t q = x[i].qh[j] * pow3[l];
-                uint16_t xi = ((uint16_t) q * 3) >> 8;
-                sum += (xi - 1) * y[i].qs[sizeof(x->qs)*5 + l*sizeof(x->qh) + j];
-            }
-        }
-
-        sumf += (float) sum * (GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d);
-    }
-
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -1381,25 +1408,10 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
     *s = sumf;
 
 #else
-    float sumf = 0.0f;
-
-    for (int i = 0; i < nb; ++i) {
-        int32_t sumi = 0;
-
-        for (size_t j = 0; j < sizeof(x->qs); j += 32) {
-            for (size_t l = 0; l < 4; ++l) {
-                for (size_t k = 0; k < 32; ++k) {
-                    sumi += y[i].qs[j*4 + l*32 + k] * (((x[i].qs[j + k] >> (l*2)) & 3) - 1);
-                }
-            }
-        }
-
-        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
-
-        sumf += (float) sumi * d;
-    }
-
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -1729,45 +1741,10 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
     *s = sum;
 
 #else
-
-    float sumf = 0;
-
-    for (int i = 0; i < nb; ++i) {
-
-        const uint8_t * q2 = x[i].qs;
-        const  int8_t * q8 = y[i].qs;
-        const uint8_t * sc = x[i].scales;
-
-        int summs = 0;
-        for (int j = 0; j < 16; ++j) {
-            summs += y[i].bsums[j] * (sc[j] >> 4);
-        }
-
-        const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
-        const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
-
-        int isum = 0;
-        int is = 0;
-        int d;
-        for (int k = 0; k < QK_K/128; ++k) {
-            int shift = 0;
-            for (int j = 0; j < 4; ++j) {
-                d = sc[is++] & 0xF;
-                int isuml = 0;
-                for (int l =  0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
-                isum += d * isuml;
-                d = sc[is++] & 0xF;
-                isuml = 0;
-                for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
-                isum += d * isuml;
-                shift += 2;
-                q8 += 32;
-            }
-            q2 += 32;
-        }
-        sumf += dall * isum - dmin * summs;
-    }
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -2057,68 +2034,12 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
     *s = sum;
 
 #else
-    // scalar version
-    // This function is written like this so the compiler can manage to vectorize most of it
-    // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
-    // manually vectorized version above. Every other version I tried would run at least 4 times slower.
-    // The ideal situation would be if we could just write the code once, and the compiler would
-    // automatically produce the best possible set of machine instructions, instead of us having to manually
-    // write vectorized versions for AVX, ARM_NEON, etc.
-
-    int8_t  aux8[QK_K];
-    int16_t aux16[8];
-    float   sums [8];
-    int32_t aux32[8];
-    memset(sums, 0, 8*sizeof(float));
-
-    uint32_t auxs[4];
-    const int8_t * scales = (const int8_t*)auxs;
-
-    float sumf = 0;
-    for (int i = 0; i < nb; ++i) {
-        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
-        const uint8_t * GGML_RESTRICT hm = x[i].hmask;
-        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
-        memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * GGML_RESTRICT a = aux8;
-        uint8_t m = 1;
-        for (int j = 0; j < QK_K; j += 128) {
-            for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
-            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
-            a += 32; m <<= 1;
-            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
-            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
-            a += 32; m <<= 1;
-            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3;
-            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
-            a += 32; m <<= 1;
-            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3;
-            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
-            a += 32; m <<= 1;
-            q3 += 32;
-        }
-        a = aux8;
-
-        memcpy(auxs, x[i].scales, 12);
-        uint32_t tmp = auxs[2];
-        auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
-        auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
-        auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
-        auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
-        for (int j = 0; j < QK_K/16; ++j) {
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
-            q8 += 8; a += 8;
-        }
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
-    }
-    for (int l = 0; l < 8; ++l) sumf += sums[l];
-    *s = sumf;
-
+    UNUSED(kmask1);
+    UNUSED(kmask2);
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 
 }
@@ -2431,61 +2352,14 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
     *s = sumf;
 
 #else
-
-    const uint8_t * scales = (const uint8_t*)&utmp[0];
-    const uint8_t * mins   = (const uint8_t*)&utmp[2];
-
-    int8_t  aux8[QK_K];
-    int16_t aux16[8];
-    float   sums [8];
-    int32_t aux32[8];
-    memset(sums, 0, 8*sizeof(float));
-
-    float sumf = 0;
-    for (int i = 0; i < nb; ++i) {
-        const uint8_t * GGML_RESTRICT q4 = x[i].qs;
-        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
-        memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * GGML_RESTRICT a = aux8;
-        for (int j = 0; j < QK_K/64; ++j) {
-            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
-            a += 32;
-            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l]  >> 4);
-            a += 32; q4 += 32;
-        }
-        memcpy(utmp, x[i].scales, 12);
-        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
-        const uint32_t uaux = utmp[1] & kmask1;
-        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
-        utmp[2] = uaux;
-        utmp[0] &= kmask1;
-
-        int sumi = 0;
-        for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
-        a = aux8;
-        int is = 0;
-        for (int j = 0; j < QK_K/32; ++j) {
-            int32_t scale = scales[is++];
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-        }
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
-        const float dmin = GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;
-        sumf -= dmin * sumi;
-    }
-    for (int l = 0; l < 8; ++l) sumf += sums[l];
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    UNUSED(kmask1);
+    UNUSED(kmask2);
+    UNUSED(kmask3);
+    UNUSED(utmp);
+    ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -2578,66 +2452,14 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
     *s = sumf;
 
 #else
-
-    const uint8_t * scales = (const uint8_t*)&utmp[0];
-    const uint8_t * mins   = (const uint8_t*)&utmp[2];
-
-    int8_t  aux8[QK_K];
-    int16_t aux16[8];
-    float   sums [8];
-    int32_t aux32[8];
-    memset(sums, 0, 8*sizeof(float));
-
-    float sumf = 0;
-    for (int i = 0; i < nb; ++i) {
-        const uint8_t * GGML_RESTRICT q4 = x[i].qs;
-        const uint8_t * GGML_RESTRICT hm = x[i].qh;
-        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
-        memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * GGML_RESTRICT a = aux8;
-        uint8_t m = 1;
-        for (int j = 0; j < QK_K/64; ++j) {
-            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
-            for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
-            a += 32; m <<= 1;
-            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l]  >> 4);
-            for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
-            a += 32; m <<= 1;
-            q4 += 32;
-        }
-        memcpy(utmp, x[i].scales, 12);
-        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
-        const uint32_t uaux = utmp[1] & kmask1;
-        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
-        utmp[2] = uaux;
-        utmp[0] &= kmask1;
-
-        int sumi = 0;
-        for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
-        a = aux8;
-        int is = 0;
-        for (int j = 0; j < QK_K/32; ++j) {
-            int32_t scale = scales[is++];
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-        }
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
-        const float dmin = GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;
-        sumf -= dmin * sumi;
-    }
-    for (int l = 0; l < 8; ++l) sumf += sums[l];
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    UNUSED(kmask1);
+    UNUSED(kmask2);
+    UNUSED(kmask3);
+    UNUSED(utmp);
+    ggml_vec_dot_q5_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -3093,47 +2915,10 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
     }
     *s = sum;
 #else
-
-    int8_t  aux8[QK_K];
-    int16_t aux16[8];
-    float   sums [8];
-    int32_t aux32[8];
-    memset(sums, 0, 8*sizeof(float));
-
-    float sumf = 0;
-    for (int i = 0; i < nb; ++i) {
-        const uint8_t * GGML_RESTRICT q4 = x[i].ql;
-        const uint8_t * GGML_RESTRICT qh = x[i].qh;
-        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
-        memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * GGML_RESTRICT a = aux8;
-        for (int j = 0; j < QK_K; j += 128) {
-            for (int l = 0; l < 32; ++l) {
-                a[l +  0] = (int8_t)((q4[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
-                a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
-                a[l + 64] = (int8_t)((q4[l +  0] >>  4) | (((qh[l] >> 4) & 3) << 4)) - 32;
-                a[l + 96] = (int8_t)((q4[l + 32] >>  4) | (((qh[l] >> 6) & 3) << 4)) - 32;
-            }
-            a  += 128;
-            q4 += 64;
-            qh += 32;
-        }
-        a = aux8;
-        int is = 0;
-        for (int j = 0; j < QK_K/16; ++j) {
-            int scale = x[i].scales[is++];
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-        }
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
-    }
-    for (int l = 0; l < 8; ++l) sumf += sums[l];
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -3229,34 +3014,10 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const
     *s = 0.25f * sumf;
 
 #else
-
-    uint32_t aux32[2];
-    const uint8_t * aux8 = (const uint8_t *)aux32;
-
-    float sumf = 0.f;
-    for (int i = 0; i < nb; ++i) {
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint16_t * GGML_RESTRICT q2 = x[i].qs;
-        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
-        int32_t bsum = 0;
-        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
-            memcpy(aux32, q2, 2*sizeof(uint32_t));
-            q2 += 4;
-            const uint32_t ls = 2*(aux32[1] >> 28) + 1;
-            int32_t sumi = 0;
-            for (int l = 0; l < 4; ++l) {
-                const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
-                const uint8_t  signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
-                for (int j = 0; j < 8; ++j) {
-                    sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            bsum += sumi * ls;
-        }
-        sumf += d * bsum;
-    }
-    *s = 0.125f * sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq2_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -3327,42 +3088,10 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
     *s = 0.125f * sumf;
 
 #else
-
-    float sumf = 0.f;
-    for (int i = 0; i < nb; ++i) {
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint16_t * GGML_RESTRICT q2 = x[i].qs;
-        const uint8_t  * GGML_RESTRICT sc = x[i].scales;
-        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
-        int32_t bsum = 0;
-        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
-            const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1;
-            const uint16_t ls2 = 2*(sc[ib32] >>  4) + 1;
-            int32_t sumi = 0;
-            for (int l = 0; l < 2; ++l) {
-                const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
-                const uint8_t  signs = ksigns_iq2xs[q2[l] >> 9];
-                for (int j = 0; j < 8; ++j) {
-                    sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            bsum += sumi * ls1;
-            sumi = 0;
-            for (int l = 2; l < 4; ++l) {
-                const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
-                const uint8_t  signs = ksigns_iq2xs[q2[l] >> 9];
-                for (int j = 0; j < 8; ++j) {
-                    sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            bsum += sumi * ls2;
-            q2 += 4;
-        }
-        sumf += d * bsum;
-    }
-    *s = 0.125f * sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -3455,45 +3184,10 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
     *s = 0.125f * sumf;
 
 #else
-
-    float sumf = 0;
-    for (int i = 0; i < nb; i++) {
-
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        const int8_t  * q8 = y[i].qs;
-        const uint8_t * qs = x[i].qs;
-        const uint8_t * qh = x[i].qh;
-        const uint8_t * signs = qs + QK_K/8;
-
-        int bsum = 0;
-        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
-            int ls1 = 1 + 2*(x[i].scales[ib32] & 0xf);
-            int ls2 = 1 + 2*(x[i].scales[ib32] >>  4);
-            int sumi1 = 0, sumi2 = 0;
-            for (int l = 0; l < 2; ++l) {
-                const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300)));
-                for (int j = 0; j < 8; ++j) {
-                    sumi1 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            for (int l = 2; l < 4; ++l) {
-                const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300)));
-                for (int j = 0; j < 8; ++j) {
-                    sumi2 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            bsum += ls1 * sumi1 + ls2 * sumi2;
-            qs += 4;
-            signs += 4;
-        }
-
-        sumf += d * bsum;
-    }
-
-    *s = 0.125f * sumf;
-
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 
 }
@@ -3553,36 +3247,10 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const
     *s = 0.5f * sumf;
 
 #else
-
-    uint32_t aux32;
-
-    float sumf = 0.f;
-    for (int i = 0; i < nb; ++i) {
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
-        const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4;
-        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
-        int32_t bsum = 0;
-        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
-            memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t);
-            const uint32_t ls = 2*(aux32 >> 28) + 1;
-            int32_t sumi = 0;
-            for (int l = 0; l < 4; ++l) {
-                const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]);
-                const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]);
-                const uint8_t  signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
-                for (int j = 0; j < 4; ++j) {
-                    sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1);
-                    sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            q3 += 8;
-            bsum += sumi * ls;
-        }
-        sumf += d * bsum;
-    }
-    *s = 0.25f * sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -3689,48 +3357,10 @@ void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
     *s = sumf;
 
 #else
-
-    float sumf = 0.f;
-    for (int i = 0; i < nb; ++i) {
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint8_t * GGML_RESTRICT qs = x[i].qs;
-        const uint8_t * GGML_RESTRICT qh = x[i].qh;
-        const uint8_t * GGML_RESTRICT signs = x[i].signs;
-        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
-        int32_t bsum = 0;
-        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
-            const uint32_t ls1 = 2*(x[i].scales[ib32/2] & 0xf) + 1;
-            const uint32_t ls2 = 2*(x[i].scales[ib32/2] >>  4) + 1;
-            int32_t sumi = 0;
-            for (int l = 0; l < 4; ++l) {
-                const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256)));
-                const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256)));
-                for (int j = 0; j < 4; ++j) {
-                    sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
-                    sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            qs += 8;
-            signs += 4;
-            bsum += sumi * ls1;
-            sumi = 0;
-            for (int l = 0; l < 4; ++l) {
-                const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256)));
-                const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256)));
-                for (int j = 0; j < 4; ++j) {
-                    sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
-                    sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            qs += 8;
-            signs += 4;
-            bsum += sumi * ls2;
-        }
-        sumf += d * bsum;
-    }
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -3793,36 +3423,10 @@ void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
     *s = sumf;
 
 #else
-
-    float sumf = 0;
-    for (int i = 0; i < nb; i++) {
-
-        const int8_t   * q8 = y[i].qs;
-        const uint8_t  * qs = x[i].qs;
-        const uint16_t * qh = x[i].qh;
-
-        int sumi = 0, sumi1 = 0;
-        for (int ib = 0; ib < QK_K/32; ++ib) {
-            const int ls = 2*((qh[ib] >> 12) & 7) + 1;
-            const int delta = qh[ib] & 0x8000 ? -1 : 1;
-            int lsum = 0;
-            for (int l = 0; l < 4; ++l) {
-                const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8)));
-                for (int j = 0; j < 8; ++j) {
-                    lsum += q8[j] * grid[j];
-                }
-                q8 += 8;
-            }
-            sumi  += ls * lsum;
-            sumi1 += ls * delta * (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]);
-            qs += 4;
-        }
-
-        sumf += GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1);
-    }
-
-    *s = sumf;
-
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -3912,52 +3516,11 @@ void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
     *s = sumf;
 
 #else
-
-    int sum1[2], sum2[2], delta[4];
-
-    float sumf = 0;
-    for (int i = 0; i < nb; i++) {
-
-        const int8_t   * q8 = y[i].qs;
-        const uint8_t  * qs = x[i].qs;
-        const uint8_t  * qh = x[i].qh;
-        const uint16_t * sc = (const uint16_t *)x[i].scales;
-
-        scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
-
-        int sumi1 = 0, sumi2 = 0;
-        for (int ib = 0; ib < QK_K/32; ++ib) {
-            delta[0] = qh[0] & 0x08 ? -1 : 1;
-            delta[1] = qh[0] & 0x80 ? -1 : 1;
-            delta[2] = qh[1] & 0x08 ? -1 : 1;
-            delta[3] = qh[1] & 0x80 ? -1 : 1;
-            sum1[0] = sum1[1] = sum2[0] = sum2[1] = 0;
-            for (int l = 0; l < 4; ++l) {
-                const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((uint16_t)qh[l/2] << (8 - 4*(l%2))) & 0x700)));
-                int lsum1 = 0, lsum2 = 0;
-                for (int j = 0; j < 8; ++j) {
-                    lsum1 += q8[j] * grid[j];
-                    lsum2 += q8[j];
-                }
-                q8 += 8;
-                sum1[l/2] += lsum1;
-                sum2[l/2] += lsum2*delta[l];
-            }
-
-            const int ls1 = 2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1;
-            const int ls2 = 2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1;
-
-            sumi1 += sum1[0] * ls1 + sum1[1] * ls2;
-            sumi2 += sum2[0] * ls1 + sum2[1] * ls2;
-            qs += 4;
-            qh += 2;
-        }
-
-        sumf += GGML_CPU_FP16_TO_FP32(scale.f16) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2);
-    }
-
-    *s = sumf;
-
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    UNUSED(scale);
+    ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -4078,37 +3641,10 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
     *s = sumf;
 
 #else
-    float sumf = 0;
-    for (int ibl = 0; ibl < nb; ++ibl) {
-        const float d4d8 = GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d;
-        uint16_t h = x[ibl].scales_h;
-        const uint8_t * qs = x[ibl].qs;
-        const int8_t  * q8 = y[ibl].qs;
-        for (int ib = 0; ib < QK_K/32; ib += 2) {
-            const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30);
-            const uint8_t ls2 = (x[ibl].scales_l[ib/2] >>  4) | ((h << 2) & 0x30);
-            h >>= 4;
-            const float d1 = d4d8*(ls1 - 32);
-            const float d2 = d4d8*(ls2 - 32);
-            int sumi1 = 0, sumi2 = 0;
-            for (int j = 0; j < 16; ++j) {
-                sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
-                sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >>  4];
-            }
-            sumf += d1 * (sumi1 + sumi2);
-            qs += 16;
-            q8 += 32;
-            sumi1 = sumi2 = 0;
-            for (int j = 0; j < 16; ++j) {
-                sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
-                sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >>  4];
-            }
-            sumf += d2 * (sumi1 + sumi2);
-            qs += 16;
-            q8 += 32;
-        }
-    }
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp
index 2f8bc9e2517..fdd0a513b83 100644
--- a/ggml/src/ggml-cpu/arch/arm/repack.cpp
+++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp
@@ -86,35 +86,9 @@ void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTR
         }
     }
 #else
-    // scalar
-    const int blck_size_interleave = 4;
-    float srcv[4][QK8_0];
-    float id[4];
-
-    for (int i = 0; i < nb; i++) {
-        for (int row_iter = 0; row_iter < 4; row_iter++) {
-            float amax = 0.0f; // absolute max
-
-            for (int j = 0; j < QK8_0; j++) {
-                srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j];
-                amax = MAX(amax, fabsf(srcv[row_iter][j]));
-            }
-
-            const float d = amax / ((1 << 7) - 1);
-            id[row_iter] = d ? 1.0f / d : 0.0f;
-
-            y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d);
-        }
-
-        for (int j = 0; j < QK8_0 * 4; j++) {
-            int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
-            int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
-            src_offset += (j % blck_size_interleave);
-
-            float x0 = srcv[src_id][src_offset] * id[src_id];
-            y[i].qs[j] = roundf(x0);
-        }
-    }
+    UNUSED(nb);
+    UNUSED(y);
+    ggml_quantize_mat_q8_0_4x4_generic(x, vy, k);
 #endif
 }
 
@@ -205,35 +179,9 @@ void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR
     }
 
 #else
-    // scalar
-    const int blck_size_interleave = 8;
-    float srcv[4][QK8_0];
-    float id[4];
-
-    for (int i = 0; i < nb; i++) {
-        for (int row_iter = 0; row_iter < 4; row_iter++) {
-            float amax = 0.0f; // absolute max
-
-            for (int j = 0; j < QK8_0; j++) {
-                srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j];
-                amax = MAX(amax, fabsf(srcv[row_iter][j]));
-            }
-
-            const float d = amax / ((1 << 7) - 1);
-            id[row_iter] = d ? 1.0f / d : 0.0f;
-
-            y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d);
-        }
-
-        for (int j = 0; j < QK8_0 * 4; j++) {
-            int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
-            int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
-            src_offset += (j % blck_size_interleave);
-
-            float x0 = srcv[src_id][src_offset] * id[src_id];
-            y[i].qs[j] = roundf(x0);
-        }
-    }
+    UNUSED(nb);
+    UNUSED(y);
+    ggml_quantize_mat_q8_0_4x8_generic(x, vy, k);
 #endif
 }
 
@@ -295,29 +243,7 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
     }
     return;
 #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
-    float sumf[4];
-    int sumi;
-
-    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
-    for (int x = 0; x < nc / ncols_interleaved; x++) {
-        const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
-
-        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
-        for (int l = 0; l < nb; l++) {
-            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
-                for (int j = 0; j < ncols_interleaved; j++) {
-                    sumi = 0;
-                    for (int i = 0; i < blocklen; ++i) {
-                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
-                        const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
-                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
-                    }
-                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
-                }
-            }
-        }
-        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
-    }
+    ggml_gemv_q4_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
 }
 
 void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -383,29 +309,7 @@ void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
     }
     return;
 #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
-    float sumf[4];
-    int sumi;
-
-    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
-    for (int x = 0; x < nc / ncols_interleaved; x++) {
-        const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
-
-        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
-        for (int l = 0; l < nb; l++) {
-            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
-                for (int j = 0; j < ncols_interleaved; j++) {
-                    sumi = 0;
-                    for (int i = 0; i < blocklen; ++i) {
-                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
-                        const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
-                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
-                    }
-                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
-                }
-            }
-        }
-        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
-    }
+    ggml_gemv_q4_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
 }
 
 void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -497,31 +401,7 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
 #endif // #if defined(__ARM_FEATURE_SVE)
 
 #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
-    {
-        float sumf[8];
-        int sumi;
-
-        const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
-        for (int x = 0; x < nc / ncols_interleaved; x++) {
-            const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
-
-            for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
-            for (int l = 0; l < nb; l++) {
-                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
-                    for (int j = 0; j < ncols_interleaved; j++) {
-                        sumi = 0;
-                        for (int i = 0; i < blocklen; ++i) {
-                            const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
-                            const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
-                            sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
-                        }
-                        sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
-                    }
-                }
-            }
-            for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
-        }
-    }
+    ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
 }
 
 void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -591,31 +471,7 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
     }
     return;
 #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
-    {
-        float sumf[4];
-        int sumi;
-
-        const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
-        for (int x = 0; x < nc / ncols_interleaved; x++) {
-            const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
-
-            for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
-            for (int l = 0; l < nb; l++) {
-                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
-                    for (int j = 0; j < ncols_interleaved; j++) {
-                        sumi = 0;
-                        for (int i = 0; i < blocklen; ++i) {
-                            const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
-                            const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
-                            sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
-                        }
-                        sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
-                    }
-                }
-            }
-            for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
-        }
-    }
+    ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
 }
 
 void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -1096,40 +952,7 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
     );
     return;
 #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
-    {
-        float sumf[4][4];
-        int sumi;
-
-        for (int y = 0; y < nr / 4; y++) {
-            const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
-            for (int x = 0; x < nc / ncols_interleaved; x++) {
-                const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
-                for (int m = 0; m < 4; m++) {
-                    for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
-                }
-                for (int l = 0; l < nb; l++) {
-                    for (int k = 0; k < (qk / (2 * blocklen)); k++) {
-                        for (int m = 0; m < 4; m++) {
-                            for (int j = 0; j < ncols_interleaved; j++) {
-                                sumi = 0;
-                                for (int i = 0; i < blocklen; ++i) {
-                                    const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
-                                    const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
-                                    sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
-                                            (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
-                                }
-                                sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
-                            }
-                        }
-                    }
-                }
-                for (int m = 0; m < 4; m++) {
-                    for (int j = 0; j < ncols_interleaved; j++)
-                        s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
-                }
-            }
-        }
-    }
+    ggml_gemm_q4_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
 }
 
 void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -1550,38 +1373,7 @@ void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
     );
     return;
 #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
-    float sumf[4][4];
-    int sumi;
-
-    for (int y = 0; y < nr / 4; y++) {
-        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
-        for (int x = 0; x < nc / ncols_interleaved; x++) {
-            const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
-            for (int m = 0; m < 4; m++) {
-                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
-            }
-            for (int l = 0; l < nb; l++) {
-                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
-                    for (int m = 0; m < 4; m++) {
-                        for (int j = 0; j < ncols_interleaved; j++) {
-                            sumi = 0;
-                            for (int i = 0; i < blocklen; ++i) {
-                                const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
-                                const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
-                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
-                                        (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
-                            }
-                            sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
-                        }
-                    }
-                }
-            }
-            for (int m = 0; m < 4; m++) {
-                for (int j = 0; j < ncols_interleaved; j++)
-                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
-            }
-        }
-    }
+    ggml_gemm_q4_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
 }
 
 void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -2019,38 +1811,7 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
 #endif // #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
 
 #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
-    float sumf[4][8];
-    int sumi;
-
-    for (int y = 0; y < nr / 4; y++) {
-        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
-        for (int x = 0; x < nc / ncols_interleaved; x++) {
-            const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
-            for (int m = 0; m < 4; m++) {
-                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
-            }
-            for (int l = 0; l < nb; l++) {
-                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
-                    for (int m = 0; m < 4; m++) {
-                        for (int j = 0; j < ncols_interleaved; j++) {
-                            sumi = 0;
-                            for (int i = 0; i < blocklen; ++i) {
-                                const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
-                                const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
-                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
-                                         (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
-                            }
-                            sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
-                        }
-                    }
-                }
-            }
-            for (int m = 0; m < 4; m++) {
-                for (int j = 0; j < ncols_interleaved; j++)
-                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
-            }
-        }
-    }
+    ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
 }
 
 void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -2126,38 +1887,5 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
     }
     return;
 #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
-    {
-        float sumf[4][4];
-        int sumi;
-
-        for (int y = 0; y < nr / 4; y++) {
-            const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
-            for (int x = 0; x < nc / ncols_interleaved; x++) {
-                const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
-                for (int m = 0; m < 4; m++) {
-                    for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
-                }
-                for (int l = 0; l < nb; l++) {
-                    for (int k = 0; k < (qk / (2 * blocklen)); k++) {
-                        for (int m = 0; m < 4; m++) {
-                            for (int j = 0; j < ncols_interleaved; j++) {
-                                sumi = 0;
-                                for (int i = 0; i < blocklen; ++i) {
-                                    const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
-                                    const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
-                                    sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
-                                            (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));
-                                }
-                                sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
-                            }
-                        }
-                    }
-                }
-                for (int m = 0; m < 4; m++) {
-                    for (int j = 0; j < ncols_interleaved; j++)
-                        s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
-                }
-            }
-        }
-    }
+    ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
 }
diff --git a/ggml/src/ggml-cpu/arch/loongarch/quants.c b/ggml/src/ggml-cpu/arch/loongarch/quants.c
index 9e33fb32286..0f9af7bf520 100644
--- a/ggml/src/ggml-cpu/arch/loongarch/quants.c
+++ b/ggml/src/ggml-cpu/arch/loongarch/quants.c
@@ -544,7 +544,7 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
         __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs, 0) );
         max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
         __m128 tmp = max4;
-        max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x10 ));
+        max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x1 ));
         const float max_scalar = ((v4f32)max4)[0];
 
         // Quantize these floats
@@ -821,24 +821,15 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
 
     sumf = hsum_float_8(acc) + summs;
 
-#endif
-    for (; ib < nb; ++ib) {
-        int sumi0 = 0;
-        int sumi1 = 0;
-
-        for (int j = 0; j < qk/2; ++j) {
-            const int v0 = (x[ib].qs[j] & 0x0F);
-            const int v1 = (x[ib].qs[j] >>   4);
-
-            sumi0 += (v0 * y[ib].qs[j]);
-            sumi1 += (v1 * y[ib].qs[j + qk/2]);
-        }
-
-        int sumi = sumi0 + sumi1;
-        sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);
-    }
-
     *s = sumf;
+#else
+    UNUSED(nb);
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(ib);
+    UNUSED(sumf);
+    ggml_vec_dot_q4_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
 }
 
 void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
@@ -883,30 +874,15 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
 
     sumf = hsum_float_8(acc);
 
-#endif
-    for (; ib < nb; ++ib) {
-        uint32_t qh;
-        memcpy(&qh, x[ib].qh, sizeof(qh));
-
-        int sumi0 = 0;
-        int sumi1 = 0;
-
-        for (int j = 0; j < qk/2; ++j) {
-            const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
-            const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
-
-            const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16);
-            const int32_t x1 = (int8_t)(((x[ib].qs[j] >>   4) | xh_1) - 16);
-
-            sumi0 += (x0 * y[ib].qs[j]);
-            sumi1 += (x1 * y[ib].qs[j + qk/2]);
-        }
-
-        int sumi = sumi0 + sumi1;
-        sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d)) * sumi;
-    }
-
     *s = sumf;
+#else
+    UNUSED(nb);
+    UNUSED(ib);
+    UNUSED(sumf);
+    UNUSED(x);
+    UNUSED(y);
+    ggml_vec_dot_q5_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
 }
 
 void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
@@ -954,30 +930,15 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
 
     sumf = hsum_float_8(acc) + summs;
 
-#endif
-    for (; ib < nb; ++ib) {
-        uint32_t qh;
-        memcpy(&qh, x[ib].qh, sizeof(qh));
-
-        int sumi0 = 0;
-        int sumi1 = 0;
-
-        for (int j = 0; j < qk/2; ++j) {
-            const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;
-            const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;
-
-            const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0;
-            const int32_t x1 = (x[ib].qs[j] >>  4) | xh_1;
-
-            sumi0 += (x0 * y[ib].qs[j]);
-            sumi1 += (x1 * y[ib].qs[j + qk/2]);
-        }
-
-        int sumi = sumi0 + sumi1;
-        sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);
-    }
-
     *s = sumf;
+#else
+    UNUSED(nb);
+    UNUSED(ib);
+    UNUSED(sumf);
+    UNUSED(x);
+    UNUSED(y);
+    ggml_vec_dot_q5_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
 }
 
 void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
@@ -1016,18 +977,15 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
 
     sumf = hsum_float_8(acc);
 
-#endif
-    for (; ib < nb; ++ib) {
-        int sumi = 0;
-
-        for (int j = 0; j < qk; j++) {
-            sumi += x[ib].qs[j]*y[ib].qs[j];
-        }
-
-        sumf += sumi*(GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d));
-    }
-
     *s = sumf;
+#else
+    UNUSED(nb);
+    UNUSED(ib);
+    UNUSED(sumf);
+    UNUSED(x);
+    UNUSED(y);
+    ggml_vec_dot_q8_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
 }
 
 void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
@@ -1103,45 +1061,10 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
     *s = hsum_float_8(acc);
 
 #else
-
-    float sumf = 0;
-
-    for (int i = 0; i < nb; ++i) {
-
-        const uint8_t * q2 = x[i].qs;
-        const  int8_t * q8 = y[i].qs;
-        const uint8_t * sc = x[i].scales;
-
-        int summs = 0;
-        for (int j = 0; j < 16; ++j) {
-            summs += y[i].bsums[j] * (sc[j] >> 4);
-        }
-
-        const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
-        const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
-
-        int isum = 0;
-        int is = 0;
-        int d;
-        for (int k = 0; k < QK_K/128; ++k) {
-            int shift = 0;
-            for (int j = 0; j < 4; ++j) {
-                d = sc[is++] & 0xF;
-                int isuml = 0;
-                for (int l =  0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
-                isum += d * isuml;
-                d = sc[is++] & 0xF;
-                isuml = 0;
-                for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
-                isum += d * isuml;
-                shift += 2;
-                q8 += 32;
-            }
-            q2 += 32;
-        }
-        sumf += dall * isum - dmin * summs;
-    }
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -1239,70 +1162,13 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
     *s = hsum_float_8(acc);
 
 #else
-    // scalar version
-    // This function is written like this so the compiler can manage to vectorize most of it
-    // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
-    // manually vectorized version above. Every other version I tried would run at least 4 times slower.
-    // The ideal situation would be if we could just write the code once, and the compiler would
-    // automatically produce the best possible set of machine instructions, instead of us having to manually
-    // write vectorized versions for AVX, ARM_NEON, etc.
-
-    int8_t  aux8[QK_K];
-    int16_t aux16[8];
-    float   sums [8];
-    int32_t aux32[8];
-    memset(sums, 0, 8*sizeof(float));
-
-    uint32_t auxs[4];
-    const int8_t * scales = (const int8_t*)auxs;
-
-    float sumf = 0;
-    for (int i = 0; i < nb; ++i) {
-        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
-        const uint8_t * GGML_RESTRICT hm = x[i].hmask;
-        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
-        memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * GGML_RESTRICT a = aux8;
-        uint8_t m = 1;
-        for (int j = 0; j < QK_K; j += 128) {
-            for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
-            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
-            a += 32; m <<= 1;
-            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
-            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
-            a += 32; m <<= 1;
-            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3;
-            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
-            a += 32; m <<= 1;
-            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3;
-            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
-            a += 32; m <<= 1;
-            q3 += 32;
-        }
-        a = aux8;
-
-        memcpy(auxs, x[i].scales, 12);
-        uint32_t tmp = auxs[2];
-        auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
-        auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
-        auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
-        auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
-        for (int j = 0; j < QK_K/16; ++j) {
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
-            q8 += 8; a += 8;
-        }
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
-    }
-    for (int l = 0; l < 8; ++l) sumf += sums[l];
-    *s = sumf;
-
+    UNUSED(kmask1);
+    UNUSED(kmask2);
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
-
 }
 
 void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
@@ -1391,61 +1257,14 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
     *s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
 
 #else
-
-    const uint8_t * scales = (const uint8_t*)&utmp[0];
-    const uint8_t * mins   = (const uint8_t*)&utmp[2];
-
-    int8_t  aux8[QK_K];
-    int16_t aux16[8];
-    float   sums [8];
-    int32_t aux32[8];
-    memset(sums, 0, 8*sizeof(float));
-
-    float sumf = 0;
-    for (int i = 0; i < nb; ++i) {
-        const uint8_t * GGML_RESTRICT q4 = x[i].qs;
-        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
-        memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * GGML_RESTRICT a = aux8;
-        for (int j = 0; j < QK_K/64; ++j) {
-            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
-            a += 32;
-            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l]  >> 4);
-            a += 32; q4 += 32;
-        }
-        memcpy(utmp, x[i].scales, 12);
-        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
-        const uint32_t uaux = utmp[1] & kmask1;
-        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
-        utmp[2] = uaux;
-        utmp[0] &= kmask1;
-
-        int sumi = 0;
-        for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
-        a = aux8;
-        int is = 0;
-        for (int j = 0; j < QK_K/32; ++j) {
-            int32_t scale = scales[is++];
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-        }
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
-        const float dmin = GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;
-        sumf -= dmin * sumi;
-    }
-    for (int l = 0; l < 8; ++l) sumf += sums[l];
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    UNUSED(kmask1);
+    UNUSED(kmask2);
+    UNUSED(kmask3);
+    UNUSED(utmp);
+    ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -1541,66 +1360,14 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
     *s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
 
 #else
-
-    const uint8_t * scales = (const uint8_t*)&utmp[0];
-    const uint8_t * mins   = (const uint8_t*)&utmp[2];
-
-    int8_t  aux8[QK_K];
-    int16_t aux16[8];
-    float   sums [8];
-    int32_t aux32[8];
-    memset(sums, 0, 8*sizeof(float));
-
-    float sumf = 0;
-    for (int i = 0; i < nb; ++i) {
-        const uint8_t * GGML_RESTRICT q4 = x[i].qs;
-        const uint8_t * GGML_RESTRICT hm = x[i].qh;
-        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
-        memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * GGML_RESTRICT a = aux8;
-        uint8_t m = 1;
-        for (int j = 0; j < QK_K/64; ++j) {
-            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
-            for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
-            a += 32; m <<= 1;
-            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l]  >> 4);
-            for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
-            a += 32; m <<= 1;
-            q4 += 32;
-        }
-        memcpy(utmp, x[i].scales, 12);
-        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
-        const uint32_t uaux = utmp[1] & kmask1;
-        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
-        utmp[2] = uaux;
-        utmp[0] &= kmask1;
-
-        int sumi = 0;
-        for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
-        a = aux8;
-        int is = 0;
-        for (int j = 0; j < QK_K/32; ++j) {
-            int32_t scale = scales[is++];
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-        }
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
-        const float dmin = GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;
-        sumf -= dmin * sumi;
-    }
-    for (int l = 0; l < 8; ++l) sumf += sums[l];
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    UNUSED(kmask1);
+    UNUSED(kmask2);
+    UNUSED(kmask3);
+    UNUSED(utmp);
+    ggml_vec_dot_q5_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -1678,47 +1445,10 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
     *s = hsum_float_8(acc);
 
 #else
-
-    int8_t  aux8[QK_K];
-    int16_t aux16[8];
-    float   sums [8];
-    int32_t aux32[8];
-    memset(sums, 0, 8*sizeof(float));
-
-    float sumf = 0;
-    for (int i = 0; i < nb; ++i) {
-        const uint8_t * GGML_RESTRICT q4 = x[i].ql;
-        const uint8_t * GGML_RESTRICT qh = x[i].qh;
-        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
-        memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * GGML_RESTRICT a = aux8;
-        for (int j = 0; j < QK_K; j += 128) {
-            for (int l = 0; l < 32; ++l) {
-                a[l +  0] = (int8_t)((q4[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
-                a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
-                a[l + 64] = (int8_t)((q4[l +  0] >>  4) | (((qh[l] >> 4) & 3) << 4)) - 32;
-                a[l + 96] = (int8_t)((q4[l + 32] >>  4) | (((qh[l] >> 6) & 3) << 4)) - 32;
-            }
-            a  += 128;
-            q4 += 64;
-            qh += 32;
-        }
-        a = aux8;
-        int is = 0;
-        for (int j = 0; j < QK_K/16; ++j) {
-            int scale = x[i].scales[is++];
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-        }
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
-    }
-    for (int l = 0; l < 8; ++l) sumf += sums[l];
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -1815,34 +1545,10 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const
     *s = 0.125f * hsum_float_8(accumf);
 
 #else
-
-    uint32_t aux32[2];
-    const uint8_t * aux8 = (const uint8_t *)aux32;
-
-    float sumf = 0.f;
-    for (int i = 0; i < nb; ++i) {
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint16_t * GGML_RESTRICT q2 = x[i].qs;
-        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
-        int32_t bsum = 0;
-        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
-            memcpy(aux32, q2, 2*sizeof(uint32_t));
-            q2 += 4;
-            const uint32_t ls = 2*(aux32[1] >> 28) + 1;
-            int32_t sumi = 0;
-            for (int l = 0; l < 4; ++l) {
-                const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
-                const uint8_t  signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
-                for (int j = 0; j < 8; ++j) {
-                    sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            bsum += sumi * ls;
-        }
-        sumf += d * bsum;
-    }
-    *s = 0.125f * sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq2_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -1978,42 +1684,10 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
     *s = 0.125f * hsum_float_8(accumf);
 
 #else
-
-    float sumf = 0.f;
-    for (int i = 0; i < nb; ++i) {
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint16_t * GGML_RESTRICT q2 = x[i].qs;
-        const uint8_t  * GGML_RESTRICT sc = x[i].scales;
-        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
-        int32_t bsum = 0;
-        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
-            const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1;
-            const uint16_t ls2 = 2*(sc[ib32] >>  4) + 1;
-            int32_t sumi = 0;
-            for (int l = 0; l < 2; ++l) {
-                const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
-                const uint8_t  signs = ksigns_iq2xs[q2[l] >> 9];
-                for (int j = 0; j < 8; ++j) {
-                    sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            bsum += sumi * ls1;
-            sumi = 0;
-            for (int l = 2; l < 4; ++l) {
-                const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
-                const uint8_t  signs = ksigns_iq2xs[q2[l] >> 9];
-                for (int j = 0; j < 8; ++j) {
-                    sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            bsum += sumi * ls2;
-            q2 += 4;
-        }
-        sumf += d * bsum;
-    }
-    *s = 0.125f * sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -2105,47 +1779,11 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
     *s = 0.125f * hsum_float_8(accumf);
 
 #else
-
-    float sumf = 0;
-    for (int i = 0; i < nb; i++) {
-
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        const int8_t  * q8 = y[i].qs;
-        const uint8_t * qs = x[i].qs;
-        const uint8_t * qh = x[i].qh;
-        const uint8_t * signs = qs + QK_K/8;
-
-        int bsum = 0;
-        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
-            int ls1 = 1 + 2*(x[i].scales[ib32] & 0xf);
-            int ls2 = 1 + 2*(x[i].scales[ib32] >>  4);
-            int sumi1 = 0, sumi2 = 0;
-            for (int l = 0; l < 2; ++l) {
-                const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300)));
-                for (int j = 0; j < 8; ++j) {
-                    sumi1 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            for (int l = 2; l < 4; ++l) {
-                const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300)));
-                for (int j = 0; j < 8; ++j) {
-                    sumi2 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            bsum += ls1 * sumi1 + ls2 * sumi2;
-            qs += 4;
-            signs += 4;
-        }
-
-        sumf += d * bsum;
-    }
-
-    *s = 0.125f * sumf;
-
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
-
 }
 
 void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
@@ -2209,36 +1847,10 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const
     *s = 0.25f * hsum_float_8(accumf);
 
 #else
-
-    uint32_t aux32;
-
-    float sumf = 0.f;
-    for (int i = 0; i < nb; ++i) {
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
-        const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4;
-        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
-        int32_t bsum = 0;
-        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
-            memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t);
-            const uint32_t ls = 2*(aux32 >> 28) + 1;
-            int32_t sumi = 0;
-            for (int l = 0; l < 4; ++l) {
-                const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]);
-                const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]);
-                const uint8_t  signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
-                for (int j = 0; j < 4; ++j) {
-                    sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1);
-                    sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            q3 += 8;
-            bsum += sumi * ls;
-        }
-        sumf += d * bsum;
-    }
-    *s = 0.25f * sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -2338,48 +1950,10 @@ void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
     *s = hsum_float_8(accumf);
 
 #else
-
-    float sumf = 0.f;
-    for (int i = 0; i < nb; ++i) {
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint8_t * GGML_RESTRICT qs = x[i].qs;
-        const uint8_t * GGML_RESTRICT qh = x[i].qh;
-        const uint8_t * GGML_RESTRICT signs = x[i].signs;
-        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
-        int32_t bsum = 0;
-        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
-            const uint32_t ls1 = 2*(x[i].scales[ib32/2] & 0xf) + 1;
-            const uint32_t ls2 = 2*(x[i].scales[ib32/2] >>  4) + 1;
-            int32_t sumi = 0;
-            for (int l = 0; l < 4; ++l) {
-                const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256)));
-                const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256)));
-                for (int j = 0; j < 4; ++j) {
-                    sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
-                    sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            qs += 8;
-            signs += 4;
-            bsum += sumi * ls1;
-            sumi = 0;
-            for (int l = 0; l < 4; ++l) {
-                const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256)));
-                const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256)));
-                for (int j = 0; j < 4; ++j) {
-                    sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
-                    sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            qs += 8;
-            signs += 4;
-            bsum += sumi * ls2;
-        }
-        sumf += d * bsum;
-    }
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -2460,36 +2034,10 @@ void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
     *s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
 
 #else
-
-    float sumf = 0;
-    for (int i = 0; i < nb; i++) {
-
-        const int8_t   * q8 = y[i].qs;
-        const uint8_t  * qs = x[i].qs;
-        const uint16_t * qh = x[i].qh;
-
-        int sumi = 0, sumi1 = 0;
-        for (int ib = 0; ib < QK_K/32; ++ib) {
-            const int ls = 2*((qh[ib] >> 12) & 7) + 1;
-            const int delta = qh[ib] & 0x8000 ? -1 : 1;
-            int lsum = 0;
-            for (int l = 0; l < 4; ++l) {
-                const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8)));
-                for (int j = 0; j < 8; ++j) {
-                    lsum += q8[j] * grid[j];
-                }
-                q8 += 8;
-            }
-            sumi  += ls * lsum;
-            sumi1 += ls * delta * (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]);
-            qs += 4;
-        }
-
-        sumf += GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1);
-    }
-
-    *s = sumf;
-
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -2603,37 +2151,10 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
     *s = hsum_float_8(accum);
 
 #else
-    float sumf = 0;
-    for (int ibl = 0; ibl < nb; ++ibl) {
-        const float d4d8 = GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d;
-        uint16_t h = x[ibl].scales_h;
-        const uint8_t * qs = x[ibl].qs;
-        const int8_t  * q8 = y[ibl].qs;
-        for (int ib = 0; ib < QK_K/32; ib += 2) {
-            const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30);
-            const uint8_t ls2 = (x[ibl].scales_l[ib/2] >>  4) | ((h << 2) & 0x30);
-            h >>= 4;
-            const float d1 = d4d8*(ls1 - 32);
-            const float d2 = d4d8*(ls2 - 32);
-            int sumi1 = 0, sumi2 = 0;
-            for (int j = 0; j < 16; ++j) {
-                sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
-                sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >>  4];
-            }
-            sumf += d1 * (sumi1 + sumi2);
-            qs += 16;
-            q8 += 32;
-            sumi1 = sumi2 = 0;
-            for (int j = 0; j < 16; ++j) {
-                sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
-                sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >>  4];
-            }
-            sumf += d2 * (sumi1 + sumi2);
-            qs += 16;
-            q8 += 32;
-        }
-    }
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
diff --git a/ggml/src/ggml-cpu/arch/powerpc/quants.c b/ggml/src/ggml-cpu/arch/powerpc/quants.c
index 053d5cbdc7b..49aae7a23bb 100644
--- a/ggml/src/ggml-cpu/arch/powerpc/quants.c
+++ b/ggml/src/ggml-cpu/arch/powerpc/quants.c
@@ -201,24 +201,14 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
 
     sumf = vec_extract(vsumf0, 0);
 
-#endif
-    for (; ib < nb; ++ib) {
-        int sumi0 = 0;
-        int sumi1 = 0;
-
-        for (int j = 0; j < qk/2; ++j) {
-            const int v0 = (x[ib].qs[j] & 0x0F) - 8;
-            const int v1 = (x[ib].qs[j] >>   4) - 8;
-
-            sumi0 += (v0 * y[ib].qs[j]);
-            sumi1 += (v1 * y[ib].qs[j + qk/2]);
-        }
-
-        int sumi = sumi0 + sumi1;
-        sumf += sumi*GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d);
-    }
-
     *s = sumf;
+#else
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(ib);
+    UNUSED(sumf);
+    ggml_vec_dot_q4_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
 }
 
 void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
@@ -278,24 +268,14 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
 
     sumf = vec_extract(vsumf0, 0);
 
-#endif
-    for (; ib < nb; ++ib) {
-        int sumi0 = 0;
-        int sumi1 = 0;
-
-        for (int j = 0; j < qk/2; ++j) {
-            const int v0 = (x[ib].qs[j] & 0x0F);
-            const int v1 = (x[ib].qs[j] >>   4);
-
-            sumi0 += (v0 * y[ib].qs[j]);
-            sumi1 += (v1 * y[ib].qs[j + qk/2]);
-        }
-
-        int sumi = sumi0 + sumi1;
-        sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);
-    }
-
     *s = sumf;
+#else
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(ib);
+    UNUSED(sumf);
+    ggml_vec_dot_q4_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
 }
 
 void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
@@ -360,30 +340,14 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
 
     sumf = vec_extract(vsumf0, 0);
 
-#endif
-    for (; ib < nb; ++ib) {
-        uint32_t qh;
-        memcpy(&qh, x[ib].qh, sizeof(qh));
-
-        int sumi0 = 0;
-        int sumi1 = 0;
-
-        for (int j = 0; j < qk/2; ++j) {
-            const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
-            const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
-
-            const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16);
-            const int32_t x1 = (int8_t)(((x[ib].qs[j] >>   4) | xh_1) - 16);
-
-            sumi0 += (x0 * y[ib].qs[j]);
-            sumi1 += (x1 * y[ib].qs[j + qk/2]);
-        }
-
-        int sumi = sumi0 + sumi1;
-        sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d)) * sumi;
-    }
-
     *s = sumf;
+#else
+    UNUSED(ib);
+    UNUSED(sumf);
+    UNUSED(x);
+    UNUSED(y);
+    ggml_vec_dot_q5_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
 }
 
 void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
@@ -451,30 +415,15 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
 
     sumf = vec_extract(vsumf0, 0);
 
-#endif
-    for (; ib < nb; ++ib) {
-        uint32_t qh;
-        memcpy(&qh, x[ib].qh, sizeof(qh));
-
-        int sumi0 = 0;
-        int sumi1 = 0;
-
-        for (int j = 0; j < qk/2; ++j) {
-            const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;
-            const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;
-
-            const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0;
-            const int32_t x1 = (x[ib].qs[j] >>  4) | xh_1;
-
-            sumi0 += (x0 * y[ib].qs[j]);
-            sumi1 += (x1 * y[ib].qs[j + qk/2]);
-        }
-
-        int sumi = sumi0 + sumi1;
-        sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);
-    }
-
     *s = sumf;
+#else
+    UNUSED(nb);
+    UNUSED(ib);
+    UNUSED(sumf);
+    UNUSED(x);
+    UNUSED(y);
+    ggml_vec_dot_q5_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
 }
 
 void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
@@ -535,18 +484,15 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
 
     sumf = vec_extract(vsumf0, 0);
 
-#endif
-    for (; ib < nb; ++ib) {
-        int sumi = 0;
-
-        for (int j = 0; j < qk; j++) {
-            sumi += x[ib].qs[j]*y[ib].qs[j];
-        }
-
-        sumf += sumi*(GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d));
-    }
-
     *s = sumf;
+#else
+    UNUSED(nb);
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(ib);
+    UNUSED(sumf);
+    ggml_vec_dot_q8_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
 }
 
 void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
@@ -695,45 +641,10 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
     *s = vec_extract(vsumf0, 0);
 
 #else
-
-    float sumf = 0;
-
-    for (int i = 0; i < nb; ++i) {
-
-        const uint8_t * q2 = x[i].qs;
-        const  int8_t * q8 = y[i].qs;
-        const uint8_t * sc = x[i].scales;
-
-        int summs = 0;
-        for (int j = 0; j < 16; ++j) {
-            summs += y[i].bsums[j] * (sc[j] >> 4);
-        }
-
-        const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
-        const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
-
-        int isum = 0;
-        int is = 0;
-        int d;
-        for (int k = 0; k < QK_K/128; ++k) {
-            int shift = 0;
-            for (int j = 0; j < 4; ++j) {
-                d = sc[is++] & 0xF;
-                int isuml = 0;
-                for (int l =  0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
-                isum += d * isuml;
-                d = sc[is++] & 0xF;
-                isuml = 0;
-                for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
-                isum += d * isuml;
-                shift += 2;
-                q8 += 32;
-            }
-            q2 += 32;
-        }
-        sumf += dall * isum - dmin * summs;
-    }
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -907,70 +818,13 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
     *s = vec_extract(vsumf0, 0);
 
 #else
-    // scalar version
-    // This function is written like this so the compiler can manage to vectorize most of it
-    // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
-    // manually vectorized version above. Every other version I tried would run at least 4 times slower.
-    // The ideal situation would be if we could just write the code once, and the compiler would
-    // automatically produce the best possible set of machine instructions, instead of us having to manually
-    // write vectorized versions for AVX, ARM_NEON, etc.
-
-    int8_t  aux8[QK_K];
-    int16_t aux16[8];
-    float   sums [8];
-    int32_t aux32[8];
-    memset(sums, 0, 8*sizeof(float));
-
-    uint32_t auxs[4];
-    const int8_t * scales = (const int8_t*)auxs;
-
-    float sumf = 0;
-    for (int i = 0; i < nb; ++i) {
-        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
-        const uint8_t * GGML_RESTRICT hm = x[i].hmask;
-        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
-        memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * GGML_RESTRICT a = aux8;
-        uint8_t m = 1;
-        for (int j = 0; j < QK_K; j += 128) {
-            for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
-            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
-            a += 32; m <<= 1;
-            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
-            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
-            a += 32; m <<= 1;
-            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3;
-            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
-            a += 32; m <<= 1;
-            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3;
-            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
-            a += 32; m <<= 1;
-            q3 += 32;
-        }
-        a = aux8;
-
-        memcpy(auxs, x[i].scales, 12);
-        uint32_t tmp = auxs[2];
-        auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
-        auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
-        auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
-        auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
-        for (int j = 0; j < QK_K/16; ++j) {
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
-            q8 += 8; a += 8;
-        }
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
-    }
-    for (int l = 0; l < 8; ++l) sumf += sums[l];
-    *s = sumf;
-
+    UNUSED(kmask1);
+    UNUSED(kmask2);
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
-
 }
 
 void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
@@ -1130,61 +984,14 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
     *s = vec_extract(vsumf0, 0);
 
 #else
-
-    const uint8_t * scales = (const uint8_t*)&utmp[0];
-    const uint8_t * mins   = (const uint8_t*)&utmp[2];
-
-    int8_t  aux8[QK_K];
-    int16_t aux16[8];
-    float   sums [8];
-    int32_t aux32[8];
-    memset(sums, 0, 8*sizeof(float));
-
-    float sumf = 0;
-    for (int i = 0; i < nb; ++i) {
-        const uint8_t * GGML_RESTRICT q4 = x[i].qs;
-        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
-        memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * GGML_RESTRICT a = aux8;
-        for (int j = 0; j < QK_K/64; ++j) {
-            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
-            a += 32;
-            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l]  >> 4);
-            a += 32; q4 += 32;
-        }
-        memcpy(utmp, x[i].scales, 12);
-        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
-        const uint32_t uaux = utmp[1] & kmask1;
-        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
-        utmp[2] = uaux;
-        utmp[0] &= kmask1;
-
-        int sumi = 0;
-        for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
-        a = aux8;
-        int is = 0;
-        for (int j = 0; j < QK_K/32; ++j) {
-            int32_t scale = scales[is++];
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-        }
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
-        const float dmin = GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;
-        sumf -= dmin * sumi;
-    }
-    for (int l = 0; l < 8; ++l) sumf += sums[l];
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    UNUSED(kmask1);
+    UNUSED(kmask2);
+    UNUSED(kmask3);
+    UNUSED(utmp);
+    ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -1342,66 +1149,14 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
     *s = vec_extract(vsumf0, 0);
 
 #else
-
-    const uint8_t * scales = (const uint8_t*)&utmp[0];
-    const uint8_t * mins   = (const uint8_t*)&utmp[2];
-
-    int8_t  aux8[QK_K];
-    int16_t aux16[8];
-    float   sums [8];
-    int32_t aux32[8];
-    memset(sums, 0, 8*sizeof(float));
-
-    float sumf = 0;
-    for (int i = 0; i < nb; ++i) {
-        const uint8_t * GGML_RESTRICT q4 = x[i].qs;
-        const uint8_t * GGML_RESTRICT hm = x[i].qh;
-        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
-        memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * GGML_RESTRICT a = aux8;
-        uint8_t m = 1;
-        for (int j = 0; j < QK_K/64; ++j) {
-            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
-            for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
-            a += 32; m <<= 1;
-            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l]  >> 4);
-            for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
-            a += 32; m <<= 1;
-            q4 += 32;
-        }
-        memcpy(utmp, x[i].scales, 12);
-        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
-        const uint32_t uaux = utmp[1] & kmask1;
-        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
-        utmp[2] = uaux;
-        utmp[0] &= kmask1;
-
-        int sumi = 0;
-        for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
-        a = aux8;
-        int is = 0;
-        for (int j = 0; j < QK_K/32; ++j) {
-            int32_t scale = scales[is++];
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-        }
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
-        const float dmin = GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;
-        sumf -= dmin * sumi;
-    }
-    for (int l = 0; l < 8; ++l) sumf += sums[l];
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    UNUSED(kmask1);
+    UNUSED(kmask2);
+    UNUSED(kmask3);
+    UNUSED(utmp);
+    ggml_vec_dot_q5_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -1556,47 +1311,10 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
     *s = vec_extract(vsumf0, 0);
 
 #else
-
-    int8_t  aux8[QK_K];
-    int16_t aux16[8];
-    float   sums [8];
-    int32_t aux32[8];
-    memset(sums, 0, 8*sizeof(float));
-
-    float sumf = 0;
-    for (int i = 0; i < nb; ++i) {
-        const uint8_t * GGML_RESTRICT q4 = x[i].ql;
-        const uint8_t * GGML_RESTRICT qh = x[i].qh;
-        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
-        memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * GGML_RESTRICT a = aux8;
-        for (int j = 0; j < QK_K; j += 128) {
-            for (int l = 0; l < 32; ++l) {
-                a[l +  0] = (int8_t)((q4[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
-                a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
-                a[l + 64] = (int8_t)((q4[l +  0] >>  4) | (((qh[l] >> 4) & 3) << 4)) - 32;
-                a[l + 96] = (int8_t)((q4[l + 32] >>  4) | (((qh[l] >> 6) & 3) << 4)) - 32;
-            }
-            a  += 128;
-            q4 += 64;
-            qh += 32;
-        }
-        a = aux8;
-        int is = 0;
-        for (int j = 0; j < QK_K/16; ++j) {
-            int scale = x[i].scales[is++];
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-        }
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
-    }
-    for (int l = 0; l < 8; ++l) sumf += sums[l];
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -1737,34 +1455,10 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const
     *s = 0.125f * vec_extract(vsumf0, 0);
 
 #else
-
-    uint32_t aux32[2];
-    const uint8_t * aux8 = (const uint8_t *)aux32;
-
-    float sumf = 0.f;
-    for (int i = 0; i < nb; ++i) {
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint16_t * GGML_RESTRICT q2 = x[i].qs;
-        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
-        int32_t bsum = 0;
-        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
-            memcpy(aux32, q2, 2*sizeof(uint32_t));
-            q2 += 4;
-            const uint32_t ls = 2*(aux32[1] >> 28) + 1;
-            int32_t sumi = 0;
-            for (int l = 0; l < 4; ++l) {
-                const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
-                const uint8_t  signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
-                for (int j = 0; j < 8; ++j) {
-                    sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            bsum += sumi * ls;
-        }
-        sumf += d * bsum;
-    }
-    *s = 0.125f * sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq2_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -1869,42 +1563,10 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
     *s = 0.125f * vec_extract(vsumf0, 0);
 
 #else
-
-    float sumf = 0.f;
-    for (int i = 0; i < nb; ++i) {
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint16_t * GGML_RESTRICT q2 = x[i].qs;
-        const uint8_t  * GGML_RESTRICT sc = x[i].scales;
-        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
-        int32_t bsum = 0;
-        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
-            const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1;
-            const uint16_t ls2 = 2*(sc[ib32] >>  4) + 1;
-            int32_t sumi = 0;
-            for (int l = 0; l < 2; ++l) {
-                const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
-                const uint8_t  signs = ksigns_iq2xs[q2[l] >> 9];
-                for (int j = 0; j < 8; ++j) {
-                    sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            bsum += sumi * ls1;
-            sumi = 0;
-            for (int l = 2; l < 4; ++l) {
-                const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
-                const uint8_t  signs = ksigns_iq2xs[q2[l] >> 9];
-                for (int j = 0; j < 8; ++j) {
-                    sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            bsum += sumi * ls2;
-            q2 += 4;
-        }
-        sumf += d * bsum;
-    }
-    *s = 0.125f * sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -2030,47 +1692,11 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
     *s = 0.125f * vec_extract(vsumf0, 0);
 
 #else
-
-    float sumf = 0;
-    for (int i = 0; i < nb; i++) {
-
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        const int8_t  * q8 = y[i].qs;
-        const uint8_t * qs = x[i].qs;
-        const uint8_t * qh = x[i].qh;
-        const uint8_t * signs = qs + QK_K/8;
-
-        int bsum = 0;
-        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
-            int ls1 = 1 + 2*(x[i].scales[ib32] & 0xf);
-            int ls2 = 1 + 2*(x[i].scales[ib32] >>  4);
-            int sumi1 = 0, sumi2 = 0;
-            for (int l = 0; l < 2; ++l) {
-                const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300)));
-                for (int j = 0; j < 8; ++j) {
-                    sumi1 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            for (int l = 2; l < 4; ++l) {
-                const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300)));
-                for (int j = 0; j < 8; ++j) {
-                    sumi2 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            bsum += ls1 * sumi1 + ls2 * sumi2;
-            qs += 4;
-            signs += 4;
-        }
-
-        sumf += d * bsum;
-    }
-
-    *s = 0.125f * sumf;
-
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
-
 }
 
 void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
@@ -2172,36 +1798,10 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const
     *s = 0.25f * vec_extract(vsumf0, 0);
 
 #else
-
-    uint32_t aux32;
-
-    float sumf = 0.f;
-    for (int i = 0; i < nb; ++i) {
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
-        const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4;
-        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
-        int32_t bsum = 0;
-        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
-            memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t);
-            const uint32_t ls = 2*(aux32 >> 28) + 1;
-            int32_t sumi = 0;
-            for (int l = 0; l < 4; ++l) {
-                const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]);
-                const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]);
-                const uint8_t  signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
-                for (int j = 0; j < 4; ++j) {
-                    sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1);
-                    sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            q3 += 8;
-            bsum += sumi * ls;
-        }
-        sumf += d * bsum;
-    }
-    *s = 0.25f * sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -2327,48 +1927,10 @@ void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
     *s = vec_extract(vsumf0, 0);
 
 #else
-
-    float sumf = 0.f;
-    for (int i = 0; i < nb; ++i) {
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint8_t * GGML_RESTRICT qs = x[i].qs;
-        const uint8_t * GGML_RESTRICT qh = x[i].qh;
-        const uint8_t * GGML_RESTRICT signs = x[i].signs;
-        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
-        int32_t bsum = 0;
-        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
-            const uint32_t ls1 = 2*(x[i].scales[ib32/2] & 0xf) + 1;
-            const uint32_t ls2 = 2*(x[i].scales[ib32/2] >>  4) + 1;
-            int32_t sumi = 0;
-            for (int l = 0; l < 4; ++l) {
-                const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256)));
-                const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256)));
-                for (int j = 0; j < 4; ++j) {
-                    sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
-                    sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            qs += 8;
-            signs += 4;
-            bsum += sumi * ls1;
-            sumi = 0;
-            for (int l = 0; l < 4; ++l) {
-                const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256)));
-                const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256)));
-                for (int j = 0; j < 4; ++j) {
-                    sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
-                    sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            qs += 8;
-            signs += 4;
-            bsum += sumi * ls2;
-        }
-        sumf += d * bsum;
-    }
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -2481,36 +2043,10 @@ void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
     *s = vec_extract(vsumf0, 0);
 
 #else
-
-    float sumf = 0;
-    for (int i = 0; i < nb; i++) {
-
-        const int8_t   * q8 = y[i].qs;
-        const uint8_t  * qs = x[i].qs;
-        const uint16_t * qh = x[i].qh;
-
-        int sumi = 0, sumi1 = 0;
-        for (int ib = 0; ib < QK_K/32; ++ib) {
-            const int ls = 2*((qh[ib] >> 12) & 7) + 1;
-            const int delta = qh[ib] & 0x8000 ? -1 : 1;
-            int lsum = 0;
-            for (int l = 0; l < 4; ++l) {
-                const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8)));
-                for (int j = 0; j < 8; ++j) {
-                    lsum += q8[j] * grid[j];
-                }
-                q8 += 8;
-            }
-            sumi  += ls * lsum;
-            sumi1 += ls * delta * (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]);
-            qs += 4;
-        }
-
-        sumf += GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1);
-    }
-
-    *s = sumf;
-
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -2581,17 +2117,15 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v
 
     sumf = vec_extract(vsumf0, 0);
 
-#endif
-    for (; ib < nb; ++ib) {
-        const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_CPU_FP16_TO_FP32(x[ib].d);
-        int sumi1 = 0, sumi2 = 0;
-        for (int j = 0; j < QK4_NL/2; ++j) {
-            sumi1 += y[ib].qs[j+       0] * kvalues_iq4nl[x[ib].qs[j] & 0xf];
-            sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >>  4];
-        }
-        sumf += d * (sumi1 + sumi2);
-    }
     *s = sumf;
+#else
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    UNUSED(ib);
+    UNUSED(sumf);
+    ggml_vec_dot_iq4_nl_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
 }
 
 void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
@@ -2696,37 +2230,10 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
     *s = vec_extract(vsumf0, 0);
 
 #else
-    float sumf = 0;
-    for (int ibl = 0; ibl < nb; ++ibl) {
-        const float d4d8 = GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d;
-        uint16_t h = x[ibl].scales_h;
-        const uint8_t * qs = x[ibl].qs;
-        const int8_t  * q8 = y[ibl].qs;
-        for (int ib = 0; ib < QK_K/32; ib += 2) {
-            const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30);
-            const uint8_t ls2 = (x[ibl].scales_l[ib/2] >>  4) | ((h << 2) & 0x30);
-            h >>= 4;
-            const float d1 = d4d8*(ls1 - 32);
-            const float d2 = d4d8*(ls2 - 32);
-            int sumi1 = 0, sumi2 = 0;
-            for (int j = 0; j < 16; ++j) {
-                sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
-                sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >>  4];
-            }
-            sumf += d1 * (sumi1 + sumi2);
-            qs += 16;
-            q8 += 32;
-            sumi1 = sumi2 = 0;
-            for (int j = 0; j < 16; ++j) {
-                sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
-                sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >>  4];
-            }
-            sumf += d2 * (sumi1 + sumi2);
-            qs += 16;
-            q8 += 32;
-        }
-    }
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
diff --git a/ggml/src/ggml-cpu/arch/riscv/quants.c b/ggml/src/ggml-cpu/arch/riscv/quants.c
index 8b64d8adc48..6c74417c90c 100644
--- a/ggml/src/ggml-cpu/arch/riscv/quants.c
+++ b/ggml/src/ggml-cpu/arch/riscv/quants.c
@@ -116,6 +116,7 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
 //===================================== Dot products =================================
 
 void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+#if defined(__riscv_v)
     const int qk = QK8_0;
     const int nb = n / qk;
 
@@ -132,7 +133,6 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
     int ib = 0;
     float sumf = 0;
 
-#if defined(__riscv_v)
     size_t vl = qk / 2;
 
     for (; ib < nb; ++ib) {
@@ -164,27 +164,14 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
         sumf += sumi*GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d);
     }
 
-#endif
-    for (; ib < nb; ++ib) {
-        int sumi0 = 0;
-        int sumi1 = 0;
-
-        for (int j = 0; j < qk/2; ++j) {
-            const int v0 = (x[ib].qs[j] & 0x0F) - 8;
-            const int v1 = (x[ib].qs[j] >>   4) - 8;
-
-            sumi0 += (v0 * y[ib].qs[j]);
-            sumi1 += (v1 * y[ib].qs[j + qk/2]);
-        }
-
-        int sumi = sumi0 + sumi1;
-        sumf += sumi*GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d);
-    }
-
     *s = sumf;
+#else
+    ggml_vec_dot_q4_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
 }
 
 void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+#if defined(__riscv_v)
     const int qk = QK8_1;
     const int nb = n / qk;
 
@@ -201,7 +188,6 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
     int ib = 0;
     float sumf = 0;
 
-#if defined(__riscv_v)
     size_t vl = qk / 2;
 
     for (; ib < nb; ++ib) {
@@ -229,27 +215,14 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
         sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);
     }
 
-#endif
-    for (; ib < nb; ++ib) {
-        int sumi0 = 0;
-        int sumi1 = 0;
-
-        for (int j = 0; j < qk/2; ++j) {
-            const int v0 = (x[ib].qs[j] & 0x0F);
-            const int v1 = (x[ib].qs[j] >>   4);
-
-            sumi0 += (v0 * y[ib].qs[j]);
-            sumi1 += (v1 * y[ib].qs[j + qk/2]);
-        }
-
-        int sumi = sumi0 + sumi1;
-        sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);
-    }
-
     *s = sumf;
+#else
+    ggml_vec_dot_q4_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
 }
 
 void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+#if defined(__riscv_v)
     const int qk = QK8_0;
     const int nb = n / qk;
 
@@ -267,7 +240,6 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
     const block_q5_0 * GGML_RESTRICT x = vx;
     const block_q8_0 * GGML_RESTRICT y = vy;
 
-#if defined(__riscv_v)
     size_t vl;
     size_t vlenb = __riscv_vlenb();
 
@@ -297,33 +269,14 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
         sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d)) * sumi;
     }
 
-#endif
-    for (; ib < nb; ++ib) {
-        uint32_t qh;
-        memcpy(&qh, x[ib].qh, sizeof(qh));
-
-        int sumi0 = 0;
-        int sumi1 = 0;
-
-        for (int j = 0; j < qk/2; ++j) {
-            const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
-            const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
-
-            const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16);
-            const int32_t x1 = (int8_t)(((x[ib].qs[j] >>   4) | xh_1) - 16);
-
-            sumi0 += (x0 * y[ib].qs[j]);
-            sumi1 += (x1 * y[ib].qs[j + qk/2]);
-        }
-
-        int sumi = sumi0 + sumi1;
-        sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d)) * sumi;
-    }
-
     *s = sumf;
+#else
+    ggml_vec_dot_q5_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
 }
 
 void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+#if defined(__riscv_v)
     const int qk = QK8_1;
     const int nb = n / qk;
 
@@ -341,7 +294,6 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
     const block_q5_1 * GGML_RESTRICT x = vx;
     const block_q8_1 * GGML_RESTRICT y = vy;
 
-#if defined(__riscv_v)
     size_t vl;
     size_t vlenb = __riscv_vlenb();
 
@@ -370,30 +322,10 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
         sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);
     }
 
-#endif
-    for (; ib < nb; ++ib) {
-        uint32_t qh;
-        memcpy(&qh, x[ib].qh, sizeof(qh));
-
-        int sumi0 = 0;
-        int sumi1 = 0;
-
-        for (int j = 0; j < qk/2; ++j) {
-            const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;
-            const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;
-
-            const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0;
-            const int32_t x1 = (x[ib].qs[j] >>  4) | xh_1;
-
-            sumi0 += (x0 * y[ib].qs[j]);
-            sumi1 += (x1 * y[ib].qs[j + qk/2]);
-        }
-
-        int sumi = sumi0 + sumi1;
-        sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);
-    }
-
     *s = sumf;
+#else
+    ggml_vec_dot_q5_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
 }
 
 void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
@@ -431,18 +363,17 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
         sumf += sumi*(GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d));
     }
 
-#endif
-    for (; ib < nb; ++ib) {
-        int sumi = 0;
+    *s = sumf;
+#else
 
-        for (int j = 0; j < qk; j++) {
-            sumi += x[ib].qs[j]*y[ib].qs[j];
-        }
+    UNUSED(nb);
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(ib);
+    UNUSED(sumf);
 
-        sumf += sumi*(GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d));
-    }
-
-    *s = sumf;
+    ggml_vec_dot_q8_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
 }
 
 void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
@@ -738,44 +669,11 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
 
 #else
 
-    float sumf = 0;
-
-    for (int i = 0; i < nb; ++i) {
-
-        const uint8_t * q2 = x[i].qs;
-        const  int8_t * q8 = y[i].qs;
-        const uint8_t * sc = x[i].scales;
-
-        int summs = 0;
-        for (int j = 0; j < 16; ++j) {
-            summs += y[i].bsums[j] * (sc[j] >> 4);
-        }
-
-        const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
-        const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
 
-        int isum = 0;
-        int is = 0;
-        int d;
-        for (int k = 0; k < QK_K/128; ++k) {
-            int shift = 0;
-            for (int j = 0; j < 4; ++j) {
-                d = sc[is++] & 0xF;
-                int isuml = 0;
-                for (int l =  0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
-                isum += d * isuml;
-                d = sc[is++] & 0xF;
-                isuml = 0;
-                for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
-                isum += d * isuml;
-                shift += 2;
-                q8 += 32;
-            }
-            q2 += 32;
-        }
-        sumf += dall * isum - dmin * summs;
-    }
-    *s = sumf;
+    ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -1147,68 +1045,14 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
     *s = sumf;
 
 #else
-    // scalar version
-    // This function is written like this so the compiler can manage to vectorize most of it
-    // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
-    // manually vectorized version above. Every other version I tried would run at least 4 times slower.
-    // The ideal situation would be if we could just write the code once, and the compiler would
-    // automatically produce the best possible set of machine instructions, instead of us having to manually
-    // write vectorized versions for AVX, ARM_NEON, etc.
-
-    int8_t  aux8[QK_K];
-    int16_t aux16[8];
-    float   sums [8];
-    int32_t aux32[8];
-    memset(sums, 0, 8*sizeof(float));
-
-    uint32_t auxs[4];
-    const int8_t * scales = (const int8_t*)auxs;
 
-    float sumf = 0;
-    for (int i = 0; i < nb; ++i) {
-        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
-        const uint8_t * GGML_RESTRICT hm = x[i].hmask;
-        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
-        memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * GGML_RESTRICT a = aux8;
-        uint8_t m = 1;
-        for (int j = 0; j < QK_K; j += 128) {
-            for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
-            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
-            a += 32; m <<= 1;
-            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
-            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
-            a += 32; m <<= 1;
-            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3;
-            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
-            a += 32; m <<= 1;
-            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3;
-            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
-            a += 32; m <<= 1;
-            q3 += 32;
-        }
-        a = aux8;
-
-        memcpy(auxs, x[i].scales, 12);
-        uint32_t tmp = auxs[2];
-        auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
-        auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
-        auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
-        auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
-        for (int j = 0; j < QK_K/16; ++j) {
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
-            q8 += 8; a += 8;
-        }
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
-    }
-    for (int l = 0; l < 8; ++l) sumf += sums[l];
-    *s = sumf;
+    UNUSED(kmask1);
+    UNUSED(kmask2);
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
 
+    ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 
 }
@@ -1534,60 +1378,15 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
 
 #else
 
-    const uint8_t * scales = (const uint8_t*)&utmp[0];
-    const uint8_t * mins   = (const uint8_t*)&utmp[2];
-
-    int8_t  aux8[QK_K];
-    int16_t aux16[8];
-    float   sums [8];
-    int32_t aux32[8];
-    memset(sums, 0, 8*sizeof(float));
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(kmask1);
+    UNUSED(kmask2);
+    UNUSED(kmask3);
+    UNUSED(nb);
+    UNUSED(utmp);
 
-    float sumf = 0;
-    for (int i = 0; i < nb; ++i) {
-        const uint8_t * GGML_RESTRICT q4 = x[i].qs;
-        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
-        memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * GGML_RESTRICT a = aux8;
-        for (int j = 0; j < QK_K/64; ++j) {
-            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
-            a += 32;
-            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l]  >> 4);
-            a += 32; q4 += 32;
-        }
-        memcpy(utmp, x[i].scales, 12);
-        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
-        const uint32_t uaux = utmp[1] & kmask1;
-        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
-        utmp[2] = uaux;
-        utmp[0] &= kmask1;
-
-        int sumi = 0;
-        for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
-        a = aux8;
-        int is = 0;
-        for (int j = 0; j < QK_K/32; ++j) {
-            int32_t scale = scales[is++];
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-        }
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
-        const float dmin = GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;
-        sumf -= dmin * sumi;
-    }
-    for (int l = 0; l < 8; ++l) sumf += sums[l];
-    *s = sumf;
+    ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -1698,65 +1497,15 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
 
 #else
 
-    const uint8_t * scales = (const uint8_t*)&utmp[0];
-    const uint8_t * mins   = (const uint8_t*)&utmp[2];
-
-    int8_t  aux8[QK_K];
-    int16_t aux16[8];
-    float   sums [8];
-    int32_t aux32[8];
-    memset(sums, 0, 8*sizeof(float));
-
-    float sumf = 0;
-    for (int i = 0; i < nb; ++i) {
-        const uint8_t * GGML_RESTRICT q4 = x[i].qs;
-        const uint8_t * GGML_RESTRICT hm = x[i].qh;
-        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
-        memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * GGML_RESTRICT a = aux8;
-        uint8_t m = 1;
-        for (int j = 0; j < QK_K/64; ++j) {
-            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
-            for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
-            a += 32; m <<= 1;
-            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l]  >> 4);
-            for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
-            a += 32; m <<= 1;
-            q4 += 32;
-        }
-        memcpy(utmp, x[i].scales, 12);
-        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
-        const uint32_t uaux = utmp[1] & kmask1;
-        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
-        utmp[2] = uaux;
-        utmp[0] &= kmask1;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(kmask1);
+    UNUSED(kmask2);
+    UNUSED(kmask3);
+    UNUSED(nb);
+    UNUSED(utmp);
 
-        int sumi = 0;
-        for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
-        a = aux8;
-        int is = 0;
-        for (int j = 0; j < QK_K/32; ++j) {
-            int32_t scale = scales[is++];
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-        }
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
-        const float dmin = GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;
-        sumf -= dmin * sumi;
-    }
-    for (int l = 0; l < 8; ++l) sumf += sums[l];
-    *s = sumf;
+    ggml_vec_dot_q5_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -2024,46 +1773,11 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
 
 #else
 
-    int8_t  aux8[QK_K];
-    int16_t aux16[8];
-    float   sums [8];
-    int32_t aux32[8];
-    memset(sums, 0, 8*sizeof(float));
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
 
-    float sumf = 0;
-    for (int i = 0; i < nb; ++i) {
-        const uint8_t * GGML_RESTRICT q4 = x[i].ql;
-        const uint8_t * GGML_RESTRICT qh = x[i].qh;
-        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
-        memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * GGML_RESTRICT a = aux8;
-        for (int j = 0; j < QK_K; j += 128) {
-            for (int l = 0; l < 32; ++l) {
-                a[l +  0] = (int8_t)((q4[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
-                a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
-                a[l + 64] = (int8_t)((q4[l +  0] >>  4) | (((qh[l] >> 4) & 3) << 4)) - 32;
-                a[l + 96] = (int8_t)((q4[l + 32] >>  4) | (((qh[l] >> 6) & 3) << 4)) - 32;
-            }
-            a  += 128;
-            q4 += 64;
-            qh += 32;
-        }
-        a = aux8;
-        int is = 0;
-        for (int j = 0; j < QK_K/16; ++j) {
-            int scale = x[i].scales[is++];
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-        }
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
-    }
-    for (int l = 0; l < 8; ++l) sumf += sums[l];
-    *s = sumf;
+    ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
diff --git a/ggml/src/ggml-cpu/arch/riscv/repack.cpp b/ggml/src/ggml-cpu/arch/riscv/repack.cpp
index 45c91a69482..2a35ff9ad87 100644
--- a/ggml/src/ggml-cpu/arch/riscv/repack.cpp
+++ b/ggml/src/ggml-cpu/arch/riscv/repack.cpp
@@ -112,31 +112,7 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
     }
 
 #endif
-    {
-        float sumf[8];
-        int sumi;
-
-        const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
-        for (int x = 0; x < nc / ncols_interleaved; x++) {
-            const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
-
-            for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
-            for (int l = 0; l < nb; l++) {
-                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
-                    for (int j = 0; j < ncols_interleaved; j++) {
-                        sumi = 0;
-                        for (int i = 0; i < blocklen; ++i) {
-                            const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
-                            const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
-                            sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
-                        }
-                        sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
-                    }
-                }
-            }
-            for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
-        }
-    }
+    ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
 }
 
 void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -361,37 +337,6 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
         return;
     }
 
-#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
-    float sumf[4][8];
-    int sumi;
-
-    for (int y = 0; y < nr / 4; y++) {
-        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
-        for (int x = 0; x < nc / ncols_interleaved; x++) {
-            const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
-            for (int m = 0; m < 4; m++) {
-                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
-            }
-            for (int l = 0; l < nb; l++) {
-                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
-                    for (int m = 0; m < 4; m++) {
-                        for (int j = 0; j < ncols_interleaved; j++) {
-                            sumi = 0;
-                            for (int i = 0; i < blocklen; ++i) {
-                                const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
-                                const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
-                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
-                                         (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
-                            }
-                            sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
-                        }
-                    }
-                }
-            }
-            for (int m = 0; m < 4; m++) {
-                for (int j = 0; j < ncols_interleaved; j++)
-                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
-            }
-        }
-    }
+#endif
+    ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
 }
diff --git a/ggml/src/ggml-cpu/arch/s390/quants.c b/ggml/src/ggml-cpu/arch/s390/quants.c
index a840219a4fc..7e4229d0e46 100644
--- a/ggml/src/ggml-cpu/arch/s390/quants.c
+++ b/ggml/src/ggml-cpu/arch/s390/quants.c
@@ -172,24 +172,15 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
 
     sumf = acc[0] + acc[1] + acc[2] + acc[3];
 
-#endif
-    for (; ib < nb; ++ib) {
-        int sumi0 = 0;
-        int sumi1 = 0;
-
-        for (int j = 0; j < qk/2; ++j) {
-            const int v0 = (x[ib].qs[j] & 0x0F) - 8;
-            const int v1 = (x[ib].qs[j] >>   4) - 8;
-
-            sumi0 += (v0 * y[ib].qs[j]);
-            sumi1 += (v1 * y[ib].qs[j + qk/2]);
-        }
-
-        int sumi = sumi0 + sumi1;
-        sumf += sumi*GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d);
-    }
-
     *s = sumf;
+#else
+    UNUSED(nb);
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(ib);
+    UNUSED(sumf);
+    ggml_vec_dot_q4_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
 }
 
 void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
@@ -239,24 +230,15 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
 
     sumf = acc[0] + acc[1] + acc[2] + acc[3] + summs;
 
-#endif
-    for (; ib < nb; ++ib) {
-        int sumi0 = 0;
-        int sumi1 = 0;
-
-        for (int j = 0; j < qk/2; ++j) {
-            const int v0 = (x[ib].qs[j] & 0x0F);
-            const int v1 = (x[ib].qs[j] >>   4);
-
-            sumi0 += (v0 * y[ib].qs[j]);
-            sumi1 += (v1 * y[ib].qs[j + qk/2]);
-        }
-
-        int sumi = sumi0 + sumi1;
-        sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);
-    }
-
     *s = sumf;
+#else
+    UNUSED(nb);
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(ib);
+    UNUSED(sumf);
+    ggml_vec_dot_q4_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
 }
 
 void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
@@ -298,18 +280,15 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
 
     sumf = acc[0] + acc[1] + acc[2] + acc[3];
 
-#endif
-    for (; ib < nb; ++ib) {
-        int sumi = 0;
-
-        for (int j = 0; j < qk; j++) {
-            sumi += x[ib].qs[j]*y[ib].qs[j];
-        }
-
-        sumf += sumi*(GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d));
-    }
-
     *s = sumf;
+#else
+    UNUSED(nb);
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(ib);
+    UNUSED(sumf);
+    ggml_vec_dot_q8_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
 }
 
 void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
@@ -442,70 +421,13 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
     *s = sum;
 
 #else
-    // scalar version
-    // This function is written like this so the compiler can manage to vectorize most of it
-    // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
-    // manually vectorized version above. Every other version I tried would run at least 4 times slower.
-    // The ideal situation would be if we could just write the code once, and the compiler would
-    // automatically produce the best possible set of machine instructions, instead of us having to manually
-    // write vectorized versions for AVX, ARM_NEON, etc.
-
-    int8_t  aux8[QK_K];
-    int16_t aux16[8];
-    float   sums [8];
-    int32_t aux32[8];
-    memset(sums, 0, 8*sizeof(float));
-
-    uint32_t auxs[4];
-    const int8_t * scales = (const int8_t*)auxs;
-
-    float sumf = 0;
-    for (int i = 0; i < nb; ++i) {
-        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
-        const uint8_t * GGML_RESTRICT hm = x[i].hmask;
-        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
-        memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * GGML_RESTRICT a = aux8;
-        uint8_t m = 1;
-        for (int j = 0; j < QK_K; j += 128) {
-            for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
-            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
-            a += 32; m <<= 1;
-            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
-            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
-            a += 32; m <<= 1;
-            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3;
-            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
-            a += 32; m <<= 1;
-            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3;
-            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
-            a += 32; m <<= 1;
-            q3 += 32;
-        }
-        a = aux8;
-
-        memcpy(auxs, x[i].scales, 12);
-        uint32_t tmp = auxs[2];
-        auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
-        auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
-        auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
-        auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
-        for (int j = 0; j < QK_K/16; ++j) {
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
-            q8 += 8; a += 8;
-        }
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
-    }
-    for (int l = 0; l < 8; ++l) sumf += sums[l];
-    *s = sumf;
-
+    UNUSED(kmask1);
+    UNUSED(kmask2);
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
-
 }
 
 void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
@@ -600,61 +522,14 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
     *s = sumf;
 
 #else
-
-    const uint8_t * scales = (const uint8_t*)&utmp[0];
-    const uint8_t * mins   = (const uint8_t*)&utmp[2];
-
-    int8_t  aux8[QK_K];
-    int16_t aux16[8];
-    float   sums [8];
-    int32_t aux32[8];
-    memset(sums, 0, 8*sizeof(float));
-
-    float sumf = 0;
-    for (int i = 0; i < nb; ++i) {
-        const uint8_t * GGML_RESTRICT q4 = x[i].qs;
-        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
-        memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * GGML_RESTRICT a = aux8;
-        for (int j = 0; j < QK_K/64; ++j) {
-            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
-            a += 32;
-            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l]  >> 4);
-            a += 32; q4 += 32;
-        }
-        memcpy(utmp, x[i].scales, 12);
-        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
-        const uint32_t uaux = utmp[1] & kmask1;
-        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
-        utmp[2] = uaux;
-        utmp[0] &= kmask1;
-
-        int sumi = 0;
-        for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
-        a = aux8;
-        int is = 0;
-        for (int j = 0; j < QK_K/32; ++j) {
-            int32_t scale = scales[is++];
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-        }
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
-        const float dmin = GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;
-        sumf -= dmin * sumi;
-    }
-    for (int l = 0; l < 8; ++l) sumf += sums[l];
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    UNUSED(kmask1);
+    UNUSED(kmask2);
+    UNUSED(kmask3);
+    UNUSED(utmp);
+    ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -767,66 +642,14 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
     *s = sumf;
 
 #else
-
-    const uint8_t * scales = (const uint8_t*)&utmp[0];
-    const uint8_t * mins   = (const uint8_t*)&utmp[2];
-
-    int8_t  aux8[QK_K];
-    int16_t aux16[8];
-    float   sums [8];
-    int32_t aux32[8];
-    memset(sums, 0, 8*sizeof(float));
-
-    float sumf = 0;
-    for (int i = 0; i < nb; ++i) {
-        const uint8_t * GGML_RESTRICT q4 = x[i].qs;
-        const uint8_t * GGML_RESTRICT hm = x[i].qh;
-        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
-        memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * GGML_RESTRICT a = aux8;
-        uint8_t m = 1;
-        for (int j = 0; j < QK_K/64; ++j) {
-            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
-            for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
-            a += 32; m <<= 1;
-            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l]  >> 4);
-            for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
-            a += 32; m <<= 1;
-            q4 += 32;
-        }
-        memcpy(utmp, x[i].scales, 12);
-        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
-        const uint32_t uaux = utmp[1] & kmask1;
-        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
-        utmp[2] = uaux;
-        utmp[0] &= kmask1;
-
-        int sumi = 0;
-        for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
-        a = aux8;
-        int is = 0;
-        for (int j = 0; j < QK_K/32; ++j) {
-            int32_t scale = scales[is++];
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-        }
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
-        const float dmin = GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;
-        sumf -= dmin * sumi;
-    }
-    for (int l = 0; l < 8; ++l) sumf += sums[l];
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    UNUSED(kmask1);
+    UNUSED(kmask2);
+    UNUSED(kmask3);
+    UNUSED(utmp);
+    ggml_vec_dot_q5_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -969,47 +792,10 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
     *s = sum;
 
 #else
-
-    int8_t  aux8[QK_K];
-    int16_t aux16[8];
-    float   sums [8];
-    int32_t aux32[8];
-    memset(sums, 0, 8*sizeof(float));
-
-    float sumf = 0;
-    for (int i = 0; i < nb; ++i) {
-        const uint8_t * GGML_RESTRICT q4 = x[i].ql;
-        const uint8_t * GGML_RESTRICT qh = x[i].qh;
-        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
-        memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * GGML_RESTRICT a = aux8;
-        for (int j = 0; j < QK_K; j += 128) {
-            for (int l = 0; l < 32; ++l) {
-                a[l +  0] = (int8_t)((q4[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
-                a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
-                a[l + 64] = (int8_t)((q4[l +  0] >>  4) | (((qh[l] >> 4) & 3) << 4)) - 32;
-                a[l + 96] = (int8_t)((q4[l + 32] >>  4) | (((qh[l] >> 6) & 3) << 4)) - 32;
-            }
-            a  += 128;
-            q4 += 64;
-            qh += 32;
-        }
-        a = aux8;
-        int is = 0;
-        for (int j = 0; j < QK_K/16; ++j) {
-            int scale = x[i].scales[is++];
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-        }
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
-    }
-    for (int l = 0; l < 8; ++l) sumf += sums[l];
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -1186,17 +972,15 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v
         sumf += GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d) * (v_xy[0] + v_xy[1] + v_xy[2] + v_xy[3]);
     }
 
-#endif
-    for (; ib < nb; ++ib) {
-        const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_CPU_FP16_TO_FP32(x[ib].d);
-        int sumi1 = 0, sumi2 = 0;
-        for (int j = 0; j < QK4_NL/2; ++j) {
-            sumi1 += y[ib].qs[j+       0] * kvalues_iq4nl[x[ib].qs[j] & 0xf];
-            sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >>  4];
-        }
-        sumf += d * (sumi1 + sumi2);
-    }
     *s = sumf;
+#else
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    UNUSED(ib);
+    UNUSED(sumf);
+    ggml_vec_dot_iq4_nl_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
 }
 
 void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
@@ -1264,37 +1048,10 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
     *s = sumf;
 
 #else
-    float sumf = 0;
-    for (int ibl = 0; ibl < nb; ++ibl) {
-        const float d4d8 = GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d;
-        uint16_t h = x[ibl].scales_h;
-        const uint8_t * qs = x[ibl].qs;
-        const int8_t  * q8 = y[ibl].qs;
-        for (int ib = 0; ib < QK_K/32; ib += 2) {
-            const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30);
-            const uint8_t ls2 = (x[ibl].scales_l[ib/2] >>  4) | ((h << 2) & 0x30);
-            h >>= 4;
-            const float d1 = d4d8*(ls1 - 32);
-            const float d2 = d4d8*(ls2 - 32);
-            int sumi1 = 0, sumi2 = 0;
-            for (int j = 0; j < 16; ++j) {
-                sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
-                sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >>  4];
-            }
-            sumf += d1 * (sumi1 + sumi2);
-            qs += 16;
-            q8 += 32;
-            sumi1 = sumi2 = 0;
-            for (int j = 0; j < 16; ++j) {
-                sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
-                sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >>  4];
-            }
-            sumf += d2 * (sumi1 + sumi2);
-            qs += 16;
-            q8 += 32;
-        }
-    }
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
diff --git a/ggml/src/ggml-cpu/arch/wasm/quants.c b/ggml/src/ggml-cpu/arch/wasm/quants.c
index b0904d8a3ab..74a359e6d12 100644
--- a/ggml/src/ggml-cpu/arch/wasm/quants.c
+++ b/ggml/src/ggml-cpu/arch/wasm/quants.c
@@ -435,30 +435,15 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
     sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
            wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
 
-#endif
-    for (; ib < nb; ++ib) {
-        uint32_t qh;
-        memcpy(&qh, x[ib].qh, sizeof(qh));
-
-        int sumi0 = 0;
-        int sumi1 = 0;
-
-        for (int j = 0; j < qk/2; ++j) {
-            const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
-            const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
-
-            const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16);
-            const int32_t x1 = (int8_t)(((x[ib].qs[j] >>   4) | xh_1) - 16);
-
-            sumi0 += (x0 * y[ib].qs[j]);
-            sumi1 += (x1 * y[ib].qs[j + qk/2]);
-        }
-
-        int sumi = sumi0 + sumi1;
-        sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d)) * sumi;
-    }
-
     *s = sumf;
+#else
+    UNUSED(nb);
+    UNUSED(ib);
+    UNUSED(sumf);
+    UNUSED(x);
+    UNUSED(y);
+    ggml_vec_dot_q5_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
 }
 
 void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
@@ -545,30 +530,15 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
     sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
            wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs;
 
-#endif
-    for (; ib < nb; ++ib) {
-        uint32_t qh;
-        memcpy(&qh, x[ib].qh, sizeof(qh));
-
-        int sumi0 = 0;
-        int sumi1 = 0;
-
-        for (int j = 0; j < qk/2; ++j) {
-            const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;
-            const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;
-
-            const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0;
-            const int32_t x1 = (x[ib].qs[j] >>  4) | xh_1;
-
-            sumi0 += (x0 * y[ib].qs[j]);
-            sumi1 += (x1 * y[ib].qs[j + qk/2]);
-        }
-
-        int sumi = sumi0 + sumi1;
-        sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);
-    }
-
     *s = sumf;
+#else
+    UNUSED(nb);
+    UNUSED(ib);
+    UNUSED(sumf);
+    UNUSED(x);
+    UNUSED(y);
+    ggml_vec_dot_q5_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
 }
 
 void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
@@ -628,18 +598,15 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
     sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
            wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
 
-#endif
-    for (; ib < nb; ++ib) {
-        int sumi = 0;
-
-        for (int j = 0; j < qk; j++) {
-            sumi += x[ib].qs[j]*y[ib].qs[j];
-        }
-
-        sumf += sumi*(GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d));
-    }
-
     *s = sumf;
+#else
+    UNUSED(nb);
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(ib);
+    UNUSED(sumf);
+    ggml_vec_dot_q8_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
 }
 
 void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
@@ -755,45 +722,10 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
     *s = sumf;
 
 #else
-
-    float sumf = 0;
-
-    for (int i = 0; i < nb; ++i) {
-
-        const uint8_t * q2 = x[i].qs;
-        const  int8_t * q8 = y[i].qs;
-        const uint8_t * sc = x[i].scales;
-
-        int summs = 0;
-        for (int j = 0; j < 16; ++j) {
-            summs += y[i].bsums[j] * (sc[j] >> 4);
-        }
-
-        const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
-        const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
-
-        int isum = 0;
-        int is = 0;
-        int d;
-        for (int k = 0; k < QK_K/128; ++k) {
-            int shift = 0;
-            for (int j = 0; j < 4; ++j) {
-                d = sc[is++] & 0xF;
-                int isuml = 0;
-                for (int l =  0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
-                isum += d * isuml;
-                d = sc[is++] & 0xF;
-                isuml = 0;
-                for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
-                isum += d * isuml;
-                shift += 2;
-                q8 += 32;
-            }
-            q2 += 32;
-        }
-        sumf += dall * isum - dmin * summs;
-    }
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -902,68 +834,12 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
     *s = sumf;
 
 #else
-    // scalar version
-    // This function is written like this so the compiler can manage to vectorize most of it
-    // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
-    // manually vectorized version above. Every other version I tried would run at least 4 times slower.
-    // The ideal situation would be if we could just write the code once, and the compiler would
-    // automatically produce the best possible set of machine instructions, instead of us having to manually
-    // write vectorized versions for AVX, ARM_NEON, etc.
-
-    int8_t  aux8[QK_K];
-    int16_t aux16[8];
-    float   sums [8];
-    int32_t aux32[8];
-    memset(sums, 0, 8*sizeof(float));
-
-    uint32_t auxs[4];
-    const int8_t * scales = (const int8_t*)auxs;
-
-    float sumf = 0;
-    for (int i = 0; i < nb; ++i) {
-        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
-        const uint8_t * GGML_RESTRICT hm = x[i].hmask;
-        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
-        memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * GGML_RESTRICT a = aux8;
-        uint8_t m = 1;
-        for (int j = 0; j < QK_K; j += 128) {
-            for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
-            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
-            a += 32; m <<= 1;
-            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
-            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
-            a += 32; m <<= 1;
-            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3;
-            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
-            a += 32; m <<= 1;
-            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3;
-            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
-            a += 32; m <<= 1;
-            q3 += 32;
-        }
-        a = aux8;
-
-        memcpy(auxs, x[i].scales, 12);
-        uint32_t tmp = auxs[2];
-        auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
-        auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
-        auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
-        auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
-        for (int j = 0; j < QK_K/16; ++j) {
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
-            q8 += 8; a += 8;
-        }
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
-    }
-    for (int l = 0; l < 8; ++l) sumf += sums[l];
-    *s = sumf;
-
+    UNUSED(kmask1);
+    UNUSED(kmask2);
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 
 }
@@ -1089,61 +965,14 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
     *s = sumf;
 
 #else
-
-    const uint8_t * scales = (const uint8_t*)&utmp[0];
-    const uint8_t * mins   = (const uint8_t*)&utmp[2];
-
-    int8_t  aux8[QK_K];
-    int16_t aux16[8];
-    float   sums [8];
-    int32_t aux32[8];
-    memset(sums, 0, 8*sizeof(float));
-
-    float sumf = 0;
-    for (int i = 0; i < nb; ++i) {
-        const uint8_t * GGML_RESTRICT q4 = x[i].qs;
-        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
-        memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * GGML_RESTRICT a = aux8;
-        for (int j = 0; j < QK_K/64; ++j) {
-            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
-            a += 32;
-            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l]  >> 4);
-            a += 32; q4 += 32;
-        }
-        memcpy(utmp, x[i].scales, 12);
-        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
-        const uint32_t uaux = utmp[1] & kmask1;
-        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
-        utmp[2] = uaux;
-        utmp[0] &= kmask1;
-
-        int sumi = 0;
-        for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
-        a = aux8;
-        int is = 0;
-        for (int j = 0; j < QK_K/32; ++j) {
-            int32_t scale = scales[is++];
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-        }
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
-        const float dmin = GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;
-        sumf -= dmin * sumi;
-    }
-    for (int l = 0; l < 8; ++l) sumf += sums[l];
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    UNUSED(kmask1);
+    UNUSED(kmask2);
+    UNUSED(kmask3);
+    UNUSED(utmp);
+    ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -1279,66 +1108,14 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
     *s = sumf;
 
 #else
-
-    const uint8_t * scales = (const uint8_t*)&utmp[0];
-    const uint8_t * mins   = (const uint8_t*)&utmp[2];
-
-    int8_t  aux8[QK_K];
-    int16_t aux16[8];
-    float   sums [8];
-    int32_t aux32[8];
-    memset(sums, 0, 8*sizeof(float));
-
-    float sumf = 0;
-    for (int i = 0; i < nb; ++i) {
-        const uint8_t * GGML_RESTRICT q4 = x[i].qs;
-        const uint8_t * GGML_RESTRICT hm = x[i].qh;
-        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
-        memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * GGML_RESTRICT a = aux8;
-        uint8_t m = 1;
-        for (int j = 0; j < QK_K/64; ++j) {
-            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
-            for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
-            a += 32; m <<= 1;
-            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l]  >> 4);
-            for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
-            a += 32; m <<= 1;
-            q4 += 32;
-        }
-        memcpy(utmp, x[i].scales, 12);
-        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
-        const uint32_t uaux = utmp[1] & kmask1;
-        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
-        utmp[2] = uaux;
-        utmp[0] &= kmask1;
-
-        int sumi = 0;
-        for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
-        a = aux8;
-        int is = 0;
-        for (int j = 0; j < QK_K/32; ++j) {
-            int32_t scale = scales[is++];
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-        }
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
-        const float dmin = GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;
-        sumf -= dmin * sumi;
-    }
-    for (int l = 0; l < 8; ++l) sumf += sums[l];
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    UNUSED(kmask1);
+    UNUSED(kmask2);
+    UNUSED(kmask3);
+    UNUSED(utmp);
+    ggml_vec_dot_q5_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -1435,47 +1212,10 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
     *s = sumf;
 
 #else
-
-    int8_t  aux8[QK_K];
-    int16_t aux16[8];
-    float   sums [8];
-    int32_t aux32[8];
-    memset(sums, 0, 8*sizeof(float));
-
-    float sumf = 0;
-    for (int i = 0; i < nb; ++i) {
-        const uint8_t * GGML_RESTRICT q4 = x[i].ql;
-        const uint8_t * GGML_RESTRICT qh = x[i].qh;
-        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
-        memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * GGML_RESTRICT a = aux8;
-        for (int j = 0; j < QK_K; j += 128) {
-            for (int l = 0; l < 32; ++l) {
-                a[l +  0] = (int8_t)((q4[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
-                a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
-                a[l + 64] = (int8_t)((q4[l +  0] >>  4) | (((qh[l] >> 4) & 3) << 4)) - 32;
-                a[l + 96] = (int8_t)((q4[l + 32] >>  4) | (((qh[l] >> 6) & 3) << 4)) - 32;
-            }
-            a  += 128;
-            q4 += 64;
-            qh += 32;
-        }
-        a = aux8;
-        int is = 0;
-        for (int j = 0; j < QK_K/16; ++j) {
-            int scale = x[i].scales[is++];
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-        }
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
-    }
-    for (int l = 0; l < 8; ++l) sumf += sums[l];
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
diff --git a/ggml/src/ggml-cpu/arch/x86/quants.c b/ggml/src/ggml-cpu/arch/x86/quants.c
index e7527c00a8f..cb49320a67f 100644
--- a/ggml/src/ggml-cpu/arch/x86/quants.c
+++ b/ggml/src/ggml-cpu/arch/x86/quants.c
@@ -66,6 +66,12 @@ static inline int hsum_i32_4(const __m128i a) {
 }
 
 #if defined(__AVX2__) || defined(__AVX512F__)
+static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
+    const __m256i ax = _mm256_sign_epi8(x, x);
+    const __m256i sy = _mm256_sign_epi8(y, x);
+    return _mm256_maddubs_epi16(ax, sy);
+}
+
 // spread 32 bits to 32 bytes { 0x00, 0xFF }
 static inline __m256i bytes_from_bits_32(const uint8_t * x) {
     uint32_t x32;
@@ -261,6 +267,11 @@ static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const
     return _mm256_set_m128(_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x1) * GGML_CPU_FP16_TO_FP32(y1)),
                            _mm_set1_ps(GGML_CPU_FP16_TO_FP32(x0) * GGML_CPU_FP16_TO_FP32(y0)));
 }
+
+static inline __m256 quad_mx_delta_float(const int8_t x0, const float y0, const int8_t x1, const float y1) {
+    return _mm256_set_m128(_mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x1) * GGML_CPU_FP16_TO_FP32(y1)),
+                           _mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x0) * GGML_CPU_FP16_TO_FP32(y0)));
+}
 #endif
 #elif defined(__SSSE3__)
 // horizontally add 4x4 floats
@@ -702,7 +713,6 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
     const block_q8_1 * GGML_RESTRICT y = vy;
 
     int ib = 0;
-    float sumf = 0;
 
 #if defined(__AVX2__) || defined(__AVX__)
     // Initialize accumulator with zeros
@@ -737,25 +747,98 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
 #endif
     }
 
-    sumf = hsum_float_8(acc) + summs;
-
+    *s = hsum_float_8(acc) + summs;
+#else
+    UNUSED(nb);
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(ib);
+    ggml_vec_dot_q4_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
-    for (; ib < nb; ++ib) {
-        int sumi0 = 0;
-        int sumi1 = 0;
+}
 
-        for (int j = 0; j < qk/2; ++j) {
-            const int v0 = (x[ib].qs[j] & 0x0F);
-            const int v1 = (x[ib].qs[j] >>   4);
+void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+    assert(n % QK_MXFP4 == 0);
+    static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
 
-            sumi0 += (v0 * y[ib].qs[j]);
-            sumi1 += (v1 * y[ib].qs[j + qk/2]);
-        }
+    const block_mxfp4 * GGML_RESTRICT x = vx;
+    const block_q8_0 * GGML_RESTRICT y = vy;
 
-        int sumi = sumi0 + sumi1;
-        sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);
+    const int nb = n / QK_MXFP4;
+
+    int ib = 0;
+    float sumf = 0;
+
+#if defined __AVX2__
+
+    const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
+    const __m128i m4b  = _mm_set1_epi8(0x0f);
+    const __m256i mone = _mm256_set1_epi16(1);
+
+    __m256 accum1 = _mm256_setzero_ps();
+    __m256 accum2 = _mm256_setzero_ps();
+    for (; ib + 1 < nb; ib += 2) {
+        const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs);
+        const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs);
+        const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs);
+        const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs);
+        const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
+                                              _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
+        const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
+                                              _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
+        const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
+        const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
+        const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
+        const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
+        accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 0].e)),
+                _mm256_cvtepi32_ps(p_1), accum1);
+        accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 1].e)),
+                _mm256_cvtepi32_ps(p_2), accum2);
     }
 
+    sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
+
+#elif defined __AVX__
+    const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
+    const __m128i m4b  = _mm_set1_epi8(0x0f);
+
+    __m256 accum = _mm256_setzero_ps();
+    for (; ib + 1 < nb; ib += 2) {
+        const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
+        const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
+        const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
+        const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
+        const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
+        const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
+
+        const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
+        const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
+        const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
+        const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
+
+        const __m256 p = mul_sum_i8_quad_float(q4b_1_0, q4b_1_1, q4b_2_0, q4b_2_1, q8b_1_0, q8b_1_1, q8b_2_0, q8b_2_1);
+        const __m256 deltas = quad_mx_delta_float(x[ib].e, y[ib].d, x[ib + 1].e, y[ib + 1].d);
+        accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
+    }
+
+    sumf = hsum_float_8(accum);
+
+#endif
+    for (; ib < nb; ++ib) {
+        const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
+        int sumi1 = 0;
+        int sumi2 = 0;
+        for (int j = 0; j < QK_MXFP4/2; ++j) {
+            sumi1 += y[ib].qs[j +          0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
+            sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >>  4];
+        }
+        sumf += d * (sumi1 + sumi2);
+    }
     *s = sumf;
 }
 
@@ -764,7 +847,6 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
     const int nb = n / qk;
 
     int ib = 0;
-    float sumf = 0;
 
     assert(n % qk == 0);
     assert(qk == QK5_0);
@@ -799,7 +881,7 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
         acc = _mm256_fmadd_ps(d, q, acc);
     }
 
-    sumf = hsum_float_8(acc);
+    *s = hsum_float_8(acc);
 #elif defined(__AVX__)
     // Initialize accumulator with zeros
     __m256 acc = _mm256_setzero_ps();
@@ -830,32 +912,14 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
         acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
     }
 
-    sumf = hsum_float_8(acc);
-
+    *s = hsum_float_8(acc);
+#else
+    UNUSED(nb);
+    UNUSED(ib);
+    UNUSED(x);
+    UNUSED(y);
+    ggml_vec_dot_q5_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
-    for (; ib < nb; ++ib) {
-        uint32_t qh;
-        memcpy(&qh, x[ib].qh, sizeof(qh));
-
-        int sumi0 = 0;
-        int sumi1 = 0;
-
-        for (int j = 0; j < qk/2; ++j) {
-            const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
-            const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
-
-            const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16);
-            const int32_t x1 = (int8_t)(((x[ib].qs[j] >>   4) | xh_1) - 16);
-
-            sumi0 += (x0 * y[ib].qs[j]);
-            sumi1 += (x1 * y[ib].qs[j + qk/2]);
-        }
-
-        int sumi = sumi0 + sumi1;
-        sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d)) * sumi;
-    }
-
-    *s = sumf;
 }
 
 void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
@@ -863,7 +927,6 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
     const int nb = n / qk;
 
     int ib = 0;
-    float sumf = 0;
 
     assert(n % qk == 0);
     assert(qk == QK5_1);
@@ -901,7 +964,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
         acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
     }
 
-    sumf = hsum_float_8(acc) + summs;
+    *s = hsum_float_8(acc) + summs;
 #elif defined(__AVX__)
     // Initialize accumulator with zeros
     __m256 acc = _mm256_setzero_ps();
@@ -935,32 +998,14 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
         acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
     }
 
-    sumf = hsum_float_8(acc) + summs;
-
+    *s = hsum_float_8(acc) + summs;
+#else
+    UNUSED(nb);
+    UNUSED(ib);
+    UNUSED(x);
+    UNUSED(y);
+    ggml_vec_dot_q5_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
-    for (; ib < nb; ++ib) {
-        uint32_t qh;
-        memcpy(&qh, x[ib].qh, sizeof(qh));
-
-        int sumi0 = 0;
-        int sumi1 = 0;
-
-        for (int j = 0; j < qk/2; ++j) {
-            const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;
-            const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;
-
-            const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0;
-            const int32_t x1 = (x[ib].qs[j] >>  4) | xh_1;
-
-            sumi0 += (x0 * y[ib].qs[j]);
-            sumi1 += (x1 * y[ib].qs[j + qk/2]);
-        }
-
-        int sumi = sumi0 + sumi1;
-        sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);
-    }
-
-    *s = sumf;
 }
 
 void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
@@ -1017,7 +1062,6 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
     }
 
     sumf = hsum_float_8(accum);
-
 #endif
     for (; ib < nb; ++ib) {
         int sumi = 0;
@@ -1157,44 +1201,10 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
     *s = hsum_float_8(sumf);
 
 #else
-    const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243};
-
-    float sumf = 0.0f;
-
-    for (int i = 0; i < nb; ++i) {
-        int sum = 0;
-
-        for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) {
-            for (size_t l = 0; l < 5; ++l) {
-                for (size_t m = 0; m < 32; ++m) {
-                    uint8_t q = x[i].qs[j + m] * pow3[l];
-                    uint16_t xi = ((uint16_t) q * 3) >> 8;
-                    sum += (xi - 1) * y[i].qs[j*5 + l*32 + m];
-                }
-            }
-        }
-        for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) {
-            for (size_t l = 0; l < 5; ++l) {
-                for (size_t m = 0; m < 16; ++m) {
-                    uint8_t q = x[i].qs[j + m] * pow3[l];
-                    uint16_t xi = ((uint16_t) q * 3) >> 8;
-                    sum += (xi - 1) * y[i].qs[j*5 + l*16 + m];
-                }
-            }
-        }
-
-        for (size_t l = 0; l < 4; ++l) {
-            for (size_t j = 0; j < sizeof(x->qh); ++j) {
-                uint8_t q = x[i].qh[j] * pow3[l];
-                uint16_t xi = ((uint16_t) q * 3) >> 8;
-                sum += (xi - 1) * y[i].qs[sizeof(x->qs)*5 + l*sizeof(x->qh) + j];
-            }
-        }
-
-        sumf += (float) sum * (GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d);
-    }
-
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -1257,25 +1267,10 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
     *s = hsum_float_8(sumf);
 
 #else
-    float sumf = 0.0f;
-
-    for (int i = 0; i < nb; ++i) {
-        int32_t sumi = 0;
-
-        for (size_t j = 0; j < sizeof(x->qs); j += 32) {
-            for (size_t l = 0; l < 4; ++l) {
-                for (size_t k = 0; k < 32; ++k) {
-                    sumi += y[i].qs[j*4 + l*32 + k] * (((x[i].qs[j + k] >> (l*2)) & 3) - 1);
-                }
-            }
-        }
-
-        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
-
-        sumf += (float) sumi * d;
-    }
-
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -1464,45 +1459,10 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
     *s = hsum_float_8(acc);
 
 #else
-
-    float sumf = 0;
-
-    for (int i = 0; i < nb; ++i) {
-
-        const uint8_t * q2 = x[i].qs;
-        const  int8_t * q8 = y[i].qs;
-        const uint8_t * sc = x[i].scales;
-
-        int summs = 0;
-        for (int j = 0; j < 16; ++j) {
-            summs += y[i].bsums[j] * (sc[j] >> 4);
-        }
-
-        const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
-        const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
-
-        int isum = 0;
-        int is = 0;
-        int d;
-        for (int k = 0; k < QK_K/128; ++k) {
-            int shift = 0;
-            for (int j = 0; j < 4; ++j) {
-                d = sc[is++] & 0xF;
-                int isuml = 0;
-                for (int l =  0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
-                isum += d * isuml;
-                d = sc[is++] & 0xF;
-                isuml = 0;
-                for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
-                isum += d * isuml;
-                shift += 2;
-                q8 += 32;
-            }
-            q2 += 32;
-        }
-        sumf += dall * isum - dmin * summs;
-    }
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -1769,70 +1729,13 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
     *s = hsum_float_8(acc);
 
 #else
-    // scalar version
-    // This function is written like this so the compiler can manage to vectorize most of it
-    // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
-    // manually vectorized version above. Every other version I tried would run at least 4 times slower.
-    // The ideal situation would be if we could just write the code once, and the compiler would
-    // automatically produce the best possible set of machine instructions, instead of us having to manually
-    // write vectorized versions for AVX, ARM_NEON, etc.
-
-    int8_t  aux8[QK_K];
-    int16_t aux16[8];
-    float   sums [8];
-    int32_t aux32[8];
-    memset(sums, 0, 8*sizeof(float));
-
-    uint32_t auxs[4];
-    const int8_t * scales = (const int8_t*)auxs;
-
-    float sumf = 0;
-    for (int i = 0; i < nb; ++i) {
-        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
-        const uint8_t * GGML_RESTRICT hm = x[i].hmask;
-        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
-        memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * GGML_RESTRICT a = aux8;
-        uint8_t m = 1;
-        for (int j = 0; j < QK_K; j += 128) {
-            for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
-            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
-            a += 32; m <<= 1;
-            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
-            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
-            a += 32; m <<= 1;
-            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3;
-            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
-            a += 32; m <<= 1;
-            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3;
-            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
-            a += 32; m <<= 1;
-            q3 += 32;
-        }
-        a = aux8;
-
-        memcpy(auxs, x[i].scales, 12);
-        uint32_t tmp = auxs[2];
-        auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
-        auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
-        auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
-        auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
-        for (int j = 0; j < QK_K/16; ++j) {
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
-            q8 += 8; a += 8;
-        }
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
-    }
-    for (int l = 0; l < 8; ++l) sumf += sums[l];
-    *s = sumf;
-
+    UNUSED(kmask1);
+    UNUSED(kmask2);
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
-
 }
 
 void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
@@ -2002,61 +1905,14 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
     *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
 
 #else
-
-    const uint8_t * scales = (const uint8_t*)&utmp[0];
-    const uint8_t * mins   = (const uint8_t*)&utmp[2];
-
-    int8_t  aux8[QK_K];
-    int16_t aux16[8];
-    float   sums [8];
-    int32_t aux32[8];
-    memset(sums, 0, 8*sizeof(float));
-
-    float sumf = 0;
-    for (int i = 0; i < nb; ++i) {
-        const uint8_t * GGML_RESTRICT q4 = x[i].qs;
-        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
-        memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * GGML_RESTRICT a = aux8;
-        for (int j = 0; j < QK_K/64; ++j) {
-            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
-            a += 32;
-            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l]  >> 4);
-            a += 32; q4 += 32;
-        }
-        memcpy(utmp, x[i].scales, 12);
-        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
-        const uint32_t uaux = utmp[1] & kmask1;
-        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
-        utmp[2] = uaux;
-        utmp[0] &= kmask1;
-
-        int sumi = 0;
-        for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
-        a = aux8;
-        int is = 0;
-        for (int j = 0; j < QK_K/32; ++j) {
-            int32_t scale = scales[is++];
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-        }
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
-        const float dmin = GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;
-        sumf -= dmin * sumi;
-    }
-    for (int l = 0; l < 8; ++l) sumf += sums[l];
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    UNUSED(kmask1);
+    UNUSED(kmask2);
+    UNUSED(kmask3);
+    UNUSED(utmp);
+    ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -2259,66 +2115,14 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
     *s = hsum_float_8(acc) + summs;
 
 #else
-
-    const uint8_t * scales = (const uint8_t*)&utmp[0];
-    const uint8_t * mins   = (const uint8_t*)&utmp[2];
-
-    int8_t  aux8[QK_K];
-    int16_t aux16[8];
-    float   sums [8];
-    int32_t aux32[8];
-    memset(sums, 0, 8*sizeof(float));
-
-    float sumf = 0;
-    for (int i = 0; i < nb; ++i) {
-        const uint8_t * GGML_RESTRICT q4 = x[i].qs;
-        const uint8_t * GGML_RESTRICT hm = x[i].qh;
-        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
-        memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * GGML_RESTRICT a = aux8;
-        uint8_t m = 1;
-        for (int j = 0; j < QK_K/64; ++j) {
-            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
-            for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
-            a += 32; m <<= 1;
-            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l]  >> 4);
-            for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
-            a += 32; m <<= 1;
-            q4 += 32;
-        }
-        memcpy(utmp, x[i].scales, 12);
-        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
-        const uint32_t uaux = utmp[1] & kmask1;
-        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
-        utmp[2] = uaux;
-        utmp[0] &= kmask1;
-
-        int sumi = 0;
-        for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
-        a = aux8;
-        int is = 0;
-        for (int j = 0; j < QK_K/32; ++j) {
-            int32_t scale = scales[is++];
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-        }
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
-        const float dmin = GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;
-        sumf -= dmin * sumi;
-    }
-    for (int l = 0; l < 8; ++l) sumf += sums[l];
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    UNUSED(kmask1);
+    UNUSED(kmask2);
+    UNUSED(kmask3);
+    UNUSED(utmp);
+    ggml_vec_dot_q5_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -2520,47 +2324,10 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
     *s = hsum_float_8(acc);
 
 #else
-
-    int8_t  aux8[QK_K];
-    int16_t aux16[8];
-    float   sums [8];
-    int32_t aux32[8];
-    memset(sums, 0, 8*sizeof(float));
-
-    float sumf = 0;
-    for (int i = 0; i < nb; ++i) {
-        const uint8_t * GGML_RESTRICT q4 = x[i].ql;
-        const uint8_t * GGML_RESTRICT qh = x[i].qh;
-        const  int8_t * GGML_RESTRICT q8 = y[i].qs;
-        memset(aux32, 0, 8*sizeof(int32_t));
-        int8_t * GGML_RESTRICT a = aux8;
-        for (int j = 0; j < QK_K; j += 128) {
-            for (int l = 0; l < 32; ++l) {
-                a[l +  0] = (int8_t)((q4[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
-                a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
-                a[l + 64] = (int8_t)((q4[l +  0] >>  4) | (((qh[l] >> 4) & 3) << 4)) - 32;
-                a[l + 96] = (int8_t)((q4[l + 32] >>  4) | (((qh[l] >> 6) & 3) << 4)) - 32;
-            }
-            a  += 128;
-            q4 += 64;
-            qh += 32;
-        }
-        a = aux8;
-        int is = 0;
-        for (int j = 0; j < QK_K/16; ++j) {
-            int scale = x[i].scales[is++];
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
-            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
-            q8 += 8; a += 8;
-        }
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
-    }
-    for (int l = 0; l < 8; ++l) sumf += sums[l];
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -2712,34 +2479,10 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const
     *s = 0.125f * hsum_float_8(accumf);
 
 #else
-
-    uint32_t aux32[2];
-    const uint8_t * aux8 = (const uint8_t *)aux32;
-
-    float sumf = 0.f;
-    for (int i = 0; i < nb; ++i) {
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint16_t * GGML_RESTRICT q2 = x[i].qs;
-        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
-        int32_t bsum = 0;
-        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
-            memcpy(aux32, q2, 2*sizeof(uint32_t));
-            q2 += 4;
-            const uint32_t ls = 2*(aux32[1] >> 28) + 1;
-            int32_t sumi = 0;
-            for (int l = 0; l < 4; ++l) {
-                const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
-                const uint8_t  signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
-                for (int j = 0; j < 8; ++j) {
-                    sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            bsum += sumi * ls;
-        }
-        sumf += d * bsum;
-    }
-    *s = 0.125f * sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq2_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -3033,42 +2776,10 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
     *s = 0.125f * hsum_float_8(accumf);
 
 #else
-
-    float sumf = 0.f;
-    for (int i = 0; i < nb; ++i) {
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint16_t * GGML_RESTRICT q2 = x[i].qs;
-        const uint8_t  * GGML_RESTRICT sc = x[i].scales;
-        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
-        int32_t bsum = 0;
-        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
-            const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1;
-            const uint16_t ls2 = 2*(sc[ib32] >>  4) + 1;
-            int32_t sumi = 0;
-            for (int l = 0; l < 2; ++l) {
-                const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
-                const uint8_t  signs = ksigns_iq2xs[q2[l] >> 9];
-                for (int j = 0; j < 8; ++j) {
-                    sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            bsum += sumi * ls1;
-            sumi = 0;
-            for (int l = 2; l < 4; ++l) {
-                const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
-                const uint8_t  signs = ksigns_iq2xs[q2[l] >> 9];
-                for (int j = 0; j < 8; ++j) {
-                    sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            bsum += sumi * ls2;
-            q2 += 4;
-        }
-        sumf += d * bsum;
-    }
-    *s = 0.125f * sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -3250,47 +2961,11 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
     *s = 0.125f * hsum_float_8(accumf);
 
 #else
-
-    float sumf = 0;
-    for (int i = 0; i < nb; i++) {
-
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        const int8_t  * q8 = y[i].qs;
-        const uint8_t * qs = x[i].qs;
-        const uint8_t * qh = x[i].qh;
-        const uint8_t * signs = qs + QK_K/8;
-
-        int bsum = 0;
-        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
-            int ls1 = 1 + 2*(x[i].scales[ib32] & 0xf);
-            int ls2 = 1 + 2*(x[i].scales[ib32] >>  4);
-            int sumi1 = 0, sumi2 = 0;
-            for (int l = 0; l < 2; ++l) {
-                const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300)));
-                for (int j = 0; j < 8; ++j) {
-                    sumi1 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            for (int l = 2; l < 4; ++l) {
-                const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300)));
-                for (int j = 0; j < 8; ++j) {
-                    sumi2 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            bsum += ls1 * sumi1 + ls2 * sumi2;
-            qs += 4;
-            signs += 4;
-        }
-
-        sumf += d * bsum;
-    }
-
-    *s = 0.125f * sumf;
-
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
-
 }
 
 void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
@@ -3410,36 +3085,10 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const
     *s = 0.25f * hsum_float_8(accumf);
 
 #else
-
-    uint32_t aux32;
-
-    float sumf = 0.f;
-    for (int i = 0; i < nb; ++i) {
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint8_t * GGML_RESTRICT q3 = x[i].qs;
-        const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4;
-        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
-        int32_t bsum = 0;
-        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
-            memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t);
-            const uint32_t ls = 2*(aux32 >> 28) + 1;
-            int32_t sumi = 0;
-            for (int l = 0; l < 4; ++l) {
-                const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]);
-                const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]);
-                const uint8_t  signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
-                for (int j = 0; j < 4; ++j) {
-                    sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1);
-                    sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            q3 += 8;
-            bsum += sumi * ls;
-        }
-        sumf += d * bsum;
-    }
-    *s = 0.25f * sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -3646,59 +3295,13 @@ void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
     *s = hsum_float_8(accumf);
 
 #else
-
-    float sumf = 0.f;
-    for (int i = 0; i < nb; ++i) {
-        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
-        const uint8_t * GGML_RESTRICT qs = x[i].qs;
-        const uint8_t * GGML_RESTRICT qh = x[i].qh;
-        const uint8_t * GGML_RESTRICT signs = x[i].signs;
-        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
-        int32_t bsum = 0;
-        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
-            const uint32_t ls1 = 2*(x[i].scales[ib32/2] & 0xf) + 1;
-            const uint32_t ls2 = 2*(x[i].scales[ib32/2] >>  4) + 1;
-            int32_t sumi = 0;
-            for (int l = 0; l < 4; ++l) {
-                const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256)));
-                const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256)));
-                for (int j = 0; j < 4; ++j) {
-                    sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
-                    sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            qs += 8;
-            signs += 4;
-            bsum += sumi * ls1;
-            sumi = 0;
-            for (int l = 0; l < 4; ++l) {
-                const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256)));
-                const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256)));
-                for (int j = 0; j < 4; ++j) {
-                    sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
-                    sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
-                }
-                q8 += 8;
-            }
-            qs += 8;
-            signs += 4;
-            bsum += sumi * ls2;
-        }
-        sumf += d * bsum;
-    }
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
-#if defined(__AVX2__)
-static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
-    const __m256i ax = _mm256_sign_epi8(x, x);
-    const __m256i sy = _mm256_sign_epi8(y, x);
-    return _mm256_maddubs_epi16(ax, sy);
-}
-#endif
-
 void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     assert(n % QK_K == 0);
     assert(nrc == 1);
@@ -3811,36 +3414,10 @@ void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
     *s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
 
 #else
-
-    float sumf = 0;
-    for (int i = 0; i < nb; i++) {
-
-        const int8_t   * q8 = y[i].qs;
-        const uint8_t  * qs = x[i].qs;
-        const uint16_t * qh = x[i].qh;
-
-        int sumi = 0, sumi1 = 0;
-        for (int ib = 0; ib < QK_K/32; ++ib) {
-            const int ls = 2*((qh[ib] >> 12) & 7) + 1;
-            const int delta = qh[ib] & 0x8000 ? -1 : 1;
-            int lsum = 0;
-            for (int l = 0; l < 4; ++l) {
-                const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8)));
-                for (int j = 0; j < 8; ++j) {
-                    lsum += q8[j] * grid[j];
-                }
-                q8 += 8;
-            }
-            sumi  += ls * lsum;
-            sumi1 += ls * delta * (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]);
-            qs += 4;
-        }
-
-        sumf += GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1);
-    }
-
-    *s = sumf;
-
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -4043,52 +3620,11 @@ void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
     *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2);
 
 #else
-
-    int sum1[2], sum2[2], delta[4];
-
-    float sumf = 0;
-    for (int i = 0; i < nb; i++) {
-
-        const int8_t   * q8 = y[i].qs;
-        const uint8_t  * qs = x[i].qs;
-        const uint8_t  * qh = x[i].qh;
-        const uint16_t * sc = (const uint16_t *)x[i].scales;
-
-        scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
-
-        int sumi1 = 0, sumi2 = 0;
-        for (int ib = 0; ib < QK_K/32; ++ib) {
-            delta[0] = qh[0] & 0x08 ? -1 : 1;
-            delta[1] = qh[0] & 0x80 ? -1 : 1;
-            delta[2] = qh[1] & 0x08 ? -1 : 1;
-            delta[3] = qh[1] & 0x80 ? -1 : 1;
-            sum1[0] = sum1[1] = sum2[0] = sum2[1] = 0;
-            for (int l = 0; l < 4; ++l) {
-                const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((uint16_t)qh[l/2] << (8 - 4*(l%2))) & 0x700)));
-                int lsum1 = 0, lsum2 = 0;
-                for (int j = 0; j < 8; ++j) {
-                    lsum1 += q8[j] * grid[j];
-                    lsum2 += q8[j];
-                }
-                q8 += 8;
-                sum1[l/2] += lsum1;
-                sum2[l/2] += lsum2*delta[l];
-            }
-
-            const int ls1 = 2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1;
-            const int ls2 = 2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1;
-
-            sumi1 += sum1[0] * ls1 + sum1[1] * ls2;
-            sumi2 += sum2[0] * ls1 + sum2[1] * ls2;
-            qs += 4;
-            qh += 2;
-        }
-
-        sumf += GGML_CPU_FP16_TO_FP32(scale.f16) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2);
-    }
-
-    *s = sumf;
-
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    UNUSED(scale);
+    ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
@@ -4275,37 +3811,10 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
     *s = hsum_float_8(accum);
 
 #else
-    float sumf = 0;
-    for (int ibl = 0; ibl < nb; ++ibl) {
-        const float d4d8 = GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d;
-        uint16_t h = x[ibl].scales_h;
-        const uint8_t * qs = x[ibl].qs;
-        const int8_t  * q8 = y[ibl].qs;
-        for (int ib = 0; ib < QK_K/32; ib += 2) {
-            const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30);
-            const uint8_t ls2 = (x[ibl].scales_l[ib/2] >>  4) | ((h << 2) & 0x30);
-            h >>= 4;
-            const float d1 = d4d8*(ls1 - 32);
-            const float d2 = d4d8*(ls2 - 32);
-            int sumi1 = 0, sumi2 = 0;
-            for (int j = 0; j < 16; ++j) {
-                sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
-                sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >>  4];
-            }
-            sumf += d1 * (sumi1 + sumi2);
-            qs += 16;
-            q8 += 32;
-            sumi1 = sumi2 = 0;
-            for (int j = 0; j < 16; ++j) {
-                sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
-                sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >>  4];
-            }
-            sumf += d2 * (sumi1 + sumi2);
-            qs += 16;
-            q8 += 32;
-        }
-    }
-    *s = sumf;
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
 
diff --git a/ggml/src/ggml-cpu/arch/x86/repack.cpp b/ggml/src/ggml-cpu/arch/x86/repack.cpp
index c00c1e541cb..d95bb6d8aaf 100644
--- a/ggml/src/ggml-cpu/arch/x86/repack.cpp
+++ b/ggml/src/ggml-cpu/arch/x86/repack.cpp
@@ -281,35 +281,9 @@ void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR
     }
 
 #else
-    // scalar
-    const int blck_size_interleave = 8;
-    float srcv[4][QK8_0];
-    float id[4];
-
-    for (int i = 0; i < nb; i++) {
-        for (int row_iter = 0; row_iter < 4; row_iter++) {
-            float amax = 0.0f; // absolute max
-
-            for (int j = 0; j < QK8_0; j++) {
-                srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j];
-                amax = MAX(amax, fabsf(srcv[row_iter][j]));
-            }
-
-            const float d = amax / ((1 << 7) - 1);
-            id[row_iter] = d ? 1.0f / d : 0.0f;
-
-            y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d);
-        }
-
-        for (int j = 0; j < QK8_0 * 4; j++) {
-            int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
-            int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
-            src_offset += (j % blck_size_interleave);
-
-            float x0 = srcv[src_id][src_offset] * id[src_id];
-            y[i].qs[j] = roundf(x0);
-        }
-    }
+    UNUSED(nb);
+    UNUSED(y);
+    ggml_quantize_mat_q8_0_4x8_generic(x, vy, k);
 #endif
 }
 
@@ -531,84 +505,40 @@ void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR
     }
 
 #else
+    UNUSED(nb);
+    UNUSED(y);
+    ggml_quantize_mat_q8_K_4x8_generic(x, vy, k);
+#endif
+}
 
-    // scalar
-    const int blck_size_interleave = 8;
-    float srcv[4][QK_K];
-    float iscale[4];
-
-    for (int i = 0; i < nb; i++) {
-        for (int row_iter = 0; row_iter < 4; row_iter++) {
-            float amax = 0.0f; // absolute max
-            float max = 0;
-
-            for (int j = 0; j < QK_K; j++) {
-                srcv[row_iter][j] = x[row_iter * k + i * QK_K + j];
-                // Update the maximum value of the corresponding super block
-                if(amax < fabsf(srcv[row_iter][j])) {
-                    amax = fabsf(srcv[row_iter][j]);
-                    max = srcv[row_iter][j];
-                }
-            }
-
-            iscale[row_iter] = amax ? -127.f/max : 0;
-
-            y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
-        }
+//
+// GEMV/GEMM templates
+//
 
-        for (int j = 0; j < QK_K / 4; j++) {
-            y[i].bsums[j] = 0;
-        }
+#if defined(__AVX2__) || defined(__AVX512F__)
 
-        // Quants values are interleaved in sequence of eight bytes from corresponding super blocks
-        // Bsums values are interleaved in sequence of four bsums from each super block taken for interleaving
-        // i.e first four bsums from the first super block, followed by first four bsums from second super block and so on
-        for (int j = 0; j < QK_K * 4; j++) {
-            int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
-            int src_id     = (j % (4 * blck_size_interleave)) / blck_size_interleave;
-            src_offset += (j % blck_size_interleave);
-            int index = (((j & 31) >> 3) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3);
-
-            float x0 = srcv[src_id][src_offset] * iscale[src_id];
-            y[i].qs[j] = nearest_int(x0);
-            y[i].bsums[index] += y[i].qs[j];
-        }
-    }
-#endif
-}
+// GEMV for 8x blocks of 32 4-bit quants with a single scale factor per block
+template
+static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, __m256i signextendlut) {
+    static_assert(
+            std::is_same_v ||
+            std::is_same_v,
+            "Unsupported block type");
 
-void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
     const int qk = QK8_0;
     const int nb = n / qk;
-    const int ncols_interleaved = 8;
-    const int blocklen = 8;
 
-    assert (n % qk == 0);
-    assert (nc % ncols_interleaved == 0);
-
-    UNUSED(s);
     UNUSED(bs);
-    UNUSED(vx);
-    UNUSED(vy);
-    UNUSED(nr);
-    UNUSED(nc);
-    UNUSED(nb);
-    UNUSED(ncols_interleaved);
-    UNUSED(blocklen);
 
-#if defined(__AVX2__)
-    // Lookup table to convert signed nibbles to signed bytes
-    __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
-    signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
     __m128i changemask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0);
     __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
 
     // Permute mask used for easier vector processing at later stages
     const __m256i m4b = _mm256_set1_epi8(0x0F);
 
-    int64_t b_nb = n / QK4_0;
+    int64_t b_nb = n / 32;
 
-    const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx;
+    const block_tx8  * b_ptr_start = (const block_tx8  *)vx;
     const block_q8_0 * a_ptr_start = (const block_q8_0 *)vy;
 
     // Process Q8_0 blocks one by one
@@ -617,17 +547,17 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
         // Pointers to LHS blocks of block_q8_0 format
         const block_q8_0 * a_ptr = a_ptr_start + (y * nb);
 
-        // Take group of eight block_q4_0x8 structures at each pass of the loop and perform dot product operation
+        // Take group of eight blocks at each pass of the loop and perform dot product operation
         for (int64_t x = 0; x < nc / 8; x++) {
 
             // Pointers to RHS blocks
-            const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb);
+            const block_tx8 * b_ptr = b_ptr_start + (x * b_nb);
 
             // Master FP accumulator
             __m256 acc_row = _mm256_setzero_ps();
 
             for (int64_t b = 0; b < nb; b++) {
-                // Load 8 blocks of Q4_0 interleaved as 8 bytes (B0 - B7)
+                // Load 8 blocks of 32 interleaved as 8 bytes (B0 - B7)
                 const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));
                 const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 1);
                 const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 2);
@@ -644,8 +574,13 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
                 const __m256i rhs_vec_0123_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m4b)); // B0(24-31) B1(24-31) B2(24-31) B3(24-31)
                 const __m256i rhs_vec_4567_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m4b)); // B4(24-31) B5(24-31) B6(24-31) B7(24-31)
 
-                // Load the scale values for the 8 blocks interleaved in block_q4_0x8
-                const __m256 col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, changemask);
+                // Load the scale values for the 8 blocks interleaved in block_tx8
+                __m256 col_scale_f32;
+                if constexpr (
+                        std::is_same_v ||
+                        std::is_same_v) {
+                    col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, changemask);
+                }
 
                 // Load and convert to FP32 scale from block_q8_0
                 const __m256 row_scale_f32 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(a_ptr[b].d));
@@ -686,1097 +621,2813 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
             _mm256_storeu_ps(s + (y * nr + x * 8), acc_row);
         }
     }
-    return;
-
-#endif
-    {
-        float sumf[8];
-        int sumi;
-
-        const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
-        for (int x = 0; x < nc / ncols_interleaved; x++) {
-            const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
-
-            for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
-            for (int l = 0; l < nb; l++) {
-                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
-                    for (int j = 0; j < ncols_interleaved; j++) {
-                        sumi = 0;
-                        for (int i = 0; i < blocklen; ++i) {
-                            const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
-                            const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
-                            sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
-                        }
-                        sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
-                    }
-                }
-            }
-            for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
-        }
-    }
 }
 
-void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
-    const int qk = QK_K;
-    const int nb = n / qk;
-    const int ncols_interleaved = 8;
-    const int blocklen = 8;
-    static const uint32_t kmask1 = 0x3f3f3f3f;
-    static const uint32_t kmask2 = 0x0f0f0f0f;
-    static const uint32_t kmask3 = 0x03030303;
-
-    assert (n % qk == 0);
-    assert (nc % ncols_interleaved == 0);
+// GEMM for 8x blocks of 32 4-bit quants with a single scale factor per block
+template
+static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, __m256i signextendlut) {
+    static_assert(
+            std::is_same_v ||
+            std::is_same_v,
+            "Unsupported block type");
 
-    UNUSED(s);
-    UNUSED(bs);
-    UNUSED(vx);
-    UNUSED(vy);
-    UNUSED(nr);
-    UNUSED(nc);
-    UNUSED(nb);
-    UNUSED(ncols_interleaved);
-    UNUSED(blocklen);
+    const int qk = QK8_0;
+    const int nb = n / qk;
 
-#if defined(__AVX2__)
-    // Lookup table to convert signed nibbles to signed bytes
-    __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
-    signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
-    // Shuffle masks to rearrange delta and scale values to multiply with appropriate scales
-    __m128i deltamask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0);
-    __m128i scalemask = _mm_set_epi8(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0);
-    // Permute mask used for easier vector processing at later stages
-    __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
+    const block_tx8    * b_ptr_start = (const block_tx8    *)vx;
+    const block_q8_0x4 * a_ptr_start = (const block_q8_0x4 *)vy;
 
-    // Mask to extract nibbles from bytes
+    int64_t b_nb = n / 32;
+    int64_t y = 0;
+    // Mask to mask out nibbles from packed bytes
     const __m256i m4b = _mm256_set1_epi8(0x0F);
+    const __m128i loadMask = _mm_blend_epi32(_mm_setzero_si128(), _mm_set1_epi32(0xFFFFFFFF), 3);
+    // Permute mask used for easier vector processing at later stages
+    __m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
+    int64_t xstart = 0;
+    int anr = nr - nr%16; // Used to align nr with boundary of 16
+#ifdef __AVX512F__
+    int anc = nc - nc%16; // Used to align nc with boundary of 16
+                          // Mask to mask out nibbles from packed bytes expanded to 512 bit length
+    const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);
+    // Lookup table to convert signed nibbles to signed bytes expanded to 512 bit length
+    __m512i signextendlutexpanded = _mm512_inserti32x8(_mm512_castsi256_si512(signextendlut), signextendlut, 1);
 
-    int64_t b_nb = n / QK_K;
-
-    const block_q4_Kx8 * b_ptr_start = (const block_q4_Kx8 *)vx;
-    const block_q8_K * a_ptr_start = (const block_q8_K *)vy;
+    // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
+    for (; y < anr / 4; y += 4) {
 
-    // Process Q8_K blocks one by one
-    for (int64_t y = 0; y < nr; y++) {
+        const block_q8_0x4 * a_ptrs[4];
 
-        // Pointers to LHS blocks of block_q8_K format
-        const block_q8_K * a_ptr = a_ptr_start + (y * nb);
+        a_ptrs[0] = a_ptr_start + (y * nb);
+        for (int i = 0; i < 3; ++i) {
+            a_ptrs[i + 1] = a_ptrs[i] + nb;
+        }
 
-        // Take group of eight interleaved block_q4_K structures at each pass of the loop and perform dot product operation
-        for (int64_t x = 0; x < nc / 8; x++) {
+        // Take group of two block_tx8 structures at each pass of the loop and perform dot product operation
+        for (int64_t x = 0; x < anc / 8; x += 2) {
 
-            // Pointers to RHS blocks
-            const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb);
+            const block_tx8 * b_ptr_0 = b_ptr_start + ((x)     * b_nb);
+            const block_tx8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);
 
             // Master FP accumulators
-            __m256 acc_row = _mm256_setzero_ps();
-            __m256 acc_min_rows = _mm256_setzero_ps();
+            __m512 acc_rows[16];
+            for (int i = 0; i < 16; i++) {
+                acc_rows[i] = _mm512_setzero_ps();
+            }
 
             for (int64_t b = 0; b < nb; b++) {
+                // Load the sixteen blocks of quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF
+                const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs));
+                const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32));
+                const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64));
+                const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96));
+
+                const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs));
+                const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32));
+                const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64));
+                const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96));
+
+                // Save the values in the following vectors in the formats B0B1B4B5B8B9BCBD, B2B3B6B7BABBBEBF for further processing and storing of values
+                const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
+                const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
+                const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
+                const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
+
+                const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);
+                const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);
+                const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);
+                const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);
+
+                const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1);
+                const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1);
+                const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1);
+                const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1);
 
-                // Load and convert to FP32 scale from block_q8_K
-                const __m256 row_scale_f32 = _mm256_set1_ps((a_ptr[b].d));
+                // 4-bit -> 8-bit - Sign is maintained
+                const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7)
+                const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7)
 
-                // Load the scale values for the 8 blocks interleaved in block_q4_Kx8
-                // col_scale_f32 rearranged so as to multiply with appropriate quants
-                const __m256 col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, deltamask);
-                const __m256 col_dmin_f32 = GGML_F32Cx8_LOAD(b_ptr[b].dmin);
+                const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15)
+                const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15)
 
-                __m256i iacc_b = _mm256_setzero_si256();
-                __m256i iacc_min_b = _mm256_setzero_si256();
+                const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23)
+                const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23)
 
-                const __m256i q8sums = _mm256_loadu_si256((const __m256i * )(a_ptr[b].bsums));
-                __m256i q8s = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(q8sums), _mm256_extracti128_si256(q8sums, 1)));
-                q8s = _mm256_permute2f128_si256(q8s, q8s, 0);
+                const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31)
+                const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
 
-                // Processes two sub blocks from each Q4_K in each iteration
-                for (int sb = 0; sb < QK_K / 64; sb++) {
+                // Shuffle pattern one - right side input
+                const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
+                const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
 
-                    // Load the eight block_q4_K for two sub blocks quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
-                    const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256));
-                    const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256));
-                    const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256));
-                    const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256));
-                    const __m256i rhs_raw_vec_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256));
-                    const __m256i rhs_raw_vec_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256));
-                    const __m256i rhs_raw_vec_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256));
-                    const __m256i rhs_raw_vec_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256));
+                const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
+                const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
 
-                    // 4-bit -> 8-bit
-                    // Values of the first sub block of eight block_q4_K structures for the sb loop
-                    const __m256i rhs_vec_0123_00 = _mm256_and_si256(rhs_raw_vec_0123_0, m4b);
-                    const __m256i rhs_vec_4567_00 = _mm256_and_si256(rhs_raw_vec_4567_0, m4b);
-                    const __m256i rhs_vec_0123_01 = _mm256_and_si256(rhs_raw_vec_0123_1, m4b);
-                    const __m256i rhs_vec_4567_01 = _mm256_and_si256(rhs_raw_vec_4567_1, m4b);
-                    const __m256i rhs_vec_0123_02 = _mm256_and_si256(rhs_raw_vec_0123_2, m4b);
-                    const __m256i rhs_vec_4567_02 = _mm256_and_si256(rhs_raw_vec_4567_2, m4b);
-                    const __m256i rhs_vec_0123_03 = _mm256_and_si256(rhs_raw_vec_0123_3, m4b);
-                    const __m256i rhs_vec_4567_03 = _mm256_and_si256(rhs_raw_vec_4567_3, m4b);
+                const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
+                const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
 
-                    // Values of the second sub block of eight block_q4_K structures when sb = 1
-                    const __m256i rhs_vec_0123_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 4), m4b);
-                    const __m256i rhs_vec_4567_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m4b);
-                    const __m256i rhs_vec_0123_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m4b);
-                    const __m256i rhs_vec_4567_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m4b);
-                    const __m256i rhs_vec_0123_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 4), m4b);
-                    const __m256i rhs_vec_4567_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 4), m4b);
-                    const __m256i rhs_vec_0123_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 4), m4b);
-                    const __m256i rhs_vec_4567_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 4), m4b);
+                const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
+                const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
 
-                    uint32_t utmp_0[4], utmp_1[4];
+                // Shuffle pattern two - right side input
 
-                    // Scales and Mins of corresponding sub blocks from different Q8_K structures are stored together
-                    // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
-                    memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12);
-                    utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);
-                    const uint32_t uaux_0 = utmp_0[1] & kmask1;
-                    utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);
-                    utmp_0[2] = uaux_0;
-                    utmp_0[0] &= kmask1;
+                const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
+                const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
 
-                    // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
-                    memcpy(utmp_1, b_ptr[b].scales + 12 + sb * 24, 12);
-                    utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);
-                    const uint32_t uaux_1 = utmp_1[1] & kmask1;
-                    utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);
-                    utmp_1[2] = uaux_1;
-                    utmp_1[0] &= kmask1;
+                const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
+                const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
 
-                    // Scales of first sub block in the sb loop
-                    const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);
-                    __m128i scales_rearrange_0 = _mm_shuffle_epi8(mins_and_scales_0, scalemask);
-                    __m256i scales_0 = _mm256_cvtepu8_epi16(scales_rearrange_0);
+                const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
+                const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
 
-                    // Scales of second sub block in the sb loop
-                    __m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);
-                    __m128i scales_rearrange_1 = _mm_shuffle_epi8(mins_and_scales_1, scalemask);
-                    __m256i scales_1 = _mm256_cvtepu8_epi16(scales_rearrange_1);
+                const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
+                const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
 
-                    // Mins of first and second sub block of Q4_K block are arranged side by side
-                    __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(_mm_shuffle_epi32(mins_and_scales_0, 78), _mm_shuffle_epi32(mins_and_scales_1, 78)));
+                // Scale values - Load the weight scale values of two block_tx8
+                __m512 col_scale_f32;
+                if constexpr (
+                        std::is_same_v ||
+                        std::is_same_v) {
+                    col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
+                }
 
-                    // Load the two sub block values corresponding to sb in block_q8_K in batches of 16 bytes and replicate the same across 256 bit vector
-                    __m256i lhs_vec_00 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + sb * 64)));
-                    __m256i lhs_vec_01 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16 + sb * 64)));
-                    __m256i lhs_vec_10 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 32 + sb * 64)));
-                    __m256i lhs_vec_11 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 48 + sb * 64)));
+                // Process LHS in pairs of rows
+                for (int rp = 0; rp < 4; rp++) {
 
-                    lhs_vec_00 = _mm256_permute2f128_si256(lhs_vec_00, lhs_vec_00, 0);
-                    lhs_vec_01 = _mm256_permute2f128_si256(lhs_vec_01, lhs_vec_01, 0);
-                    lhs_vec_10 = _mm256_permute2f128_si256(lhs_vec_10, lhs_vec_10, 0);
-                    lhs_vec_11 = _mm256_permute2f128_si256(lhs_vec_11, lhs_vec_11, 0);
+                    // Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
+                    // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector
+                    __m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs)));
+                    __m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0);
+                    __m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17);
+                    __m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32)));
+                    __m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0);
+                    __m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17);
+                    __m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64)));
+                    __m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0);
+                    __m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17);
+                    __m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96)));
+                    __m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0);
+                    __m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17);
 
-                    // Dot product done within 32 bit lanes and accumulated in the same vector
-                    // First done for first sub block and thenn for second sub block in each sb
-                    // B0(0-3) B4(0-3) B1(0-3) B5(0-3) B2(0-3) B6(0-3) B3(0-3) B7(0-3) with A0(0-3)
-                    // B0(4-7) B4(4-7) B1(4-7) B5(4-7) B2(4-7) B6(4-7) B3(4-7) B7(4-7) with A0(4-7)
-                    // ...........................................................................
-                    // B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31)
+                    __m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1);
+                    __m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1);
+                    __m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1);
+                    __m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1);
+                    __m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1);
+                    __m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1);
+                    __m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1);
+                    __m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1);
 
+                    // Shuffle pattern one - left side input
 
-                    __m256i iacc_0 = _mm256_setzero_si256();
-                    __m256i iacc_1 = _mm256_setzero_si256();
+                    const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160);  //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
+                    const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160);  //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
 
-                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_00 ,_mm256_shuffle_epi32(rhs_vec_4567_00, 177), 170), _mm256_shuffle_epi32(lhs_vec_00, 0)));
-                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_00, 177) ,rhs_vec_4567_00, 170), _mm256_shuffle_epi32(lhs_vec_00, 85)));
+                    const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160);  //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
+                    const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160);  //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
 
-                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_01 ,_mm256_shuffle_epi32(rhs_vec_4567_01, 177), 170), _mm256_shuffle_epi32(lhs_vec_00, 170)));
-                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_01, 177) ,rhs_vec_4567_01, 170), _mm256_shuffle_epi32(lhs_vec_00, 255)));
+                    const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160);  //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
+                    const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160);  //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
 
-                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_02 ,_mm256_shuffle_epi32(rhs_vec_4567_02, 177), 170), _mm256_shuffle_epi32(lhs_vec_01, 0)));
-                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_02, 177) ,rhs_vec_4567_02, 170), _mm256_shuffle_epi32(lhs_vec_01, 85)));
+                    const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160);  //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
+                    const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160);  //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
 
-                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_03 ,_mm256_shuffle_epi32(rhs_vec_4567_03, 177), 170), _mm256_shuffle_epi32(lhs_vec_01, 170)));
-                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_03, 177) ,rhs_vec_4567_03, 170), _mm256_shuffle_epi32(lhs_vec_01, 255)));
+                    // Shuffle pattern two - left side input
 
-                    iacc_0 = _mm256_madd_epi16(iacc_0, scales_0);
+                    const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245);  //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
+                    const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245);  //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
 
-                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_10 ,_mm256_shuffle_epi32(rhs_vec_4567_10, 177), 170), _mm256_shuffle_epi32(lhs_vec_10, 0)));
-                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_10, 177) ,rhs_vec_4567_10, 170), _mm256_shuffle_epi32(lhs_vec_10, 85)));
+                    const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245);  //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
+                    const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245);  //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
 
-                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_11 ,_mm256_shuffle_epi32(rhs_vec_4567_11, 177), 170), _mm256_shuffle_epi32(lhs_vec_10, 170)));
-                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_11, 177) ,rhs_vec_4567_11, 170), _mm256_shuffle_epi32(lhs_vec_10, 255)));
+                    const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245);  //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
+                    const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245);  //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
 
-                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_12 ,_mm256_shuffle_epi32(rhs_vec_4567_12, 177), 170), _mm256_shuffle_epi32(lhs_vec_11, 0)));
-                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_12, 177) ,rhs_vec_4567_12, 170), _mm256_shuffle_epi32(lhs_vec_11, 85)));
-
-                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_13 ,_mm256_shuffle_epi32(rhs_vec_4567_13, 177), 170), _mm256_shuffle_epi32(lhs_vec_11, 170)));
-                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_13, 177) ,rhs_vec_4567_13, 170), _mm256_shuffle_epi32(lhs_vec_11, 255)));
+                    const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245);  //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
+                    const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245);  //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
 
-                    iacc_1 = _mm256_madd_epi16(iacc_1, scales_1);
+                    // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
+                    // Resembles MMLAs into 2x2 matrices in ARM Version
+                    const __m512i zero = _mm512_setzero_epi32();
+                    __m512i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1);
+                    __m512i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1);
+                    __m512i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1);
+                    __m512i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1);
+                    __m512i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2);
+                    __m512i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2);
+                    __m512i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2);
+                    __m512i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2);
 
-                    // Accumulate the iacc value for one sb
-                    __m256i iacc_sb = _mm256_add_epi32(iacc_0, iacc_1);
+                    // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+                    __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
+                    __m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
+                    __m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
+                    __m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
 
-                    // Broadcast the bsums of the two sub blocks  of the iteration of Q8_K across the vector
-                    // Multiply-Add with corresponding mins of Q4_Kx8 with bsums
-                    __m256i q8s_sb = _mm256_shuffle_epi32(q8s, 0);
-                    __m256i iacc_min_sb = _mm256_madd_epi16(q8s_sb, mins_01);
-                    q8s = _mm256_bsrli_epi128(q8s, 4);
 
-                    // Accumulate for the complete block
-                    iacc_b = _mm256_add_epi32(iacc_b, iacc_sb);
-                    iacc_min_b = _mm256_add_epi32(iacc_min_b, iacc_min_sb);
-                }
+                    // Straighten out to make 4 row vectors
+                    __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));
+                    __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
+                    __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
+                    __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
 
-                // Multiply-Add with scale values for the complete super block
-                acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_b), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row);
-                acc_min_rows = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_min_b), _mm256_mul_ps(col_dmin_f32, row_scale_f32), acc_min_rows);
+                    // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
+                    const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68);
+                    const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16);
 
+                    // Multiply with appropiate scales and accumulate
+                    acc_rows[rp * 4]     = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)),   acc_rows[rp * 4]);
+                    acc_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)),  acc_rows[rp * 4 + 1]);
+                    acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
+                    acc_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);
+                }
             }
 
-            // Accumulated output values permuted so as to be stored in appropriate order post accumulation
-            acc_row = _mm256_permutevar8x32_ps(acc_row, finalpermutemask);
-            _mm256_storeu_ps(s + (y * nr + x * 8), _mm256_sub_ps(acc_row, acc_min_rows));
+            // Store the accumulated values
+            for (int i = 0; i < 16; i++) {
+                _mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
+            }
         }
     }
 
-#else
+    // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation
+    for (; y < nr / 4; y ++) {
+        const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb);
 
-    float sumf[8];
-    float sum_minf[8];
-    uint32_t utmp[32];
-    int sumi1;
-    int sumi2;
-    int sumi;
+        // Take group of two block_tx8 structures at each pass of the loop and perform dot product operation
+        for (int64_t x = 0; x < anc / 8; x += 2) {
 
-    const block_q8_K * a_ptr = (const block_q8_K *) vy;
-    for (int x = 0; x < nc / ncols_interleaved; x++) {
-        const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
+            const block_tx8 * b_ptr_0 = b_ptr_start + ((x)     * b_nb);
+            const block_tx8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);
 
-        for (int j = 0; j < ncols_interleaved; j++) {
-            sumf[j] = 0.0;
-            sum_minf[j] = 0.0;
-        }
-        for (int l = 0; l < nb; l++) {
-            for (int sb = 0; sb < 8; sb++) {
-                memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
-                utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
-                const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
-                utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
-                utmp[sb * 4 + 2] = uaux_0;
-                utmp[sb * 4 + 0] &= kmask1;
-            }
-            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
-                uint8_t *scales_0 = (uint8_t*) utmp + (k / 4) * 32;
-                uint8_t *scales_1 = (uint8_t*) utmp + (k / 4) * 32 + 16;
-                for (int j = 0; j < ncols_interleaved; j++) {
-                    sumi1 = 0;
-                    sumi2 = 0;
-                    sumi = 0;
-                    for (int i = 0; i < blocklen; ++i) {
-                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
-                        const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
-                        sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 64 + (k % 4) * blocklen + i]);
-                        sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 64 + (k % 4) * blocklen + i + 32]);
-                        sumi1 = sumi1 * scales_0[j];
-                        sumi2 = sumi2 * scales_1[j];
-                        sumi += sumi1 + sumi2;
-                    }
-                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
-                }
-            }
-            for (int sb = 0; sb < 8; sb++) {
-                uint8_t *mins = (uint8_t*) utmp + 8 + sb * 16;
-                for (int j = 0; j < ncols_interleaved; j++) {
-                    sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
-                }
+            // Master FP accumulators
+            __m512 acc_rows[4];
+            for (int i = 0; i < 4; i++) {
+                acc_rows[i] = _mm512_setzero_ps();
             }
-        }
-        for (int j = 0; j < ncols_interleaved; j++) {
-            s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
-        }
-    }
-#endif
-}
 
-void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
-    const int qk = QK8_0;
-    const int nb = n / qk;
-    const int ncols_interleaved = 8;
-    const int blocklen = 8;
+            for (int64_t b = 0; b < nb; b++) {
+                // Load the sixteen blocks of quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF
+                const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs));
+                const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32));
+                const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64));
+                const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96));
+
+                const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs));
+                const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32));
+                const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64));
+                const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96));
+
+                // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess
+                const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
+                const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
+                const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
+                const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
+
+                const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);
+                const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);
+                const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);
+                const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);
+
+                const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1);
+                const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1);
+                const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1);
+                const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1);
 
-    assert (n % qk == 0);
-    assert (nr % 4 == 0);
-    assert (nc % ncols_interleaved == 0);
+                // 4-bit -> 8-bit - Sign is maintained
+                const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7)
+                const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7)
 
-    UNUSED(s);
-    UNUSED(bs);
-    UNUSED(vx);
-    UNUSED(vy);
-    UNUSED(nr);
-    UNUSED(nc);
-    UNUSED(nb);
-    UNUSED(ncols_interleaved);
-    UNUSED(blocklen);
+                const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15)
+                const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15)
 
-#if defined(__AVX2__) || defined(__AVX512F__)
-    {
-        const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx;
-        const block_q8_0x4 * a_ptr_start = (const block_q8_0x4 *)vy;
-        int64_t b_nb = n / QK4_0;
-        int64_t y = 0;
-        // Mask to mask out nibbles from packed bytes
-        const __m256i m4b = _mm256_set1_epi8(0x0F);
-        const __m128i loadMask = _mm_blend_epi32(_mm_setzero_si128(), _mm_set1_epi32(0xFFFFFFFF), 3);
-        // Lookup table to convert signed nibbles to signed bytes
-        __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
-        signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
-        // Permute mask used for easier vector processing at later stages
-        __m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
-        int64_t xstart = 0;
-        int anr = nr - nr%16; // Used to align nr with boundary of 16
-    #ifdef __AVX512F__
-        int anc = nc - nc%16; // Used to align nc with boundary of 16
-        // Mask to mask out nibbles from packed bytes expanded to 512 bit length
-        const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);
-        // Lookup table to convert signed nibbles to signed bytes expanded to 512 bit length
-        __m512i signextendlutexpanded = _mm512_inserti32x8(_mm512_castsi256_si512(signextendlut), signextendlut, 1);
-
-        // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
-        for (; y < anr / 4; y += 4) {
-
-            const block_q8_0x4 * a_ptrs[4];
-
-            a_ptrs[0] = a_ptr_start + (y * nb);
-            for (int i = 0; i < 3; ++i) {
-                a_ptrs[i + 1] = a_ptrs[i] + nb;
-            }
+                const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23)
+                const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23)
 
-            // Take group of two block_q4_0x8 structures at each pass of the loop and perform dot product operation
-            for (int64_t x = 0; x < anc / 8; x += 2) {
+                const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31)
+                const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
 
-                const block_q4_0x8 * b_ptr_0 = b_ptr_start + ((x)     * b_nb);
-                const block_q4_0x8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);
+                // Shuffle pattern one - right side input
+                const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
+                const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
 
-                // Master FP accumulators
-                __m512 acc_rows[16];
-                for (int i = 0; i < 16; i++) {
-                    acc_rows[i] = _mm512_setzero_ps();
-                }
+                const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
+                const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
 
-                for (int64_t b = 0; b < nb; b++) {
-                    // Load the sixteen block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF
-                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs));
-                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32));
-                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64));
-                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96));
+                const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
+                const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
 
-                    const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs));
-                    const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32));
-                    const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64));
-                    const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96));
+                const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
+                const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
 
-                    // Save the values in the following vectors in the formats B0B1B4B5B8B9BCBD, B2B3B6B7BABBBEBF for further processing and storing of values
-                    const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
-                    const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
-                    const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
-                    const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
+                // Shuffle pattern two - right side input
 
-                    const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);
-                    const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);
-                    const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);
-                    const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);
+                const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
+                const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
 
-                    const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1);
-                    const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1);
-                    const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1);
-                    const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1);
+                const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
+                const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
 
-                    // 4-bit -> 8-bit - Sign is maintained
-                    const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7)
-                    const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7)
+                const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
+                const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
 
-                    const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15)
-                    const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15)
+                const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
+                const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
 
-                    const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23)
-                    const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23)
 
-                    const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31)
-                    const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
+                // Scale values - Load the weight scale values of two block_tx8
+                __m512 col_scale_f32;
+                if constexpr (
+                        std::is_same_v ||
+                        std::is_same_v) {
+                    col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
+                }
 
-                    // Shuffle pattern one - right side input
-                    const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
-                    const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
+                // Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
+                // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector
+                __m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs)));
+                __m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0);
+                __m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17);
+                __m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32)));
+                __m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0);
+                __m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17);
+                __m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64)));
+                __m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0);
+                __m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17);
+                __m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96)));
+                __m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0);
+                __m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17);
+
+                __m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1);
+                __m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1);
+                __m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1);
+                __m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1);
+                __m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1);
+                __m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1);
+                __m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1);
+                __m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1);
+
+                // Shuffle pattern one - left side input
+
+                const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160);  //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
+                const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160);  //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
+
+                const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160);  //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
+                const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160);  //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
+
+                const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160);  //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
+                const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160);  //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
+
+                const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160);  //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
+                const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160);  //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
+
+                // Shuffle pattern two - left side input
+
+                const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245);  //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
+                const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245);  //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
+
+                const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245);  //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
+                const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245);  //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
+
+                const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245);  //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
+                const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245);  //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
+
+                const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245);  //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
+                const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245);  //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
+
+                // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
+                // Resembles MMLAs into 2x2 matrices in ARM Version
+                const __m512i zero = _mm512_setzero_epi32();
+                __m512i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1);
+                __m512i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1);
+                __m512i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1);
+                __m512i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1);
+                __m512i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2);
+                __m512i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2);
+                __m512i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2);
+                __m512i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2);
+
+                // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+                __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
+                __m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
+                __m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
+                __m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
+
+
+                // Straighten out to make 4 row vectors
+                __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));
+                __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
+                __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
+                __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
+
+                // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
+                const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptr[b].d), loadMask), 68);
+                const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16);
+
+                // Multiply with appropiate scales and accumulate
+                acc_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)),   acc_rows[0]);
+                acc_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)),  acc_rows[1]);
+                acc_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
+                acc_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);
+            }
 
-                    const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
-                    const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
+            // Store the accumulated values
+            for (int i = 0; i < 4; i++) {
+                _mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
+            }
+        }
+    }
+    if (anc != nc) {
+        xstart = anc/8;
+        y = 0;
+    }
+#endif // __AVX512F__
 
-                    const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
-                    const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
+    // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
 
-                    const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
-                    const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
+    for (; y < anr / 4; y += 4) {
+        const block_q8_0x4 * a_ptrs[4];
 
-                    // Shuffle pattern two - right side input
+        a_ptrs[0] = a_ptr_start + (y * nb);
+        for (int i = 0; i < 3; ++i) {
+            a_ptrs[i + 1] = a_ptrs[i] + nb;
+        }
 
-                    const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
-                    const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
+        // Take group of eight block_tx8 structures at each pass of the loop and perform dot product operation
+        for (int64_t x = xstart; x < nc / 8; x++) {
 
-                    const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
-                    const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
+            const block_tx8 * b_ptr = b_ptr_start + (x * b_nb);
 
-                    const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
-                    const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
+            // Master FP accumulators
+            __m256 acc_rows[16];
+            for (int i = 0; i < 16; i++) {
+                acc_rows[i] = _mm256_setzero_ps();
+            }
 
-                    const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
-                    const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
+            for (int64_t b = 0; b < nb; b++) {
+                // Load the eight blocks of quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
+                const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));
+                const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32));
+                const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64));
+                const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96));
+
+                // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values
+                const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
+                const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
+                const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
+                const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
 
-                    // Scale values - Load the weight scale values of two block_q4_0x8
-                    const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
+                // 4-bit -> 8-bit - Sign is maintained
+                const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7)
+                const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7)
 
-                    // Process LHS in pairs of rows
-                    for (int rp = 0; rp < 4; rp++) {
+                const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15)
+                const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15)
 
-                        // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
-                        // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector
-                        __m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs)));
-                        __m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0);
-                        __m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17);
-                        __m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32)));
-                        __m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0);
-                        __m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17);
-                        __m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64)));
-                        __m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0);
-                        __m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17);
-                        __m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96)));
-                        __m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0);
-                        __m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17);
-
-                        __m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1);
-                        __m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1);
-                        __m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1);
-                        __m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1);
-                        __m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1);
-                        __m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1);
-                        __m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1);
-                        __m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1);
+                const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23)
+                const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23)
 
-                        // Shuffle pattern one - left side input
+                const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31)
+                const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31)
 
-                        const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160);  //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
-                        const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160);  //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
+                // Shuffle pattern one - right side input
+                const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136);  //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3)
+                const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136);  //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3)
 
-                        const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160);  //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
-                        const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160);  //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
+                const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136);  //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11)
+                const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136);  //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11)
 
-                        const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160);  //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
-                        const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160);  //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
+                const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136);  //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19)
+                const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136);  //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19)
 
-                        const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160);  //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
-                        const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160);  //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
+                const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136);  //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27)
+                const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136);  //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27)
 
-                        // Shuffle pattern two - left side input
+                // Shuffle pattern two - right side input
 
-                        const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245);  //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
-                        const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245);  //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
+                const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221);  //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7)
+                const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221);  //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7)
 
-                        const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245);  //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
-                        const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245);  //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
+                const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221);  //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15)
+                const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221);  //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15)
 
-                        const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245);  //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
-                        const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245);  //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
+                const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221);  //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23)
+                const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221);  //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23)
 
-                        const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245);  //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
-                        const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245);  //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
+                const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221);  //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31)
+                const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221);  //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31)
 
-                        // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
-                        // Resembles MMLAs into 2x2 matrices in ARM Version
-                        const __m512i zero = _mm512_setzero_epi32();
-                        __m512i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1);
-                        __m512i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1);
-                        __m512i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1);
-                        __m512i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1);
-                        __m512i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2);
-                        __m512i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2);
-                        __m512i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2);
-                        __m512i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2);
+                // Scale values - Load the wight scale values of block_tx8
+                __m256 col_scale_f32;
+                if constexpr (
+                        std::is_same_v ||
+                        std::is_same_v) {
+                    col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
+                }
 
-                        // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
-                        __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
-                        __m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
-                        __m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
-                        __m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
+                // Process LHS in groups of four
+                for (int rp = 0; rp < 4; rp++) {
+                    // Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
+                    // Loaded as set of 128 bit vectors and repeated into a 256 bit vector
+                    __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs)));
+                    __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0);
+                    __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17);
+                    __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32)));
+                    __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0);
+                    __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17);
+                    __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64)));
+                    __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0);
+                    __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17);
+                    __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96)));
+                    __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0);
+                    __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17);
 
+                    // Shuffle pattern one - left side input
+                    const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160);  //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
+                    const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160);  //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
 
-                        // Straighten out to make 4 row vectors
-                        __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));
-                        __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
-                        __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
-                        __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
+                    const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160);  //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
+                    const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160);  //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
 
-                        // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
-                        const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68);
-                        const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16);
+                    const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160);  //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
+                    const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160);  //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
 
-                        // Multiply with appropiate scales and accumulate
-                        acc_rows[rp * 4]     = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)),   acc_rows[rp * 4]);
-                        acc_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)),  acc_rows[rp * 4 + 1]);
-                        acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
-                        acc_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);
-                    }
-                }
+                    const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160);  //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
+                    const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160);  //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
 
-                // Store the accumulated values
-                for (int i = 0; i < 16; i++) {
-                    _mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
-                }
-            }
-        }
-        // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation
-        for (; y < nr / 4; y ++) {
+                    // Shuffle pattern two - left side input
+                    const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245);  //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
+                    const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245);  //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
 
-            const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb);
+                    const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245);  //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
+                    const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245);  //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
 
-            // Take group of two block_q4_0x8 structures at each pass of the loop and perform dot product operation
-            for (int64_t x = 0; x < anc / 8; x += 2) {
+                    const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245);  //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
+                    const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245);  //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
 
-                const block_q4_0x8 * b_ptr_0 = b_ptr_start + ((x)     * b_nb);
-                const block_q4_0x8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);
+                    const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245);  //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
+                    const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245);  //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
 
-                // Master FP accumulators
-                __m512 acc_rows[4];
-                for (int i = 0; i < 4; i++) {
-                    acc_rows[i] = _mm512_setzero_ps();
-                }
+                    // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
+                    // Resembles MMLAs into 2x2 matrices in ARM Version
+                    const __m256i zero = _mm256_setzero_si256();
+                    __m256i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1);
+                    __m256i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1);
+                    __m256i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1);
+                    __m256i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1);
+                    __m256i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2);
+                    __m256i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2);
+                    __m256i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2);
+                    __m256i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2);
 
-                for (int64_t b = 0; b < nb; b++) {
-                    // Load the sixteen block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF
-                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs));
-                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32));
-                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64));
-                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96));
+                    // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+                    __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
+                    __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
+                    __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
+                    __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
 
-                    const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs));
-                    const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32));
-                    const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64));
-                    const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96));
+                    // Straighten out to make 4 row vectors
+                    __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204);
+                    __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204);
+                    __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204);
+                    __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204);
 
-                    // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess
-                    const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
-                    const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
-                    const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
-                    const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
+                    // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
+                    const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask);
 
-                    const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);
-                    const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);
-                    const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);
-                    const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);
+                    // Multiply with appropiate scales and accumulate
+                    acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);
+                    acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);
+                    acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
+                    acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32,  255)), acc_rows[rp * 4 + 3]);
+                }
+            }
 
-                    const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1);
-                    const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1);
-                    const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1);
-                    const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1);
+            // Store the accumulated values
+            for (int i = 0; i < 16; i++) {
+                _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
+            }
+        }
+    }
 
-                    // 4-bit -> 8-bit - Sign is maintained
-                    const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7)
-                    const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7)
+    // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation
+    for (; y < nr / 4; y ++) {
+        const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb);
 
-                    const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15)
-                    const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15)
+        // Load the eight blocks of quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
+        for (int64_t x = xstart; x < nc / 8; x++) {
+            const block_tx8 * b_ptr = b_ptr_start + (x * b_nb);
 
-                    const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23)
-                    const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23)
+            // Master FP accumulators
+            __m256 acc_rows[4];
+            for (int i = 0; i < 4; i++) {
+                acc_rows[i] = _mm256_setzero_ps();
+            }
 
-                    const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31)
-                    const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
+            for (int64_t b = 0; b < nb; b++) {
+                // Load the eight block_q8_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
+                const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));
+                const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32));
+                const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64));
+                const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96));
+
+                // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess
+                const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
+                const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
+                const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
+                const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
 
-                    // Shuffle pattern one - right side input
-                    const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
-                    const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
+                // 4-bit -> 8-bit - Sign is maintained
+                const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b));  //B0(0-7) B1(0-7) B4(0-7) B5(0-7)
+                const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b));  //B2(0-7) B3(0-7) B6(0-7) B7(0-7)
 
-                    const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
-                    const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
+                const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b));  //B0(8-15) B1(8-15) B4(8-15) B5(8-15)
+                const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b));  //B2(8-15) B3(8-15) B6(8-15) B7(8-15)
 
-                    const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
-                    const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
+                const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b));  //B0(16-23) B1(16-23) B4(16-23) B5(16-23)
+                const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b));  //B2(16-23) B3(16-23) B6(16-23) B7(16-23)
 
-                    const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
-                    const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
+                const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b));  //B0(24-31) B1(24-31) B4(24-31) B5(24-31)
+                const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b));  //B2(24-31) B3(24-31) B6(24-31) B7(24-31)
 
-                    // Shuffle pattern two - right side input
+                // Shuffle pattern one - right side input
+                const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136);  //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3)
+                const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136);  //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3)
 
-                    const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
-                    const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
+                const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136);  //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11)
+                const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136);  //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11)
 
-                    const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
-                    const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
+                const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136);  //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19)
+                const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136);  //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19)
 
-                    const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
-                    const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
+                const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136);  //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27)
+                const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136);  //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27)
 
-                    const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
-                    const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
+                // Shuffle pattern two - right side input
 
+                const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221);  //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7)
+                const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221);  //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7)
 
-                    // Scale values - Load the weight scale values of two block_q4_0x8
-                    const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
+                const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221);  //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15)
+                const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221);  //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15)
 
-                    // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
-                    // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector
-                    __m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs)));
-                    __m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0);
-                    __m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17);
-                    __m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32)));
-                    __m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0);
-                    __m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17);
-                    __m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64)));
-                    __m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0);
-                    __m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17);
-                    __m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96)));
-                    __m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0);
-                    __m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17);
+                const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221);  //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23)
+                const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221);  //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23)
 
-                    __m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1);
-                    __m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1);
-                    __m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1);
-                    __m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1);
-                    __m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1);
-                    __m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1);
-                    __m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1);
-                    __m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1);
+                const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221);  //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31)
+                const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221);  //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31)
 
-                    // Shuffle pattern one - left side input
+                // Scale values - Load the wight scale values of block_tx8
+                __m256 col_scale_f32;
+                if constexpr (
+                        std::is_same_v ||
+                        std::is_same_v) {
+                    col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
+                }
 
-                    const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160);  //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
-                    const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160);  //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
+                // Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
+                // Loaded as set of 128 bit vectors and repeated into a 256 bit vector
+                __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs)));
+                __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0);
+                __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17);
+                __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32)));
+                __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0);
+                __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17);
+                __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64)));
+                __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0);
+                __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17);
+                __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96)));
+                __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0);
+                __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17);
+
+                // Shuffle pattern one - left side input
+
+                const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160);  //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
+                const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160);  //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
+
+                const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160);  //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
+                const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160);  //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
+
+                const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160);  //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
+                const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160);  //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
+
+                const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160);  //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
+                const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160);  //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
+
+                // Shuffle pattern two - left side input
+
+                const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245);  //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
+                const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245);  //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
+
+                const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245);  //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
+                const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245);  //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
+
+                const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245);  //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
+                const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245);  //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
+
+                const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245);  //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
+                const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245);  //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
+
+                // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
+                // Resembles MMLAs into 2x2 matrices in ARM Version
+                const __m256i zero = _mm256_setzero_si256();
+                __m256i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1);
+                __m256i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1);
+                __m256i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1);
+                __m256i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1);
+                __m256i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2);
+                __m256i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2);
+                __m256i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2);
+                __m256i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2);
+
+                // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+                __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
+                __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
+                __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
+                __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
+
+
+                // Straighten out to make 4 row vectors
+                __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204);
+                __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204);
+                __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204);
+                __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204);
+
+                // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
+                const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptr[b].d, loadMask);
+
+                // Multiply with appropiate scales and accumulate
+                acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);
+                acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);
+                acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
+                acc_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);
+            }
 
-                    const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160);  //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
-                    const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160);  //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
+            // Store the accumulated values
+            for (int i = 0; i < 4; i++) {
+                _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
+            }
+        }
+    }
+}
+
+#endif // defined(__AVX2__) || defined(__AVX512F__)
+
+void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+#if defined(__AVX2__) || defined(__AVX512F__)
+    {
+        // Lookup table to convert signed nibbles to signed bytes
+        __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
+        signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
+
+        gemv_q4_b32_8x8_q8_0_lut_avx(n, s, bs, vx, vy, nr, nc, signextendlut);
+
+        return;
+    }
+#endif
+
+    ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK_K;
+    const int nb = n / qk;
+    const int ncols_interleaved = 8;
+    const int blocklen = 8;
+    static const uint32_t kmask1 = 0x3f3f3f3f;
+    static const uint32_t kmask2 = 0x0f0f0f0f;
+    static const uint32_t kmask3 = 0x03030303;
+
+    assert (n % qk == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__AVX2__)
+    // Lookup table to convert signed nibbles to signed bytes
+    __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
+    signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
+    // Shuffle masks to rearrange delta and scale values to multiply with appropriate scales
+    __m128i deltamask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0);
+    __m128i scalemask = _mm_set_epi8(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0);
+    // Permute mask used for easier vector processing at later stages
+    __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
+
+    // Mask to extract nibbles from bytes
+    const __m256i m4b = _mm256_set1_epi8(0x0F);
+
+    int64_t b_nb = n / QK_K;
+
+    const block_q4_Kx8 * b_ptr_start = (const block_q4_Kx8 *)vx;
+    const block_q8_K * a_ptr_start = (const block_q8_K *)vy;
+
+    // Process Q8_K blocks one by one
+    for (int64_t y = 0; y < nr; y++) {
+
+        // Pointers to LHS blocks of block_q8_K format
+        const block_q8_K * a_ptr = a_ptr_start + (y * nb);
+
+        // Take group of eight interleaved block_q4_K structures at each pass of the loop and perform dot product operation
+        for (int64_t x = 0; x < nc / 8; x++) {
+
+            // Pointers to RHS blocks
+            const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb);
+
+            // Master FP accumulators
+            __m256 acc_row = _mm256_setzero_ps();
+            __m256 acc_min_rows = _mm256_setzero_ps();
+
+            for (int64_t b = 0; b < nb; b++) {
+
+                // Load and convert to FP32 scale from block_q8_K
+                const __m256 row_scale_f32 = _mm256_set1_ps((a_ptr[b].d));
+
+                // Load the scale values for the 8 blocks interleaved in block_q4_Kx8
+                // col_scale_f32 rearranged so as to multiply with appropriate quants
+                const __m256 col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, deltamask);
+                const __m256 col_dmin_f32 = GGML_F32Cx8_LOAD(b_ptr[b].dmin);
+
+                __m256i iacc_b = _mm256_setzero_si256();
+                __m256i iacc_min_b = _mm256_setzero_si256();
+
+                const __m256i q8sums = _mm256_loadu_si256((const __m256i * )(a_ptr[b].bsums));
+                __m256i q8s = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(q8sums), _mm256_extracti128_si256(q8sums, 1)));
+                q8s = _mm256_permute2f128_si256(q8s, q8s, 0);
+
+                // Processes two sub blocks from each Q4_K in each iteration
+                for (int sb = 0; sb < QK_K / 64; sb++) {
+
+                    // Load the eight block_q4_K for two sub blocks quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
+                    const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256));
+                    const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256));
+                    const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256));
+                    const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256));
+                    const __m256i rhs_raw_vec_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256));
+                    const __m256i rhs_raw_vec_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256));
+                    const __m256i rhs_raw_vec_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256));
+                    const __m256i rhs_raw_vec_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256));
+
+                    // 4-bit -> 8-bit
+                    // Values of the first sub block of eight block_q4_K structures for the sb loop
+                    const __m256i rhs_vec_0123_00 = _mm256_and_si256(rhs_raw_vec_0123_0, m4b);
+                    const __m256i rhs_vec_4567_00 = _mm256_and_si256(rhs_raw_vec_4567_0, m4b);
+                    const __m256i rhs_vec_0123_01 = _mm256_and_si256(rhs_raw_vec_0123_1, m4b);
+                    const __m256i rhs_vec_4567_01 = _mm256_and_si256(rhs_raw_vec_4567_1, m4b);
+                    const __m256i rhs_vec_0123_02 = _mm256_and_si256(rhs_raw_vec_0123_2, m4b);
+                    const __m256i rhs_vec_4567_02 = _mm256_and_si256(rhs_raw_vec_4567_2, m4b);
+                    const __m256i rhs_vec_0123_03 = _mm256_and_si256(rhs_raw_vec_0123_3, m4b);
+                    const __m256i rhs_vec_4567_03 = _mm256_and_si256(rhs_raw_vec_4567_3, m4b);
+
+                    // Values of the second sub block of eight block_q4_K structures when sb = 1
+                    const __m256i rhs_vec_0123_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 4), m4b);
+                    const __m256i rhs_vec_4567_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m4b);
+                    const __m256i rhs_vec_0123_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m4b);
+                    const __m256i rhs_vec_4567_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m4b);
+                    const __m256i rhs_vec_0123_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 4), m4b);
+                    const __m256i rhs_vec_4567_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 4), m4b);
+                    const __m256i rhs_vec_0123_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 4), m4b);
+                    const __m256i rhs_vec_4567_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 4), m4b);
+
+                    uint32_t utmp_0[4], utmp_1[4];
+
+                    // Scales and Mins of corresponding sub blocks from different Q8_K structures are stored together
+                    // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
+                    memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12);
+                    utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_0 = utmp_0[1] & kmask1;
+                    utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);
+                    utmp_0[2] = uaux_0;
+                    utmp_0[0] &= kmask1;
+
+                    // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
+                    memcpy(utmp_1, b_ptr[b].scales + 12 + sb * 24, 12);
+                    utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_1 = utmp_1[1] & kmask1;
+                    utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);
+                    utmp_1[2] = uaux_1;
+                    utmp_1[0] &= kmask1;
+
+                    // Scales of first sub block in the sb loop
+                    const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);
+                    __m128i scales_rearrange_0 = _mm_shuffle_epi8(mins_and_scales_0, scalemask);
+                    __m256i scales_0 = _mm256_cvtepu8_epi16(scales_rearrange_0);
+
+                    // Scales of second sub block in the sb loop
+                    __m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);
+                    __m128i scales_rearrange_1 = _mm_shuffle_epi8(mins_and_scales_1, scalemask);
+                    __m256i scales_1 = _mm256_cvtepu8_epi16(scales_rearrange_1);
+
+                    // Mins of first and second sub block of Q4_K block are arranged side by side
+                    __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(_mm_shuffle_epi32(mins_and_scales_0, 78), _mm_shuffle_epi32(mins_and_scales_1, 78)));
+
+                    // Load the two sub block values corresponding to sb in block_q8_K in batches of 16 bytes and replicate the same across 256 bit vector
+                    __m256i lhs_vec_00 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + sb * 64)));
+                    __m256i lhs_vec_01 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16 + sb * 64)));
+                    __m256i lhs_vec_10 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 32 + sb * 64)));
+                    __m256i lhs_vec_11 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 48 + sb * 64)));
+
+                    lhs_vec_00 = _mm256_permute2f128_si256(lhs_vec_00, lhs_vec_00, 0);
+                    lhs_vec_01 = _mm256_permute2f128_si256(lhs_vec_01, lhs_vec_01, 0);
+                    lhs_vec_10 = _mm256_permute2f128_si256(lhs_vec_10, lhs_vec_10, 0);
+                    lhs_vec_11 = _mm256_permute2f128_si256(lhs_vec_11, lhs_vec_11, 0);
+
+                    // Dot product done within 32 bit lanes and accumulated in the same vector
+                    // First done for first sub block and thenn for second sub block in each sb
+                    // B0(0-3) B4(0-3) B1(0-3) B5(0-3) B2(0-3) B6(0-3) B3(0-3) B7(0-3) with A0(0-3)
+                    // B0(4-7) B4(4-7) B1(4-7) B5(4-7) B2(4-7) B6(4-7) B3(4-7) B7(4-7) with A0(4-7)
+                    // ...........................................................................
+                    // B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31)
+
+
+                    __m256i iacc_0 = _mm256_setzero_si256();
+                    __m256i iacc_1 = _mm256_setzero_si256();
+
+                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_00 ,_mm256_shuffle_epi32(rhs_vec_4567_00, 177), 170), _mm256_shuffle_epi32(lhs_vec_00, 0)));
+                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_00, 177) ,rhs_vec_4567_00, 170), _mm256_shuffle_epi32(lhs_vec_00, 85)));
+
+                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_01 ,_mm256_shuffle_epi32(rhs_vec_4567_01, 177), 170), _mm256_shuffle_epi32(lhs_vec_00, 170)));
+                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_01, 177) ,rhs_vec_4567_01, 170), _mm256_shuffle_epi32(lhs_vec_00, 255)));
+
+                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_02 ,_mm256_shuffle_epi32(rhs_vec_4567_02, 177), 170), _mm256_shuffle_epi32(lhs_vec_01, 0)));
+                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_02, 177) ,rhs_vec_4567_02, 170), _mm256_shuffle_epi32(lhs_vec_01, 85)));
+
+                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_03 ,_mm256_shuffle_epi32(rhs_vec_4567_03, 177), 170), _mm256_shuffle_epi32(lhs_vec_01, 170)));
+                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_03, 177) ,rhs_vec_4567_03, 170), _mm256_shuffle_epi32(lhs_vec_01, 255)));
+
+                    iacc_0 = _mm256_madd_epi16(iacc_0, scales_0);
+
+                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_10 ,_mm256_shuffle_epi32(rhs_vec_4567_10, 177), 170), _mm256_shuffle_epi32(lhs_vec_10, 0)));
+                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_10, 177) ,rhs_vec_4567_10, 170), _mm256_shuffle_epi32(lhs_vec_10, 85)));
+
+                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_11 ,_mm256_shuffle_epi32(rhs_vec_4567_11, 177), 170), _mm256_shuffle_epi32(lhs_vec_10, 170)));
+                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_11, 177) ,rhs_vec_4567_11, 170), _mm256_shuffle_epi32(lhs_vec_10, 255)));
+
+                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_12 ,_mm256_shuffle_epi32(rhs_vec_4567_12, 177), 170), _mm256_shuffle_epi32(lhs_vec_11, 0)));
+                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_12, 177) ,rhs_vec_4567_12, 170), _mm256_shuffle_epi32(lhs_vec_11, 85)));
+
+                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_13 ,_mm256_shuffle_epi32(rhs_vec_4567_13, 177), 170), _mm256_shuffle_epi32(lhs_vec_11, 170)));
+                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_13, 177) ,rhs_vec_4567_13, 170), _mm256_shuffle_epi32(lhs_vec_11, 255)));
+
+                    iacc_1 = _mm256_madd_epi16(iacc_1, scales_1);
+
+                    // Accumulate the iacc value for one sb
+                    __m256i iacc_sb = _mm256_add_epi32(iacc_0, iacc_1);
+
+                    // Broadcast the bsums of the two sub blocks  of the iteration of Q8_K across the vector
+                    // Multiply-Add with corresponding mins of Q4_Kx8 with bsums
+                    __m256i q8s_sb = _mm256_shuffle_epi32(q8s, 0);
+                    __m256i iacc_min_sb = _mm256_madd_epi16(q8s_sb, mins_01);
+                    q8s = _mm256_bsrli_epi128(q8s, 4);
+
+                    // Accumulate for the complete block
+                    iacc_b = _mm256_add_epi32(iacc_b, iacc_sb);
+                    iacc_min_b = _mm256_add_epi32(iacc_min_b, iacc_min_sb);
+                }
+
+                // Multiply-Add with scale values for the complete super block
+                acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_b), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row);
+                acc_min_rows = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_min_b), _mm256_mul_ps(col_dmin_f32, row_scale_f32), acc_min_rows);
+
+            }
+
+            // Accumulated output values permuted so as to be stored in appropriate order post accumulation
+            acc_row = _mm256_permutevar8x32_ps(acc_row, finalpermutemask);
+            _mm256_storeu_ps(s + (y * nr + x * 8), _mm256_sub_ps(acc_row, acc_min_rows));
+        }
+    }
+
+#else
+    UNUSED(kmask1);
+    UNUSED(kmask2);
+    UNUSED(kmask3);
+    ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
+#endif
+}
+
+void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+#if defined(__AVX2__)
+    __m256i signextendlut = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)kvalues_iq4nl));
+    signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
+
+    gemv_q4_b32_8x8_q8_0_lut_avx(n, s, bs, vx, vy, nr, nc, signextendlut);
+
+    return;
+#endif
+
+    ggml_gemv_iq4_nl_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK_K;
+    const int nb = n / qk;
+    const int ncols_interleaved = 8;
+    const int blocklen = 8;
+
+    assert (n % qk == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__AVX2__)
+    // Lookup table to convert signed nibbles to signed bytes
+    __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
+    signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
+    // Shuffle masks to rearrange delta values to multiply with appropriate scales
+    __m128i deltamask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0);
+    // Permute mask used for easier vector processing at later stages
+    __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
+
+    const __m256i m3b = _mm256_set1_epi8(3);
+    const __m128i m4b_sse = _mm_set1_epi8(0xF);
+
+    //Mask to get appropriate scales
+    __m128i scalemask1 = _mm_set_epi8(14,14,6,6,12,12,4,4,10,10,2,2,8,8,0,0);
+    __m128i scalemask2 = _mm_set_epi8(15,15,7,7,13,13,5,5,11,11,3,3,9,9,1,1);
+
+    int64_t b_nb = n / QK_K;
+
+    const block_q2_Kx8 * b_ptr_start = (const block_q2_Kx8 *)vx;
+    const block_q8_K * a_ptr_start = (const block_q8_K *)vy;
+
+    // Process Q8_K blocks one by one
+    for (int64_t y = 0; y < nr; y++) {
+
+        // Pointers to LHS blocks of block_q8_K format
+        const block_q8_K * a_ptr = a_ptr_start + (y * nb);
+
+        // Take group of eight interleaved block_q2_K structures at each pass of the loop and perform dot product operation
+        for(int64_t x = 0; x < nc / 8; x++) {
+
+            // Pointers to RHS blocks
+            const block_q2_Kx8 * b_ptr = b_ptr_start + (x * b_nb);
+
+            // Master FP accumulators
+            __m256 acc_row = _mm256_setzero_ps();
+            __m256 acc_min_rows = _mm256_setzero_ps();
+
+            for (int64_t b = 0; b < nb; b++) {
+
+                // Load and convert to FP32 delta from block_q8_K
+                const __m256 row_scale_f32 = _mm256_set1_ps((a_ptr[b].d));
+
+                // Load the delta values for the 8 blocks interleaved in block_q2_Kx8
+                // col_scale_f32 rearranged so as to multiply with appropriate quants
+                const __m256 col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, deltamask);
+                const __m256 col_dmin_f32 = GGML_F32Cx8_LOAD(b_ptr[b].dmin);
+
+                __m256i iacc_b = _mm256_setzero_si256();
+                __m256i iacc_min_b = _mm256_setzero_si256();
+
+                // Processes eight sub blocks from each Q2_K in each iteration
+                for(int sb = 0; sb < QK_K / 128; sb++) {
+
+                    // Load the eight block_q2_K for eight sub blocks quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
+                    const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256));
+                    const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256));
+                    const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256));
+                    const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256));
+                    const __m256i rhs_raw_vec_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256));
+                    const __m256i rhs_raw_vec_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256));
+                    const __m256i rhs_raw_vec_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256));
+                    const __m256i rhs_raw_vec_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256));
+
+                    // 2-bit -> 8-bit
+                    // Values of the 0th,2nd,4th,6th sub blocks of eight block_q2_K structures for the sb loop
+                    const __m256i rhs_vec_0123_00 = _mm256_and_si256(rhs_raw_vec_0123_0, m3b); //B00(0-7) B01(0-7) B02(0-7) B03(0-7)
+                    const __m256i rhs_vec_0123_20 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 2), m3b); //B20(0-7) B21(0-7) B22(0-7) B23(0-7)
+                    const __m256i rhs_vec_0123_40 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 4), m3b); //B40(0-7) B41(0-7) B42(0-7) B43(0-7)
+                    const __m256i rhs_vec_0123_60 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 6), m3b); //B60(0-7) B61(0-7) B62(0-7) B63(0-7)
+
+                    const __m256i rhs_vec_4567_00 = _mm256_and_si256(rhs_raw_vec_4567_0, m3b); //B04(0-7) B05(0-7) B06(0-7) B07(0-7)
+                    const __m256i rhs_vec_4567_20 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 2), m3b); //B24(0-7) B25(0-7) B26(0-7) B27(0-7)
+                    const __m256i rhs_vec_4567_40 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m3b); //B44(0-7) B45(0-7) B46(0-7) B47(0-7)
+                    const __m256i rhs_vec_4567_60 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 6), m3b); //B64(0-7) B65(0-7) B66(0-7) B67(0-7)
+
+                    const __m256i rhs_vec_0123_01 = _mm256_and_si256(rhs_raw_vec_0123_1, m3b); //B00(8-15) B01(8-15) B02(8-15) B03(8-15)
+                    const __m256i rhs_vec_0123_21 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 2), m3b); //B20(8-15) B21(8-15) B22(8-15) B23(8-15)
+                    const __m256i rhs_vec_0123_41 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m3b); //B40(8-15) B41(8-15) B42(8-15) B43(8-15)
+                    const __m256i rhs_vec_0123_61 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 6), m3b); //B60(8-15) B61(8-15) B62(8-15) B63(8-15)
+
+                    const __m256i rhs_vec_4567_01 = _mm256_and_si256(rhs_raw_vec_4567_1, m3b); //B04(8-15) B05(8-15) B06(8-15) B07(8-15)
+                    const __m256i rhs_vec_4567_21 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 2), m3b); //B24(8-15) B25(8-15) B26(8-15) B27(8-15)
+                    const __m256i rhs_vec_4567_41 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m3b); //B44(8-15) B45(8-15) B46(8-15) B47(8-15)
+                    const __m256i rhs_vec_4567_61 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 6), m3b); //B64(8-15) B65(8-15) B66(8-15) B67(8-15)
+
+                    // Values of the 1st,3rd,5th,7th sub blocks of eight block_q2_K structures for the sb loop
+                    const __m256i rhs_vec_0123_10 = _mm256_and_si256(rhs_raw_vec_0123_2, m3b); //B10(0-7) B11(0-7) B12(0-7) B13(0-7)
+                    const __m256i rhs_vec_0123_30 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 2), m3b); //B30(0-7) B31(0-7) B32(0-7) B33(0-7)
+                    const __m256i rhs_vec_0123_50 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 4), m3b); //B50(0-7) B51(0-7) B52(0-7) B53(0-7)
+                    const __m256i rhs_vec_0123_70 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 6), m3b); //B70(0-7) B71(0-7) B72(0-7) B73(0-7)
+
+                    const __m256i rhs_vec_4567_10 = _mm256_and_si256(rhs_raw_vec_4567_2, m3b); //B14(0-7) B15(0-7) B16(0-7) B17(0-7)
+                    const __m256i rhs_vec_4567_30 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 2), m3b); //B34(0-7) B35(0-7) B36(0-7) B37(0-7)
+                    const __m256i rhs_vec_4567_50 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 4), m3b); //B54(0-7) B55(0-7) B56(0-7) B57(0-7)
+                    const __m256i rhs_vec_4567_70 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 6), m3b); //B74(0-7) B75(0-7) B76(0-7) B77(0-7)
+
+                    const __m256i rhs_vec_0123_11 = _mm256_and_si256(rhs_raw_vec_0123_3, m3b); //B10(8-15) B11(8-15) B12(8-15) B13(8-15)
+                    const __m256i rhs_vec_0123_31 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 2), m3b); //B30(8-15) B31(8-15) B32(8-15) B33(8-15)
+                    const __m256i rhs_vec_0123_51 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 4), m3b); //B50(8-15) B51(8-15) B52(8-15) B53(8-15)
+                    const __m256i rhs_vec_0123_71 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 6), m3b); //B70(8-15) B71(8-15) B72(8-15) B73(8-15)
+
+                    const __m256i rhs_vec_4567_11 = _mm256_and_si256(rhs_raw_vec_4567_3, m3b); //B14(8-15) B15(8-15) B16(8-15) B17(8-15)
+                    const __m256i rhs_vec_4567_31 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 2), m3b); //B34(8-15) B35(8-15) B36(8-15) B37(8-15)
+                    const __m256i rhs_vec_4567_51 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 4), m3b); //B54(8-15) B55(8-15) B56(8-15) B57(8-15)
+                    const __m256i rhs_vec_4567_71 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 6), m3b); //B74(8-15) B75(8-15) B76(8-15) B77(8-15)
+
+                    //Scales and Mins of corresponding sub blocks from different Q2_K structures are stored together
+                    //s00 m00  s01 m01   s10 m10  s11 m11  s20 m20  s21 m21   s30 m30  s31 m31  s40 m40  s41 m41   s50 m50  s51 m51  s60 m60  s61 m61   s70 m70  s71 m71
+
+                    const __m128i mins_and_scales_01 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + sb * 64));
+                    const __m128i mins_and_scales_23 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 16 + sb * 64));
+                    const __m128i mins_and_scales_45 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 32 + sb * 64));
+                    const __m128i mins_and_scales_67 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 48 + sb * 64));
+
+                    // Extract scales which is lower half from mins_and_scales
+                    const __m128i scales_01 = _mm_and_si128(mins_and_scales_01, m4b_sse);
+                    const __m128i scales_23 = _mm_and_si128(mins_and_scales_23, m4b_sse);
+                    const __m128i scales_45 = _mm_and_si128(mins_and_scales_45, m4b_sse);
+                    const __m128i scales_67 = _mm_and_si128(mins_and_scales_67, m4b_sse);
+
+                    // Extract mins which is upper half from mins_and_scales
+                    const __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_01, 4), m4b_sse));
+                    const __m256i mins_23 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_23, 4), m4b_sse));
+                    const __m256i mins_45 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_45, 4), m4b_sse));
+                    const __m256i mins_67 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_67, 4), m4b_sse));
+
+                    // Scales of sub blocks in the sb loop
+                    // Scales of the 0th sub block from each super block
+                    __m128i scales_rearrange_0 = _mm_shuffle_epi8(scales_01, scalemask1);
+                    __m256i scales_0 = _mm256_cvtepu8_epi16(scales_rearrange_0);
+
+                    // Scales of the 1st sub block from each super block
+                    __m128i scales_rearrange_1 = _mm_shuffle_epi8(scales_01, scalemask2);
+                    __m256i scales_1 = _mm256_cvtepu8_epi16(scales_rearrange_1);
+
+                    // Scales of the 2nd sub block from each super block
+                    __m128i scales_rearrange_2 = _mm_shuffle_epi8(scales_23, scalemask1);
+                    __m256i scales_2 = _mm256_cvtepu8_epi16(scales_rearrange_2);
+
+                    // Scales of the 3rd sub block from each super block
+                    __m128i scales_rearrange_3 = _mm_shuffle_epi8(scales_23, scalemask2);
+                    __m256i scales_3 = _mm256_cvtepu8_epi16(scales_rearrange_3);
+
+                    // Scales of the 4th sub block from each super block
+                    __m128i scales_rearrange_4 = _mm_shuffle_epi8(scales_45, scalemask1);
+                    __m256i scales_4 = _mm256_cvtepu8_epi16(scales_rearrange_4);
+
+                    // Scales of the 5th sub block from each super block
+                    __m128i scales_rearrange_5 = _mm_shuffle_epi8(scales_45, scalemask2);
+                    __m256i scales_5 = _mm256_cvtepu8_epi16(scales_rearrange_5);
+
+                    // Scales of the 6th sub block from each super block
+                    __m128i scales_rearrange_6 = _mm_shuffle_epi8(scales_67, scalemask1);
+                    __m256i scales_6 = _mm256_cvtepu8_epi16(scales_rearrange_6);
+
+                    // Scales of the 7th sub block from each super block
+                    __m128i scales_rearrange_7 = _mm_shuffle_epi8(scales_67, scalemask2);
+                    __m256i scales_7 = _mm256_cvtepu8_epi16(scales_rearrange_7);
+
+                    // Load the sub block values corresponding to sb in block_q8_K in batches of 16 bytes and replicate the same across 256 bit vector
+                    __m256i lhs_vec_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + sb * 128)));
+                    __m256i lhs_vec_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16 + sb * 128)));
+                    __m256i lhs_vec_2 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 32 + sb * 128)));
+                    __m256i lhs_vec_3 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 48 + sb * 128)));
+                    __m256i lhs_vec_4 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 64 + sb * 128)));
+                    __m256i lhs_vec_5 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 80 + sb * 128)));
+                    __m256i lhs_vec_6 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 96 + sb * 128)));
+                    __m256i lhs_vec_7 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 112 + sb * 128)));
+
+                    lhs_vec_0 = _mm256_permute2f128_si256(lhs_vec_0, lhs_vec_0, 0);
+                    lhs_vec_1 = _mm256_permute2f128_si256(lhs_vec_1, lhs_vec_1, 0);
+                    lhs_vec_2 = _mm256_permute2f128_si256(lhs_vec_2, lhs_vec_2, 0);
+                    lhs_vec_3 = _mm256_permute2f128_si256(lhs_vec_3, lhs_vec_3, 0);
+                    lhs_vec_4 = _mm256_permute2f128_si256(lhs_vec_4, lhs_vec_4, 0);
+                    lhs_vec_5 = _mm256_permute2f128_si256(lhs_vec_5, lhs_vec_5, 0);
+                    lhs_vec_6 = _mm256_permute2f128_si256(lhs_vec_6, lhs_vec_6, 0);
+                    lhs_vec_7 = _mm256_permute2f128_si256(lhs_vec_7, lhs_vec_7, 0);
+
+                    __m256i iacc_0 = _mm256_setzero_si256();
+                    __m256i iacc_1 = _mm256_setzero_si256();
+                    __m256i iacc_2 = _mm256_setzero_si256();
+                    __m256i iacc_3 = _mm256_setzero_si256();
+                    __m256i iacc_4 = _mm256_setzero_si256();
+                    __m256i iacc_5 = _mm256_setzero_si256();
+                    __m256i iacc_6 = _mm256_setzero_si256();
+                    __m256i iacc_7 = _mm256_setzero_si256();
+
+                    // Dot product done within 32 bit lanes and accumulated in the same vector
+                    // First done for 0th sub block and then for seven (1st - 7th) other sub blocks processed for each sb (sb < QK_K/128 loop)                    // B0(0-3) B4(0-3) B1(0-3) B5(0-3) B2(0-3) B6(0-3) B3(0-3) B7(0-3) with A0(0-3)
+                    // B0(4-7) B4(4-7) B1(4-7) B5(4-7) B2(4-7) B6(4-7) B3(4-7) B7(4-7) with A0(4-7)
+                    // B0(8-11) B4(8-11) B1(8-11) B5(8-11) B2(8-11) B6(8-11) B3(8-11) B7(8-11) with A0(8-11)
+                    // B0(12-15) B4(12-15) B1(12-15) B5(12-15) B2(12-15) B6(12-15) B3(12-15) B7(12-15) with A0(12-15)
+
+                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_00 ,_mm256_shuffle_epi32(rhs_vec_4567_00, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0)));
+                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_00, 177) ,rhs_vec_4567_00, 170), _mm256_shuffle_epi32(lhs_vec_0, 85)));
+
+                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_01 ,_mm256_shuffle_epi32(rhs_vec_4567_01, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170)));
+                    iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_01, 177) ,rhs_vec_4567_01, 170), _mm256_shuffle_epi32(lhs_vec_0, 255)));
+
+                    iacc_0 = _mm256_madd_epi16(iacc_0, scales_0);
+
+                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_10 ,_mm256_shuffle_epi32(rhs_vec_4567_10, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0)));
+                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_10, 177) ,rhs_vec_4567_10, 170), _mm256_shuffle_epi32(lhs_vec_1, 85)));
+
+                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_11 ,_mm256_shuffle_epi32(rhs_vec_4567_11, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170)));
+                    iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_11, 177) ,rhs_vec_4567_11, 170), _mm256_shuffle_epi32(lhs_vec_1, 255)));
+
+                    iacc_1 = _mm256_madd_epi16(iacc_1, scales_1);
+
+                    iacc_2 = _mm256_add_epi16(iacc_2, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_20 ,_mm256_shuffle_epi32(rhs_vec_4567_20, 177), 170), _mm256_shuffle_epi32(lhs_vec_2, 0)));
+                    iacc_2 = _mm256_add_epi16(iacc_2, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_20, 177) ,rhs_vec_4567_20, 170), _mm256_shuffle_epi32(lhs_vec_2, 85)));
+
+                    iacc_2 = _mm256_add_epi16(iacc_2, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_21 ,_mm256_shuffle_epi32(rhs_vec_4567_21, 177), 170), _mm256_shuffle_epi32(lhs_vec_2, 170)));
+                    iacc_2 = _mm256_add_epi16(iacc_2, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_21, 177) ,rhs_vec_4567_21, 170), _mm256_shuffle_epi32(lhs_vec_2, 255)));
+
+                    iacc_2 = _mm256_madd_epi16(iacc_2, scales_2);
+
+                    iacc_3 = _mm256_add_epi16(iacc_3, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_30 ,_mm256_shuffle_epi32(rhs_vec_4567_30, 177), 170), _mm256_shuffle_epi32(lhs_vec_3, 0)));
+                    iacc_3 = _mm256_add_epi16(iacc_3, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_30, 177) ,rhs_vec_4567_30, 170), _mm256_shuffle_epi32(lhs_vec_3, 85)));
+
+                    iacc_3 = _mm256_add_epi16(iacc_3, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_31 ,_mm256_shuffle_epi32(rhs_vec_4567_31, 177), 170), _mm256_shuffle_epi32(lhs_vec_3, 170)));
+                    iacc_3 = _mm256_add_epi16(iacc_3, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_31, 177) ,rhs_vec_4567_31, 170), _mm256_shuffle_epi32(lhs_vec_3, 255)));
+
+                    iacc_3 = _mm256_madd_epi16(iacc_3, scales_3);
+
+                    iacc_4 = _mm256_add_epi16(iacc_4, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_40 ,_mm256_shuffle_epi32(rhs_vec_4567_40, 177), 170), _mm256_shuffle_epi32(lhs_vec_4, 0)));
+                    iacc_4 = _mm256_add_epi16(iacc_4, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_40, 177) ,rhs_vec_4567_40, 170), _mm256_shuffle_epi32(lhs_vec_4, 85)));
+
+                    iacc_4 = _mm256_add_epi16(iacc_4, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_41 ,_mm256_shuffle_epi32(rhs_vec_4567_41, 177), 170), _mm256_shuffle_epi32(lhs_vec_4, 170)));
+                    iacc_4 = _mm256_add_epi16(iacc_4, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_41, 177) ,rhs_vec_4567_41, 170), _mm256_shuffle_epi32(lhs_vec_4, 255)));
+
+                    iacc_4 = _mm256_madd_epi16(iacc_4, scales_4);
+
+                    iacc_5 = _mm256_add_epi16(iacc_5, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_50 ,_mm256_shuffle_epi32(rhs_vec_4567_50, 177), 170), _mm256_shuffle_epi32(lhs_vec_5, 0)));
+                    iacc_5 = _mm256_add_epi16(iacc_5, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_50, 177) ,rhs_vec_4567_50, 170), _mm256_shuffle_epi32(lhs_vec_5, 85)));
+
+                    iacc_5 = _mm256_add_epi16(iacc_5, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_51 ,_mm256_shuffle_epi32(rhs_vec_4567_51, 177), 170), _mm256_shuffle_epi32(lhs_vec_5, 170)));
+                    iacc_5 = _mm256_add_epi16(iacc_5, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_51, 177) ,rhs_vec_4567_51, 170), _mm256_shuffle_epi32(lhs_vec_5, 255)));
+
+                    iacc_5 = _mm256_madd_epi16(iacc_5, scales_5);
+
+                    iacc_6 = _mm256_add_epi16(iacc_6, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_60 ,_mm256_shuffle_epi32(rhs_vec_4567_60, 177), 170), _mm256_shuffle_epi32(lhs_vec_6, 0)));
+                    iacc_6 = _mm256_add_epi16(iacc_6, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_60, 177) ,rhs_vec_4567_60, 170), _mm256_shuffle_epi32(lhs_vec_6, 85)));
+
+                    iacc_6 = _mm256_add_epi16(iacc_6, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_61 ,_mm256_shuffle_epi32(rhs_vec_4567_61, 177), 170), _mm256_shuffle_epi32(lhs_vec_6, 170)));
+                    iacc_6 = _mm256_add_epi16(iacc_6, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_61, 177) ,rhs_vec_4567_61, 170), _mm256_shuffle_epi32(lhs_vec_6, 255)));
+
+                    iacc_6 = _mm256_madd_epi16(iacc_6, scales_6);
+
+                    iacc_7 = _mm256_add_epi16(iacc_7, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_70 ,_mm256_shuffle_epi32(rhs_vec_4567_70, 177), 170), _mm256_shuffle_epi32(lhs_vec_7, 0)));
+                    iacc_7 = _mm256_add_epi16(iacc_7, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_70, 177) ,rhs_vec_4567_70, 170), _mm256_shuffle_epi32(lhs_vec_7, 85)));
+
+                    iacc_7 = _mm256_add_epi16(iacc_7, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_71 ,_mm256_shuffle_epi32(rhs_vec_4567_71, 177), 170), _mm256_shuffle_epi32(lhs_vec_7, 170)));
+                    iacc_7 = _mm256_add_epi16(iacc_7, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_71, 177) ,rhs_vec_4567_71, 170), _mm256_shuffle_epi32(lhs_vec_7, 255)));
+
+                    iacc_7 = _mm256_madd_epi16(iacc_7, scales_7);
+
+                    // Accumulate the iacc value for one sb
+                    __m256i iacc_sb = _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(iacc_0, iacc_1), _mm256_add_epi32(iacc_2, iacc_3)), _mm256_add_epi32(_mm256_add_epi32(iacc_4, iacc_5), _mm256_add_epi32(iacc_6, iacc_7)));
+
+                    __m128i q8sums = _mm_loadu_si128((const __m128i *)(a_ptr[b].bsums + sb * 8));
+                    __m256i q8s = _mm256_castsi128_si256(q8sums);
+                    q8s= _mm256_permute2f128_si256(q8s, q8s, 0);
+
+                    // Broadcast the bsums of the two corresponding subblocks of q8_k
+                    // Multiply-Add with corresponding mins of Q2_Kx8 with bsums
+                    __m256i iacc_min_sb_01 = _mm256_madd_epi16(_mm256_shuffle_epi32(q8s, 0), mins_01);
+                    __m256i iacc_min_sb_23 = _mm256_madd_epi16(_mm256_shuffle_epi32(q8s, 85), mins_23);
+                    __m256i iacc_min_sb_45 = _mm256_madd_epi16(_mm256_shuffle_epi32(q8s, 170), mins_45);
+                    __m256i iacc_min_sb_67 = _mm256_madd_epi16(_mm256_shuffle_epi32(q8s, 255), mins_67);
+
+                    __m256i iacc_min_sb = _mm256_add_epi32(_mm256_add_epi32(iacc_min_sb_01, iacc_min_sb_23), _mm256_add_epi32(iacc_min_sb_45,iacc_min_sb_67));
+
+                    // Accumulate for the complete block
+                    iacc_b = _mm256_add_epi32(iacc_b, iacc_sb);
+                    iacc_min_b = _mm256_add_epi32(iacc_min_b, iacc_min_sb);
+                }
+
+                //Multiply-Add with scale values for complete super block
+                acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_b), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row);
+                acc_min_rows = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_min_b), _mm256_mul_ps(col_dmin_f32, row_scale_f32), acc_min_rows);
+            }
+            // Accumulated output values permuted so as to be stored in appropriate order post accumulation
+            acc_row = _mm256_permutevar8x32_ps(acc_row, finalpermutemask);
+            _mm256_storeu_ps(s + (y * nr + x * 8), _mm256_sub_ps(acc_row, acc_min_rows));
+        }
+    }
+#else
+
+    ggml_gemv_q2_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
+
+#endif
+}
+
+void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+#if defined(__AVX2__) || defined(__AVX512F__)
+    {
+        // Lookup table to convert signed nibbles to signed bytes
+        __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
+        signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
+
+        gemm_q4_b32_8x8_q8_0_lut_avx(n, s, bs, vx, vy, nr, nc, signextendlut);
+
+        return;
+    }
+#endif // defined(__AVX2__) || defined(__AVX512F__)
+
+    ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK_K;
+    const int nb = n / qk;
+    const int ncols_interleaved = 8;
+    const int blocklen = 8;
+    static const uint32_t kmask1 = 0x3f3f3f3f;
+    static const uint32_t kmask2 = 0x0f0f0f0f;
+    static const uint32_t kmask3 = 0x03030303;
+
+    assert (n % qk == 0);
+    assert (nr % 4 == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__AVX2__) || defined(__AVX512F__)
+    const block_q4_Kx8 * b_ptr_start = (const block_q4_Kx8 * ) vx;
+    const block_q8_Kx4 * a_ptr_start = (const block_q8_Kx4 * ) vy;
+    int64_t b_nb = n / QK_K;
+    int64_t y = 0;
+
+    // Mask to mask out nibbles from packed bytes
+    const __m256i m4b = _mm256_set1_epi8(0x0F);
+    // Permute mask used for easier vector processing at later stages
+    __m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
+    int64_t xstart = 0;
+    int anr = nr - nr % 16;; // Used to align nr with boundary of 16
+#ifdef __AVX512F__
+    int anc = nc - nc % 16; // Used to align nc with boundary of 16
+    // Mask to mask out nibbles from packed bytes expanded to 512 bit length
+    const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);
+    //Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation
+    for (; y < anr / 4; y += 4) {
+
+        const block_q8_Kx4 * a_ptrs[4];
+
+        a_ptrs[0] = a_ptr_start + (y * nb);
+        for (int i = 0; i < 3; ++i) {
+            a_ptrs[i + 1] = a_ptrs[i] + nb;
+        }
+
+        // Take group of eight block_q4_kx8 structures at each pass of the loop and perform dot product operation
+        for (int64_t x = 0; x < anc / 8; x += 2) {
+
+            const block_q4_Kx8 * b_ptr_0 = b_ptr_start + ((x) * b_nb);
+            const block_q4_Kx8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);
+
+            // Master FP accumulators
+            __m512 acc_rows[16];
+            for (int i = 0; i < 16; i++) {
+                acc_rows[i] = _mm512_setzero_ps();
+            }
+
+            __m512 acc_min_rows[16];
+            for (int i = 0; i < 16; i++) {
+                acc_min_rows[i] = _mm512_setzero_ps();
+            }
+
+            // For super block
+            for (int64_t b = 0; b < nb; b++) {
+                // Scale values - Load the sixteen scale values from two block_q4_kx8 structures
+                const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
+
+                // dmin values - Load the sixteen dmin values from two block_q4_kx8 structures
+                const __m512 col_dmin_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].dmin, b_ptr_1[b].dmin);
+
+                // Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration
+                for (int sb = 0; sb < QK_K / 64; sb++) {
+
+                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + sb * 256));
+                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 32 + sb * 256));
+                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 64 + sb * 256));
+                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 96 + sb * 256));
+                    const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 128 + sb * 256));
+                    const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 160 + sb * 256));
+                    const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 192 + sb * 256));
+                    const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 224 + sb * 256));
+
+                    const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + sb * 256));
+                    const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 32 + sb * 256));
+                    const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 64 + sb * 256));
+                    const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 96 + sb * 256));
+                    const __m256i rhs_raw_mat_89AB_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 128 + sb * 256));
+                    const __m256i rhs_raw_mat_CDEF_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 160 + sb * 256));
+                    const __m256i rhs_raw_mat_89AB_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 192 + sb * 256));
+                    const __m256i rhs_raw_mat_CDEF_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 224 + sb * 256));
+
+                    const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
+                    const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
+                    const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240);
+                    const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240);
+
+                    const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);
+                    const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);
+                    const __m256i rhs_raw_mat_89CD_2 = _mm256_blend_epi32(rhs_raw_mat_89AB_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_2, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_ABEF_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_2, requiredOrder), rhs_raw_mat_CDEF_2, 240);
+                    const __m256i rhs_raw_mat_89CD_3 = _mm256_blend_epi32(rhs_raw_mat_89AB_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_3, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_ABEF_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_3, requiredOrder), rhs_raw_mat_CDEF_3, 240);
+
+                    const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1);
+                    const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1);
+                    const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1);
+                    const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1);
+
+                    const __m512i rhs_raw_mat_014589CD_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_2), rhs_raw_mat_89CD_2, 1);
+                    const __m512i rhs_raw_mat_2367ABEF_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_2), rhs_raw_mat_ABEF_2, 1);
+                    const __m512i rhs_raw_mat_014589CD_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_3), rhs_raw_mat_89CD_3, 1);
+                    const __m512i rhs_raw_mat_2367ABEF_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_3), rhs_raw_mat_ABEF_3, 1);
+
+                    //4-bit -> 8-bit
+                    const __m512i rhs_mat_014589CD_00 = _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded); //B00(0-7) B01(0-7) B04(0-7) B05(0-7) B08(0-7) B09(0-7) B0C(0-7) B0D(0-7)
+                    const __m512i rhs_mat_2367ABEF_00 = _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded); //B02(0-7) B03(0-7) B06(0-7) B07(0-7) B0A(0-7) B0B(0-7) B0E(0-7) B0F(0-7)
+                    const __m512i rhs_mat_014589CD_01 = _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded); //B00(8-15) B01(8-15) B04(8-15) B05(8-15) B08(8-15) B09(8-15) B0C(8-15) B0D(8-15)
+                    const __m512i rhs_mat_2367ABEF_01 = _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded); //B02(8-15) B03(8-15) B06(8-15) B07(8-15) B0A(8-15) B0B(8-15) B0E(8-15) B0F(8-15)
+
+                    const __m512i rhs_mat_014589CD_02 = _mm512_and_si512(rhs_raw_mat_014589CD_2, m4bexpanded); //B00(16-23) B01(16-23) B04(16-23) B05(16-23) B08(16-23) B09(16-23) B0C(16-23) B0D(16-23)
+                    const __m512i rhs_mat_2367ABEF_02 = _mm512_and_si512(rhs_raw_mat_2367ABEF_2, m4bexpanded); //B02(16-23) B03(16-23) B06(16-23) B07(16-23) B0A(16-23) B0B(16-23) B0E(16-23) B0F(16-23)
+                    const __m512i rhs_mat_014589CD_03 = _mm512_and_si512(rhs_raw_mat_014589CD_3, m4bexpanded); //B00(24-31) B01(24-31) B04(24-31) B05(24-31) B08(24-31) B09(24-31) B0C(24-31) B0D(24-31)
+                    const __m512i rhs_mat_2367ABEF_03 = _mm512_and_si512(rhs_raw_mat_2367ABEF_3, m4bexpanded); //B02(24-31) B03(24-31) B06(24-31) B07(24-31) B0A(24-31) B0B(24-31) B0E(24-31) B0F(24-31)
+
+                    const __m512i rhs_mat_014589CD_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded); //B10(0-7) B11(0-7) B14(0-7) B15(0-7) B18(0-7) B19(0-7) B1C(0-7) B1D(0-7)
+                    const __m512i rhs_mat_2367ABEF_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded); //B12(0-7) B13(0-7) B16(0-7) B17(0-7) B1A(0-7) B1B(0-7) B1E(0-7) B1F(0-7)
+                    const __m512i rhs_mat_014589CD_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded); //B10(8-15) B11(8-15) B14(8-15) B15(8-15) B18(8-15) B19(8-15) B1C(8-15) B1D(8-15)
+                    const __m512i rhs_mat_2367ABEF_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded); //B12(8-15) B13(8-15) B16(8-15) B17(8-15) B1A(8-15) B1B(8-15) B1E(8-15) B1F(8-15)
+
+                    const __m512i rhs_mat_014589CD_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 4), m4bexpanded); //B10(16-23) B11(16-23) B14(16-23) B15(16-23) B18(16-23) B19(16-23) B1C(16-23) B1D(16-23)
+                    const __m512i rhs_mat_2367ABEF_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 4), m4bexpanded); //B12(16-23) B13(16-23) B16(16-23) B17(16-23) B1A(16-23) B1B(16-23) B1E(16-23) B1F(16-23)
+                    const __m512i rhs_mat_014589CD_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 4), m4bexpanded); //B10(24-31) B11(24-31) B14(24-31) B15(24-31) B18(24-31) B19(24-31) B1C(24-31) B1D(24-31)
+                    const __m512i rhs_mat_2367ABEF_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 4), m4bexpanded); //B12(24-31) B13(24-31) B16(24-31) B17(24-31) B1A(24-31) B1B(24-31) B1E(24-31) B1F(24-31)
+
+                    // Shuffle pattern one - right side input
+                    const __m512i rhs_mat_014589CD_00_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3) B08(0-3) B09(0-3) B08(0-3) B09(0-3) B0C(0-3) B0D(0-3) B0C(0-3) B0D(0-3)
+                    const __m512i rhs_mat_2367ABEF_00_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3) B0A(0-3) B0B(0-3) B0A(0-3) B0B(0-3) B0E(0-3) B0F(0-3) B0E(0-3) B0F(0-3)
+                    const __m512i rhs_mat_014589CD_01_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11) B08(8-11) B09(8-11) B08(8-11) B09(8-11) B0C(8-11) B0D(8-11) B0C(8-11) B0D(8-11)
+                    const __m512i rhs_mat_2367ABEF_01_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11) B0A(8-11) B0B(8-11) B0A(8-11) B0B(8-11) B0E(8-11) B0F(8-11) B0E(8-11) B0F(8-11)
+                    const __m512i rhs_mat_014589CD_02_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19) B08(16-19) B09(16-19) B08(16-19) B09(16-19) B0C(16-19) B0D(16-19) B0C(16-19) B0D(16-19)
+                    const __m512i rhs_mat_2367ABEF_02_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19) B0A(16-19) B0B(16-19) B0A(16-19) B0B(16-19) B0E(16-19) B0F(16-19) B0E(16-19) B0F(16-19)
+                    const __m512i rhs_mat_014589CD_03_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27) B08(24-27) B09(24-27) B08(24-27) B09(24-27) B0C(24-27) B0D(24-27) B0C(24-27) B0D(24-27)
+                    const __m512i rhs_mat_2367ABEF_03_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27) B0A(24-27) B0B(24-27) B0A(24-27) B0B(24-27) B0E(24-27) B0F(24-27) B0E(24-27) B0F(24-27)
+
+                    const __m512i rhs_mat_014589CD_10_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3) B18(0-3) B19(0-3) B18(0-3) B19(0-3) B1C(0-3) B1D(0-3) B1C(0-3) B1D(0-3)
+                    const __m512i rhs_mat_2367ABEF_10_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3) B1A(0-3) B1B(0-3) B1A(0-3) B1B(0-3) B1E(0-3) B1F(0-3) B1E(0-3) B1F(0-3)
+                    const __m512i rhs_mat_014589CD_11_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11) B18(8-11) B19(8-11) B18(8-11) B19(8-11) B1C(8-11) B1D(8-11) B1C(8-11) B1D(8-11)
+                    const __m512i rhs_mat_2367ABEF_11_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11) B1A(8-11) B1B(8-11) B1A(8-11) B1B(8-11) B1E(8-11) B1F(8-11) B1E(8-11) B1F(8-11)
+                    const __m512i rhs_mat_014589CD_12_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19) B18(16-19) B19(16-19) B18(16-19) B19(16-19) B1C(16-19) B1D(16-19) B1C(16-19) B1D(16-19)
+                    const __m512i rhs_mat_2367ABEF_12_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19) B1A(16-19) B1B(16-19) B1A(16-19) B1B(16-19) B1E(16-19) B1F(16-19) B1E(16-19) B1F(16-19)
+                    const __m512i rhs_mat_014589CD_13_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27) B18(24-27) B19(24-27) B18(24-27) B19(24-27) B1C(24-27) B1D(24-27) B1C(24-27) B1D(24-27)
+                    const __m512i rhs_mat_2367ABEF_13_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27) B1A(24-27) B1B(24-27) B1A(24-27) B1B(24-27) B1E(24-27) B1F(24-27) B1E(24-27) B1F(24-27)
+
+                    // Shuffle pattern two - right side input
+                    const __m512i rhs_mat_014589CD_00_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7) B08(4-7) B09(4-7) B08(4-7) B09(4-7) B0C(4-7) B0D(4-7) B0C(4-7) B0D(4-7)
+                    const __m512i rhs_mat_2367ABEF_00_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7) B0A(4-7) B0B(4-7) B0A(4-7) B0B(4-7) B0E(4-7) B0F(4-7) B0E(4-7) B0F(4-7)
+                    const __m512i rhs_mat_014589CD_01_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15) B08(12-15) B09(12-15) B08(12-15) B09(12-15) B0C(12-15) B0D(12-15) B0C(12-15) B0D(12-15)
+                    const __m512i rhs_mat_2367ABEF_01_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15) B0A(12-15) B0B(12-15) B0A(12-15) B0B(12-15) B0E(12-15) B0F(12-15) B0E(12-15) B0F(12-15)
+                    const __m512i rhs_mat_014589CD_02_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23) B08(20-23) B09(20-23) B08(20-23) B09(20-23) B0C(20-23) B0D(20-23) B0C(20-23) B0D(20-23)
+                    const __m512i rhs_mat_2367ABEF_02_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23) B0A(20-23) B0B(20-23) B0A(20-23) B0B(20-23) B0E(20-23) B0F(20-23) B0E(20-23) B0F(20-23)
+                    const __m512i rhs_mat_014589CD_03_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31) B08(28-31) B09(28-31) B08(28-31) B09(28-31) B0C(28-31) B0D(28-31) B0C(28-31) 0BD(28-31)
+                    const __m512i rhs_mat_2367ABEF_03_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31) B0A(28-31) B0B(28-31) B0A(28-31) B0B(28-31) B0E(28-31) B0F(28-31) B0E(28-31) B0F(28-31)
+
+                    const __m512i rhs_mat_014589CD_10_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7) B18(4-7) B19(4-7) B18(4-7) B19(4-7) B1C(4-7) B1D(4-7) B1C(4-7) B1D(4-7)
+                    const __m512i rhs_mat_2367ABEF_10_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7) B1A(4-7) B1B(4-7) B1A(4-7) B1B(4-7) B1E(4-7) B1F(4-7) B1E(4-7) B1F(4-7)
+                    const __m512i rhs_mat_014589CD_11_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15) B18(12-15) B19(12-15) B18(12-15) B19(12-15) B1C(12-15) B1D(12-15) B1C(12-15) B1D(12-15)
+                    const __m512i rhs_mat_2367ABEF_11_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15) B1A(12-15) B1B(12-15) B1A(12-15) B1B(12-15) B1E(12-15) B1F(12-15) B1E(12-15) B1F(12-15)
+                    const __m512i rhs_mat_014589CD_12_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23) B18(20-23) B19(20-23) B18(20-23) B19(20-23) B1C(20-23) B1D(20-23) B1C(20-23) B1D(20-23)
+                    const __m512i rhs_mat_2367ABEF_12_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23) B1A(20-23) B1B(20-23) B1A(20-23) B1B(20-23) B1E(20-23) B1F(20-23) B1E(20-23) B1F(20-23)
+                    const __m512i rhs_mat_014589CD_13_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31) B18(28-31) B19(28-31) B18(28-31) B19(28-31) B1C(28-31) B1D(28-31) B1C(28-31) B1D(28-31)
+                    const __m512i rhs_mat_2367ABEF_13_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31) B1A(28-31) B1B(28-31) B1A(28-31) B1B(28-31) B1E(28-31) B1F(28-31) B1E(28-31) B1F(28-31)
+
+                    uint32_t utmp_00[4], utmp_01[4], utmp_10[4], utmp_11[4];
+
+                    // Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together
+                    // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
+                    memcpy(utmp_00, b_ptr_0[b].scales + 24 * sb, 12);
+                    utmp_00[3] = ((utmp_00[2] >> 4) & kmask2) | (((utmp_00[1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_00 = utmp_00[1] & kmask1;
+                    utmp_00[1] = (utmp_00[2] & kmask2) | (((utmp_00[0] >> 6) & kmask3) << 4);
+                    utmp_00[2] = uaux_00;
+                    utmp_00[0] &= kmask1;
+
+                    // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
+                    memcpy(utmp_01, b_ptr_0[b].scales + 12 + sb * 24, 12);
+                    utmp_01[3] = ((utmp_01[2] >> 4) & kmask2) | (((utmp_01[1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_01 = utmp_01[1] & kmask1;
+                    utmp_01[1] = (utmp_01[2] & kmask2) | (((utmp_01[0] >> 6) & kmask3) << 4);
+                    utmp_01[2] = uaux_01;
+                    utmp_01[0] &= kmask1;
+
+                    memcpy(utmp_10, b_ptr_1[b].scales + sb * 24, 12);
+                    utmp_10[3] = ((utmp_10[2] >> 4) & kmask2) | (((utmp_10[1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_10 = utmp_10[1] & kmask1;
+                    utmp_10[1] = (utmp_10[2] & kmask2) | (((utmp_10[0] >> 6) & kmask3) << 4);
+                    utmp_10[2] = uaux_10;
+                    utmp_10[0] &= kmask1;
+
+                    // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
+                    memcpy(utmp_11, b_ptr_1[b].scales + 12 + sb * 24, 12);
+                    utmp_11[3] = ((utmp_11[2] >> 4) & kmask2) | (((utmp_11[1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_11 = utmp_11[1] & kmask1;
+                    utmp_11[1] = (utmp_11[2] & kmask2) | (((utmp_11[0] >> 6) & kmask3) << 4);
+                    utmp_11[2] = uaux_11;
+                    utmp_11[0] &= kmask1;
+
+                    // Scales of first sub block in the sb loop
+                    const __m256i mins_and_scales_0 = _mm256_set_epi32(utmp_10[3], utmp_10[2], utmp_10[1], utmp_10[0], utmp_00[3], utmp_00[2], utmp_00[1], utmp_00[0]);
+                    const __m512i scales_0 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0));
+
+                    // Scales of second sub block in the sb loop
+                    const __m256i mins_and_scales_1 = _mm256_set_epi32(utmp_11[3], utmp_11[2], utmp_11[1], utmp_11[0], utmp_01[3], utmp_01[2], utmp_01[1], utmp_01[0]);
+                    const __m512i scales_1 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1));
+
+                    // Mins of first and second sub block of Q4_K block are arranged side by side
+                    const __m512i mins_01 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(_mm256_shuffle_epi32(mins_and_scales_0, 78), _mm256_shuffle_epi32(mins_and_scales_1, 78)));
+
+                    const __m512i scale_014589CD_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)68);
+                    const __m512i scale_2367ABEF_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)238);
+
+                    const __m512i scale_014589CD_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)68);
+                    const __m512i scale_2367ABEF_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)238);
+
+                    for (int rp = 0; rp < 4; rp++) {
+
+                        // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3
+                        // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector
+                        __m256i lhs_mat_ymm_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 256 * sb)));
+                        __m256i lhs_mat_ymm_01_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 0);
+                        __m256i lhs_mat_ymm_23_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 17);
+                        __m256i lhs_mat_ymm_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 32 + 256 * sb)));
+                        __m256i lhs_mat_ymm_01_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 0);
+                        __m256i lhs_mat_ymm_23_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 17);
+                        __m256i lhs_mat_ymm_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 64 + 256 * sb)));
+                        __m256i lhs_mat_ymm_01_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 0);
+                        __m256i lhs_mat_ymm_23_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 17);
+                        __m256i lhs_mat_ymm_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 96 + 256 * sb)));
+                        __m256i lhs_mat_ymm_01_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 0);
+                        __m256i lhs_mat_ymm_23_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 17);
+                        __m256i lhs_mat_ymm_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 128 + 256 * sb)));
+                        __m256i lhs_mat_ymm_01_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 0);
+                        __m256i lhs_mat_ymm_23_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 17);
+                        __m256i lhs_mat_ymm_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 160 + 256 * sb)));
+                        __m256i lhs_mat_ymm_01_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 0);
+                        __m256i lhs_mat_ymm_23_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 17);
+                        __m256i lhs_mat_ymm_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 192 + 256 * sb)));
+                        __m256i lhs_mat_ymm_01_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 0);
+                        __m256i lhs_mat_ymm_23_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 17);
+                        __m256i lhs_mat_ymm_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 224 + 256 * sb)));
+                        __m256i lhs_mat_ymm_01_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 0);
+                        __m256i lhs_mat_ymm_23_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 17);
+
+                        __m512i lhs_mat_01_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_00), lhs_mat_ymm_01_00, 1);
+                        __m512i lhs_mat_23_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_00), lhs_mat_ymm_23_00, 1);
+                        __m512i lhs_mat_01_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_01), lhs_mat_ymm_01_01, 1);
+                        __m512i lhs_mat_23_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_01), lhs_mat_ymm_23_01, 1);
+                        __m512i lhs_mat_01_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_02), lhs_mat_ymm_01_02, 1);
+                        __m512i lhs_mat_23_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_02), lhs_mat_ymm_23_02, 1);
+                        __m512i lhs_mat_01_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_03), lhs_mat_ymm_01_03, 1);
+                        __m512i lhs_mat_23_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_03), lhs_mat_ymm_23_03, 1);
+
+                        __m512i lhs_mat_01_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_10), lhs_mat_ymm_01_10, 1);
+                        __m512i lhs_mat_23_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_10), lhs_mat_ymm_23_10, 1);
+                        __m512i lhs_mat_01_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_11), lhs_mat_ymm_01_11, 1);
+                        __m512i lhs_mat_23_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_11), lhs_mat_ymm_23_11, 1);
+                        __m512i lhs_mat_01_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_12), lhs_mat_ymm_01_12, 1);
+                        __m512i lhs_mat_23_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_12), lhs_mat_ymm_23_12, 1);
+                        __m512i lhs_mat_01_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_13), lhs_mat_ymm_01_13, 1);
+                        __m512i lhs_mat_23_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_13), lhs_mat_ymm_23_13, 1);
+
+                        // Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks
+                        __m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].bsums + 16 * sb)));
+                        __m256i lhs_bsums_hsum_ymm_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1)));
+                        lhs_bsums_hsum_ymm_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_ymm_0123_01, lhs_bsums_hsum_ymm_0123_01, 0);
+                        __m512i lhs_bsums_hsum_0123_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_hsum_ymm_0123_01), lhs_bsums_hsum_ymm_0123_01, 1);
+
+                        // Shuffle pattern one - left side input
+                        const __m512i lhs_mat_01_00_sp1 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)
+                        const __m512i lhs_mat_23_00_sp1 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)160); //A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3)
+                        const __m512i lhs_mat_01_01_sp1 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)
+                        const __m512i lhs_mat_23_01_sp1 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)160); //A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11)
+                        const __m512i lhs_mat_01_02_sp1 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)160); //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19)
+                        const __m512i lhs_mat_23_02_sp1 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)160); //A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19)
+                        const __m512i lhs_mat_01_03_sp1 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)160); //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27)
+                        const __m512i lhs_mat_23_03_sp1 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)160); //A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27)
+
+                        const __m512i lhs_mat_01_10_sp1 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3)
+                        const __m512i lhs_mat_23_10_sp1 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)160); //A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3)
+                        const __m512i lhs_mat_01_11_sp1 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)
+                        const __m512i lhs_mat_23_11_sp1 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)160); //A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11)
+                        const __m512i lhs_mat_01_12_sp1 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)160); //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19)
+                        const __m512i lhs_mat_23_12_sp1 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)160); //A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19)
+                        const __m512i lhs_mat_01_13_sp1 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27)
+                        const __m512i lhs_mat_23_13_sp1 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)160); //A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27)
+
+                        const __m512i lhs_mat_01_00_sp2 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7)
+                        const __m512i lhs_mat_23_00_sp2 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)245); //A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7)
+                        const __m512i lhs_mat_01_01_sp2 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15)
+                        const __m512i lhs_mat_23_01_sp2 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)245); //A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15)
+                        const __m512i lhs_mat_01_02_sp2 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23)
+                        const __m512i lhs_mat_23_02_sp2 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)245); //A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23)
+                        const __m512i lhs_mat_01_03_sp2 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31)
+                        const __m512i lhs_mat_23_03_sp2 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)245); //A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31)
+
+                        const __m512i lhs_mat_01_10_sp2 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7)
+                        const __m512i lhs_mat_23_10_sp2 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)245); //A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7)
+                        const __m512i lhs_mat_01_11_sp2 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15)
+                        const __m512i lhs_mat_23_11_sp2 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)245); //A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15)
+                        const __m512i lhs_mat_01_12_sp2 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23)
+                        const __m512i lhs_mat_23_12_sp2 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)245); //A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23)
+                        const __m512i lhs_mat_01_13_sp2 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31)
+                        const __m512i lhs_mat_23_13_sp2 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)245); //A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31)
+
+                        // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
+                        __m512i iacc_mat_00_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_01_00_sp1));
+                        __m512i iacc_mat_01_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_01_00_sp1));
+                        __m512i iacc_mat_10_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_23_00_sp1));
+                        __m512i iacc_mat_11_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_23_00_sp1));
+                        __m512i iacc_mat_00_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_01_10_sp1));
+                        __m512i iacc_mat_01_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_01_10_sp1));
+                        __m512i iacc_mat_10_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_23_10_sp1));
+                        __m512i iacc_mat_11_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_23_10_sp1));
+
+                        __m512i iacc_mat_00_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_01_00_sp2));
+                        __m512i iacc_mat_01_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_01_00_sp2));
+                        __m512i iacc_mat_10_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_23_00_sp2));
+                        __m512i iacc_mat_11_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_23_00_sp2));
+                        __m512i iacc_mat_00_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_01_10_sp2));
+                        __m512i iacc_mat_01_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_01_10_sp2));
+                        __m512i iacc_mat_10_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_23_10_sp2));
+                        __m512i iacc_mat_11_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_23_10_sp2));
+
+                        // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+                        __m512i iacc_mat_00_0 = _mm512_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2);
+                        __m512i iacc_mat_01_0 = _mm512_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2);
+                        __m512i iacc_mat_10_0 = _mm512_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2);
+                        __m512i iacc_mat_11_0 = _mm512_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2);
+
+                        __m512i iacc_mat_00_1 = _mm512_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2);
+                        __m512i iacc_mat_01_1 = _mm512_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2);
+                        __m512i iacc_mat_10_1 = _mm512_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2);
+                        __m512i iacc_mat_11_1 = _mm512_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2);
+
+                        iacc_mat_00_0 = _mm512_madd_epi16(iacc_mat_00_0, scale_014589CD_0);
+                        iacc_mat_01_0 = _mm512_madd_epi16(iacc_mat_01_0, scale_2367ABEF_0);
+                        iacc_mat_10_0 = _mm512_madd_epi16(iacc_mat_10_0, scale_014589CD_0);
+                        iacc_mat_11_0 = _mm512_madd_epi16(iacc_mat_11_0, scale_2367ABEF_0);
+
+                        iacc_mat_00_1 = _mm512_madd_epi16(iacc_mat_00_1, scale_014589CD_1);
+                        iacc_mat_01_1 = _mm512_madd_epi16(iacc_mat_01_1, scale_2367ABEF_1);
+                        iacc_mat_10_1 = _mm512_madd_epi16(iacc_mat_10_1, scale_014589CD_1);
+                        iacc_mat_11_1 = _mm512_madd_epi16(iacc_mat_11_1, scale_2367ABEF_1);
+
+                        // Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step)
+                        __m512i iacc_row_0_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_0, _mm512_shuffle_epi32(iacc_mat_01_0, (_MM_PERM_ENUM)78));
+                        __m512i iacc_row_1_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_0, (_MM_PERM_ENUM)78), iacc_mat_01_0);
+                        __m512i iacc_row_2_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_0, _mm512_shuffle_epi32(iacc_mat_11_0, (_MM_PERM_ENUM)78));
+                        __m512i iacc_row_3_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10_0, (_MM_PERM_ENUM)78), iacc_mat_11_0);
+                        __m512i iacc_row_0_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_1, _mm512_shuffle_epi32(iacc_mat_01_1, (_MM_PERM_ENUM)78));
+                        __m512i iacc_row_1_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_1, (_MM_PERM_ENUM)78), iacc_mat_01_1);
+                        __m512i iacc_row_2_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_1, _mm512_shuffle_epi32(iacc_mat_11_1, (_MM_PERM_ENUM)78));
+                        __m512i iacc_row_3_1 = _mm512_mask_blend_epi32(0xCCCC,_mm512_shuffle_epi32(iacc_mat_10_1, (_MM_PERM_ENUM)78), iacc_mat_11_1);
+
+                        __m512i iacc_row_0 = _mm512_add_epi32(iacc_row_0_0, iacc_row_0_1);
+                        __m512i iacc_row_1 = _mm512_add_epi32(iacc_row_1_0, iacc_row_1_1);
+                        __m512i iacc_row_2 = _mm512_add_epi32(iacc_row_2_0, iacc_row_2_1);
+                        __m512i iacc_row_3 = _mm512_add_epi32(iacc_row_3_0, iacc_row_3_1);
+
+                        // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes
+                        const __m128 row_scale_f32_sse = _mm_load_ps(a_ptrs[rp][b].d);
+                        const __m256 row_scale_f32_ymm = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse);
+                        const __m512 row_scale_f32 = _mm512_insertf32x8(_mm512_castps256_ps512(row_scale_f32_ymm), row_scale_f32_ymm, 1);
+
+                        // Multiply with appropiate scales and accumulate (for both d and dmin) below
+                        acc_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);
+                        acc_rows[rp * 4  + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);
+                        acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
+                        acc_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);
+
+                        __m512i iacc_row_min_0 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)0), mins_01);
+                        __m512i iacc_row_min_1 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)85), mins_01);
+                        __m512i iacc_row_min_2 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)170), mins_01);
+                        __m512i iacc_row_min_3 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)255), mins_01);
+
+                        acc_min_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_0), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[rp * 4]);
+                        acc_min_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_1), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[rp * 4 + 1]);
+                        acc_min_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_2), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[rp * 4 + 2]);
+                        acc_min_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_3), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[rp * 4 + 3]);
+                    }
+                }
+            }
+            // Store the accumulated values
+            for (int i = 0; i < 16; i++) {
+                _mm512_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm512_sub_ps(acc_rows[i], acc_min_rows[i]));
+            }
+        }
+    }
+
+    for (; y < nr / 4; y++) {
+
+        const block_q8_Kx4 * a_ptr = a_ptr_start + (y * nb);
+
+        // Take group of eight block_q4_kx8 structures at each pass of the loop and perform dot product operation
+        for (int64_t x = 0; x < anc / 8; x += 2) {
+
+            const block_q4_Kx8 * b_ptr_0 = b_ptr_start + ((x) * b_nb);
+            const block_q4_Kx8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);
+
+            // Master FP accumulators
+            __m512 acc_rows[4];
+            for (int i = 0; i < 4; i++) {
+                acc_rows[i] = _mm512_setzero_ps();
+            }
+
+            __m512 acc_min_rows[4];
+            for (int i = 0; i < 4; i++) {
+                acc_min_rows[i] = _mm512_setzero_ps();
+            }
+
+            // For super block
+            for (int64_t b = 0; b < nb; b++) {
+                // Scale values - Load the sixteen scale values from two block_q4_kx8 structures
+                const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
+
+                // dmin values - Load the sixteen dmin values from two block_q4_kx8 structures
+                const __m512 col_dmin_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].dmin, b_ptr_1[b].dmin);
+
+                // Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration
+                for (int sb = 0; sb < QK_K / 64; sb++) {
+
+                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + sb * 256));
+                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 32 + sb * 256));
+                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 64 + sb * 256));
+                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 96 + sb * 256));
+                    const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 128 + sb * 256));
+                    const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 160 + sb * 256));
+                    const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 192 + sb * 256));
+                    const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 224 + sb * 256));
+
+                    const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + sb * 256));
+                    const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 32 + sb * 256));
+                    const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 64 + sb * 256));
+                    const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 96 + sb * 256));
+                    const __m256i rhs_raw_mat_89AB_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 128 + sb * 256));
+                    const __m256i rhs_raw_mat_CDEF_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 160 + sb * 256));
+                    const __m256i rhs_raw_mat_89AB_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 192 + sb * 256));
+                    const __m256i rhs_raw_mat_CDEF_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 224 + sb * 256));
+
+                    const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
+                    const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
+                    const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240);
+                    const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240);
+
+                    const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);
+                    const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);
+                    const __m256i rhs_raw_mat_89CD_2 = _mm256_blend_epi32(rhs_raw_mat_89AB_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_2, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_ABEF_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_2, requiredOrder), rhs_raw_mat_CDEF_2, 240);
+                    const __m256i rhs_raw_mat_89CD_3 = _mm256_blend_epi32(rhs_raw_mat_89AB_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_3, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_ABEF_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_3, requiredOrder), rhs_raw_mat_CDEF_3, 240);
+
+                    const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1);
+                    const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1);
+                    const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1);
+                    const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1);
+
+                    const __m512i rhs_raw_mat_014589CD_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_2), rhs_raw_mat_89CD_2, 1);
+                    const __m512i rhs_raw_mat_2367ABEF_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_2), rhs_raw_mat_ABEF_2, 1);
+                    const __m512i rhs_raw_mat_014589CD_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_3), rhs_raw_mat_89CD_3, 1);
+                    const __m512i rhs_raw_mat_2367ABEF_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_3), rhs_raw_mat_ABEF_3, 1);
+
+                    //4-bit -> 8-bit
+                    const __m512i rhs_mat_014589CD_00 = _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded); //B00(0-7) B01(0-7) B04(0-7) B05(0-7) B08(0-7) B09(0-7) B0C(0-7) B0D(0-7)
+                    const __m512i rhs_mat_2367ABEF_00 = _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded); //B02(0-7) B03(0-7) B06(0-7) B07(0-7) B0A(0-7) B0B(0-7) B0E(0-7) B0F(0-7)
+                    const __m512i rhs_mat_014589CD_01 = _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded); //B00(8-15) B01(8-15) B04(8-15) B05(8-15) B08(8-15) B09(8-15) B0C(8-15) B0D(8-15)
+                    const __m512i rhs_mat_2367ABEF_01 = _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded); //B02(8-15) B03(8-15) B06(8-15) B07(8-15) B0A(8-15) B0B(8-15) B0E(8-15) B0F(8-15)
+
+                    const __m512i rhs_mat_014589CD_02 = _mm512_and_si512(rhs_raw_mat_014589CD_2, m4bexpanded); //B00(16-23) B01(16-23) B04(16-23) B05(16-23) B08(16-23) B09(16-23) B0C(16-23) B0D(16-23)
+                    const __m512i rhs_mat_2367ABEF_02 = _mm512_and_si512(rhs_raw_mat_2367ABEF_2, m4bexpanded); //B02(16-23) B03(16-23) B06(16-23) B07(16-23) B0A(16-23) B0B(16-23) B0E(16-23) B0F(16-23)
+                    const __m512i rhs_mat_014589CD_03 = _mm512_and_si512(rhs_raw_mat_014589CD_3, m4bexpanded); //B00(24-31) B01(24-31) B04(24-31) B05(24-31) B08(24-31) B09(24-31) B0C(24-31) B0D(24-31)
+                    const __m512i rhs_mat_2367ABEF_03 = _mm512_and_si512(rhs_raw_mat_2367ABEF_3, m4bexpanded); //B02(24-31) B03(24-31) B06(24-31) B07(24-31) B0A(24-31) B0B(24-31) B0E(24-31) B0F(24-31)
+
+                    const __m512i rhs_mat_014589CD_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded); //B10(0-7) B11(0-7) B14(0-7) B15(0-7) B18(0-7) B19(0-7) B1C(0-7) B1D(0-7)
+                    const __m512i rhs_mat_2367ABEF_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded); //B12(0-7) B13(0-7) B16(0-7) B17(0-7) B1A(0-7) B1B(0-7) B1E(0-7) B1F(0-7)
+                    const __m512i rhs_mat_014589CD_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded); //B10(8-15) B11(8-15) B14(8-15) B15(8-15) B18(8-15) B19(8-15) B1C(8-15) B1D(8-15)
+                    const __m512i rhs_mat_2367ABEF_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded); //B12(8-15) B13(8-15) B16(8-15) B17(8-15) B1A(8-15) B1B(8-15) B1E(8-15) B1F(8-15)
+
+                    const __m512i rhs_mat_014589CD_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 4), m4bexpanded); //B10(16-23) B11(16-23) B14(16-23) B15(16-23) B18(16-23) B19(16-23) B1C(16-23) B1D(16-23)
+                    const __m512i rhs_mat_2367ABEF_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 4), m4bexpanded); //B12(16-23) B13(16-23) B16(16-23) B17(16-23) B1A(16-23) B1B(16-23) B1E(16-23) B1F(16-23)
+                    const __m512i rhs_mat_014589CD_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 4), m4bexpanded); //B10(24-31) B11(24-31) B14(24-31) B15(24-31) B18(24-31) B19(24-31) B1C(24-31) B1D(24-31)
+                    const __m512i rhs_mat_2367ABEF_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 4), m4bexpanded); //B12(24-31) B13(24-31) B16(24-31) B17(24-31) B1A(24-31) B1B(24-31) B1E(24-31) B1F(24-31)
+
+                    // Shuffle pattern one - right side input
+                    const __m512i rhs_mat_014589CD_00_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3) B08(0-3) B09(0-3) B08(0-3) B09(0-3) B0C(0-3) B0D(0-3) B0C(0-3) B0D(0-3)
+                    const __m512i rhs_mat_2367ABEF_00_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3) B0A(0-3) B0B(0-3) B0A(0-3) B0B(0-3) B0E(0-3) B0F(0-3) B0E(0-3) B0F(0-3)
+                    const __m512i rhs_mat_014589CD_01_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11) B08(8-11) B09(8-11) B08(8-11) B09(8-11) B0C(8-11) B0D(8-11) B0C(8-11) B0D(8-11)
+                    const __m512i rhs_mat_2367ABEF_01_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11) B0A(8-11) B0B(8-11) B0A(8-11) B0B(8-11) B0E(8-11) B0F(8-11) B0E(8-11) B0F(8-11)
+                    const __m512i rhs_mat_014589CD_02_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19) B08(16-19) B09(16-19) B08(16-19) B09(16-19) B0C(16-19) B0D(16-19) B0C(16-19) B0D(16-19)
+                    const __m512i rhs_mat_2367ABEF_02_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19) B0A(16-19) B0B(16-19) B0A(16-19) B0B(16-19) B0E(16-19) B0F(16-19) B0E(16-19) B0F(16-19)
+                    const __m512i rhs_mat_014589CD_03_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27) B08(24-27) B09(24-27) B08(24-27) B09(24-27) B0C(24-27) B0D(24-27) B0C(24-27) B0D(24-27)
+                    const __m512i rhs_mat_2367ABEF_03_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27) B0A(24-27) B0B(24-27) B0A(24-27) B0B(24-27) B0E(24-27) B0F(24-27) B0E(24-27) B0F(24-27)
+
+                    const __m512i rhs_mat_014589CD_10_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3) B18(0-3) B19(0-3) B18(0-3) B19(0-3) B1C(0-3) B1D(0-3) B1C(0-3) B1D(0-3)
+                    const __m512i rhs_mat_2367ABEF_10_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3) B1A(0-3) B1B(0-3) B1A(0-3) B1B(0-3) B1E(0-3) B1F(0-3) B1E(0-3) B1F(0-3)
+                    const __m512i rhs_mat_014589CD_11_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11) B18(8-11) B19(8-11) B18(8-11) B19(8-11) B1C(8-11) B1D(8-11) B1C(8-11) B1D(8-11)
+                    const __m512i rhs_mat_2367ABEF_11_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11) B1A(8-11) B1B(8-11) B1A(8-11) B1B(8-11) B1E(8-11) B1F(8-11) B1E(8-11) B1F(8-11)
+                    const __m512i rhs_mat_014589CD_12_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19) B18(16-19) B19(16-19) B18(16-19) B19(16-19) B1C(16-19) B1D(16-19) B1C(16-19) B1D(16-19)
+                    const __m512i rhs_mat_2367ABEF_12_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19) B1A(16-19) B1B(16-19) B1A(16-19) B1B(16-19) B1E(16-19) B1F(16-19) B1E(16-19) B1F(16-19)
+                    const __m512i rhs_mat_014589CD_13_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27) B18(24-27) B19(24-27) B18(24-27) B19(24-27) B1C(24-27) B1D(24-27) B1C(24-27) B1D(24-27)
+                    const __m512i rhs_mat_2367ABEF_13_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27) B1A(24-27) B1B(24-27) B1A(24-27) B1B(24-27) B1E(24-27) B1F(24-27) B1E(24-27) B1F(24-27)
+
+                    // Shuffle pattern two - right side input
+                    const __m512i rhs_mat_014589CD_00_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7) B08(4-7) B09(4-7) B08(4-7) B09(4-7) B0C(4-7) B0D(4-7) B0C(4-7) B0D(4-7)
+                    const __m512i rhs_mat_2367ABEF_00_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7) B0A(4-7) B0B(4-7) B0A(4-7) B0B(4-7) B0E(4-7) B0F(4-7) B0E(4-7) B0F(4-7)
+                    const __m512i rhs_mat_014589CD_01_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15) B08(12-15) B09(12-15) B08(12-15) B09(12-15) B0C(12-15) B0D(12-15) B0C(12-15) B0D(12-15)
+                    const __m512i rhs_mat_2367ABEF_01_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15) B0A(12-15) B0B(12-15) B0A(12-15) B0B(12-15) B0E(12-15) B0F(12-15) B0E(12-15) B0F(12-15)
+                    const __m512i rhs_mat_014589CD_02_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23) B08(20-23) B09(20-23) B08(20-23) B09(20-23) B0C(20-23) B0D(20-23) B0C(20-23) B0D(20-23)
+                    const __m512i rhs_mat_2367ABEF_02_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23) B0A(20-23) B0B(20-23) B0A(20-23) B0B(20-23) B0E(20-23) B0F(20-23) B0E(20-23) B0F(20-23)
+                    const __m512i rhs_mat_014589CD_03_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31) B08(28-31) B09(28-31) B08(28-31) B09(28-31) B0C(28-31) B0D(28-31) B0C(28-31) 0BD(28-31)
+                    const __m512i rhs_mat_2367ABEF_03_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31) B0A(28-31) B0B(28-31) B0A(28-31) B0B(28-31) B0E(28-31) B0F(28-31) B0E(28-31) B0F(28-31)
+
+                    const __m512i rhs_mat_014589CD_10_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7) B18(4-7) B19(4-7) B18(4-7) B19(4-7) B1C(4-7) B1D(4-7) B1C(4-7) B1D(4-7)
+                    const __m512i rhs_mat_2367ABEF_10_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7) B1A(4-7) B1B(4-7) B1A(4-7) B1B(4-7) B1E(4-7) B1F(4-7) B1E(4-7) B1F(4-7)
+                    const __m512i rhs_mat_014589CD_11_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15) B18(12-15) B19(12-15) B18(12-15) B19(12-15) B1C(12-15) B1D(12-15) B1C(12-15) B1D(12-15)
+                    const __m512i rhs_mat_2367ABEF_11_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15) B1A(12-15) B1B(12-15) B1A(12-15) B1B(12-15) B1E(12-15) B1F(12-15) B1E(12-15) B1F(12-15)
+                    const __m512i rhs_mat_014589CD_12_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23) B18(20-23) B19(20-23) B18(20-23) B19(20-23) B1C(20-23) B1D(20-23) B1C(20-23) B1D(20-23)
+                    const __m512i rhs_mat_2367ABEF_12_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23) B1A(20-23) B1B(20-23) B1A(20-23) B1B(20-23) B1E(20-23) B1F(20-23) B1E(20-23) B1F(20-23)
+                    const __m512i rhs_mat_014589CD_13_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31) B18(28-31) B19(28-31) B18(28-31) B19(28-31) B1C(28-31) B1D(28-31) B1C(28-31) B1D(28-31)
+                    const __m512i rhs_mat_2367ABEF_13_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31) B1A(28-31) B1B(28-31) B1A(28-31) B1B(28-31) B1E(28-31) B1F(28-31) B1E(28-31) B1F(28-31)
+
+                    uint32_t utmp_00[4], utmp_01[4], utmp_10[4], utmp_11[4];
+
+                    // Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together
+                    // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
+                    memcpy(utmp_00, b_ptr_0[b].scales + 24 * sb, 12);
+                    utmp_00[3] = ((utmp_00[2] >> 4) & kmask2) | (((utmp_00[1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_00 = utmp_00[1] & kmask1;
+                    utmp_00[1] = (utmp_00[2] & kmask2) | (((utmp_00[0] >> 6) & kmask3) << 4);
+                    utmp_00[2] = uaux_00;
+                    utmp_00[0] &= kmask1;
+
+                    // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
+                    memcpy(utmp_01, b_ptr_0[b].scales + 12 + sb * 24, 12);
+                    utmp_01[3] = ((utmp_01[2] >> 4) & kmask2) | (((utmp_01[1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_01 = utmp_01[1] & kmask1;
+                    utmp_01[1] = (utmp_01[2] & kmask2) | (((utmp_01[0] >> 6) & kmask3) << 4);
+                    utmp_01[2] = uaux_01;
+                    utmp_01[0] &= kmask1;
+
+                    // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
+                    memcpy(utmp_10, b_ptr_1[b].scales + sb * 24, 12);
+                    utmp_10[3] = ((utmp_10[2] >> 4) & kmask2) | (((utmp_10[1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_10 = utmp_10[1] & kmask1;
+                    utmp_10[1] = (utmp_10[2] & kmask2) | (((utmp_10[0] >> 6) & kmask3) << 4);
+                    utmp_10[2] = uaux_10;
+                    utmp_10[0] &= kmask1;
+
+                    // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
+                    memcpy(utmp_11, b_ptr_1[b].scales + 12 + sb * 24, 12);
+                    utmp_11[3] = ((utmp_11[2] >> 4) & kmask2) | (((utmp_11[1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_11 = utmp_11[1] & kmask1;
+                    utmp_11[1] = (utmp_11[2] & kmask2) | (((utmp_11[0] >> 6) & kmask3) << 4);
+                    utmp_11[2] = uaux_11;
+                    utmp_11[0] &= kmask1;
+
+                    // Scales of first sub block in the sb loop
+                    const __m256i mins_and_scales_0 = _mm256_set_epi32(utmp_10[3], utmp_10[2], utmp_10[1], utmp_10[0], utmp_00[3], utmp_00[2], utmp_00[1], utmp_00[0]);
+                    const __m512i scales_0 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0));
+
+                    // Scales of second sub block in the sb loop
+                    const __m256i mins_and_scales_1 = _mm256_set_epi32(utmp_11[3], utmp_11[2], utmp_11[1], utmp_11[0], utmp_01[3], utmp_01[2], utmp_01[1], utmp_01[0]);
+                    const __m512i scales_1 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1));
+
+                    // Mins of first and second sub block of Q4_K block are arranged side by side
+                    const __m512i mins_01 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(_mm256_shuffle_epi32(mins_and_scales_0, 78), _mm256_shuffle_epi32(mins_and_scales_1, 78)));
+
+                    const __m512i scale_014589CD_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)68);
+                    const __m512i scale_2367ABEF_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)238);
+
+                    const __m512i scale_014589CD_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)68);
+                    const __m512i scale_2367ABEF_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)238);
+
+                    // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3
+                    // Loaded as set of 128 bit vectors and repeated into a 256 bit vector
+                    __m256i lhs_mat_ymm_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 256 * sb)));
+                    __m256i lhs_mat_ymm_01_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 0);
+                    __m256i lhs_mat_ymm_23_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 17);
+                    __m256i lhs_mat_ymm_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 32 + 256 * sb)));
+                    __m256i lhs_mat_ymm_01_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 0);
+                    __m256i lhs_mat_ymm_23_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 17);
+                    __m256i lhs_mat_ymm_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 64 + 256 * sb)));
+                    __m256i lhs_mat_ymm_01_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 0);
+                    __m256i lhs_mat_ymm_23_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 17);
+                    __m256i lhs_mat_ymm_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 96 + 256 * sb)));
+                    __m256i lhs_mat_ymm_01_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 0);
+                    __m256i lhs_mat_ymm_23_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 17);
+                    __m256i lhs_mat_ymm_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 128 + 256 * sb)));
+                    __m256i lhs_mat_ymm_01_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 0);
+                    __m256i lhs_mat_ymm_23_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 17);
+                    __m256i lhs_mat_ymm_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 160 + 256 * sb)));
+                    __m256i lhs_mat_ymm_01_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 0);
+                    __m256i lhs_mat_ymm_23_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 17);
+                    __m256i lhs_mat_ymm_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 192 + 256 * sb)));
+                    __m256i lhs_mat_ymm_01_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 0);
+                    __m256i lhs_mat_ymm_23_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 17);
+                    __m256i lhs_mat_ymm_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 224 + 256 * sb)));
+                    __m256i lhs_mat_ymm_01_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 0);
+                    __m256i lhs_mat_ymm_23_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 17);
+
+                    //Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into a 512 bit vector
+                    __m512i lhs_mat_01_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_00), lhs_mat_ymm_01_00, 1);
+                    __m512i lhs_mat_23_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_00), lhs_mat_ymm_23_00, 1);
+                    __m512i lhs_mat_01_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_01), lhs_mat_ymm_01_01, 1);
+                    __m512i lhs_mat_23_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_01), lhs_mat_ymm_23_01, 1);
+                    __m512i lhs_mat_01_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_02), lhs_mat_ymm_01_02, 1);
+                    __m512i lhs_mat_23_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_02), lhs_mat_ymm_23_02, 1);
+                    __m512i lhs_mat_01_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_03), lhs_mat_ymm_01_03, 1);
+                    __m512i lhs_mat_23_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_03), lhs_mat_ymm_23_03, 1);
+
+                    __m512i lhs_mat_01_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_10), lhs_mat_ymm_01_10, 1);
+                    __m512i lhs_mat_23_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_10), lhs_mat_ymm_23_10, 1);
+                    __m512i lhs_mat_01_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_11), lhs_mat_ymm_01_11, 1);
+                    __m512i lhs_mat_23_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_11), lhs_mat_ymm_23_11, 1);
+                    __m512i lhs_mat_01_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_12), lhs_mat_ymm_01_12, 1);
+                    __m512i lhs_mat_23_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_12), lhs_mat_ymm_23_12, 1);
+                    __m512i lhs_mat_01_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_13), lhs_mat_ymm_01_13, 1);
+                    __m512i lhs_mat_23_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_13), lhs_mat_ymm_23_13, 1);
+
+                    // Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks
+                    __m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].bsums + 16 * sb)));
+                    __m256i lhs_bsums_hsum_ymm_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1)));
+                    lhs_bsums_hsum_ymm_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_ymm_0123_01, lhs_bsums_hsum_ymm_0123_01, 0);
+                    __m512i lhs_bsums_hsum_0123_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_hsum_ymm_0123_01), lhs_bsums_hsum_ymm_0123_01, 1);
+
+                    // Shuffle pattern one - left side input
+                    const __m512i lhs_mat_01_00_sp1 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)
+                    const __m512i lhs_mat_23_00_sp1 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)160); //A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3)
+                    const __m512i lhs_mat_01_01_sp1 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)
+                    const __m512i lhs_mat_23_01_sp1 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)160); //A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11)
+                    const __m512i lhs_mat_01_02_sp1 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)160); //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19)
+                    const __m512i lhs_mat_23_02_sp1 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)160); //A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19)
+                    const __m512i lhs_mat_01_03_sp1 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)160); //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27)
+                    const __m512i lhs_mat_23_03_sp1 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)160); //A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27)
+
+                    const __m512i lhs_mat_01_10_sp1 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3)
+                    const __m512i lhs_mat_23_10_sp1 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)160); //A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3)
+                    const __m512i lhs_mat_01_11_sp1 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)
+                    const __m512i lhs_mat_23_11_sp1 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)160); //A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11)
+                    const __m512i lhs_mat_01_12_sp1 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)160); //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19)
+                    const __m512i lhs_mat_23_12_sp1 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)160); //A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19)
+                    const __m512i lhs_mat_01_13_sp1 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27)
+                    const __m512i lhs_mat_23_13_sp1 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)160); //A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27)
+
+                    const __m512i lhs_mat_01_00_sp2 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7)
+                    const __m512i lhs_mat_23_00_sp2 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)245); //A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7)
+                    const __m512i lhs_mat_01_01_sp2 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15)
+                    const __m512i lhs_mat_23_01_sp2 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)245); //A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15)
+                    const __m512i lhs_mat_01_02_sp2 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23)
+                    const __m512i lhs_mat_23_02_sp2 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)245); //A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23)
+                    const __m512i lhs_mat_01_03_sp2 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31)
+                    const __m512i lhs_mat_23_03_sp2 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)245); //A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31)
+
+                    const __m512i lhs_mat_01_10_sp2 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7)
+                    const __m512i lhs_mat_23_10_sp2 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)245); //A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7)
+                    const __m512i lhs_mat_01_11_sp2 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15)
+                    const __m512i lhs_mat_23_11_sp2 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)245); //A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15)
+                    const __m512i lhs_mat_01_12_sp2 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23)
+                    const __m512i lhs_mat_23_12_sp2 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)245); //A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23)
+                    const __m512i lhs_mat_01_13_sp2 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31)
+                    const __m512i lhs_mat_23_13_sp2 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)245); //A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31)
+
+                    // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
+                    __m512i iacc_mat_00_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_01_00_sp1));
+                    __m512i iacc_mat_01_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_01_00_sp1));
+                    __m512i iacc_mat_10_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_23_00_sp1));
+                    __m512i iacc_mat_11_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_23_00_sp1));
+                    __m512i iacc_mat_00_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_01_10_sp1));
+                    __m512i iacc_mat_01_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_01_10_sp1));
+                    __m512i iacc_mat_10_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_23_10_sp1));
+                    __m512i iacc_mat_11_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_23_10_sp1));
+
+                    __m512i iacc_mat_00_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_01_00_sp2));
+                    __m512i iacc_mat_01_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_01_00_sp2));
+                    __m512i iacc_mat_10_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_23_00_sp2));
+                    __m512i iacc_mat_11_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_23_00_sp2));
+                    __m512i iacc_mat_00_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_01_10_sp2));
+                    __m512i iacc_mat_01_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_01_10_sp2));
+                    __m512i iacc_mat_10_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_23_10_sp2));
+                    __m512i iacc_mat_11_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_23_10_sp2));
+
+                    // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+                    __m512i iacc_mat_00_0 = _mm512_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2);
+                    __m512i iacc_mat_01_0 = _mm512_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2);
+                    __m512i iacc_mat_10_0 = _mm512_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2);
+                    __m512i iacc_mat_11_0 = _mm512_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2);
+
+                    __m512i iacc_mat_00_1 = _mm512_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2);
+                    __m512i iacc_mat_01_1 = _mm512_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2);
+                    __m512i iacc_mat_10_1 = _mm512_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2);
+                    __m512i iacc_mat_11_1 = _mm512_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2);
+
+                    iacc_mat_00_0 = _mm512_madd_epi16(iacc_mat_00_0, scale_014589CD_0);
+                    iacc_mat_01_0 = _mm512_madd_epi16(iacc_mat_01_0, scale_2367ABEF_0);
+                    iacc_mat_10_0 = _mm512_madd_epi16(iacc_mat_10_0, scale_014589CD_0);
+                    iacc_mat_11_0 = _mm512_madd_epi16(iacc_mat_11_0, scale_2367ABEF_0);
+
+                    iacc_mat_00_1 = _mm512_madd_epi16(iacc_mat_00_1, scale_014589CD_1);
+                    iacc_mat_01_1 = _mm512_madd_epi16(iacc_mat_01_1, scale_2367ABEF_1);
+                    iacc_mat_10_1 = _mm512_madd_epi16(iacc_mat_10_1, scale_014589CD_1);
+                    iacc_mat_11_1 = _mm512_madd_epi16(iacc_mat_11_1, scale_2367ABEF_1);
+
+                    // Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step)
+                    __m512i iacc_row_0_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_0, _mm512_shuffle_epi32(iacc_mat_01_0, (_MM_PERM_ENUM)78));
+                    __m512i iacc_row_1_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_0, (_MM_PERM_ENUM)78), iacc_mat_01_0);
+                    __m512i iacc_row_2_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_0, _mm512_shuffle_epi32(iacc_mat_11_0, (_MM_PERM_ENUM)78));
+                    __m512i iacc_row_3_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10_0, (_MM_PERM_ENUM)78), iacc_mat_11_0);
+                    __m512i iacc_row_0_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_1, _mm512_shuffle_epi32(iacc_mat_01_1, (_MM_PERM_ENUM)78));
+                    __m512i iacc_row_1_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_1, (_MM_PERM_ENUM)78), iacc_mat_01_1);
+                    __m512i iacc_row_2_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_1, _mm512_shuffle_epi32(iacc_mat_11_1, (_MM_PERM_ENUM)78));
+                    __m512i iacc_row_3_1 = _mm512_mask_blend_epi32(0xCCCC,_mm512_shuffle_epi32(iacc_mat_10_1, (_MM_PERM_ENUM)78), iacc_mat_11_1);
+
+                    __m512i iacc_row_0 = _mm512_add_epi32(iacc_row_0_0, iacc_row_0_1);
+                    __m512i iacc_row_1 = _mm512_add_epi32(iacc_row_1_0, iacc_row_1_1);
+                    __m512i iacc_row_2 = _mm512_add_epi32(iacc_row_2_0, iacc_row_2_1);
+                    __m512i iacc_row_3 = _mm512_add_epi32(iacc_row_3_0, iacc_row_3_1);
+
+                    // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes
+                    const __m128 row_scale_f32_sse = _mm_load_ps(a_ptr[b].d);
+                    const __m256 row_scale_f32_ymm = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse);
+                    const __m512 row_scale_f32 = _mm512_insertf32x8(_mm512_castps256_ps512(row_scale_f32_ymm), row_scale_f32_ymm, 1);
+
+                    // Multiply with appropiate scales and accumulate (for both d and dmin) below
+                    acc_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);
+                    acc_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);
+                    acc_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
+                    acc_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);
+
+                    __m512i iacc_row_min_0 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)0), mins_01);
+                    __m512i iacc_row_min_1 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)85), mins_01);
+                    __m512i iacc_row_min_2 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)170), mins_01);
+                    __m512i iacc_row_min_3 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)255), mins_01);
+
+                    acc_min_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_0), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[0]);
+                    acc_min_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_1), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[1]);
+                    acc_min_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_2), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[2]);
+                    acc_min_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_3), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[3]);
+                }
+            }
+            // Store accumlated values
+            for (int i = 0; i < 4; i++) {
+                _mm512_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm512_sub_ps(acc_rows[i], acc_min_rows[i]));
+            }
+        }
+    }
+    if (anc != nc) {
+        xstart = anc/8;
+        y = 0;
+    }
+#endif //AVX512F
+
+    // Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation
+    for (; y < anr / 4; y += 4) {
+
+        const block_q8_Kx4 * a_ptrs[4];
+
+        a_ptrs[0] = a_ptr_start + (y * nb);
+        for (int i = 0; i < 3; ++i) {
+            a_ptrs[i + 1] = a_ptrs[i] + nb;
+        }
+
+        // Take group of eight block_q4_kx8 structures at each pass of the loop and perform dot product operation
+        for (int64_t x = xstart; x < nc / 8; x++) {
+
+            const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb);
+
+            // Master FP accumulators
+            __m256 acc_rows[16];
+            for (int i = 0; i < 16; i++) {
+                acc_rows[i] = _mm256_setzero_ps();
+            }
+
+            __m256 acc_min_rows[16];
+            for (int i = 0; i < 16; i++) {
+                acc_min_rows[i] = _mm256_setzero_ps();
+            }
+
+            // For super block
+            for (int64_t b = 0; b < nb; b++) {
+
+                // Scale values - Load the eight scale values of block_q4_kx8
+                const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
+
+                // dmin values - Load the eight dmin values of block_q4_kx8
+                const __m256 col_dmin_f32 = GGML_F32Cx8_LOAD(b_ptr[b].dmin);
+
+                // Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration
+                for (int sb = 0; sb < QK_K / 64; sb++) {
+
+                    // Load the eight block_q4_K for two sub blocks quantized values interleaved with each other in chunks of eight bytes - B0,B1 ....B6,B7
+                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256));
+                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256));
+                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256));
+                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256));
+                    const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256));
+                    const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256));
+                    const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256));
+                    const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256));
 
-                    const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160);  //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
-                    const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160);  //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
+                    // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values
+                    const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
+                    const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
+                    const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240);
+                    const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240);
 
-                    const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160);  //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
-                    const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160);  //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
+                    // 4-bit -> 8-bit
+                    // First sub block of the two sub blocks processed in the iteration
+                    const __m256i rhs_mat_0145_00 = _mm256_and_si256(rhs_raw_mat_0145_0, m4b); //B00(0-7) B01(0-7) B04(0-7) B05(0-7)
+                    const __m256i rhs_mat_2367_00 = _mm256_and_si256(rhs_raw_mat_2367_0, m4b); //B02(0-7) B03(0-7) B06(0-7) B07(0-7)
 
-                    // Shuffle pattern two - left side input
+                    const __m256i rhs_mat_0145_01 = _mm256_and_si256(rhs_raw_mat_0145_1, m4b); //B00(8-15) B01(8-15) B04(8-15) B05(8-15)
+                    const __m256i rhs_mat_2367_01 = _mm256_and_si256(rhs_raw_mat_2367_1, m4b); //B02(8-15) B03(8-15) B06(8-15) B07(8-15)
 
-                    const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245);  //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
-                    const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245);  //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
+                    const __m256i rhs_mat_0145_02 = _mm256_and_si256(rhs_raw_mat_0145_2, m4b); //B00(16-23) B01(16-23) B04(16-23) B05(16-23)
+                    const __m256i rhs_mat_2367_02 = _mm256_and_si256(rhs_raw_mat_2367_2, m4b); //B02(16-23) B03(16-23) B06(16-23) B07(16-23)
 
-                    const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245);  //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
-                    const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245);  //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
+                    const __m256i rhs_mat_0145_03 = _mm256_and_si256(rhs_raw_mat_0145_3, m4b); //B00(24-31) B01(24-31) B04(24-31) B05(24-31)
+                    const __m256i rhs_mat_2367_03 = _mm256_and_si256(rhs_raw_mat_2367_3, m4b); //B02(24-31) B03(24-31) B06(24-31) B07(24-31)
 
-                    const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245);  //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
-                    const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245);  //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
+                    // Second sub block of the two sub blocks processed in the iteration
+                    const __m256i rhs_mat_0145_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b); //B10(0-7) B11(0-7) B14(0-7) B15(0-7)
+                    const __m256i rhs_mat_2367_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b); //B12(0-7) B13(0-7) B16(0-7) B17(0-7)
+
+                    const __m256i rhs_mat_0145_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b); //B10(8-15) B11(8-15) B14(8-15) B15(8-15)
+                    const __m256i rhs_mat_2367_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b); //B12(8-15) B13(8-15) B16(8-15) B17(8-15)
+
+                    const __m256i rhs_mat_0145_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 4), m4b); //B10(16-23) B11(16-23) B14(16-23) B15(16-23)
+                    const __m256i rhs_mat_2367_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 4), m4b); //B12(16-23) B13(16-23) B16(16-23) B17(16-23)
+
+                    const __m256i rhs_mat_0145_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 4), m4b); //B10(24-31) B11(24-31) B14(24-31) B15(24-31)
+                    const __m256i rhs_mat_2367_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 4), m4b); //B12(24-31) B13(24-31) B16(24-31) B17(24-31)
+
+                    // Shuffle pattern one - right side input
+                    const __m256i rhs_mat_0145_00_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_00, 136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3)
+                    const __m256i rhs_mat_2367_00_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_00, 136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3)
+
+                    const __m256i rhs_mat_0145_01_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_01, 136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11)
+                    const __m256i rhs_mat_2367_01_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_01, 136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11)
+
+                    const __m256i rhs_mat_0145_02_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_02, 136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19)
+                    const __m256i rhs_mat_2367_02_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_02, 136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19)
+
+                    const __m256i rhs_mat_0145_03_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_03, 136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27)
+                    const __m256i rhs_mat_2367_03_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_03, 136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27)
+
+                    const __m256i rhs_mat_0145_10_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_10, 136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3)
+                    const __m256i rhs_mat_2367_10_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_10, 136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3)
+
+                    const __m256i rhs_mat_0145_11_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_11, 136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11)
+                    const __m256i rhs_mat_2367_11_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_11, 136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11)
+
+                    const __m256i rhs_mat_0145_12_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_12, 136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19)
+                    const __m256i rhs_mat_2367_12_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_12, 136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19)
+
+                    const __m256i rhs_mat_0145_13_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_13, 136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27)
+                    const __m256i rhs_mat_2367_13_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_13, 136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27)
+
+
+                    // Shuffle pattern two - right side input
+                    const __m256i rhs_mat_0145_00_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_00, 221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7)
+                    const __m256i rhs_mat_2367_00_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_00, 221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7)
+
+                    const __m256i rhs_mat_0145_01_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_01, 221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15)
+                    const __m256i rhs_mat_2367_01_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_01, 221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15)
+
+                    const __m256i rhs_mat_0145_02_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_02, 221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23)
+                    const __m256i rhs_mat_2367_02_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_02, 221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23)
+
+                    const __m256i rhs_mat_0145_03_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_03, 221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31)
+                    const __m256i rhs_mat_2367_03_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_03, 221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31)
+
+                    const __m256i rhs_mat_0145_10_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_10, 221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7)
+                    const __m256i rhs_mat_2367_10_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_10, 221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7)
+
+                    const __m256i rhs_mat_0145_11_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_11, 221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15)
+                    const __m256i rhs_mat_2367_11_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_11, 221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15)
+
+                    const __m256i rhs_mat_0145_12_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_12, 221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23)
+                    const __m256i rhs_mat_2367_12_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_12, 221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23)
+
+                    const __m256i rhs_mat_0145_13_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_13, 221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31)
+                    const __m256i rhs_mat_2367_13_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_13, 221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31)
+
+                    uint32_t utmp_0[4], utmp_1[4];
+
+                    // Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together
+                    // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
+                    memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12);
+                    utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_0 = utmp_0[1] & kmask1;
+                    utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);
+                    utmp_0[2] = uaux_0;
+                    utmp_0[0] &= kmask1;
+
+                    // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
+                    memcpy(utmp_1, b_ptr[b].scales + 12 + sb * 24, 12);
+                    utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_1 = utmp_1[1] & kmask1;
+                    utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);
+                    utmp_1[2] = uaux_1;
+                    utmp_1[0] &= kmask1;
+
+                    // Scales of first sub block in the sb loop
+                    const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);
+                    const __m256i scales_0 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0));
+
+                    // Scales of second sub block in the sb loop
+                    const __m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);
+                    const __m256i scales_1 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1));
+
+                    // Mins of first and second sub block of Q4_K block are arranged side by side
+                    const __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(_mm_shuffle_epi32(mins_and_scales_0, 78), _mm_shuffle_epi32(mins_and_scales_1, 78)));
+
+                    const __m256i scale_0145_0 = _mm256_shuffle_epi32(scales_0, 68);
+                    const __m256i scale_2367_0 = _mm256_shuffle_epi32(scales_0, 238);
+
+                    const __m256i scale_0145_1 = _mm256_shuffle_epi32(scales_1, 68);
+                    const __m256i scale_2367_1 = _mm256_shuffle_epi32(scales_1, 238);
+
+                    for (int rp = 0; rp < 4; rp++) {
+
+                        // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3
+                        // Loaded as set of 128 bit vectors and repeated into a 256 bit vector
+                        __m256i lhs_mat_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 256 * sb)));
+                        __m256i lhs_mat_01_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 0);
+                        __m256i lhs_mat_23_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 17);
+                        __m256i lhs_mat_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 32 + 256 * sb)));
+                        __m256i lhs_mat_01_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 0);
+                        __m256i lhs_mat_23_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 17);
+                        __m256i lhs_mat_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 64 + 256 * sb)));
+                        __m256i lhs_mat_01_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 0);
+                        __m256i lhs_mat_23_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 17);
+                        __m256i lhs_mat_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 96 + 256 * sb)));
+                        __m256i lhs_mat_01_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 0);
+                        __m256i lhs_mat_23_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 17);
+                        __m256i lhs_mat_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 128 + 256 * sb)));
+                        __m256i lhs_mat_01_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 0);
+                        __m256i lhs_mat_23_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 17);
+                        __m256i lhs_mat_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 160 + 256 * sb)));
+                        __m256i lhs_mat_01_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 0);
+                        __m256i lhs_mat_23_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 17);
+                        __m256i lhs_mat_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 192 + 256 * sb)));
+                        __m256i lhs_mat_01_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 0);
+                        __m256i lhs_mat_23_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 17);
+                        __m256i lhs_mat_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 224 + 256 * sb)));
+                        __m256i lhs_mat_01_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 0);
+                        __m256i lhs_mat_23_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 17);
+
+                        // Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks
+                        __m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].bsums + 16 * sb)));
+                        __m256i lhs_bsums_hsum_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1)));
+                        lhs_bsums_hsum_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_0123_01, lhs_bsums_hsum_0123_01, 0);
+
+                        // Shuffle pattern one - left side input
+                        const __m256i lhs_mat_01_00_sp1 = _mm256_shuffle_epi32(lhs_mat_01_00, 160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)
+                        const __m256i lhs_mat_23_00_sp1 = _mm256_shuffle_epi32(lhs_mat_23_00, 160); //A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3)
+
+                        const __m256i lhs_mat_01_01_sp1 = _mm256_shuffle_epi32(lhs_mat_01_01, 160);  //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)
+                        const __m256i lhs_mat_23_01_sp1 = _mm256_shuffle_epi32(lhs_mat_23_01, 160);  //A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11)
+
+                        const __m256i lhs_mat_01_02_sp1 = _mm256_shuffle_epi32(lhs_mat_01_02, 160);  //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19)
+                        const __m256i lhs_mat_23_02_sp1 = _mm256_shuffle_epi32(lhs_mat_23_02, 160);  //A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19)
+
+                        const __m256i lhs_mat_01_03_sp1 = _mm256_shuffle_epi32(lhs_mat_01_03, 160);  //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27)
+                        const __m256i lhs_mat_23_03_sp1 = _mm256_shuffle_epi32(lhs_mat_23_03, 160); //A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27)
+
+                        const __m256i lhs_mat_01_10_sp1 = _mm256_shuffle_epi32(lhs_mat_01_10, 160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3)
+                        const __m256i lhs_mat_23_10_sp1 = _mm256_shuffle_epi32(lhs_mat_23_10, 160); //A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3)
+
+                        const __m256i lhs_mat_01_11_sp1 = _mm256_shuffle_epi32(lhs_mat_01_11, 160);  //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)
+                        const __m256i lhs_mat_23_11_sp1 = _mm256_shuffle_epi32(lhs_mat_23_11, 160);  //A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11)
+
+                        const __m256i lhs_mat_01_12_sp1 = _mm256_shuffle_epi32(lhs_mat_01_12, 160);  //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19)
+                        const __m256i lhs_mat_23_12_sp1 = _mm256_shuffle_epi32(lhs_mat_23_12, 160);  //A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19)
+
+                        const __m256i lhs_mat_01_13_sp1 = _mm256_shuffle_epi32(lhs_mat_01_13, 160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27)
+                        const __m256i lhs_mat_23_13_sp1 = _mm256_shuffle_epi32(lhs_mat_23_13, 160); //A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27)
+
+                        // Shuffle pattern two- left side input
+                        const __m256i lhs_mat_01_00_sp2 = _mm256_shuffle_epi32(lhs_mat_01_00, 245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7)
+                        const __m256i lhs_mat_23_00_sp2 = _mm256_shuffle_epi32(lhs_mat_23_00, 245); //A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7)
+
+                        const __m256i lhs_mat_01_01_sp2 = _mm256_shuffle_epi32(lhs_mat_01_01, 245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15)
+                        const __m256i lhs_mat_23_01_sp2 = _mm256_shuffle_epi32(lhs_mat_23_01, 245); //A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15)
+
+                        const __m256i lhs_mat_01_02_sp2 = _mm256_shuffle_epi32(lhs_mat_01_02, 245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23)
+                        const __m256i lhs_mat_23_02_sp2 = _mm256_shuffle_epi32(lhs_mat_23_02, 245); //A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23)
+
+                        const __m256i lhs_mat_01_03_sp2 = _mm256_shuffle_epi32(lhs_mat_01_03, 245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31)
+                        const __m256i lhs_mat_23_03_sp2 = _mm256_shuffle_epi32(lhs_mat_23_03, 245); //A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31)
+
+                        const __m256i lhs_mat_01_10_sp2 = _mm256_shuffle_epi32(lhs_mat_01_10, 245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7)
+                        const __m256i lhs_mat_23_10_sp2 = _mm256_shuffle_epi32(lhs_mat_23_10, 245); //A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7)
+
+                        const __m256i lhs_mat_01_11_sp2 = _mm256_shuffle_epi32(lhs_mat_01_11, 245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15)
+                        const __m256i lhs_mat_23_11_sp2 = _mm256_shuffle_epi32(lhs_mat_23_11, 245); //A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15)
+
+                        const __m256i lhs_mat_01_12_sp2 = _mm256_shuffle_epi32(lhs_mat_01_12, 245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23)
+                        const __m256i lhs_mat_23_12_sp2 = _mm256_shuffle_epi32(lhs_mat_23_12, 245); //A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23)
+
+                        const __m256i lhs_mat_01_13_sp2 = _mm256_shuffle_epi32(lhs_mat_01_13, 245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31)
+                        const __m256i lhs_mat_23_13_sp2 = _mm256_shuffle_epi32(lhs_mat_23_13, 245); //A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31)
+
+                        // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
+                        __m256i iacc_mat_00_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_01_00_sp1));
+                        __m256i iacc_mat_01_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_01_00_sp1));
+                        __m256i iacc_mat_10_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_23_00_sp1));
+                        __m256i iacc_mat_11_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_23_00_sp1));
+                        __m256i iacc_mat_00_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_01_10_sp1));
+                        __m256i iacc_mat_01_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_01_10_sp1));
+                        __m256i iacc_mat_10_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_23_10_sp1));
+                        __m256i iacc_mat_11_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_23_10_sp1));
+
+                        __m256i iacc_mat_00_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_01_00_sp2));
+                        __m256i iacc_mat_01_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_01_00_sp2));
+                        __m256i iacc_mat_10_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_23_00_sp2));
+                        __m256i iacc_mat_11_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_23_00_sp2));
+                        __m256i iacc_mat_00_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_01_10_sp2));
+                        __m256i iacc_mat_01_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_01_10_sp2));
+                        __m256i iacc_mat_10_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_23_10_sp2));
+                        __m256i iacc_mat_11_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_23_10_sp2));
+
+                        // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+                        __m256i iacc_mat_00_0 = _mm256_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2);
+                        __m256i iacc_mat_01_0 = _mm256_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2);
+                        __m256i iacc_mat_10_0 = _mm256_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2);
+                        __m256i iacc_mat_11_0 = _mm256_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2);
+
+                        __m256i iacc_mat_00_1 = _mm256_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2);
+                        __m256i iacc_mat_01_1 = _mm256_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2);
+                        __m256i iacc_mat_10_1 = _mm256_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2);
+                        __m256i iacc_mat_11_1 = _mm256_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2);
+
+                        // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+                        iacc_mat_00_0 = _mm256_madd_epi16(iacc_mat_00_0, scale_0145_0);
+                        iacc_mat_01_0 = _mm256_madd_epi16(iacc_mat_01_0, scale_2367_0);
+                        iacc_mat_10_0 = _mm256_madd_epi16(iacc_mat_10_0, scale_0145_0);
+                        iacc_mat_11_0 = _mm256_madd_epi16(iacc_mat_11_0, scale_2367_0);
 
-                    const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245);  //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
-                    const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245);  //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
+                        iacc_mat_00_1 = _mm256_madd_epi16(iacc_mat_00_1, scale_0145_1);
+                        iacc_mat_01_1 = _mm256_madd_epi16(iacc_mat_01_1, scale_2367_1);
+                        iacc_mat_10_1 = _mm256_madd_epi16(iacc_mat_10_1, scale_0145_1);
+                        iacc_mat_11_1 = _mm256_madd_epi16(iacc_mat_11_1, scale_2367_1);
 
-                    // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
-                    // Resembles MMLAs into 2x2 matrices in ARM Version
-                    const __m512i zero = _mm512_setzero_epi32();
-                    __m512i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1);
-                    __m512i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1);
-                    __m512i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1);
-                    __m512i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1);
-                    __m512i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2);
-                    __m512i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2);
-                    __m512i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2);
-                    __m512i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2);
+                        // Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step)
+                        __m256i iacc_row_0_0 = _mm256_blend_epi32(iacc_mat_00_0, _mm256_shuffle_epi32(iacc_mat_01_0, 78), 204);
+                        __m256i iacc_row_1_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_0, 78), iacc_mat_01_0, 204);
+                        __m256i iacc_row_2_0 = _mm256_blend_epi32(iacc_mat_10_0, _mm256_shuffle_epi32(iacc_mat_11_0, 78), 204);
+                        __m256i iacc_row_3_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_0, 78), iacc_mat_11_0, 204);
+                        __m256i iacc_row_0_1 = _mm256_blend_epi32(iacc_mat_00_1, _mm256_shuffle_epi32(iacc_mat_01_1, 78), 204);
+                        __m256i iacc_row_1_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_1, 78), iacc_mat_01_1, 204);
+                        __m256i iacc_row_2_1 = _mm256_blend_epi32(iacc_mat_10_1, _mm256_shuffle_epi32(iacc_mat_11_1, 78), 204);
+                        __m256i iacc_row_3_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_1, 78), iacc_mat_11_1, 204);
 
-                    // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
-                    __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
-                    __m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
-                    __m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
-                    __m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
+                        __m256i iacc_row_0 = _mm256_add_epi32(iacc_row_0_0, iacc_row_0_1);
+                        __m256i iacc_row_1 = _mm256_add_epi32(iacc_row_1_0, iacc_row_1_1);
+                        __m256i iacc_row_2 = _mm256_add_epi32(iacc_row_2_0, iacc_row_2_1);
+                        __m256i iacc_row_3 = _mm256_add_epi32(iacc_row_3_0, iacc_row_3_1);
 
+                        // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes
+                        const __m128 row_scale_f32_sse = _mm_load_ps(a_ptrs[rp][b].d);
+                        const __m256 row_scale_f32 = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse);//GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask);
 
-                    // Straighten out to make 4 row vectors
-                    __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));
-                    __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
-                    __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
-                    __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
+                        // Multiply with appropiate scales and accumulate (for both d and dmin) below
+                        acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);
+                        acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);
+                        acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
+                        acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);
 
-                    // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
-                    const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptr[b].d), loadMask), 68);
-                    const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16);
+                        __m256i iacc_row_min_0 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 0), mins_01);
+                        __m256i iacc_row_min_1 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 85), mins_01);
+                        __m256i iacc_row_min_2 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 170), mins_01);
+                        __m256i iacc_row_min_3 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 255), mins_01);
 
-                    // Multiply with appropiate scales and accumulate
-                    acc_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)),   acc_rows[0]);
-                    acc_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)),  acc_rows[1]);
-                    acc_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
-                    acc_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);
-                }
+                        acc_min_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_0), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[rp * 4]);
+                        acc_min_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_1), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[rp * 4 + 1]);
+                        acc_min_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_2), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[rp * 4 + 2]);
+                        acc_min_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_3), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[rp * 4 + 3]);
 
-                // Store the accumulated values
-                for (int i = 0; i < 4; i++) {
-                    _mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
+                    }
                 }
             }
+            // Store the accumulated values
+            for (int i = 0; i < 16; i++) {
+                _mm256_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm256_sub_ps(acc_rows[i], acc_min_rows[i]));
+            }
         }
-        if (anc != nc) {
-            xstart = anc/8;
-            y = 0;
-        }
-    #endif // __AVX512F__
+    }
+    for (; y < nr / 4; y++) {
 
-        // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
+        const block_q8_Kx4 * a_ptr = a_ptr_start + (y * nb);
+
+        for (int64_t x = xstart; x < nc / 8; x++) {
+
+            const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb);
 
-        for (; y < anr / 4; y += 4) {
-            const block_q8_0x4 * a_ptrs[4];
+            // Master FP accumulators
+            __m256 acc_rows[4];
+            for (int i = 0; i < 4; i++) {
+                acc_rows[i] = _mm256_setzero_ps();
+            }
 
-            a_ptrs[0] = a_ptr_start + (y * nb);
-            for (int i = 0; i < 3; ++i) {
-                a_ptrs[i + 1] = a_ptrs[i] + nb;
+            __m256 acc_min_rows[4];
+            for (int i = 0; i < 4; i++) {
+                acc_min_rows[i] = _mm256_setzero_ps();
             }
 
-            // Take group of eight block_q4_0x8 structures at each pass of the loop and perform dot product operation
-            for (int64_t x = xstart; x < nc / 8; x++) {
+            for (int64_t b = 0; b < nb; b++) {
+
+                // Scale values - Load the eight scale values of block_q4_Kx8
+                const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
 
-                const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb);
+                // dmin values - Load the eight dmin values of block_q4_Kx8
+                const __m256 col_dmin_f32 = GGML_F32Cx8_LOAD(b_ptr[b].dmin);
 
-                // Master FP accumulators
-                __m256 acc_rows[16];
-                for (int i = 0; i < 16; i++) {
-                    acc_rows[i] = _mm256_setzero_ps();
-                }
+                // Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration
+                for (int sb = 0; sb < QK_K / 64; sb++) {
 
-                for (int64_t b = 0; b < nb; b++) {
-                    // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
-                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));
-                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32));
-                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64));
-                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96));
+                    // Load the eight block_q4_k for two sub blocks quantized values interleaved with each other in chunks of eight bytes - B0,B1 ....B6,B7
+                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256));
+                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256));
+                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256));
+                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256));
+                    const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256));
+                    const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256));
+                    const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256));
+                    const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256));
 
                     // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values
                     const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
                     const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
                     const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
                     const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
+                    const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240);
+                    const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240);
 
-                    // 4-bit -> 8-bit - Sign is maintained
-                    const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7)
-                    const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7)
-
-                    const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15)
-                    const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15)
-
-                    const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23)
-                    const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23)
-
-                    const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31)
-                    const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31)
-
-                    // Shuffle pattern one - right side input
-                    const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136);  //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3)
-                    const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136);  //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3)
-
-                    const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136);  //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11)
-                    const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136);  //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11)
-
-                    const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136);  //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19)
-                    const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136);  //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19)
-
-                    const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136);  //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27)
-                    const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136);  //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27)
-
-                    // Shuffle pattern two - right side input
+                    // 4-bit -> 8-bit
+                    // First sub block of the two sub blocks processed in the iteration
+                    const __m256i rhs_mat_0145_00 = _mm256_and_si256(rhs_raw_mat_0145_0, m4b); //B00(0-7) B01(0-7) B04(0-7) B05(0-7)
+                    const __m256i rhs_mat_2367_00 = _mm256_and_si256(rhs_raw_mat_2367_0, m4b); //B02(0-7) B03(0-7) B06(0-7) B07(0-7)
 
-                    const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221);  //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7)
-                    const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221);  //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7)
+                    const __m256i rhs_mat_0145_01 = _mm256_and_si256(rhs_raw_mat_0145_1, m4b); //B00(8-15) B01(8-15) B04(8-15) B05(8-15)
+                    const __m256i rhs_mat_2367_01 = _mm256_and_si256(rhs_raw_mat_2367_1, m4b); //B02(8-15) B03(8-15) B06(8-15) B07(8-15)
 
-                    const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221);  //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15)
-                    const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221);  //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15)
+                    const __m256i rhs_mat_0145_02 = _mm256_and_si256(rhs_raw_mat_0145_2, m4b); //B00(16-23) B01(16-23) B04(16-23) B05(16-23)
+                    const __m256i rhs_mat_2367_02 = _mm256_and_si256(rhs_raw_mat_2367_2, m4b); //B02(16-23) B03(16-23) B06(16-23) B07(16-23)
 
-                    const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221);  //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23)
-                    const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221);  //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23)
+                    const __m256i rhs_mat_0145_03 = _mm256_and_si256(rhs_raw_mat_0145_3, m4b); //B00(24-31) B01(24-31) B04(24-31) B05(24-31)
+                    const __m256i rhs_mat_2367_03 = _mm256_and_si256(rhs_raw_mat_2367_3, m4b); //B02(24-31) B03(24-31) B06(24-31) B07(24-31)
 
-                    const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221);  //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31)
-                    const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221);  //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31)
+                    // Second sub block of the two sub blocks processed in the iteration
+                    const __m256i rhs_mat_0145_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b); //B10(0-7) B11(0-7) B14(0-7) B15(0-7)
+                    const __m256i rhs_mat_2367_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b); //B12(0-7) B13(0-7) B16(0-7) B17(0-7)
 
-                    // Scale values - Load the wight scale values of block_q4_0x8
-                    const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
+                    const __m256i rhs_mat_0145_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b); //B10(8-15) B11(8-15) B14(8-15) B15(8-15)
+                    const __m256i rhs_mat_2367_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b); //B12(8-15) B13(8-15) B16(8-15) B17(8-15)
 
-                    // Process LHS in groups of four
-                    for (int rp = 0; rp < 4; rp++) {
-                        // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
-                        // Loaded as set of 128 bit vectors and repeated into a 256 bit vector
-                        __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs)));
-                        __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0);
-                        __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17);
-                        __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32)));
-                        __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0);
-                        __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17);
-                        __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64)));
-                        __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0);
-                        __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17);
-                        __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96)));
-                        __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0);
-                        __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17);
+                    const __m256i rhs_mat_0145_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 4), m4b); //B10(16-23) B11(16-23) B14(16-23) B15(16-23)
+                    const __m256i rhs_mat_2367_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 4), m4b); //B12(16-23) B13(16-23) B16(16-23) B17(16-23)
 
-                        // Shuffle pattern one - left side input
-                        const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160);  //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
-                        const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160);  //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
+                    const __m256i rhs_mat_0145_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 4), m4b); //B10(24-31) B11(24-31) B14(24-31) B15(24-31)
+                    const __m256i rhs_mat_2367_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 4), m4b); //B12(24-31) B13(24-31) B16(24-31) B17(24-31)
 
-                        const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160);  //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
-                        const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160);  //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
+                    // Shuffle pattern one - right side input
+                    const __m256i rhs_mat_0145_00_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_00, 136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3)
+                    const __m256i rhs_mat_2367_00_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_00, 136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3)
 
-                        const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160);  //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
-                        const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160);  //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
+                    const __m256i rhs_mat_0145_01_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_01, 136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11)
+                    const __m256i rhs_mat_2367_01_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_01, 136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11)
 
-                        const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160);  //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
-                        const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160);  //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
+                    const __m256i rhs_mat_0145_02_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_02, 136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19)
+                    const __m256i rhs_mat_2367_02_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_02, 136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19)
 
-                        // Shuffle pattern two - left side input
-                        const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245);  //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
-                        const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245);  //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
+                    const __m256i rhs_mat_0145_03_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_03, 136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27)
+                    const __m256i rhs_mat_2367_03_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_03, 136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27)
 
-                        const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245);  //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
-                        const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245);  //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
+                    const __m256i rhs_mat_0145_10_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_10, 136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3)
+                    const __m256i rhs_mat_2367_10_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_10, 136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3)
 
-                        const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245);  //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
-                        const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245);  //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
+                    const __m256i rhs_mat_0145_11_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_11, 136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11)
+                    const __m256i rhs_mat_2367_11_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_11, 136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11)
 
-                        const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245);  //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
-                        const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245);  //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
+                    const __m256i rhs_mat_0145_12_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_12, 136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19)
+                    const __m256i rhs_mat_2367_12_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_12, 136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19)
 
-                        // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
-                        // Resembles MMLAs into 2x2 matrices in ARM Version
-                        const __m256i zero = _mm256_setzero_si256();
-                        __m256i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1);
-                        __m256i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1);
-                        __m256i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1);
-                        __m256i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1);
-                        __m256i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2);
-                        __m256i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2);
-                        __m256i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2);
-                        __m256i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2);
+                    const __m256i rhs_mat_0145_13_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_13, 136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27)
+                    const __m256i rhs_mat_2367_13_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_13, 136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27)
 
-                        // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
-                        __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
-                        __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
-                        __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
-                        __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
+                    // Shuffle pattern two - right side input
+                    const __m256i rhs_mat_0145_00_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_00, 221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7)
+                    const __m256i rhs_mat_2367_00_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_00, 221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7)
 
-                        // Straighten out to make 4 row vectors
-                        __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204);
-                        __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204);
-                        __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204);
-                        __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204);
+                    const __m256i rhs_mat_0145_01_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_01, 221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15)
+                    const __m256i rhs_mat_2367_01_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_01, 221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15)
 
-                        // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
-                        const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask);
+                    const __m256i rhs_mat_0145_02_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_02, 221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23)
+                    const __m256i rhs_mat_2367_02_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_02, 221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23)
 
-                        // Multiply with appropiate scales and accumulate
-                        acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);
-                        acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);
-                        acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
-                        acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32,  255)), acc_rows[rp * 4 + 3]);
-                    }
-                }
+                    const __m256i rhs_mat_0145_03_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_03, 221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31)
+                    const __m256i rhs_mat_2367_03_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_03, 221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31)
 
-                // Store the accumulated values
-                for (int i = 0; i < 16; i++) {
-                    _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
-                }
-            }
-        }
+                    const __m256i rhs_mat_0145_10_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_10, 221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7)
+                    const __m256i rhs_mat_2367_10_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_10, 221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7)
 
-        // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation
-        for (; y < nr / 4; y ++) {
+                    const __m256i rhs_mat_0145_11_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_11, 221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15)
+                    const __m256i rhs_mat_2367_11_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_11, 221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15)
 
-            const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb);
+                    const __m256i rhs_mat_0145_12_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_12, 221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23)
+                    const __m256i rhs_mat_2367_12_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_12, 221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23)
 
-            // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
-            for (int64_t x = xstart; x < nc / 8; x++) {
+                    const __m256i rhs_mat_0145_13_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_13, 221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31)
+                    const __m256i rhs_mat_2367_13_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_13, 221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31)
 
-                const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb);
+                    uint32_t utmp_0[4], utmp_1[4];
 
-                // Master FP accumulators
-                __m256 acc_rows[4];
-                for (int i = 0; i < 4; i++) {
-                    acc_rows[i] = _mm256_setzero_ps();
-                }
+                    // Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together
+                    // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
+                    memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12);
+                    utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_0 = utmp_0[1] & kmask1;
+                    utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);
+                    utmp_0[2] = uaux_0;
+                    utmp_0[0] &= kmask1;
 
-                for (int64_t b = 0; b < nb; b++) {
-                    // Load the eight block_q8_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
-                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));
-                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32));
-                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64));
-                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96));
+                    // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures when sb = 1
+                    memcpy(utmp_1, b_ptr[b].scales + 12 + sb * 24, 12);
+                    utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_1 = utmp_1[1] & kmask1;
+                    utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);
+                    utmp_1[2] = uaux_1;
+                    utmp_1[0] &= kmask1;
 
-                    // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess
-                    const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
-                    const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
-                    const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
-                    const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
+                    // Scales of first sub block in the sb loop
+                    const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);
+                    const __m256i scales_0 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0));
 
-                    // 4-bit -> 8-bit - Sign is maintained
-                    const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b));  //B0(0-7) B1(0-7) B4(0-7) B5(0-7)
-                    const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b));  //B2(0-7) B3(0-7) B6(0-7) B7(0-7)
+                    // Scales of second sub block in the sb loop
+                    const __m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);
+                    const __m256i scales_1 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1));
 
-                    const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b));  //B0(8-15) B1(8-15) B4(8-15) B5(8-15)
-                    const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b));  //B2(8-15) B3(8-15) B6(8-15) B7(8-15)
+                    // Mins of first and second sub block of Q4_K block are arranged side by side
+                    const __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(_mm_shuffle_epi32(mins_and_scales_0, 78), _mm_shuffle_epi32(mins_and_scales_1, 78)));
 
-                    const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b));  //B0(16-23) B1(16-23) B4(16-23) B5(16-23)
-                    const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b));  //B2(16-23) B3(16-23) B6(16-23) B7(16-23)
+                    const __m256i scale_0145_0 = _mm256_shuffle_epi32(scales_0, 68);
+                    const __m256i scale_2367_0 = _mm256_shuffle_epi32(scales_0, 238);
 
-                    const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b));  //B0(24-31) B1(24-31) B4(24-31) B5(24-31)
-                    const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b));  //B2(24-31) B3(24-31) B6(24-31) B7(24-31)
+                    const __m256i scale_0145_1 = _mm256_shuffle_epi32(scales_1, 68);
+                    const __m256i scale_2367_1 = _mm256_shuffle_epi32(scales_1, 238);
 
-                    // Shuffle pattern one - right side input
-                    const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136);  //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3)
-                    const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136);  //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3)
+                    // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3
+                    // Loaded as set of 128 bit vectors and repeated into a 256 bit vector
+                    __m256i lhs_mat_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 256 * sb)));
+                    __m256i lhs_mat_01_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 0);
+                    __m256i lhs_mat_23_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 17);
+                    __m256i lhs_mat_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 32 + 256 * sb)));
+                    __m256i lhs_mat_01_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 0);
+                    __m256i lhs_mat_23_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 17);
+                    __m256i lhs_mat_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 64 + 256 * sb)));
+                    __m256i lhs_mat_01_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 0);
+                    __m256i lhs_mat_23_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 17);
+                    __m256i lhs_mat_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 96 + 256 * sb)));
+                    __m256i lhs_mat_01_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 0);
+                    __m256i lhs_mat_23_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 17);
+                    __m256i lhs_mat_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 128 + 256 * sb)));
+                    __m256i lhs_mat_01_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 0);
+                    __m256i lhs_mat_23_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 17);
+                    __m256i lhs_mat_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 160 + 256 * sb)));
+                    __m256i lhs_mat_01_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 0);
+                    __m256i lhs_mat_23_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 17);
+                    __m256i lhs_mat_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 192 + 256 * sb)));
+                    __m256i lhs_mat_01_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 0);
+                    __m256i lhs_mat_23_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 17);
+                    __m256i lhs_mat_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 224 + 256 * sb)));
+                    __m256i lhs_mat_01_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 0);
+                    __m256i lhs_mat_23_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 17);
 
-                    const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136);  //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11)
-                    const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136);  //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11)
+                    // Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks
+                    __m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].bsums + 16 * sb)));
+                    __m256i lhs_bsums_hsum_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1)));
+                    lhs_bsums_hsum_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_0123_01, lhs_bsums_hsum_0123_01, 0);
 
-                    const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136);  //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19)
-                    const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136);  //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19)
+                    // Shuffle pattern one - left side input
+                    const __m256i lhs_mat_01_00_sp1 = _mm256_shuffle_epi32(lhs_mat_01_00, 160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)
+                    const __m256i lhs_mat_23_00_sp1 = _mm256_shuffle_epi32(lhs_mat_23_00, 160); //A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3)
 
-                    const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136);  //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27)
-                    const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136);  //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27)
+                    const __m256i lhs_mat_01_01_sp1 = _mm256_shuffle_epi32(lhs_mat_01_01, 160);  //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)
+                    const __m256i lhs_mat_23_01_sp1 = _mm256_shuffle_epi32(lhs_mat_23_01, 160);  //A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11)
 
-                    // Shuffle pattern two - right side input
+                    const __m256i lhs_mat_01_02_sp1 = _mm256_shuffle_epi32(lhs_mat_01_02, 160);  //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19)
+                    const __m256i lhs_mat_23_02_sp1 = _mm256_shuffle_epi32(lhs_mat_23_02, 160);  //A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19)
 
-                    const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221);  //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7)
-                    const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221);  //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7)
+                    const __m256i lhs_mat_01_03_sp1 = _mm256_shuffle_epi32(lhs_mat_01_03, 160);  //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27)
+                    const __m256i lhs_mat_23_03_sp1 = _mm256_shuffle_epi32(lhs_mat_23_03, 160); //A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27)
 
-                    const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221);  //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15)
-                    const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221);  //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15)
+                    const __m256i lhs_mat_01_10_sp1 = _mm256_shuffle_epi32(lhs_mat_01_10, 160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3)
+                    const __m256i lhs_mat_23_10_sp1 = _mm256_shuffle_epi32(lhs_mat_23_10, 160); //A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3)
 
-                    const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221);  //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23)
-                    const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221);  //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23)
+                    const __m256i lhs_mat_01_11_sp1 = _mm256_shuffle_epi32(lhs_mat_01_11, 160);  //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)
+                    const __m256i lhs_mat_23_11_sp1 = _mm256_shuffle_epi32(lhs_mat_23_11, 160);  //A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11)
 
-                    const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221);  //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31)
-                    const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221);  //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31)
+                    const __m256i lhs_mat_01_12_sp1 = _mm256_shuffle_epi32(lhs_mat_01_12, 160);  //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19)
+                    const __m256i lhs_mat_23_12_sp1 = _mm256_shuffle_epi32(lhs_mat_23_12, 160);  //A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19)
 
-                    // Scale values - Load the wight scale values of block_q4_0x8
-                    const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
+                    const __m256i lhs_mat_01_13_sp1 = _mm256_shuffle_epi32(lhs_mat_01_13, 160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27)
+                    const __m256i lhs_mat_23_13_sp1 = _mm256_shuffle_epi32(lhs_mat_23_13, 160); //A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27)
 
-                    // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
-                    // Loaded as set of 128 bit vectors and repeated into a 256 bit vector
-                    __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs)));
-                    __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0);
-                    __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17);
-                    __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32)));
-                    __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0);
-                    __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17);
-                    __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64)));
-                    __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0);
-                    __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17);
-                    __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96)));
-                    __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0);
-                    __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17);
+                    // Shuffle pattern two- left side input
+                    const __m256i lhs_mat_01_00_sp2 = _mm256_shuffle_epi32(lhs_mat_01_00, 245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7)
+                    const __m256i lhs_mat_23_00_sp2 = _mm256_shuffle_epi32(lhs_mat_23_00, 245); //A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7) A02(4-7) A03(4-7)
 
-                    // Shuffle pattern one - left side input
+                    const __m256i lhs_mat_01_01_sp2 = _mm256_shuffle_epi32(lhs_mat_01_01, 245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15)
+                    const __m256i lhs_mat_23_01_sp2 = _mm256_shuffle_epi32(lhs_mat_23_01, 245); //A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15)
 
-                    const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160);  //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
-                    const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160);  //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
+                    const __m256i lhs_mat_01_02_sp2 = _mm256_shuffle_epi32(lhs_mat_01_02, 245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23)
+                    const __m256i lhs_mat_23_02_sp2 = _mm256_shuffle_epi32(lhs_mat_23_02, 245); //A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23)
 
-                    const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160);  //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
-                    const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160);  //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
+                    const __m256i lhs_mat_01_03_sp2 = _mm256_shuffle_epi32(lhs_mat_01_03, 245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31)
+                    const __m256i lhs_mat_23_03_sp2 = _mm256_shuffle_epi32(lhs_mat_23_03, 245); //A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31)
 
-                    const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160);  //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
-                    const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160);  //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
+                    const __m256i lhs_mat_01_10_sp2 = _mm256_shuffle_epi32(lhs_mat_01_10, 245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7)
+                    const __m256i lhs_mat_23_10_sp2 = _mm256_shuffle_epi32(lhs_mat_23_10, 245); //A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7)
 
-                    const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160);  //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
-                    const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160);  //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
+                    const __m256i lhs_mat_01_11_sp2 = _mm256_shuffle_epi32(lhs_mat_01_11, 245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15)
+                    const __m256i lhs_mat_23_11_sp2 = _mm256_shuffle_epi32(lhs_mat_23_11, 245); //A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15)
 
-                    // Shuffle pattern two - left side input
+                    const __m256i lhs_mat_01_12_sp2 = _mm256_shuffle_epi32(lhs_mat_01_12, 245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23)
+                    const __m256i lhs_mat_23_12_sp2 = _mm256_shuffle_epi32(lhs_mat_23_12, 245); //A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23)
 
-                    const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245);  //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
-                    const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245);  //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
+                    const __m256i lhs_mat_01_13_sp2 = _mm256_shuffle_epi32(lhs_mat_01_13, 245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31)
+                    const __m256i lhs_mat_23_13_sp2 = _mm256_shuffle_epi32(lhs_mat_23_13, 245); //A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31)
 
-                    const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245);  //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
-                    const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245);  //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
+                    // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
+                    __m256i iacc_mat_00_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_01_00_sp1));
+                    __m256i iacc_mat_01_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_01_00_sp1));
+                    __m256i iacc_mat_10_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_23_00_sp1));
+                    __m256i iacc_mat_11_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_23_00_sp1));
+                    __m256i iacc_mat_00_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_01_10_sp1));
+                    __m256i iacc_mat_01_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_01_10_sp1));
+                    __m256i iacc_mat_10_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_23_10_sp1));
+                    __m256i iacc_mat_11_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_23_10_sp1));
 
-                    const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245);  //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
-                    const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245);  //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
+                    __m256i iacc_mat_00_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_01_00_sp2));
+                    __m256i iacc_mat_01_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_01_00_sp2));
+                    __m256i iacc_mat_10_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_23_00_sp2));
+                    __m256i iacc_mat_11_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_23_00_sp2));
+                    __m256i iacc_mat_00_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_01_10_sp2));
+                    __m256i iacc_mat_01_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_01_10_sp2));
+                    __m256i iacc_mat_10_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_23_10_sp2));
+                    __m256i iacc_mat_11_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_23_10_sp2));
 
-                    const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245);  //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
-                    const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245);  //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
+                    // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+                    __m256i iacc_mat_00_0 = _mm256_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2);
+                    __m256i iacc_mat_01_0 = _mm256_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2);
+                    __m256i iacc_mat_10_0 = _mm256_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2);
+                    __m256i iacc_mat_11_0 = _mm256_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2);
 
-                    // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
-                    // Resembles MMLAs into 2x2 matrices in ARM Version
-                    const __m256i zero = _mm256_setzero_si256();
-                    __m256i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1);
-                    __m256i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1);
-                    __m256i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1);
-                    __m256i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1);
-                    __m256i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2);
-                    __m256i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2);
-                    __m256i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2);
-                    __m256i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2);
+                    __m256i iacc_mat_00_1 = _mm256_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2);
+                    __m256i iacc_mat_01_1 = _mm256_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2);
+                    __m256i iacc_mat_10_1 = _mm256_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2);
+                    __m256i iacc_mat_11_1 = _mm256_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2);
 
                     // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
-                    __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
-                    __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
-                    __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
-                    __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
+                    iacc_mat_00_0 = _mm256_madd_epi16(iacc_mat_00_0, scale_0145_0);
+                    iacc_mat_01_0 = _mm256_madd_epi16(iacc_mat_01_0, scale_2367_0);
+                    iacc_mat_10_0 = _mm256_madd_epi16(iacc_mat_10_0, scale_0145_0);
+                    iacc_mat_11_0 = _mm256_madd_epi16(iacc_mat_11_0, scale_2367_0);
 
+                    iacc_mat_00_1 = _mm256_madd_epi16(iacc_mat_00_1, scale_0145_1);
+                    iacc_mat_01_1 = _mm256_madd_epi16(iacc_mat_01_1, scale_2367_1);
+                    iacc_mat_10_1 = _mm256_madd_epi16(iacc_mat_10_1, scale_0145_1);
+                    iacc_mat_11_1 = _mm256_madd_epi16(iacc_mat_11_1, scale_2367_1);
 
-                    // Straighten out to make 4 row vectors
-                    __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204);
-                    __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204);
-                    __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204);
-                    __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204);
+                    // Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step)
+                    __m256i iacc_row_0_0 = _mm256_blend_epi32(iacc_mat_00_0, _mm256_shuffle_epi32(iacc_mat_01_0, 78), 204);
+                    __m256i iacc_row_1_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_0, 78), iacc_mat_01_0, 204);
+                    __m256i iacc_row_2_0 = _mm256_blend_epi32(iacc_mat_10_0, _mm256_shuffle_epi32(iacc_mat_11_0, 78), 204);
+                    __m256i iacc_row_3_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_0, 78), iacc_mat_11_0, 204);
+                    __m256i iacc_row_0_1 = _mm256_blend_epi32(iacc_mat_00_1, _mm256_shuffle_epi32(iacc_mat_01_1, 78), 204);
+                    __m256i iacc_row_1_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_1, 78), iacc_mat_01_1, 204);
+                    __m256i iacc_row_2_1 = _mm256_blend_epi32(iacc_mat_10_1, _mm256_shuffle_epi32(iacc_mat_11_1, 78), 204);
+                    __m256i iacc_row_3_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_1, 78), iacc_mat_11_1, 204);
 
-                    // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
-                    const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptr[b].d, loadMask);
+                    __m256i iacc_row_0 = _mm256_add_epi32(iacc_row_0_0, iacc_row_0_1);
+                    __m256i iacc_row_1 = _mm256_add_epi32(iacc_row_1_0, iacc_row_1_1);
+                    __m256i iacc_row_2 = _mm256_add_epi32(iacc_row_2_0, iacc_row_2_1);
+                    __m256i iacc_row_3 = _mm256_add_epi32(iacc_row_3_0, iacc_row_3_1);
 
-                    // Multiply with appropiate scales and accumulate
+                    // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes
+                    const __m128 row_scale_f32_sse = _mm_load_ps(a_ptr[b].d);
+                    const __m256 row_scale_f32 = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse); //GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask);
+
+                    // Multiply with appropiate scales and accumulate (for both d and dmin) below
                     acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);
                     acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);
                     acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
                     acc_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);
-                }
 
-                // Store the accumulated values
-                for (int i = 0; i < 4; i++) {
-                    _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
+                    __m256i iacc_row_min_0 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 0), mins_01);
+                    __m256i iacc_row_min_1 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 85), mins_01);
+                    __m256i iacc_row_min_2 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 170), mins_01);
+                    __m256i iacc_row_min_3 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 255), mins_01);
+
+                    acc_min_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_0), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[0]);
+                    acc_min_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_1), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[1]);
+                    acc_min_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_2), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[2]);
+                    acc_min_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_3), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[3]);
                 }
             }
+
+            // Store the accumulated values
+            for (int i = 0; i < 4; i++) {
+                _mm256_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm256_sub_ps(acc_rows[i], acc_min_rows[i]));
+            }
         }
-        return;
     }
 
-#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
-    float sumf[4][8];
-    int sumi;
+#else
+    UNUSED(kmask1);
+    UNUSED(kmask2);
+    UNUSED(kmask3);
+    ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
+#endif
+}
 
-    for (int y = 0; y < nr / 4; y++) {
-        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
-        for (int x = 0; x < nc / ncols_interleaved; x++) {
-            const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
-            for (int m = 0; m < 4; m++) {
-                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
-            }
-            for (int l = 0; l < nb; l++) {
-                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
-                    for (int m = 0; m < 4; m++) {
-                        for (int j = 0; j < ncols_interleaved; j++) {
-                            sumi = 0;
-                            for (int i = 0; i < blocklen; ++i) {
-                                const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
-                                const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
-                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
-                                         (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
-                            }
-                            sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
-                        }
-                    }
-                }
-            }
-            for (int m = 0; m < 4; m++) {
-                for (int j = 0; j < ncols_interleaved; j++)
-                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
-            }
-        }
+void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+#if defined(__AVX2__) || defined(__AVX512F__)
+    {
+        __m256i signextendlut = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)kvalues_iq4nl));
+        signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
+
+        gemm_q4_b32_8x8_q8_0_lut_avx(n, s, bs, vx, vy, nr, nc, signextendlut);
+
+        return;
     }
+#endif // defined(__AVX2__) || defined(__AVX512F__)
+
+    ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
-void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
     const int qk = QK_K;
     const int nb = n / qk;
     const int ncols_interleaved = 8;
     const int blocklen = 8;
-    static const uint32_t kmask1 = 0x3f3f3f3f;
-    static const uint32_t kmask2 = 0x0f0f0f0f;
-    static const uint32_t kmask3 = 0x03030303;
 
     assert (n % qk == 0);
     assert (nr % 4 == 0);
@@ -1793,21 +3444,37 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
     UNUSED(blocklen);
 
 #if defined(__AVX2__) || defined(__AVX512F__)
-    const block_q4_Kx8 * b_ptr_start = (const block_q4_Kx8 * ) vx;
+    const block_q2_Kx8 * b_ptr_start = (const block_q2_Kx8 * ) vx;
     const block_q8_Kx4 * a_ptr_start = (const block_q8_Kx4 * ) vy;
     int64_t b_nb = n / QK_K;
     int64_t y = 0;
 
-    // Mask to mask out nibbles from packed bytes
-    const __m256i m4b = _mm256_set1_epi8(0x0F);
     // Permute mask used for easier vector processing at later stages
     __m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
     int64_t xstart = 0;
-    int anr = nr - nr % 16;; // Used to align nr with boundary of 16
+    int anr = nr - nr % 16; // Used to align nr with boundary of 16
+
+    // Mask to convert 2 bit and 4 bit values into a bytes
+    const __m256i m3b = _mm256_set1_epi8(3);
+    const __m128i m4b_sse = _mm_set1_epi8(0xF);
+
+    //Mask to get appropriate scales
+    __m128i scalesmask1_sse = _mm_set_epi8(14,14,12,12,10,10,8,8,6,6,4,4,2,2,0,0);
+    __m128i scalesmask2_sse = _mm_set_epi8(15,15,13,13,11,11,9,9,7,7,5,5,3,3,1,1);
+
+    __m256i scalesmask1 = _mm256_castsi128_si256(scalesmask1_sse);
+    scalesmask1 = _mm256_permute2f128_si256(scalesmask1, scalesmask1, 0);
+    __m256i scalesmask2 = _mm256_castsi128_si256(scalesmask2_sse);
+    scalesmask2 = _mm256_permute2f128_si256(scalesmask2, scalesmask2, 0);
+
 #ifdef __AVX512F__
+
     int anc = nc - nc % 16; // Used to align nc with boundary of 16
+
+    // Mask to mask out nibbles from packed bytes
+    const __m256i m4b = _mm256_set1_epi8(0x0F);
     // Mask to mask out nibbles from packed bytes expanded to 512 bit length
-    const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);
+    const __m512i m3bexpanded = _mm512_set1_epi8(3);
     //Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation
     for (; y < anr / 4; y += 4) {
 
@@ -1818,11 +3485,11 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
             a_ptrs[i + 1] = a_ptrs[i] + nb;
         }
 
-        // Take group of eight block_q4_kx8 structures at each pass of the loop and perform dot product operation
+        // Take group of eight block_q2_kx8 structures at each pass of the loop and perform dot product operation
         for (int64_t x = 0; x < anc / 8; x += 2) {
 
-            const block_q4_Kx8 * b_ptr_0 = b_ptr_start + ((x) * b_nb);
-            const block_q4_Kx8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);
+            const block_q2_Kx8 * b_ptr_0 = b_ptr_start + ((x) * b_nb);
+            const block_q2_Kx8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);
 
             // Master FP accumulators
             __m512 acc_rows[16];
@@ -1834,18 +3501,18 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
             for (int i = 0; i < 16; i++) {
                 acc_min_rows[i] = _mm512_setzero_ps();
             }
-
             // For super block
             for (int64_t b = 0; b < nb; b++) {
-                // Scale values - Load the sixteen scale values from two block_q4_kx8 structures
+                // Delta values - Load the sixteen scale values from two block_q2_kx8 structures
                 const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
 
-                // dmin values - Load the sixteen dmin values from two block_q4_kx8 structures
+                // dmin values - Load the sixteen dmin values from two block_q2_kx8 structures
                 const __m512 col_dmin_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].dmin, b_ptr_1[b].dmin);
 
-                // Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration
-                for (int sb = 0; sb < QK_K / 64; sb++) {
+                // Loop to iterate over the sixteen sub blocks of a super block - eight sub blocks are processed per iteration
+                for (int sb = 0; sb < QK_K / 128; sb++) {
 
+                    // Load the eight block_q2_k for eight sub blocks quantized values interleaved with each other in chunks of eight bytes - B0,B1 ....B6,B7
                     const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + sb * 256));
                     const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 32 + sb * 256));
                     const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 64 + sb * 256));
@@ -1892,109 +3559,187 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                     const __m512i rhs_raw_mat_014589CD_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_3), rhs_raw_mat_89CD_3, 1);
                     const __m512i rhs_raw_mat_2367ABEF_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_3), rhs_raw_mat_ABEF_3, 1);
 
-                    //4-bit -> 8-bit
-                    const __m512i rhs_mat_014589CD_00 = _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded); //B00(0-7) B01(0-7) B04(0-7) B05(0-7) B08(0-7) B09(0-7) B0C(0-7) B0D(0-7)
-                    const __m512i rhs_mat_2367ABEF_00 = _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded); //B02(0-7) B03(0-7) B06(0-7) B07(0-7) B0A(0-7) B0B(0-7) B0E(0-7) B0F(0-7)
-                    const __m512i rhs_mat_014589CD_01 = _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded); //B00(8-15) B01(8-15) B04(8-15) B05(8-15) B08(8-15) B09(8-15) B0C(8-15) B0D(8-15)
-                    const __m512i rhs_mat_2367ABEF_01 = _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded); //B02(8-15) B03(8-15) B06(8-15) B07(8-15) B0A(8-15) B0B(8-15) B0E(8-15) B0F(8-15)
+                    //2-bit -> 8-bit
+                    const __m512i rhs_mat_014589CD_00 = _mm512_and_si512(rhs_raw_mat_014589CD_0,m3bexpanded); //B00(0-7) B01(0-7) B04(0-7) B05(0-7) B08(0-7) B09(0-7) B0C(0-7) B0D(0-7)
+                    const __m512i rhs_mat_2367ABEF_00 = _mm512_and_si512(rhs_raw_mat_2367ABEF_0,m3bexpanded); //B02(0-7) B03(0-7) B06(0-7) B07(0-7) B0A(0-7) B0B(0-7) B0E(0-7) B0F(0-7)
+                    const __m512i rhs_mat_014589CD_01 = _mm512_and_si512(rhs_raw_mat_014589CD_1,m3bexpanded); //B00(8-15) B01(8-15) B04(8-15) B05(8-15) B08(8-15) B09(8-15) B0C(8-15) B0D(8-15)
+                    const __m512i rhs_mat_2367ABEF_01 = _mm512_and_si512(rhs_raw_mat_2367ABEF_1,m3bexpanded); //B02(8-15) B03(8-15) B06(8-15) B07(8-15) B0A(8-15) B0B(8-15) B0E(8-15) B0F(8-15)
+                    const __m512i rhs_mat_014589CD_10 = _mm512_and_si512(rhs_raw_mat_014589CD_2,m3bexpanded); //B10(0-7) B11(0-7) B14(0-7) B15(0-7) B18(0-7) B19(0-7) B1C(0-7) B1D(0-7)
+                    const __m512i rhs_mat_2367ABEF_10 = _mm512_and_si512(rhs_raw_mat_2367ABEF_2,m3bexpanded); //B12(0-7) B13(0-7) B16(0-7) B17(0-7) B1A(0-7) B1B(0-7) B1E(0-7) B1F(0-7)
+                    const __m512i rhs_mat_014589CD_11 = _mm512_and_si512(rhs_raw_mat_014589CD_3,m3bexpanded); //B10(8-15) B11(8-15) B14(8-15) B15(8-15) B18(8-15) B19(8-15) B1C(8-15) B1D(8-15)
+                    const __m512i rhs_mat_2367ABEF_11 = _mm512_and_si512(rhs_raw_mat_2367ABEF_3,m3bexpanded); //B12(8-15) B13(8-15) B16(8-15) B17(8-15) B1A(8-15) B1B(8-15) B1E(8-15) B1F(8-15)
 
-                    const __m512i rhs_mat_014589CD_02 = _mm512_and_si512(rhs_raw_mat_014589CD_2, m4bexpanded); //B00(16-23) B01(16-23) B04(16-23) B05(16-23) B08(16-23) B09(16-23) B0C(16-23) B0D(16-23)
-                    const __m512i rhs_mat_2367ABEF_02 = _mm512_and_si512(rhs_raw_mat_2367ABEF_2, m4bexpanded); //B02(16-23) B03(16-23) B06(16-23) B07(16-23) B0A(16-23) B0B(16-23) B0E(16-23) B0F(16-23)
-                    const __m512i rhs_mat_014589CD_03 = _mm512_and_si512(rhs_raw_mat_014589CD_3, m4bexpanded); //B00(24-31) B01(24-31) B04(24-31) B05(24-31) B08(24-31) B09(24-31) B0C(24-31) B0D(24-31)
-                    const __m512i rhs_mat_2367ABEF_03 = _mm512_and_si512(rhs_raw_mat_2367ABEF_3, m4bexpanded); //B02(24-31) B03(24-31) B06(24-31) B07(24-31) B0A(24-31) B0B(24-31) B0E(24-31) B0F(24-31)
+                    const __m512i rhs_mat_014589CD_20 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 2), m3bexpanded); //B20(0-7) B21(0-7) B24(0-7) B25(0-7) B28(0-7) B29(0-7) B2C(0-7) B2D(0-7)
+                    const __m512i rhs_mat_2367ABEF_20 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 2), m3bexpanded); //B22(0-7) B23(0-7) B26(0-7) B27(0-7) B2A(0-7) B2B(0-7) B2E(0-7) B2F(0-7)
 
-                    const __m512i rhs_mat_014589CD_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded); //B10(0-7) B11(0-7) B14(0-7) B15(0-7) B18(0-7) B19(0-7) B1C(0-7) B1D(0-7)
-                    const __m512i rhs_mat_2367ABEF_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded); //B12(0-7) B13(0-7) B16(0-7) B17(0-7) B1A(0-7) B1B(0-7) B1E(0-7) B1F(0-7)
-                    const __m512i rhs_mat_014589CD_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded); //B10(8-15) B11(8-15) B14(8-15) B15(8-15) B18(8-15) B19(8-15) B1C(8-15) B1D(8-15)
-                    const __m512i rhs_mat_2367ABEF_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded); //B12(8-15) B13(8-15) B16(8-15) B17(8-15) B1A(8-15) B1B(8-15) B1E(8-15) B1F(8-15)
+                    const __m512i rhs_mat_014589CD_21 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 2), m3bexpanded); //B20(8-15) B21(8-15) B24(8-15) B25(8-15) B28(8-15) B29(8-15) B2C(8-15) B2D(8-15)
+                    const __m512i rhs_mat_2367ABEF_21 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 2), m3bexpanded); //B22(8-15) B23(8-15) B26(8-15) B27(8-15) B2A(8-15) B2B(8-15) B2E(8-15) B2F(8-15)
 
-                    const __m512i rhs_mat_014589CD_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 4), m4bexpanded); //B10(16-23) B11(16-23) B14(16-23) B15(16-23) B18(16-23) B19(16-23) B1C(16-23) B1D(16-23)
-                    const __m512i rhs_mat_2367ABEF_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 4), m4bexpanded); //B12(16-23) B13(16-23) B16(16-23) B17(16-23) B1A(16-23) B1B(16-23) B1E(16-23) B1F(16-23)
-                    const __m512i rhs_mat_014589CD_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 4), m4bexpanded); //B10(24-31) B11(24-31) B14(24-31) B15(24-31) B18(24-31) B19(24-31) B1C(24-31) B1D(24-31)
-                    const __m512i rhs_mat_2367ABEF_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 4), m4bexpanded); //B12(24-31) B13(24-31) B16(24-31) B17(24-31) B1A(24-31) B1B(24-31) B1E(24-31) B1F(24-31)
+                    const __m512i rhs_mat_014589CD_30 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 2), m3bexpanded); //B30(0-7) B31(0-7) B34(0-7) B35(0-7) B38(0-7) B39(0-7) B3C(0-7) B3D(0-7)
+                    const __m512i rhs_mat_2367ABEF_30 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 2), m3bexpanded); //B32(0-7) B33(0-7) B36(0-7) B37(0-7) B3A(0-7) B3B(0-7) B3E(0-7) B3F(0-7)
+
+                    const __m512i rhs_mat_014589CD_31 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 2), m3bexpanded); //B30(8-15) B31(8-15) B34(8-15) B35(8-15) B38(8-15) B39(8-15) B3C(8-15) B3D(8-15)
+                    const __m512i rhs_mat_2367ABEF_31 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 2), m3bexpanded); //B32(8-15) B33(8-15) B36(8-15) B37(8-15) B3A(8-15) B3B(8-15) B3E(8-15) B3F(8-15)
+
+                    const __m512i rhs_mat_014589CD_40 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m3bexpanded); //B40(0-7) B41(0-7) B44(0-7) B45(0-7) B48(0-7) B49(0-7) B4C(0-7) B4D(0-7)
+                    const __m512i rhs_mat_2367ABEF_40 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m3bexpanded); //B42(0-7) B43(0-7) B46(0-7) B47(0-7) B4A(0-7) B4B(0-7) B4E(0-7) B4F(0-7)
+
+                    const __m512i rhs_mat_014589CD_41 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m3bexpanded); //B40(8-15) B41(8-15) B44(8-15) B45(8-15) B48(8-15) B49(8-15) B4C(8-15) B4D(8-15)
+                    const __m512i rhs_mat_2367ABEF_41 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m3bexpanded); //B42(8-15) B43(8-15) B46(8-15) B47(8-15) B4A(8-15) B4B(8-15) B4E(8-15) B4F(8-15)
+
+                    const __m512i rhs_mat_014589CD_50 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 4), m3bexpanded); //B50(0-7) B51(0-7) B54(0-7) B55(0-7) B58(0-7) B59(0-7) B5C(0-7) B5D(0-7)
+                    const __m512i rhs_mat_2367ABEF_50 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 4), m3bexpanded); //B52(0-7) B53(0-7) B56(0-7) B57(0-7) B5A(0-7) B5B(0-7) B5E(0-7) B5F(0-7)
+
+                    const __m512i rhs_mat_014589CD_51 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 4), m3bexpanded); //B50(8-15) B51(8-15) B54(8-15) B55(8-15) B58(8-15) B59(8-15) B5C(8-15) B5D(8-15)
+                    const __m512i rhs_mat_2367ABEF_51 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 4), m3bexpanded); //B52(8-15) B53(8-15) B56(8-15) B57(8-15) B5A(8-15) B5B(8-15) B5E(8-15) B5F(8-15)
+
+                    const __m512i rhs_mat_014589CD_60 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 6), m3bexpanded); //B60(0-7) B61(0-7) B64(0-7) B65(0-7) B68(0-7) B69(0-7) B6C(0-7) B6D(0-7)
+                    const __m512i rhs_mat_2367ABEF_60 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 6), m3bexpanded); //B62(0-7) B63(0-7) B66(0-7) B67(0-7) B6A(0-7) B6B(0-7) B6E(0-7) B6F(0-7)
+
+                    const __m512i rhs_mat_014589CD_61 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 6), m3bexpanded); //B60(8-15) B61(8-15) B64(8-15) B65(8-15) B68(8-15) B69(8-15) B6C(8-15) B6D(8-15)
+                    const __m512i rhs_mat_2367ABEF_61 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 6), m3bexpanded); //B62(8-15) B63(8-15) B66(8-15) B67(8-15) B6A(8-15) B6B(8-15) B6E(8-15) B6F(8-15)
+
+                    const __m512i rhs_mat_014589CD_70 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 6), m3bexpanded); //B70(0-7) B71(0-7) B74(0-7) B75(0-7) B78(0-7) B79(0-7) B7C(0-7) B7D(0-7)
+                    const __m512i rhs_mat_2367ABEF_70 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 6), m3bexpanded); //B72(0-7) B73(0-7) B76(0-7) B77(0-7) B7A(0-7) B7B(0-7) B7E(0-7) B7F(0-7)
+
+                    const __m512i rhs_mat_014589CD_71 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 6), m3bexpanded); //B70(8-15) B71(8-15) B74(8-15) B75(8-15) B78(8-15) B79(8-15) B7C(8-15) B7D(8-15)
+                    const __m512i rhs_mat_2367ABEF_71 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 6), m3bexpanded); //B72(8-15) B73(8-15) B76(8-15) B77(8-15) B7A(8-15) B7B(8-15) B7E(8-15) B7F(8-15)
 
-                    // Shuffle pattern one - right side input
                     const __m512i rhs_mat_014589CD_00_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3) B08(0-3) B09(0-3) B08(0-3) B09(0-3) B0C(0-3) B0D(0-3) B0C(0-3) B0D(0-3)
                     const __m512i rhs_mat_2367ABEF_00_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3) B0A(0-3) B0B(0-3) B0A(0-3) B0B(0-3) B0E(0-3) B0F(0-3) B0E(0-3) B0F(0-3)
+
                     const __m512i rhs_mat_014589CD_01_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11) B08(8-11) B09(8-11) B08(8-11) B09(8-11) B0C(8-11) B0D(8-11) B0C(8-11) B0D(8-11)
                     const __m512i rhs_mat_2367ABEF_01_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11) B0A(8-11) B0B(8-11) B0A(8-11) B0B(8-11) B0E(8-11) B0F(8-11) B0E(8-11) B0F(8-11)
-                    const __m512i rhs_mat_014589CD_02_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19) B08(16-19) B09(16-19) B08(16-19) B09(16-19) B0C(16-19) B0D(16-19) B0C(16-19) B0D(16-19)
-                    const __m512i rhs_mat_2367ABEF_02_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19) B0A(16-19) B0B(16-19) B0A(16-19) B0B(16-19) B0E(16-19) B0F(16-19) B0E(16-19) B0F(16-19)
-                    const __m512i rhs_mat_014589CD_03_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27) B08(24-27) B09(24-27) B08(24-27) B09(24-27) B0C(24-27) B0D(24-27) B0C(24-27) B0D(24-27)
-                    const __m512i rhs_mat_2367ABEF_03_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27) B0A(24-27) B0B(24-27) B0A(24-27) B0B(24-27) B0E(24-27) B0F(24-27) B0E(24-27) B0F(24-27)
 
                     const __m512i rhs_mat_014589CD_10_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3) B18(0-3) B19(0-3) B18(0-3) B19(0-3) B1C(0-3) B1D(0-3) B1C(0-3) B1D(0-3)
                     const __m512i rhs_mat_2367ABEF_10_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3) B1A(0-3) B1B(0-3) B1A(0-3) B1B(0-3) B1E(0-3) B1F(0-3) B1E(0-3) B1F(0-3)
+
                     const __m512i rhs_mat_014589CD_11_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11) B18(8-11) B19(8-11) B18(8-11) B19(8-11) B1C(8-11) B1D(8-11) B1C(8-11) B1D(8-11)
                     const __m512i rhs_mat_2367ABEF_11_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11) B1A(8-11) B1B(8-11) B1A(8-11) B1B(8-11) B1E(8-11) B1F(8-11) B1E(8-11) B1F(8-11)
-                    const __m512i rhs_mat_014589CD_12_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19) B18(16-19) B19(16-19) B18(16-19) B19(16-19) B1C(16-19) B1D(16-19) B1C(16-19) B1D(16-19)
-                    const __m512i rhs_mat_2367ABEF_12_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19) B1A(16-19) B1B(16-19) B1A(16-19) B1B(16-19) B1E(16-19) B1F(16-19) B1E(16-19) B1F(16-19)
-                    const __m512i rhs_mat_014589CD_13_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27) B18(24-27) B19(24-27) B18(24-27) B19(24-27) B1C(24-27) B1D(24-27) B1C(24-27) B1D(24-27)
-                    const __m512i rhs_mat_2367ABEF_13_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27) B1A(24-27) B1B(24-27) B1A(24-27) B1B(24-27) B1E(24-27) B1F(24-27) B1E(24-27) B1F(24-27)
 
-                    // Shuffle pattern two - right side input
+                    const __m512i rhs_mat_014589CD_20_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_20, (_MM_PERM_ENUM)136); //B20(0-3) B21(0-3) B20(0-3) B21(0-3) B24(0-3) B25(0-3) B24(0-3) B25(0-3) B28(0-3) B29(0-3) B28(0-3) B29(0-3) B2C(0-3) B2D(0-3) B2C(0-3) B2D(0-3)
+                    const __m512i rhs_mat_2367ABEF_20_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_20, (_MM_PERM_ENUM)136); //B22(0-3) B23(0-3) B22(0-3) B23(0-3) B26(0-3) B27(0-3) B26(0-3) B27(0-3) B2A(0-3) B2B(0-3) B2A(0-3) B2B(0-3) B2E(0-3) B2F(0-3) B2E(0-3) B2F(0-3)
+
+                    const __m512i rhs_mat_014589CD_21_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_21, (_MM_PERM_ENUM)136); //B20(8-11) B21(8-11) B20(8-11) B21(8-11) B24(8-11) B25(8-11) B24(8-11) B25(8-11) B28(8-11) B29(8-11) B28(8-11) B29(8-11) B2C(8-11) B2D(8-11) B2C(8-11) B2D(8-11)
+                    const __m512i rhs_mat_2367ABEF_21_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_21, (_MM_PERM_ENUM)136); //B22(8-11) B23(8-11) B22(8-11) B23(8-11) B26(8-11) B27(8-11) B26(8-11) B27(8-11) B2A(8-11) B2B(8-11) B2A(8-11) B2B(8-11) B2E(8-11) B2F(8-11) B2E(8-11) B2F(8-11)
+
+                    const __m512i rhs_mat_014589CD_30_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_30, (_MM_PERM_ENUM)136); ///B30(0-3) B31(0-3) B30(0-3) B31(0-3) B34(0-3) B35(0-3) B34(0-3) B35(0-3) B38(0-3) B39(0-3) B38(0-3) B39(0-3) B3C(0-3) B3D(0-3) B3C(0-3) B3D(0-3)
+                    const __m512i rhs_mat_2367ABEF_30_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_30, (_MM_PERM_ENUM)136); //B32(0-3) B33(0-3) B32(0-3) B33(0-3) B36(0-3) B37(0-3) B36(0-3) B37(0-3) B3A(0-3) B3B(0-3) B3A(0-3) B3B(0-3) B3E(0-3) B3F(0-3) B3E(0-3) B3F(0-3)
+
+                    const __m512i rhs_mat_014589CD_31_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_31, (_MM_PERM_ENUM)136); //B30(8-11) B31(8-11) B30(8-11) B31(8-11) B34(8-11) B35(8-11) B34(8-11) B35(8-11) B38(8-11) B39(8-11) B38(8-11) B39(8-11) B3C(8-11) B3D(8-11) B3C(8-11) B3D(8-11)
+                    const __m512i rhs_mat_2367ABEF_31_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_31, (_MM_PERM_ENUM)136); //B32(8-11) B33(8-11) B32(8-11) B33(8-11) B36(8-11) B37(8-11) B36(8-11) B37(8-11) B3A(8-11) B3B(8-11) B3A(8-11) B3B(8-11) B3E(8-11) B3F(8-11) B3E(8-11) B3F(8-11)
+
+                    const __m512i rhs_mat_014589CD_40_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_40, (_MM_PERM_ENUM)136); //B40(0-3) B41(0-3) B40(0-3) B41(0-3) B44(0-3) B45(0-3) B44(0-3) B45(0-3) B48(0-3) B49(0-3) B48(0-3) B49(0-3) B4C(0-3) B4D(0-3) B4C(0-3) B4D(0-3)
+                    const __m512i rhs_mat_2367ABEF_40_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_40, (_MM_PERM_ENUM)136); //B42(0-3) B43(0-3) B42(0-3) B43(0-3) B46(0-3) B47(0-3) B46(0-3) B47(0-3) B4A(0-3) B4B(0-3) B4A(0-3) B4B(0-3) B4E(0-3) B4F(0-3) B4E(0-3) B4F(0-3)
+
+                    const __m512i rhs_mat_014589CD_41_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_41, (_MM_PERM_ENUM)136); //B40(8-11) B41(8-11) B40(8-11) B41(8-11) B44(8-11) B45(8-11) B44(8-11) B45(8-11) B48(8-11) B49(8-11) B48(8-11) B49(8-11) B4C(8-11) B4D(8-11) B4C(8-11) B4D(8-11)
+                    const __m512i rhs_mat_2367ABEF_41_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_41, (_MM_PERM_ENUM)136); //B42(8-11) B43(8-11) B42(8-11) B43(8-11) B46(8-11) B47(8-11) B46(8-11) B47(8-11) B4A(8-11) B4B(8-11) B4A(8-11) B4B(8-11) B4E(8-11) B4F(8-11) B4E(8-11) B4F(8-11)
+
+                    const __m512i rhs_mat_014589CD_50_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_50, (_MM_PERM_ENUM)136); //B50(0-3) B51(0-3) B50(0-3) B51(0-3) B54(0-3) B55(0-3) B54(0-3) B55(0-3) B58(0-3) B59(0-3) B58(0-3) B59(0-3) B5C(0-3) B5D(0-3) B5C(0-3) B5D(0-3)
+                    const __m512i rhs_mat_2367ABEF_50_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_50, (_MM_PERM_ENUM)136); //B52(0-3) B53(0-3) B52(0-3) B53(0-3) B56(0-3) B57(0-3) B56(0-3) B57(0-3) B5A(0-3) B5B(0-3) B5A(0-3) B5B(0-3) B5E(0-3) B5F(0-3) B5E(0-3) B5F(0-3)
+
+                    const __m512i rhs_mat_014589CD_51_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_51, (_MM_PERM_ENUM)136); //B50(8-11) B51(8-11) B50(8-11) B51(8-11) B54(8-11) B55(8-11) B54(8-11) B55(8-11) B58(8-11) B59(8-11) B58(8-11) B59(8-11) B5C(8-11) B5D(8-11) B5C(8-11) B5D(8-11)
+                    const __m512i rhs_mat_2367ABEF_51_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_51, (_MM_PERM_ENUM)136); //B52(8-11) B53(8-11) B52(8-11) B53(8-11) B56(8-11) B57(8-11) B56(8-11) B57(8-11) B5A(8-11) B5B(8-11) B5A(8-11) B5B(8-11) B5E(8-11) B5F(8-11) B5E(8-11) B5F(8-11)
+
+                    const __m512i rhs_mat_014589CD_60_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_60, (_MM_PERM_ENUM)136); //B60(0-3) B61(0-3) B60(0-3) B61(0-3) B64(0-3) B65(0-3) B64(0-3) B65(0-3) B68(0-3) B69(0-3) B68(0-3) B69(0-3) B6C(0-3) B6D(0-3) B6C(0-3) B6D(0-3)
+                    const __m512i rhs_mat_2367ABEF_60_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_60, (_MM_PERM_ENUM)136); //B62(0-3) B63(0-3) B62(0-3) B63(0-3) B66(0-3) B67(0-3) B66(0-3) B67(0-3) B6A(0-3) B6B(0-3) B6A(0-3) B6B(0-3) B6E(0-3) B6F(0-3) B6E(0-3) B6F(0-3)
+
+                    const __m512i rhs_mat_014589CD_61_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_61, (_MM_PERM_ENUM)136); //B60(8-11) B61(8-11) B60(8-11) B61(8-11) B64(8-11) B65(8-11) B64(8-11) B65(8-11) B68(8-11) B69(8-11) B68(8-11) B69(8-11) B6C(8-11) B6D(8-11) B6C(8-11) B6D(8-11)
+                    const __m512i rhs_mat_2367ABEF_61_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_61, (_MM_PERM_ENUM)136); //B62(8-11) B63(8-11) B62(8-11) B63(8-11) B66(8-11) B67(8-11) B66(8-11) B67(8-11) B6A(8-11) B6B(8-11) B6A(8-11) B6B(8-11) B6E(8-11) B6F(8-11) B6E(8-11) B6F(8-11)
+
+                    const __m512i rhs_mat_014589CD_70_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_70, (_MM_PERM_ENUM)136); //B70(0-3) B71(0-3) B70(0-3) B71(0-3) B74(0-3) B75(0-3) B74(0-3) B75(0-3) B78(0-3) B79(0-3) B78(0-3) B79(0-3) B7C(0-3) B7D(0-3) B7C(0-3) B7D(0-3)
+                    const __m512i rhs_mat_2367ABEF_70_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_70, (_MM_PERM_ENUM)136); //B72(0-3) B73(0-3) B72(0-3) B73(0-3) B76(0-3) B77(0-3) B76(0-3) B77(0-3) B7A(0-3) B7B(0-3) B7A(0-3) B7B(0-3) B7E(0-3) B7F(0-3) B7E(0-3) B7F(0-3)
+
+                    const __m512i rhs_mat_014589CD_71_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_71, (_MM_PERM_ENUM)136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11) B08(8-11) B09(8-11) B08(8-11) B09(8-11) B0C(8-11) B0D(8-11) B0C(8-11) B0D(8-11)
+                    const __m512i rhs_mat_2367ABEF_71_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_71, (_MM_PERM_ENUM)136); //B72(8-11) B73(8-11) B72(8-11) B73(8-11) B76(8-11) B77(8-11) B76(8-11) B77(8-11) B7A(8-11) B7B(8-11) B7A(8-11) B7B(8-11) B7E(8-11) B7F(8-11) B7E(8-11) B7F(8-11)
+
                     const __m512i rhs_mat_014589CD_00_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7) B08(4-7) B09(4-7) B08(4-7) B09(4-7) B0C(4-7) B0D(4-7) B0C(4-7) B0D(4-7)
                     const __m512i rhs_mat_2367ABEF_00_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7) B0A(4-7) B0B(4-7) B0A(4-7) B0B(4-7) B0E(4-7) B0F(4-7) B0E(4-7) B0F(4-7)
+
                     const __m512i rhs_mat_014589CD_01_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15) B08(12-15) B09(12-15) B08(12-15) B09(12-15) B0C(12-15) B0D(12-15) B0C(12-15) B0D(12-15)
                     const __m512i rhs_mat_2367ABEF_01_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15) B0A(12-15) B0B(12-15) B0A(12-15) B0B(12-15) B0E(12-15) B0F(12-15) B0E(12-15) B0F(12-15)
-                    const __m512i rhs_mat_014589CD_02_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23) B08(20-23) B09(20-23) B08(20-23) B09(20-23) B0C(20-23) B0D(20-23) B0C(20-23) B0D(20-23)
-                    const __m512i rhs_mat_2367ABEF_02_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23) B0A(20-23) B0B(20-23) B0A(20-23) B0B(20-23) B0E(20-23) B0F(20-23) B0E(20-23) B0F(20-23)
-                    const __m512i rhs_mat_014589CD_03_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31) B08(28-31) B09(28-31) B08(28-31) B09(28-31) B0C(28-31) B0D(28-31) B0C(28-31) 0BD(28-31)
-                    const __m512i rhs_mat_2367ABEF_03_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31) B0A(28-31) B0B(28-31) B0A(28-31) B0B(28-31) B0E(28-31) B0F(28-31) B0E(28-31) B0F(28-31)
 
                     const __m512i rhs_mat_014589CD_10_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7) B18(4-7) B19(4-7) B18(4-7) B19(4-7) B1C(4-7) B1D(4-7) B1C(4-7) B1D(4-7)
                     const __m512i rhs_mat_2367ABEF_10_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7) B1A(4-7) B1B(4-7) B1A(4-7) B1B(4-7) B1E(4-7) B1F(4-7) B1E(4-7) B1F(4-7)
+
                     const __m512i rhs_mat_014589CD_11_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15) B18(12-15) B19(12-15) B18(12-15) B19(12-15) B1C(12-15) B1D(12-15) B1C(12-15) B1D(12-15)
                     const __m512i rhs_mat_2367ABEF_11_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15) B1A(12-15) B1B(12-15) B1A(12-15) B1B(12-15) B1E(12-15) B1F(12-15) B1E(12-15) B1F(12-15)
-                    const __m512i rhs_mat_014589CD_12_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23) B18(20-23) B19(20-23) B18(20-23) B19(20-23) B1C(20-23) B1D(20-23) B1C(20-23) B1D(20-23)
-                    const __m512i rhs_mat_2367ABEF_12_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23) B1A(20-23) B1B(20-23) B1A(20-23) B1B(20-23) B1E(20-23) B1F(20-23) B1E(20-23) B1F(20-23)
-                    const __m512i rhs_mat_014589CD_13_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31) B18(28-31) B19(28-31) B18(28-31) B19(28-31) B1C(28-31) B1D(28-31) B1C(28-31) B1D(28-31)
-                    const __m512i rhs_mat_2367ABEF_13_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31) B1A(28-31) B1B(28-31) B1A(28-31) B1B(28-31) B1E(28-31) B1F(28-31) B1E(28-31) B1F(28-31)
 
-                    uint32_t utmp_00[4], utmp_01[4], utmp_10[4], utmp_11[4];
+                    const __m512i rhs_mat_014589CD_20_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_20, (_MM_PERM_ENUM)221); //B20(4-7) B21(4-7) B20(4-7) B21(4-7) B24(4-7) B25(4-7) B24(4-7) B25(4-7) B28(4-7) B29(4-7) B28(4-7) B29(4-7) B2C(4-7) B2D(4-7) B2C(4-7) B2D(4-7)
+                    const __m512i rhs_mat_2367ABEF_20_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_20, (_MM_PERM_ENUM)221); //B22(4-7) B23(4-7) B22(4-7) B23(4-7) B26(4-7) B27(4-7) B26(4-7) B27(4-7) B2A(4-7) B2B(4-7) B2A(4-7) B2B(4-7) B2E(4-7) B2F(4-7) B2E(4-7) B2F(4-7)
+
+                    const __m512i rhs_mat_014589CD_21_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_21, (_MM_PERM_ENUM)221); //B20(12-15) B21(12-15) B20(12-15) B21(12-15) B24(12-15) B25(12-15) B24(12-15) B25(12-15) B28(12-15) B29(12-15) B28(12-15) B29(12-15) B2C(12-15) B2D(12-15) B2C(12-15) B2D(12-15)
+                    const __m512i rhs_mat_2367ABEF_21_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_21, (_MM_PERM_ENUM)221); //B22(12-15) B23(12-15) B22(12-15) B23(12-15) B26(12-15) B27(12-15) B26(12-15) B27(12-15) B2A(12-15) B2B(12-15) B2A(12-15) B2B(12-15) B2E(12-15) B2F(12-15) B2E(12-15) B2F(12-15)
+
+                    const __m512i rhs_mat_014589CD_30_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_30, (_MM_PERM_ENUM)221); //B30(4-7) B31(4-7) B30(4-7) B31(4-7) B34(4-7) B35(4-7) B34(4-7) B35(4-7) B38(4-7) B39(4-7) B38(4-7) B39(4-7) B3C(4-7) B3D(4-7) B3C(4-7) B3D(4-7)
+                    const __m512i rhs_mat_2367ABEF_30_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_30, (_MM_PERM_ENUM)221); //B32(4-7) B33(4-7) B32(4-7) B33(4-7) B36(4-7) B37(4-7) B36(4-7) B37(4-7) B3A(4-7) B3B(4-7) B3A(4-7) B3B(4-7) B3E(4-7) B3F(4-7) B3E(4-7) B3F(4-7)
+
+                    const __m512i rhs_mat_014589CD_31_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_31, (_MM_PERM_ENUM)221); //B30(12-15) B31(12-15) B30(12-15) B31(12-15) B34(12-15) B35(12-15) B34(12-15) B35(12-15) B38(12-15) B39(12-15) B38(12-15) B39(12-15) B3C(12-15) B3D(12-15) B3C(12-15) B3D(12-15)
+                    const __m512i rhs_mat_2367ABEF_31_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_31, (_MM_PERM_ENUM)221); //B32(12-15) B33(12-15) B32(12-15) B33(12-15) B36(12-15) B37(12-15) B36(12-15) B37(12-15) B3A(12-15) B3B(12-15) B3A(12-15) B3B(12-15) B3E(12-15) B3F(12-15) B3E(12-15) B3F(12-15)
+
+                    const __m512i rhs_mat_014589CD_40_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_40, (_MM_PERM_ENUM)221); //B40(4-7) B41(4-7) B40(4-7) B41(4-7) B44(4-7) B45(4-7) B44(4-7) B45(4-7) B48(4-7) B49(4-7) B48(4-7) B49(4-7) B4C(4-7) B4D(4-7) B4C(4-7) B4D(4-7)
+                    const __m512i rhs_mat_2367ABEF_40_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_40, (_MM_PERM_ENUM)221); //B42(4-7) B43(4-7) B42(4-7) B43(4-7) B46(4-7) B47(4-7) B46(4-7) B47(4-7) B4A(4-7) B4B(4-7) B4A(4-7) B4B(4-7) B4E(4-7) B4F(4-7) B4E(4-7) B4F(4-7)
+
+                    const __m512i rhs_mat_014589CD_41_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_41, (_MM_PERM_ENUM)221); //B40(12-15) B41(12-15) B40(12-15) B41(12-15) B44(12-15) B45(12-15) B44(12-15) B45(12-15) B48(12-15) B49(12-15) B48(12-15) B49(12-15) B4C(12-15) B4D(12-15) B4C(12-15) B4D(12-15)
+                    const __m512i rhs_mat_2367ABEF_41_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_41, (_MM_PERM_ENUM)221); //B42(12-15) B43(12-15) B42(12-15) B43(12-15) B46(12-15) B47(12-15) B46(12-15) B47(12-15) B4A(12-15) B4B(12-15) B4A(12-15) B4B(12-15) B4E(12-15) B4F(12-15) B4E(12-15) B4F(12-15)
+
+                    const __m512i rhs_mat_014589CD_50_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_50, (_MM_PERM_ENUM)221); //B50(4-7) B51(4-7) B50(4-7) B51(4-7) B54(4-7) B55(4-7) B54(4-7) B55(4-7) B58(4-7) B59(4-7) B58(4-7) B59(4-7) B5C(4-7) B5D(4-7) B5C(4-7) B5D(4-7)
+                    const __m512i rhs_mat_2367ABEF_50_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_50, (_MM_PERM_ENUM)221); //B52(4-7) B53(4-7) B52(4-7) B53(4-7) B56(4-7) B57(4-7) B56(4-7) B57(4-7) B5A(4-7) B5B(4-7) B5A(4-7) B5B(4-7) B5E(4-7) B5F(4-7) B5E(4-7) B5F(4-7)
+
+                    const __m512i rhs_mat_014589CD_51_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_51, (_MM_PERM_ENUM)221); //B50(12-15) B51(12-15) B50(12-15) B51(12-15) B54(12-15) B55(12-15) B54(12-15) B55(12-15) B58(12-15) B59(12-15) B58(12-15) B59(12-15) B5C(12-15) B5D(12-15) B5C(12-15) B5D(12-15)
+                    const __m512i rhs_mat_2367ABEF_51_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_51, (_MM_PERM_ENUM)221); //B52(12-15) B53(12-15) B52(12-15) B53(12-15) B56(12-15) B57(12-15) B56(12-15) B57(12-15) B5A(12-15) B5B(12-15) B5A(12-15) B5B(12-15) B5E(12-15) B5F(12-15) B5E(12-15) B5F(12-15)
 
-                    // Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together
-                    // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
-                    memcpy(utmp_00, b_ptr_0[b].scales + 24 * sb, 12);
-                    utmp_00[3] = ((utmp_00[2] >> 4) & kmask2) | (((utmp_00[1] >> 6) & kmask3) << 4);
-                    const uint32_t uaux_00 = utmp_00[1] & kmask1;
-                    utmp_00[1] = (utmp_00[2] & kmask2) | (((utmp_00[0] >> 6) & kmask3) << 4);
-                    utmp_00[2] = uaux_00;
-                    utmp_00[0] &= kmask1;
+                    const __m512i rhs_mat_014589CD_60_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_60, (_MM_PERM_ENUM)221); //B60(4-7) B61(4-7) B60(4-7) B61(4-7) B64(4-7) B65(4-7) B64(4-7) B65(4-7) B68(4-7) B69(4-7) B68(4-7) B69(4-7) B6C(4-7) B6D(4-7) B6C(4-7) B6D(4-7)
+                    const __m512i rhs_mat_2367ABEF_60_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_60, (_MM_PERM_ENUM)221); //B62(4-7) B63(4-7) B62(4-7) B63(4-7) B66(4-7) B67(4-7) B66(4-7) B67(4-7) B6A(4-7) B6B(4-7) B6A(4-7) B6B(4-7) B6E(4-7) B6F(4-7) B6E(4-7) B6F(4-7)
 
-                    // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
-                    memcpy(utmp_01, b_ptr_0[b].scales + 12 + sb * 24, 12);
-                    utmp_01[3] = ((utmp_01[2] >> 4) & kmask2) | (((utmp_01[1] >> 6) & kmask3) << 4);
-                    const uint32_t uaux_01 = utmp_01[1] & kmask1;
-                    utmp_01[1] = (utmp_01[2] & kmask2) | (((utmp_01[0] >> 6) & kmask3) << 4);
-                    utmp_01[2] = uaux_01;
-                    utmp_01[0] &= kmask1;
+                    const __m512i rhs_mat_014589CD_61_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_61, (_MM_PERM_ENUM)221); //B60(12-15) B61(12-15) B60(12-15) B61(12-15) B64(12-15) B65(12-15) B64(12-15) B65(12-15) B68(12-15) B69(12-15) B68(12-15) B69(12-15) B6C(12-15) B6D(12-15) B6C(12-15) B6D(12-15)
+                    const __m512i rhs_mat_2367ABEF_61_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_61, (_MM_PERM_ENUM)221); //B62(12-15) B63(12-15) B62(12-15) B63(12-15) B66(12-15) B67(12-15) B66(12-15) B67(12-15) B6A(12-15) B6B(12-15) B6A(12-15) B6B(12-15) B6E(12-15) B6F(12-15) B6E(12-15) B6F(12-15)
 
-                    memcpy(utmp_10, b_ptr_1[b].scales + sb * 24, 12);
-                    utmp_10[3] = ((utmp_10[2] >> 4) & kmask2) | (((utmp_10[1] >> 6) & kmask3) << 4);
-                    const uint32_t uaux_10 = utmp_10[1] & kmask1;
-                    utmp_10[1] = (utmp_10[2] & kmask2) | (((utmp_10[0] >> 6) & kmask3) << 4);
-                    utmp_10[2] = uaux_10;
-                    utmp_10[0] &= kmask1;
+                    const __m512i rhs_mat_014589CD_70_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_70, (_MM_PERM_ENUM)221); //B70(4-7) B71(4-7) B70(4-7) B71(4-7) B74(4-7) B75(4-7) B74(4-7) B75(4-7) B78(4-7) B79(4-7) B78(4-7) B79(4-7) B7C(4-7) B7D(4-7) B7C(4-7) B7D(4-7)
+                    const __m512i rhs_mat_2367ABEF_70_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_70, (_MM_PERM_ENUM)221); //B72(4-7) B73(4-7) B72(4-7) B73(4-7) B76(4-7) B77(4-7) B76(4-7) B77(4-7) B7A(4-7) B7B(4-7) B7A(4-7) B7B(4-7) B7E(4-7) B7F(4-7) B7E(4-7) B7F(4-7)
 
-                    // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
-                    memcpy(utmp_11, b_ptr_1[b].scales + 12 + sb * 24, 12);
-                    utmp_11[3] = ((utmp_11[2] >> 4) & kmask2) | (((utmp_11[1] >> 6) & kmask3) << 4);
-                    const uint32_t uaux_11 = utmp_11[1] & kmask1;
-                    utmp_11[1] = (utmp_11[2] & kmask2) | (((utmp_11[0] >> 6) & kmask3) << 4);
-                    utmp_11[2] = uaux_11;
-                    utmp_11[0] &= kmask1;
+                    const __m512i rhs_mat_014589CD_71_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_71, (_MM_PERM_ENUM)221); //B70(12-15) B71(12-15) B70(12-15) B71(12-15) B74(12-15) B75(12-15) B74(12-15) B75(12-15) B78(12-15) B79(12-15) B78(12-15) B79(12-15) B7C(12-15) B7D(12-15) B7C(12-15) B7D(12-15)
+                    const __m512i rhs_mat_2367ABEF_71_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_71, (_MM_PERM_ENUM)221); //B72(12-15) B73(12-15) B72(12-15) B73(12-15) B76(12-15) B77(12-15) B76(12-15) B77(12-15) B7A(12-15) B7B(12-15) B7A(12-15) B7B(12-15) B7E(12-15) B7F(12-15) B7E(12-15) B7F(12-15)
 
-                    // Scales of first sub block in the sb loop
-                    const __m256i mins_and_scales_0 = _mm256_set_epi32(utmp_10[3], utmp_10[2], utmp_10[1], utmp_10[0], utmp_00[3], utmp_00[2], utmp_00[1], utmp_00[0]);
-                    const __m512i scales_0 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0));
+                    //notation:superblock subblock
+                    //s00 m00  s01 m01   s10 m10  s11 m11  s20 m20  s21 m21   s30 m30  s31 m31  s40 m40  s41 m41   s50 m50  s51 m51  s60 m60  s61 m61   s70 m70  s71 m71
 
-                    // Scales of second sub block in the sb loop
-                    const __m256i mins_and_scales_1 = _mm256_set_epi32(utmp_11[3], utmp_11[2], utmp_11[1], utmp_11[0], utmp_01[3], utmp_01[2], utmp_01[1], utmp_01[0]);
-                    const __m512i scales_1 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1));
+                    const __m128i mins_and_scales_01_0 = _mm_loadu_si128((const __m128i *)(b_ptr_0[b].scales + sb * 64));
+                    const __m128i mins_and_scales_23_0 = _mm_loadu_si128((const __m128i *)(b_ptr_0[b].scales + 16 + sb * 64));
+                    const __m128i mins_and_scales_45_0 = _mm_loadu_si128((const __m128i *)(b_ptr_0[b].scales + 32 + sb * 64));
+                    const __m128i mins_and_scales_67_0 = _mm_loadu_si128((const __m128i *)(b_ptr_0[b].scales + 48 + sb * 64));
 
-                    // Mins of first and second sub block of Q4_K block are arranged side by side
-                    const __m512i mins_01 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(_mm256_shuffle_epi32(mins_and_scales_0, 78), _mm256_shuffle_epi32(mins_and_scales_1, 78)));
+                    const __m128i mins_and_scales_01_1 = _mm_loadu_si128((const __m128i *)(b_ptr_1[b].scales + sb * 64));
+                    const __m128i mins_and_scales_23_1 = _mm_loadu_si128((const __m128i *)(b_ptr_1[b].scales + 16 + sb * 64));
+                    const __m128i mins_and_scales_45_1 = _mm_loadu_si128((const __m128i *)(b_ptr_1[b].scales + 32 + sb * 64));
+                    const __m128i mins_and_scales_67_1 = _mm_loadu_si128((const __m128i *)(b_ptr_1[b].scales + 48 + sb * 64));
+
+                    // Combine mins and scales for sub-blocks: 0-1, 2-3, 4-5, 6-7 in the sb loop
+                    const __m256i mins_and_scales_01 = _mm256_insertf128_si256(_mm256_castsi128_si256(mins_and_scales_01_0), mins_and_scales_01_1, 1);
+                    const __m256i mins_and_scales_23 = _mm256_insertf128_si256(_mm256_castsi128_si256(mins_and_scales_23_0), mins_and_scales_23_1, 1);
+                    const __m256i mins_and_scales_45 = _mm256_insertf128_si256(_mm256_castsi128_si256(mins_and_scales_45_0), mins_and_scales_45_1, 1);
+                    const __m256i mins_and_scales_67 = _mm256_insertf128_si256(_mm256_castsi128_si256(mins_and_scales_67_0), mins_and_scales_67_1, 1);
+
+                    // Extract scales which is lower half from mins_and_scales
+                    const __m256i scales_01 = _mm256_and_si256(mins_and_scales_01, m4b);
+                    const __m256i scales_23 = _mm256_and_si256(mins_and_scales_23, m4b);
+                    const __m256i scales_45 = _mm256_and_si256(mins_and_scales_45, m4b);
+                    const __m256i scales_67 = _mm256_and_si256(mins_and_scales_67, m4b);
+
+                    // Extract mins which is upper half from mins_and_scales
+                    const __m512i mins_01 = _mm512_cvtepu8_epi16(_mm256_and_si256(_mm256_srli_epi16(mins_and_scales_01, 4), m4b));
+                    const __m512i mins_23 = _mm512_cvtepu8_epi16(_mm256_and_si256(_mm256_srli_epi16(mins_and_scales_23, 4), m4b));
+                    const __m512i mins_45 = _mm512_cvtepu8_epi16(_mm256_and_si256(_mm256_srli_epi16(mins_and_scales_45, 4), m4b));
+                    const __m512i mins_67 = _mm512_cvtepu8_epi16(_mm256_and_si256(_mm256_srli_epi16(mins_and_scales_67, 4), m4b));
+
+                    const __m512i scales_0 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_01,scalesmask1));
+                    const __m512i scales_1 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_01,scalesmask2));
+                    const __m512i scales_2 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_23,scalesmask1));
+                    const __m512i scales_3 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_23,scalesmask2));
+                    const __m512i scales_4 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_45,scalesmask1));
+                    const __m512i scales_5 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_45,scalesmask2));
+                    const __m512i scales_6 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_67,scalesmask1));
+                    const __m512i scales_7 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_67,scalesmask2));
 
                     const __m512i scale_014589CD_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)68);
                     const __m512i scale_2367ABEF_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)238);
@@ -2002,116 +3747,330 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                     const __m512i scale_014589CD_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)68);
                     const __m512i scale_2367ABEF_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)238);
 
+                    const __m512i scale_014589CD_2 = _mm512_shuffle_epi32(scales_2, (_MM_PERM_ENUM)68);
+                    const __m512i scale_2367ABEF_2 = _mm512_shuffle_epi32(scales_2, (_MM_PERM_ENUM)238);
+
+                    const __m512i scale_014589CD_3 = _mm512_shuffle_epi32(scales_3, (_MM_PERM_ENUM)68);
+                    const __m512i scale_2367ABEF_3 = _mm512_shuffle_epi32(scales_3, (_MM_PERM_ENUM)238);
+
+                    const __m512i scale_014589CD_4 = _mm512_shuffle_epi32(scales_4, (_MM_PERM_ENUM)68);
+                    const __m512i scale_2367ABEF_4 = _mm512_shuffle_epi32(scales_4, (_MM_PERM_ENUM)238);
+
+                    const __m512i scale_014589CD_5 = _mm512_shuffle_epi32(scales_5, (_MM_PERM_ENUM)68);
+                    const __m512i scale_2367ABEF_5 = _mm512_shuffle_epi32(scales_5, (_MM_PERM_ENUM)238);
+
+                    const __m512i scale_014589CD_6 = _mm512_shuffle_epi32(scales_6, (_MM_PERM_ENUM)68);
+                    const __m512i scale_2367ABEF_6 = _mm512_shuffle_epi32(scales_6, (_MM_PERM_ENUM)238);
+
+                    const __m512i scale_014589CD_7 = _mm512_shuffle_epi32(scales_7, (_MM_PERM_ENUM)68);
+                    const __m512i scale_2367ABEF_7 = _mm512_shuffle_epi32(scales_7, (_MM_PERM_ENUM)238);
+
+
                     for (int rp = 0; rp < 4; rp++) {
 
                         // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3
                         // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector
-                        __m256i lhs_mat_ymm_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 256 * sb)));
+                        __m256i lhs_mat_ymm_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 512 * sb)));
                         __m256i lhs_mat_ymm_01_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 0);
                         __m256i lhs_mat_ymm_23_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 17);
-                        __m256i lhs_mat_ymm_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 32 + 256 * sb)));
+                        __m256i lhs_mat_ymm_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 32 + 512 * sb)));
                         __m256i lhs_mat_ymm_01_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 0);
                         __m256i lhs_mat_ymm_23_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 17);
-                        __m256i lhs_mat_ymm_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 64 + 256 * sb)));
-                        __m256i lhs_mat_ymm_01_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 0);
-                        __m256i lhs_mat_ymm_23_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 17);
-                        __m256i lhs_mat_ymm_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 96 + 256 * sb)));
-                        __m256i lhs_mat_ymm_01_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 0);
-                        __m256i lhs_mat_ymm_23_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 17);
-                        __m256i lhs_mat_ymm_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 128 + 256 * sb)));
+                        __m256i lhs_mat_ymm_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 64 + 512 * sb)));
                         __m256i lhs_mat_ymm_01_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 0);
                         __m256i lhs_mat_ymm_23_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 17);
-                        __m256i lhs_mat_ymm_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 160 + 256 * sb)));
+                        __m256i lhs_mat_ymm_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 96 + 512 * sb)));
                         __m256i lhs_mat_ymm_01_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 0);
                         __m256i lhs_mat_ymm_23_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 17);
-                        __m256i lhs_mat_ymm_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 192 + 256 * sb)));
-                        __m256i lhs_mat_ymm_01_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 0);
-                        __m256i lhs_mat_ymm_23_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 17);
-                        __m256i lhs_mat_ymm_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 224 + 256 * sb)));
-                        __m256i lhs_mat_ymm_01_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 0);
-                        __m256i lhs_mat_ymm_23_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 17);
+                        __m256i lhs_mat_ymm_0123_20 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 128 + 512 * sb)));
+                        __m256i lhs_mat_ymm_01_20 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_20, lhs_mat_ymm_0123_20, 0);
+                        __m256i lhs_mat_ymm_23_20 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_20, lhs_mat_ymm_0123_20, 17);
+                        __m256i lhs_mat_ymm_0123_21 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 160 + 512 * sb)));
+                        __m256i lhs_mat_ymm_01_21 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_21, lhs_mat_ymm_0123_21, 0);
+                        __m256i lhs_mat_ymm_23_21 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_21, lhs_mat_ymm_0123_21, 17);
+                        __m256i lhs_mat_ymm_0123_30 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 192 + 512 * sb)));
+                        __m256i lhs_mat_ymm_01_30 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_30, lhs_mat_ymm_0123_30, 0);
+                        __m256i lhs_mat_ymm_23_30 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_30, lhs_mat_ymm_0123_30, 17);
+                        __m256i lhs_mat_ymm_0123_31 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 224 + 512 * sb)));
+                        __m256i lhs_mat_ymm_01_31 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_31, lhs_mat_ymm_0123_31, 0);
+                        __m256i lhs_mat_ymm_23_31 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_31, lhs_mat_ymm_0123_31, 17);
+
+                        __m256i lhs_mat_ymm_0123_40 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 256 + 512 * sb)));
+                        __m256i lhs_mat_ymm_01_40 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_40, lhs_mat_ymm_0123_40, 0);
+                        __m256i lhs_mat_ymm_23_40 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_40, lhs_mat_ymm_0123_40, 17);
+                        __m256i lhs_mat_ymm_0123_41 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 288 + 512 * sb)));
+                        __m256i lhs_mat_ymm_01_41 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_41, lhs_mat_ymm_0123_41, 0);
+                        __m256i lhs_mat_ymm_23_41 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_41, lhs_mat_ymm_0123_41, 17);
+                        __m256i lhs_mat_ymm_0123_50 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 320 + 512 * sb)));
+                        __m256i lhs_mat_ymm_01_50 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_50, lhs_mat_ymm_0123_50, 0);
+                        __m256i lhs_mat_ymm_23_50 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_50, lhs_mat_ymm_0123_50, 17);
+                        __m256i lhs_mat_ymm_0123_51 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 352 + 512 * sb)));
+                        __m256i lhs_mat_ymm_01_51 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_51, lhs_mat_ymm_0123_51, 0);
+                        __m256i lhs_mat_ymm_23_51 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_51, lhs_mat_ymm_0123_51, 17);
+                        __m256i lhs_mat_ymm_0123_60 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 384 + 512 * sb)));
+                        __m256i lhs_mat_ymm_01_60 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_60, lhs_mat_ymm_0123_60, 0);
+                        __m256i lhs_mat_ymm_23_60 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_60, lhs_mat_ymm_0123_60, 17);
+                        __m256i lhs_mat_ymm_0123_61 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 416 + 512 * sb)));
+                        __m256i lhs_mat_ymm_01_61 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_61, lhs_mat_ymm_0123_61, 0);
+                        __m256i lhs_mat_ymm_23_61 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_61, lhs_mat_ymm_0123_61, 17);
+                        __m256i lhs_mat_ymm_0123_70 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 448 + 512 * sb)));
+                        __m256i lhs_mat_ymm_01_70 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_70, lhs_mat_ymm_0123_70, 0);
+                        __m256i lhs_mat_ymm_23_70 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_70, lhs_mat_ymm_0123_70, 17);
+                        __m256i lhs_mat_ymm_0123_71 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 480 + 512 * sb)));
+                        __m256i lhs_mat_ymm_01_71 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_71, lhs_mat_ymm_0123_71, 0);
+                        __m256i lhs_mat_ymm_23_71 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_71, lhs_mat_ymm_0123_71, 17);
+
 
                         __m512i lhs_mat_01_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_00), lhs_mat_ymm_01_00, 1);
                         __m512i lhs_mat_23_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_00), lhs_mat_ymm_23_00, 1);
                         __m512i lhs_mat_01_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_01), lhs_mat_ymm_01_01, 1);
                         __m512i lhs_mat_23_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_01), lhs_mat_ymm_23_01, 1);
-                        __m512i lhs_mat_01_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_02), lhs_mat_ymm_01_02, 1);
-                        __m512i lhs_mat_23_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_02), lhs_mat_ymm_23_02, 1);
-                        __m512i lhs_mat_01_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_03), lhs_mat_ymm_01_03, 1);
-                        __m512i lhs_mat_23_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_03), lhs_mat_ymm_23_03, 1);
 
                         __m512i lhs_mat_01_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_10), lhs_mat_ymm_01_10, 1);
                         __m512i lhs_mat_23_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_10), lhs_mat_ymm_23_10, 1);
                         __m512i lhs_mat_01_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_11), lhs_mat_ymm_01_11, 1);
                         __m512i lhs_mat_23_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_11), lhs_mat_ymm_23_11, 1);
-                        __m512i lhs_mat_01_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_12), lhs_mat_ymm_01_12, 1);
-                        __m512i lhs_mat_23_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_12), lhs_mat_ymm_23_12, 1);
-                        __m512i lhs_mat_01_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_13), lhs_mat_ymm_01_13, 1);
-                        __m512i lhs_mat_23_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_13), lhs_mat_ymm_23_13, 1);
 
-                        // Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks
-                        __m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].bsums + 16 * sb)));
-                        __m256i lhs_bsums_hsum_ymm_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1)));
-                        lhs_bsums_hsum_ymm_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_ymm_0123_01, lhs_bsums_hsum_ymm_0123_01, 0);
-                        __m512i lhs_bsums_hsum_0123_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_hsum_ymm_0123_01), lhs_bsums_hsum_ymm_0123_01, 1);
+                        __m512i lhs_mat_01_20 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_20), lhs_mat_ymm_01_20, 1);
+                        __m512i lhs_mat_23_20 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_20), lhs_mat_ymm_23_20, 1);
+                        __m512i lhs_mat_01_21 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_21), lhs_mat_ymm_01_21, 1);
+                        __m512i lhs_mat_23_21 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_21), lhs_mat_ymm_23_21, 1);
+
+                        __m512i lhs_mat_01_30 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_30), lhs_mat_ymm_01_30, 1);
+                        __m512i lhs_mat_23_30 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_30), lhs_mat_ymm_23_30, 1);
+                        __m512i lhs_mat_01_31 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_31), lhs_mat_ymm_01_31, 1);
+                        __m512i lhs_mat_23_31 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_31), lhs_mat_ymm_23_31, 1);
+
+                        __m512i lhs_mat_01_40 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_40), lhs_mat_ymm_01_40, 1);
+                        __m512i lhs_mat_23_40 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_40), lhs_mat_ymm_23_40, 1);
+                        __m512i lhs_mat_01_41 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_41), lhs_mat_ymm_01_41, 1);
+                        __m512i lhs_mat_23_41 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_41), lhs_mat_ymm_23_41, 1);
+
+                        __m512i lhs_mat_01_50 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_50), lhs_mat_ymm_01_50, 1);
+                        __m512i lhs_mat_23_50 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_50), lhs_mat_ymm_23_50, 1);
+                        __m512i lhs_mat_01_51 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_51), lhs_mat_ymm_01_51, 1);
+                        __m512i lhs_mat_23_51 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_51), lhs_mat_ymm_23_51, 1);
+
+                        __m512i lhs_mat_01_60 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_60), lhs_mat_ymm_01_60, 1);
+                        __m512i lhs_mat_23_60 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_60), lhs_mat_ymm_23_60, 1);
+                        __m512i lhs_mat_01_61 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_61), lhs_mat_ymm_01_61, 1);
+                        __m512i lhs_mat_23_61 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_61), lhs_mat_ymm_23_61, 1);
+
+                        __m512i lhs_mat_01_70 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_70), lhs_mat_ymm_01_70, 1);
+                        __m512i lhs_mat_23_70 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_70), lhs_mat_ymm_23_70, 1);
+                        __m512i lhs_mat_01_71 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_71), lhs_mat_ymm_01_71, 1);
+                        __m512i lhs_mat_23_71 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_71), lhs_mat_ymm_23_71, 1);
+
+                        // Bsums are loaded for the different Q8_K blocks
+                        __m128i lhs_raw_bsums_01_0123 = _mm_loadu_si128((const __m128i *)((a_ptrs[rp][b].bsums + 32 * sb)));
+                        __m128i lhs_raw_bsums_23_0123 = _mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].bsums + 8 + 32 * sb));
+                        __m128i lhs_raw_bsums_01_4567 = _mm_loadu_si128((const __m128i *)((a_ptrs[rp][b].bsums + 16 + 32 * sb)));
+                        __m128i lhs_raw_bsums_23_4567 = _mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].bsums + 24 + 32 * sb));
+
+                        __m256i lhs_bsums_ymm_01_0123 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_01_0123), lhs_raw_bsums_01_0123, 1);
+                        __m512i lhs_bsums_01_0123 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_ymm_01_0123), lhs_bsums_ymm_01_0123, 1);
+                        __m256i lhs_bsums_ymm_23_0123 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_23_0123), lhs_raw_bsums_23_0123, 1);
+                        __m512i lhs_bsums_23_0123 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_ymm_23_0123), lhs_bsums_ymm_23_0123, 1);                        __m256i lhs_bsums_ymm_01_4567 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_01_4567), lhs_raw_bsums_01_4567, 1);
+                        __m512i lhs_bsums_01_4567 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_ymm_01_4567), lhs_bsums_ymm_01_4567, 1);
+                        __m256i lhs_bsums_ymm_23_4567 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_23_4567), lhs_raw_bsums_23_4567, 1);
+                        __m512i lhs_bsums_23_4567 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_ymm_23_4567), lhs_bsums_ymm_23_4567, 1);
 
                         // Shuffle pattern one - left side input
                         const __m512i lhs_mat_01_00_sp1 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)
                         const __m512i lhs_mat_23_00_sp1 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)160); //A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3)
+
                         const __m512i lhs_mat_01_01_sp1 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)
                         const __m512i lhs_mat_23_01_sp1 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)160); //A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11)
-                        const __m512i lhs_mat_01_02_sp1 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)160); //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19)
-                        const __m512i lhs_mat_23_02_sp1 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)160); //A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19)
-                        const __m512i lhs_mat_01_03_sp1 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)160); //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27)
-                        const __m512i lhs_mat_23_03_sp1 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)160); //A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27)
 
                         const __m512i lhs_mat_01_10_sp1 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3)
                         const __m512i lhs_mat_23_10_sp1 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)160); //A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3)
+
                         const __m512i lhs_mat_01_11_sp1 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)
                         const __m512i lhs_mat_23_11_sp1 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)160); //A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11)
-                        const __m512i lhs_mat_01_12_sp1 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)160); //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19)
-                        const __m512i lhs_mat_23_12_sp1 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)160); //A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19)
-                        const __m512i lhs_mat_01_13_sp1 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27)
-                        const __m512i lhs_mat_23_13_sp1 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)160); //A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27)
+
+                        const __m512i lhs_mat_01_20_sp1 = _mm512_shuffle_epi32(lhs_mat_01_20, (_MM_PERM_ENUM)160); //A20(0-3) A20(0-3) A21(0-3) A21(0-3) A20(0-3) A20(0-3) A21(0-3) A21(0-3) A20(0-3) A20(0-3) A21(0-3) A21(0-3) A20(0-3) A20(0-3) A21(0-3) A21(0-3)
+                        const __m512i lhs_mat_23_20_sp1 = _mm512_shuffle_epi32(lhs_mat_23_20, (_MM_PERM_ENUM)160); //A22(0-3) A22(0-3) A23(0-3) A23(0-3) A22(0-3) A22(0-3) A23(0-3) A23(0-3) A22(0-3) A22(0-3) A23(0-3) A23(0-3) A22(0-3) A22(0-3) A23(0-3) A23(0-3)
+
+                        const __m512i lhs_mat_01_21_sp1 = _mm512_shuffle_epi32(lhs_mat_01_21, (_MM_PERM_ENUM)160); //A20(8-11) A20(8-11) A21(8-11) A21(8-11) A20(8-11) A20(8-11) A21(8-11) A21(8-11) A20(8-11) A20(8-11) A21(8-11) A21(8-11) A20(8-11) A20(8-11) A21(8-11) A21(8-11)
+                        const __m512i lhs_mat_23_21_sp1 = _mm512_shuffle_epi32(lhs_mat_23_21, (_MM_PERM_ENUM)160); //A22(8-11) A22(8-11) A23(8-11) A23(8-11) A22(8-11) A22(8-11) A23(8-11) A23(8-11) A22(8-11) A22(8-11) A23(8-11) A23(8-11) A22(8-11) A22(8-11) A23(8-11) A23(8-11)
+
+                        const __m512i lhs_mat_01_30_sp1 = _mm512_shuffle_epi32(lhs_mat_01_30, (_MM_PERM_ENUM)160); //A30(0-3) A30(0-3) A31(0-3) A31(0-3) A30(0-3) A30(0-3) A31(0-3) A31(0-3) A30(0-3) A30(0-3) A31(0-3) A31(0-3) A30(0-3) A30(0-3) A31(0-3) A31(0-3)
+                        const __m512i lhs_mat_23_30_sp1 = _mm512_shuffle_epi32(lhs_mat_23_30, (_MM_PERM_ENUM)160); //A32(0-3) A32(0-3) A33(0-3) A33(0-3) A32(0-3) A32(0-3) A33(0-3) A33(0-3) A32(0-3) A32(0-3) A33(0-3) A33(0-3) A32(0-3) A32(0-3) A33(0-3) A33(0-3)
+
+                        const __m512i lhs_mat_01_31_sp1 = _mm512_shuffle_epi32(lhs_mat_01_31, (_MM_PERM_ENUM)160); //A30(8-11) A30(8-11) A31(8-11) A31(8-11) A30(8-11) A30(8-11) A31(8-11) A31(8-11) A30(8-11) A30(8-11) A31(8-11) A31(8-11) A30(8-11) A30(8-11) A31(8-11) A31(8-11)
+                        const __m512i lhs_mat_23_31_sp1 = _mm512_shuffle_epi32(lhs_mat_23_31, (_MM_PERM_ENUM)160); //A32(8-11) A32(8-11) A33(8-11) A33(8-11) A32(8-11) A32(8-11) A33(8-11) A33(8-11) A32(8-11) A32(8-11) A33(8-11) A33(8-11) A32(8-11) A32(8-11) A33(8-11) A33(8-11)
+
+                        const __m512i lhs_mat_01_40_sp1 = _mm512_shuffle_epi32(lhs_mat_01_40, (_MM_PERM_ENUM)160); //A40(0-3) A40(0-3) A41(0-3) A41(0-3) A40(0-3) A40(0-3) A41(0-3) A41(0-3) A40(0-3) A40(0-3) A41(0-3) A41(0-3) A40(0-3) A40(0-3) A41(0-3) A41(0-3)
+                        const __m512i lhs_mat_23_40_sp1 = _mm512_shuffle_epi32(lhs_mat_23_40, (_MM_PERM_ENUM)160); //A42(0-3) A42(0-3) A43(0-3) A43(0-3) A42(0-3) A42(0-3) A43(0-3) A43(0-3) A42(0-3) A42(0-3) A43(0-3) A43(0-3) A42(0-3) A42(0-3) A43(0-3) A43(0-3)
+
+                        const __m512i lhs_mat_01_41_sp1 = _mm512_shuffle_epi32(lhs_mat_01_41, (_MM_PERM_ENUM)160); //A40(8-11) A40(8-11) A41(8-11) A41(8-11) A40(8-11) A40(8-11) A41(8-11) A41(8-11) A40(8-11) A40(8-11) A41(8-11) A41(8-11) A40(8-11) A40(8-11) A41(8-11) A41(8-11)
+                        const __m512i lhs_mat_23_41_sp1 = _mm512_shuffle_epi32(lhs_mat_23_41, (_MM_PERM_ENUM)160); //A42(8-11) A42(8-11) A43(8-11) A43(8-11) A42(8-11) A42(8-11) A43(8-11) A43(8-11) A42(8-11) A42(8-11) A43(8-11) A43(8-11) A42(8-11) A42(8-11) A43(8-11) A43(8-11)
+
+                        const __m512i lhs_mat_01_50_sp1 = _mm512_shuffle_epi32(lhs_mat_01_50, (_MM_PERM_ENUM)160); //A50(0-3) A50(0-3) A51(0-3) A51(0-3) A50(0-3) A50(0-3) A51(0-3) A51(0-3) A50(0-3) A50(0-3) A51(0-3) A51(0-3) A50(0-3) A50(0-3) A51(0-3) A51(0-3)
+                        const __m512i lhs_mat_23_50_sp1 = _mm512_shuffle_epi32(lhs_mat_23_50, (_MM_PERM_ENUM)160); //A52(0-3) A52(0-3) A53(0-3) A53(0-3) A52(0-3) A52(0-3) A53(0-3) A53(0-3) A52(0-3) A52(0-3) A53(0-3) A53(0-3) A52(0-3) A52(0-3) A53(0-3) A53(0-3)
+
+                        const __m512i lhs_mat_01_51_sp1 = _mm512_shuffle_epi32(lhs_mat_01_51, (_MM_PERM_ENUM)160); //A50(8-11) A50(8-11) A51(8-11) A51(8-11) A50(8-11) A50(8-11) A51(8-11) A51(8-11) A50(8-11) A50(8-11) A51(8-11) A51(8-11) A50(8-11) A50(8-11) A51(8-11) A51(8-11)
+                        const __m512i lhs_mat_23_51_sp1 = _mm512_shuffle_epi32(lhs_mat_23_51, (_MM_PERM_ENUM)160); //A52(8-11) A52(8-11) A53(8-11) A53(8-11) A52(8-11) A52(8-11) A53(8-11) A53(8-11) A52(8-11) A52(8-11) A53(8-11) A53(8-11) A52(8-11) A52(8-11) A53(8-11) A53(8-11)
+
+                        const __m512i lhs_mat_01_60_sp1 = _mm512_shuffle_epi32(lhs_mat_01_60, (_MM_PERM_ENUM)160); //A60(0-3) A60(0-3) A61(0-3) A61(0-3) A60(0-3) A60(0-3) A61(0-3) A61(0-3) A60(0-3) A60(0-3) A61(0-3) A61(0-3) A60(0-3) A60(0-3) A61(0-3) A61(0-3)
+                        const __m512i lhs_mat_23_60_sp1 = _mm512_shuffle_epi32(lhs_mat_23_60, (_MM_PERM_ENUM)160); //A62(0-3) A62(0-3) A63(0-3) A63(0-3) A62(0-3) A62(0-3) A63(0-3) A63(0-3) A62(0-3) A62(0-3) A63(0-3) A63(0-3) A62(0-3) A62(0-3) A63(0-3) A63(0-3)
+
+                        const __m512i lhs_mat_01_61_sp1 = _mm512_shuffle_epi32(lhs_mat_01_61, (_MM_PERM_ENUM)160); //A60(8-11) A60(8-11) A61(8-11) A61(8-11) A60(8-11) A60(8-11) A61(8-11) A61(8-11) A60(8-11) A60(8-11) A61(8-11) A61(8-11) A60(8-11) A60(8-11) A61(8-11) A61(8-11)
+                        const __m512i lhs_mat_23_61_sp1 = _mm512_shuffle_epi32(lhs_mat_23_61, (_MM_PERM_ENUM)160); //A62(8-11) A62(8-11) A63(8-11) A63(8-11) A62(8-11) A62(8-11) A63(8-11) A63(8-11) A62(8-11) A62(8-11) A63(8-11) A63(8-11) A62(8-11) A62(8-11) A63(8-11) A63(8-11)
+
+                        const __m512i lhs_mat_01_70_sp1 = _mm512_shuffle_epi32(lhs_mat_01_70, (_MM_PERM_ENUM)160); //A70(0-3) A70(0-3) A71(0-3) A71(0-3) A70(0-3) A70(0-3) A71(0-3) A71(0-3) A70(0-3) A70(0-3) A71(0-3) A71(0-3) A70(0-3) A70(0-3) A71(0-3) A71(0-3)
+                        const __m512i lhs_mat_23_70_sp1 = _mm512_shuffle_epi32(lhs_mat_23_70, (_MM_PERM_ENUM)160); //A72(0-3) A72(0-3) A73(0-3) A73(0-3) A72(0-3) A72(0-3) A73(0-3) A73(0-3) A72(0-3) A72(0-3) A73(0-3) A73(0-3) A72(0-3) A72(0-3) A73(0-3) A73(0-3)
+
+                        const __m512i lhs_mat_01_71_sp1 = _mm512_shuffle_epi32(lhs_mat_01_71, (_MM_PERM_ENUM)160); //A70(8-11) A70(8-11) A71(8-11) A71(8-11) A70(8-11) A70(8-11) A71(8-11) A71(8-11) A70(8-11) A70(8-11) A71(8-11) A71(8-11) A70(8-11) A70(8-11) A71(8-11) A71(8-11)
+                        const __m512i lhs_mat_23_71_sp1 = _mm512_shuffle_epi32(lhs_mat_23_71, (_MM_PERM_ENUM)160); //A72(8-11) A72(8-11) A73(8-11) A73(8-11) A72(8-11) A72(8-11) A73(8-11) A73(8-11) A72(8-11) A72(8-11) A73(8-11) A73(8-11) A72(8-11) A72(8-11) A73(8-11) A73(8-11)
 
                         const __m512i lhs_mat_01_00_sp2 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7)
                         const __m512i lhs_mat_23_00_sp2 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)245); //A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7)
+
                         const __m512i lhs_mat_01_01_sp2 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15)
                         const __m512i lhs_mat_23_01_sp2 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)245); //A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15)
-                        const __m512i lhs_mat_01_02_sp2 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23)
-                        const __m512i lhs_mat_23_02_sp2 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)245); //A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23)
-                        const __m512i lhs_mat_01_03_sp2 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31)
-                        const __m512i lhs_mat_23_03_sp2 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)245); //A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31)
 
                         const __m512i lhs_mat_01_10_sp2 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7)
                         const __m512i lhs_mat_23_10_sp2 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)245); //A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7)
+
                         const __m512i lhs_mat_01_11_sp2 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15)
                         const __m512i lhs_mat_23_11_sp2 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)245); //A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15)
-                        const __m512i lhs_mat_01_12_sp2 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23)
-                        const __m512i lhs_mat_23_12_sp2 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)245); //A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23)
-                        const __m512i lhs_mat_01_13_sp2 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31)
-                        const __m512i lhs_mat_23_13_sp2 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)245); //A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31)
+
+                        const __m512i lhs_mat_01_20_sp2 = _mm512_shuffle_epi32(lhs_mat_01_20, (_MM_PERM_ENUM)245); //A20(4-7) A20(4-7) A21(4-7) A21(4-7) A20(4-7) A20(4-7) A21(4-7) A21(4-7) A20(4-7) A20(4-7) A21(4-7) A21(4-7) A20(4-7) A20(4-7) A21(4-7) A21(4-7)
+                        const __m512i lhs_mat_23_20_sp2 = _mm512_shuffle_epi32(lhs_mat_23_20, (_MM_PERM_ENUM)245); //A22(4-7) A22(4-7) A23(4-7) A23(4-7) A22(4-7) A22(4-7) A23(4-7) A23(4-7) A22(4-7) A22(4-7) A23(4-7) A23(4-7) A22(4-7) A22(4-7) A23(4-7) A23(4-7)
+
+                        const __m512i lhs_mat_01_21_sp2 = _mm512_shuffle_epi32(lhs_mat_01_21, (_MM_PERM_ENUM)245); //A20(12-15) A20(12-15) A21(12-15) A21(12-15) A20(12-15) A20(12-15) A21(12-15) A21(12-15) A20(12-15) A20(12-15) A21(12-15) A21(12-15) A20(12-15) A20(12-15) A21(12-15) A21(12-15)
+                        const __m512i lhs_mat_23_21_sp2 = _mm512_shuffle_epi32(lhs_mat_23_21, (_MM_PERM_ENUM)245); //A22(12-15) A22(12-15) A23(12-15) A23(12-15) A22(12-15) A22(12-15) A23(12-15) A23(12-15) A22(12-15) A22(12-15) A23(12-15) A23(12-15) A22(12-15) A22(12-15) A23(12-15) A23(12-15)
+
+                        const __m512i lhs_mat_01_30_sp2 = _mm512_shuffle_epi32(lhs_mat_01_30, (_MM_PERM_ENUM)245); //A30(4-7) A30(4-7) A31(4-7) A31(4-7) A30(4-7) A30(4-7) A31(4-7) A31(4-7) A30(4-7) A30(4-7) A31(4-7) A31(4-7) A30(4-7) A30(4-7) A31(4-7) A31(4-7)
+                        const __m512i lhs_mat_23_30_sp2 = _mm512_shuffle_epi32(lhs_mat_23_30, (_MM_PERM_ENUM)245); //A32(4-7) A32(4-7) A33(4-7) A33(4-7) A32(4-7) A32(4-7) A33(4-7) A33(4-7) A32(4-7) A32(4-7) A33(4-7) A33(4-7) A32(4-7) A32(4-7) A33(4-7) A33(4-7)
+
+                        const __m512i lhs_mat_01_31_sp2 = _mm512_shuffle_epi32(lhs_mat_01_31, (_MM_PERM_ENUM)245); //A30(12-15) A30(12-15) A31(12-15) A31(12-15) A30(12-15) A30(12-15) A31(12-15) A31(12-15) A30(12-15) A30(12-15) A31(12-15) A31(12-15) A30(12-15) A30(12-15) A31(12-15) A31(12-15)
+                        const __m512i lhs_mat_23_31_sp2 = _mm512_shuffle_epi32(lhs_mat_23_31, (_MM_PERM_ENUM)245); //A32(12-15) A32(12-15) A33(12-15) A33(12-15) A32(12-15) A32(12-15) A33(12-15) A33(12-15) A32(12-15) A32(12-15) A33(12-15) A33(12-15) A32(12-15) A32(12-15) A33(12-15) A33(12-15)
+
+                        const __m512i lhs_mat_01_40_sp2 = _mm512_shuffle_epi32(lhs_mat_01_40, (_MM_PERM_ENUM)245); //A40(4-7) A40(4-7) A41(4-7) A41(4-7) A40(4-7) A40(4-7) A41(4-7) A41(4-7) A40(4-7) A40(4-7) A41(4-7) A41(4-7) A40(4-7) A40(4-7) A41(4-7) A41(4-7)
+                        const __m512i lhs_mat_23_40_sp2 = _mm512_shuffle_epi32(lhs_mat_23_40, (_MM_PERM_ENUM)245); //A42(4-7) A42(4-7) A43(4-7) A43(4-7) A42(4-7) A42(4-7) A43(4-7) A43(4-7) A42(4-7) A42(4-7) A43(4-7) A43(4-7) A42(4-7) A42(4-7) A43(4-7) A43(4-7)
+
+                        const __m512i lhs_mat_01_41_sp2 = _mm512_shuffle_epi32(lhs_mat_01_41, (_MM_PERM_ENUM)245); //A40(12-15) A40(12-15) A41(12-15) A41(12-15) A40(12-15) A40(12-15) A41(12-15) A41(12-15) A40(12-15) A40(12-15) A41(12-15) A41(12-15) A40(12-15) A40(12-15) A41(12-15) A41(12-15)
+                        const __m512i lhs_mat_23_41_sp2 = _mm512_shuffle_epi32(lhs_mat_23_41, (_MM_PERM_ENUM)245); //A42(12-15) A42(12-15) A43(12-15) A43(12-15) A42(12-15) A42(12-15) A43(12-15) A43(12-15) A42(12-15) A42(12-15) A43(12-15) A43(12-15) A42(12-15) A42(12-15) A43(12-15) A43(12-15)
+
+                        const __m512i lhs_mat_01_50_sp2 = _mm512_shuffle_epi32(lhs_mat_01_50, (_MM_PERM_ENUM)245); //A50(4-7) A50(4-7) A51(4-7) A51(4-7) A50(4-7) A50(4-7) A51(4-7) A51(4-7) A50(4-7) A50(4-7) A51(4-7) A51(4-7) A50(4-7) A50(4-7) A51(4-7) A51(4-7)
+                        const __m512i lhs_mat_23_50_sp2 = _mm512_shuffle_epi32(lhs_mat_23_50, (_MM_PERM_ENUM)245); //A52(4-7) A52(4-7) A53(4-7) A53(4-7) A52(4-7) A52(4-7) A53(4-7) A53(4-7) A52(4-7) A52(4-7) A53(4-7) A53(4-7) A52(4-7) A52(4-7) A53(4-7) A53(4-7)
+
+                        const __m512i lhs_mat_01_51_sp2 = _mm512_shuffle_epi32(lhs_mat_01_51, (_MM_PERM_ENUM)245); //A50(12-15) A50(12-15) A51(12-15) A51(12-15) A50(12-15) A50(12-15) A51(12-15) A51(12-15) A50(12-15) A50(12-15) A51(12-15) A51(12-15) A50(12-15) A50(12-15) A51(12-15) A51(12-15)
+                        const __m512i lhs_mat_23_51_sp2 = _mm512_shuffle_epi32(lhs_mat_23_51, (_MM_PERM_ENUM)245); //A52(12-15) A52(12-15) A53(12-15) A53(12-15) A52(12-15) A52(12-15) A53(12-15) A53(12-15) A52(12-15) A52(12-15) A53(12-15) A53(12-15) A52(12-15) A52(12-15) A53(12-15) A53(12-15)
+
+                        const __m512i lhs_mat_01_60_sp2 = _mm512_shuffle_epi32(lhs_mat_01_60, (_MM_PERM_ENUM)245); //A60(4-7) A60(4-7) A61(4-7) A61(4-7) A60(4-7) A60(4-7) A61(4-7) A61(4-7) A60(4-7) A60(4-7) A61(4-7) A61(4-7) A60(4-7) A60(4-7) A61(4-7) A61(4-7)
+                        const __m512i lhs_mat_23_60_sp2 = _mm512_shuffle_epi32(lhs_mat_23_60, (_MM_PERM_ENUM)245); //A62(4-7) A62(4-7) A63(4-7) A63(4-7) A62(4-7) A62(4-7) A63(4-7) A63(4-7) A62(4-7) A62(4-7) A63(4-7) A63(4-7) A62(4-7) A62(4-7) A63(4-7) A63(4-7)
+
+                        const __m512i lhs_mat_01_61_sp2 = _mm512_shuffle_epi32(lhs_mat_01_61, (_MM_PERM_ENUM)245); //A60(12-15) A60(12-15) A61(12-15) A61(12-15) A60(12-15) A60(12-15) A61(12-15) A61(12-15) A60(12-15) A60(12-15) A61(12-15) A61(12-15) A60(12-15) A60(12-15) A61(12-15) A61(12-15)
+                        const __m512i lhs_mat_23_61_sp2 = _mm512_shuffle_epi32(lhs_mat_23_61, (_MM_PERM_ENUM)245); //A62(12-15) A62(12-15) A63(12-15) A63(12-15) A62(12-15) A62(12-15) A63(12-15) A63(12-15) A62(12-15) A62(12-15) A63(12-15) A63(12-15) A62(12-15) A62(12-15) A63(12-15) A63(12-15)
+
+                        const __m512i lhs_mat_01_70_sp2 = _mm512_shuffle_epi32(lhs_mat_01_70, (_MM_PERM_ENUM)245); //A70(4-7) A70(4-7) A71(4-7) A71(4-7) A70(4-7) A70(4-7) A71(4-7) A71(4-7) A70(4-7) A70(4-7) A71(4-7) A71(4-7) A70(4-7) A70(4-7) A71(4-7) A71(4-7)
+                        const __m512i lhs_mat_23_70_sp2 = _mm512_shuffle_epi32(lhs_mat_23_70, (_MM_PERM_ENUM)245); //A72(4-7) A72(4-7) A73(4-7) A73(4-7) A72(4-7) A72(4-7) A73(4-7) A73(4-7) A72(4-7) A72(4-7) A73(4-7) A73(4-7) A72(4-7) A72(4-7) A73(4-7) A73(4-7)
+
+                        const __m512i lhs_mat_01_71_sp2 = _mm512_shuffle_epi32(lhs_mat_01_71, (_MM_PERM_ENUM)245); //A70(12-15) A70(12-15) A71(12-15) A71(12-15) A70(12-15) A70(12-15) A71(12-15) A71(12-15) A70(12-15) A70(12-15) A71(12-15) A71(12-15) A70(12-15) A70(12-15) A71(12-15) A71(12-15)
+                        const __m512i lhs_mat_23_71_sp2 = _mm512_shuffle_epi32(lhs_mat_23_71, (_MM_PERM_ENUM)245); //A72(12-15) A72(12-15) A73(12-15) A73(12-15) A72(12-15) A72(12-15) A73(12-15) A73(12-15) A72(12-15) A72(12-15) A73(12-15) A73(12-15) A72(12-15) A72(12-15) A73(12-15) A73(12-15)
 
                         // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
-                        __m512i iacc_mat_00_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_01_00_sp1));
-                        __m512i iacc_mat_01_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_01_00_sp1));
-                        __m512i iacc_mat_10_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_23_00_sp1));
-                        __m512i iacc_mat_11_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_23_00_sp1));
-                        __m512i iacc_mat_00_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_01_10_sp1));
-                        __m512i iacc_mat_01_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_01_10_sp1));
-                        __m512i iacc_mat_10_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_23_10_sp1));
-                        __m512i iacc_mat_11_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_23_10_sp1));
+                        __m512i iacc_mat_00_0_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_01_00_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_01_01_sp1));
+                        __m512i iacc_mat_01_0_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_01_00_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_01_01_sp1));
 
-                        __m512i iacc_mat_00_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_01_00_sp2));
-                        __m512i iacc_mat_01_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_01_00_sp2));
-                        __m512i iacc_mat_10_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_23_00_sp2));
-                        __m512i iacc_mat_11_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_23_00_sp2));
-                        __m512i iacc_mat_00_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_01_10_sp2));
-                        __m512i iacc_mat_01_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_01_10_sp2));
-                        __m512i iacc_mat_10_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_23_10_sp2));
-                        __m512i iacc_mat_11_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_23_10_sp2));
+                        __m512i iacc_mat_10_0_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_23_00_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_23_01_sp1));
+                        __m512i iacc_mat_11_0_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_23_00_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_23_01_sp1));
 
-                        // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+                        __m512i iacc_mat_00_1_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_01_10_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_01_11_sp1));
+                        __m512i iacc_mat_01_1_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_01_10_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_01_11_sp1));
+
+                        __m512i iacc_mat_10_1_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_23_10_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_23_11_sp1));
+                        __m512i iacc_mat_11_1_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_23_10_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_23_11_sp1));
+
+                        __m512i iacc_mat_00_2_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_20_sp1, lhs_mat_01_20_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_21_sp1, lhs_mat_01_21_sp1));
+                        __m512i iacc_mat_01_2_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_20_sp1, lhs_mat_01_20_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_21_sp1, lhs_mat_01_21_sp1));
+
+                        __m512i iacc_mat_10_2_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_20_sp1, lhs_mat_23_20_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_21_sp1, lhs_mat_23_21_sp1));
+                        __m512i iacc_mat_11_2_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_20_sp1, lhs_mat_23_20_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_21_sp1, lhs_mat_23_21_sp1));
+
+                        __m512i iacc_mat_00_3_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_30_sp1, lhs_mat_01_30_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_31_sp1, lhs_mat_01_31_sp1));
+                        __m512i iacc_mat_01_3_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_30_sp1, lhs_mat_01_30_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_31_sp1, lhs_mat_01_31_sp1));
+
+                        __m512i iacc_mat_10_3_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_30_sp1, lhs_mat_23_30_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_31_sp1, lhs_mat_23_31_sp1));
+                        __m512i iacc_mat_11_3_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_30_sp1, lhs_mat_23_30_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_31_sp1, lhs_mat_23_31_sp1));
+
+                        __m512i iacc_mat_00_4_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_40_sp1, lhs_mat_01_40_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_41_sp1, lhs_mat_01_41_sp1));
+                        __m512i iacc_mat_01_4_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_40_sp1, lhs_mat_01_40_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_41_sp1, lhs_mat_01_41_sp1));
+
+                        __m512i iacc_mat_10_4_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_40_sp1, lhs_mat_23_40_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_41_sp1, lhs_mat_23_41_sp1));
+                        __m512i iacc_mat_11_4_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_40_sp1, lhs_mat_23_40_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_41_sp1, lhs_mat_23_41_sp1));
+
+                        __m512i iacc_mat_00_5_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_50_sp1, lhs_mat_01_50_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_51_sp1, lhs_mat_01_51_sp1));
+                        __m512i iacc_mat_01_5_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_50_sp1, lhs_mat_01_50_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_51_sp1, lhs_mat_01_51_sp1));
+
+                        __m512i iacc_mat_10_5_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_50_sp1, lhs_mat_23_50_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_51_sp1, lhs_mat_23_51_sp1));
+                        __m512i iacc_mat_11_5_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_50_sp1, lhs_mat_23_50_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_51_sp1, lhs_mat_23_51_sp1));
+
+                        __m512i iacc_mat_00_6_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_60_sp1, lhs_mat_01_60_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_61_sp1, lhs_mat_01_61_sp1));
+                        __m512i iacc_mat_01_6_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_60_sp1, lhs_mat_01_60_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_61_sp1, lhs_mat_01_61_sp1));
+
+                        __m512i iacc_mat_10_6_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_60_sp1, lhs_mat_23_60_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_61_sp1, lhs_mat_23_61_sp1));
+                        __m512i iacc_mat_11_6_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_60_sp1, lhs_mat_23_60_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_61_sp1, lhs_mat_23_61_sp1));
+
+                        __m512i iacc_mat_00_7_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_70_sp1, lhs_mat_01_70_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_71_sp1, lhs_mat_01_71_sp1));
+                        __m512i iacc_mat_01_7_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_70_sp1, lhs_mat_01_70_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_71_sp1, lhs_mat_01_71_sp1));
+
+                        __m512i iacc_mat_10_7_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_70_sp1, lhs_mat_23_70_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_71_sp1, lhs_mat_23_71_sp1));
+                        __m512i iacc_mat_11_7_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_70_sp1, lhs_mat_23_70_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_71_sp1, lhs_mat_23_71_sp1));
+
+
+                        __m512i iacc_mat_00_0_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_01_00_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_01_01_sp2));
+                        __m512i iacc_mat_01_0_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_01_00_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_01_01_sp2));
+
+                        __m512i iacc_mat_10_0_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_23_00_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_23_01_sp2));
+                        __m512i iacc_mat_11_0_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_23_00_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_23_01_sp2));
+
+                        __m512i iacc_mat_00_1_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_01_10_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_01_11_sp2));
+                        __m512i iacc_mat_01_1_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_01_10_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_01_11_sp2));
+
+                        __m512i iacc_mat_10_1_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_23_10_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_23_11_sp2));
+                        __m512i iacc_mat_11_1_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_23_10_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_23_11_sp2));
+
+                        __m512i iacc_mat_00_2_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_20_sp2, lhs_mat_01_20_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_21_sp2, lhs_mat_01_21_sp2));
+                        __m512i iacc_mat_01_2_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_20_sp2, lhs_mat_01_20_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_21_sp2, lhs_mat_01_21_sp2));
+
+                        __m512i iacc_mat_10_2_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_20_sp2, lhs_mat_23_20_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_21_sp2, lhs_mat_23_21_sp2));
+                        __m512i iacc_mat_11_2_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_20_sp2, lhs_mat_23_20_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_21_sp2, lhs_mat_23_21_sp2));
+
+                        __m512i iacc_mat_00_3_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_30_sp2, lhs_mat_01_30_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_31_sp2, lhs_mat_01_31_sp2));
+                        __m512i iacc_mat_01_3_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_30_sp2, lhs_mat_01_30_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_31_sp2, lhs_mat_01_31_sp2));
+
+                        __m512i iacc_mat_10_3_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_30_sp2, lhs_mat_23_30_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_31_sp2, lhs_mat_23_31_sp2));
+                        __m512i iacc_mat_11_3_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_30_sp2, lhs_mat_23_30_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_31_sp2, lhs_mat_23_31_sp2));
+
+                        __m512i iacc_mat_00_4_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_40_sp2, lhs_mat_01_40_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_41_sp2, lhs_mat_01_41_sp2));
+                        __m512i iacc_mat_01_4_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_40_sp2, lhs_mat_01_40_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_41_sp2, lhs_mat_01_41_sp2));
+
+                        __m512i iacc_mat_10_4_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_40_sp2, lhs_mat_23_40_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_41_sp2, lhs_mat_23_41_sp2));
+                        __m512i iacc_mat_11_4_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_40_sp2, lhs_mat_23_40_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_41_sp2, lhs_mat_23_41_sp2));
+
+                        __m512i iacc_mat_00_5_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_50_sp2, lhs_mat_01_50_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_51_sp2, lhs_mat_01_51_sp2));
+                        __m512i iacc_mat_01_5_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_50_sp2, lhs_mat_01_50_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_51_sp2, lhs_mat_01_51_sp2));
+
+                        __m512i iacc_mat_10_5_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_50_sp2, lhs_mat_23_50_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_51_sp2, lhs_mat_23_51_sp2));
+                        __m512i iacc_mat_11_5_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_50_sp2, lhs_mat_23_50_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_51_sp2, lhs_mat_23_51_sp2));
+
+                        __m512i iacc_mat_00_6_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_60_sp2, lhs_mat_01_60_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_61_sp2, lhs_mat_01_61_sp2));
+                        __m512i iacc_mat_01_6_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_60_sp2, lhs_mat_01_60_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_61_sp2, lhs_mat_01_61_sp2));
+
+                        __m512i iacc_mat_10_6_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_60_sp2, lhs_mat_23_60_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_61_sp2, lhs_mat_23_61_sp2));
+                        __m512i iacc_mat_11_6_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_60_sp2, lhs_mat_23_60_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_61_sp2, lhs_mat_23_61_sp2));
+
+                        __m512i iacc_mat_00_7_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_70_sp2, lhs_mat_01_70_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_71_sp2, lhs_mat_01_71_sp2));
+                        __m512i iacc_mat_01_7_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_70_sp2, lhs_mat_01_70_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_71_sp2, lhs_mat_01_71_sp2));
+
+                        __m512i iacc_mat_10_7_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_70_sp2, lhs_mat_23_70_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_71_sp2, lhs_mat_23_71_sp2));
+                        __m512i iacc_mat_11_7_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_70_sp2, lhs_mat_23_70_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_71_sp2, lhs_mat_23_71_sp2));
+
+                        // Combine results from both shuffle patterns for each output block
                         __m512i iacc_mat_00_0 = _mm512_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2);
                         __m512i iacc_mat_01_0 = _mm512_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2);
                         __m512i iacc_mat_10_0 = _mm512_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2);
@@ -2122,6 +4081,37 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                         __m512i iacc_mat_10_1 = _mm512_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2);
                         __m512i iacc_mat_11_1 = _mm512_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2);
 
+                        __m512i iacc_mat_00_2 = _mm512_add_epi16(iacc_mat_00_2_sp1, iacc_mat_00_2_sp2);
+                        __m512i iacc_mat_01_2 = _mm512_add_epi16(iacc_mat_01_2_sp1, iacc_mat_01_2_sp2);
+                        __m512i iacc_mat_10_2 = _mm512_add_epi16(iacc_mat_10_2_sp1, iacc_mat_10_2_sp2);
+                        __m512i iacc_mat_11_2 = _mm512_add_epi16(iacc_mat_11_2_sp1, iacc_mat_11_2_sp2);
+
+                        __m512i iacc_mat_00_3 = _mm512_add_epi16(iacc_mat_00_3_sp1, iacc_mat_00_3_sp2);
+                        __m512i iacc_mat_01_3 = _mm512_add_epi16(iacc_mat_01_3_sp1, iacc_mat_01_3_sp2);
+                        __m512i iacc_mat_10_3 = _mm512_add_epi16(iacc_mat_10_3_sp1, iacc_mat_10_3_sp2);
+                        __m512i iacc_mat_11_3 = _mm512_add_epi16(iacc_mat_11_3_sp1, iacc_mat_11_3_sp2);
+
+                        __m512i iacc_mat_00_4 = _mm512_add_epi16(iacc_mat_00_4_sp1, iacc_mat_00_4_sp2);
+                        __m512i iacc_mat_01_4 = _mm512_add_epi16(iacc_mat_01_4_sp1, iacc_mat_01_4_sp2);
+                        __m512i iacc_mat_10_4 = _mm512_add_epi16(iacc_mat_10_4_sp1, iacc_mat_10_4_sp2);
+                        __m512i iacc_mat_11_4 = _mm512_add_epi16(iacc_mat_11_4_sp1, iacc_mat_11_4_sp2);
+
+                        __m512i iacc_mat_00_5 = _mm512_add_epi16(iacc_mat_00_5_sp1, iacc_mat_00_5_sp2);
+                        __m512i iacc_mat_01_5 = _mm512_add_epi16(iacc_mat_01_5_sp1, iacc_mat_01_5_sp2);
+                        __m512i iacc_mat_10_5 = _mm512_add_epi16(iacc_mat_10_5_sp1, iacc_mat_10_5_sp2);
+                        __m512i iacc_mat_11_5 = _mm512_add_epi16(iacc_mat_11_5_sp1, iacc_mat_11_5_sp2);
+
+                        __m512i iacc_mat_00_6 = _mm512_add_epi16(iacc_mat_00_6_sp1, iacc_mat_00_6_sp2);
+                        __m512i iacc_mat_01_6 = _mm512_add_epi16(iacc_mat_01_6_sp1, iacc_mat_01_6_sp2);
+                        __m512i iacc_mat_10_6 = _mm512_add_epi16(iacc_mat_10_6_sp1, iacc_mat_10_6_sp2);
+                        __m512i iacc_mat_11_6 = _mm512_add_epi16(iacc_mat_11_6_sp1, iacc_mat_11_6_sp2);
+
+                        __m512i iacc_mat_00_7 = _mm512_add_epi16(iacc_mat_00_7_sp1, iacc_mat_00_7_sp2);
+                        __m512i iacc_mat_01_7 = _mm512_add_epi16(iacc_mat_01_7_sp1, iacc_mat_01_7_sp2);
+                        __m512i iacc_mat_10_7 = _mm512_add_epi16(iacc_mat_10_7_sp1, iacc_mat_10_7_sp2);
+                        __m512i iacc_mat_11_7 = _mm512_add_epi16(iacc_mat_11_7_sp1, iacc_mat_11_7_sp2);
+
+                        // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
                         iacc_mat_00_0 = _mm512_madd_epi16(iacc_mat_00_0, scale_014589CD_0);
                         iacc_mat_01_0 = _mm512_madd_epi16(iacc_mat_01_0, scale_2367ABEF_0);
                         iacc_mat_10_0 = _mm512_madd_epi16(iacc_mat_10_0, scale_014589CD_0);
@@ -2132,20 +4122,46 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                         iacc_mat_10_1 = _mm512_madd_epi16(iacc_mat_10_1, scale_014589CD_1);
                         iacc_mat_11_1 = _mm512_madd_epi16(iacc_mat_11_1, scale_2367ABEF_1);
 
-                        // Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step)
-                        __m512i iacc_row_0_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_0, _mm512_shuffle_epi32(iacc_mat_01_0, (_MM_PERM_ENUM)78));
-                        __m512i iacc_row_1_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_0, (_MM_PERM_ENUM)78), iacc_mat_01_0);
-                        __m512i iacc_row_2_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_0, _mm512_shuffle_epi32(iacc_mat_11_0, (_MM_PERM_ENUM)78));
-                        __m512i iacc_row_3_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10_0, (_MM_PERM_ENUM)78), iacc_mat_11_0);
-                        __m512i iacc_row_0_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_1, _mm512_shuffle_epi32(iacc_mat_01_1, (_MM_PERM_ENUM)78));
-                        __m512i iacc_row_1_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_1, (_MM_PERM_ENUM)78), iacc_mat_01_1);
-                        __m512i iacc_row_2_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_1, _mm512_shuffle_epi32(iacc_mat_11_1, (_MM_PERM_ENUM)78));
-                        __m512i iacc_row_3_1 = _mm512_mask_blend_epi32(0xCCCC,_mm512_shuffle_epi32(iacc_mat_10_1, (_MM_PERM_ENUM)78), iacc_mat_11_1);
+                        iacc_mat_00_2 = _mm512_madd_epi16(iacc_mat_00_2, scale_014589CD_2);
+                        iacc_mat_01_2 = _mm512_madd_epi16(iacc_mat_01_2, scale_2367ABEF_2);
+                        iacc_mat_10_2 = _mm512_madd_epi16(iacc_mat_10_2, scale_014589CD_2);
+                        iacc_mat_11_2 = _mm512_madd_epi16(iacc_mat_11_2, scale_2367ABEF_2);
+
+                        iacc_mat_00_3 = _mm512_madd_epi16(iacc_mat_00_3, scale_014589CD_3);
+                        iacc_mat_01_3 = _mm512_madd_epi16(iacc_mat_01_3, scale_2367ABEF_3);
+                        iacc_mat_10_3 = _mm512_madd_epi16(iacc_mat_10_3, scale_014589CD_3);
+                        iacc_mat_11_3 = _mm512_madd_epi16(iacc_mat_11_3, scale_2367ABEF_3);
+
+                        iacc_mat_00_4 = _mm512_madd_epi16(iacc_mat_00_4, scale_014589CD_4);
+                        iacc_mat_01_4 = _mm512_madd_epi16(iacc_mat_01_4, scale_2367ABEF_4);
+                        iacc_mat_10_4 = _mm512_madd_epi16(iacc_mat_10_4, scale_014589CD_4);
+                        iacc_mat_11_4 = _mm512_madd_epi16(iacc_mat_11_4, scale_2367ABEF_4);
+
+                        iacc_mat_00_5 = _mm512_madd_epi16(iacc_mat_00_5, scale_014589CD_5);
+                        iacc_mat_01_5 = _mm512_madd_epi16(iacc_mat_01_5, scale_2367ABEF_5);
+                        iacc_mat_10_5 = _mm512_madd_epi16(iacc_mat_10_5, scale_014589CD_5);
+                        iacc_mat_11_5 = _mm512_madd_epi16(iacc_mat_11_5, scale_2367ABEF_5);
+
+                        iacc_mat_00_6 = _mm512_madd_epi16(iacc_mat_00_6, scale_014589CD_6);
+                        iacc_mat_01_6 = _mm512_madd_epi16(iacc_mat_01_6, scale_2367ABEF_6);
+                        iacc_mat_10_6 = _mm512_madd_epi16(iacc_mat_10_6, scale_014589CD_6);
+                        iacc_mat_11_6 = _mm512_madd_epi16(iacc_mat_11_6, scale_2367ABEF_6);
+
+                        iacc_mat_00_7 = _mm512_madd_epi16(iacc_mat_00_7, scale_014589CD_7);
+                        iacc_mat_01_7 = _mm512_madd_epi16(iacc_mat_01_7, scale_2367ABEF_7);
+                        iacc_mat_10_7 = _mm512_madd_epi16(iacc_mat_10_7, scale_014589CD_7);
+                        iacc_mat_11_7 = _mm512_madd_epi16(iacc_mat_11_7, scale_2367ABEF_7);
+
+                        __m512i iacc_mat_00 = _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(iacc_mat_00_0, iacc_mat_00_1), _mm512_add_epi32(iacc_mat_00_2, iacc_mat_00_3)), _mm512_add_epi32(_mm512_add_epi32(iacc_mat_00_4, iacc_mat_00_5), _mm512_add_epi32(iacc_mat_00_6, iacc_mat_00_7)));
+                        __m512i iacc_mat_01 = _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(iacc_mat_01_0, iacc_mat_01_1), _mm512_add_epi32(iacc_mat_01_2, iacc_mat_01_3)), _mm512_add_epi32(_mm512_add_epi32(iacc_mat_01_4, iacc_mat_01_5), _mm512_add_epi32(iacc_mat_01_6, iacc_mat_01_7)));
+                        __m512i iacc_mat_10 = _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(iacc_mat_10_0, iacc_mat_10_1), _mm512_add_epi32(iacc_mat_10_2, iacc_mat_10_3)), _mm512_add_epi32(_mm512_add_epi32(iacc_mat_10_4, iacc_mat_10_5), _mm512_add_epi32(iacc_mat_10_6, iacc_mat_10_7)));
+                        __m512i iacc_mat_11 = _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(iacc_mat_11_0, iacc_mat_11_1), _mm512_add_epi32(iacc_mat_11_2, iacc_mat_11_3)), _mm512_add_epi32(_mm512_add_epi32(iacc_mat_11_4, iacc_mat_11_5), _mm512_add_epi32(iacc_mat_11_6, iacc_mat_11_7)));
 
-                        __m512i iacc_row_0 = _mm512_add_epi32(iacc_row_0_0, iacc_row_0_1);
-                        __m512i iacc_row_1 = _mm512_add_epi32(iacc_row_1_0, iacc_row_1_1);
-                        __m512i iacc_row_2 = _mm512_add_epi32(iacc_row_2_0, iacc_row_2_1);
-                        __m512i iacc_row_3 = _mm512_add_epi32(iacc_row_3_0, iacc_row_3_1);
+                        // Straighten out to make 4 row vectors
+                        __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));
+                        __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
+                        __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
+                        __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
 
                         // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes
                         const __m128 row_scale_f32_sse = _mm_load_ps(a_ptrs[rp][b].d);
@@ -2158,10 +4174,31 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                         acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
                         acc_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);
 
-                        __m512i iacc_row_min_0 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)0), mins_01);
-                        __m512i iacc_row_min_1 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)85), mins_01);
-                        __m512i iacc_row_min_2 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)170), mins_01);
-                        __m512i iacc_row_min_3 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)255), mins_01);
+                        // Take two bsums from two Q8_Ks at a time and multiply with corresponding mins values from each Q2_K
+                        __m512i iacc_row_min_0_01 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_0123, (_MM_PERM_ENUM)0), mins_01);
+                        __m512i iacc_row_min_1_01 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_0123, (_MM_PERM_ENUM)170), mins_01);
+                        __m512i iacc_row_min_2_01 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_0123, (_MM_PERM_ENUM)0), mins_01);
+                        __m512i iacc_row_min_3_01 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_0123, (_MM_PERM_ENUM)170), mins_01);
+
+                        __m512i iacc_row_min_0_23 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_0123, (_MM_PERM_ENUM)85), mins_23);
+                        __m512i iacc_row_min_1_23 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_0123, (_MM_PERM_ENUM)255), mins_23);
+                        __m512i iacc_row_min_2_23 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_0123, (_MM_PERM_ENUM)85), mins_23);
+                        __m512i iacc_row_min_3_23 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_0123, (_MM_PERM_ENUM)255), mins_23);
+
+                        __m512i iacc_row_min_0_45 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_4567, (_MM_PERM_ENUM)0), mins_45);
+                        __m512i iacc_row_min_1_45 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_4567, (_MM_PERM_ENUM)170), mins_45);
+                        __m512i iacc_row_min_2_45 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_4567, (_MM_PERM_ENUM)0), mins_45);
+                        __m512i iacc_row_min_3_45 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_4567, (_MM_PERM_ENUM)170), mins_45);
+
+                        __m512i iacc_row_min_0_67 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_4567, (_MM_PERM_ENUM)85), mins_67);
+                        __m512i iacc_row_min_1_67 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_4567, (_MM_PERM_ENUM)255), mins_67);
+                        __m512i iacc_row_min_2_67 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_4567, (_MM_PERM_ENUM)85), mins_67);
+                        __m512i iacc_row_min_3_67 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_4567, (_MM_PERM_ENUM)255), mins_67);
+
+                        __m512i iacc_row_min_0 = _mm512_add_epi32(_mm512_add_epi32(iacc_row_min_0_01, iacc_row_min_0_23), _mm512_add_epi32(iacc_row_min_0_45,iacc_row_min_0_67));
+                        __m512i iacc_row_min_1 = _mm512_add_epi32(_mm512_add_epi32(iacc_row_min_1_01, iacc_row_min_1_23), _mm512_add_epi32(iacc_row_min_1_45,iacc_row_min_1_67));
+                        __m512i iacc_row_min_2 = _mm512_add_epi32(_mm512_add_epi32(iacc_row_min_2_01, iacc_row_min_2_23), _mm512_add_epi32(iacc_row_min_2_45,iacc_row_min_2_67));
+                        __m512i iacc_row_min_3 = _mm512_add_epi32(_mm512_add_epi32(iacc_row_min_3_01, iacc_row_min_3_23), _mm512_add_epi32(iacc_row_min_3_45,iacc_row_min_3_67));
 
                         acc_min_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_0), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[rp * 4]);
                         acc_min_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_1), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[rp * 4 + 1]);
@@ -2177,15 +4214,15 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
         }
     }
 
-    for (; y < nr / 4; y++) {
+    for (; y < nr / 4; y ++) {
 
         const block_q8_Kx4 * a_ptr = a_ptr_start + (y * nb);
 
-        // Take group of eight block_q4_kx8 structures at each pass of the loop and perform dot product operation
+        // Take group of eight block_q2_kx8 structures at each pass of the loop and perform dot product operation
         for (int64_t x = 0; x < anc / 8; x += 2) {
 
-            const block_q4_Kx8 * b_ptr_0 = b_ptr_start + ((x) * b_nb);
-            const block_q4_Kx8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);
+            const block_q2_Kx8 * b_ptr_0 = b_ptr_start + ((x) * b_nb);
+            const block_q2_Kx8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);
 
             // Master FP accumulators
             __m512 acc_rows[4];
@@ -2197,18 +4234,18 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
             for (int i = 0; i < 4; i++) {
                 acc_min_rows[i] = _mm512_setzero_ps();
             }
-
             // For super block
             for (int64_t b = 0; b < nb; b++) {
-                // Scale values - Load the sixteen scale values from two block_q4_kx8 structures
+                // Delta values - Load the sixteen scale values from two block_q2_kx8 structures
                 const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
 
-                // dmin values - Load the sixteen dmin values from two block_q4_kx8 structures
+                // dmin values - Load the sixteen dmin values from two block_q2_kx8 structures
                 const __m512 col_dmin_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].dmin, b_ptr_1[b].dmin);
 
-                // Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration
-                for (int sb = 0; sb < QK_K / 64; sb++) {
+                // Loop to iterate over the sixteen sub blocks of a super block - eight sub blocks are processed per iteration
+                for (int sb = 0; sb < QK_K / 128; sb++) {
 
+                    // Load the eight block_q2_k for eight sub blocks quantized values interleaved with each other in chunks of eight bytes - B0,B1 ....B6,B7
                     const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + sb * 256));
                     const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 32 + sb * 256));
                     const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 64 + sb * 256));
@@ -2255,110 +4292,186 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                     const __m512i rhs_raw_mat_014589CD_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_3), rhs_raw_mat_89CD_3, 1);
                     const __m512i rhs_raw_mat_2367ABEF_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_3), rhs_raw_mat_ABEF_3, 1);
 
-                    //4-bit -> 8-bit
-                    const __m512i rhs_mat_014589CD_00 = _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded); //B00(0-7) B01(0-7) B04(0-7) B05(0-7) B08(0-7) B09(0-7) B0C(0-7) B0D(0-7)
-                    const __m512i rhs_mat_2367ABEF_00 = _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded); //B02(0-7) B03(0-7) B06(0-7) B07(0-7) B0A(0-7) B0B(0-7) B0E(0-7) B0F(0-7)
-                    const __m512i rhs_mat_014589CD_01 = _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded); //B00(8-15) B01(8-15) B04(8-15) B05(8-15) B08(8-15) B09(8-15) B0C(8-15) B0D(8-15)
-                    const __m512i rhs_mat_2367ABEF_01 = _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded); //B02(8-15) B03(8-15) B06(8-15) B07(8-15) B0A(8-15) B0B(8-15) B0E(8-15) B0F(8-15)
+                    //2-bit -> 8-bit
+                    const __m512i rhs_mat_014589CD_00 = _mm512_and_si512(rhs_raw_mat_014589CD_0,m3bexpanded); //B00(0-7) B01(0-7) B04(0-7) B05(0-7) B08(0-7) B09(0-7) B0C(0-7) B0D(0-7)
+                    const __m512i rhs_mat_2367ABEF_00 = _mm512_and_si512(rhs_raw_mat_2367ABEF_0,m3bexpanded); //B02(0-7) B03(0-7) B06(0-7) B07(0-7) B0A(0-7) B0B(0-7) B0E(0-7) B0F(0-7)
+                    const __m512i rhs_mat_014589CD_01 = _mm512_and_si512(rhs_raw_mat_014589CD_1,m3bexpanded); //B00(8-15) B01(8-15) B04(8-15) B05(8-15) B08(8-15) B09(8-15) B0C(8-15) B0D(8-15)
+                    const __m512i rhs_mat_2367ABEF_01 = _mm512_and_si512(rhs_raw_mat_2367ABEF_1,m3bexpanded); //B02(8-15) B03(8-15) B06(8-15) B07(8-15) B0A(8-15) B0B(8-15) B0E(8-15) B0F(8-15)
+                    const __m512i rhs_mat_014589CD_10 = _mm512_and_si512(rhs_raw_mat_014589CD_2,m3bexpanded); //B10(0-7) B11(0-7) B14(0-7) B15(0-7) B18(0-7) B19(0-7) B1C(0-7) B1D(0-7)
+                    const __m512i rhs_mat_2367ABEF_10 = _mm512_and_si512(rhs_raw_mat_2367ABEF_2,m3bexpanded); //B12(0-7) B13(0-7) B16(0-7) B17(0-7) B1A(0-7) B1B(0-7) B1E(0-7) B1F(0-7)
+                    const __m512i rhs_mat_014589CD_11 = _mm512_and_si512(rhs_raw_mat_014589CD_3,m3bexpanded); //B10(8-15) B11(8-15) B14(8-15) B15(8-15) B18(8-15) B19(8-15) B1C(8-15) B1D(8-15)
+                    const __m512i rhs_mat_2367ABEF_11 = _mm512_and_si512(rhs_raw_mat_2367ABEF_3,m3bexpanded); //B12(8-15) B13(8-15) B16(8-15) B17(8-15) B1A(8-15) B1B(8-15) B1E(8-15) B1F(8-15)
 
-                    const __m512i rhs_mat_014589CD_02 = _mm512_and_si512(rhs_raw_mat_014589CD_2, m4bexpanded); //B00(16-23) B01(16-23) B04(16-23) B05(16-23) B08(16-23) B09(16-23) B0C(16-23) B0D(16-23)
-                    const __m512i rhs_mat_2367ABEF_02 = _mm512_and_si512(rhs_raw_mat_2367ABEF_2, m4bexpanded); //B02(16-23) B03(16-23) B06(16-23) B07(16-23) B0A(16-23) B0B(16-23) B0E(16-23) B0F(16-23)
-                    const __m512i rhs_mat_014589CD_03 = _mm512_and_si512(rhs_raw_mat_014589CD_3, m4bexpanded); //B00(24-31) B01(24-31) B04(24-31) B05(24-31) B08(24-31) B09(24-31) B0C(24-31) B0D(24-31)
-                    const __m512i rhs_mat_2367ABEF_03 = _mm512_and_si512(rhs_raw_mat_2367ABEF_3, m4bexpanded); //B02(24-31) B03(24-31) B06(24-31) B07(24-31) B0A(24-31) B0B(24-31) B0E(24-31) B0F(24-31)
+                    const __m512i rhs_mat_014589CD_20 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 2), m3bexpanded); //B20(0-7) B21(0-7) B24(0-7) B25(0-7) B28(0-7) B29(0-7) B2C(0-7) B2D(0-7)
+                    const __m512i rhs_mat_2367ABEF_20 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 2), m3bexpanded); //B22(0-7) B23(0-7) B26(0-7) B27(0-7) B2A(0-7) B2B(0-7) B2E(0-7) B2F(0-7)
 
-                    const __m512i rhs_mat_014589CD_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded); //B10(0-7) B11(0-7) B14(0-7) B15(0-7) B18(0-7) B19(0-7) B1C(0-7) B1D(0-7)
-                    const __m512i rhs_mat_2367ABEF_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded); //B12(0-7) B13(0-7) B16(0-7) B17(0-7) B1A(0-7) B1B(0-7) B1E(0-7) B1F(0-7)
-                    const __m512i rhs_mat_014589CD_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded); //B10(8-15) B11(8-15) B14(8-15) B15(8-15) B18(8-15) B19(8-15) B1C(8-15) B1D(8-15)
-                    const __m512i rhs_mat_2367ABEF_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded); //B12(8-15) B13(8-15) B16(8-15) B17(8-15) B1A(8-15) B1B(8-15) B1E(8-15) B1F(8-15)
+                    const __m512i rhs_mat_014589CD_21 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 2), m3bexpanded); //B20(8-15) B21(8-15) B24(8-15) B25(8-15) B28(8-15) B29(8-15) B2C(8-15) B2D(8-15)
+                    const __m512i rhs_mat_2367ABEF_21 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 2), m3bexpanded); //B22(8-15) B23(8-15) B26(8-15) B27(8-15) B2A(8-15) B2B(8-15) B2E(8-15) B2F(8-15)
 
-                    const __m512i rhs_mat_014589CD_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 4), m4bexpanded); //B10(16-23) B11(16-23) B14(16-23) B15(16-23) B18(16-23) B19(16-23) B1C(16-23) B1D(16-23)
-                    const __m512i rhs_mat_2367ABEF_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 4), m4bexpanded); //B12(16-23) B13(16-23) B16(16-23) B17(16-23) B1A(16-23) B1B(16-23) B1E(16-23) B1F(16-23)
-                    const __m512i rhs_mat_014589CD_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 4), m4bexpanded); //B10(24-31) B11(24-31) B14(24-31) B15(24-31) B18(24-31) B19(24-31) B1C(24-31) B1D(24-31)
-                    const __m512i rhs_mat_2367ABEF_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 4), m4bexpanded); //B12(24-31) B13(24-31) B16(24-31) B17(24-31) B1A(24-31) B1B(24-31) B1E(24-31) B1F(24-31)
+                    const __m512i rhs_mat_014589CD_30 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 2), m3bexpanded); //B30(0-7) B31(0-7) B34(0-7) B35(0-7) B38(0-7) B39(0-7) B3C(0-7) B3D(0-7)
+                    const __m512i rhs_mat_2367ABEF_30 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 2), m3bexpanded); //B32(0-7) B33(0-7) B36(0-7) B37(0-7) B3A(0-7) B3B(0-7) B3E(0-7) B3F(0-7)
+
+                    const __m512i rhs_mat_014589CD_31 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 2), m3bexpanded); //B30(8-15) B31(8-15) B34(8-15) B35(8-15) B38(8-15) B39(8-15) B3C(8-15) B3D(8-15)
+                    const __m512i rhs_mat_2367ABEF_31 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 2), m3bexpanded); //B32(8-15) B33(8-15) B36(8-15) B37(8-15) B3A(8-15) B3B(8-15) B3E(8-15) B3F(8-15)
+
+                    const __m512i rhs_mat_014589CD_40 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m3bexpanded); //B40(0-7) B41(0-7) B44(0-7) B45(0-7) B48(0-7) B49(0-7) B4C(0-7) B4D(0-7)
+                    const __m512i rhs_mat_2367ABEF_40 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m3bexpanded); //B42(0-7) B43(0-7) B46(0-7) B47(0-7) B4A(0-7) B4B(0-7) B4E(0-7) B4F(0-7)
+
+                    const __m512i rhs_mat_014589CD_41 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m3bexpanded); //B40(8-15) B41(8-15) B44(8-15) B45(8-15) B48(8-15) B49(8-15) B4C(8-15) B4D(8-15)
+                    const __m512i rhs_mat_2367ABEF_41 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m3bexpanded); //B42(8-15) B43(8-15) B46(8-15) B47(8-15) B4A(8-15) B4B(8-15) B4E(8-15) B4F(8-15)
+
+                    const __m512i rhs_mat_014589CD_50 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 4), m3bexpanded); //B50(0-7) B51(0-7) B54(0-7) B55(0-7) B58(0-7) B59(0-7) B5C(0-7) B5D(0-7)
+                    const __m512i rhs_mat_2367ABEF_50 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 4), m3bexpanded); //B52(0-7) B53(0-7) B56(0-7) B57(0-7) B5A(0-7) B5B(0-7) B5E(0-7) B5F(0-7)
+
+                    const __m512i rhs_mat_014589CD_51 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 4), m3bexpanded); //B50(8-15) B51(8-15) B54(8-15) B55(8-15) B58(8-15) B59(8-15) B5C(8-15) B5D(8-15)
+                    const __m512i rhs_mat_2367ABEF_51 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 4), m3bexpanded); //B52(8-15) B53(8-15) B56(8-15) B57(8-15) B5A(8-15) B5B(8-15) B5E(8-15) B5F(8-15)
+
+                    const __m512i rhs_mat_014589CD_60 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 6), m3bexpanded); //B60(0-7) B61(0-7) B64(0-7) B65(0-7) B68(0-7) B69(0-7) B6C(0-7) B6D(0-7)
+                    const __m512i rhs_mat_2367ABEF_60 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 6), m3bexpanded); //B62(0-7) B63(0-7) B66(0-7) B67(0-7) B6A(0-7) B6B(0-7) B6E(0-7) B6F(0-7)
+
+                    const __m512i rhs_mat_014589CD_61 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 6), m3bexpanded); //B60(8-15) B61(8-15) B64(8-15) B65(8-15) B68(8-15) B69(8-15) B6C(8-15) B6D(8-15)
+                    const __m512i rhs_mat_2367ABEF_61 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 6), m3bexpanded); //B62(8-15) B63(8-15) B66(8-15) B67(8-15) B6A(8-15) B6B(8-15) B6E(8-15) B6F(8-15)
+
+                    const __m512i rhs_mat_014589CD_70 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 6), m3bexpanded); //B70(0-7) B71(0-7) B74(0-7) B75(0-7) B78(0-7) B79(0-7) B7C(0-7) B7D(0-7)
+                    const __m512i rhs_mat_2367ABEF_70 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 6), m3bexpanded); //B72(0-7) B73(0-7) B76(0-7) B77(0-7) B7A(0-7) B7B(0-7) B7E(0-7) B7F(0-7)
+
+                    const __m512i rhs_mat_014589CD_71 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 6), m3bexpanded); //B70(8-15) B71(8-15) B74(8-15) B75(8-15) B78(8-15) B79(8-15) B7C(8-15) B7D(8-15)
+                    const __m512i rhs_mat_2367ABEF_71 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 6), m3bexpanded); //B72(8-15) B73(8-15) B76(8-15) B77(8-15) B7A(8-15) B7B(8-15) B7E(8-15) B7F(8-15)
 
-                    // Shuffle pattern one - right side input
                     const __m512i rhs_mat_014589CD_00_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3) B08(0-3) B09(0-3) B08(0-3) B09(0-3) B0C(0-3) B0D(0-3) B0C(0-3) B0D(0-3)
                     const __m512i rhs_mat_2367ABEF_00_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3) B0A(0-3) B0B(0-3) B0A(0-3) B0B(0-3) B0E(0-3) B0F(0-3) B0E(0-3) B0F(0-3)
+
                     const __m512i rhs_mat_014589CD_01_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11) B08(8-11) B09(8-11) B08(8-11) B09(8-11) B0C(8-11) B0D(8-11) B0C(8-11) B0D(8-11)
                     const __m512i rhs_mat_2367ABEF_01_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11) B0A(8-11) B0B(8-11) B0A(8-11) B0B(8-11) B0E(8-11) B0F(8-11) B0E(8-11) B0F(8-11)
-                    const __m512i rhs_mat_014589CD_02_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19) B08(16-19) B09(16-19) B08(16-19) B09(16-19) B0C(16-19) B0D(16-19) B0C(16-19) B0D(16-19)
-                    const __m512i rhs_mat_2367ABEF_02_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19) B0A(16-19) B0B(16-19) B0A(16-19) B0B(16-19) B0E(16-19) B0F(16-19) B0E(16-19) B0F(16-19)
-                    const __m512i rhs_mat_014589CD_03_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27) B08(24-27) B09(24-27) B08(24-27) B09(24-27) B0C(24-27) B0D(24-27) B0C(24-27) B0D(24-27)
-                    const __m512i rhs_mat_2367ABEF_03_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27) B0A(24-27) B0B(24-27) B0A(24-27) B0B(24-27) B0E(24-27) B0F(24-27) B0E(24-27) B0F(24-27)
 
                     const __m512i rhs_mat_014589CD_10_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3) B18(0-3) B19(0-3) B18(0-3) B19(0-3) B1C(0-3) B1D(0-3) B1C(0-3) B1D(0-3)
                     const __m512i rhs_mat_2367ABEF_10_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3) B1A(0-3) B1B(0-3) B1A(0-3) B1B(0-3) B1E(0-3) B1F(0-3) B1E(0-3) B1F(0-3)
+
                     const __m512i rhs_mat_014589CD_11_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11) B18(8-11) B19(8-11) B18(8-11) B19(8-11) B1C(8-11) B1D(8-11) B1C(8-11) B1D(8-11)
                     const __m512i rhs_mat_2367ABEF_11_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11) B1A(8-11) B1B(8-11) B1A(8-11) B1B(8-11) B1E(8-11) B1F(8-11) B1E(8-11) B1F(8-11)
-                    const __m512i rhs_mat_014589CD_12_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19) B18(16-19) B19(16-19) B18(16-19) B19(16-19) B1C(16-19) B1D(16-19) B1C(16-19) B1D(16-19)
-                    const __m512i rhs_mat_2367ABEF_12_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19) B1A(16-19) B1B(16-19) B1A(16-19) B1B(16-19) B1E(16-19) B1F(16-19) B1E(16-19) B1F(16-19)
-                    const __m512i rhs_mat_014589CD_13_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27) B18(24-27) B19(24-27) B18(24-27) B19(24-27) B1C(24-27) B1D(24-27) B1C(24-27) B1D(24-27)
-                    const __m512i rhs_mat_2367ABEF_13_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27) B1A(24-27) B1B(24-27) B1A(24-27) B1B(24-27) B1E(24-27) B1F(24-27) B1E(24-27) B1F(24-27)
 
-                    // Shuffle pattern two - right side input
+                    const __m512i rhs_mat_014589CD_20_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_20, (_MM_PERM_ENUM)136); //B20(0-3) B21(0-3) B20(0-3) B21(0-3) B24(0-3) B25(0-3) B24(0-3) B25(0-3) B28(0-3) B29(0-3) B28(0-3) B29(0-3) B2C(0-3) B2D(0-3) B2C(0-3) B2D(0-3)
+                    const __m512i rhs_mat_2367ABEF_20_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_20, (_MM_PERM_ENUM)136); //B22(0-3) B23(0-3) B22(0-3) B23(0-3) B26(0-3) B27(0-3) B26(0-3) B27(0-3) B2A(0-3) B2B(0-3) B2A(0-3) B2B(0-3) B2E(0-3) B2F(0-3) B2E(0-3) B2F(0-3)
+
+                    const __m512i rhs_mat_014589CD_21_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_21, (_MM_PERM_ENUM)136); //B20(8-11) B21(8-11) B20(8-11) B21(8-11) B24(8-11) B25(8-11) B24(8-11) B25(8-11) B28(8-11) B29(8-11) B28(8-11) B29(8-11) B2C(8-11) B2D(8-11) B2C(8-11) B2D(8-11)
+                    const __m512i rhs_mat_2367ABEF_21_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_21, (_MM_PERM_ENUM)136); //B22(8-11) B23(8-11) B22(8-11) B23(8-11) B26(8-11) B27(8-11) B26(8-11) B27(8-11) B2A(8-11) B2B(8-11) B2A(8-11) B2B(8-11) B2E(8-11) B2F(8-11) B2E(8-11) B2F(8-11)
+                    const __m512i rhs_mat_014589CD_30_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_30, (_MM_PERM_ENUM)136); ///B30(0-3) B31(0-3) B30(0-3) B31(0-3) B34(0-3) B35(0-3) B34(0-3) B35(0-3) B38(0-3) B39(0-3) B38(0-3) B39(0-3) B3C(0-3) B3D(0-3) B3C(0-3) B3D(0-3)
+                    const __m512i rhs_mat_2367ABEF_30_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_30, (_MM_PERM_ENUM)136); //B32(0-3) B33(0-3) B32(0-3) B33(0-3) B36(0-3) B37(0-3) B36(0-3) B37(0-3) B3A(0-3) B3B(0-3) B3A(0-3) B3B(0-3) B3E(0-3) B3F(0-3) B3E(0-3) B3F(0-3)
+
+                    const __m512i rhs_mat_014589CD_31_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_31, (_MM_PERM_ENUM)136); //B30(8-11) B31(8-11) B30(8-11) B31(8-11) B34(8-11) B35(8-11) B34(8-11) B35(8-11) B38(8-11) B39(8-11) B38(8-11) B39(8-11) B3C(8-11) B3D(8-11) B3C(8-11) B3D(8-11)
+                    const __m512i rhs_mat_2367ABEF_31_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_31, (_MM_PERM_ENUM)136); //B32(8-11) B33(8-11) B32(8-11) B33(8-11) B36(8-11) B37(8-11) B36(8-11) B37(8-11) B3A(8-11) B3B(8-11) B3A(8-11) B3B(8-11) B3E(8-11) B3F(8-11) B3E(8-11) B3F(8-11)
+
+                    const __m512i rhs_mat_014589CD_40_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_40, (_MM_PERM_ENUM)136); //B40(0-3) B41(0-3) B40(0-3) B41(0-3) B44(0-3) B45(0-3) B44(0-3) B45(0-3) B48(0-3) B49(0-3) B48(0-3) B49(0-3) B4C(0-3) B4D(0-3) B4C(0-3) B4D(0-3)
+                    const __m512i rhs_mat_2367ABEF_40_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_40, (_MM_PERM_ENUM)136); //B42(0-3) B43(0-3) B42(0-3) B43(0-3) B46(0-3) B47(0-3) B46(0-3) B47(0-3) B4A(0-3) B4B(0-3) B4A(0-3) B4B(0-3) B4E(0-3) B4F(0-3) B4E(0-3) B4F(0-3)
+
+                    const __m512i rhs_mat_014589CD_41_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_41, (_MM_PERM_ENUM)136); //B40(8-11) B41(8-11) B40(8-11) B41(8-11) B44(8-11) B45(8-11) B44(8-11) B45(8-11) B48(8-11) B49(8-11) B48(8-11) B49(8-11) B4C(8-11) B4D(8-11) B4C(8-11) B4D(8-11)
+                    const __m512i rhs_mat_2367ABEF_41_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_41, (_MM_PERM_ENUM)136); //B42(8-11) B43(8-11) B42(8-11) B43(8-11) B46(8-11) B47(8-11) B46(8-11) B47(8-11) B4A(8-11) B4B(8-11) B4A(8-11) B4B(8-11) B4E(8-11) B4F(8-11) B4E(8-11) B4F(8-11)
+
+                    const __m512i rhs_mat_014589CD_50_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_50, (_MM_PERM_ENUM)136); //B50(0-3) B51(0-3) B50(0-3) B51(0-3) B54(0-3) B55(0-3) B54(0-3) B55(0-3) B58(0-3) B59(0-3) B58(0-3) B59(0-3) B5C(0-3) B5D(0-3) B5C(0-3) B5D(0-3)
+                    const __m512i rhs_mat_2367ABEF_50_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_50, (_MM_PERM_ENUM)136); //B52(0-3) B53(0-3) B52(0-3) B53(0-3) B56(0-3) B57(0-3) B56(0-3) B57(0-3) B5A(0-3) B5B(0-3) B5A(0-3) B5B(0-3) B5E(0-3) B5F(0-3) B5E(0-3) B5F(0-3)
+
+                    const __m512i rhs_mat_014589CD_51_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_51, (_MM_PERM_ENUM)136); //B50(8-11) B51(8-11) B50(8-11) B51(8-11) B54(8-11) B55(8-11) B54(8-11) B55(8-11) B58(8-11) B59(8-11) B58(8-11) B59(8-11) B5C(8-11) B5D(8-11) B5C(8-11) B5D(8-11)
+                    const __m512i rhs_mat_2367ABEF_51_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_51, (_MM_PERM_ENUM)136); //B52(8-11) B53(8-11) B52(8-11) B53(8-11) B56(8-11) B57(8-11) B56(8-11) B57(8-11) B5A(8-11) B5B(8-11) B5A(8-11) B5B(8-11) B5E(8-11) B5F(8-11) B5E(8-11) B5F(8-11)
+
+                    const __m512i rhs_mat_014589CD_60_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_60, (_MM_PERM_ENUM)136); //B60(0-3) B61(0-3) B60(0-3) B61(0-3) B64(0-3) B65(0-3) B64(0-3) B65(0-3) B68(0-3) B69(0-3) B68(0-3) B69(0-3) B6C(0-3) B6D(0-3) B6C(0-3) B6D(0-3)
+                    const __m512i rhs_mat_2367ABEF_60_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_60, (_MM_PERM_ENUM)136); //B62(0-3) B63(0-3) B62(0-3) B63(0-3) B66(0-3) B67(0-3) B66(0-3) B67(0-3) B6A(0-3) B6B(0-3) B6A(0-3) B6B(0-3) B6E(0-3) B6F(0-3) B6E(0-3) B6F(0-3)
+
+                    const __m512i rhs_mat_014589CD_61_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_61, (_MM_PERM_ENUM)136); //B60(8-11) B61(8-11) B60(8-11) B61(8-11) B64(8-11) B65(8-11) B64(8-11) B65(8-11) B68(8-11) B69(8-11) B68(8-11) B69(8-11) B6C(8-11) B6D(8-11) B6C(8-11) B6D(8-11)
+                    const __m512i rhs_mat_2367ABEF_61_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_61, (_MM_PERM_ENUM)136); //B62(8-11) B63(8-11) B62(8-11) B63(8-11) B66(8-11) B67(8-11) B66(8-11) B67(8-11) B6A(8-11) B6B(8-11) B6A(8-11) B6B(8-11) B6E(8-11) B6F(8-11) B6E(8-11) B6F(8-11)
+
+                    const __m512i rhs_mat_014589CD_70_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_70, (_MM_PERM_ENUM)136); //B70(0-3) B71(0-3) B70(0-3) B71(0-3) B74(0-3) B75(0-3) B74(0-3) B75(0-3) B78(0-3) B79(0-3) B78(0-3) B79(0-3) B7C(0-3) B7D(0-3) B7C(0-3) B7D(0-3)
+                    const __m512i rhs_mat_2367ABEF_70_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_70, (_MM_PERM_ENUM)136); //B72(0-3) B73(0-3) B72(0-3) B73(0-3) B76(0-3) B77(0-3) B76(0-3) B77(0-3) B7A(0-3) B7B(0-3) B7A(0-3) B7B(0-3) B7E(0-3) B7F(0-3) B7E(0-3) B7F(0-3)
+
+                    const __m512i rhs_mat_014589CD_71_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_71, (_MM_PERM_ENUM)136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11) B08(8-11) B09(8-11) B08(8-11) B09(8-11) B0C(8-11) B0D(8-11) B0C(8-11) B0D(8-11)
+                    const __m512i rhs_mat_2367ABEF_71_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_71, (_MM_PERM_ENUM)136); //B72(8-11) B73(8-11) B72(8-11) B73(8-11) B76(8-11) B77(8-11) B76(8-11) B77(8-11) B7A(8-11) B7B(8-11) B7A(8-11) B7B(8-11) B7E(8-11) B7F(8-11) B7E(8-11) B7F(8-11)
+
                     const __m512i rhs_mat_014589CD_00_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7) B08(4-7) B09(4-7) B08(4-7) B09(4-7) B0C(4-7) B0D(4-7) B0C(4-7) B0D(4-7)
                     const __m512i rhs_mat_2367ABEF_00_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7) B0A(4-7) B0B(4-7) B0A(4-7) B0B(4-7) B0E(4-7) B0F(4-7) B0E(4-7) B0F(4-7)
+
                     const __m512i rhs_mat_014589CD_01_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15) B08(12-15) B09(12-15) B08(12-15) B09(12-15) B0C(12-15) B0D(12-15) B0C(12-15) B0D(12-15)
                     const __m512i rhs_mat_2367ABEF_01_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15) B0A(12-15) B0B(12-15) B0A(12-15) B0B(12-15) B0E(12-15) B0F(12-15) B0E(12-15) B0F(12-15)
-                    const __m512i rhs_mat_014589CD_02_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23) B08(20-23) B09(20-23) B08(20-23) B09(20-23) B0C(20-23) B0D(20-23) B0C(20-23) B0D(20-23)
-                    const __m512i rhs_mat_2367ABEF_02_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23) B0A(20-23) B0B(20-23) B0A(20-23) B0B(20-23) B0E(20-23) B0F(20-23) B0E(20-23) B0F(20-23)
-                    const __m512i rhs_mat_014589CD_03_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31) B08(28-31) B09(28-31) B08(28-31) B09(28-31) B0C(28-31) B0D(28-31) B0C(28-31) 0BD(28-31)
-                    const __m512i rhs_mat_2367ABEF_03_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31) B0A(28-31) B0B(28-31) B0A(28-31) B0B(28-31) B0E(28-31) B0F(28-31) B0E(28-31) B0F(28-31)
 
                     const __m512i rhs_mat_014589CD_10_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7) B18(4-7) B19(4-7) B18(4-7) B19(4-7) B1C(4-7) B1D(4-7) B1C(4-7) B1D(4-7)
                     const __m512i rhs_mat_2367ABEF_10_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7) B1A(4-7) B1B(4-7) B1A(4-7) B1B(4-7) B1E(4-7) B1F(4-7) B1E(4-7) B1F(4-7)
+
                     const __m512i rhs_mat_014589CD_11_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15) B18(12-15) B19(12-15) B18(12-15) B19(12-15) B1C(12-15) B1D(12-15) B1C(12-15) B1D(12-15)
                     const __m512i rhs_mat_2367ABEF_11_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15) B1A(12-15) B1B(12-15) B1A(12-15) B1B(12-15) B1E(12-15) B1F(12-15) B1E(12-15) B1F(12-15)
-                    const __m512i rhs_mat_014589CD_12_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23) B18(20-23) B19(20-23) B18(20-23) B19(20-23) B1C(20-23) B1D(20-23) B1C(20-23) B1D(20-23)
-                    const __m512i rhs_mat_2367ABEF_12_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23) B1A(20-23) B1B(20-23) B1A(20-23) B1B(20-23) B1E(20-23) B1F(20-23) B1E(20-23) B1F(20-23)
-                    const __m512i rhs_mat_014589CD_13_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31) B18(28-31) B19(28-31) B18(28-31) B19(28-31) B1C(28-31) B1D(28-31) B1C(28-31) B1D(28-31)
-                    const __m512i rhs_mat_2367ABEF_13_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31) B1A(28-31) B1B(28-31) B1A(28-31) B1B(28-31) B1E(28-31) B1F(28-31) B1E(28-31) B1F(28-31)
 
-                    uint32_t utmp_00[4], utmp_01[4], utmp_10[4], utmp_11[4];
+                    const __m512i rhs_mat_014589CD_20_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_20, (_MM_PERM_ENUM)221); //B20(4-7) B21(4-7) B20(4-7) B21(4-7) B24(4-7) B25(4-7) B24(4-7) B25(4-7) B28(4-7) B29(4-7) B28(4-7) B29(4-7) B2C(4-7) B2D(4-7) B2C(4-7) B2D(4-7)
+                    const __m512i rhs_mat_2367ABEF_20_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_20, (_MM_PERM_ENUM)221); //B22(4-7) B23(4-7) B22(4-7) B23(4-7) B26(4-7) B27(4-7) B26(4-7) B27(4-7) B2A(4-7) B2B(4-7) B2A(4-7) B2B(4-7) B2E(4-7) B2F(4-7) B2E(4-7) B2F(4-7)
 
-                    // Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together
-                    // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
-                    memcpy(utmp_00, b_ptr_0[b].scales + 24 * sb, 12);
-                    utmp_00[3] = ((utmp_00[2] >> 4) & kmask2) | (((utmp_00[1] >> 6) & kmask3) << 4);
-                    const uint32_t uaux_00 = utmp_00[1] & kmask1;
-                    utmp_00[1] = (utmp_00[2] & kmask2) | (((utmp_00[0] >> 6) & kmask3) << 4);
-                    utmp_00[2] = uaux_00;
-                    utmp_00[0] &= kmask1;
+                    const __m512i rhs_mat_014589CD_21_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_21, (_MM_PERM_ENUM)221); //B20(12-15) B21(12-15) B20(12-15) B21(12-15) B24(12-15) B25(12-15) B24(12-15) B25(12-15) B28(12-15) B29(12-15) B28(12-15) B29(12-15) B2C(12-15) B2D(12-15) B2C(12-15) B2D(12-15)
+                    const __m512i rhs_mat_2367ABEF_21_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_21, (_MM_PERM_ENUM)221); //B22(12-15) B23(12-15) B22(12-15) B23(12-15) B26(12-15) B27(12-15) B26(12-15) B27(12-15) B2A(12-15) B2B(12-15) B2A(12-15) B2B(12-15) B2E(12-15) B2F(12-15) B2E(12-15) B2F(12-15)
 
-                    // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
-                    memcpy(utmp_01, b_ptr_0[b].scales + 12 + sb * 24, 12);
-                    utmp_01[3] = ((utmp_01[2] >> 4) & kmask2) | (((utmp_01[1] >> 6) & kmask3) << 4);
-                    const uint32_t uaux_01 = utmp_01[1] & kmask1;
-                    utmp_01[1] = (utmp_01[2] & kmask2) | (((utmp_01[0] >> 6) & kmask3) << 4);
-                    utmp_01[2] = uaux_01;
-                    utmp_01[0] &= kmask1;
+                    const __m512i rhs_mat_014589CD_30_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_30, (_MM_PERM_ENUM)221); //B30(4-7) B31(4-7) B30(4-7) B31(4-7) B34(4-7) B35(4-7) B34(4-7) B35(4-7) B38(4-7) B39(4-7) B38(4-7) B39(4-7) B3C(4-7) B3D(4-7) B3C(4-7) B3D(4-7)
+                    const __m512i rhs_mat_2367ABEF_30_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_30, (_MM_PERM_ENUM)221); //B32(4-7) B33(4-7) B32(4-7) B33(4-7) B36(4-7) B37(4-7) B36(4-7) B37(4-7) B3A(4-7) B3B(4-7) B3A(4-7) B3B(4-7) B3E(4-7) B3F(4-7) B3E(4-7) B3F(4-7)
 
-                    // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
-                    memcpy(utmp_10, b_ptr_1[b].scales + sb * 24, 12);
-                    utmp_10[3] = ((utmp_10[2] >> 4) & kmask2) | (((utmp_10[1] >> 6) & kmask3) << 4);
-                    const uint32_t uaux_10 = utmp_10[1] & kmask1;
-                    utmp_10[1] = (utmp_10[2] & kmask2) | (((utmp_10[0] >> 6) & kmask3) << 4);
-                    utmp_10[2] = uaux_10;
-                    utmp_10[0] &= kmask1;
+                    const __m512i rhs_mat_014589CD_31_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_31, (_MM_PERM_ENUM)221); //B30(12-15) B31(12-15) B30(12-15) B31(12-15) B34(12-15) B35(12-15) B34(12-15) B35(12-15) B38(12-15) B39(12-15) B38(12-15) B39(12-15) B3C(12-15) B3D(12-15) B3C(12-15) B3D(12-15)
+                    const __m512i rhs_mat_2367ABEF_31_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_31, (_MM_PERM_ENUM)221); //B32(12-15) B33(12-15) B32(12-15) B33(12-15) B36(12-15) B37(12-15) B36(12-15) B37(12-15) B3A(12-15) B3B(12-15) B3A(12-15) B3B(12-15) B3E(12-15) B3F(12-15) B3E(12-15) B3F(12-15)
 
-                    // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
-                    memcpy(utmp_11, b_ptr_1[b].scales + 12 + sb * 24, 12);
-                    utmp_11[3] = ((utmp_11[2] >> 4) & kmask2) | (((utmp_11[1] >> 6) & kmask3) << 4);
-                    const uint32_t uaux_11 = utmp_11[1] & kmask1;
-                    utmp_11[1] = (utmp_11[2] & kmask2) | (((utmp_11[0] >> 6) & kmask3) << 4);
-                    utmp_11[2] = uaux_11;
-                    utmp_11[0] &= kmask1;
+                    const __m512i rhs_mat_014589CD_40_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_40, (_MM_PERM_ENUM)221); //B40(4-7) B41(4-7) B40(4-7) B41(4-7) B44(4-7) B45(4-7) B44(4-7) B45(4-7) B48(4-7) B49(4-7) B48(4-7) B49(4-7) B4C(4-7) B4D(4-7) B4C(4-7) B4D(4-7)
+                    const __m512i rhs_mat_2367ABEF_40_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_40, (_MM_PERM_ENUM)221); //B42(4-7) B43(4-7) B42(4-7) B43(4-7) B46(4-7) B47(4-7) B46(4-7) B47(4-7) B4A(4-7) B4B(4-7) B4A(4-7) B4B(4-7) B4E(4-7) B4F(4-7) B4E(4-7) B4F(4-7)
 
-                    // Scales of first sub block in the sb loop
-                    const __m256i mins_and_scales_0 = _mm256_set_epi32(utmp_10[3], utmp_10[2], utmp_10[1], utmp_10[0], utmp_00[3], utmp_00[2], utmp_00[1], utmp_00[0]);
-                    const __m512i scales_0 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0));
+                    const __m512i rhs_mat_014589CD_41_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_41, (_MM_PERM_ENUM)221); //B40(12-15) B41(12-15) B40(12-15) B41(12-15) B44(12-15) B45(12-15) B44(12-15) B45(12-15) B48(12-15) B49(12-15) B48(12-15) B49(12-15) B4C(12-15) B4D(12-15) B4C(12-15) B4D(12-15)
+                    const __m512i rhs_mat_2367ABEF_41_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_41, (_MM_PERM_ENUM)221); //B42(12-15) B43(12-15) B42(12-15) B43(12-15) B46(12-15) B47(12-15) B46(12-15) B47(12-15) B4A(12-15) B4B(12-15) B4A(12-15) B4B(12-15) B4E(12-15) B4F(12-15) B4E(12-15) B4F(12-15)
 
-                    // Scales of second sub block in the sb loop
-                    const __m256i mins_and_scales_1 = _mm256_set_epi32(utmp_11[3], utmp_11[2], utmp_11[1], utmp_11[0], utmp_01[3], utmp_01[2], utmp_01[1], utmp_01[0]);
-                    const __m512i scales_1 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1));
+                    const __m512i rhs_mat_014589CD_50_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_50, (_MM_PERM_ENUM)221); //B50(4-7) B51(4-7) B50(4-7) B51(4-7) B54(4-7) B55(4-7) B54(4-7) B55(4-7) B58(4-7) B59(4-7) B58(4-7) B59(4-7) B5C(4-7) B5D(4-7) B5C(4-7) B5D(4-7)
+                    const __m512i rhs_mat_2367ABEF_50_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_50, (_MM_PERM_ENUM)221); //B52(4-7) B53(4-7) B52(4-7) B53(4-7) B56(4-7) B57(4-7) B56(4-7) B57(4-7) B5A(4-7) B5B(4-7) B5A(4-7) B5B(4-7) B5E(4-7) B5F(4-7) B5E(4-7) B5F(4-7)
 
-                    // Mins of first and second sub block of Q4_K block are arranged side by side
-                    const __m512i mins_01 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(_mm256_shuffle_epi32(mins_and_scales_0, 78), _mm256_shuffle_epi32(mins_and_scales_1, 78)));
+                    const __m512i rhs_mat_014589CD_51_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_51, (_MM_PERM_ENUM)221); //B50(12-15) B51(12-15) B50(12-15) B51(12-15) B54(12-15) B55(12-15) B54(12-15) B55(12-15) B58(12-15) B59(12-15) B58(12-15) B59(12-15) B5C(12-15) B5D(12-15) B5C(12-15) B5D(12-15)
+                    const __m512i rhs_mat_2367ABEF_51_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_51, (_MM_PERM_ENUM)221); //B52(12-15) B53(12-15) B52(12-15) B53(12-15) B56(12-15) B57(12-15) B56(12-15) B57(12-15) B5A(12-15) B5B(12-15) B5A(12-15) B5B(12-15) B5E(12-15) B5F(12-15) B5E(12-15) B5F(12-15)
+
+                    const __m512i rhs_mat_014589CD_60_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_60, (_MM_PERM_ENUM)221); //B60(4-7) B61(4-7) B60(4-7) B61(4-7) B64(4-7) B65(4-7) B64(4-7) B65(4-7) B68(4-7) B69(4-7) B68(4-7) B69(4-7) B6C(4-7) B6D(4-7) B6C(4-7) B6D(4-7)
+                    const __m512i rhs_mat_2367ABEF_60_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_60, (_MM_PERM_ENUM)221); //B62(4-7) B63(4-7) B62(4-7) B63(4-7) B66(4-7) B67(4-7) B66(4-7) B67(4-7) B6A(4-7) B6B(4-7) B6A(4-7) B6B(4-7) B6E(4-7) B6F(4-7) B6E(4-7) B6F(4-7)
+
+                    const __m512i rhs_mat_014589CD_61_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_61, (_MM_PERM_ENUM)221); //B60(12-15) B61(12-15) B60(12-15) B61(12-15) B64(12-15) B65(12-15) B64(12-15) B65(12-15) B68(12-15) B69(12-15) B68(12-15) B69(12-15) B6C(12-15) B6D(12-15) B6C(12-15) B6D(12-15)
+                    const __m512i rhs_mat_2367ABEF_61_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_61, (_MM_PERM_ENUM)221); //B62(12-15) B63(12-15) B62(12-15) B63(12-15) B66(12-15) B67(12-15) B66(12-15) B67(12-15) B6A(12-15) B6B(12-15) B6A(12-15) B6B(12-15) B6E(12-15) B6F(12-15) B6E(12-15) B6F(12-15)
+
+                    const __m512i rhs_mat_014589CD_70_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_70, (_MM_PERM_ENUM)221); //B70(4-7) B71(4-7) B70(4-7) B71(4-7) B74(4-7) B75(4-7) B74(4-7) B75(4-7) B78(4-7) B79(4-7) B78(4-7) B79(4-7) B7C(4-7) B7D(4-7) B7C(4-7) B7D(4-7)
+                    const __m512i rhs_mat_2367ABEF_70_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_70, (_MM_PERM_ENUM)221); //B72(4-7) B73(4-7) B72(4-7) B73(4-7) B76(4-7) B77(4-7) B76(4-7) B77(4-7) B7A(4-7) B7B(4-7) B7A(4-7) B7B(4-7) B7E(4-7) B7F(4-7) B7E(4-7) B7F(4-7)
+
+                    const __m512i rhs_mat_014589CD_71_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_71, (_MM_PERM_ENUM)221); //B70(12-15) B71(12-15) B70(12-15) B71(12-15) B74(12-15) B75(12-15) B74(12-15) B75(12-15) B78(12-15) B79(12-15) B78(12-15) B79(12-15) B7C(12-15) B7D(12-15) B7C(12-15) B7D(12-15)
+                    const __m512i rhs_mat_2367ABEF_71_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_71, (_MM_PERM_ENUM)221); //B72(12-15) B73(12-15) B72(12-15) B73(12-15) B76(12-15) B77(12-15) B76(12-15) B77(12-15) B7A(12-15) B7B(12-15) B7A(12-15) B7B(12-15) B7E(12-15) B7F(12-15) B7E(12-15) B7F(12-15)
+
+                    //notation:superblock subblock
+                    //s00 m00  s01 m01   s10 m10  s11 m11  s20 m20  s21 m21   s30 m30  s31 m31  s40 m40  s41 m41   s50 m50  s51 m51  s60 m60  s61 m61   s70 m70  s71 m71
+
+                    const __m128i mins_and_scales_01_0 = _mm_loadu_si128((const __m128i *)(b_ptr_0[b].scales + sb * 64));
+                    const __m128i mins_and_scales_23_0 = _mm_loadu_si128((const __m128i *)(b_ptr_0[b].scales + 16 + sb * 64));
+                    const __m128i mins_and_scales_45_0 = _mm_loadu_si128((const __m128i *)(b_ptr_0[b].scales + 32 + sb * 64));
+                    const __m128i mins_and_scales_67_0 = _mm_loadu_si128((const __m128i *)(b_ptr_0[b].scales + 48 + sb * 64));
+
+                    const __m128i mins_and_scales_01_1 = _mm_loadu_si128((const __m128i *)(b_ptr_1[b].scales + sb * 64));
+                    const __m128i mins_and_scales_23_1 = _mm_loadu_si128((const __m128i *)(b_ptr_1[b].scales + 16 + sb * 64));
+                    const __m128i mins_and_scales_45_1 = _mm_loadu_si128((const __m128i *)(b_ptr_1[b].scales + 32 + sb * 64));
+                    const __m128i mins_and_scales_67_1 = _mm_loadu_si128((const __m128i *)(b_ptr_1[b].scales + 48 + sb * 64));
+
+                    // Combine mins and scales for sub-blocks: 0-1, 2-3, 4-5, 6-7 in the sb loop
+                    const __m256i mins_and_scales_01 = _mm256_insertf128_si256(_mm256_castsi128_si256(mins_and_scales_01_0), mins_and_scales_01_1, 1);
+                    const __m256i mins_and_scales_23 = _mm256_insertf128_si256(_mm256_castsi128_si256(mins_and_scales_23_0), mins_and_scales_23_1, 1);
+                    const __m256i mins_and_scales_45 = _mm256_insertf128_si256(_mm256_castsi128_si256(mins_and_scales_45_0), mins_and_scales_45_1, 1);
+                    const __m256i mins_and_scales_67 = _mm256_insertf128_si256(_mm256_castsi128_si256(mins_and_scales_67_0), mins_and_scales_67_1, 1);
+
+                    // Extract scales which is lower half from mins_and_scales
+                    const __m256i scales_01 = _mm256_and_si256(mins_and_scales_01, m4b);
+                    const __m256i scales_23 = _mm256_and_si256(mins_and_scales_23, m4b);
+                    const __m256i scales_45 = _mm256_and_si256(mins_and_scales_45, m4b);
+                    const __m256i scales_67 = _mm256_and_si256(mins_and_scales_67, m4b);
+
+                    // Extract mins which is upper half from mins_and_scales
+                    const __m512i mins_01 = _mm512_cvtepu8_epi16(_mm256_and_si256(_mm256_srli_epi16(mins_and_scales_01, 4), m4b));
+                    const __m512i mins_23 = _mm512_cvtepu8_epi16(_mm256_and_si256(_mm256_srli_epi16(mins_and_scales_23, 4), m4b));
+                    const __m512i mins_45 = _mm512_cvtepu8_epi16(_mm256_and_si256(_mm256_srli_epi16(mins_and_scales_45, 4), m4b));
+                    const __m512i mins_67 = _mm512_cvtepu8_epi16(_mm256_and_si256(_mm256_srli_epi16(mins_and_scales_67, 4), m4b));
+
+                    const __m512i scales_0 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_01, scalesmask1));
+                    const __m512i scales_1 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_01, scalesmask2));
+                    const __m512i scales_2 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_23, scalesmask1));
+                    const __m512i scales_3 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_23, scalesmask2));
+                    const __m512i scales_4 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_45, scalesmask1));
+                    const __m512i scales_5 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_45, scalesmask2));
+                    const __m512i scales_6 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_67, scalesmask1));
+                    const __m512i scales_7 = _mm512_cvtepu8_epi16(_mm256_shuffle_epi8(scales_67, scalesmask2));
 
                     const __m512i scale_014589CD_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)68);
                     const __m512i scale_2367ABEF_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)238);
@@ -2366,115 +4479,327 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                     const __m512i scale_014589CD_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)68);
                     const __m512i scale_2367ABEF_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)238);
 
+                    const __m512i scale_014589CD_2 = _mm512_shuffle_epi32(scales_2, (_MM_PERM_ENUM)68);
+                    const __m512i scale_2367ABEF_2 = _mm512_shuffle_epi32(scales_2, (_MM_PERM_ENUM)238);
+
+                    const __m512i scale_014589CD_3 = _mm512_shuffle_epi32(scales_3, (_MM_PERM_ENUM)68);
+                    const __m512i scale_2367ABEF_3 = _mm512_shuffle_epi32(scales_3, (_MM_PERM_ENUM)238);
+
+                    const __m512i scale_014589CD_4 = _mm512_shuffle_epi32(scales_4, (_MM_PERM_ENUM)68);
+                    const __m512i scale_2367ABEF_4 = _mm512_shuffle_epi32(scales_4, (_MM_PERM_ENUM)238);
+
+                    const __m512i scale_014589CD_5 = _mm512_shuffle_epi32(scales_5, (_MM_PERM_ENUM)68);
+                    const __m512i scale_2367ABEF_5 = _mm512_shuffle_epi32(scales_5, (_MM_PERM_ENUM)238);
+
+                    const __m512i scale_014589CD_6 = _mm512_shuffle_epi32(scales_6, (_MM_PERM_ENUM)68);
+                    const __m512i scale_2367ABEF_6 = _mm512_shuffle_epi32(scales_6, (_MM_PERM_ENUM)238);
+
+                    const __m512i scale_014589CD_7 = _mm512_shuffle_epi32(scales_7, (_MM_PERM_ENUM)68);
+                    const __m512i scale_2367ABEF_7 = _mm512_shuffle_epi32(scales_7, (_MM_PERM_ENUM)238);
+
                     // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3
                     // Loaded as set of 128 bit vectors and repeated into a 256 bit vector
-                    __m256i lhs_mat_ymm_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 256 * sb)));
+                    __m256i lhs_mat_ymm_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 512 * sb)));
                     __m256i lhs_mat_ymm_01_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 0);
                     __m256i lhs_mat_ymm_23_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 17);
-                    __m256i lhs_mat_ymm_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 32 + 256 * sb)));
+                    __m256i lhs_mat_ymm_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 32 + 512 * sb)));
                     __m256i lhs_mat_ymm_01_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 0);
                     __m256i lhs_mat_ymm_23_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 17);
-                    __m256i lhs_mat_ymm_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 64 + 256 * sb)));
-                    __m256i lhs_mat_ymm_01_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 0);
-                    __m256i lhs_mat_ymm_23_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 17);
-                    __m256i lhs_mat_ymm_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 96 + 256 * sb)));
-                    __m256i lhs_mat_ymm_01_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 0);
-                    __m256i lhs_mat_ymm_23_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 17);
-                    __m256i lhs_mat_ymm_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 128 + 256 * sb)));
+                    __m256i lhs_mat_ymm_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 64 + 512 * sb)));
                     __m256i lhs_mat_ymm_01_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 0);
                     __m256i lhs_mat_ymm_23_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 17);
-                    __m256i lhs_mat_ymm_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 160 + 256 * sb)));
+                    __m256i lhs_mat_ymm_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 96 + 512 * sb)));
                     __m256i lhs_mat_ymm_01_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 0);
                     __m256i lhs_mat_ymm_23_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 17);
-                    __m256i lhs_mat_ymm_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 192 + 256 * sb)));
-                    __m256i lhs_mat_ymm_01_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 0);
-                    __m256i lhs_mat_ymm_23_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 17);
-                    __m256i lhs_mat_ymm_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 224 + 256 * sb)));
-                    __m256i lhs_mat_ymm_01_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 0);
-                    __m256i lhs_mat_ymm_23_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 17);
+                    __m256i lhs_mat_ymm_0123_20 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 128 + 512 * sb)));
+                    __m256i lhs_mat_ymm_01_20 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_20, lhs_mat_ymm_0123_20, 0);
+                    __m256i lhs_mat_ymm_23_20 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_20, lhs_mat_ymm_0123_20, 17);
+                    __m256i lhs_mat_ymm_0123_21 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 160 + 512 * sb)));
+                    __m256i lhs_mat_ymm_01_21 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_21, lhs_mat_ymm_0123_21, 0);
+                    __m256i lhs_mat_ymm_23_21 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_21, lhs_mat_ymm_0123_21, 17);
+                    __m256i lhs_mat_ymm_0123_30 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 192 + 512 * sb)));
+                    __m256i lhs_mat_ymm_01_30 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_30, lhs_mat_ymm_0123_30, 0);
+                    __m256i lhs_mat_ymm_23_30 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_30, lhs_mat_ymm_0123_30, 17);
+                    __m256i lhs_mat_ymm_0123_31 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 224 + 512 * sb)));
+                    __m256i lhs_mat_ymm_01_31 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_31, lhs_mat_ymm_0123_31, 0);
+                    __m256i lhs_mat_ymm_23_31 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_31, lhs_mat_ymm_0123_31, 17);
+
+                    __m256i lhs_mat_ymm_0123_40 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 256 + 512 * sb)));
+                    __m256i lhs_mat_ymm_01_40 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_40, lhs_mat_ymm_0123_40, 0);
+                    __m256i lhs_mat_ymm_23_40 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_40, lhs_mat_ymm_0123_40, 17);
+                    __m256i lhs_mat_ymm_0123_41 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 288 + 512 * sb)));
+                    __m256i lhs_mat_ymm_01_41 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_41, lhs_mat_ymm_0123_41, 0);
+                    __m256i lhs_mat_ymm_23_41 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_41, lhs_mat_ymm_0123_41, 17);
+                    __m256i lhs_mat_ymm_0123_50 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 320 + 512 * sb)));
+                    __m256i lhs_mat_ymm_01_50 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_50, lhs_mat_ymm_0123_50, 0);
+                    __m256i lhs_mat_ymm_23_50 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_50, lhs_mat_ymm_0123_50, 17);
+                    __m256i lhs_mat_ymm_0123_51 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 352 + 512 * sb)));
+                    __m256i lhs_mat_ymm_01_51 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_51, lhs_mat_ymm_0123_51, 0);
+                    __m256i lhs_mat_ymm_23_51 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_51, lhs_mat_ymm_0123_51, 17);
+                    __m256i lhs_mat_ymm_0123_60 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 384 + 512 * sb)));
+                    __m256i lhs_mat_ymm_01_60 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_60, lhs_mat_ymm_0123_60, 0);
+                    __m256i lhs_mat_ymm_23_60 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_60, lhs_mat_ymm_0123_60, 17);
+                    __m256i lhs_mat_ymm_0123_61 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 416 + 512 * sb)));
+                    __m256i lhs_mat_ymm_01_61 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_61, lhs_mat_ymm_0123_61, 0);
+                    __m256i lhs_mat_ymm_23_61 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_61, lhs_mat_ymm_0123_61, 17);
+                    __m256i lhs_mat_ymm_0123_70 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 448 + 512 * sb)));
+                    __m256i lhs_mat_ymm_01_70 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_70, lhs_mat_ymm_0123_70, 0);
+                    __m256i lhs_mat_ymm_23_70 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_70, lhs_mat_ymm_0123_70, 17);
+                    __m256i lhs_mat_ymm_0123_71 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 480 + 512 * sb)));
+                    __m256i lhs_mat_ymm_01_71 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_71, lhs_mat_ymm_0123_71, 0);
+                    __m256i lhs_mat_ymm_23_71 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_71, lhs_mat_ymm_0123_71, 17);
 
-                    //Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into a 512 bit vector
                     __m512i lhs_mat_01_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_00), lhs_mat_ymm_01_00, 1);
                     __m512i lhs_mat_23_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_00), lhs_mat_ymm_23_00, 1);
                     __m512i lhs_mat_01_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_01), lhs_mat_ymm_01_01, 1);
                     __m512i lhs_mat_23_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_01), lhs_mat_ymm_23_01, 1);
-                    __m512i lhs_mat_01_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_02), lhs_mat_ymm_01_02, 1);
-                    __m512i lhs_mat_23_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_02), lhs_mat_ymm_23_02, 1);
-                    __m512i lhs_mat_01_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_03), lhs_mat_ymm_01_03, 1);
-                    __m512i lhs_mat_23_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_03), lhs_mat_ymm_23_03, 1);
 
                     __m512i lhs_mat_01_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_10), lhs_mat_ymm_01_10, 1);
                     __m512i lhs_mat_23_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_10), lhs_mat_ymm_23_10, 1);
                     __m512i lhs_mat_01_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_11), lhs_mat_ymm_01_11, 1);
                     __m512i lhs_mat_23_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_11), lhs_mat_ymm_23_11, 1);
-                    __m512i lhs_mat_01_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_12), lhs_mat_ymm_01_12, 1);
-                    __m512i lhs_mat_23_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_12), lhs_mat_ymm_23_12, 1);
-                    __m512i lhs_mat_01_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_13), lhs_mat_ymm_01_13, 1);
-                    __m512i lhs_mat_23_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_13), lhs_mat_ymm_23_13, 1);
 
-                    // Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks
-                    __m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].bsums + 16 * sb)));
-                    __m256i lhs_bsums_hsum_ymm_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1)));
-                    lhs_bsums_hsum_ymm_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_ymm_0123_01, lhs_bsums_hsum_ymm_0123_01, 0);
-                    __m512i lhs_bsums_hsum_0123_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_hsum_ymm_0123_01), lhs_bsums_hsum_ymm_0123_01, 1);
+                    __m512i lhs_mat_01_20 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_20), lhs_mat_ymm_01_20, 1);
+                    __m512i lhs_mat_23_20 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_20), lhs_mat_ymm_23_20, 1);
+                    __m512i lhs_mat_01_21 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_21), lhs_mat_ymm_01_21, 1);
+                    __m512i lhs_mat_23_21 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_21), lhs_mat_ymm_23_21, 1);
+
+                    __m512i lhs_mat_01_30 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_30), lhs_mat_ymm_01_30, 1);
+                    __m512i lhs_mat_23_30 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_30), lhs_mat_ymm_23_30, 1);
+                    __m512i lhs_mat_01_31 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_31), lhs_mat_ymm_01_31, 1);
+                    __m512i lhs_mat_23_31 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_31), lhs_mat_ymm_23_31, 1);
+
+                    __m512i lhs_mat_01_40 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_40), lhs_mat_ymm_01_40, 1);
+                    __m512i lhs_mat_23_40 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_40), lhs_mat_ymm_23_40, 1);
+                    __m512i lhs_mat_01_41 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_41), lhs_mat_ymm_01_41, 1);
+                    __m512i lhs_mat_23_41 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_41), lhs_mat_ymm_23_41, 1);
+
+                    __m512i lhs_mat_01_50 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_50), lhs_mat_ymm_01_50, 1);
+                    __m512i lhs_mat_23_50 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_50), lhs_mat_ymm_23_50, 1);
+                    __m512i lhs_mat_01_51 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_51), lhs_mat_ymm_01_51, 1);
+                    __m512i lhs_mat_23_51 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_51), lhs_mat_ymm_23_51, 1);
+
+                    __m512i lhs_mat_01_60 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_60), lhs_mat_ymm_01_60, 1);
+                    __m512i lhs_mat_23_60 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_60), lhs_mat_ymm_23_60, 1);
+                    __m512i lhs_mat_01_61 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_61), lhs_mat_ymm_01_61, 1);
+                    __m512i lhs_mat_23_61 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_61), lhs_mat_ymm_23_61, 1);
+
+                    __m512i lhs_mat_01_70 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_70), lhs_mat_ymm_01_70, 1);
+                    __m512i lhs_mat_23_70 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_70), lhs_mat_ymm_23_70, 1);
+                    __m512i lhs_mat_01_71 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_71), lhs_mat_ymm_01_71, 1);
+                    __m512i lhs_mat_23_71 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_71), lhs_mat_ymm_23_71, 1);
+
+                    // Bsums are loaded for the different Q8_K blocks
+                    __m128i lhs_raw_bsums_01_0123 = _mm_loadu_si128((const __m128i *)((a_ptr[b].bsums + 32 * sb)));
+                    __m128i lhs_raw_bsums_23_0123 = _mm_loadu_si128((const __m128i *)(a_ptr[b].bsums + 8 + 32 * sb));
+                    __m128i lhs_raw_bsums_01_4567 = _mm_loadu_si128((const __m128i *)((a_ptr[b].bsums + 16 + 32 * sb)));
+                    __m128i lhs_raw_bsums_23_4567 = _mm_loadu_si128((const __m128i *)(a_ptr[b].bsums + 24 + 32 * sb));
+
+                    __m256i lhs_bsums_ymm_01_0123 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_01_0123), lhs_raw_bsums_01_0123, 1);
+                    __m512i lhs_bsums_01_0123 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_ymm_01_0123), lhs_bsums_ymm_01_0123, 1);
+                    __m256i lhs_bsums_ymm_23_0123 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_23_0123), lhs_raw_bsums_23_0123, 1);
+                    __m512i lhs_bsums_23_0123 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_ymm_23_0123), lhs_bsums_ymm_23_0123, 1);
+                    __m256i lhs_bsums_ymm_01_4567 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_01_4567), lhs_raw_bsums_01_4567, 1);
+                    __m512i lhs_bsums_01_4567 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_ymm_01_4567), lhs_bsums_ymm_01_4567, 1);
+                    __m256i lhs_bsums_ymm_23_4567 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_23_4567), lhs_raw_bsums_23_4567, 1);
+                    __m512i lhs_bsums_23_4567 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_ymm_23_4567), lhs_bsums_ymm_23_4567, 1);
 
                     // Shuffle pattern one - left side input
                     const __m512i lhs_mat_01_00_sp1 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)
                     const __m512i lhs_mat_23_00_sp1 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)160); //A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3)
+
                     const __m512i lhs_mat_01_01_sp1 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)
                     const __m512i lhs_mat_23_01_sp1 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)160); //A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11)
-                    const __m512i lhs_mat_01_02_sp1 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)160); //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19)
-                    const __m512i lhs_mat_23_02_sp1 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)160); //A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19)
-                    const __m512i lhs_mat_01_03_sp1 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)160); //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27)
-                    const __m512i lhs_mat_23_03_sp1 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)160); //A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27)
 
                     const __m512i lhs_mat_01_10_sp1 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3)
                     const __m512i lhs_mat_23_10_sp1 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)160); //A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3)
+
                     const __m512i lhs_mat_01_11_sp1 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)
                     const __m512i lhs_mat_23_11_sp1 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)160); //A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11)
-                    const __m512i lhs_mat_01_12_sp1 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)160); //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19)
-                    const __m512i lhs_mat_23_12_sp1 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)160); //A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19)
-                    const __m512i lhs_mat_01_13_sp1 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27)
-                    const __m512i lhs_mat_23_13_sp1 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)160); //A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27)
+
+                    const __m512i lhs_mat_01_20_sp1 = _mm512_shuffle_epi32(lhs_mat_01_20, (_MM_PERM_ENUM)160); //A20(0-3) A20(0-3) A21(0-3) A21(0-3) A20(0-3) A20(0-3) A21(0-3) A21(0-3) A20(0-3) A20(0-3) A21(0-3) A21(0-3) A20(0-3) A20(0-3) A21(0-3) A21(0-3)
+                    const __m512i lhs_mat_23_20_sp1 = _mm512_shuffle_epi32(lhs_mat_23_20, (_MM_PERM_ENUM)160); //A22(0-3) A22(0-3) A23(0-3) A23(0-3) A22(0-3) A22(0-3) A23(0-3) A23(0-3) A22(0-3) A22(0-3) A23(0-3) A23(0-3) A22(0-3) A22(0-3) A23(0-3) A23(0-3)
+
+                    const __m512i lhs_mat_01_21_sp1 = _mm512_shuffle_epi32(lhs_mat_01_21, (_MM_PERM_ENUM)160); //A20(8-11) A20(8-11) A21(8-11) A21(8-11) A20(8-11) A20(8-11) A21(8-11) A21(8-11) A20(8-11) A20(8-11) A21(8-11) A21(8-11) A20(8-11) A20(8-11) A21(8-11) A21(8-11)
+                    const __m512i lhs_mat_23_21_sp1 = _mm512_shuffle_epi32(lhs_mat_23_21, (_MM_PERM_ENUM)160); //A22(8-11) A22(8-11) A23(8-11) A23(8-11) A22(8-11) A22(8-11) A23(8-11) A23(8-11) A22(8-11) A22(8-11) A23(8-11) A23(8-11) A22(8-11) A22(8-11) A23(8-11) A23(8-11)
+
+                    const __m512i lhs_mat_01_30_sp1 = _mm512_shuffle_epi32(lhs_mat_01_30, (_MM_PERM_ENUM)160); //A30(0-3) A30(0-3) A31(0-3) A31(0-3) A30(0-3) A30(0-3) A31(0-3) A31(0-3) A30(0-3) A30(0-3) A31(0-3) A31(0-3) A30(0-3) A30(0-3) A31(0-3) A31(0-3)
+                    const __m512i lhs_mat_23_30_sp1 = _mm512_shuffle_epi32(lhs_mat_23_30, (_MM_PERM_ENUM)160); //A32(0-3) A32(0-3) A33(0-3) A33(0-3) A32(0-3) A32(0-3) A33(0-3) A33(0-3) A32(0-3) A32(0-3) A33(0-3) A33(0-3) A32(0-3) A32(0-3) A33(0-3) A33(0-3)
+
+                    const __m512i lhs_mat_01_31_sp1 = _mm512_shuffle_epi32(lhs_mat_01_31, (_MM_PERM_ENUM)160); //A30(8-11) A30(8-11) A31(8-11) A31(8-11) A30(8-11) A30(8-11) A31(8-11) A31(8-11) A30(8-11) A30(8-11) A31(8-11) A31(8-11) A30(8-11) A30(8-11) A31(8-11) A31(8-11)
+                    const __m512i lhs_mat_23_31_sp1 = _mm512_shuffle_epi32(lhs_mat_23_31, (_MM_PERM_ENUM)160); //A32(8-11) A32(8-11) A33(8-11) A33(8-11) A32(8-11) A32(8-11) A33(8-11) A33(8-11) A32(8-11) A32(8-11) A33(8-11) A33(8-11) A32(8-11) A32(8-11) A33(8-11) A33(8-11)
+
+                    const __m512i lhs_mat_01_40_sp1 = _mm512_shuffle_epi32(lhs_mat_01_40, (_MM_PERM_ENUM)160); //A40(0-3) A40(0-3) A41(0-3) A41(0-3) A40(0-3) A40(0-3) A41(0-3) A41(0-3) A40(0-3) A40(0-3) A41(0-3) A41(0-3) A40(0-3) A40(0-3) A41(0-3) A41(0-3)
+                    const __m512i lhs_mat_23_40_sp1 = _mm512_shuffle_epi32(lhs_mat_23_40, (_MM_PERM_ENUM)160); //A42(0-3) A42(0-3) A43(0-3) A43(0-3) A42(0-3) A42(0-3) A43(0-3) A43(0-3) A42(0-3) A42(0-3) A43(0-3) A43(0-3) A42(0-3) A42(0-3) A43(0-3) A43(0-3)
+
+                    const __m512i lhs_mat_01_41_sp1 = _mm512_shuffle_epi32(lhs_mat_01_41, (_MM_PERM_ENUM)160); //A40(8-11) A40(8-11) A41(8-11) A41(8-11) A40(8-11) A40(8-11) A41(8-11) A41(8-11) A40(8-11) A40(8-11) A41(8-11) A41(8-11) A40(8-11) A40(8-11) A41(8-11) A41(8-11)
+                    const __m512i lhs_mat_23_41_sp1 = _mm512_shuffle_epi32(lhs_mat_23_41, (_MM_PERM_ENUM)160); //A42(8-11) A42(8-11) A43(8-11) A43(8-11) A42(8-11) A42(8-11) A43(8-11) A43(8-11) A42(8-11) A42(8-11) A43(8-11) A43(8-11) A42(8-11) A42(8-11) A43(8-11) A43(8-11)
+
+                    const __m512i lhs_mat_01_50_sp1 = _mm512_shuffle_epi32(lhs_mat_01_50, (_MM_PERM_ENUM)160); //A50(0-3) A50(0-3) A51(0-3) A51(0-3) A50(0-3) A50(0-3) A51(0-3) A51(0-3) A50(0-3) A50(0-3) A51(0-3) A51(0-3) A50(0-3) A50(0-3) A51(0-3) A51(0-3)
+                    const __m512i lhs_mat_23_50_sp1 = _mm512_shuffle_epi32(lhs_mat_23_50, (_MM_PERM_ENUM)160); //A52(0-3) A52(0-3) A53(0-3) A53(0-3) A52(0-3) A52(0-3) A53(0-3) A53(0-3) A52(0-3) A52(0-3) A53(0-3) A53(0-3) A52(0-3) A52(0-3) A53(0-3) A53(0-3)
+
+                    const __m512i lhs_mat_01_51_sp1 = _mm512_shuffle_epi32(lhs_mat_01_51, (_MM_PERM_ENUM)160); //A50(8-11) A50(8-11) A51(8-11) A51(8-11) A50(8-11) A50(8-11) A51(8-11) A51(8-11) A50(8-11) A50(8-11) A51(8-11) A51(8-11) A50(8-11) A50(8-11) A51(8-11) A51(8-11)
+                    const __m512i lhs_mat_23_51_sp1 = _mm512_shuffle_epi32(lhs_mat_23_51, (_MM_PERM_ENUM)160); //A52(8-11) A52(8-11) A53(8-11) A53(8-11) A52(8-11) A52(8-11) A53(8-11) A53(8-11) A52(8-11) A52(8-11) A53(8-11) A53(8-11) A52(8-11) A52(8-11) A53(8-11) A53(8-11)
+
+                    const __m512i lhs_mat_01_60_sp1 = _mm512_shuffle_epi32(lhs_mat_01_60, (_MM_PERM_ENUM)160); //A60(0-3) A60(0-3) A61(0-3) A61(0-3) A60(0-3) A60(0-3) A61(0-3) A61(0-3) A60(0-3) A60(0-3) A61(0-3) A61(0-3) A60(0-3) A60(0-3) A61(0-3) A61(0-3)
+                    const __m512i lhs_mat_23_60_sp1 = _mm512_shuffle_epi32(lhs_mat_23_60, (_MM_PERM_ENUM)160); //A62(0-3) A62(0-3) A63(0-3) A63(0-3) A62(0-3) A62(0-3) A63(0-3) A63(0-3) A62(0-3) A62(0-3) A63(0-3) A63(0-3) A62(0-3) A62(0-3) A63(0-3) A63(0-3)
+
+                    const __m512i lhs_mat_01_61_sp1 = _mm512_shuffle_epi32(lhs_mat_01_61, (_MM_PERM_ENUM)160); //A60(8-11) A60(8-11) A61(8-11) A61(8-11) A60(8-11) A60(8-11) A61(8-11) A61(8-11) A60(8-11) A60(8-11) A61(8-11) A61(8-11) A60(8-11) A60(8-11) A61(8-11) A61(8-11)
+                    const __m512i lhs_mat_23_61_sp1 = _mm512_shuffle_epi32(lhs_mat_23_61, (_MM_PERM_ENUM)160); //A62(8-11) A62(8-11) A63(8-11) A63(8-11) A62(8-11) A62(8-11) A63(8-11) A63(8-11) A62(8-11) A62(8-11) A63(8-11) A63(8-11) A62(8-11) A62(8-11) A63(8-11) A63(8-11)
+
+                    const __m512i lhs_mat_01_70_sp1 = _mm512_shuffle_epi32(lhs_mat_01_70, (_MM_PERM_ENUM)160); //A70(0-3) A70(0-3) A71(0-3) A71(0-3) A70(0-3) A70(0-3) A71(0-3) A71(0-3) A70(0-3) A70(0-3) A71(0-3) A71(0-3) A70(0-3) A70(0-3) A71(0-3) A71(0-3)
+                    const __m512i lhs_mat_23_70_sp1 = _mm512_shuffle_epi32(lhs_mat_23_70, (_MM_PERM_ENUM)160); //A72(0-3) A72(0-3) A73(0-3) A73(0-3) A72(0-3) A72(0-3) A73(0-3) A73(0-3) A72(0-3) A72(0-3) A73(0-3) A73(0-3) A72(0-3) A72(0-3) A73(0-3) A73(0-3)
+
+                    const __m512i lhs_mat_01_71_sp1 = _mm512_shuffle_epi32(lhs_mat_01_71, (_MM_PERM_ENUM)160); //A70(8-11) A70(8-11) A71(8-11) A71(8-11) A70(8-11) A70(8-11) A71(8-11) A71(8-11) A70(8-11) A70(8-11) A71(8-11) A71(8-11) A70(8-11) A70(8-11) A71(8-11) A71(8-11)
+                    const __m512i lhs_mat_23_71_sp1 = _mm512_shuffle_epi32(lhs_mat_23_71, (_MM_PERM_ENUM)160); //A72(8-11) A72(8-11) A73(8-11) A73(8-11) A72(8-11) A72(8-11) A73(8-11) A73(8-11) A72(8-11) A72(8-11) A73(8-11) A73(8-11) A72(8-11) A72(8-11) A73(8-11) A73(8-11)
 
                     const __m512i lhs_mat_01_00_sp2 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7)
                     const __m512i lhs_mat_23_00_sp2 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)245); //A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7)
+
                     const __m512i lhs_mat_01_01_sp2 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15)
                     const __m512i lhs_mat_23_01_sp2 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)245); //A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15)
-                    const __m512i lhs_mat_01_02_sp2 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23)
-                    const __m512i lhs_mat_23_02_sp2 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)245); //A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23)
-                    const __m512i lhs_mat_01_03_sp2 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31)
-                    const __m512i lhs_mat_23_03_sp2 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)245); //A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31)
 
                     const __m512i lhs_mat_01_10_sp2 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7)
                     const __m512i lhs_mat_23_10_sp2 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)245); //A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7)
+
                     const __m512i lhs_mat_01_11_sp2 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15)
                     const __m512i lhs_mat_23_11_sp2 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)245); //A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15)
-                    const __m512i lhs_mat_01_12_sp2 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23)
-                    const __m512i lhs_mat_23_12_sp2 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)245); //A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23)
-                    const __m512i lhs_mat_01_13_sp2 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31)
-                    const __m512i lhs_mat_23_13_sp2 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)245); //A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31)
+
+                    const __m512i lhs_mat_01_20_sp2 = _mm512_shuffle_epi32(lhs_mat_01_20, (_MM_PERM_ENUM)245); //A20(4-7) A20(4-7) A21(4-7) A21(4-7) A20(4-7) A20(4-7) A21(4-7) A21(4-7) A20(4-7) A20(4-7) A21(4-7) A21(4-7) A20(4-7) A20(4-7) A21(4-7) A21(4-7)
+                    const __m512i lhs_mat_23_20_sp2 = _mm512_shuffle_epi32(lhs_mat_23_20, (_MM_PERM_ENUM)245); //A22(4-7) A22(4-7) A23(4-7) A23(4-7) A22(4-7) A22(4-7) A23(4-7) A23(4-7) A22(4-7) A22(4-7) A23(4-7) A23(4-7) A22(4-7) A22(4-7) A23(4-7) A23(4-7)
+
+                    const __m512i lhs_mat_01_21_sp2 = _mm512_shuffle_epi32(lhs_mat_01_21, (_MM_PERM_ENUM)245); //A20(12-15) A20(12-15) A21(12-15) A21(12-15) A20(12-15) A20(12-15) A21(12-15) A21(12-15) A20(12-15) A20(12-15) A21(12-15) A21(12-15) A20(12-15) A20(12-15) A21(12-15) A21(12-15)
+                    const __m512i lhs_mat_23_21_sp2 = _mm512_shuffle_epi32(lhs_mat_23_21, (_MM_PERM_ENUM)245); //A22(12-15) A22(12-15) A23(12-15) A23(12-15) A22(12-15) A22(12-15) A23(12-15) A23(12-15) A22(12-15) A22(12-15) A23(12-15) A23(12-15) A22(12-15) A22(12-15) A23(12-15) A23(12-15)
+
+                    const __m512i lhs_mat_01_30_sp2 = _mm512_shuffle_epi32(lhs_mat_01_30, (_MM_PERM_ENUM)245); //A30(4-7) A30(4-7) A31(4-7) A31(4-7) A30(4-7) A30(4-7) A31(4-7) A31(4-7) A30(4-7) A30(4-7) A31(4-7) A31(4-7) A30(4-7) A30(4-7) A31(4-7) A31(4-7)
+                    const __m512i lhs_mat_23_30_sp2 = _mm512_shuffle_epi32(lhs_mat_23_30, (_MM_PERM_ENUM)245); //A32(4-7) A32(4-7) A33(4-7) A33(4-7) A32(4-7) A32(4-7) A33(4-7) A33(4-7) A32(4-7) A32(4-7) A33(4-7) A33(4-7) A32(4-7) A32(4-7) A33(4-7) A33(4-7)
+
+                    const __m512i lhs_mat_01_31_sp2 = _mm512_shuffle_epi32(lhs_mat_01_31, (_MM_PERM_ENUM)245); //A30(12-15) A30(12-15) A31(12-15) A31(12-15) A30(12-15) A30(12-15) A31(12-15) A31(12-15) A30(12-15) A30(12-15) A31(12-15) A31(12-15) A30(12-15) A30(12-15) A31(12-15) A31(12-15)
+                    const __m512i lhs_mat_23_31_sp2 = _mm512_shuffle_epi32(lhs_mat_23_31, (_MM_PERM_ENUM)245); //A32(12-15) A32(12-15) A33(12-15) A33(12-15) A32(12-15) A32(12-15) A33(12-15) A33(12-15) A32(12-15) A32(12-15) A33(12-15) A33(12-15) A32(12-15) A32(12-15) A33(12-15) A33(12-15)
+
+                    const __m512i lhs_mat_01_40_sp2 = _mm512_shuffle_epi32(lhs_mat_01_40, (_MM_PERM_ENUM)245); //A40(4-7) A40(4-7) A41(4-7) A41(4-7) A40(4-7) A40(4-7) A41(4-7) A41(4-7) A40(4-7) A40(4-7) A41(4-7) A41(4-7) A40(4-7) A40(4-7) A41(4-7) A41(4-7)
+                    const __m512i lhs_mat_23_40_sp2 = _mm512_shuffle_epi32(lhs_mat_23_40, (_MM_PERM_ENUM)245); //A42(4-7) A42(4-7) A43(4-7) A43(4-7) A42(4-7) A42(4-7) A43(4-7) A43(4-7) A42(4-7) A42(4-7) A43(4-7) A43(4-7) A42(4-7) A42(4-7) A43(4-7) A43(4-7)
+
+                    const __m512i lhs_mat_01_41_sp2 = _mm512_shuffle_epi32(lhs_mat_01_41, (_MM_PERM_ENUM)245); //A40(12-15) A40(12-15) A41(12-15) A41(12-15) A40(12-15) A40(12-15) A41(12-15) A41(12-15) A40(12-15) A40(12-15) A41(12-15) A41(12-15) A40(12-15) A40(12-15) A41(12-15) A41(12-15)
+                    const __m512i lhs_mat_23_41_sp2 = _mm512_shuffle_epi32(lhs_mat_23_41, (_MM_PERM_ENUM)245); //A42(12-15) A42(12-15) A43(12-15) A43(12-15) A42(12-15) A42(12-15) A43(12-15) A43(12-15) A42(12-15) A42(12-15) A43(12-15) A43(12-15) A42(12-15) A42(12-15) A43(12-15) A43(12-15)
+
+                    const __m512i lhs_mat_01_50_sp2 = _mm512_shuffle_epi32(lhs_mat_01_50, (_MM_PERM_ENUM)245); //A50(4-7) A50(4-7) A51(4-7) A51(4-7) A50(4-7) A50(4-7) A51(4-7) A51(4-7) A50(4-7) A50(4-7) A51(4-7) A51(4-7) A50(4-7) A50(4-7) A51(4-7) A51(4-7)
+                    const __m512i lhs_mat_23_50_sp2 = _mm512_shuffle_epi32(lhs_mat_23_50, (_MM_PERM_ENUM)245); //A52(4-7) A52(4-7) A53(4-7) A53(4-7) A52(4-7) A52(4-7) A53(4-7) A53(4-7) A52(4-7) A52(4-7) A53(4-7) A53(4-7) A52(4-7) A52(4-7) A53(4-7) A53(4-7)
+
+                    const __m512i lhs_mat_01_51_sp2 = _mm512_shuffle_epi32(lhs_mat_01_51, (_MM_PERM_ENUM)245); //A50(12-15) A50(12-15) A51(12-15) A51(12-15) A50(12-15) A50(12-15) A51(12-15) A51(12-15) A50(12-15) A50(12-15) A51(12-15) A51(12-15) A50(12-15) A50(12-15) A51(12-15) A51(12-15)
+                    const __m512i lhs_mat_23_51_sp2 = _mm512_shuffle_epi32(lhs_mat_23_51, (_MM_PERM_ENUM)245); //A52(12-15) A52(12-15) A53(12-15) A53(12-15) A52(12-15) A52(12-15) A53(12-15) A53(12-15) A52(12-15) A52(12-15) A53(12-15) A53(12-15) A52(12-15) A52(12-15) A53(12-15) A53(12-15)
+
+                    const __m512i lhs_mat_01_60_sp2 = _mm512_shuffle_epi32(lhs_mat_01_60, (_MM_PERM_ENUM)245); //A60(4-7) A60(4-7) A61(4-7) A61(4-7) A60(4-7) A60(4-7) A61(4-7) A61(4-7) A60(4-7) A60(4-7) A61(4-7) A61(4-7) A60(4-7) A60(4-7) A61(4-7) A61(4-7)
+                    const __m512i lhs_mat_23_60_sp2 = _mm512_shuffle_epi32(lhs_mat_23_60, (_MM_PERM_ENUM)245); //A62(4-7) A62(4-7) A63(4-7) A63(4-7) A62(4-7) A62(4-7) A63(4-7) A63(4-7) A62(4-7) A62(4-7) A63(4-7) A63(4-7) A62(4-7) A62(4-7) A63(4-7) A63(4-7)
+
+                    const __m512i lhs_mat_01_61_sp2 = _mm512_shuffle_epi32(lhs_mat_01_61, (_MM_PERM_ENUM)245); //A60(12-15) A60(12-15) A61(12-15) A61(12-15) A60(12-15) A60(12-15) A61(12-15) A61(12-15) A60(12-15) A60(12-15) A61(12-15) A61(12-15) A60(12-15) A60(12-15) A61(12-15) A61(12-15)
+                    const __m512i lhs_mat_23_61_sp2 = _mm512_shuffle_epi32(lhs_mat_23_61, (_MM_PERM_ENUM)245); //A62(12-15) A62(12-15) A63(12-15) A63(12-15) A62(12-15) A62(12-15) A63(12-15) A63(12-15) A62(12-15) A62(12-15) A63(12-15) A63(12-15) A62(12-15) A62(12-15) A63(12-15) A63(12-15)
+
+                    const __m512i lhs_mat_01_70_sp2 = _mm512_shuffle_epi32(lhs_mat_01_70, (_MM_PERM_ENUM)245); //A70(4-7) A70(4-7) A71(4-7) A71(4-7) A70(4-7) A70(4-7) A71(4-7) A71(4-7) A70(4-7) A70(4-7) A71(4-7) A71(4-7) A70(4-7) A70(4-7) A71(4-7) A71(4-7)
+                    const __m512i lhs_mat_23_70_sp2 = _mm512_shuffle_epi32(lhs_mat_23_70, (_MM_PERM_ENUM)245); //A72(4-7) A72(4-7) A73(4-7) A73(4-7) A72(4-7) A72(4-7) A73(4-7) A73(4-7) A72(4-7) A72(4-7) A73(4-7) A73(4-7) A72(4-7) A72(4-7) A73(4-7) A73(4-7)
+
+                    const __m512i lhs_mat_01_71_sp2 = _mm512_shuffle_epi32(lhs_mat_01_71, (_MM_PERM_ENUM)245); //A70(12-15) A70(12-15) A71(12-15) A71(12-15) A70(12-15) A70(12-15) A71(12-15) A71(12-15) A70(12-15) A70(12-15) A71(12-15) A71(12-15) A70(12-15) A70(12-15) A71(12-15) A71(12-15)
+                    const __m512i lhs_mat_23_71_sp2 = _mm512_shuffle_epi32(lhs_mat_23_71, (_MM_PERM_ENUM)245); //A72(12-15) A72(12-15) A73(12-15) A73(12-15) A72(12-15) A72(12-15) A73(12-15) A73(12-15) A72(12-15) A72(12-15) A73(12-15) A73(12-15) A72(12-15) A72(12-15) A73(12-15) A73(12-15)
 
                     // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
-                    __m512i iacc_mat_00_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_01_00_sp1));
-                    __m512i iacc_mat_01_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_01_00_sp1));
-                    __m512i iacc_mat_10_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_23_00_sp1));
-                    __m512i iacc_mat_11_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_23_00_sp1));
-                    __m512i iacc_mat_00_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_01_10_sp1));
-                    __m512i iacc_mat_01_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_01_10_sp1));
-                    __m512i iacc_mat_10_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_23_10_sp1));
-                    __m512i iacc_mat_11_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_23_10_sp1));
+                    __m512i iacc_mat_00_0_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_01_00_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_01_01_sp1));
+                    __m512i iacc_mat_01_0_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_01_00_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_01_01_sp1));
 
-                    __m512i iacc_mat_00_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_01_00_sp2));
-                    __m512i iacc_mat_01_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_01_00_sp2));
-                    __m512i iacc_mat_10_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_23_00_sp2));
-                    __m512i iacc_mat_11_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_23_00_sp2));
-                    __m512i iacc_mat_00_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_01_10_sp2));
-                    __m512i iacc_mat_01_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_01_10_sp2));
-                    __m512i iacc_mat_10_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_23_10_sp2));
-                    __m512i iacc_mat_11_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_23_10_sp2));
+                    __m512i iacc_mat_10_0_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_23_00_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_23_01_sp1));
+                    __m512i iacc_mat_11_0_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_23_00_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_23_01_sp1));
 
-                    // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+                    __m512i iacc_mat_00_1_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_01_10_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_01_11_sp1));
+                    __m512i iacc_mat_01_1_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_01_10_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_01_11_sp1));
+
+                    __m512i iacc_mat_10_1_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_23_10_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_23_11_sp1));
+                    __m512i iacc_mat_11_1_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_23_10_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_23_11_sp1));
+
+                    __m512i iacc_mat_00_2_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_20_sp1, lhs_mat_01_20_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_21_sp1, lhs_mat_01_21_sp1));
+                    __m512i iacc_mat_01_2_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_20_sp1, lhs_mat_01_20_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_21_sp1, lhs_mat_01_21_sp1));
+
+                    __m512i iacc_mat_10_2_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_20_sp1, lhs_mat_23_20_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_21_sp1, lhs_mat_23_21_sp1));
+                    __m512i iacc_mat_11_2_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_20_sp1, lhs_mat_23_20_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_21_sp1, lhs_mat_23_21_sp1));
+
+                    __m512i iacc_mat_00_3_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_30_sp1, lhs_mat_01_30_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_31_sp1, lhs_mat_01_31_sp1));
+                    __m512i iacc_mat_01_3_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_30_sp1, lhs_mat_01_30_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_31_sp1, lhs_mat_01_31_sp1));
+
+                    __m512i iacc_mat_10_3_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_30_sp1, lhs_mat_23_30_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_31_sp1, lhs_mat_23_31_sp1));
+                    __m512i iacc_mat_11_3_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_30_sp1, lhs_mat_23_30_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_31_sp1, lhs_mat_23_31_sp1));
+
+                    __m512i iacc_mat_00_4_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_40_sp1, lhs_mat_01_40_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_41_sp1, lhs_mat_01_41_sp1));
+                    __m512i iacc_mat_01_4_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_40_sp1, lhs_mat_01_40_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_41_sp1, lhs_mat_01_41_sp1));
+
+                    __m512i iacc_mat_10_4_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_40_sp1, lhs_mat_23_40_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_41_sp1, lhs_mat_23_41_sp1));
+                    __m512i iacc_mat_11_4_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_40_sp1, lhs_mat_23_40_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_41_sp1, lhs_mat_23_41_sp1));
+
+                    __m512i iacc_mat_00_5_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_50_sp1, lhs_mat_01_50_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_51_sp1, lhs_mat_01_51_sp1));
+                    __m512i iacc_mat_01_5_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_50_sp1, lhs_mat_01_50_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_51_sp1, lhs_mat_01_51_sp1));
+
+                    __m512i iacc_mat_10_5_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_50_sp1, lhs_mat_23_50_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_51_sp1, lhs_mat_23_51_sp1));
+                    __m512i iacc_mat_11_5_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_50_sp1, lhs_mat_23_50_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_51_sp1, lhs_mat_23_51_sp1));
+
+                    __m512i iacc_mat_00_6_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_60_sp1, lhs_mat_01_60_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_61_sp1, lhs_mat_01_61_sp1));
+                    __m512i iacc_mat_01_6_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_60_sp1, lhs_mat_01_60_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_61_sp1, lhs_mat_01_61_sp1));
+
+                    __m512i iacc_mat_10_6_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_60_sp1, lhs_mat_23_60_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_61_sp1, lhs_mat_23_61_sp1));
+                    __m512i iacc_mat_11_6_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_60_sp1, lhs_mat_23_60_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_61_sp1, lhs_mat_23_61_sp1));
+
+                    __m512i iacc_mat_00_7_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_70_sp1, lhs_mat_01_70_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_71_sp1, lhs_mat_01_71_sp1));
+                    __m512i iacc_mat_01_7_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_70_sp1, lhs_mat_01_70_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_71_sp1, lhs_mat_01_71_sp1));
+
+                    __m512i iacc_mat_10_7_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_70_sp1, lhs_mat_23_70_sp1),_mm512_maddubs_epi16(rhs_mat_014589CD_71_sp1, lhs_mat_23_71_sp1));
+                    __m512i iacc_mat_11_7_sp1 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_70_sp1, lhs_mat_23_70_sp1),_mm512_maddubs_epi16(rhs_mat_2367ABEF_71_sp1, lhs_mat_23_71_sp1));
+
+
+                    __m512i iacc_mat_00_0_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_01_00_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_01_01_sp2));
+                    __m512i iacc_mat_01_0_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_01_00_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_01_01_sp2));
+
+                    __m512i iacc_mat_10_0_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_23_00_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_23_01_sp2));
+                    __m512i iacc_mat_11_0_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_23_00_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_23_01_sp2));
+
+                    __m512i iacc_mat_00_1_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_01_10_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_01_11_sp2));
+                    __m512i iacc_mat_01_1_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_01_10_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_01_11_sp2));
+
+                    __m512i iacc_mat_10_1_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_23_10_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_23_11_sp2));
+                    __m512i iacc_mat_11_1_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_23_10_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_23_11_sp2));
+
+                    __m512i iacc_mat_00_2_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_20_sp2, lhs_mat_01_20_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_21_sp2, lhs_mat_01_21_sp2));
+                    __m512i iacc_mat_01_2_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_20_sp2, lhs_mat_01_20_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_21_sp2, lhs_mat_01_21_sp2));
+
+                    __m512i iacc_mat_10_2_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_20_sp2, lhs_mat_23_20_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_21_sp2, lhs_mat_23_21_sp2));
+                    __m512i iacc_mat_11_2_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_20_sp2, lhs_mat_23_20_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_21_sp2, lhs_mat_23_21_sp2));
+
+                    __m512i iacc_mat_00_3_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_30_sp2, lhs_mat_01_30_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_31_sp2, lhs_mat_01_31_sp2));
+                    __m512i iacc_mat_01_3_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_30_sp2, lhs_mat_01_30_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_31_sp2, lhs_mat_01_31_sp2));
+
+                    __m512i iacc_mat_10_3_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_30_sp2, lhs_mat_23_30_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_31_sp2, lhs_mat_23_31_sp2));
+                    __m512i iacc_mat_11_3_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_30_sp2, lhs_mat_23_30_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_31_sp2, lhs_mat_23_31_sp2));
+
+                    __m512i iacc_mat_00_4_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_40_sp2, lhs_mat_01_40_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_41_sp2, lhs_mat_01_41_sp2));
+                    __m512i iacc_mat_01_4_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_40_sp2, lhs_mat_01_40_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_41_sp2, lhs_mat_01_41_sp2));
+
+                    __m512i iacc_mat_10_4_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_40_sp2, lhs_mat_23_40_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_41_sp2, lhs_mat_23_41_sp2));
+                    __m512i iacc_mat_11_4_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_40_sp2, lhs_mat_23_40_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_41_sp2, lhs_mat_23_41_sp2));
+
+                    __m512i iacc_mat_00_5_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_50_sp2, lhs_mat_01_50_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_51_sp2, lhs_mat_01_51_sp2));
+                    __m512i iacc_mat_01_5_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_50_sp2, lhs_mat_01_50_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_51_sp2, lhs_mat_01_51_sp2));
+
+                    __m512i iacc_mat_10_5_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_50_sp2, lhs_mat_23_50_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_51_sp2, lhs_mat_23_51_sp2));
+                    __m512i iacc_mat_11_5_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_50_sp2, lhs_mat_23_50_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_51_sp2, lhs_mat_23_51_sp2));
+
+                    __m512i iacc_mat_00_6_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_60_sp2, lhs_mat_01_60_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_61_sp2, lhs_mat_01_61_sp2));
+                    __m512i iacc_mat_01_6_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_60_sp2, lhs_mat_01_60_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_61_sp2, lhs_mat_01_61_sp2));
+
+                    __m512i iacc_mat_10_6_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_60_sp2, lhs_mat_23_60_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_61_sp2, lhs_mat_23_61_sp2));
+                    __m512i iacc_mat_11_6_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_60_sp2, lhs_mat_23_60_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_61_sp2, lhs_mat_23_61_sp2));
+
+                    __m512i iacc_mat_00_7_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_70_sp2, lhs_mat_01_70_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_71_sp2, lhs_mat_01_71_sp2));
+                    __m512i iacc_mat_01_7_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_70_sp2, lhs_mat_01_70_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_71_sp2, lhs_mat_01_71_sp2));
+
+                    __m512i iacc_mat_10_7_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_70_sp2, lhs_mat_23_70_sp2),_mm512_maddubs_epi16(rhs_mat_014589CD_71_sp2, lhs_mat_23_71_sp2));
+                    __m512i iacc_mat_11_7_sp2 = _mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_70_sp2, lhs_mat_23_70_sp2),_mm512_maddubs_epi16(rhs_mat_2367ABEF_71_sp2, lhs_mat_23_71_sp2));
+
+                    // Combine results from both shuffle patterns for each output block
                     __m512i iacc_mat_00_0 = _mm512_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2);
                     __m512i iacc_mat_01_0 = _mm512_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2);
                     __m512i iacc_mat_10_0 = _mm512_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2);
@@ -2485,6 +4810,37 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                     __m512i iacc_mat_10_1 = _mm512_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2);
                     __m512i iacc_mat_11_1 = _mm512_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2);
 
+                    __m512i iacc_mat_00_2 = _mm512_add_epi16(iacc_mat_00_2_sp1, iacc_mat_00_2_sp2);
+                    __m512i iacc_mat_01_2 = _mm512_add_epi16(iacc_mat_01_2_sp1, iacc_mat_01_2_sp2);
+                    __m512i iacc_mat_10_2 = _mm512_add_epi16(iacc_mat_10_2_sp1, iacc_mat_10_2_sp2);
+                    __m512i iacc_mat_11_2 = _mm512_add_epi16(iacc_mat_11_2_sp1, iacc_mat_11_2_sp2);
+
+                    __m512i iacc_mat_00_3 = _mm512_add_epi16(iacc_mat_00_3_sp1, iacc_mat_00_3_sp2);
+                    __m512i iacc_mat_01_3 = _mm512_add_epi16(iacc_mat_01_3_sp1, iacc_mat_01_3_sp2);
+                    __m512i iacc_mat_10_3 = _mm512_add_epi16(iacc_mat_10_3_sp1, iacc_mat_10_3_sp2);
+                    __m512i iacc_mat_11_3 = _mm512_add_epi16(iacc_mat_11_3_sp1, iacc_mat_11_3_sp2);
+
+                    __m512i iacc_mat_00_4 = _mm512_add_epi16(iacc_mat_00_4_sp1, iacc_mat_00_4_sp2);
+                    __m512i iacc_mat_01_4 = _mm512_add_epi16(iacc_mat_01_4_sp1, iacc_mat_01_4_sp2);
+                    __m512i iacc_mat_10_4 = _mm512_add_epi16(iacc_mat_10_4_sp1, iacc_mat_10_4_sp2);
+                    __m512i iacc_mat_11_4 = _mm512_add_epi16(iacc_mat_11_4_sp1, iacc_mat_11_4_sp2);
+
+                    __m512i iacc_mat_00_5 = _mm512_add_epi16(iacc_mat_00_5_sp1, iacc_mat_00_5_sp2);
+                    __m512i iacc_mat_01_5 = _mm512_add_epi16(iacc_mat_01_5_sp1, iacc_mat_01_5_sp2);
+                    __m512i iacc_mat_10_5 = _mm512_add_epi16(iacc_mat_10_5_sp1, iacc_mat_10_5_sp2);
+                    __m512i iacc_mat_11_5 = _mm512_add_epi16(iacc_mat_11_5_sp1, iacc_mat_11_5_sp2);
+
+                    __m512i iacc_mat_00_6 = _mm512_add_epi16(iacc_mat_00_6_sp1, iacc_mat_00_6_sp2);
+                    __m512i iacc_mat_01_6 = _mm512_add_epi16(iacc_mat_01_6_sp1, iacc_mat_01_6_sp2);
+                    __m512i iacc_mat_10_6 = _mm512_add_epi16(iacc_mat_10_6_sp1, iacc_mat_10_6_sp2);
+                    __m512i iacc_mat_11_6 = _mm512_add_epi16(iacc_mat_11_6_sp1, iacc_mat_11_6_sp2);
+
+                    __m512i iacc_mat_00_7 = _mm512_add_epi16(iacc_mat_00_7_sp1, iacc_mat_00_7_sp2);
+                    __m512i iacc_mat_01_7 = _mm512_add_epi16(iacc_mat_01_7_sp1, iacc_mat_01_7_sp2);
+                    __m512i iacc_mat_10_7 = _mm512_add_epi16(iacc_mat_10_7_sp1, iacc_mat_10_7_sp2);
+                    __m512i iacc_mat_11_7 = _mm512_add_epi16(iacc_mat_11_7_sp1, iacc_mat_11_7_sp2);
+
+                    // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
                     iacc_mat_00_0 = _mm512_madd_epi16(iacc_mat_00_0, scale_014589CD_0);
                     iacc_mat_01_0 = _mm512_madd_epi16(iacc_mat_01_0, scale_2367ABEF_0);
                     iacc_mat_10_0 = _mm512_madd_epi16(iacc_mat_10_0, scale_014589CD_0);
@@ -2495,20 +4851,46 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                     iacc_mat_10_1 = _mm512_madd_epi16(iacc_mat_10_1, scale_014589CD_1);
                     iacc_mat_11_1 = _mm512_madd_epi16(iacc_mat_11_1, scale_2367ABEF_1);
 
-                    // Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step)
-                    __m512i iacc_row_0_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_0, _mm512_shuffle_epi32(iacc_mat_01_0, (_MM_PERM_ENUM)78));
-                    __m512i iacc_row_1_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_0, (_MM_PERM_ENUM)78), iacc_mat_01_0);
-                    __m512i iacc_row_2_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_0, _mm512_shuffle_epi32(iacc_mat_11_0, (_MM_PERM_ENUM)78));
-                    __m512i iacc_row_3_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10_0, (_MM_PERM_ENUM)78), iacc_mat_11_0);
-                    __m512i iacc_row_0_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_1, _mm512_shuffle_epi32(iacc_mat_01_1, (_MM_PERM_ENUM)78));
-                    __m512i iacc_row_1_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_1, (_MM_PERM_ENUM)78), iacc_mat_01_1);
-                    __m512i iacc_row_2_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_1, _mm512_shuffle_epi32(iacc_mat_11_1, (_MM_PERM_ENUM)78));
-                    __m512i iacc_row_3_1 = _mm512_mask_blend_epi32(0xCCCC,_mm512_shuffle_epi32(iacc_mat_10_1, (_MM_PERM_ENUM)78), iacc_mat_11_1);
+                    iacc_mat_00_2 = _mm512_madd_epi16(iacc_mat_00_2, scale_014589CD_2);
+                    iacc_mat_01_2 = _mm512_madd_epi16(iacc_mat_01_2, scale_2367ABEF_2);
+                    iacc_mat_10_2 = _mm512_madd_epi16(iacc_mat_10_2, scale_014589CD_2);
+                    iacc_mat_11_2 = _mm512_madd_epi16(iacc_mat_11_2, scale_2367ABEF_2);
+
+                    iacc_mat_00_3 = _mm512_madd_epi16(iacc_mat_00_3, scale_014589CD_3);
+                    iacc_mat_01_3 = _mm512_madd_epi16(iacc_mat_01_3, scale_2367ABEF_3);
+                    iacc_mat_10_3 = _mm512_madd_epi16(iacc_mat_10_3, scale_014589CD_3);
+                    iacc_mat_11_3 = _mm512_madd_epi16(iacc_mat_11_3, scale_2367ABEF_3);
+
+                    iacc_mat_00_4 = _mm512_madd_epi16(iacc_mat_00_4, scale_014589CD_4);
+                    iacc_mat_01_4 = _mm512_madd_epi16(iacc_mat_01_4, scale_2367ABEF_4);
+                    iacc_mat_10_4 = _mm512_madd_epi16(iacc_mat_10_4, scale_014589CD_4);
+                    iacc_mat_11_4 = _mm512_madd_epi16(iacc_mat_11_4, scale_2367ABEF_4);
+
+                    iacc_mat_00_5 = _mm512_madd_epi16(iacc_mat_00_5, scale_014589CD_5);
+                    iacc_mat_01_5 = _mm512_madd_epi16(iacc_mat_01_5, scale_2367ABEF_5);
+                    iacc_mat_10_5 = _mm512_madd_epi16(iacc_mat_10_5, scale_014589CD_5);
+                    iacc_mat_11_5 = _mm512_madd_epi16(iacc_mat_11_5, scale_2367ABEF_5);
+
+                    iacc_mat_00_6 = _mm512_madd_epi16(iacc_mat_00_6, scale_014589CD_6);
+                    iacc_mat_01_6 = _mm512_madd_epi16(iacc_mat_01_6, scale_2367ABEF_6);
+                    iacc_mat_10_6 = _mm512_madd_epi16(iacc_mat_10_6, scale_014589CD_6);
+                    iacc_mat_11_6 = _mm512_madd_epi16(iacc_mat_11_6, scale_2367ABEF_6);
+
+                    iacc_mat_00_7 = _mm512_madd_epi16(iacc_mat_00_7, scale_014589CD_7);
+                    iacc_mat_01_7 = _mm512_madd_epi16(iacc_mat_01_7, scale_2367ABEF_7);
+                    iacc_mat_10_7 = _mm512_madd_epi16(iacc_mat_10_7, scale_014589CD_7);
+                    iacc_mat_11_7 = _mm512_madd_epi16(iacc_mat_11_7, scale_2367ABEF_7);
+
+                    __m512i iacc_mat_00 = _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(iacc_mat_00_0, iacc_mat_00_1), _mm512_add_epi32(iacc_mat_00_2, iacc_mat_00_3)), _mm512_add_epi32(_mm512_add_epi32(iacc_mat_00_4, iacc_mat_00_5), _mm512_add_epi32(iacc_mat_00_6, iacc_mat_00_7)));
+                    __m512i iacc_mat_01 = _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(iacc_mat_01_0, iacc_mat_01_1), _mm512_add_epi32(iacc_mat_01_2, iacc_mat_01_3)), _mm512_add_epi32(_mm512_add_epi32(iacc_mat_01_4, iacc_mat_01_5), _mm512_add_epi32(iacc_mat_01_6, iacc_mat_01_7)));
+                    __m512i iacc_mat_10 = _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(iacc_mat_10_0, iacc_mat_10_1), _mm512_add_epi32(iacc_mat_10_2, iacc_mat_10_3)), _mm512_add_epi32(_mm512_add_epi32(iacc_mat_10_4, iacc_mat_10_5), _mm512_add_epi32(iacc_mat_10_6, iacc_mat_10_7)));
+                    __m512i iacc_mat_11 = _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(iacc_mat_11_0, iacc_mat_11_1), _mm512_add_epi32(iacc_mat_11_2, iacc_mat_11_3)), _mm512_add_epi32(_mm512_add_epi32(iacc_mat_11_4, iacc_mat_11_5), _mm512_add_epi32(iacc_mat_11_6, iacc_mat_11_7)));
 
-                    __m512i iacc_row_0 = _mm512_add_epi32(iacc_row_0_0, iacc_row_0_1);
-                    __m512i iacc_row_1 = _mm512_add_epi32(iacc_row_1_0, iacc_row_1_1);
-                    __m512i iacc_row_2 = _mm512_add_epi32(iacc_row_2_0, iacc_row_2_1);
-                    __m512i iacc_row_3 = _mm512_add_epi32(iacc_row_3_0, iacc_row_3_1);
+                    // Straighten out to make 4 row vectors
+                    __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));
+                    __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
+                    __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
+                    __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
 
                     // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes
                     const __m128 row_scale_f32_sse = _mm_load_ps(a_ptr[b].d);
@@ -2521,10 +4903,31 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                     acc_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
                     acc_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);
 
-                    __m512i iacc_row_min_0 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)0), mins_01);
-                    __m512i iacc_row_min_1 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)85), mins_01);
-                    __m512i iacc_row_min_2 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)170), mins_01);
-                    __m512i iacc_row_min_3 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)255), mins_01);
+                    // Take two bsums from two Q8_Ks at a time and multiply with corresponding mins values from each Q2_K
+                    __m512i iacc_row_min_0_01 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_0123, (_MM_PERM_ENUM)0), mins_01);
+                    __m512i iacc_row_min_1_01 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_0123, (_MM_PERM_ENUM)170), mins_01);
+                    __m512i iacc_row_min_2_01 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_0123, (_MM_PERM_ENUM)0), mins_01);
+                    __m512i iacc_row_min_3_01 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_0123, (_MM_PERM_ENUM)170), mins_01);
+
+                    __m512i iacc_row_min_0_23 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_0123, (_MM_PERM_ENUM)85), mins_23);
+                    __m512i iacc_row_min_1_23 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_0123, (_MM_PERM_ENUM)255), mins_23);
+                    __m512i iacc_row_min_2_23 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_0123, (_MM_PERM_ENUM)85), mins_23);
+                    __m512i iacc_row_min_3_23 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_0123, (_MM_PERM_ENUM)255), mins_23);
+
+                    __m512i iacc_row_min_0_45 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_4567, (_MM_PERM_ENUM)0), mins_45);
+                    __m512i iacc_row_min_1_45 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_4567, (_MM_PERM_ENUM)170), mins_45);
+                    __m512i iacc_row_min_2_45 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_4567, (_MM_PERM_ENUM)0), mins_45);
+                    __m512i iacc_row_min_3_45 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_4567, (_MM_PERM_ENUM)170), mins_45);
+
+                    __m512i iacc_row_min_0_67 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_4567, (_MM_PERM_ENUM)85), mins_67);
+                    __m512i iacc_row_min_1_67 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_01_4567, (_MM_PERM_ENUM)255), mins_67);
+                    __m512i iacc_row_min_2_67 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_4567, (_MM_PERM_ENUM)85), mins_67);
+                    __m512i iacc_row_min_3_67 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_23_4567, (_MM_PERM_ENUM)255), mins_67);
+
+                    __m512i iacc_row_min_0 = _mm512_add_epi32(_mm512_add_epi32(iacc_row_min_0_01, iacc_row_min_0_23), _mm512_add_epi32(iacc_row_min_0_45,iacc_row_min_0_67));
+                    __m512i iacc_row_min_1 = _mm512_add_epi32(_mm512_add_epi32(iacc_row_min_1_01, iacc_row_min_1_23), _mm512_add_epi32(iacc_row_min_1_45,iacc_row_min_1_67));
+                    __m512i iacc_row_min_2 = _mm512_add_epi32(_mm512_add_epi32(iacc_row_min_2_01, iacc_row_min_2_23), _mm512_add_epi32(iacc_row_min_2_45,iacc_row_min_2_67));
+                    __m512i iacc_row_min_3 = _mm512_add_epi32(_mm512_add_epi32(iacc_row_min_3_01, iacc_row_min_3_23), _mm512_add_epi32(iacc_row_min_3_45,iacc_row_min_3_67));
 
                     acc_min_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_0), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[0]);
                     acc_min_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_1), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[1]);
@@ -2538,10 +4941,12 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
             }
         }
     }
+
     if (anc != nc) {
         xstart = anc/8;
         y = 0;
     }
+
 #endif //AVX512F
 
     // Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation
@@ -2554,10 +4959,10 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
             a_ptrs[i + 1] = a_ptrs[i] + nb;
         }
 
-        // Take group of eight block_q4_kx8 structures at each pass of the loop and perform dot product operation
+        // Take group of eight block_q2_kx8 structures at each pass of the loop and perform dot product operation
         for (int64_t x = xstart; x < nc / 8; x++) {
 
-            const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb);
+            const block_q2_Kx8 * b_ptr = b_ptr_start + (x * b_nb);
 
             // Master FP accumulators
             __m256 acc_rows[16];
@@ -2572,62 +4977,95 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
 
             // For super block
             for (int64_t b = 0; b < nb; b++) {
-
-                // Scale values - Load the eight scale values of block_q4_kx8
+                // Delta values - Load the eight scale values of block_q2_kx8
                 const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
 
-                // dmin values - Load the eight dmin values of block_q4_kx8
+                // dmin values - Load the eight dmin values of block_q2_kx8
                 const __m256 col_dmin_f32 = GGML_F32Cx8_LOAD(b_ptr[b].dmin);
 
-                // Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration
-                for (int sb = 0; sb < QK_K / 64; sb++) {
+                // Loop to iterate over the sixteen sub blocks of a super block - eight sub blocks are processed per iteration
+                for (int sb = 0; sb < QK_K / 128; sb++) {
 
-                    // Load the eight block_q4_K for two sub blocks quantized values interleaved with each other in chunks of eight bytes - B0,B1 ....B6,B7
-                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256));
-                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256));
-                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256));
-                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256));
-                    const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256));
-                    const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256));
-                    const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256));
-                    const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256));
+                    // Load the eight block_q2_K for eight sub blocks quantized values interleaved with each other in chunks of eight bytes - B0,B1 ....B6,B7
+                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + sb * 256));
+                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32 + sb * 256));
+                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64 + sb * 256));
+                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96 + sb * 256));
+                    const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 128 + sb * 256));
+                    const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 160 + sb * 256));
+                    const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 192 + sb * 256));
+                    const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 224 + sb * 256));
 
                     // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values
+                    //superblock    sub block   which part of sub block
                     const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
                     const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
+
                     const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
                     const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
+
                     const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240);
                     const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240);
+
                     const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240);
                     const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240);
 
-                    // 4-bit -> 8-bit
-                    // First sub block of the two sub blocks processed in the iteration
-                    const __m256i rhs_mat_0145_00 = _mm256_and_si256(rhs_raw_mat_0145_0, m4b); //B00(0-7) B01(0-7) B04(0-7) B05(0-7)
-                    const __m256i rhs_mat_2367_00 = _mm256_and_si256(rhs_raw_mat_2367_0, m4b); //B02(0-7) B03(0-7) B06(0-7) B07(0-7)
+                    // 2-bit -> 8-bit
+                    // First sub block of the eight sub blocks processed in the iteration
+                    const __m256i rhs_mat_0145_00 = _mm256_and_si256(rhs_raw_mat_0145_0, m3b); //B00(0-7) B01(0-7) B04(0-7) B05(0-7)
+                    const __m256i rhs_mat_2367_00 = _mm256_and_si256(rhs_raw_mat_2367_0, m3b); //B02(0-7) B03(0-7) B06(0-7) B07(0-7)
+
+                    const __m256i rhs_mat_0145_01 = _mm256_and_si256(rhs_raw_mat_0145_1, m3b); //B00(8-15) B01(8-15) B04(8-15) B05(8-15)
+                    const __m256i rhs_mat_2367_01 = _mm256_and_si256(rhs_raw_mat_2367_1, m3b); //B02(8-15) B03(8-15) B06(8-15) B07(8-15)
+
+                    // Second sub block of the eight sub blocks processed in the iteration
+                    const __m256i rhs_mat_0145_10 = _mm256_and_si256(rhs_raw_mat_0145_2, m3b); //B10(0-7) B11(0-7) B14(0-7) B15(0-7)
+                    const __m256i rhs_mat_2367_10 = _mm256_and_si256(rhs_raw_mat_2367_2, m3b); //B12(0-7) B13(0-7) B16(0-7) B17(0-7)
+
+                    const __m256i rhs_mat_0145_11 = _mm256_and_si256(rhs_raw_mat_0145_3, m3b); //B10(8-15) B11(8-15) B14(8-15) B15(8-15)
+                    const __m256i rhs_mat_2367_11 = _mm256_and_si256(rhs_raw_mat_2367_3, m3b); //B12(8-15) B13(8-15) B16(8-15) B17(8-15)
+
+                    // Third sub block of the eight sub blocks processed in the iteration
+                    const __m256i rhs_mat_0145_20 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 2), m3b); //B20(0-7) B21(0-7) B24(0-7) B25(0-7)
+                    const __m256i rhs_mat_2367_20 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 2), m3b); //B22(0-7) B23(0-7) B26(0-7) B27(0-7)
+
+                    const __m256i rhs_mat_0145_21 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 2), m3b); //B20(8-15) B21(8-15) B24(8-15) B25(8-15)
+                    const __m256i rhs_mat_2367_21 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 2), m3b); //B22(8-15) B23(8-15) B26(8-15) B27(8-15)
+
+                    // Fourth sub block of the eight sub blocks processed in the iteration
+                    const __m256i rhs_mat_0145_30 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 2), m3b); //B30(0-7) B31(0-7) B34(0-7) B35(0-7)
+                    const __m256i rhs_mat_2367_30 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 2), m3b); //B32(0-7) B33(0-7) B36(0-7) B37(0-7)
 
-                    const __m256i rhs_mat_0145_01 = _mm256_and_si256(rhs_raw_mat_0145_1, m4b); //B00(8-15) B01(8-15) B04(8-15) B05(8-15)
-                    const __m256i rhs_mat_2367_01 = _mm256_and_si256(rhs_raw_mat_2367_1, m4b); //B02(8-15) B03(8-15) B06(8-15) B07(8-15)
+                    const __m256i rhs_mat_0145_31 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 2), m3b); //B30(8-15) B31(8-15) B34(8-15) B35(8-15)
+                    const __m256i rhs_mat_2367_31 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 2), m3b); //B32(8-15) B33(8-15) B36(8-15) B37(8-15)
 
-                    const __m256i rhs_mat_0145_02 = _mm256_and_si256(rhs_raw_mat_0145_2, m4b); //B00(16-23) B01(16-23) B04(16-23) B05(16-23)
-                    const __m256i rhs_mat_2367_02 = _mm256_and_si256(rhs_raw_mat_2367_2, m4b); //B02(16-23) B03(16-23) B06(16-23) B07(16-23)
+                    // Fifth sub block of the eight sub blocks processed in the iteration
+                    const __m256i rhs_mat_0145_40 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m3b); //B40(0-7) B41(0-7) B44(0-7) B45(0-7)
+                    const __m256i rhs_mat_2367_40 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m3b); //B42(0-7) B43(0-7) B46(0-7) B47(0-7)
 
-                    const __m256i rhs_mat_0145_03 = _mm256_and_si256(rhs_raw_mat_0145_3, m4b); //B00(24-31) B01(24-31) B04(24-31) B05(24-31)
-                    const __m256i rhs_mat_2367_03 = _mm256_and_si256(rhs_raw_mat_2367_3, m4b); //B02(24-31) B03(24-31) B06(24-31) B07(24-31)
+                    const __m256i rhs_mat_0145_41 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m3b); //B40(8-15) B41(8-15) B44(8-15) B45(8-15)
+                    const __m256i rhs_mat_2367_41 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m3b); //B42(8-15) B43(8-15) B46(8-15) B47(8-15)
 
-                    // Second sub block of the two sub blocks processed in the iteration
-                    const __m256i rhs_mat_0145_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b); //B10(0-7) B11(0-7) B14(0-7) B15(0-7)
-                    const __m256i rhs_mat_2367_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b); //B12(0-7) B13(0-7) B16(0-7) B17(0-7)
+                    // Sixth sub block of the eight sub blocks processed in the iteration
+                    const __m256i rhs_mat_0145_50 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 4), m3b); //B50(0-7) B51(0-7) B54(0-7) B55(0-7)
+                    const __m256i rhs_mat_2367_50 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 4), m3b); //B52(0-7) B53(0-7) B56(0-7) B57(0-7)
 
-                    const __m256i rhs_mat_0145_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b); //B10(8-15) B11(8-15) B14(8-15) B15(8-15)
-                    const __m256i rhs_mat_2367_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b); //B12(8-15) B13(8-15) B16(8-15) B17(8-15)
+                    const __m256i rhs_mat_0145_51 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 4), m3b); //B50(8-15) B51(8-15) B54(8-15) B55(8-15)
+                    const __m256i rhs_mat_2367_51 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 4), m3b); //B52(8-15) B53(8-15) B56(8-15) B57(8-15)
 
-                    const __m256i rhs_mat_0145_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 4), m4b); //B10(16-23) B11(16-23) B14(16-23) B15(16-23)
-                    const __m256i rhs_mat_2367_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 4), m4b); //B12(16-23) B13(16-23) B16(16-23) B17(16-23)
+                    // Seventh sub block of the eight sub blocks processed in the iteration
+                    const __m256i rhs_mat_0145_60 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 6), m3b); //B60(0-7) B61(0-7) B64(0-7) B65(0-7)
+                    const __m256i rhs_mat_2367_60 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 6), m3b); //B62(0-7) B63(0-7) B66(0-7) B67(0-7)
 
-                    const __m256i rhs_mat_0145_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 4), m4b); //B10(24-31) B11(24-31) B14(24-31) B15(24-31)
-                    const __m256i rhs_mat_2367_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 4), m4b); //B12(24-31) B13(24-31) B16(24-31) B17(24-31)
+                    const __m256i rhs_mat_0145_61 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 6), m3b); //B60(8-15) B61(8-15) B64(8-15) B65(8-15)
+                    const __m256i rhs_mat_2367_61 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 6), m3b); //B62(8-15) B63(8-15) B66(8-15) B67(8-15)
+
+                    // Eighth sub block of the eight sub blocks processed in the iteration
+                    const __m256i rhs_mat_0145_70 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 6), m3b); //B70(0-7) B71(0-7) B74(0-7) B75(0-7)
+                    const __m256i rhs_mat_2367_70 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 6), m3b); //B72(0-7) B73(0-7) B76(0-7) B77(0-7)
+
+                    const __m256i rhs_mat_0145_71 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 6), m3b); //B70(8-15) B71(8-15) B74(8-15) B75(8-15)
+                    const __m256i rhs_mat_2367_71 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 6), m3b); //B72(8-15) B73(8-15) B76(8-15) B77(8-15)
 
                     // Shuffle pattern one - right side input
                     const __m256i rhs_mat_0145_00_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_00, 136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3)
@@ -2636,23 +5074,47 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                     const __m256i rhs_mat_0145_01_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_01, 136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11)
                     const __m256i rhs_mat_2367_01_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_01, 136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11)
 
-                    const __m256i rhs_mat_0145_02_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_02, 136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19)
-                    const __m256i rhs_mat_2367_02_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_02, 136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19)
-
-                    const __m256i rhs_mat_0145_03_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_03, 136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27)
-                    const __m256i rhs_mat_2367_03_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_03, 136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27)
-
                     const __m256i rhs_mat_0145_10_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_10, 136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3)
                     const __m256i rhs_mat_2367_10_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_10, 136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3)
 
                     const __m256i rhs_mat_0145_11_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_11, 136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11)
                     const __m256i rhs_mat_2367_11_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_11, 136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11)
 
-                    const __m256i rhs_mat_0145_12_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_12, 136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19)
-                    const __m256i rhs_mat_2367_12_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_12, 136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19)
+                    const __m256i rhs_mat_0145_20_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_20, 136); //B20(0-3) B21(0-3) B20(0-3) B21(0-3) B24(0-3) B25(0-3) B24(0-3) B25(0-3)
+                    const __m256i rhs_mat_2367_20_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_20, 136); //B22(0-3) B23(0-3) B22(0-3) B23(0-3) B26(0-3) B27(0-3) B26(0-3) B27(0-3)
 
-                    const __m256i rhs_mat_0145_13_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_13, 136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27)
-                    const __m256i rhs_mat_2367_13_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_13, 136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27)
+                    const __m256i rhs_mat_0145_21_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_21, 136); //B20(8-11) B21(8-11) B20(8-11) B21(8-11) B24(8-11) B25(8-11) B24(8-11) B25(8-11)
+                    const __m256i rhs_mat_2367_21_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_21, 136); //B22(8-11) B23(8-11) B22(8-11) B23(8-11) B26(8-11) B27(8-11) B26(8-11) B27(8-11)
+
+                    const __m256i rhs_mat_0145_30_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_30, 136); //B30(0-3) B31(0-3) B30(0-3) B31(0-3) B34(0-3) B35(0-3) B34(0-3) B35(0-3)
+                    const __m256i rhs_mat_2367_30_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_30, 136); //B32(0-3) B33(0-3) B32(0-3) B33(0-3) B36(0-3) B37(0-3) B36(0-3) B37(0-3)
+
+                    const __m256i rhs_mat_0145_31_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_31, 136); //B30(8-11) B31(8-11) B30(8-11) B31(8-11) B34(8-11) B35(8-11) B34(8-11) B35(8-11
+                    const __m256i rhs_mat_2367_31_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_31, 136); //B32(8-11) B33(8-11) B32(8-11) B33(8-11) B36(8-11) B37(8-11) B36(8-11) B37(8-11)
+
+                    const __m256i rhs_mat_0145_40_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_40, 136); //B40(0-3) B41(0-3) B40(0-3) B41(0-3) B44(0-3) B45(0-3) B44(0-3) B45(0-3)
+                    const __m256i rhs_mat_2367_40_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_40, 136); //B42(0-3) B43(0-3) B42(0-3) B43(0-3) B46(0-3) B47(0-3) B46(0-3) B47(0-3)
+
+                    const __m256i rhs_mat_0145_41_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_41, 136); //B40(8-11) B41(8-11) B40(8-11) B41(8-11) B44(8-11) B45(8-11) B44(8-11) B45(8-11)
+                    const __m256i rhs_mat_2367_41_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_41, 136); //B42(8-11) B43(8-11) B42(8-11) B43(8-11) B46(8-11) B47(8-11) B46(8-11) B47(8-11)
+
+                    const __m256i rhs_mat_0145_50_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_50, 136); //B50(0-3) B51(0-3) B50(0-3) B51(0-3) B54(0-3) B55(0-3) B54(0-3) B55(0-3)
+                    const __m256i rhs_mat_2367_50_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_50, 136); //B52(0-3) B53(0-3) B52(0-3) B53(0-3) B56(0-3) B57(0-3) B56(0-3) B57(0-3)
+
+                    const __m256i rhs_mat_0145_51_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_51, 136); //B50(8-11) B51(8-11) B50(8-11) B51(8-11) B54(8-11) B55(8-11) B54(8-11) B55(8-11)
+                    const __m256i rhs_mat_2367_51_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_51, 136); //B52(8-11) B53(8-11) B52(8-11) B53(8-11) B56(8-11) B57(8-11) B56(8-11) B57(8-11)
+
+                    const __m256i rhs_mat_0145_60_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_60, 136); //B60(0-3) B61(0-3) B60(0-3) B61(0-3) B64(0-3) B65(0-3) B64(0-3) B65(0-3)
+                    const __m256i rhs_mat_2367_60_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_60, 136); //B62(0-3) B63(0-3) B62(0-3) B63(0-3) B66(0-3) B67(0-3) B66(0-3) B67(0-3)
+
+                    const __m256i rhs_mat_0145_61_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_61, 136); //B60(8-11) B61(8-11) B60(8-11) B61(8-11) B64(8-11) B65(8-11) B64(8-11) B65(8-11)
+                    const __m256i rhs_mat_2367_61_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_61, 136); //B62(8-11) B63(8-11) B62(8-11) B63(8-11) B66(8-11) B67(8-11) B66(8-11) B67(8-11)
+
+                    const __m256i rhs_mat_0145_70_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_70, 136); //B70(0-3) B71(0-3) B70(0-3) B71(0-3) B74(0-3) B75(0-3) B74(0-3) B75(0-3)
+                    const __m256i rhs_mat_2367_70_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_70, 136); //B72(0-3) B73(0-3) B72(0-3) B73(0-3) B76(0-3) B77(0-3) B76(0-3) B77(0-3)
+
+                    const __m256i rhs_mat_0145_71_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_71, 136); //B70(8-11) B71(8-11) B70(8-11) B71(8-11) B74(8-11) B75(8-11) B74(8-11) B75(8-11)
+                    const __m256i rhs_mat_2367_71_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_71, 136); //B72(8-11) B73(8-11) B72(8-11) B73(8-11) B76(8-11) B77(8-11) B76(8-11) B77(8-11)
 
 
                     // Shuffle pattern two - right side input
@@ -2662,53 +5124,80 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                     const __m256i rhs_mat_0145_01_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_01, 221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15)
                     const __m256i rhs_mat_2367_01_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_01, 221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15)
 
-                    const __m256i rhs_mat_0145_02_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_02, 221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23)
-                    const __m256i rhs_mat_2367_02_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_02, 221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23)
-
-                    const __m256i rhs_mat_0145_03_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_03, 221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31)
-                    const __m256i rhs_mat_2367_03_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_03, 221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31)
-
                     const __m256i rhs_mat_0145_10_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_10, 221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7)
                     const __m256i rhs_mat_2367_10_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_10, 221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7)
 
                     const __m256i rhs_mat_0145_11_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_11, 221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15)
                     const __m256i rhs_mat_2367_11_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_11, 221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15)
 
-                    const __m256i rhs_mat_0145_12_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_12, 221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23)
-                    const __m256i rhs_mat_2367_12_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_12, 221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23)
+                    const __m256i rhs_mat_0145_20_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_20, 221); //B20(4-7) B21(4-7) B20(4-7) B21(4-7) B24(4-7) B25(4-7) B24(4-7) B25(4-7)
+                    const __m256i rhs_mat_2367_20_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_20, 221); //B22(4-7) B23(4-7) B22(4-7) B23(4-7) B26(4-7) B27(4-7) B26(4-7) B27(4-7)
 
-                    const __m256i rhs_mat_0145_13_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_13, 221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31)
-                    const __m256i rhs_mat_2367_13_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_13, 221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31)
+                    const __m256i rhs_mat_0145_21_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_21, 221); //B20(12-15) B21(12-15) B20(12-15) B21(12-15) B24(12-15) B25(12-15) B24(12-15) B25(12-15)
+                    const __m256i rhs_mat_2367_21_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_21, 221); //B22(12-15) B23(12-15) B22(12-15) B23(12-15) B26(12-15) B27(12-15) B26(12-15) B27(12-15)
 
-                    uint32_t utmp_0[4], utmp_1[4];
+                    const __m256i rhs_mat_0145_30_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_30, 221); //B30(4-7) B31(4-7) B30(4-7) B31(4-7) B34(4-7) B35(4-7) B34(4-7) B35(4-7)
+                    const __m256i rhs_mat_2367_30_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_30, 221); //B32(4-7) B33(4-7) B32(4-7) B33(4-7) B36(4-7) B37(4-7) B36(4-7) B37(4-7)
 
-                    // Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together
-                    // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
-                    memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12);
-                    utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);
-                    const uint32_t uaux_0 = utmp_0[1] & kmask1;
-                    utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);
-                    utmp_0[2] = uaux_0;
-                    utmp_0[0] &= kmask1;
+                    const __m256i rhs_mat_0145_31_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_31, 221); //B30(12-15) B31(12-15) B30(12-15) B31(12-15) B34(12-15) B35(12-15) B34(12-15) B35(12-15)
+                    const __m256i rhs_mat_2367_31_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_31, 221); //B32(12-15) B33(12-15) B32(12-15) B33(12-15) B36(12-15) B37(12-15) B36(12-15) B37(12-15)
 
-                    // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
-                    memcpy(utmp_1, b_ptr[b].scales + 12 + sb * 24, 12);
-                    utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);
-                    const uint32_t uaux_1 = utmp_1[1] & kmask1;
-                    utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);
-                    utmp_1[2] = uaux_1;
-                    utmp_1[0] &= kmask1;
+                    const __m256i rhs_mat_0145_40_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_40, 221); //B40(4-7) B41(4-7) B40(4-7) B41(4-7) B44(4-7) B45(4-7) B44(4-7) B45(4-7)
+                    const __m256i rhs_mat_2367_40_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_40, 221); //B42(4-7) B43(4-7) B42(4-7) B43(4-7) B46(4-7) B47(4-7) B46(4-7) B47(4-7)
 
-                    // Scales of first sub block in the sb loop
-                    const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);
-                    const __m256i scales_0 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0));
+                    const __m256i rhs_mat_0145_41_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_41, 221); //B40(12-15) B41(12-15) B40(12-15) B41(12-15) B44(12-15) B45(12-15) B44(12-15) B45(12-15)
+                    const __m256i rhs_mat_2367_41_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_41, 221); //B42(12-15) B43(12-15) B42(12-15) B43(12-15) B46(12-15) B47(12-15) B46(12-15) B47(12-15)
 
-                    // Scales of second sub block in the sb loop
-                    const __m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);
-                    const __m256i scales_1 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1));
+                    const __m256i rhs_mat_0145_50_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_50, 221); //B50(4-7) B51(4-7) B50(4-7) B51(4-7) B54(4-7) B55(4-7) B54(4-7) B55(4-7)
+                    const __m256i rhs_mat_2367_50_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_50, 221); //B52(4-7) B53(4-7) B52(4-7) B53(4-7) B56(4-7) B57(4-7) B56(4-7) B57(4-7)
 
-                    // Mins of first and second sub block of Q4_K block are arranged side by side
-                    const __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(_mm_shuffle_epi32(mins_and_scales_0, 78), _mm_shuffle_epi32(mins_and_scales_1, 78)));
+                    const __m256i rhs_mat_0145_51_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_51, 221); //B50(12-15) B51(12-15) B50(12-15) B51(12-15) B54(12-15) B55(12-15) B54(12-15) B55(12-15)
+                    const __m256i rhs_mat_2367_51_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_51, 221); //B52(12-15) B53(12-15) B52(12-15) B53(12-15) B56(12-15) B57(12-15) B56(12-15) B57(12-15)
+
+                    const __m256i rhs_mat_0145_60_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_60, 221); //B60(4-7) B61(4-7) B60(4-7) B61(4-7) B64(4-7) B65(4-7) B64(4-7) B65(4-7)
+                    const __m256i rhs_mat_2367_60_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_60, 221); //B62(4-7) B63(4-7) B62(4-7) B63(4-7) B66(4-7) B67(4-7) B66(4-7) B67(4-7)
+
+                    const __m256i rhs_mat_0145_61_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_61, 221); //B60(12-15) B61(12-15) B60(12-15) B61(12-15) B64(12-15) B65(12-15) B64(12-15) B65(12-15)
+                    const __m256i rhs_mat_2367_61_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_61, 221); //B62(12-15) B63(12-15) B62(12-15) B63(12-15) B66(12-15) B67(12-15) B66(12-15) B67(12-15)
+
+                    const __m256i rhs_mat_0145_70_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_70, 221); //B70(4-7) B71(4-7) B70(4-7) B71(4-7) B74(4-7) B75(4-7) B74(4-7) B75(4-7)
+                    const __m256i rhs_mat_2367_70_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_70, 221); //B72(4-7) B73(4-7) B72(4-7) B73(4-7) B76(4-7) B77(4-7) B76(4-7) B77(4-7)
+
+                    const __m256i rhs_mat_0145_71_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_71, 221); //B70(12-15) B71(12-15) B70(12-15) B71(12-15) B74(12-15) B75(12-15) B74(12-15) B75(12-15)
+                    const __m256i rhs_mat_2367_71_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_71, 221); //B72(12-15) B73(12-15) B72(12-15) B73(12-15) B76(12-15) B77(12-15) B76(12-15) B77(12-15)
+
+                    //Scales and Mins of corresponding sub blocks from different Q2_K structures are stored together
+                    //s00 m00  s01 m01   s10 m10  s11 m11  s20 m20  s21 m21   s30 m30  s31 m31  s40 m40  s41 m41   s50 m50  s51 m51  s60 m60  s61 m61   s70 m70  s71 m71
+
+                    // Combine mins and scales for sub-blocks: 0-1, 2-3, 4-5, 6-7 in the sb loop
+                    const __m128i mins_and_scales_01 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + sb * 64));
+                    const __m128i mins_and_scales_23 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 16 + sb * 64));
+                    const __m128i mins_and_scales_45 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 32 + sb * 64));
+                    const __m128i mins_and_scales_67 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 48 + sb * 64));
+
+                    // Extract scales which is lower half from mins_and_scales
+                    const __m128i scales_01 = _mm_and_si128(mins_and_scales_01, m4b_sse);
+                    const __m128i scales_23 = _mm_and_si128(mins_and_scales_23, m4b_sse);
+                    const __m128i scales_45 = _mm_and_si128(mins_and_scales_45, m4b_sse);
+                    const __m128i scales_67 = _mm_and_si128(mins_and_scales_67, m4b_sse);
+
+                    // Extract mins which is upper half from mins_and_scales
+                    const __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_01, 4), m4b_sse));
+                    const __m256i mins_23 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_23, 4), m4b_sse));
+                    const __m256i mins_45 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_45, 4), m4b_sse));
+                    const __m256i mins_67 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_67, 4), m4b_sse));
+
+                    const __m256i scales_0 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_01, scalesmask1_sse));
+                    const __m256i scales_1 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_01, scalesmask2_sse));
+
+                    const __m256i scales_2 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_23, scalesmask1_sse));
+                    const __m256i scales_3 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_23, scalesmask2_sse));
+
+                    const __m256i scales_4 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_45, scalesmask1_sse));
+                    const __m256i scales_5 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_45, scalesmask2_sse));
+
+                    const __m256i scales_6 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_67, scalesmask1_sse));
+                    const __m256i scales_7 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_67, scalesmask2_sse));
 
                     const __m256i scale_0145_0 = _mm256_shuffle_epi32(scales_0, 68);
                     const __m256i scale_2367_0 = _mm256_shuffle_epi32(scales_0, 238);
@@ -2716,64 +5205,133 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                     const __m256i scale_0145_1 = _mm256_shuffle_epi32(scales_1, 68);
                     const __m256i scale_2367_1 = _mm256_shuffle_epi32(scales_1, 238);
 
+                    const __m256i scale_0145_2 = _mm256_shuffle_epi32(scales_2, 68);
+                    const __m256i scale_2367_2 = _mm256_shuffle_epi32(scales_2, 238);
+
+                    const __m256i scale_0145_3 = _mm256_shuffle_epi32(scales_3, 68);
+                    const __m256i scale_2367_3 = _mm256_shuffle_epi32(scales_3, 238);
+
+                    const __m256i scale_0145_4 = _mm256_shuffle_epi32(scales_4, 68);
+                    const __m256i scale_2367_4 = _mm256_shuffle_epi32(scales_4, 238);
+
+                    const __m256i scale_0145_5 = _mm256_shuffle_epi32(scales_5, 68);
+                    const __m256i scale_2367_5 = _mm256_shuffle_epi32(scales_5, 238);
+
+                    const __m256i scale_0145_6 = _mm256_shuffle_epi32(scales_6, 68);
+                    const __m256i scale_2367_6 = _mm256_shuffle_epi32(scales_6, 238);
+
+                    const __m256i scale_0145_7 = _mm256_shuffle_epi32(scales_7, 68);
+                    const __m256i scale_2367_7 = _mm256_shuffle_epi32(scales_7, 238);
+
+
                     for (int rp = 0; rp < 4; rp++) {
 
                         // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3
                         // Loaded as set of 128 bit vectors and repeated into a 256 bit vector
-                        __m256i lhs_mat_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 256 * sb)));
+                        __m256i lhs_mat_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 512 * sb)));
                         __m256i lhs_mat_01_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 0);
                         __m256i lhs_mat_23_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 17);
-                        __m256i lhs_mat_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 32 + 256 * sb)));
+                        __m256i lhs_mat_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 32 + 512 * sb)));
                         __m256i lhs_mat_01_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 0);
                         __m256i lhs_mat_23_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 17);
-                        __m256i lhs_mat_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 64 + 256 * sb)));
-                        __m256i lhs_mat_01_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 0);
-                        __m256i lhs_mat_23_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 17);
-                        __m256i lhs_mat_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 96 + 256 * sb)));
-                        __m256i lhs_mat_01_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 0);
-                        __m256i lhs_mat_23_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 17);
-                        __m256i lhs_mat_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 128 + 256 * sb)));
+                        __m256i lhs_mat_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 64 + 512 * sb)));
                         __m256i lhs_mat_01_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 0);
                         __m256i lhs_mat_23_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 17);
-                        __m256i lhs_mat_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 160 + 256 * sb)));
+                        __m256i lhs_mat_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 96 + 512 * sb)));
                         __m256i lhs_mat_01_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 0);
                         __m256i lhs_mat_23_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 17);
-                        __m256i lhs_mat_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 192 + 256 * sb)));
-                        __m256i lhs_mat_01_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 0);
-                        __m256i lhs_mat_23_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 17);
-                        __m256i lhs_mat_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 224 + 256 * sb)));
-                        __m256i lhs_mat_01_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 0);
-                        __m256i lhs_mat_23_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 17);
-
-                        // Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks
-                        __m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].bsums + 16 * sb)));
-                        __m256i lhs_bsums_hsum_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1)));
-                        lhs_bsums_hsum_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_0123_01, lhs_bsums_hsum_0123_01, 0);
+                        __m256i lhs_mat_0123_20 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 128 + 512 * sb)));
+                        __m256i lhs_mat_01_20 = _mm256_permute2f128_si256(lhs_mat_0123_20, lhs_mat_0123_20, 0);
+                        __m256i lhs_mat_23_20 = _mm256_permute2f128_si256(lhs_mat_0123_20, lhs_mat_0123_20, 17);
+                        __m256i lhs_mat_0123_21 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 160 + 512 * sb)));
+                        __m256i lhs_mat_01_21 = _mm256_permute2f128_si256(lhs_mat_0123_21, lhs_mat_0123_21, 0);
+                        __m256i lhs_mat_23_21 = _mm256_permute2f128_si256(lhs_mat_0123_21, lhs_mat_0123_21, 17);
+                        __m256i lhs_mat_0123_30 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 192 + 512 * sb)));
+                        __m256i lhs_mat_01_30 = _mm256_permute2f128_si256(lhs_mat_0123_30, lhs_mat_0123_30, 0);
+                        __m256i lhs_mat_23_30 = _mm256_permute2f128_si256(lhs_mat_0123_30, lhs_mat_0123_30, 17);
+                        __m256i lhs_mat_0123_31 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 224 + 512 * sb)));
+                        __m256i lhs_mat_01_31 = _mm256_permute2f128_si256(lhs_mat_0123_31, lhs_mat_0123_31, 0);
+                        __m256i lhs_mat_23_31 = _mm256_permute2f128_si256(lhs_mat_0123_31, lhs_mat_0123_31, 17);
+
+                        __m256i lhs_mat_0123_40 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 256 + 512 * sb)));
+                        __m256i lhs_mat_01_40 = _mm256_permute2f128_si256(lhs_mat_0123_40, lhs_mat_0123_40, 0);
+                        __m256i lhs_mat_23_40 = _mm256_permute2f128_si256(lhs_mat_0123_40, lhs_mat_0123_40, 17);
+                        __m256i lhs_mat_0123_41 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 288 + 512 * sb)));
+                        __m256i lhs_mat_01_41 = _mm256_permute2f128_si256(lhs_mat_0123_41, lhs_mat_0123_41, 0);
+                        __m256i lhs_mat_23_41 = _mm256_permute2f128_si256(lhs_mat_0123_41, lhs_mat_0123_41, 17);
+                        __m256i lhs_mat_0123_50 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 320 + 512 * sb)));
+                        __m256i lhs_mat_01_50 = _mm256_permute2f128_si256(lhs_mat_0123_50, lhs_mat_0123_50, 0);
+                        __m256i lhs_mat_23_50 = _mm256_permute2f128_si256(lhs_mat_0123_50, lhs_mat_0123_50, 17);
+                        __m256i lhs_mat_0123_51 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 352 + 512 * sb)));
+                        __m256i lhs_mat_01_51 = _mm256_permute2f128_si256(lhs_mat_0123_51, lhs_mat_0123_51, 0);
+                        __m256i lhs_mat_23_51 = _mm256_permute2f128_si256(lhs_mat_0123_51, lhs_mat_0123_51, 17);
+                        __m256i lhs_mat_0123_60 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 384 + 512 * sb)));
+                        __m256i lhs_mat_01_60 = _mm256_permute2f128_si256(lhs_mat_0123_60, lhs_mat_0123_60, 0);
+                        __m256i lhs_mat_23_60 = _mm256_permute2f128_si256(lhs_mat_0123_60, lhs_mat_0123_60, 17);
+                        __m256i lhs_mat_0123_61 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 416 + 512 * sb)));
+                        __m256i lhs_mat_01_61 = _mm256_permute2f128_si256(lhs_mat_0123_61, lhs_mat_0123_61, 0);
+                        __m256i lhs_mat_23_61 = _mm256_permute2f128_si256(lhs_mat_0123_61, lhs_mat_0123_61, 17);
+                        __m256i lhs_mat_0123_70 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 448 + 512 * sb)));
+                        __m256i lhs_mat_01_70 = _mm256_permute2f128_si256(lhs_mat_0123_70, lhs_mat_0123_70, 0);
+                        __m256i lhs_mat_23_70 = _mm256_permute2f128_si256(lhs_mat_0123_70, lhs_mat_0123_70, 17);
+                        __m256i lhs_mat_0123_71 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 480 + 512 * sb)));
+                        __m256i lhs_mat_01_71 = _mm256_permute2f128_si256(lhs_mat_0123_71, lhs_mat_0123_71, 0);
+                        __m256i lhs_mat_23_71 = _mm256_permute2f128_si256(lhs_mat_0123_71, lhs_mat_0123_71, 17);
+
+                        // Bsums are loaded for the different Q8_K blocks
+                        __m128i lhs_raw_bsums_01_0123 = _mm_loadu_si128((const __m128i *)((a_ptrs[rp][b].bsums + 32 * sb)));
+                        __m128i lhs_raw_bsums_23_0123 = _mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].bsums + 8 + 32 * sb));
+                        __m128i lhs_raw_bsums_01_4567 = _mm_loadu_si128((const __m128i *)((a_ptrs[rp][b].bsums + 16 + 32 * sb)));
+                        __m128i lhs_raw_bsums_23_4567 = _mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].bsums + 24 + 32 * sb));
 
                         // Shuffle pattern one - left side input
                         const __m256i lhs_mat_01_00_sp1 = _mm256_shuffle_epi32(lhs_mat_01_00, 160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)
                         const __m256i lhs_mat_23_00_sp1 = _mm256_shuffle_epi32(lhs_mat_23_00, 160); //A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3)
 
-                        const __m256i lhs_mat_01_01_sp1 = _mm256_shuffle_epi32(lhs_mat_01_01, 160);  //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)
-                        const __m256i lhs_mat_23_01_sp1 = _mm256_shuffle_epi32(lhs_mat_23_01, 160);  //A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11)
-
-                        const __m256i lhs_mat_01_02_sp1 = _mm256_shuffle_epi32(lhs_mat_01_02, 160);  //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19)
-                        const __m256i lhs_mat_23_02_sp1 = _mm256_shuffle_epi32(lhs_mat_23_02, 160);  //A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19)
-
-                        const __m256i lhs_mat_01_03_sp1 = _mm256_shuffle_epi32(lhs_mat_01_03, 160);  //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27)
-                        const __m256i lhs_mat_23_03_sp1 = _mm256_shuffle_epi32(lhs_mat_23_03, 160); //A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27)
+                        const __m256i lhs_mat_01_01_sp1 = _mm256_shuffle_epi32(lhs_mat_01_01, 160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)
+                        const __m256i lhs_mat_23_01_sp1 = _mm256_shuffle_epi32(lhs_mat_23_01, 160); //A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11)
 
                         const __m256i lhs_mat_01_10_sp1 = _mm256_shuffle_epi32(lhs_mat_01_10, 160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3)
                         const __m256i lhs_mat_23_10_sp1 = _mm256_shuffle_epi32(lhs_mat_23_10, 160); //A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3)
 
-                        const __m256i lhs_mat_01_11_sp1 = _mm256_shuffle_epi32(lhs_mat_01_11, 160);  //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)
-                        const __m256i lhs_mat_23_11_sp1 = _mm256_shuffle_epi32(lhs_mat_23_11, 160);  //A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11)
+                        const __m256i lhs_mat_01_11_sp1 = _mm256_shuffle_epi32(lhs_mat_01_11, 160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)
+                        const __m256i lhs_mat_23_11_sp1 = _mm256_shuffle_epi32(lhs_mat_23_11, 160); //A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11)
 
-                        const __m256i lhs_mat_01_12_sp1 = _mm256_shuffle_epi32(lhs_mat_01_12, 160);  //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19)
-                        const __m256i lhs_mat_23_12_sp1 = _mm256_shuffle_epi32(lhs_mat_23_12, 160);  //A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19)
+                        const __m256i lhs_mat_01_20_sp1 = _mm256_shuffle_epi32(lhs_mat_01_20, 160); //A20(0-3) A20(0-3) A21(0-3) A21(0-3) A20(0-3) A20(0-3) A21(0-3) A21(0-3)
+                        const __m256i lhs_mat_23_20_sp1 = _mm256_shuffle_epi32(lhs_mat_23_20, 160); //A22(0-3) A23(0-3) A22(0-3) A23(0-3) A22(0-3) A23(0-3) A22(0-3) A23(0-3)
 
-                        const __m256i lhs_mat_01_13_sp1 = _mm256_shuffle_epi32(lhs_mat_01_13, 160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27)
-                        const __m256i lhs_mat_23_13_sp1 = _mm256_shuffle_epi32(lhs_mat_23_13, 160); //A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27)
+                        const __m256i lhs_mat_01_21_sp1 = _mm256_shuffle_epi32(lhs_mat_01_21, 160); //A20(8-11) A20(8-11) A21(8-11) A21(8-11) A20(8-11) A20(8-11) A21(8-11) A21(8-11)
+                        const __m256i lhs_mat_23_21_sp1 = _mm256_shuffle_epi32(lhs_mat_23_21, 160); //A22(8-11) A23(8-11) A22(8-11) A23(8-11) A22(8-11) A23(8-11) A22(8-11) A23(8-11)
+
+                        const __m256i lhs_mat_01_30_sp1 = _mm256_shuffle_epi32(lhs_mat_01_30, 160); //A30(0-3) A30(0-3) A31(0-3) A31(0-3) A30(0-3) A30(0-3) A31(0-3) A31(0-3)
+                        const __m256i lhs_mat_23_30_sp1 = _mm256_shuffle_epi32(lhs_mat_23_30, 160); //A32(0-3) A33(0-3) A32(0-3) A33(0-3) A32(0-3) A33(0-3) A32(0-3) A33(0-3)
+
+                        const __m256i lhs_mat_01_31_sp1 = _mm256_shuffle_epi32(lhs_mat_01_31, 160); //A30(8-11) A30(8-11) A31(8-11) A31(8-11) A30(8-11) A30(8-11) A31(8-11) A31(8-11)
+                        const __m256i lhs_mat_23_31_sp1 = _mm256_shuffle_epi32(lhs_mat_23_31, 160); //A32(8-11) A33(8-11) A32(8-11) A33(8-11) A32(8-11) A33(8-11) A32(8-11) A33(8-11)
+
+                        const __m256i lhs_mat_01_40_sp1 = _mm256_shuffle_epi32(lhs_mat_01_40, 160); //A40(0-3) A40(0-3) A41(0-3) A41(0-3) A40(0-3) A40(0-3) A41(0-3) A41(0-3)
+                        const __m256i lhs_mat_23_40_sp1 = _mm256_shuffle_epi32(lhs_mat_23_40, 160); //A42(0-3) A43(0-3) A42(0-3) A43(0-3) A42(0-3) A43(0-3) A42(0-3) A43(0-3)
+
+                        const __m256i lhs_mat_01_41_sp1 = _mm256_shuffle_epi32(lhs_mat_01_41, 160); //A40(8-11) A40(8-11) A41(8-11) A41(8-11) A40(8-11) A40(8-11) A41(8-11) A41(8-11)
+                        const __m256i lhs_mat_23_41_sp1 = _mm256_shuffle_epi32(lhs_mat_23_41, 160); //A42(8-11) A43(8-11) A42(8-11) A43(8-11) A42(8-11) A43(8-11) A42(8-11) A43(8-11)
+
+                        const __m256i lhs_mat_01_50_sp1 = _mm256_shuffle_epi32(lhs_mat_01_50, 160); //A50(0-3) A50(0-3) A51(0-3) A51(0-3) A50(0-3) A50(0-3) A51(0-3) A51(0-3)
+                        const __m256i lhs_mat_23_50_sp1 = _mm256_shuffle_epi32(lhs_mat_23_50, 160); //A52(0-3) A53(0-3) A52(0-3) A53(0-3) A52(0-3) A53(0-3) A52(0-3) A53(0-3)
+
+                        const __m256i lhs_mat_01_51_sp1 = _mm256_shuffle_epi32(lhs_mat_01_51, 160); //A50(8-11) A50(8-11) A51(8-11) A51(8-11) A50(8-11) A50(8-11) A51(8-11) A51(8-11)
+                        const __m256i lhs_mat_23_51_sp1 = _mm256_shuffle_epi32(lhs_mat_23_51, 160); //A52(8-11) A53(8-11) A52(8-11) A53(8-11) A52(8-11) A53(8-11) A52(8-11) A53(8-11)
+
+                        const __m256i lhs_mat_01_60_sp1 = _mm256_shuffle_epi32(lhs_mat_01_60, 160); //A60(0-3) A60(0-3) A61(0-3) A61(0-3) A60(0-3) A60(0-3) A61(0-3) A61(0-3)
+                        const __m256i lhs_mat_23_60_sp1 = _mm256_shuffle_epi32(lhs_mat_23_60, 160); //A62(0-3) A63(0-3) A62(0-3) A63(0-3) A62(0-3) A63(0-3) A62(0-3) A63(0-3)
+
+                        const __m256i lhs_mat_01_61_sp1 = _mm256_shuffle_epi32(lhs_mat_01_61, 160); //A60(8-11) A60(8-11) A61(8-11) A61(8-11) A60(8-11) A60(8-11) A61(8-11) A61(8-11)
+                        const __m256i lhs_mat_23_61_sp1 = _mm256_shuffle_epi32(lhs_mat_23_61, 160); //A62(8-11) A63(8-11) A62(8-11) A63(8-11) A62(8-11) A63(8-11) A62(8-11) A63(8-11)
+
+                        const __m256i lhs_mat_01_70_sp1 = _mm256_shuffle_epi32(lhs_mat_01_70, 160); //A70(0-3) A70(0-3) A71(0-3) A71(0-3) A70(0-3) A70(0-3) A71(0-3) A71(0-3)
+                        const __m256i lhs_mat_23_70_sp1 = _mm256_shuffle_epi32(lhs_mat_23_70, 160); //A72(0-3) A73(0-3) A72(0-3) A73(0-3) A72(0-3) A73(0-3) A72(0-3) A73(0-3)
+
+                        const __m256i lhs_mat_01_71_sp1 = _mm256_shuffle_epi32(lhs_mat_01_71, 160); //A70(8-11) A70(8-11) A71(8-11) A71(8-11) A70(8-11) A70(8-11) A71(8-11) A71(8-11)
+                        const __m256i lhs_mat_23_71_sp1 = _mm256_shuffle_epi32(lhs_mat_23_71, 160); //A72(8-11) A73(8-11) A72(8-11) A73(8-11) A72(8-11) A73(8-11) A72(8-11) A73(8-11)
 
                         // Shuffle pattern two- left side input
                         const __m256i lhs_mat_01_00_sp2 = _mm256_shuffle_epi32(lhs_mat_01_00, 245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7)
@@ -2782,44 +5340,147 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                         const __m256i lhs_mat_01_01_sp2 = _mm256_shuffle_epi32(lhs_mat_01_01, 245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15)
                         const __m256i lhs_mat_23_01_sp2 = _mm256_shuffle_epi32(lhs_mat_23_01, 245); //A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15)
 
-                        const __m256i lhs_mat_01_02_sp2 = _mm256_shuffle_epi32(lhs_mat_01_02, 245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23)
-                        const __m256i lhs_mat_23_02_sp2 = _mm256_shuffle_epi32(lhs_mat_23_02, 245); //A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23)
-
-                        const __m256i lhs_mat_01_03_sp2 = _mm256_shuffle_epi32(lhs_mat_01_03, 245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31)
-                        const __m256i lhs_mat_23_03_sp2 = _mm256_shuffle_epi32(lhs_mat_23_03, 245); //A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31)
-
                         const __m256i lhs_mat_01_10_sp2 = _mm256_shuffle_epi32(lhs_mat_01_10, 245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7)
                         const __m256i lhs_mat_23_10_sp2 = _mm256_shuffle_epi32(lhs_mat_23_10, 245); //A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7)
 
                         const __m256i lhs_mat_01_11_sp2 = _mm256_shuffle_epi32(lhs_mat_01_11, 245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15)
                         const __m256i lhs_mat_23_11_sp2 = _mm256_shuffle_epi32(lhs_mat_23_11, 245); //A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15)
 
-                        const __m256i lhs_mat_01_12_sp2 = _mm256_shuffle_epi32(lhs_mat_01_12, 245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23)
-                        const __m256i lhs_mat_23_12_sp2 = _mm256_shuffle_epi32(lhs_mat_23_12, 245); //A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23)
+                        const __m256i lhs_mat_01_20_sp2 = _mm256_shuffle_epi32(lhs_mat_01_20, 245); //A20(4-7) A20(4-7) A21(4-7) A21(4-7) A20(4-7) A20(4-7) A21(4-7) A21(4-7)
+                        const __m256i lhs_mat_23_20_sp2 = _mm256_shuffle_epi32(lhs_mat_23_20, 245); //A22(4-7) A23(4-7) A22(4-7) A23(4-7) A22(4-7) A23(4-7) A22(4-7) A23(4-7)
 
-                        const __m256i lhs_mat_01_13_sp2 = _mm256_shuffle_epi32(lhs_mat_01_13, 245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31)
-                        const __m256i lhs_mat_23_13_sp2 = _mm256_shuffle_epi32(lhs_mat_23_13, 245); //A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31)
+                        const __m256i lhs_mat_01_21_sp2 = _mm256_shuffle_epi32(lhs_mat_01_21, 245); //A20(12-15) A20(12-15) A21(12-15) A21(12-15) A20(12-15) A20(12-15) A21(12-15) A21(12-15)
+                        const __m256i lhs_mat_23_21_sp2 = _mm256_shuffle_epi32(lhs_mat_23_21, 245); //A22(12-15) A23(12-15) A22(12-15) A23(12-15) A22(12-15) A23(12-15) A22(12-15) A23(12-15)
+
+                        const __m256i lhs_mat_01_30_sp2 = _mm256_shuffle_epi32(lhs_mat_01_30, 245); //A30(4-7) A30(4-7) A31(4-7) A31(4-7) A30(4-7) A30(4-7) A31(4-7) A31(4-7)
+                        const __m256i lhs_mat_23_30_sp2 = _mm256_shuffle_epi32(lhs_mat_23_30, 245); //A32(4-7) A33(4-7) A32(4-7) A33(4-7) A32(4-7) A33(4-7) A32(4-7) A33(4-7)
+
+                        const __m256i lhs_mat_01_31_sp2 = _mm256_shuffle_epi32(lhs_mat_01_31, 245); //A30(12-15) A30(12-15) A31(12-15) A31(12-15) A30(12-15) A30(12-15) A31(12-15) A31(12-15)
+                        const __m256i lhs_mat_23_31_sp2 = _mm256_shuffle_epi32(lhs_mat_23_31, 245); //A32(12-15) A33(12-15) A32(12-15) A33(12-15) A32(12-15) A33(12-15) A32(12-15) A33(12-15)
+
+                        const __m256i lhs_mat_01_40_sp2 = _mm256_shuffle_epi32(lhs_mat_01_40, 245); //A40(4-7) A40(4-7) A41(4-7) A41(4-7) A40(4-7) A40(4-7) A41(4-7) A41(4-7)
+                        const __m256i lhs_mat_23_40_sp2 = _mm256_shuffle_epi32(lhs_mat_23_40, 245); //A42(4-7) A43(4-7) A42(4-7) A43(4-7) A42(4-7) A43(4-7) A42(4-7) A43(4-7)
+
+                        const __m256i lhs_mat_01_41_sp2 = _mm256_shuffle_epi32(lhs_mat_01_41, 245); //A40(12-15) A40(12-15) A41(12-15) A41(12-15) A40(12-15) A40(12-15) A41(12-15) A41(12-15)
+                        const __m256i lhs_mat_23_41_sp2 = _mm256_shuffle_epi32(lhs_mat_23_41, 245); //A42(12-15) A43(12-15) A42(12-15) A43(12-15) A42(12-15) A43(12-15) A42(12-15) A43(12-15)
+
+                        const __m256i lhs_mat_01_50_sp2 = _mm256_shuffle_epi32(lhs_mat_01_50, 245); //A50(4-7) A50(4-7) A51(4-7) A51(4-7) A50(4-7) A50(4-7) A51(4-7) A51(4-7)
+                        const __m256i lhs_mat_23_50_sp2 = _mm256_shuffle_epi32(lhs_mat_23_50, 245); //A52(4-7) A53(4-7) A52(4-7) A53(4-7) A52(4-7) A53(4-7) A52(4-7) A53(4-7)
+
+                        const __m256i lhs_mat_01_51_sp2 = _mm256_shuffle_epi32(lhs_mat_01_51, 245); //A50(12-15) A50(12-15) A51(12-15) A51(12-15) A50(12-15) A50(12-15) A51(12-15) A51(12-15)
+                        const __m256i lhs_mat_23_51_sp2 = _mm256_shuffle_epi32(lhs_mat_23_51, 245); //A52(12-15) A53(12-15) A52(12-15) A53(12-15) A52(12-15) A53(12-15) A52(12-15) A53(12-15)
+
+                        const __m256i lhs_mat_01_60_sp2 = _mm256_shuffle_epi32(lhs_mat_01_60, 245); //A60(4-7) A60(4-7) A61(4-7) A61(4-7) A60(4-7) A60(4-7) A61(4-7) A61(4-7)
+                        const __m256i lhs_mat_23_60_sp2 = _mm256_shuffle_epi32(lhs_mat_23_60, 245); //A62(4-7) A63(4-7) A62(4-7) A63(4-7) A62(4-7) A63(4-7) A62(4-7) A63(4-7)
+
+                        const __m256i lhs_mat_01_61_sp2 = _mm256_shuffle_epi32(lhs_mat_01_61, 245); //A60(12-15) A60(12-15) A61(12-15) A61(12-15) A60(12-15) A60(12-15) A61(12-15) A61(12-15)
+                        const __m256i lhs_mat_23_61_sp2 = _mm256_shuffle_epi32(lhs_mat_23_61, 245); //A62(12-15) A63(12-15) A62(12-15) A63(12-15) A62(12-15) A63(12-15) A62(12-15) A63(12-15)
+
+                        const __m256i lhs_mat_01_70_sp2 = _mm256_shuffle_epi32(lhs_mat_01_70, 245); //A70(4-7) A70(4-7) A71(4-7) A71(4-7) A70(4-7) A70(4-7) A71(4-7) A71(4-7)
+                        const __m256i lhs_mat_23_70_sp2 = _mm256_shuffle_epi32(lhs_mat_23_70, 245); //A72(4-7) A73(4-7) A72(4-7) A73(4-7) A72(4-7) A73(4-7) A72(4-7) A73(4-7)
+
+                        const __m256i lhs_mat_01_71_sp2 = _mm256_shuffle_epi32(lhs_mat_01_71, 245); //A70(12-15) A70(12-15) A71(12-15) A71(12-15) A70(12-15) A70(12-15) A71(12-15) A71(12-15)
+                        const __m256i lhs_mat_23_71_sp2 = _mm256_shuffle_epi32(lhs_mat_23_71, 245); //A72(12-15) A73(12-15) A72(12-15) A73(12-15) A72(12-15) A73(12-15) A72(12-15) A73(12-15)
 
                         // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
-                        __m256i iacc_mat_00_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_01_00_sp1));
-                        __m256i iacc_mat_01_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_01_00_sp1));
-                        __m256i iacc_mat_10_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_23_00_sp1));
-                        __m256i iacc_mat_11_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_23_00_sp1));
-                        __m256i iacc_mat_00_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_01_10_sp1));
-                        __m256i iacc_mat_01_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_01_10_sp1));
-                        __m256i iacc_mat_10_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_23_10_sp1));
-                        __m256i iacc_mat_11_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_23_10_sp1));
+                        __m256i iacc_mat_00_0_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_01_00_sp1),_mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_01_01_sp1));
+                        __m256i iacc_mat_01_0_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_01_00_sp1),_mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_01_01_sp1));
 
-                        __m256i iacc_mat_00_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_01_00_sp2));
-                        __m256i iacc_mat_01_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_01_00_sp2));
-                        __m256i iacc_mat_10_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_23_00_sp2));
-                        __m256i iacc_mat_11_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_23_00_sp2));
-                        __m256i iacc_mat_00_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_01_10_sp2));
-                        __m256i iacc_mat_01_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_01_10_sp2));
-                        __m256i iacc_mat_10_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_23_10_sp2));
-                        __m256i iacc_mat_11_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_23_10_sp2));
+                        __m256i iacc_mat_10_0_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_23_00_sp1),_mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_23_01_sp1));
+                        __m256i iacc_mat_11_0_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_23_00_sp1),_mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_23_01_sp1));
 
-                        // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+                        __m256i iacc_mat_00_1_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_01_10_sp1),_mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_01_11_sp1));
+                        __m256i iacc_mat_01_1_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_01_10_sp1),_mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_01_11_sp1));
+
+                        __m256i iacc_mat_10_1_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_23_10_sp1),_mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_23_11_sp1));
+                        __m256i iacc_mat_11_1_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_23_10_sp1),_mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_23_11_sp1));
+
+                        __m256i iacc_mat_00_2_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_20_sp1, lhs_mat_01_20_sp1),_mm256_maddubs_epi16(rhs_mat_0145_21_sp1, lhs_mat_01_21_sp1));
+                        __m256i iacc_mat_01_2_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_20_sp1, lhs_mat_01_20_sp1),_mm256_maddubs_epi16(rhs_mat_2367_21_sp1, lhs_mat_01_21_sp1));
+
+                        __m256i iacc_mat_10_2_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_20_sp1, lhs_mat_23_20_sp1),_mm256_maddubs_epi16(rhs_mat_0145_21_sp1, lhs_mat_23_21_sp1));
+                        __m256i iacc_mat_11_2_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_20_sp1, lhs_mat_23_20_sp1),_mm256_maddubs_epi16(rhs_mat_2367_21_sp1, lhs_mat_23_21_sp1));
+
+                        __m256i iacc_mat_00_3_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_30_sp1, lhs_mat_01_30_sp1),_mm256_maddubs_epi16(rhs_mat_0145_31_sp1, lhs_mat_01_31_sp1));
+                        __m256i iacc_mat_01_3_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_30_sp1, lhs_mat_01_30_sp1),_mm256_maddubs_epi16(rhs_mat_2367_31_sp1, lhs_mat_01_31_sp1));
+
+                        __m256i iacc_mat_10_3_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_30_sp1, lhs_mat_23_30_sp1),_mm256_maddubs_epi16(rhs_mat_0145_31_sp1, lhs_mat_23_31_sp1));
+                        __m256i iacc_mat_11_3_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_30_sp1, lhs_mat_23_30_sp1),_mm256_maddubs_epi16(rhs_mat_2367_31_sp1, lhs_mat_23_31_sp1));
+
+                        __m256i iacc_mat_00_4_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_40_sp1, lhs_mat_01_40_sp1),_mm256_maddubs_epi16(rhs_mat_0145_41_sp1, lhs_mat_01_41_sp1));
+                        __m256i iacc_mat_01_4_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_40_sp1, lhs_mat_01_40_sp1),_mm256_maddubs_epi16(rhs_mat_2367_41_sp1, lhs_mat_01_41_sp1));
+
+                        __m256i iacc_mat_10_4_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_40_sp1, lhs_mat_23_40_sp1),_mm256_maddubs_epi16(rhs_mat_0145_41_sp1, lhs_mat_23_41_sp1));
+                        __m256i iacc_mat_11_4_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_40_sp1, lhs_mat_23_40_sp1),_mm256_maddubs_epi16(rhs_mat_2367_41_sp1, lhs_mat_23_41_sp1));
+
+                        __m256i iacc_mat_00_5_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_50_sp1, lhs_mat_01_50_sp1),_mm256_maddubs_epi16(rhs_mat_0145_51_sp1, lhs_mat_01_51_sp1));
+                        __m256i iacc_mat_01_5_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_50_sp1, lhs_mat_01_50_sp1),_mm256_maddubs_epi16(rhs_mat_2367_51_sp1, lhs_mat_01_51_sp1));
+
+                        __m256i iacc_mat_10_5_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_50_sp1, lhs_mat_23_50_sp1),_mm256_maddubs_epi16(rhs_mat_0145_51_sp1, lhs_mat_23_51_sp1));
+                        __m256i iacc_mat_11_5_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_50_sp1, lhs_mat_23_50_sp1),_mm256_maddubs_epi16(rhs_mat_2367_51_sp1, lhs_mat_23_51_sp1));
+
+                        __m256i iacc_mat_00_6_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_60_sp1, lhs_mat_01_60_sp1),_mm256_maddubs_epi16(rhs_mat_0145_61_sp1, lhs_mat_01_61_sp1));
+                        __m256i iacc_mat_01_6_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_60_sp1, lhs_mat_01_60_sp1),_mm256_maddubs_epi16(rhs_mat_2367_61_sp1, lhs_mat_01_61_sp1));
+
+                        __m256i iacc_mat_10_6_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_60_sp1, lhs_mat_23_60_sp1),_mm256_maddubs_epi16(rhs_mat_0145_61_sp1, lhs_mat_23_61_sp1));
+                        __m256i iacc_mat_11_6_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_60_sp1, lhs_mat_23_60_sp1),_mm256_maddubs_epi16(rhs_mat_2367_61_sp1, lhs_mat_23_61_sp1));
+
+                        __m256i iacc_mat_00_7_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_70_sp1, lhs_mat_01_70_sp1),_mm256_maddubs_epi16(rhs_mat_0145_71_sp1, lhs_mat_01_71_sp1));
+                        __m256i iacc_mat_01_7_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_70_sp1, lhs_mat_01_70_sp1),_mm256_maddubs_epi16(rhs_mat_2367_71_sp1, lhs_mat_01_71_sp1));
+
+                        __m256i iacc_mat_10_7_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_70_sp1, lhs_mat_23_70_sp1),_mm256_maddubs_epi16(rhs_mat_0145_71_sp1, lhs_mat_23_71_sp1));
+                        __m256i iacc_mat_11_7_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_70_sp1, lhs_mat_23_70_sp1),_mm256_maddubs_epi16(rhs_mat_2367_71_sp1, lhs_mat_23_71_sp1));
+
+
+                        __m256i iacc_mat_00_0_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_01_00_sp2),_mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_01_01_sp2));
+                        __m256i iacc_mat_01_0_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_01_00_sp2),_mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_01_01_sp2));
+
+                        __m256i iacc_mat_10_0_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_23_00_sp2),_mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_23_01_sp2));
+                        __m256i iacc_mat_11_0_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_23_00_sp2),_mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_23_01_sp2));
+
+                        __m256i iacc_mat_00_1_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_01_10_sp2),_mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_01_11_sp2));
+                        __m256i iacc_mat_01_1_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_01_10_sp2),_mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_01_11_sp2));
+
+                        __m256i iacc_mat_10_1_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_23_10_sp2),_mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_23_11_sp2));
+                        __m256i iacc_mat_11_1_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_23_10_sp2),_mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_23_11_sp2));
+
+                        __m256i iacc_mat_00_2_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_20_sp2, lhs_mat_01_20_sp2),_mm256_maddubs_epi16(rhs_mat_0145_21_sp2, lhs_mat_01_21_sp2));
+                        __m256i iacc_mat_01_2_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_20_sp2, lhs_mat_01_20_sp2),_mm256_maddubs_epi16(rhs_mat_2367_21_sp2, lhs_mat_01_21_sp2));
+
+                        __m256i iacc_mat_10_2_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_20_sp2, lhs_mat_23_20_sp2),_mm256_maddubs_epi16(rhs_mat_0145_21_sp2, lhs_mat_23_21_sp2));
+                        __m256i iacc_mat_11_2_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_20_sp2, lhs_mat_23_20_sp2),_mm256_maddubs_epi16(rhs_mat_2367_21_sp2, lhs_mat_23_21_sp2));
+
+                        __m256i iacc_mat_00_3_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_30_sp2, lhs_mat_01_30_sp2),_mm256_maddubs_epi16(rhs_mat_0145_31_sp2, lhs_mat_01_31_sp2));
+                        __m256i iacc_mat_01_3_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_30_sp2, lhs_mat_01_30_sp2),_mm256_maddubs_epi16(rhs_mat_2367_31_sp2, lhs_mat_01_31_sp2));
+
+                        __m256i iacc_mat_10_3_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_30_sp2, lhs_mat_23_30_sp2),_mm256_maddubs_epi16(rhs_mat_0145_31_sp2, lhs_mat_23_31_sp2));
+                        __m256i iacc_mat_11_3_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_30_sp2, lhs_mat_23_30_sp2),_mm256_maddubs_epi16(rhs_mat_2367_31_sp2, lhs_mat_23_31_sp2));
+
+                        __m256i iacc_mat_00_4_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_40_sp2, lhs_mat_01_40_sp2),_mm256_maddubs_epi16(rhs_mat_0145_41_sp2, lhs_mat_01_41_sp2));
+                        __m256i iacc_mat_01_4_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_40_sp2, lhs_mat_01_40_sp2),_mm256_maddubs_epi16(rhs_mat_2367_41_sp2, lhs_mat_01_41_sp2));
+
+                        __m256i iacc_mat_10_4_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_40_sp2, lhs_mat_23_40_sp2),_mm256_maddubs_epi16(rhs_mat_0145_41_sp2, lhs_mat_23_41_sp2));
+                        __m256i iacc_mat_11_4_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_40_sp2, lhs_mat_23_40_sp2),_mm256_maddubs_epi16(rhs_mat_2367_41_sp2, lhs_mat_23_41_sp2));
+
+                        __m256i iacc_mat_00_5_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_50_sp2, lhs_mat_01_50_sp2),_mm256_maddubs_epi16(rhs_mat_0145_51_sp2, lhs_mat_01_51_sp2));
+                        __m256i iacc_mat_01_5_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_50_sp2, lhs_mat_01_50_sp2),_mm256_maddubs_epi16(rhs_mat_2367_51_sp2, lhs_mat_01_51_sp2));
+
+                        __m256i iacc_mat_10_5_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_50_sp2, lhs_mat_23_50_sp2),_mm256_maddubs_epi16(rhs_mat_0145_51_sp2, lhs_mat_23_51_sp2));
+                        __m256i iacc_mat_11_5_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_50_sp2, lhs_mat_23_50_sp2),_mm256_maddubs_epi16(rhs_mat_2367_51_sp2, lhs_mat_23_51_sp2));
+
+                        __m256i iacc_mat_00_6_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_60_sp2, lhs_mat_01_60_sp2),_mm256_maddubs_epi16(rhs_mat_0145_61_sp2, lhs_mat_01_61_sp2));
+                        __m256i iacc_mat_01_6_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_60_sp2, lhs_mat_01_60_sp2),_mm256_maddubs_epi16(rhs_mat_2367_61_sp2, lhs_mat_01_61_sp2));
+
+                        __m256i iacc_mat_10_6_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_60_sp2, lhs_mat_23_60_sp2),_mm256_maddubs_epi16(rhs_mat_0145_61_sp2, lhs_mat_23_61_sp2));
+                        __m256i iacc_mat_11_6_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_60_sp2, lhs_mat_23_60_sp2),_mm256_maddubs_epi16(rhs_mat_2367_61_sp2, lhs_mat_23_61_sp2));
+
+                        __m256i iacc_mat_00_7_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_70_sp2, lhs_mat_01_70_sp2),_mm256_maddubs_epi16(rhs_mat_0145_71_sp2, lhs_mat_01_71_sp2));
+                        __m256i iacc_mat_01_7_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_70_sp2, lhs_mat_01_70_sp2),_mm256_maddubs_epi16(rhs_mat_2367_71_sp2, lhs_mat_01_71_sp2));
+
+                        __m256i iacc_mat_10_7_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_70_sp2, lhs_mat_23_70_sp2),_mm256_maddubs_epi16(rhs_mat_0145_71_sp2, lhs_mat_23_71_sp2));
+                        __m256i iacc_mat_11_7_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_70_sp2, lhs_mat_23_70_sp2),_mm256_maddubs_epi16(rhs_mat_2367_71_sp2, lhs_mat_23_71_sp2));
+
+                        // Combine results from both shuffle patterns for each output block
                         __m256i iacc_mat_00_0 = _mm256_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2);
                         __m256i iacc_mat_01_0 = _mm256_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2);
                         __m256i iacc_mat_10_0 = _mm256_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2);
@@ -2830,6 +5491,36 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                         __m256i iacc_mat_10_1 = _mm256_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2);
                         __m256i iacc_mat_11_1 = _mm256_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2);
 
+                        __m256i iacc_mat_00_2 = _mm256_add_epi16(iacc_mat_00_2_sp1, iacc_mat_00_2_sp2);
+                        __m256i iacc_mat_01_2 = _mm256_add_epi16(iacc_mat_01_2_sp1, iacc_mat_01_2_sp2);
+                        __m256i iacc_mat_10_2 = _mm256_add_epi16(iacc_mat_10_2_sp1, iacc_mat_10_2_sp2);
+                        __m256i iacc_mat_11_2 = _mm256_add_epi16(iacc_mat_11_2_sp1, iacc_mat_11_2_sp2);
+
+                        __m256i iacc_mat_00_3 = _mm256_add_epi16(iacc_mat_00_3_sp1, iacc_mat_00_3_sp2);
+                        __m256i iacc_mat_01_3 = _mm256_add_epi16(iacc_mat_01_3_sp1, iacc_mat_01_3_sp2);
+                        __m256i iacc_mat_10_3 = _mm256_add_epi16(iacc_mat_10_3_sp1, iacc_mat_10_3_sp2);
+                        __m256i iacc_mat_11_3 = _mm256_add_epi16(iacc_mat_11_3_sp1, iacc_mat_11_3_sp2);
+
+                        __m256i iacc_mat_00_4 = _mm256_add_epi16(iacc_mat_00_4_sp1, iacc_mat_00_4_sp2);
+                        __m256i iacc_mat_01_4 = _mm256_add_epi16(iacc_mat_01_4_sp1, iacc_mat_01_4_sp2);
+                        __m256i iacc_mat_10_4 = _mm256_add_epi16(iacc_mat_10_4_sp1, iacc_mat_10_4_sp2);
+                        __m256i iacc_mat_11_4 = _mm256_add_epi16(iacc_mat_11_4_sp1, iacc_mat_11_4_sp2);
+
+                        __m256i iacc_mat_00_5 = _mm256_add_epi16(iacc_mat_00_5_sp1, iacc_mat_00_5_sp2);
+                        __m256i iacc_mat_01_5 = _mm256_add_epi16(iacc_mat_01_5_sp1, iacc_mat_01_5_sp2);
+                        __m256i iacc_mat_10_5 = _mm256_add_epi16(iacc_mat_10_5_sp1, iacc_mat_10_5_sp2);
+                        __m256i iacc_mat_11_5 = _mm256_add_epi16(iacc_mat_11_5_sp1, iacc_mat_11_5_sp2);
+
+                        __m256i iacc_mat_00_6 = _mm256_add_epi16(iacc_mat_00_6_sp1, iacc_mat_00_6_sp2);
+                        __m256i iacc_mat_01_6 = _mm256_add_epi16(iacc_mat_01_6_sp1, iacc_mat_01_6_sp2);
+                        __m256i iacc_mat_10_6 = _mm256_add_epi16(iacc_mat_10_6_sp1, iacc_mat_10_6_sp2);
+                        __m256i iacc_mat_11_6 = _mm256_add_epi16(iacc_mat_11_6_sp1, iacc_mat_11_6_sp2);
+
+                        __m256i iacc_mat_00_7 = _mm256_add_epi16(iacc_mat_00_7_sp1, iacc_mat_00_7_sp2);
+                        __m256i iacc_mat_01_7 = _mm256_add_epi16(iacc_mat_01_7_sp1, iacc_mat_01_7_sp2);
+                        __m256i iacc_mat_10_7 = _mm256_add_epi16(iacc_mat_10_7_sp1, iacc_mat_10_7_sp2);
+                        __m256i iacc_mat_11_7 = _mm256_add_epi16(iacc_mat_11_7_sp1, iacc_mat_11_7_sp2);
+
                         // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
                         iacc_mat_00_0 = _mm256_madd_epi16(iacc_mat_00_0, scale_0145_0);
                         iacc_mat_01_0 = _mm256_madd_epi16(iacc_mat_01_0, scale_2367_0);
@@ -2841,24 +5532,50 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                         iacc_mat_10_1 = _mm256_madd_epi16(iacc_mat_10_1, scale_0145_1);
                         iacc_mat_11_1 = _mm256_madd_epi16(iacc_mat_11_1, scale_2367_1);
 
-                        // Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step)
-                        __m256i iacc_row_0_0 = _mm256_blend_epi32(iacc_mat_00_0, _mm256_shuffle_epi32(iacc_mat_01_0, 78), 204);
-                        __m256i iacc_row_1_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_0, 78), iacc_mat_01_0, 204);
-                        __m256i iacc_row_2_0 = _mm256_blend_epi32(iacc_mat_10_0, _mm256_shuffle_epi32(iacc_mat_11_0, 78), 204);
-                        __m256i iacc_row_3_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_0, 78), iacc_mat_11_0, 204);
-                        __m256i iacc_row_0_1 = _mm256_blend_epi32(iacc_mat_00_1, _mm256_shuffle_epi32(iacc_mat_01_1, 78), 204);
-                        __m256i iacc_row_1_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_1, 78), iacc_mat_01_1, 204);
-                        __m256i iacc_row_2_1 = _mm256_blend_epi32(iacc_mat_10_1, _mm256_shuffle_epi32(iacc_mat_11_1, 78), 204);
-                        __m256i iacc_row_3_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_1, 78), iacc_mat_11_1, 204);
+                        iacc_mat_00_2 = _mm256_madd_epi16(iacc_mat_00_2, scale_0145_2);
+                        iacc_mat_01_2 = _mm256_madd_epi16(iacc_mat_01_2, scale_2367_2);
+                        iacc_mat_10_2 = _mm256_madd_epi16(iacc_mat_10_2, scale_0145_2);
+                        iacc_mat_11_2 = _mm256_madd_epi16(iacc_mat_11_2, scale_2367_2);
+
+                        iacc_mat_00_3 = _mm256_madd_epi16(iacc_mat_00_3, scale_0145_3);
+                        iacc_mat_01_3 = _mm256_madd_epi16(iacc_mat_01_3, scale_2367_3);
+                        iacc_mat_10_3 = _mm256_madd_epi16(iacc_mat_10_3, scale_0145_3);
+                        iacc_mat_11_3 = _mm256_madd_epi16(iacc_mat_11_3, scale_2367_3);
+
+                        iacc_mat_00_4 = _mm256_madd_epi16(iacc_mat_00_4, scale_0145_4);
+                        iacc_mat_01_4 = _mm256_madd_epi16(iacc_mat_01_4, scale_2367_4);
+                        iacc_mat_10_4 = _mm256_madd_epi16(iacc_mat_10_4, scale_0145_4);
+                        iacc_mat_11_4 = _mm256_madd_epi16(iacc_mat_11_4, scale_2367_4);
+
+                        iacc_mat_00_5 = _mm256_madd_epi16(iacc_mat_00_5, scale_0145_5);
+                        iacc_mat_01_5 = _mm256_madd_epi16(iacc_mat_01_5, scale_2367_5);
+                        iacc_mat_10_5 = _mm256_madd_epi16(iacc_mat_10_5, scale_0145_5);
+                        iacc_mat_11_5 = _mm256_madd_epi16(iacc_mat_11_5, scale_2367_5);
+
+                        iacc_mat_00_6 = _mm256_madd_epi16(iacc_mat_00_6, scale_0145_6);
+                        iacc_mat_01_6 = _mm256_madd_epi16(iacc_mat_01_6, scale_2367_6);
+                        iacc_mat_10_6 = _mm256_madd_epi16(iacc_mat_10_6, scale_0145_6);
+                        iacc_mat_11_6 = _mm256_madd_epi16(iacc_mat_11_6, scale_2367_6);
+
+                        iacc_mat_00_7 = _mm256_madd_epi16(iacc_mat_00_7, scale_0145_7);
+                        iacc_mat_01_7 = _mm256_madd_epi16(iacc_mat_01_7, scale_2367_7);
+                        iacc_mat_10_7 = _mm256_madd_epi16(iacc_mat_10_7, scale_0145_7);
+                        iacc_mat_11_7 = _mm256_madd_epi16(iacc_mat_11_7, scale_2367_7);
+
+                        __m256i iacc_mat_00 = _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(iacc_mat_00_0, iacc_mat_00_1), _mm256_add_epi32(iacc_mat_00_2, iacc_mat_00_3)), _mm256_add_epi32(_mm256_add_epi32(iacc_mat_00_4, iacc_mat_00_5), _mm256_add_epi32(iacc_mat_00_6, iacc_mat_00_7)));
+                        __m256i iacc_mat_01 = _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(iacc_mat_01_0, iacc_mat_01_1), _mm256_add_epi32(iacc_mat_01_2, iacc_mat_01_3)), _mm256_add_epi32(_mm256_add_epi32(iacc_mat_01_4, iacc_mat_01_5), _mm256_add_epi32(iacc_mat_01_6, iacc_mat_01_7)));
+                        __m256i iacc_mat_10 = _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(iacc_mat_10_0, iacc_mat_10_1), _mm256_add_epi32(iacc_mat_10_2, iacc_mat_10_3)), _mm256_add_epi32(_mm256_add_epi32(iacc_mat_10_4, iacc_mat_10_5), _mm256_add_epi32(iacc_mat_10_6, iacc_mat_10_7)));
+                        __m256i iacc_mat_11 = _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(iacc_mat_11_0, iacc_mat_11_1), _mm256_add_epi32(iacc_mat_11_2, iacc_mat_11_3)), _mm256_add_epi32(_mm256_add_epi32(iacc_mat_11_4, iacc_mat_11_5), _mm256_add_epi32(iacc_mat_11_6, iacc_mat_11_7)));
 
-                        __m256i iacc_row_0 = _mm256_add_epi32(iacc_row_0_0, iacc_row_0_1);
-                        __m256i iacc_row_1 = _mm256_add_epi32(iacc_row_1_0, iacc_row_1_1);
-                        __m256i iacc_row_2 = _mm256_add_epi32(iacc_row_2_0, iacc_row_2_1);
-                        __m256i iacc_row_3 = _mm256_add_epi32(iacc_row_3_0, iacc_row_3_1);
+                        // Straighten out to make 4 row vectors
+                        __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204);
+                        __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204);
+                        __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204);
+                        __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204);
 
                         // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes
                         const __m128 row_scale_f32_sse = _mm_load_ps(a_ptrs[rp][b].d);
-                        const __m256 row_scale_f32 = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse);//GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask);
+                        const __m256 row_scale_f32 = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse);
 
                         // Multiply with appropiate scales and accumulate (for both d and dmin) below
                         acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);
@@ -2866,10 +5583,36 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                         acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
                         acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);
 
-                        __m256i iacc_row_min_0 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 0), mins_01);
-                        __m256i iacc_row_min_1 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 85), mins_01);
-                        __m256i iacc_row_min_2 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 170), mins_01);
-                        __m256i iacc_row_min_3 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 255), mins_01);
+                        __m256i lhs_bsums_01_0123 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_01_0123), lhs_raw_bsums_01_0123, 1);
+                        __m256i lhs_bsums_23_0123 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_23_0123), lhs_raw_bsums_23_0123, 1);
+                        __m256i lhs_bsums_01_4567 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_01_4567), lhs_raw_bsums_01_4567, 1);
+                        __m256i lhs_bsums_23_4567 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_23_4567), lhs_raw_bsums_23_4567, 1);
+
+                       // Take two bsums from two Q8_Ks at a time and multiply with corresponding mins values from each Q2_K
+                        __m256i iacc_row_min_0_01 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_0123, 0), mins_01);
+                        __m256i iacc_row_min_1_01 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_0123, 170), mins_01);
+                        __m256i iacc_row_min_2_01 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_0123, 0), mins_01);
+                        __m256i iacc_row_min_3_01 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_0123, 170), mins_01);
+
+                        __m256i iacc_row_min_0_23 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_0123, 85), mins_23);
+                        __m256i iacc_row_min_1_23 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_0123, 255), mins_23);
+                        __m256i iacc_row_min_2_23 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_0123, 85), mins_23);
+                        __m256i iacc_row_min_3_23 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_0123, 255), mins_23);
+
+                        __m256i iacc_row_min_0_45 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_4567, 0), mins_45);
+                        __m256i iacc_row_min_1_45 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_4567, 170), mins_45);
+                        __m256i iacc_row_min_2_45 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_4567, 0), mins_45);
+                        __m256i iacc_row_min_3_45 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_4567, 170), mins_45);
+
+                        __m256i iacc_row_min_0_67 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_4567, 85), mins_67);
+                        __m256i iacc_row_min_1_67 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_4567, 255), mins_67);
+                        __m256i iacc_row_min_2_67 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_4567, 85), mins_67);
+                        __m256i iacc_row_min_3_67 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_4567, 255), mins_67);
+
+                        __m256i iacc_row_min_0 = _mm256_add_epi32(_mm256_add_epi32(iacc_row_min_0_01, iacc_row_min_0_23), _mm256_add_epi32(iacc_row_min_0_45,iacc_row_min_0_67));
+                        __m256i iacc_row_min_1 = _mm256_add_epi32(_mm256_add_epi32(iacc_row_min_1_01, iacc_row_min_1_23), _mm256_add_epi32(iacc_row_min_1_45,iacc_row_min_1_67));
+                        __m256i iacc_row_min_2 = _mm256_add_epi32(_mm256_add_epi32(iacc_row_min_2_01, iacc_row_min_2_23), _mm256_add_epi32(iacc_row_min_2_45,iacc_row_min_2_67));
+                        __m256i iacc_row_min_3 = _mm256_add_epi32(_mm256_add_epi32(iacc_row_min_3_01, iacc_row_min_3_23), _mm256_add_epi32(iacc_row_min_3_45,iacc_row_min_3_67));
 
                         acc_min_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_0), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[rp * 4]);
                         acc_min_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_1), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[rp * 4 + 1]);
@@ -2882,16 +5625,19 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
             // Store the accumulated values
             for (int i = 0; i < 16; i++) {
                 _mm256_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm256_sub_ps(acc_rows[i], acc_min_rows[i]));
+
             }
         }
     }
-    for (; y < nr / 4; y++) {
+
+    for (; y < nr / 4; y ++) {
 
         const block_q8_Kx4 * a_ptr = a_ptr_start + (y * nb);
 
+        // Take group of eight block_q2_kx8 structures at each pass of the loop and perform dot product operation
         for (int64_t x = xstart; x < nc / 8; x++) {
 
-            const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb);
+            const block_q2_Kx8 * b_ptr = b_ptr_start + (x * b_nb);
 
             // Master FP accumulators
             __m256 acc_rows[4];
@@ -2905,62 +5651,95 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
             }
 
             for (int64_t b = 0; b < nb; b++) {
-
-                // Scale values - Load the eight scale values of block_q4_Kx8
+                // Delta values - Load the eight scale values of block_q2_kx8
                 const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
 
-                // dmin values - Load the eight dmin values of block_q4_Kx8
+                // dmin values - Load the eight dmin values of block_q2_kx8
                 const __m256 col_dmin_f32 = GGML_F32Cx8_LOAD(b_ptr[b].dmin);
 
-                // Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration
-                for (int sb = 0; sb < QK_K / 64; sb++) {
+                // Loop to iterate over the sixteen sub blocks of a super block - eight sub blocks are processed per iteration
+                for (int sb = 0; sb < QK_K / 128; sb++) {
 
-                    // Load the eight block_q4_k for two sub blocks quantized values interleaved with each other in chunks of eight bytes - B0,B1 ....B6,B7
-                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256));
-                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256));
-                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256));
-                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256));
-                    const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256));
-                    const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256));
-                    const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256));
-                    const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256));
+                    // Load the eight block_q2_k for eight sub blocks quantized values interleaved with each other in chunks of eight bytes - B0,B1 ....B6,B7
+                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + sb * 256));
+                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32 + sb * 256));
+                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64 + sb * 256));
+                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96 + sb * 256));
+                    const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 128 + sb * 256));
+                    const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 160 + sb * 256));
+                    const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 192 + sb * 256));
+                    const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 224 + sb * 256));
 
                     // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values
+                    //superblock    sub block   which part of sub block
                     const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
                     const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
+
                     const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
                     const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
+
                     const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240);
                     const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240);
+
                     const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240);
                     const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240);
 
-                    // 4-bit -> 8-bit
-                    // First sub block of the two sub blocks processed in the iteration
-                    const __m256i rhs_mat_0145_00 = _mm256_and_si256(rhs_raw_mat_0145_0, m4b); //B00(0-7) B01(0-7) B04(0-7) B05(0-7)
-                    const __m256i rhs_mat_2367_00 = _mm256_and_si256(rhs_raw_mat_2367_0, m4b); //B02(0-7) B03(0-7) B06(0-7) B07(0-7)
+                    // 2-bit -> 8-bit
+                    // First sub block of the eight sub blocks processed in the iteration
+                    const __m256i rhs_mat_0145_00 = _mm256_and_si256(rhs_raw_mat_0145_0, m3b); //B00(0-7) B01(0-7) B04(0-7) B05(0-7)
+                    const __m256i rhs_mat_2367_00 = _mm256_and_si256(rhs_raw_mat_2367_0, m3b); //B02(0-7) B03(0-7) B06(0-7) B07(0-7)
+
+                    const __m256i rhs_mat_0145_01 = _mm256_and_si256(rhs_raw_mat_0145_1, m3b); //B00(8-15) B01(8-15) B04(8-15) B05(8-15)
+                    const __m256i rhs_mat_2367_01 = _mm256_and_si256(rhs_raw_mat_2367_1, m3b); //B02(8-15) B03(8-15) B06(8-15) B07(8-15)
+
+                    // Second sub block of the eight sub blocks processed in the iteration
+                    const __m256i rhs_mat_0145_10 = _mm256_and_si256(rhs_raw_mat_0145_2, m3b); //B10(0-7) B11(0-7) B14(0-7) B15(0-7)
+                    const __m256i rhs_mat_2367_10 = _mm256_and_si256(rhs_raw_mat_2367_2, m3b); //B12(0-7) B13(0-7) B16(0-7) B17(0-7)
+
+                    const __m256i rhs_mat_0145_11 = _mm256_and_si256(rhs_raw_mat_0145_3, m3b); //B10(8-15) B11(8-15) B14(8-15) B15(8-15)
+                    const __m256i rhs_mat_2367_11 = _mm256_and_si256(rhs_raw_mat_2367_3, m3b); //B12(8-15) B13(8-15) B16(8-15) B17(8-15)
 
-                    const __m256i rhs_mat_0145_01 = _mm256_and_si256(rhs_raw_mat_0145_1, m4b); //B00(8-15) B01(8-15) B04(8-15) B05(8-15)
-                    const __m256i rhs_mat_2367_01 = _mm256_and_si256(rhs_raw_mat_2367_1, m4b); //B02(8-15) B03(8-15) B06(8-15) B07(8-15)
+                    // Third sub block of the eight sub blocks processed in the iteration
+                    const __m256i rhs_mat_0145_20 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 2), m3b); //B20(0-7) B21(0-7) B24(0-7) B25(0-7)
+                    const __m256i rhs_mat_2367_20 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 2), m3b); //B22(0-7) B23(0-7) B26(0-7) B27(0-7)
 
-                    const __m256i rhs_mat_0145_02 = _mm256_and_si256(rhs_raw_mat_0145_2, m4b); //B00(16-23) B01(16-23) B04(16-23) B05(16-23)
-                    const __m256i rhs_mat_2367_02 = _mm256_and_si256(rhs_raw_mat_2367_2, m4b); //B02(16-23) B03(16-23) B06(16-23) B07(16-23)
+                    const __m256i rhs_mat_0145_21 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 2), m3b); //B20(8-15) B21(8-15) B24(8-15) B25(8-15)
+                    const __m256i rhs_mat_2367_21 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 2), m3b); //B22(8-15) B23(8-15) B26(8-15) B27(8-15)
 
-                    const __m256i rhs_mat_0145_03 = _mm256_and_si256(rhs_raw_mat_0145_3, m4b); //B00(24-31) B01(24-31) B04(24-31) B05(24-31)
-                    const __m256i rhs_mat_2367_03 = _mm256_and_si256(rhs_raw_mat_2367_3, m4b); //B02(24-31) B03(24-31) B06(24-31) B07(24-31)
+                    // Fourth sub block of the eight sub blocks processed in the iteration
+                    const __m256i rhs_mat_0145_30 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 2), m3b); //B30(0-7) B31(0-7) B34(0-7) B35(0-7)
+                    const __m256i rhs_mat_2367_30 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 2), m3b); //B32(0-7) B33(0-7) B36(0-7) B37(0-7)
 
-                    // Second sub block of the two sub blocks processed in the iteration
-                    const __m256i rhs_mat_0145_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b); //B10(0-7) B11(0-7) B14(0-7) B15(0-7)
-                    const __m256i rhs_mat_2367_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b); //B12(0-7) B13(0-7) B16(0-7) B17(0-7)
+                    const __m256i rhs_mat_0145_31 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 2), m3b); //B30(8-15) B31(8-15) B34(8-15) B35(8-15)
+                    const __m256i rhs_mat_2367_31 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 2), m3b); //B32(8-15) B33(8-15) B36(8-15) B37(8-15)
 
-                    const __m256i rhs_mat_0145_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b); //B10(8-15) B11(8-15) B14(8-15) B15(8-15)
-                    const __m256i rhs_mat_2367_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b); //B12(8-15) B13(8-15) B16(8-15) B17(8-15)
+                    // Fifth sub block of the eight sub blocks processed in the iteration
+                    const __m256i rhs_mat_0145_40 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m3b); //B40(0-7) B41(0-7) B44(0-7) B45(0-7)
+                    const __m256i rhs_mat_2367_40 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m3b); //B42(0-7) B43(0-7) B46(0-7) B47(0-7)
 
-                    const __m256i rhs_mat_0145_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 4), m4b); //B10(16-23) B11(16-23) B14(16-23) B15(16-23)
-                    const __m256i rhs_mat_2367_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 4), m4b); //B12(16-23) B13(16-23) B16(16-23) B17(16-23)
+                    const __m256i rhs_mat_0145_41 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m3b); //B40(8-15) B41(8-15) B44(8-15) B45(8-15)
+                    const __m256i rhs_mat_2367_41 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m3b); //B42(8-15) B43(8-15) B46(8-15) B47(8-15)
 
-                    const __m256i rhs_mat_0145_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 4), m4b); //B10(24-31) B11(24-31) B14(24-31) B15(24-31)
-                    const __m256i rhs_mat_2367_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 4), m4b); //B12(24-31) B13(24-31) B16(24-31) B17(24-31)
+                    // Sixth sub block of the eight sub blocks processed in the iteration
+                    const __m256i rhs_mat_0145_50 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 4), m3b); //B50(0-7) B51(0-7) B54(0-7) B55(0-7)
+                    const __m256i rhs_mat_2367_50 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 4), m3b); //B52(0-7) B53(0-7) B56(0-7) B57(0-7)
+
+                    const __m256i rhs_mat_0145_51 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 4), m3b); //B50(8-15) B51(8-15) B54(8-15) B55(8-15)
+                    const __m256i rhs_mat_2367_51 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 4), m3b); //B52(8-15) B53(8-15) B56(8-15) B57(8-15)
+
+                    // Seventh sub block of the eight sub blocks processed in the iteration
+                    const __m256i rhs_mat_0145_60 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 6), m3b); //B60(0-7) B61(0-7) B64(0-7) B65(0-7)
+                    const __m256i rhs_mat_2367_60 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 6), m3b); //B62(0-7) B63(0-7) B66(0-7) B67(0-7)
+
+                    const __m256i rhs_mat_0145_61 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 6), m3b); //B60(8-15) B61(8-15) B64(8-15) B65(8-15)
+                    const __m256i rhs_mat_2367_61 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 6), m3b); //B62(8-15) B63(8-15) B66(8-15) B67(8-15)
+
+                    // Eighth sub block of the eight sub blocks processed in the iteration
+                    const __m256i rhs_mat_0145_70 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_2, 6), m3b); //B70(0-7) B71(0-7) B74(0-7) B75(0-7)
+                    const __m256i rhs_mat_2367_70 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_2, 6), m3b); //B72(0-7) B73(0-7) B76(0-7) B77(0-7)
+
+                    const __m256i rhs_mat_0145_71 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_3, 6), m3b); //B70(8-15) B71(8-15) B74(8-15) B75(8-15)
+                    const __m256i rhs_mat_2367_71 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_3, 6), m3b); //B72(8-15) B73(8-15) B76(8-15) B77(8-15)
 
                     // Shuffle pattern one - right side input
                     const __m256i rhs_mat_0145_00_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_00, 136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3)
@@ -2969,23 +5748,48 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                     const __m256i rhs_mat_0145_01_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_01, 136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11)
                     const __m256i rhs_mat_2367_01_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_01, 136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11)
 
-                    const __m256i rhs_mat_0145_02_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_02, 136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19)
-                    const __m256i rhs_mat_2367_02_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_02, 136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19)
-
-                    const __m256i rhs_mat_0145_03_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_03, 136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27)
-                    const __m256i rhs_mat_2367_03_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_03, 136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27)
-
                     const __m256i rhs_mat_0145_10_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_10, 136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3)
                     const __m256i rhs_mat_2367_10_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_10, 136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3)
 
                     const __m256i rhs_mat_0145_11_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_11, 136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11)
                     const __m256i rhs_mat_2367_11_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_11, 136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11)
 
-                    const __m256i rhs_mat_0145_12_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_12, 136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19)
-                    const __m256i rhs_mat_2367_12_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_12, 136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19)
+                    const __m256i rhs_mat_0145_20_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_20, 136); //B20(0-3) B21(0-3) B20(0-3) B21(0-3) B24(0-3) B25(0-3) B24(0-3) B25(0-3)
+                    const __m256i rhs_mat_2367_20_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_20, 136); //B22(0-3) B23(0-3) B22(0-3) B23(0-3) B26(0-3) B27(0-3) B26(0-3) B27(0-3)
+
+                    const __m256i rhs_mat_0145_21_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_21, 136); //B20(8-11) B21(8-11) B20(8-11) B21(8-11) B24(8-11) B25(8-11) B24(8-11) B25(8-11)
+                    const __m256i rhs_mat_2367_21_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_21, 136); //B22(8-11) B23(8-11) B22(8-11) B23(8-11) B26(8-11) B27(8-11) B26(8-11) B27(8-11)
+
+                    const __m256i rhs_mat_0145_30_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_30, 136); //B30(0-3) B31(0-3) B30(0-3) B31(0-3) B34(0-3) B35(0-3) B34(0-3) B35(0-3)
+                    const __m256i rhs_mat_2367_30_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_30, 136); //B32(0-3) B33(0-3) B32(0-3) B33(0-3) B36(0-3) B37(0-3) B36(0-3) B37(0-3)
+
+                    const __m256i rhs_mat_0145_31_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_31, 136); //B30(8-11) B31(8-11) B30(8-11) B31(8-11) B34(8-11) B35(8-11) B34(8-11) B35(8-11
+                    const __m256i rhs_mat_2367_31_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_31, 136); //B32(8-11) B33(8-11) B32(8-11) B33(8-11) B36(8-11) B37(8-11) B36(8-11) B37(8-11)
+
+                    const __m256i rhs_mat_0145_40_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_40, 136); //B40(0-3) B41(0-3) B40(0-3) B41(0-3) B44(0-3) B45(0-3) B44(0-3) B45(0-3)
+                    const __m256i rhs_mat_2367_40_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_40, 136); //B42(0-3) B43(0-3) B42(0-3) B43(0-3) B46(0-3) B47(0-3) B46(0-3) B47(0-3)
+
+                    const __m256i rhs_mat_0145_41_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_41, 136); //B40(8-11) B41(8-11) B40(8-11) B41(8-11) B44(8-11) B45(8-11) B44(8-11) B45(8-11)
+                    const __m256i rhs_mat_2367_41_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_41, 136); //B42(8-11) B43(8-11) B42(8-11) B43(8-11) B46(8-11) B47(8-11) B46(8-11) B47(8-11)
+
+                    const __m256i rhs_mat_0145_50_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_50, 136); //B50(0-3) B51(0-3) B50(0-3) B51(0-3) B54(0-3) B55(0-3) B54(0-3) B55(0-3)
+                    const __m256i rhs_mat_2367_50_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_50, 136); //B52(0-3) B53(0-3) B52(0-3) B53(0-3) B56(0-3) B57(0-3) B56(0-3) B57(0-3)
+
+                    const __m256i rhs_mat_0145_51_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_51, 136); //B50(8-11) B51(8-11) B50(8-11) B51(8-11) B54(8-11) B55(8-11) B54(8-11) B55(8-11)
+                    const __m256i rhs_mat_2367_51_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_51, 136); //B52(8-11) B53(8-11) B52(8-11) B53(8-11) B56(8-11) B57(8-11) B56(8-11) B57(8-11)
+
+                    const __m256i rhs_mat_0145_60_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_60, 136); //B60(0-3) B61(0-3) B60(0-3) B61(0-3) B64(0-3) B65(0-3) B64(0-3) B65(0-3)
+                    const __m256i rhs_mat_2367_60_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_60, 136); //B62(0-3) B63(0-3) B62(0-3) B63(0-3) B66(0-3) B67(0-3) B66(0-3) B67(0-3)
+
+                    const __m256i rhs_mat_0145_61_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_61, 136); //B60(8-11) B61(8-11) B60(8-11) B61(8-11) B64(8-11) B65(8-11) B64(8-11) B65(8-11)
+                    const __m256i rhs_mat_2367_61_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_61, 136); //B62(8-11) B63(8-11) B62(8-11) B63(8-11) B66(8-11) B67(8-11) B66(8-11) B67(8-11)
+
+                    const __m256i rhs_mat_0145_70_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_70, 136); //B70(0-3) B71(0-3) B70(0-3) B71(0-3) B74(0-3) B75(0-3) B74(0-3) B75(0-3)
+                    const __m256i rhs_mat_2367_70_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_70, 136); //B72(0-3) B73(0-3) B72(0-3) B73(0-3) B76(0-3) B77(0-3) B76(0-3) B77(0-3)
+
+                    const __m256i rhs_mat_0145_71_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_71, 136); //B70(8-11) B71(8-11) B70(8-11) B71(8-11) B74(8-11) B75(8-11) B74(8-11) B75(8-11)
+                    const __m256i rhs_mat_2367_71_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_71, 136); //B72(8-11) B73(8-11) B72(8-11) B73(8-11) B76(8-11) B77(8-11) B76(8-11) B77(8-11)
 
-                    const __m256i rhs_mat_0145_13_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_13, 136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27)
-                    const __m256i rhs_mat_2367_13_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_13, 136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27)
 
                     // Shuffle pattern two - right side input
                     const __m256i rhs_mat_0145_00_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_00, 221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7)
@@ -2994,53 +5798,81 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                     const __m256i rhs_mat_0145_01_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_01, 221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15)
                     const __m256i rhs_mat_2367_01_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_01, 221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15)
 
-                    const __m256i rhs_mat_0145_02_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_02, 221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23)
-                    const __m256i rhs_mat_2367_02_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_02, 221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23)
-
-                    const __m256i rhs_mat_0145_03_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_03, 221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31)
-                    const __m256i rhs_mat_2367_03_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_03, 221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31)
-
                     const __m256i rhs_mat_0145_10_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_10, 221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7)
                     const __m256i rhs_mat_2367_10_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_10, 221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7)
 
                     const __m256i rhs_mat_0145_11_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_11, 221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15)
                     const __m256i rhs_mat_2367_11_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_11, 221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15)
 
-                    const __m256i rhs_mat_0145_12_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_12, 221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23)
-                    const __m256i rhs_mat_2367_12_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_12, 221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23)
+                    const __m256i rhs_mat_0145_20_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_20, 221); //B20(4-7) B21(4-7) B20(4-7) B21(4-7) B24(4-7) B25(4-7) B24(4-7) B25(4-7)
+                    const __m256i rhs_mat_2367_20_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_20, 221); //B22(4-7) B23(4-7) B22(4-7) B23(4-7) B26(4-7) B27(4-7) B26(4-7) B27(4-7)
 
-                    const __m256i rhs_mat_0145_13_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_13, 221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31)
-                    const __m256i rhs_mat_2367_13_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_13, 221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31)
+                    const __m256i rhs_mat_0145_21_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_21, 221); //B20(12-15) B21(12-15) B20(12-15) B21(12-15) B24(12-15) B25(12-15) B24(12-15) B25(12-15)
+                    const __m256i rhs_mat_2367_21_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_21, 221); //B22(12-15) B23(12-15) B22(12-15) B23(12-15) B26(12-15) B27(12-15) B26(12-15) B27(12-15)
 
-                    uint32_t utmp_0[4], utmp_1[4];
+                    const __m256i rhs_mat_0145_30_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_30, 221); //B30(4-7) B31(4-7) B30(4-7) B31(4-7) B34(4-7) B35(4-7) B34(4-7) B35(4-7)
+                    const __m256i rhs_mat_2367_30_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_30, 221); //B32(4-7) B33(4-7) B32(4-7) B33(4-7) B36(4-7) B37(4-7) B36(4-7) B37(4-7)
 
-                    // Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together
-                    // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
-                    memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12);
-                    utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);
-                    const uint32_t uaux_0 = utmp_0[1] & kmask1;
-                    utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);
-                    utmp_0[2] = uaux_0;
-                    utmp_0[0] &= kmask1;
+                    const __m256i rhs_mat_0145_31_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_31, 221); //B30(12-15) B31(12-15) B30(12-15) B31(12-15) B34(12-15) B35(12-15) B34(12-15) B35(12-15)
+                    const __m256i rhs_mat_2367_31_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_31, 221); //B32(12-15) B33(12-15) B32(12-15) B33(12-15) B36(12-15) B37(12-15) B36(12-15) B37(12-15)
 
-                    // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures when sb = 1
-                    memcpy(utmp_1, b_ptr[b].scales + 12 + sb * 24, 12);
-                    utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);
-                    const uint32_t uaux_1 = utmp_1[1] & kmask1;
-                    utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);
-                    utmp_1[2] = uaux_1;
-                    utmp_1[0] &= kmask1;
+                    const __m256i rhs_mat_0145_40_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_40, 221); //B40(4-7) B41(4-7) B40(4-7) B41(4-7) B44(4-7) B45(4-7) B44(4-7) B45(4-7)
+                    const __m256i rhs_mat_2367_40_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_40, 221); //B42(4-7) B43(4-7) B42(4-7) B43(4-7) B46(4-7) B47(4-7) B46(4-7) B47(4-7)
 
-                    // Scales of first sub block in the sb loop
-                    const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);
-                    const __m256i scales_0 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0));
+                    const __m256i rhs_mat_0145_41_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_41, 221); //B40(12-15) B41(12-15) B40(12-15) B41(12-15) B44(12-15) B45(12-15) B44(12-15) B45(12-15)
+                    const __m256i rhs_mat_2367_41_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_41, 221); //B42(12-15) B43(12-15) B42(12-15) B43(12-15) B46(12-15) B47(12-15) B46(12-15) B47(12-15)
 
-                    // Scales of second sub block in the sb loop
-                    const __m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);
-                    const __m256i scales_1 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1));
+                    const __m256i rhs_mat_0145_50_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_50, 221); //B50(4-7) B51(4-7) B50(4-7) B51(4-7) B54(4-7) B55(4-7) B54(4-7) B55(4-7)
+                    const __m256i rhs_mat_2367_50_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_50, 221); //B52(4-7) B53(4-7) B52(4-7) B53(4-7) B56(4-7) B57(4-7) B56(4-7) B57(4-7)
 
-                    // Mins of first and second sub block of Q4_K block are arranged side by side
-                    const __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(_mm_shuffle_epi32(mins_and_scales_0, 78), _mm_shuffle_epi32(mins_and_scales_1, 78)));
+                    const __m256i rhs_mat_0145_51_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_51, 221); //B50(12-15) B51(12-15) B50(12-15) B51(12-15) B54(12-15) B55(12-15) B54(12-15) B55(12-15)
+                    const __m256i rhs_mat_2367_51_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_51, 221); //B52(12-15) B53(12-15) B52(12-15) B53(12-15) B56(12-15) B57(12-15) B56(12-15) B57(12-15)
+
+                    const __m256i rhs_mat_0145_60_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_60, 221); //B60(4-7) B61(4-7) B60(4-7) B61(4-7) B64(4-7) B65(4-7) B64(4-7) B65(4-7)
+                    const __m256i rhs_mat_2367_60_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_60, 221); //B62(4-7) B63(4-7) B62(4-7) B63(4-7) B66(4-7) B67(4-7) B66(4-7) B67(4-7)
+
+                    const __m256i rhs_mat_0145_61_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_61, 221); //B60(12-15) B61(12-15) B60(12-15) B61(12-15) B64(12-15) B65(12-15) B64(12-15) B65(12-15)
+                    const __m256i rhs_mat_2367_61_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_61, 221); //B62(12-15) B63(12-15) B62(12-15) B63(12-15) B66(12-15) B67(12-15) B66(12-15) B67(12-15)
+
+                    const __m256i rhs_mat_0145_70_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_70, 221); //B70(4-7) B71(4-7) B70(4-7) B71(4-7) B74(4-7) B75(4-7) B74(4-7) B75(4-7)
+                    const __m256i rhs_mat_2367_70_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_70, 221); //B72(4-7) B73(4-7) B72(4-7) B73(4-7) B76(4-7) B77(4-7) B76(4-7) B77(4-7)
+
+                    const __m256i rhs_mat_0145_71_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_71, 221); //B70(12-15) B71(12-15) B70(12-15) B71(12-15) B74(12-15) B75(12-15) B74(12-15) B75(12-15)
+                    const __m256i rhs_mat_2367_71_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_71, 221); //B72(12-15) B73(12-15) B72(12-15) B73(12-15) B76(12-15) B77(12-15) B76(12-15) B77(12-15)
+
+
+                    //Scales and Mins of corresponding sub blocks from different Q2_K structures are stored together
+                    //s00 m00  s01 m01   s10 m10  s11 m11  s20 m20  s21 m21   s30 m30  s31 m31  s40 m40  s41 m41   s50 m50  s51 m51  s60 m60  s61 m61   s70 m70  s71 m71
+
+                    // Combine mins and scales for sub-blocks: 0-1, 2-3, 4-5, 6-7 in the sb loop
+                    const __m128i mins_and_scales_01 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + sb * 64));
+                    const __m128i mins_and_scales_23 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 16 + sb * 64));
+                    const __m128i mins_and_scales_45 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 32 + sb * 64));
+                    const __m128i mins_and_scales_67 = _mm_loadu_si128((const __m128i *)(b_ptr[b].scales + 48 + sb * 64));
+
+                    // Extract scales which is lower half from mins_and_scales
+                    const __m128i scales_01 = _mm_and_si128(mins_and_scales_01, m4b_sse);
+                    const __m128i scales_23 = _mm_and_si128(mins_and_scales_23, m4b_sse);
+                    const __m128i scales_45 = _mm_and_si128(mins_and_scales_45, m4b_sse);
+                    const __m128i scales_67 = _mm_and_si128(mins_and_scales_67, m4b_sse);
+
+                    // Extract mins which is upper half from mins_and_scales
+                    const __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_01, 4), m4b_sse));
+                    const __m256i mins_23 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_23, 4), m4b_sse));
+                    const __m256i mins_45 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_45, 4), m4b_sse));
+                    const __m256i mins_67 = _mm256_cvtepu8_epi16(_mm_and_si128(_mm_srli_epi16(mins_and_scales_67, 4), m4b_sse));
+
+                    const __m256i scales_0 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_01, scalesmask1_sse));
+                    const __m256i scales_1 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_01, scalesmask2_sse));
+
+                    const __m256i scales_2 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_23, scalesmask1_sse));
+                    const __m256i scales_3 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_23, scalesmask2_sse));
+
+                    const __m256i scales_4 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_45, scalesmask1_sse));
+                    const __m256i scales_5 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_45, scalesmask2_sse));
+
+                    const __m256i scales_6 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_67, scalesmask1_sse));
+                    const __m256i scales_7 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(scales_67, scalesmask2_sse));
 
                     const __m256i scale_0145_0 = _mm256_shuffle_epi32(scales_0, 68);
                     const __m256i scale_2367_0 = _mm256_shuffle_epi32(scales_0, 238);
@@ -3048,62 +5880,130 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                     const __m256i scale_0145_1 = _mm256_shuffle_epi32(scales_1, 68);
                     const __m256i scale_2367_1 = _mm256_shuffle_epi32(scales_1, 238);
 
+                    const __m256i scale_0145_2 = _mm256_shuffle_epi32(scales_2, 68);
+                    const __m256i scale_2367_2 = _mm256_shuffle_epi32(scales_2, 238);
+
+                    const __m256i scale_0145_3 = _mm256_shuffle_epi32(scales_3, 68);
+                    const __m256i scale_2367_3 = _mm256_shuffle_epi32(scales_3, 238);
+
+                    const __m256i scale_0145_4 = _mm256_shuffle_epi32(scales_4, 68);
+                    const __m256i scale_2367_4 = _mm256_shuffle_epi32(scales_4, 238);
+
+                    const __m256i scale_0145_5 = _mm256_shuffle_epi32(scales_5, 68);
+                    const __m256i scale_2367_5 = _mm256_shuffle_epi32(scales_5, 238);
+
+                    const __m256i scale_0145_6 = _mm256_shuffle_epi32(scales_6, 68);
+                    const __m256i scale_2367_6 = _mm256_shuffle_epi32(scales_6, 238);
+
+                    const __m256i scale_0145_7 = _mm256_shuffle_epi32(scales_7, 68);
+                    const __m256i scale_2367_7 = _mm256_shuffle_epi32(scales_7, 238);
+
                     // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3
                     // Loaded as set of 128 bit vectors and repeated into a 256 bit vector
-                    __m256i lhs_mat_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 256 * sb)));
+                    __m256i lhs_mat_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 512 * sb)));
                     __m256i lhs_mat_01_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 0);
                     __m256i lhs_mat_23_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 17);
-                    __m256i lhs_mat_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 32 + 256 * sb)));
+                    __m256i lhs_mat_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 32 + 512 * sb)));
                     __m256i lhs_mat_01_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 0);
                     __m256i lhs_mat_23_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 17);
-                    __m256i lhs_mat_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 64 + 256 * sb)));
-                    __m256i lhs_mat_01_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 0);
-                    __m256i lhs_mat_23_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 17);
-                    __m256i lhs_mat_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 96 + 256 * sb)));
-                    __m256i lhs_mat_01_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 0);
-                    __m256i lhs_mat_23_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 17);
-                    __m256i lhs_mat_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 128 + 256 * sb)));
+                    __m256i lhs_mat_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 64 + 512 * sb)));
                     __m256i lhs_mat_01_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 0);
                     __m256i lhs_mat_23_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 17);
-                    __m256i lhs_mat_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 160 + 256 * sb)));
+                    __m256i lhs_mat_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 96 + 512 * sb)));
                     __m256i lhs_mat_01_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 0);
                     __m256i lhs_mat_23_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 17);
-                    __m256i lhs_mat_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 192 + 256 * sb)));
-                    __m256i lhs_mat_01_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 0);
-                    __m256i lhs_mat_23_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 17);
-                    __m256i lhs_mat_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 224 + 256 * sb)));
-                    __m256i lhs_mat_01_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 0);
-                    __m256i lhs_mat_23_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 17);
-
-                    // Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks
-                    __m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].bsums + 16 * sb)));
-                    __m256i lhs_bsums_hsum_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1)));
-                    lhs_bsums_hsum_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_0123_01, lhs_bsums_hsum_0123_01, 0);
+                    __m256i lhs_mat_0123_20 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 128 + 512 * sb)));
+                    __m256i lhs_mat_01_20 = _mm256_permute2f128_si256(lhs_mat_0123_20, lhs_mat_0123_20, 0);
+                    __m256i lhs_mat_23_20 = _mm256_permute2f128_si256(lhs_mat_0123_20, lhs_mat_0123_20, 17);
+                    __m256i lhs_mat_0123_21 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 160 + 512 * sb)));
+                    __m256i lhs_mat_01_21 = _mm256_permute2f128_si256(lhs_mat_0123_21, lhs_mat_0123_21, 0);
+                    __m256i lhs_mat_23_21 = _mm256_permute2f128_si256(lhs_mat_0123_21, lhs_mat_0123_21, 17);
+                    __m256i lhs_mat_0123_30 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 192 + 512 * sb)));
+                    __m256i lhs_mat_01_30 = _mm256_permute2f128_si256(lhs_mat_0123_30, lhs_mat_0123_30, 0);
+                    __m256i lhs_mat_23_30 = _mm256_permute2f128_si256(lhs_mat_0123_30, lhs_mat_0123_30, 17);
+                    __m256i lhs_mat_0123_31 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 224 + 512 * sb)));
+                    __m256i lhs_mat_01_31 = _mm256_permute2f128_si256(lhs_mat_0123_31, lhs_mat_0123_31, 0);
+                    __m256i lhs_mat_23_31 = _mm256_permute2f128_si256(lhs_mat_0123_31, lhs_mat_0123_31, 17);
+
+                    __m256i lhs_mat_0123_40 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 256 + 512 * sb)));
+                    __m256i lhs_mat_01_40 = _mm256_permute2f128_si256(lhs_mat_0123_40, lhs_mat_0123_40, 0);
+                    __m256i lhs_mat_23_40 = _mm256_permute2f128_si256(lhs_mat_0123_40, lhs_mat_0123_40, 17);
+                    __m256i lhs_mat_0123_41 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 288 + 512 * sb)));
+                    __m256i lhs_mat_01_41 = _mm256_permute2f128_si256(lhs_mat_0123_41, lhs_mat_0123_41, 0);
+                    __m256i lhs_mat_23_41 = _mm256_permute2f128_si256(lhs_mat_0123_41, lhs_mat_0123_41, 17);
+                    __m256i lhs_mat_0123_50 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 320 + 512 * sb)));
+                    __m256i lhs_mat_01_50 = _mm256_permute2f128_si256(lhs_mat_0123_50, lhs_mat_0123_50, 0);
+                    __m256i lhs_mat_23_50 = _mm256_permute2f128_si256(lhs_mat_0123_50, lhs_mat_0123_50, 17);
+                    __m256i lhs_mat_0123_51 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 352 + 512 * sb)));
+                    __m256i lhs_mat_01_51 = _mm256_permute2f128_si256(lhs_mat_0123_51, lhs_mat_0123_51, 0);
+                    __m256i lhs_mat_23_51 = _mm256_permute2f128_si256(lhs_mat_0123_51, lhs_mat_0123_51, 17);
+                    __m256i lhs_mat_0123_60 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 384 + 512 * sb)));
+                    __m256i lhs_mat_01_60 = _mm256_permute2f128_si256(lhs_mat_0123_60, lhs_mat_0123_60, 0);
+                    __m256i lhs_mat_23_60 = _mm256_permute2f128_si256(lhs_mat_0123_60, lhs_mat_0123_60, 17);
+                    __m256i lhs_mat_0123_61 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 416 + 512 * sb)));
+                    __m256i lhs_mat_01_61 = _mm256_permute2f128_si256(lhs_mat_0123_61, lhs_mat_0123_61, 0);
+                    __m256i lhs_mat_23_61 = _mm256_permute2f128_si256(lhs_mat_0123_61, lhs_mat_0123_61, 17);
+                    __m256i lhs_mat_0123_70 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 448 + 512 * sb)));
+                    __m256i lhs_mat_01_70 = _mm256_permute2f128_si256(lhs_mat_0123_70, lhs_mat_0123_70, 0);
+                    __m256i lhs_mat_23_70 = _mm256_permute2f128_si256(lhs_mat_0123_70, lhs_mat_0123_70, 17);
+                    __m256i lhs_mat_0123_71 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 480 + 512 * sb)));
+                    __m256i lhs_mat_01_71 = _mm256_permute2f128_si256(lhs_mat_0123_71, lhs_mat_0123_71, 0);
+                    __m256i lhs_mat_23_71 = _mm256_permute2f128_si256(lhs_mat_0123_71, lhs_mat_0123_71, 17);
+
+                    // Bsums are loaded for the different Q8_K blocks
+                    __m128i lhs_raw_bsums_01_0123 = _mm_loadu_si128((const __m128i *)((a_ptr[b].bsums + 32 * sb)));
+                    __m128i lhs_raw_bsums_23_0123 = _mm_loadu_si128((const __m128i *)(a_ptr[b].bsums + 8 + 32 * sb));
+                    __m128i lhs_raw_bsums_01_4567 = _mm_loadu_si128((const __m128i *)((a_ptr[b].bsums + 16 + 32 * sb)));
+                    __m128i lhs_raw_bsums_23_4567 = _mm_loadu_si128((const __m128i *)(a_ptr[b].bsums + 24 + 32 * sb));
 
                     // Shuffle pattern one - left side input
                     const __m256i lhs_mat_01_00_sp1 = _mm256_shuffle_epi32(lhs_mat_01_00, 160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)
                     const __m256i lhs_mat_23_00_sp1 = _mm256_shuffle_epi32(lhs_mat_23_00, 160); //A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3) A02(0-3) A03(0-3)
 
-                    const __m256i lhs_mat_01_01_sp1 = _mm256_shuffle_epi32(lhs_mat_01_01, 160);  //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)
-                    const __m256i lhs_mat_23_01_sp1 = _mm256_shuffle_epi32(lhs_mat_23_01, 160);  //A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11)
-
-                    const __m256i lhs_mat_01_02_sp1 = _mm256_shuffle_epi32(lhs_mat_01_02, 160);  //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19)
-                    const __m256i lhs_mat_23_02_sp1 = _mm256_shuffle_epi32(lhs_mat_23_02, 160);  //A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19) A02(16-19) A03(16-19)
-
-                    const __m256i lhs_mat_01_03_sp1 = _mm256_shuffle_epi32(lhs_mat_01_03, 160);  //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27)
-                    const __m256i lhs_mat_23_03_sp1 = _mm256_shuffle_epi32(lhs_mat_23_03, 160); //A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27) A02(24-27) A03(24-27)
+                    const __m256i lhs_mat_01_01_sp1 = _mm256_shuffle_epi32(lhs_mat_01_01, 160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11)
+                    const __m256i lhs_mat_23_01_sp1 = _mm256_shuffle_epi32(lhs_mat_23_01, 160); //A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11) A02(8-11) A03(8-11)
 
                     const __m256i lhs_mat_01_10_sp1 = _mm256_shuffle_epi32(lhs_mat_01_10, 160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3)
                     const __m256i lhs_mat_23_10_sp1 = _mm256_shuffle_epi32(lhs_mat_23_10, 160); //A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3) A12(0-3) A13(0-3)
 
-                    const __m256i lhs_mat_01_11_sp1 = _mm256_shuffle_epi32(lhs_mat_01_11, 160);  //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)
-                    const __m256i lhs_mat_23_11_sp1 = _mm256_shuffle_epi32(lhs_mat_23_11, 160);  //A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11)
+                    const __m256i lhs_mat_01_11_sp1 = _mm256_shuffle_epi32(lhs_mat_01_11, 160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11)
+                    const __m256i lhs_mat_23_11_sp1 = _mm256_shuffle_epi32(lhs_mat_23_11, 160); //A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11) A12(8-11) A13(8-11)
 
-                    const __m256i lhs_mat_01_12_sp1 = _mm256_shuffle_epi32(lhs_mat_01_12, 160);  //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19)
-                    const __m256i lhs_mat_23_12_sp1 = _mm256_shuffle_epi32(lhs_mat_23_12, 160);  //A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19) A12(16-19) A13(16-19)
+                    const __m256i lhs_mat_01_20_sp1 = _mm256_shuffle_epi32(lhs_mat_01_20, 160); //A20(0-3) A20(0-3) A21(0-3) A21(0-3) A20(0-3) A20(0-3) A21(0-3) A21(0-3)
+                    const __m256i lhs_mat_23_20_sp1 = _mm256_shuffle_epi32(lhs_mat_23_20, 160); //A22(0-3) A23(0-3) A22(0-3) A23(0-3) A22(0-3) A23(0-3) A22(0-3) A23(0-3)
 
-                    const __m256i lhs_mat_01_13_sp1 = _mm256_shuffle_epi32(lhs_mat_01_13, 160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27)
-                    const __m256i lhs_mat_23_13_sp1 = _mm256_shuffle_epi32(lhs_mat_23_13, 160); //A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27) A12(24-27) A13(24-27)
+                    const __m256i lhs_mat_01_21_sp1 = _mm256_shuffle_epi32(lhs_mat_01_21, 160); //A20(8-11) A20(8-11) A21(8-11) A21(8-11) A20(8-11) A20(8-11) A21(8-11) A21(8-11)
+                    const __m256i lhs_mat_23_21_sp1 = _mm256_shuffle_epi32(lhs_mat_23_21, 160); //A22(8-11) A23(8-11) A22(8-11) A23(8-11) A22(8-11) A23(8-11) A22(8-11) A23(8-11)
+
+                    const __m256i lhs_mat_01_30_sp1 = _mm256_shuffle_epi32(lhs_mat_01_30, 160); //A30(0-3) A30(0-3) A31(0-3) A31(0-3) A30(0-3) A30(0-3) A31(0-3) A31(0-3)
+                    const __m256i lhs_mat_23_30_sp1 = _mm256_shuffle_epi32(lhs_mat_23_30, 160); //A32(0-3) A33(0-3) A32(0-3) A33(0-3) A32(0-3) A33(0-3) A32(0-3) A33(0-3)
+
+                    const __m256i lhs_mat_01_31_sp1 = _mm256_shuffle_epi32(lhs_mat_01_31, 160); //A30(8-11) A30(8-11) A31(8-11) A31(8-11) A30(8-11) A30(8-11) A31(8-11) A31(8-11)
+                    const __m256i lhs_mat_23_31_sp1 = _mm256_shuffle_epi32(lhs_mat_23_31, 160); //A32(8-11) A33(8-11) A32(8-11) A33(8-11) A32(8-11) A33(8-11) A32(8-11) A33(8-11)
+
+                    const __m256i lhs_mat_01_40_sp1 = _mm256_shuffle_epi32(lhs_mat_01_40, 160); //A40(0-3) A40(0-3) A41(0-3) A41(0-3) A40(0-3) A40(0-3) A41(0-3) A41(0-3)
+                    const __m256i lhs_mat_23_40_sp1 = _mm256_shuffle_epi32(lhs_mat_23_40, 160); //A42(0-3) A43(0-3) A42(0-3) A43(0-3) A42(0-3) A43(0-3) A42(0-3) A43(0-3)
+
+                    const __m256i lhs_mat_01_41_sp1 = _mm256_shuffle_epi32(lhs_mat_01_41, 160); //A40(8-11) A40(8-11) A41(8-11) A41(8-11) A40(8-11) A40(8-11) A41(8-11) A41(8-11)
+                    const __m256i lhs_mat_23_41_sp1 = _mm256_shuffle_epi32(lhs_mat_23_41, 160); //A42(8-11) A43(8-11) A42(8-11) A43(8-11) A42(8-11) A43(8-11) A42(8-11) A43(8-11)
+
+                    const __m256i lhs_mat_01_50_sp1 = _mm256_shuffle_epi32(lhs_mat_01_50, 160); //A50(0-3) A50(0-3) A51(0-3) A51(0-3) A50(0-3) A50(0-3) A51(0-3) A51(0-3)
+                    const __m256i lhs_mat_23_50_sp1 = _mm256_shuffle_epi32(lhs_mat_23_50, 160); //A52(0-3) A53(0-3) A52(0-3) A53(0-3) A52(0-3) A53(0-3) A52(0-3) A53(0-3)
+
+                    const __m256i lhs_mat_01_51_sp1 = _mm256_shuffle_epi32(lhs_mat_01_51, 160); //A50(8-11) A50(8-11) A51(8-11) A51(8-11) A50(8-11) A50(8-11) A51(8-11) A51(8-11)
+                    const __m256i lhs_mat_23_51_sp1 = _mm256_shuffle_epi32(lhs_mat_23_51, 160); //A52(8-11) A53(8-11) A52(8-11) A53(8-11) A52(8-11) A53(8-11) A52(8-11) A53(8-11)
+
+                    const __m256i lhs_mat_01_60_sp1 = _mm256_shuffle_epi32(lhs_mat_01_60, 160); //A60(0-3) A60(0-3) A61(0-3) A61(0-3) A60(0-3) A60(0-3) A61(0-3) A61(0-3)
+                    const __m256i lhs_mat_23_60_sp1 = _mm256_shuffle_epi32(lhs_mat_23_60, 160); //A62(0-3) A63(0-3) A62(0-3) A63(0-3) A62(0-3) A63(0-3) A62(0-3) A63(0-3)
+
+                    const __m256i lhs_mat_01_61_sp1 = _mm256_shuffle_epi32(lhs_mat_01_61, 160); //A60(8-11) A60(8-11) A61(8-11) A61(8-11) A60(8-11) A60(8-11) A61(8-11) A61(8-11)
+                    const __m256i lhs_mat_23_61_sp1 = _mm256_shuffle_epi32(lhs_mat_23_61, 160); //A62(8-11) A63(8-11) A62(8-11) A63(8-11) A62(8-11) A63(8-11) A62(8-11) A63(8-11)
+
+                    const __m256i lhs_mat_01_70_sp1 = _mm256_shuffle_epi32(lhs_mat_01_70, 160); //A70(0-3) A70(0-3) A71(0-3) A71(0-3) A70(0-3) A70(0-3) A71(0-3) A71(0-3)
+                    const __m256i lhs_mat_23_70_sp1 = _mm256_shuffle_epi32(lhs_mat_23_70, 160); //A72(0-3) A73(0-3) A72(0-3) A73(0-3) A72(0-3) A73(0-3) A72(0-3) A73(0-3)
+
+                    const __m256i lhs_mat_01_71_sp1 = _mm256_shuffle_epi32(lhs_mat_01_71, 160); //A70(8-11) A70(8-11) A71(8-11) A71(8-11) A70(8-11) A70(8-11) A71(8-11) A71(8-11)
+                    const __m256i lhs_mat_23_71_sp1 = _mm256_shuffle_epi32(lhs_mat_23_71, 160); //A72(8-11) A73(8-11) A72(8-11) A73(8-11) A72(8-11) A73(8-11) A72(8-11) A73(8-11)
 
                     // Shuffle pattern two- left side input
                     const __m256i lhs_mat_01_00_sp2 = _mm256_shuffle_epi32(lhs_mat_01_00, 245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7)
@@ -3112,44 +6012,147 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                     const __m256i lhs_mat_01_01_sp2 = _mm256_shuffle_epi32(lhs_mat_01_01, 245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15)
                     const __m256i lhs_mat_23_01_sp2 = _mm256_shuffle_epi32(lhs_mat_23_01, 245); //A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15) A02(12-15) A03(12-15)
 
-                    const __m256i lhs_mat_01_02_sp2 = _mm256_shuffle_epi32(lhs_mat_01_02, 245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23)
-                    const __m256i lhs_mat_23_02_sp2 = _mm256_shuffle_epi32(lhs_mat_23_02, 245); //A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23) A02(20-23) A03(20-23)
-
-                    const __m256i lhs_mat_01_03_sp2 = _mm256_shuffle_epi32(lhs_mat_01_03, 245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31)
-                    const __m256i lhs_mat_23_03_sp2 = _mm256_shuffle_epi32(lhs_mat_23_03, 245); //A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31) A02(28-31) A03(28-31)
-
                     const __m256i lhs_mat_01_10_sp2 = _mm256_shuffle_epi32(lhs_mat_01_10, 245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7)
                     const __m256i lhs_mat_23_10_sp2 = _mm256_shuffle_epi32(lhs_mat_23_10, 245); //A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7) A12(4-7) A13(4-7)
 
                     const __m256i lhs_mat_01_11_sp2 = _mm256_shuffle_epi32(lhs_mat_01_11, 245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15)
                     const __m256i lhs_mat_23_11_sp2 = _mm256_shuffle_epi32(lhs_mat_23_11, 245); //A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15) A12(12-15) A13(12-15)
 
-                    const __m256i lhs_mat_01_12_sp2 = _mm256_shuffle_epi32(lhs_mat_01_12, 245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23)
-                    const __m256i lhs_mat_23_12_sp2 = _mm256_shuffle_epi32(lhs_mat_23_12, 245); //A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23) A12(20-23) A13(20-23)
+                    const __m256i lhs_mat_01_20_sp2 = _mm256_shuffle_epi32(lhs_mat_01_20, 245); //A20(4-7) A20(4-7) A21(4-7) A21(4-7) A20(4-7) A20(4-7) A21(4-7) A21(4-7)
+                    const __m256i lhs_mat_23_20_sp2 = _mm256_shuffle_epi32(lhs_mat_23_20, 245); //A22(4-7) A23(4-7) A22(4-7) A23(4-7) A22(4-7) A23(4-7) A22(4-7) A23(4-7)
 
-                    const __m256i lhs_mat_01_13_sp2 = _mm256_shuffle_epi32(lhs_mat_01_13, 245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31)
-                    const __m256i lhs_mat_23_13_sp2 = _mm256_shuffle_epi32(lhs_mat_23_13, 245); //A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31) A12(28-31) A13(28-31)
+                    const __m256i lhs_mat_01_21_sp2 = _mm256_shuffle_epi32(lhs_mat_01_21, 245); //A20(12-15) A20(12-15) A21(12-15) A21(12-15) A20(12-15) A20(12-15) A21(12-15) A21(12-15)
+                    const __m256i lhs_mat_23_21_sp2 = _mm256_shuffle_epi32(lhs_mat_23_21, 245); //A22(12-15) A23(12-15) A22(12-15) A23(12-15) A22(12-15) A23(12-15) A22(12-15) A23(12-15)
+
+                    const __m256i lhs_mat_01_30_sp2 = _mm256_shuffle_epi32(lhs_mat_01_30, 245); //A30(4-7) A30(4-7) A31(4-7) A31(4-7) A30(4-7) A30(4-7) A31(4-7) A31(4-7)
+                    const __m256i lhs_mat_23_30_sp2 = _mm256_shuffle_epi32(lhs_mat_23_30, 245); //A32(4-7) A33(4-7) A32(4-7) A33(4-7) A32(4-7) A33(4-7) A32(4-7) A33(4-7)
+
+                    const __m256i lhs_mat_01_31_sp2 = _mm256_shuffle_epi32(lhs_mat_01_31, 245); //A30(12-15) A30(12-15) A31(12-15) A31(12-15) A30(12-15) A30(12-15) A31(12-15) A31(12-15)
+                    const __m256i lhs_mat_23_31_sp2 = _mm256_shuffle_epi32(lhs_mat_23_31, 245); //A32(12-15) A33(12-15) A32(12-15) A33(12-15) A32(12-15) A33(12-15) A32(12-15) A33(12-15)
+
+                    const __m256i lhs_mat_01_40_sp2 = _mm256_shuffle_epi32(lhs_mat_01_40, 245); //A40(4-7) A40(4-7) A41(4-7) A41(4-7) A40(4-7) A40(4-7) A41(4-7) A41(4-7)
+                    const __m256i lhs_mat_23_40_sp2 = _mm256_shuffle_epi32(lhs_mat_23_40, 245); //A42(4-7) A43(4-7) A42(4-7) A43(4-7) A42(4-7) A43(4-7) A42(4-7) A43(4-7)
+
+                    const __m256i lhs_mat_01_41_sp2 = _mm256_shuffle_epi32(lhs_mat_01_41, 245); //A40(12-15) A40(12-15) A41(12-15) A41(12-15) A40(12-15) A40(12-15) A41(12-15) A41(12-15)
+                    const __m256i lhs_mat_23_41_sp2 = _mm256_shuffle_epi32(lhs_mat_23_41, 245); //A42(12-15) A43(12-15) A42(12-15) A43(12-15) A42(12-15) A43(12-15) A42(12-15) A43(12-15)
+
+                    const __m256i lhs_mat_01_50_sp2 = _mm256_shuffle_epi32(lhs_mat_01_50, 245); //A50(4-7) A50(4-7) A51(4-7) A51(4-7) A50(4-7) A50(4-7) A51(4-7) A51(4-7)
+                    const __m256i lhs_mat_23_50_sp2 = _mm256_shuffle_epi32(lhs_mat_23_50, 245); //A52(4-7) A53(4-7) A52(4-7) A53(4-7) A52(4-7) A53(4-7) A52(4-7) A53(4-7)
+
+                    const __m256i lhs_mat_01_51_sp2 = _mm256_shuffle_epi32(lhs_mat_01_51, 245); //A50(12-15) A50(12-15) A51(12-15) A51(12-15) A50(12-15) A50(12-15) A51(12-15) A51(12-15)
+                    const __m256i lhs_mat_23_51_sp2 = _mm256_shuffle_epi32(lhs_mat_23_51, 245); //A52(12-15) A53(12-15) A52(12-15) A53(12-15) A52(12-15) A53(12-15) A52(12-15) A53(12-15)
+
+                    const __m256i lhs_mat_01_60_sp2 = _mm256_shuffle_epi32(lhs_mat_01_60, 245); //A60(4-7) A60(4-7) A61(4-7) A61(4-7) A60(4-7) A60(4-7) A61(4-7) A61(4-7)
+                    const __m256i lhs_mat_23_60_sp2 = _mm256_shuffle_epi32(lhs_mat_23_60, 245); //A62(4-7) A63(4-7) A62(4-7) A63(4-7) A62(4-7) A63(4-7) A62(4-7) A63(4-7)
+
+                    const __m256i lhs_mat_01_61_sp2 = _mm256_shuffle_epi32(lhs_mat_01_61, 245); //A60(12-15) A60(12-15) A61(12-15) A61(12-15) A60(12-15) A60(12-15) A61(12-15) A61(12-15)
+                    const __m256i lhs_mat_23_61_sp2 = _mm256_shuffle_epi32(lhs_mat_23_61, 245); //A62(12-15) A63(12-15) A62(12-15) A63(12-15) A62(12-15) A63(12-15) A62(12-15) A63(12-15)
+
+                    const __m256i lhs_mat_01_70_sp2 = _mm256_shuffle_epi32(lhs_mat_01_70, 245); //A70(4-7) A70(4-7) A71(4-7) A71(4-7) A70(4-7) A70(4-7) A71(4-7) A71(4-7)
+                    const __m256i lhs_mat_23_70_sp2 = _mm256_shuffle_epi32(lhs_mat_23_70, 245); //A72(4-7) A73(4-7) A72(4-7) A73(4-7) A72(4-7) A73(4-7) A72(4-7) A73(4-7)
+
+                    const __m256i lhs_mat_01_71_sp2 = _mm256_shuffle_epi32(lhs_mat_01_71, 245); //A70(12-15) A70(12-15) A71(12-15) A71(12-15) A70(12-15) A70(12-15) A71(12-15) A71(12-15)
+                    const __m256i lhs_mat_23_71_sp2 = _mm256_shuffle_epi32(lhs_mat_23_71, 245); //A72(12-15) A73(12-15) A72(12-15) A73(12-15) A72(12-15) A73(12-15) A72(12-15) A73(12-15)
 
                     // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
-                    __m256i iacc_mat_00_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_01_00_sp1));
-                    __m256i iacc_mat_01_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_01_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_01_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_01_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_01_00_sp1));
-                    __m256i iacc_mat_10_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_0145_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_23_00_sp1));
-                    __m256i iacc_mat_11_0_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp1, lhs_mat_23_03_sp1), _mm256_maddubs_epi16(rhs_mat_2367_02_sp1, lhs_mat_23_02_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_23_01_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_23_00_sp1));
-                    __m256i iacc_mat_00_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_01_10_sp1));
-                    __m256i iacc_mat_01_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_01_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_01_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_01_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_01_10_sp1));
-                    __m256i iacc_mat_10_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_0145_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_23_10_sp1));
-                    __m256i iacc_mat_11_1_sp1 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp1, lhs_mat_23_13_sp1), _mm256_maddubs_epi16(rhs_mat_2367_12_sp1, lhs_mat_23_12_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_23_11_sp1)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_23_10_sp1));
+                    __m256i iacc_mat_00_0_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_01_00_sp1),_mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_01_01_sp1));
+                    __m256i iacc_mat_01_0_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_01_00_sp1),_mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_01_01_sp1));
 
-                    __m256i iacc_mat_00_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_01_00_sp2));
-                    __m256i iacc_mat_01_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_01_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_01_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_01_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_01_00_sp2));
-                    __m256i iacc_mat_10_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_0145_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_23_00_sp2));
-                    __m256i iacc_mat_11_0_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_03_sp2, lhs_mat_23_03_sp2), _mm256_maddubs_epi16(rhs_mat_2367_02_sp2, lhs_mat_23_02_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_23_01_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_23_00_sp2));
-                    __m256i iacc_mat_00_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_01_10_sp2));
-                    __m256i iacc_mat_01_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_01_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_01_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_01_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_01_10_sp2));
-                    __m256i iacc_mat_10_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_0145_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_23_10_sp2));
-                    __m256i iacc_mat_11_1_sp2 = _mm256_add_epi16(_mm256_add_epi16(_mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_13_sp2, lhs_mat_23_13_sp2), _mm256_maddubs_epi16(rhs_mat_2367_12_sp2, lhs_mat_23_12_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_23_11_sp2)), _mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_23_10_sp2));
+                    __m256i iacc_mat_10_0_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_00_sp1, lhs_mat_23_00_sp1),_mm256_maddubs_epi16(rhs_mat_0145_01_sp1, lhs_mat_23_01_sp1));
+                    __m256i iacc_mat_11_0_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_00_sp1, lhs_mat_23_00_sp1),_mm256_maddubs_epi16(rhs_mat_2367_01_sp1, lhs_mat_23_01_sp1));
 
-                    // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+                    __m256i iacc_mat_00_1_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_01_10_sp1),_mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_01_11_sp1));
+                    __m256i iacc_mat_01_1_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_01_10_sp1),_mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_01_11_sp1));
+
+                    __m256i iacc_mat_10_1_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_10_sp1, lhs_mat_23_10_sp1),_mm256_maddubs_epi16(rhs_mat_0145_11_sp1, lhs_mat_23_11_sp1));
+                    __m256i iacc_mat_11_1_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_10_sp1, lhs_mat_23_10_sp1),_mm256_maddubs_epi16(rhs_mat_2367_11_sp1, lhs_mat_23_11_sp1));
+
+                    __m256i iacc_mat_00_2_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_20_sp1, lhs_mat_01_20_sp1),_mm256_maddubs_epi16(rhs_mat_0145_21_sp1, lhs_mat_01_21_sp1));
+                    __m256i iacc_mat_01_2_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_20_sp1, lhs_mat_01_20_sp1),_mm256_maddubs_epi16(rhs_mat_2367_21_sp1, lhs_mat_01_21_sp1));
+
+                    __m256i iacc_mat_10_2_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_20_sp1, lhs_mat_23_20_sp1),_mm256_maddubs_epi16(rhs_mat_0145_21_sp1, lhs_mat_23_21_sp1));
+                    __m256i iacc_mat_11_2_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_20_sp1, lhs_mat_23_20_sp1),_mm256_maddubs_epi16(rhs_mat_2367_21_sp1, lhs_mat_23_21_sp1));
+
+                    __m256i iacc_mat_00_3_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_30_sp1, lhs_mat_01_30_sp1),_mm256_maddubs_epi16(rhs_mat_0145_31_sp1, lhs_mat_01_31_sp1));
+                    __m256i iacc_mat_01_3_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_30_sp1, lhs_mat_01_30_sp1),_mm256_maddubs_epi16(rhs_mat_2367_31_sp1, lhs_mat_01_31_sp1));
+
+                    __m256i iacc_mat_10_3_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_30_sp1, lhs_mat_23_30_sp1),_mm256_maddubs_epi16(rhs_mat_0145_31_sp1, lhs_mat_23_31_sp1));
+                    __m256i iacc_mat_11_3_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_30_sp1, lhs_mat_23_30_sp1),_mm256_maddubs_epi16(rhs_mat_2367_31_sp1, lhs_mat_23_31_sp1));
+
+                    __m256i iacc_mat_00_4_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_40_sp1, lhs_mat_01_40_sp1),_mm256_maddubs_epi16(rhs_mat_0145_41_sp1, lhs_mat_01_41_sp1));
+                    __m256i iacc_mat_01_4_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_40_sp1, lhs_mat_01_40_sp1),_mm256_maddubs_epi16(rhs_mat_2367_41_sp1, lhs_mat_01_41_sp1));
+
+                    __m256i iacc_mat_10_4_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_40_sp1, lhs_mat_23_40_sp1),_mm256_maddubs_epi16(rhs_mat_0145_41_sp1, lhs_mat_23_41_sp1));
+                    __m256i iacc_mat_11_4_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_40_sp1, lhs_mat_23_40_sp1),_mm256_maddubs_epi16(rhs_mat_2367_41_sp1, lhs_mat_23_41_sp1));
+
+                    __m256i iacc_mat_00_5_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_50_sp1, lhs_mat_01_50_sp1),_mm256_maddubs_epi16(rhs_mat_0145_51_sp1, lhs_mat_01_51_sp1));
+                    __m256i iacc_mat_01_5_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_50_sp1, lhs_mat_01_50_sp1),_mm256_maddubs_epi16(rhs_mat_2367_51_sp1, lhs_mat_01_51_sp1));
+
+                    __m256i iacc_mat_10_5_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_50_sp1, lhs_mat_23_50_sp1),_mm256_maddubs_epi16(rhs_mat_0145_51_sp1, lhs_mat_23_51_sp1));
+                    __m256i iacc_mat_11_5_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_50_sp1, lhs_mat_23_50_sp1),_mm256_maddubs_epi16(rhs_mat_2367_51_sp1, lhs_mat_23_51_sp1));
+
+                    __m256i iacc_mat_00_6_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_60_sp1, lhs_mat_01_60_sp1),_mm256_maddubs_epi16(rhs_mat_0145_61_sp1, lhs_mat_01_61_sp1));
+                    __m256i iacc_mat_01_6_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_60_sp1, lhs_mat_01_60_sp1),_mm256_maddubs_epi16(rhs_mat_2367_61_sp1, lhs_mat_01_61_sp1));
+
+                    __m256i iacc_mat_10_6_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_60_sp1, lhs_mat_23_60_sp1),_mm256_maddubs_epi16(rhs_mat_0145_61_sp1, lhs_mat_23_61_sp1));
+                    __m256i iacc_mat_11_6_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_60_sp1, lhs_mat_23_60_sp1),_mm256_maddubs_epi16(rhs_mat_2367_61_sp1, lhs_mat_23_61_sp1));
+
+                    __m256i iacc_mat_00_7_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_70_sp1, lhs_mat_01_70_sp1),_mm256_maddubs_epi16(rhs_mat_0145_71_sp1, lhs_mat_01_71_sp1));
+                    __m256i iacc_mat_01_7_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_70_sp1, lhs_mat_01_70_sp1),_mm256_maddubs_epi16(rhs_mat_2367_71_sp1, lhs_mat_01_71_sp1));
+
+                    __m256i iacc_mat_10_7_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_70_sp1, lhs_mat_23_70_sp1),_mm256_maddubs_epi16(rhs_mat_0145_71_sp1, lhs_mat_23_71_sp1));
+                    __m256i iacc_mat_11_7_sp1 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_70_sp1, lhs_mat_23_70_sp1),_mm256_maddubs_epi16(rhs_mat_2367_71_sp1, lhs_mat_23_71_sp1));
+
+
+                    __m256i iacc_mat_00_0_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_01_00_sp2),_mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_01_01_sp2));
+                    __m256i iacc_mat_01_0_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_01_00_sp2),_mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_01_01_sp2));
+
+                    __m256i iacc_mat_10_0_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_00_sp2, lhs_mat_23_00_sp2),_mm256_maddubs_epi16(rhs_mat_0145_01_sp2, lhs_mat_23_01_sp2));
+                    __m256i iacc_mat_11_0_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_00_sp2, lhs_mat_23_00_sp2),_mm256_maddubs_epi16(rhs_mat_2367_01_sp2, lhs_mat_23_01_sp2));
+
+                    __m256i iacc_mat_00_1_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_01_10_sp2),_mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_01_11_sp2));
+                    __m256i iacc_mat_01_1_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_01_10_sp2),_mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_01_11_sp2));
+
+                    __m256i iacc_mat_10_1_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_10_sp2, lhs_mat_23_10_sp2),_mm256_maddubs_epi16(rhs_mat_0145_11_sp2, lhs_mat_23_11_sp2));
+                    __m256i iacc_mat_11_1_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_10_sp2, lhs_mat_23_10_sp2),_mm256_maddubs_epi16(rhs_mat_2367_11_sp2, lhs_mat_23_11_sp2));
+
+                    __m256i iacc_mat_00_2_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_20_sp2, lhs_mat_01_20_sp2),_mm256_maddubs_epi16(rhs_mat_0145_21_sp2, lhs_mat_01_21_sp2));
+                    __m256i iacc_mat_01_2_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_20_sp2, lhs_mat_01_20_sp2),_mm256_maddubs_epi16(rhs_mat_2367_21_sp2, lhs_mat_01_21_sp2));
+
+                    __m256i iacc_mat_10_2_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_20_sp2, lhs_mat_23_20_sp2),_mm256_maddubs_epi16(rhs_mat_0145_21_sp2, lhs_mat_23_21_sp2));
+                    __m256i iacc_mat_11_2_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_20_sp2, lhs_mat_23_20_sp2),_mm256_maddubs_epi16(rhs_mat_2367_21_sp2, lhs_mat_23_21_sp2));
+
+                    __m256i iacc_mat_00_3_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_30_sp2, lhs_mat_01_30_sp2),_mm256_maddubs_epi16(rhs_mat_0145_31_sp2, lhs_mat_01_31_sp2));
+                    __m256i iacc_mat_01_3_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_30_sp2, lhs_mat_01_30_sp2),_mm256_maddubs_epi16(rhs_mat_2367_31_sp2, lhs_mat_01_31_sp2));
+
+                    __m256i iacc_mat_10_3_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_30_sp2, lhs_mat_23_30_sp2),_mm256_maddubs_epi16(rhs_mat_0145_31_sp2, lhs_mat_23_31_sp2));
+                    __m256i iacc_mat_11_3_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_30_sp2, lhs_mat_23_30_sp2),_mm256_maddubs_epi16(rhs_mat_2367_31_sp2, lhs_mat_23_31_sp2));
+
+                    __m256i iacc_mat_00_4_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_40_sp2, lhs_mat_01_40_sp2),_mm256_maddubs_epi16(rhs_mat_0145_41_sp2, lhs_mat_01_41_sp2));
+                    __m256i iacc_mat_01_4_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_40_sp2, lhs_mat_01_40_sp2),_mm256_maddubs_epi16(rhs_mat_2367_41_sp2, lhs_mat_01_41_sp2));
+
+                    __m256i iacc_mat_10_4_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_40_sp2, lhs_mat_23_40_sp2),_mm256_maddubs_epi16(rhs_mat_0145_41_sp2, lhs_mat_23_41_sp2));
+                    __m256i iacc_mat_11_4_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_40_sp2, lhs_mat_23_40_sp2),_mm256_maddubs_epi16(rhs_mat_2367_41_sp2, lhs_mat_23_41_sp2));
+
+                    __m256i iacc_mat_00_5_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_50_sp2, lhs_mat_01_50_sp2),_mm256_maddubs_epi16(rhs_mat_0145_51_sp2, lhs_mat_01_51_sp2));
+                    __m256i iacc_mat_01_5_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_50_sp2, lhs_mat_01_50_sp2),_mm256_maddubs_epi16(rhs_mat_2367_51_sp2, lhs_mat_01_51_sp2));
+
+                    __m256i iacc_mat_10_5_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_50_sp2, lhs_mat_23_50_sp2),_mm256_maddubs_epi16(rhs_mat_0145_51_sp2, lhs_mat_23_51_sp2));
+                    __m256i iacc_mat_11_5_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_50_sp2, lhs_mat_23_50_sp2),_mm256_maddubs_epi16(rhs_mat_2367_51_sp2, lhs_mat_23_51_sp2));
+
+                    __m256i iacc_mat_00_6_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_60_sp2, lhs_mat_01_60_sp2),_mm256_maddubs_epi16(rhs_mat_0145_61_sp2, lhs_mat_01_61_sp2));
+                    __m256i iacc_mat_01_6_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_60_sp2, lhs_mat_01_60_sp2),_mm256_maddubs_epi16(rhs_mat_2367_61_sp2, lhs_mat_01_61_sp2));
+
+                    __m256i iacc_mat_10_6_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_60_sp2, lhs_mat_23_60_sp2),_mm256_maddubs_epi16(rhs_mat_0145_61_sp2, lhs_mat_23_61_sp2));
+                    __m256i iacc_mat_11_6_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_60_sp2, lhs_mat_23_60_sp2),_mm256_maddubs_epi16(rhs_mat_2367_61_sp2, lhs_mat_23_61_sp2));
+
+                    __m256i iacc_mat_00_7_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_70_sp2, lhs_mat_01_70_sp2),_mm256_maddubs_epi16(rhs_mat_0145_71_sp2, lhs_mat_01_71_sp2));
+                    __m256i iacc_mat_01_7_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_70_sp2, lhs_mat_01_70_sp2),_mm256_maddubs_epi16(rhs_mat_2367_71_sp2, lhs_mat_01_71_sp2));
+
+                    __m256i iacc_mat_10_7_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_0145_70_sp2, lhs_mat_23_70_sp2),_mm256_maddubs_epi16(rhs_mat_0145_71_sp2, lhs_mat_23_71_sp2));
+                    __m256i iacc_mat_11_7_sp2 = _mm256_add_epi16(_mm256_maddubs_epi16(rhs_mat_2367_70_sp2, lhs_mat_23_70_sp2),_mm256_maddubs_epi16(rhs_mat_2367_71_sp2, lhs_mat_23_71_sp2));
+
+                    // Combine results from both shuffle patterns for each output block.
                     __m256i iacc_mat_00_0 = _mm256_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2);
                     __m256i iacc_mat_01_0 = _mm256_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2);
                     __m256i iacc_mat_10_0 = _mm256_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2);
@@ -3160,6 +6163,36 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                     __m256i iacc_mat_10_1 = _mm256_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2);
                     __m256i iacc_mat_11_1 = _mm256_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2);
 
+                    __m256i iacc_mat_00_2 = _mm256_add_epi16(iacc_mat_00_2_sp1, iacc_mat_00_2_sp2);
+                    __m256i iacc_mat_01_2 = _mm256_add_epi16(iacc_mat_01_2_sp1, iacc_mat_01_2_sp2);
+                    __m256i iacc_mat_10_2 = _mm256_add_epi16(iacc_mat_10_2_sp1, iacc_mat_10_2_sp2);
+                    __m256i iacc_mat_11_2 = _mm256_add_epi16(iacc_mat_11_2_sp1, iacc_mat_11_2_sp2);
+
+                    __m256i iacc_mat_00_3 = _mm256_add_epi16(iacc_mat_00_3_sp1, iacc_mat_00_3_sp2);
+                    __m256i iacc_mat_01_3 = _mm256_add_epi16(iacc_mat_01_3_sp1, iacc_mat_01_3_sp2);
+                    __m256i iacc_mat_10_3 = _mm256_add_epi16(iacc_mat_10_3_sp1, iacc_mat_10_3_sp2);
+                    __m256i iacc_mat_11_3 = _mm256_add_epi16(iacc_mat_11_3_sp1, iacc_mat_11_3_sp2);
+
+                    __m256i iacc_mat_00_4 = _mm256_add_epi16(iacc_mat_00_4_sp1, iacc_mat_00_4_sp2);
+                    __m256i iacc_mat_01_4 = _mm256_add_epi16(iacc_mat_01_4_sp1, iacc_mat_01_4_sp2);
+                    __m256i iacc_mat_10_4 = _mm256_add_epi16(iacc_mat_10_4_sp1, iacc_mat_10_4_sp2);
+                    __m256i iacc_mat_11_4 = _mm256_add_epi16(iacc_mat_11_4_sp1, iacc_mat_11_4_sp2);
+
+                    __m256i iacc_mat_00_5 = _mm256_add_epi16(iacc_mat_00_5_sp1, iacc_mat_00_5_sp2);
+                    __m256i iacc_mat_01_5 = _mm256_add_epi16(iacc_mat_01_5_sp1, iacc_mat_01_5_sp2);
+                    __m256i iacc_mat_10_5 = _mm256_add_epi16(iacc_mat_10_5_sp1, iacc_mat_10_5_sp2);
+                    __m256i iacc_mat_11_5 = _mm256_add_epi16(iacc_mat_11_5_sp1, iacc_mat_11_5_sp2);
+
+                    __m256i iacc_mat_00_6 = _mm256_add_epi16(iacc_mat_00_6_sp1, iacc_mat_00_6_sp2);
+                    __m256i iacc_mat_01_6 = _mm256_add_epi16(iacc_mat_01_6_sp1, iacc_mat_01_6_sp2);
+                    __m256i iacc_mat_10_6 = _mm256_add_epi16(iacc_mat_10_6_sp1, iacc_mat_10_6_sp2);
+                    __m256i iacc_mat_11_6 = _mm256_add_epi16(iacc_mat_11_6_sp1, iacc_mat_11_6_sp2);
+
+                    __m256i iacc_mat_00_7 = _mm256_add_epi16(iacc_mat_00_7_sp1, iacc_mat_00_7_sp2);
+                    __m256i iacc_mat_01_7 = _mm256_add_epi16(iacc_mat_01_7_sp1, iacc_mat_01_7_sp2);
+                    __m256i iacc_mat_10_7 = _mm256_add_epi16(iacc_mat_10_7_sp1, iacc_mat_10_7_sp2);
+                    __m256i iacc_mat_11_7 = _mm256_add_epi16(iacc_mat_11_7_sp1, iacc_mat_11_7_sp2);
+
                     // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
                     iacc_mat_00_0 = _mm256_madd_epi16(iacc_mat_00_0, scale_0145_0);
                     iacc_mat_01_0 = _mm256_madd_epi16(iacc_mat_01_0, scale_2367_0);
@@ -3171,24 +6204,50 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                     iacc_mat_10_1 = _mm256_madd_epi16(iacc_mat_10_1, scale_0145_1);
                     iacc_mat_11_1 = _mm256_madd_epi16(iacc_mat_11_1, scale_2367_1);
 
-                    // Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step)
-                    __m256i iacc_row_0_0 = _mm256_blend_epi32(iacc_mat_00_0, _mm256_shuffle_epi32(iacc_mat_01_0, 78), 204);
-                    __m256i iacc_row_1_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_0, 78), iacc_mat_01_0, 204);
-                    __m256i iacc_row_2_0 = _mm256_blend_epi32(iacc_mat_10_0, _mm256_shuffle_epi32(iacc_mat_11_0, 78), 204);
-                    __m256i iacc_row_3_0 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_0, 78), iacc_mat_11_0, 204);
-                    __m256i iacc_row_0_1 = _mm256_blend_epi32(iacc_mat_00_1, _mm256_shuffle_epi32(iacc_mat_01_1, 78), 204);
-                    __m256i iacc_row_1_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00_1, 78), iacc_mat_01_1, 204);
-                    __m256i iacc_row_2_1 = _mm256_blend_epi32(iacc_mat_10_1, _mm256_shuffle_epi32(iacc_mat_11_1, 78), 204);
-                    __m256i iacc_row_3_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10_1, 78), iacc_mat_11_1, 204);
+                    iacc_mat_00_2 = _mm256_madd_epi16(iacc_mat_00_2, scale_0145_2);
+                    iacc_mat_01_2 = _mm256_madd_epi16(iacc_mat_01_2, scale_2367_2);
+                    iacc_mat_10_2 = _mm256_madd_epi16(iacc_mat_10_2, scale_0145_2);
+                    iacc_mat_11_2 = _mm256_madd_epi16(iacc_mat_11_2, scale_2367_2);
+
+                    iacc_mat_00_3 = _mm256_madd_epi16(iacc_mat_00_3, scale_0145_3);
+                    iacc_mat_01_3 = _mm256_madd_epi16(iacc_mat_01_3, scale_2367_3);
+                    iacc_mat_10_3 = _mm256_madd_epi16(iacc_mat_10_3, scale_0145_3);
+                    iacc_mat_11_3 = _mm256_madd_epi16(iacc_mat_11_3, scale_2367_3);
+
+                    iacc_mat_00_4 = _mm256_madd_epi16(iacc_mat_00_4, scale_0145_4);
+                    iacc_mat_01_4 = _mm256_madd_epi16(iacc_mat_01_4, scale_2367_4);
+                    iacc_mat_10_4 = _mm256_madd_epi16(iacc_mat_10_4, scale_0145_4);
+                    iacc_mat_11_4 = _mm256_madd_epi16(iacc_mat_11_4, scale_2367_4);
+
+                    iacc_mat_00_5 = _mm256_madd_epi16(iacc_mat_00_5, scale_0145_5);
+                    iacc_mat_01_5 = _mm256_madd_epi16(iacc_mat_01_5, scale_2367_5);
+                    iacc_mat_10_5 = _mm256_madd_epi16(iacc_mat_10_5, scale_0145_5);
+                    iacc_mat_11_5 = _mm256_madd_epi16(iacc_mat_11_5, scale_2367_5);
+
+                    iacc_mat_00_6 = _mm256_madd_epi16(iacc_mat_00_6, scale_0145_6);
+                    iacc_mat_01_6 = _mm256_madd_epi16(iacc_mat_01_6, scale_2367_6);
+                    iacc_mat_10_6 = _mm256_madd_epi16(iacc_mat_10_6, scale_0145_6);
+                    iacc_mat_11_6 = _mm256_madd_epi16(iacc_mat_11_6, scale_2367_6);
+
+                    iacc_mat_00_7 = _mm256_madd_epi16(iacc_mat_00_7, scale_0145_7);
+                    iacc_mat_01_7 = _mm256_madd_epi16(iacc_mat_01_7, scale_2367_7);
+                    iacc_mat_10_7 = _mm256_madd_epi16(iacc_mat_10_7, scale_0145_7);
+                    iacc_mat_11_7 = _mm256_madd_epi16(iacc_mat_11_7, scale_2367_7);
+
+                    __m256i iacc_mat_00 = _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(iacc_mat_00_0, iacc_mat_00_1), _mm256_add_epi32(iacc_mat_00_2, iacc_mat_00_3)), _mm256_add_epi32(_mm256_add_epi32(iacc_mat_00_4, iacc_mat_00_5), _mm256_add_epi32(iacc_mat_00_6, iacc_mat_00_7)));
+                    __m256i iacc_mat_01 = _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(iacc_mat_01_0, iacc_mat_01_1), _mm256_add_epi32(iacc_mat_01_2, iacc_mat_01_3)), _mm256_add_epi32(_mm256_add_epi32(iacc_mat_01_4, iacc_mat_01_5), _mm256_add_epi32(iacc_mat_01_6, iacc_mat_01_7)));
+                    __m256i iacc_mat_10 = _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(iacc_mat_10_0, iacc_mat_10_1), _mm256_add_epi32(iacc_mat_10_2, iacc_mat_10_3)), _mm256_add_epi32(_mm256_add_epi32(iacc_mat_10_4, iacc_mat_10_5), _mm256_add_epi32(iacc_mat_10_6, iacc_mat_10_7)));
+                    __m256i iacc_mat_11 = _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(iacc_mat_11_0, iacc_mat_11_1), _mm256_add_epi32(iacc_mat_11_2, iacc_mat_11_3)), _mm256_add_epi32(_mm256_add_epi32(iacc_mat_11_4, iacc_mat_11_5), _mm256_add_epi32(iacc_mat_11_6, iacc_mat_11_7)));
 
-                    __m256i iacc_row_0 = _mm256_add_epi32(iacc_row_0_0, iacc_row_0_1);
-                    __m256i iacc_row_1 = _mm256_add_epi32(iacc_row_1_0, iacc_row_1_1);
-                    __m256i iacc_row_2 = _mm256_add_epi32(iacc_row_2_0, iacc_row_2_1);
-                    __m256i iacc_row_3 = _mm256_add_epi32(iacc_row_3_0, iacc_row_3_1);
+                    // Straighten out to make 4 row vectors
+                    __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204);
+                    __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204);
+                    __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204);
+                    __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204);
 
                     // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes
                     const __m128 row_scale_f32_sse = _mm_load_ps(a_ptr[b].d);
-                    const __m256 row_scale_f32 = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse); //GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask);
+                    const __m256 row_scale_f32 = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse);
 
                     // Multiply with appropiate scales and accumulate (for both d and dmin) below
                     acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);
@@ -3196,10 +6255,36 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                     acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
                     acc_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);
 
-                    __m256i iacc_row_min_0 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 0), mins_01);
-                    __m256i iacc_row_min_1 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 85), mins_01);
-                    __m256i iacc_row_min_2 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 170), mins_01);
-                    __m256i iacc_row_min_3 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 255), mins_01);
+                    __m256i lhs_bsums_01_0123 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_01_0123), lhs_raw_bsums_01_0123, 1);
+                    __m256i lhs_bsums_23_0123 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_23_0123), lhs_raw_bsums_23_0123, 1);
+                    __m256i lhs_bsums_01_4567 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_01_4567), lhs_raw_bsums_01_4567, 1);
+                    __m256i lhs_bsums_23_4567 = _mm256_inserti128_si256(_mm256_castsi128_si256(lhs_raw_bsums_23_4567), lhs_raw_bsums_23_4567, 1);
+
+                    // Take two bsums from two Q8_Ks at a time and multiply with corresponding mins values from each Q2_K
+                    __m256i iacc_row_min_0_01 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_0123, 0), mins_01);
+                    __m256i iacc_row_min_1_01 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_0123, 170), mins_01);
+                    __m256i iacc_row_min_2_01 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_0123, 0), mins_01);
+                    __m256i iacc_row_min_3_01 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_0123, 170), mins_01);
+
+                    __m256i iacc_row_min_0_23 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_0123, 85), mins_23);
+                    __m256i iacc_row_min_1_23 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_0123, 255), mins_23);
+                    __m256i iacc_row_min_2_23 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_0123, 85), mins_23);
+                    __m256i iacc_row_min_3_23 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_0123, 255), mins_23);
+
+                    __m256i iacc_row_min_0_45 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_4567, 0), mins_45);
+                    __m256i iacc_row_min_1_45 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_4567, 170), mins_45);
+                    __m256i iacc_row_min_2_45 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_4567, 0), mins_45);
+                    __m256i iacc_row_min_3_45 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_4567, 170), mins_45);
+
+                    __m256i iacc_row_min_0_67 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_4567, 85), mins_67);
+                    __m256i iacc_row_min_1_67 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_01_4567, 255), mins_67);
+                    __m256i iacc_row_min_2_67 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_4567, 85), mins_67);
+                    __m256i iacc_row_min_3_67 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_23_4567, 255), mins_67);
+
+                    __m256i iacc_row_min_0 = _mm256_add_epi32(_mm256_add_epi32(iacc_row_min_0_01, iacc_row_min_0_23), _mm256_add_epi32(iacc_row_min_0_45,iacc_row_min_0_67));
+                    __m256i iacc_row_min_1 = _mm256_add_epi32(_mm256_add_epi32(iacc_row_min_1_01, iacc_row_min_1_23), _mm256_add_epi32(iacc_row_min_1_45,iacc_row_min_1_67));
+                    __m256i iacc_row_min_2 = _mm256_add_epi32(_mm256_add_epi32(iacc_row_min_2_01, iacc_row_min_2_23), _mm256_add_epi32(iacc_row_min_2_45,iacc_row_min_2_67));
+                    __m256i iacc_row_min_3 = _mm256_add_epi32(_mm256_add_epi32(iacc_row_min_3_01, iacc_row_min_3_23), _mm256_add_epi32(iacc_row_min_3_45,iacc_row_min_3_67));
 
                     acc_min_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_0), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[0]);
                     acc_min_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_1), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[1]);
@@ -3207,79 +6292,16 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                     acc_min_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_3), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[3]);
                 }
             }
-
             // Store the accumulated values
             for (int i = 0; i < 4; i++) {
                 _mm256_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm256_sub_ps(acc_rows[i], acc_min_rows[i]));
             }
         }
     }
-
 #else
 
-    float sumf[4][8];
-    float sum_minf[4][8];
-    uint32_t utmp[32];
-    int sumi1;
-    int sumi2;
-    int sumi;
-
-    for (int y = 0; y < nr / 4; y++) {
-        const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
-        for (int x = 0; x < nc / ncols_interleaved; x++) {
-            const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
-            for (int m = 0; m < 4; m++) {
-                for (int j = 0; j < ncols_interleaved; j++) {
-                    sumf[m][j] = 0.0;
-                    sum_minf[m][j] = 0.0;
-                }
-            }
-            for (int l = 0; l < nb; l++) {
-                for (int sb = 0; sb < 8; sb++) {
-                    memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
-                    utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
-                    const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
-                    utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
-                    utmp[sb * 4 + 2] = uaux_0;
-                    utmp[sb * 4 + 0] &= kmask1;
-                }
-                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
-                    uint8_t *scales_0 = (uint8_t*) utmp + (k / 4) * 32;
-                    uint8_t *scales_1 = (uint8_t*) utmp + (k / 4) * 32 + 16;
-                    for (int m = 0; m < 4; m++) {
-                        for (int j = 0; j < ncols_interleaved; j++) {
-                            sumi1 = 0;
-                            sumi2 = 0;
-                            sumi = 0;
-                            for (int i = 0; i < blocklen; ++i) {
-                                const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
-                                const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
-                                sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i]);
-                                sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i + 128]);
-                                sumi1 = sumi1 * scales_0[j];
-                                sumi2 = sumi2 * scales_1[j];
-                                sumi += sumi1 + sumi2;
-                            }
-                            sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
-                        }
-                    }
-                }
-                for (int sb = 0; sb < 8; sb++) {
-                    uint8_t *mins = (uint8_t*) utmp + 8 + sb * 16;
-                    for(int m = 0; m < 4; m++) {
-                        const int16_t *bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
-                        for(int j = 0; j < ncols_interleaved; j++) {
-                            sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
-                        }
-                    }
-                }
-            }
-            for (int m = 0; m < 4; m++) {
-                for (int j = 0; j < ncols_interleaved; j++) {
-                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
-                }
-            }
-        }
-    }
+    ggml_gemm_q2_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
+
+
 #endif
 }
diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
index c5271b77572..f6bea3df34a 100644
--- a/ggml/src/ggml-cpu/ggml-cpu.c
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
@@ -253,6 +253,12 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
         .vec_dot_type             = GGML_TYPE_Q8_1,
         .nrows                    = 1,
     },
+    [GGML_TYPE_MXFP4] = {
+        .from_float               = quantize_row_mxfp4,
+        .vec_dot                  = ggml_vec_dot_mxfp4_q8_0,
+        .vec_dot_type             = GGML_TYPE_Q8_0,
+        .nrows                    = 1,
+    },
     [GGML_TYPE_Q2_K] = {
         .from_float               = quantize_row_q2_K,
         .vec_dot                  = ggml_vec_dot_q2_K_q8_K,
@@ -1670,6 +1676,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_add(params, tensor);
             } break;
+        case GGML_OP_ADD_ID:
+            {
+                ggml_compute_forward_add_id(params, tensor);
+            } break;
         case GGML_OP_ADD1:
             {
                 ggml_compute_forward_add1(params, tensor);
@@ -1924,7 +1934,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_FLASH_ATTN_EXT:
             {
-                ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
+                ggml_compute_forward_flash_attn_ext(params, tensor);
             } break;
         case GGML_OP_FLASH_ATTN_BACK:
             {
@@ -2012,6 +2022,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
                 ggml_compute_forward_opt_step_adamw(params, tensor);
             }
             break;
+        case GGML_OP_OPT_STEP_SGD:
+            {
+                ggml_compute_forward_opt_step_sgd(params, tensor);
+            }
+            break;
         case GGML_OP_NONE:
             {
                 // nop
@@ -2111,6 +2126,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_DUP:
         case GGML_OP_CONT:
         case GGML_OP_ADD:
+        case GGML_OP_ADD_ID:
         case GGML_OP_ADD1:
         case GGML_OP_ACC:
             {
@@ -2172,6 +2188,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
                 case GGML_GLU_OP_REGLU:
                 case GGML_GLU_OP_GEGLU:
                 case GGML_GLU_OP_SWIGLU:
+                case GGML_GLU_OP_SWIGLU_OAI:
                 case GGML_GLU_OP_GEGLU_ERF:
                 case GGML_GLU_OP_GEGLU_QUICK:
                     {
@@ -2313,6 +2330,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_CROSS_ENTROPY_LOSS:
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
         case GGML_OP_OPT_STEP_ADAMW:
+        case GGML_OP_OPT_STEP_SGD:
             {
                 n_tasks = n_threads;
             } break;
@@ -2673,6 +2691,7 @@ struct ggml_cplan ggml_graph_plan(
                         }
                     } break;
                 case GGML_OP_ADD:
+                case GGML_OP_ADD_ID:
                 case GGML_OP_ADD1:
                     {
                         if (ggml_is_quantized(node->src[0]->type)) {
diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp
index c9daa4c39e8..8dacd36714b 100644
--- a/ggml/src/ggml-cpu/ggml-cpu.cpp
+++ b/ggml/src/ggml-cpu/ggml-cpu.cpp
@@ -35,7 +35,7 @@
 
 // ggml-backend interface
 
-std::vector& ggml_backend_cpu_get_extra_buffers_type() {
+std::vector & ggml_backend_cpu_get_extra_buffer_types() {
     static std::vector bufts = []() {
         std::vector bufts;
 
@@ -57,8 +57,6 @@ std::vector& ggml_backend_cpu_get_extra_buffers_type
         }
 #endif
 
-        bufts.push_back(NULL);
-
         return bufts;
     }();
 
@@ -66,14 +64,20 @@ std::vector& ggml_backend_cpu_get_extra_buffers_type
 }
 
 static ggml_backend_buffer_type_t * ggml_backend_cpu_device_get_extra_buffers_type(ggml_backend_dev_t device) {
-    return ggml_backend_cpu_get_extra_buffers_type().data();
+    static std::vector extra_bufts = [] {
+        std::vector bufts = ggml_backend_cpu_get_extra_buffer_types();
+        bufts.push_back(nullptr);
+        return bufts;
+    }();
+
+    return extra_bufts.data();
 
     GGML_UNUSED(device);
 }
 
 static bool ggml_backend_cpu_is_extra_buffer_type(ggml_backend_buffer_type_t buft) {
-    for (auto * extra : ggml_backend_cpu_get_extra_buffers_type()) {
-        if (extra && extra == buft) {
+    for (auto * extra : ggml_backend_cpu_get_extra_buffer_types()) {
+        if (extra == buft) {
             return true;
         }
     }
@@ -210,10 +214,10 @@ ggml_backend_t ggml_backend_cpu_init(void) {
     ctx->abort_callback_data = NULL;
 
     ggml_backend_t cpu_backend = new ggml_backend {
-        /* .guid      = */ ggml_backend_cpu_guid(),
-        /* .interface = */ ggml_backend_cpu_i,
-        /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
-        /* .context   = */ ctx,
+        /* .guid    = */ ggml_backend_cpu_guid(),
+        /* .iface   = */ ggml_backend_cpu_i,
+        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
+        /* .context = */ ctx,
     };
 
     if (cpu_backend == NULL) {
@@ -397,20 +401,13 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
         return true;
     }
 
-    // extra_buffer_op?
-    for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
-        if (extra) {
-            auto buf_extra = (ggml::cpu::extra_buffer_type*) extra->context;
-            if (buf_extra && buf_extra->supports_op(dev, op)) {
-                return true;
-            }
-        }
-    }
-
-    // the other case need host buffer.
-    for (int i = 0; i < GGML_MAX_SRC; i++) {
-        if (op->src[i] && op->src[i]->buffer && !ggml_backend_buft_is_host(op->src[i]->buffer->buft)) {
-            return false;
+    // check extra buffer types
+    // note: only the first sources are checked for extra buffer types to reduce overhead, increase if necessary
+    for (int i = 0; i < 4; i++) {
+        if (op->src[i] && op->src[i]->buffer &&
+            ggml_backend_cpu_is_extra_buffer_type(op->src[i]->buffer->buft)) {
+            auto * buf_extra = (ggml::cpu::extra_buffer_type *) op->src[i]->buffer->buft->context;
+            return buf_extra->supports_op(dev, op);
         }
     }
 
diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.cpp b/ggml/src/ggml-cpu/kleidiai/kernels.cpp
index 910fd0ee4e7..ddd29d002d1 100644
--- a/ggml/src/ggml-cpu/kleidiai/kernels.cpp
+++ b/ggml/src/ggml-cpu/kleidiai/kernels.cpp
@@ -22,9 +22,94 @@
 
 #include "kai_common.h"
 
+#include "simd-mappings.h"
+
 #include "kernels.h"
 
 #define NELEMS(x) sizeof(x) / sizeof(*x)
+
+static const size_t INT4_PER_BYTE = 2;
+static const size_t INT4_BITS     = 4;
+static const int Q4_0_ZERO_POINT  = 8;
+const size_t INT4_PER_UINT16      = 4;
+
+static void dequantize_row_qsi4c32pscalef16(
+    const void *packed_data,
+    int32_t row_idx,
+    int64_t nc,
+    float *out,
+    size_t nr_pack,
+    size_t packed_row_stride,
+    size_t kr,
+    size_t bl,
+    size_t num_bytes_multiplier
+) {
+    size_t group_idx = row_idx / nr_pack;
+    size_t row_in_group = row_idx % nr_pack;
+    const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride;
+    size_t num_blocks = nc / bl;
+    const uint8_t *block_ptr = packed_group;
+
+    for (size_t b = 0; b < num_blocks; ++b) {
+        uint16_t scale_f16 = *((const uint16_t *)(block_ptr + row_in_group * num_bytes_multiplier));
+        float scale = GGML_CPU_FP16_TO_FP32(scale_f16);
+
+        const uint8_t *segment_ptr = block_ptr + nr_pack * num_bytes_multiplier;
+        size_t num_segments = bl / kr;
+        size_t num_bytes_per_segment = kr / INT4_PER_BYTE;
+
+        for (size_t s = 0; s < num_segments; ++s) {
+            const uint8_t *seg_base = segment_ptr + s * nr_pack * num_bytes_per_segment;
+            const uint8_t *qbytes = seg_base + row_in_group * num_bytes_per_segment;
+            for (size_t k = 0; k < num_bytes_per_segment; ++k) {
+                uint8_t byte = qbytes[k] ^ 0x88;
+                int x0 = (byte & 0x0F) - Q4_0_ZERO_POINT;
+                int x1 = (byte >> INT4_BITS) - Q4_0_ZERO_POINT;
+                out[b * bl + s * num_bytes_per_segment + k] = x0 * scale;
+                out[b * bl + s * num_bytes_per_segment + k + bl/2] = x1 * scale;
+            }
+        }
+        block_ptr += nr_pack * num_bytes_multiplier + num_segments * nr_pack * num_bytes_per_segment;
+    }
+}
+
+static void dequantize_row_qsi4c32ps1s0scalef16(
+    const void *packed_data,
+    int32_t row_idx,
+    int64_t k,
+    float *out,
+    size_t nr,
+    size_t packed_row_stride,
+    size_t kr,
+    size_t bl,
+    size_t num_bytes_multiplier
+) {
+    const size_t num_blocks = k / bl;
+    const size_t bl4 = bl / INT4_PER_UINT16;
+
+    size_t group_idx = row_idx / nr;
+    size_t row_in_group = row_idx % nr;
+
+    const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride;
+    const uint16_t *qdata = (const uint16_t *)packed_group;
+    const uint16_t *scales = (const uint16_t *)(packed_group + packed_row_stride - (nr * num_blocks * num_bytes_multiplier));
+
+    for (size_t block_idx = 0; block_idx < num_blocks; ++block_idx) {
+        uint16_t scale_f16 = scales[row_in_group + block_idx * nr];
+        float scale = GGML_CPU_FP16_TO_FP32(scale_f16);
+
+        for (size_t bl4_idx = 0; bl4_idx < bl4; ++bl4_idx) {
+            uint16_t q = qdata[(block_idx * bl4 + bl4_idx) * nr + row_in_group];
+
+            for (size_t qidx = 0; qidx < INT4_PER_UINT16; ++qidx) {
+                int v = ((q >> (qidx * 4)) & 0xF) - Q4_0_ZERO_POINT;
+                out[block_idx * bl + bl4_idx * INT4_BITS + qidx] = v * scale;
+            }
+        }
+    }
+    GGML_UNUSED(kr);
+}
+
 static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
 #if defined(__ARM_FEATURE_SME)
     {
@@ -63,8 +148,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
         },
         /* .rhs_info = */ {
-            /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
-            /* .pack_func   = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
+            /* .packed_size   = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
+            /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
+            /* .pack_func     = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
+            /* .to_float      = */ dequantize_row_qsi4c32ps1s0scalef16,
         },
         /* .required_cpu       = */ CPU_FEATURE_SME,
         /* .lhs_type           = */ GGML_TYPE_F32,
@@ -107,8 +194,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .pack_func             = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
         },
         /* .rhs_info = */ {
-            /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
-            /* .pack_func   = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
+            /* .packed_size   = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
+            /* .packed_stride = */ NULL,
+            /* .pack_func     = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
+            /* .to_float      = */ NULL,
         },
         /* .required_cpu       = */ CPU_FEATURE_SME,
         /* .lhs_type           = */ GGML_TYPE_F32,
@@ -154,8 +243,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
         },
         /* .rhs_info = */ {
-            /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
-            /* .pack_func   = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .packed_size   = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .pack_func     = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .to_float      = */ dequantize_row_qsi4c32pscalef16,
         },
         /* .required_cpu       = */ CPU_FEATURE_DOTPROD,
         /* .lhs_type           = */ GGML_TYPE_F32,
@@ -200,8 +291,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
         },
         /* .rhs_info = */ {
-            /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
-            /* .pack_func   = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .packed_size   = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .pack_func     = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .to_float      = */ dequantize_row_qsi4c32pscalef16,
         },
         /* .required_cpu       = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
         /* .lhs_type           = */ GGML_TYPE_F32,
@@ -247,8 +340,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
         },
         /* .rhs_info = */ {
-            /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
-            /* .pack_func   = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .packed_size   = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .pack_func     = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .to_float      = */ dequantize_row_qsi4c32pscalef16,
         },
         /* .required_cpu       = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
         /* .lhs_type           = */ GGML_TYPE_F32,
@@ -293,8 +388,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
         },
         /* .rhs_info = */ {
-            /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
-            /* .pack_func   = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .packed_size   = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .pack_func     = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .to_float      = */ dequantize_row_qsi4c32pscalef16,
         },
         /* .required_cpu       = */ CPU_FEATURE_DOTPROD,
         /* .lhs_type           = */ GGML_TYPE_F32,
diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.h b/ggml/src/ggml-cpu/kleidiai/kernels.h
index 3b268d4a22a..bc8f33405d1 100644
--- a/ggml/src/ggml-cpu/kleidiai/kernels.h
+++ b/ggml/src/ggml-cpu/kleidiai/kernels.h
@@ -71,12 +71,15 @@ struct rhs_packing_info {
         std::function,
         std::function
     > packed_size;
+    size_t (*packed_stride)(size_t k, size_t nr, size_t kr, size_t bl);
     std::variant<
         std::function,
         std::function
     > pack_func;
+    void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out, size_t nr_pack, size_t packed_row_stride,
+          size_t kr, size_t bl, size_t num_bytes_multiplier);
 };
 
 struct ggml_kleidiai_kernels {
diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp
index fafe45e6c5c..dff8fa244a1 100644
--- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp
+++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp
@@ -40,6 +40,17 @@ struct ggml_kleidiai_context {
     ggml_kleidiai_kernels * kernels;
 } static ctx = { CPU_FEATURE_NONE, NULL };
 
+static const char* cpu_feature_to_string(cpu_feature f) {
+    switch (f) {
+        case CPU_FEATURE_NONE:    return "NONE";
+        case CPU_FEATURE_DOTPROD: return "DOTPROD";
+        case CPU_FEATURE_I8MM:    return "I8MM";
+        case CPU_FEATURE_SVE:     return "SVE";
+        case CPU_FEATURE_SME:     return "SME";
+        default:                  return "UNKNOWN";
+    }
+}
+
 static void init_kleidiai_context(void) {
 
     ggml_critical_section_start();
@@ -62,6 +73,11 @@ static void init_kleidiai_context(void) {
             ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
         }
         ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
+#ifndef NDEBUG
+        if (ctx.kernels) {
+            GGML_LOG_DEBUG("kleidiai: using kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels->required_cpu));
+        }
+#endif
     }
     ggml_critical_section_end();
 }
@@ -102,6 +118,9 @@ static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint1
 
 class tensor_traits : public ggml::cpu::tensor_traits {
     bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
+        if (op->op != GGML_OP_MUL_MAT) {
+            return false;
+        }
         ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
         GGML_ASSERT(kernels);
         kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
@@ -135,6 +154,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
             } else if (dst->src[0]->type == GGML_TYPE_F16) {
                 return compute_forward_kv_cache(params, dst);
             }
+        } else if (dst->op == GGML_OP_GET_ROWS) {
+            if (dst->src[0]->type == GGML_TYPE_Q4_0) {
+                return compute_forward_get_rows(params, dst);
+            }
         }
         return false;
     }
@@ -236,7 +259,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
                 const int64_t m_start      = 0;
 
                 const int64_t n_step      = static_cast(kernel->get_n_step());
-                const int64_t num_threads = KAI_MIN(n / n_step, nth);
+                int64_t num_threads       = KAI_MIN(n / n_step, nth);
+                if (num_threads <= 0) {
+                    num_threads = 1;
+                }
 
                 if (ith < num_threads) {
                     const int64_t num_n_per_thread0   = round_down(n / num_threads, n_step);
@@ -270,6 +296,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
     }
 
     bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
+        GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
+
         const ggml_tensor * src0 = dst->src[0];
         const ggml_tensor * src1 = dst->src[1];
 
@@ -284,7 +312,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
         GGML_ASSERT(kernel);
 
         const int ith = params->ith;
-        const int nth = params->nth;
+        const int nth_raw = params->nth;
+        const int nth = nth_raw > 0 ? nth_raw : 1;
 
         const size_t k = ne00;
         const size_t m = ne11;
@@ -302,9 +331,12 @@ class tensor_traits : public ggml::cpu::tensor_traits {
         const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
         const size_t n_start = ith * num_n_per_thread;
 
-        size_t n_to_process = num_n_per_thread;
-        if ((n_start + n_to_process) > n) {
-            n_to_process = n - n_start;
+        size_t n_to_process = 0;
+        if (n_start < n) {
+            n_to_process = num_n_per_thread;
+            if ((n_start + n_to_process) > n) {
+                n_to_process = n - n_start;
+            }
         }
 
         // Calculate number of columns to be processed per thread
@@ -336,14 +368,57 @@ class tensor_traits : public ggml::cpu::tensor_traits {
         const void* lhs_ptr            = (const void*)((const char *)lhs_packed + lhs_packed_offset);
         float *dst_ptr                 = reinterpret_cast(static_cast(dst->data) + dst_offset);
 
-        variant_call(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
-                           sizeof(float), -FLT_MAX, FLT_MAX);
+        if (n_to_process > 0) {
+            variant_call(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
+                               sizeof(float), -FLT_MAX, FLT_MAX);
+        }
+
+        return true;
+    }
+
+    bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
+        GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
+        GGML_ASSERT(ctx.kernels);
+
+        const ggml_tensor * src0 = dst->src[0];
+        const ggml_tensor * src1 = dst->src[1];
+
+        GGML_TENSOR_BINARY_OP_LOCALS
+
+        rhs_packing_info * rhs_info = &ctx.kernels->rhs_info;
+        kernel_info * kernel        = &ctx.kernels->gemm;
+
+        const int64_t nc     = ne00;
+        const int64_t nr     = ggml_nelements(src1);
+
+        const size_t block_rows = kernel->get_nr();
+        const size_t kr         = kernel->get_kr();
+
+        const size_t num_bytes_multiplier = sizeof(uint16_t);
+        const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, QK4_0);
+
+        const int ith = params->ith;
+        const int nth = params->nth;
+
+        const int dr = (nr + nth - 1) / nth;
+        const int ir0 = dr * ith;
+        const int ir1 = MIN(ir0 + dr, nr);
+
+        for (int64_t i = ir0; i < ir1; ++i) {
+            GGML_ASSERT(src1->type == GGML_TYPE_I32);
+            int64_t row_idx = ((const int32_t *)src1->data)[i];
+            GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
+
+            float *out = (float *)((char *)dst->data + i * nb1);
+            rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, QK4_0, num_bytes_multiplier);
+        }
 
         return true;
     }
 
 public:
     int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
+        GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
         GGML_ASSERT(ctx.kernels);
         const size_t n = tensor->ne[1];
         const size_t k = tensor->ne[0];
@@ -351,17 +426,12 @@ class tensor_traits : public ggml::cpu::tensor_traits {
         size_t kr      = ctx.kernels->gemm.get_kr();
         size_t sr      = ctx.kernels->gemm.get_sr();
 
-#ifndef NDEBUG
-        const size_t repacked_size = variant_call(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
-        GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!");
-#endif
         struct kai_rhs_pack_qs4cxs1s0_param params;
         params.lhs_zero_point = 1;
         params.rhs_zero_point = 8;
         variant_call(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, ¶ms);
 
         return 0;
-
         GGML_UNUSED(data_size);
     }
 };
@@ -375,8 +445,8 @@ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struc
 static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
     tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
 
-    GGML_UNUSED(buffer);
     return GGML_STATUS_SUCCESS;
+    GGML_UNUSED(buffer);
 }
 
 static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
@@ -418,18 +488,35 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b
     GGML_UNUSED(buft);
 }
 
+static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
+    GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
+    GGML_ASSERT(ctx.kernels);
+
+    const size_t n  = tensor->ne[1];
+    const size_t k  = tensor->ne[0];
+    const size_t nr = ctx.kernels->gemm.get_nr();
+    const size_t kr = ctx.kernels->gemm.get_kr();
+
+    return variant_call(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
+
+    GGML_UNUSED(buft);
+}
+
 namespace ggml::cpu::kleidiai {
 class extra_buffer_type : ggml::cpu::extra_buffer_type {
     bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
-        if (op->op == GGML_OP_MUL_MAT &&
+        if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
             op->src[0]->type == GGML_TYPE_Q4_0 &&
             op->src[0]->buffer &&
             (ggml_n_dims(op->src[0]) == 2) &&
             op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
+            if (op->op == GGML_OP_GET_ROWS && op->src[1]->ne[0] != 8) {
+                return false;
+            }
             if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
                 return false;
             }
-            if (op->src[1]->type == GGML_TYPE_F32 &&
+            if ((op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_I32) &&
                 ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
                 return true;
             }
@@ -438,7 +525,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
     }
 
     ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
-        if (op->op == GGML_OP_MUL_MAT) {
+        if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) {
             if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
                 return (ggml::cpu::tensor_traits *) op->src[0]->extra;
             }
@@ -469,7 +556,7 @@ ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) {
                            /* .alloc_buffer     = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,
                            /* .get_alignment    = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,
                            /* .get_max_size     = */ nullptr,  // defaults to SIZE_MAX
-                           /* .get_alloc_size   = */ nullptr,  // defaults to ggml_nbytes
+                           /* .get_alloc_size   = */ ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size,
                            /* .is_host          = */ nullptr,
                            },
         /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp
index ed61869a550..2be54c31b5f 100644
--- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp
+++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp
@@ -1541,7 +1541,7 @@ class tinyBLAS_BF16_PPC {
        } else if constexpr(RM == 8 && RN == 4) {
           KERNEL_8x4(ii,jj);
        } else {
-          static_assert(false, "RN/RM values not supported");
+          assert(false && "RN/RM values not supported");
        }
     }
 
@@ -1573,13 +1573,13 @@ class tinyBLAS_BF16_PPC {
     const int nth;
 };
 
-template 
+template 
 class tinyBLAS_Q0_PPC {
   public:
     tinyBLAS_Q0_PPC(int64_t k,
                 const TA *A, int64_t lda,
-                const TB *B, int64_t ldb,
-                TC *C, int64_t ldc,
+                const block_q8_0 *B, int64_t ldb,
+                float *C, int64_t ldc,
                 int ith, int nth)
         : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
     }
@@ -1590,8 +1590,7 @@ class tinyBLAS_Q0_PPC {
 
   private:
 
-    template
-    inline void save_res(int ii, int jj, int idx, vector float* fin_res) {
+    inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
        for (int I = 0; I < RM; I++) {
           for (int J = 0; J < RN; J++) {
              *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
@@ -1611,29 +1610,67 @@ class tinyBLAS_Q0_PPC {
           fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
        }
     }
-
-    template
-    void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, VA* vec, std::array& comparray) {
-        int64_t i, j;
-        TA *aoffset = NULL;
-        VA *vecOffset = NULL;
-        TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
-        TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
-        VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
-        VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
-        VB t1, t2, t3, t4, t5, t6, t7, t8;
+    /* This function processes quantized data from block_q4_0 elements.
+     * First the we try to extract the two int4 values stored in single int8_t into two signed int8.
+     * And then we subtract each of the resultant element with 8, to convert signed int8 to unsigned int8.
+     * Also compute the rowsum which is required to compensate the above conversion. */
+    inline void process_q4_elements(vector signed char (&c)[2], int* ca) {
         const vector signed char lowMask = vec_splats((signed char)0xF);
         const vector unsigned char v4 = vec_splats((unsigned char)0x4);
         const vector signed char v8 = vec_splats((signed char)0x8);
-        aoffset = const_cast(a);
-        vecOffset = vec;
+        vector signed int vsum = {0};
+        vector signed int vsum2 = {0};
+        c[0] = vec_and(c[1], lowMask);
+        c[1] = vec_sr(c[1], v4);
+        c[0] = vec_sub(c[0], v8);
+        c[1] = vec_sub(c[1], v8);
+        vsum = vec_sum4s(c[0], vsum);
+        vsum2 = vec_sum4s(c[1], vsum2);
+        vsum = vec_add(vsum, vsum2);
+        *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+    }
+
+    template 
+    inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) {
         vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
         vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
         vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
         vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
-        vector signed int vsum = {0};
-        vector signed int vsum2 = {0};
+        V2 t1, t2, t3, t4, t5, t6, t7, t8;
+        vector unsigned char xor_vector;
+        uint8_t flip_vec = 0x80;
+        xor_vector = vec_splats(flip_vec);
+        t1 = vec_perm(s1, s2, swiz1);
+        t2 = vec_perm(s1, s2, swiz2);
+        t3 = vec_perm(s3, s4, swiz1);
+        t4 = vec_perm(s3, s4, swiz2);
+        t5 = vec_perm(t1, t3, swiz3);
+        t6 = vec_perm(t1, t3, swiz4);
+        t7 = vec_perm(t2, t4, swiz3);
+        t8 = vec_perm(t2, t4, swiz4);
+        if (flip == true) {
+            t5 = vec_xor(t5, xor_vector);
+            t6 = vec_xor(t6, xor_vector);
+            t7 = vec_xor(t7, xor_vector);
+            t8 = vec_xor(t8, xor_vector);
+        }
+        vec_xst(t5, 0, vecOffset);
+        vec_xst(t6, 0, vecOffset+16);
+        vec_xst(t7, 0, vecOffset+32);
+        vec_xst(t8, 0, vecOffset+48);
+    }
 
+    template
+    void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array& comparray) {
+        int64_t i, j;
+        TA *aoffset = NULL;
+        int8_t *vecOffset = NULL;
+        TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
+        TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
+        vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
+        vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
+        aoffset = const_cast(a);
+        vecOffset = vec;
         j = (rows >> 3);
         if (j > 0) {
             do {
@@ -1646,159 +1683,30 @@ class tinyBLAS_Q0_PPC {
                 aoffset7 = aoffset6 + lda;
                 aoffset8 = aoffset7 + lda;
                 aoffset += 8 * lda;
-
                 i = (cols >> 2);
                 if (i > 0) {
                     do {
-                        c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs));
-                        c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs));
-                        c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs));
-                        c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs));
-                        c5[1] = reinterpret_cast(vec_xl(0, aoffset5->qs));
-                        c6[1] = reinterpret_cast(vec_xl(0, aoffset6->qs));
-                        c7[1] = reinterpret_cast(vec_xl(0, aoffset7->qs));
-                        c8[1] = reinterpret_cast(vec_xl(0, aoffset8->qs));
-
-                        c1[0] = vec_and(c1[1], lowMask);
-                        c1[1] = vec_sr(c1[1], v4);
-                        c1[0] = vec_sub(c1[0], v8);
-                        c1[1] = vec_sub(c1[1], v8);
-                        vsum = vec_sum4s(c1[0], vsum);
-                        vsum2 = vec_sum4s(c1[1], vsum2);
-                        vsum = vec_add(vsum, vsum2);
-                        comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
-                        vsum = vec_splats(0);
-                        vsum2 = vec_splats(0);
-
-                        c2[0] = vec_and(c2[1], lowMask);
-                        c2[1] = vec_sr(c2[1], v4);
-                        c2[0] = vec_sub(c2[0], v8);
-                        c2[1] = vec_sub(c2[1], v8);
-                        vsum = vec_sum4s(c2[0], vsum);
-                        vsum2 = vec_sum4s(c2[1], vsum2);
-                        vsum = vec_add(vsum, vsum2);
-                        comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
-                        vsum = vec_splats(0);
-                        vsum2 = vec_splats(0);
-
-                        c3[0] = vec_and(c3[1], lowMask);
-                        c3[1] = vec_sr(c3[1], v4);
-                        c3[0] = vec_sub(c3[0], v8);
-                        c3[1] = vec_sub(c3[1], v8);
-                        vsum = vec_sum4s(c3[0], vsum);
-                        vsum2 = vec_sum4s(c3[1], vsum2);
-                        vsum = vec_add(vsum, vsum2);
-                        comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
-                        vsum = vec_splats(0);
-                        vsum2 = vec_splats(0);
-
-                        c4[0] = vec_and(c4[1], lowMask);
-                        c4[1] = vec_sr(c4[1], v4);
-                        c4[0] = vec_sub(c4[0], v8);
-                        c4[1] = vec_sub(c4[1], v8);
-                        vsum = vec_sum4s(c4[0], vsum);
-                        vsum2 = vec_sum4s(c4[1], vsum2);
-                        vsum = vec_add(vsum, vsum2);
-                        comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
-                        vsum = vec_splats(0);
-                        vsum2 = vec_splats(0);
-
-                        c5[0] = vec_and(c5[1], lowMask);
-                        c5[1] = vec_sr(c5[1], v4);
-                        c5[0] = vec_sub(c5[0], v8);
-                        c5[1] = vec_sub(c5[1], v8);
-                        vsum = vec_sum4s(c5[0], vsum);
-                        vsum2 = vec_sum4s(c5[1], vsum2);
-                        vsum = vec_add(vsum, vsum2);
-                        comparray[4] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
-                        vsum = vec_splats(0);
-                        vsum2 = vec_splats(0);
-
-                        c6[0] = vec_and(c6[1], lowMask);
-                        c6[1] = vec_sr(c6[1], v4);
-                        c6[0] = vec_sub(c6[0], v8);
-                        c6[1] = vec_sub(c6[1], v8);
-                        vsum = vec_sum4s(c6[0], vsum);
-                        vsum2 = vec_sum4s(c6[1], vsum2);
-                        vsum = vec_add(vsum, vsum2);
-                        comparray[5] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
-                        vsum = vec_splats(0);
-                        vsum2 = vec_splats(0);
-
-                        c7[0] = vec_and(c7[1], lowMask);
-                        c7[1] = vec_sr(c7[1], v4);
-                        c7[0] = vec_sub(c7[0], v8);
-                        c7[1] = vec_sub(c7[1], v8);
-                        vsum = vec_sum4s(c7[0], vsum);
-                        vsum2 = vec_sum4s(c7[1], vsum2);
-                        vsum = vec_add(vsum, vsum2);
-                        comparray[6] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
-                        vsum = vec_splats(0);
-                        vsum2 = vec_splats(0);
-
-                        c8[0] = vec_and(c8[1], lowMask);
-                        c8[1] = vec_sr(c8[1], v4);
-                        c8[0] = vec_sub(c8[0], v8);
-                        c8[1] = vec_sub(c8[1], v8);
-                        vsum = vec_sum4s(c8[0], vsum);
-                        vsum2 = vec_sum4s(c8[1], vsum2);
-                        vsum = vec_add(vsum, vsum2);
-                        comparray[7] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
-                        vsum = vec_splats(0);
-                        vsum2 = vec_splats(0);
-
-                        t1 = vec_perm(c1[0], c2[0], swiz1);
-                        t2 = vec_perm(c1[0], c2[0], swiz2);
-                        t3 = vec_perm(c3[0], c4[0], swiz1);
-                        t4 = vec_perm(c3[0], c4[0], swiz2);
-                        t5 = vec_perm(t1, t3, swiz3);
-                        t6 = vec_perm(t1, t3, swiz4);
-                        t7 = vec_perm(t2, t4, swiz3);
-                        t8 = vec_perm(t2, t4, swiz4);
-                        vec_xst(t5, 0, vecOffset);
-                        vec_xst(t6, 0, vecOffset+16);
-                        vec_xst(t7, 0, vecOffset+32);
-                        vec_xst(t8, 0, vecOffset+48);
-
-                        t1 = vec_perm(c1[1], c2[1], swiz1);
-                        t2 = vec_perm(c1[1], c2[1], swiz2);
-                        t3 = vec_perm(c3[1], c4[1], swiz1);
-                        t4 = vec_perm(c3[1], c4[1], swiz2);
-                        t5 = vec_perm(t1, t3, swiz3);
-                        t6 = vec_perm(t1, t3, swiz4);
-                        t7 = vec_perm(t2, t4, swiz3);
-                        t8 = vec_perm(t2, t4, swiz4);
-                        vec_xst(t5, 0, vecOffset+64);
-                        vec_xst(t6, 0, vecOffset+80);
-                        vec_xst(t7, 0, vecOffset+96);
-                        vec_xst(t8, 0, vecOffset+112);
-
-                        t1 = vec_perm(c5[0], c6[0], swiz1);
-                        t2 = vec_perm(c5[0], c6[0], swiz2);
-                        t3 = vec_perm(c7[0], c8[0], swiz1);
-                        t4 = vec_perm(c7[0], c8[0], swiz2);
-                        t5 = vec_perm(t1, t3, swiz3);
-                        t6 = vec_perm(t1, t3, swiz4);
-                        t7 = vec_perm(t2, t4, swiz3);
-                        t8 = vec_perm(t2, t4, swiz4);
-                        vec_xst(t5, 0, vecOffset+128);
-                        vec_xst(t6, 0, vecOffset+144);
-                        vec_xst(t7, 0, vecOffset+160);
-                        vec_xst(t8, 0, vecOffset+176);
-
-                        t1 = vec_perm(c5[1], c6[1], swiz1);
-                        t2 = vec_perm(c5[1], c6[1], swiz2);
-                        t3 = vec_perm(c7[1], c8[1], swiz1);
-                        t4 = vec_perm(c7[1], c8[1], swiz2);
-                        t5 = vec_perm(t1, t3, swiz3);
-                        t6 = vec_perm(t1, t3, swiz4);
-                        t7 = vec_perm(t2, t4, swiz3);
-                        t8 = vec_perm(t2, t4, swiz4);
-                        vec_xst(t5, 0, vecOffset+192);
-                        vec_xst(t6, 0, vecOffset+208);
-                        vec_xst(t7, 0, vecOffset+224);
-                        vec_xst(t8, 0, vecOffset+240);
-
+                        c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs));
+                        c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs));
+                        c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs));
+                        c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs));
+                        c5[1] = reinterpret_cast(vec_xl(0, aoffset5->qs));
+                        c6[1] = reinterpret_cast(vec_xl(0, aoffset6->qs));
+                        c7[1] = reinterpret_cast(vec_xl(0, aoffset7->qs));
+                        c8[1] = reinterpret_cast(vec_xl(0, aoffset8->qs));
+
+                        process_q4_elements(c1, &comparray[0]);
+                        process_q4_elements(c2, &comparray[1]);
+                        process_q4_elements(c3, &comparray[2]);
+                        process_q4_elements(c4, &comparray[3]);
+                        process_q4_elements(c5, &comparray[4]);
+                        process_q4_elements(c6, &comparray[5]);
+                        process_q4_elements(c7, &comparray[6]);
+                        process_q4_elements(c8, &comparray[7]);
+                        vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
+                        vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
+                        vector_permute_store(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
+                        vector_permute_store(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
                         aoffset1 += lda;
                         aoffset2 += lda;
                         aoffset3 += lda;
@@ -1821,85 +1729,20 @@ class tinyBLAS_Q0_PPC {
             aoffset3 = aoffset2 + lda;
             aoffset4 = aoffset3 + lda;
             aoffset += 4 * lda;
-
             i = (cols >> 2);
             if (i > 0) {
                 do {
-                    c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs));
-                    c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs));
-                    c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs));
-                    c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs));
-
-                    c1[0] = vec_and(c1[1], lowMask);
-                    c1[1] = vec_sr(c1[1], v4);
-                    c1[0] = vec_sub(c1[0], v8);
-                    c1[1] = vec_sub(c1[1], v8);
-                    vsum = vec_sum4s(c1[0], vsum);
-                    vsum2 = vec_sum4s(c1[1], vsum2);
-                    vsum = vec_add(vsum, vsum2);
-                    comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
-                    vsum = vec_splats(0);
-                    vsum2 = vec_splats(0);
-
-                    c2[0] = vec_and(c2[1], lowMask);
-                    c2[1] = vec_sr(c2[1], v4);
-                    c2[0] = vec_sub(c2[0], v8);
-                    c2[1] = vec_sub(c2[1], v8);
-                    vsum = vec_sum4s(c2[0], vsum);
-                    vsum2 = vec_sum4s(c2[1], vsum2);
-                    vsum = vec_add(vsum, vsum2);
-                    comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
-                    vsum = vec_splats(0);
-                    vsum2 = vec_splats(0);
-
-                    c3[0] = vec_and(c3[1], lowMask);
-                    c3[1] = vec_sr(c3[1], v4);
-                    c3[0] = vec_sub(c3[0], v8);
-                    c3[1] = vec_sub(c3[1], v8);
-                    vsum = vec_sum4s(c3[0], vsum);
-                    vsum2 = vec_sum4s(c3[1], vsum2);
-                    vsum = vec_add(vsum, vsum2);
-                    comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
-                    vsum = vec_splats(0);
-                    vsum2 = vec_splats(0);
-
-                    c4[0] = vec_and(c4[1], lowMask);
-                    c4[1] = vec_sr(c4[1], v4);
-                    c4[0] = vec_sub(c4[0], v8);
-                    c4[1] = vec_sub(c4[1], v8);
-                    vsum = vec_sum4s(c4[0], vsum);
-                    vsum2 = vec_sum4s(c4[1], vsum2);
-                    vsum = vec_add(vsum, vsum2);
-                    comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
-                    vsum = vec_splats(0);
-                    vsum2 = vec_splats( 0);
-
-                    t1 = vec_perm(c1[0], c2[0], swiz1);
-                    t2 = vec_perm(c1[0], c2[0], swiz2);
-                    t3 = vec_perm(c3[0], c4[0], swiz1);
-                    t4 = vec_perm(c3[0], c4[0], swiz2);
-                    t5 = vec_perm(t1, t3, swiz3);
-                    t6 = vec_perm(t1, t3, swiz4);
-                    t7 = vec_perm(t2, t4, swiz3);
-                    t8 = vec_perm(t2, t4, swiz4);
-                    vec_xst(t5, 0, vecOffset);
-                    vec_xst(t6, 0, vecOffset+16);
-                    vec_xst(t7, 0, vecOffset+32);
-                    vec_xst(t8, 0, vecOffset+48);
-
-                    t1 = vec_perm(c1[1], c2[1], swiz1);
-                    t2 = vec_perm(c1[1], c2[1], swiz2);
-                    t3 = vec_perm(c3[1], c4[1], swiz1);
-                    t4 = vec_perm(c3[1], c4[1], swiz2);
-                    t5 = vec_perm(t1, t3, swiz3);
-                    t6 = vec_perm(t1, t3, swiz4);
-                    t7 = vec_perm(t2, t4, swiz3);
-                    t8 = vec_perm(t2, t4, swiz4);
-                    vec_xst(t5, 0, vecOffset+64);
-                    vec_xst(t6, 0, vecOffset+80);
-                    vec_xst(t7, 0, vecOffset+96);
-                    vec_xst(t8, 0, vecOffset+112);
-
+                    c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs));
+                    c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs));
+                    c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs));
+                    c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs));
+
+                    process_q4_elements(c1, &comparray[0]);
+                    process_q4_elements(c2, &comparray[1]);
+                    process_q4_elements(c3, &comparray[2]);
+                    process_q4_elements(c4, &comparray[3]);
+                    vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
+                    vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
                     aoffset1 += lda;
                     aoffset2 += lda;
                     aoffset3 += lda;
@@ -1918,80 +1761,17 @@ class tinyBLAS_Q0_PPC {
             if (i > 0) {
                 do {
                     switch(rows) {
-                        case 3: c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs));
-                        case 2: c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs));
-                        case 1: c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs));
+                        case 3: c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs));
+                        case 2: c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs));
+                        case 1: c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs));
                             break;
                     }
-                    c1[0] = vec_and(c1[1], lowMask);
-                    c1[1] = vec_sr(c1[1], v4);
-                    c1[0] = vec_sub(c1[0], v8);
-                    c1[1] = vec_sub(c1[1], v8);
-                    vsum = vec_sum4s(c1[0], vsum);
-                    vsum2 = vec_sum4s(c1[1], vsum2);
-                    vsum = vec_add(vsum, vsum2);
-                    comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
-                    vsum = vec_splats(0);
-                    vsum2 = vec_splats(0);
-
-                    c2[0] = vec_and(c2[1], lowMask);
-                    c2[1] = vec_sr(c2[1], v4);
-                    c2[0] = vec_sub(c2[0], v8);
-                    c2[1] = vec_sub(c2[1], v8);
-                    vsum = vec_sum4s(c2[0], vsum);
-                    vsum2 = vec_sum4s(c2[1], vsum2);
-                    vsum = vec_add(vsum, vsum2);
-                    comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
-                    vsum = vec_splats(0);
-                    vsum2 = vec_splats(0);
-
-                    c3[0] = vec_and(c3[1], lowMask);
-                    c3[1] = vec_sr(c3[1], v4);
-                    c3[0] = vec_sub(c3[0], v8);
-                    c3[1] = vec_sub(c3[1], v8);
-                    vsum = vec_sum4s(c3[0], vsum);
-                    vsum2 = vec_sum4s(c3[1], vsum2);
-                    vsum = vec_add(vsum, vsum2);
-                    comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
-                    vsum = vec_splats(0);
-                    vsum2 = vec_splats(0);
-
-                    c4[0] = vec_and(c4[1], lowMask);
-                    c4[1] = vec_sr(c4[1], v4);
-                    c4[0] = vec_sub(c4[0], v8);
-                    c4[1] = vec_sub(c4[1], v8);
-                    vsum = vec_sum4s(c4[0], vsum);
-                    vsum2 = vec_sum4s(c4[1], vsum2);
-                    vsum = vec_add(vsum, vsum2);
-                    comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
-                    vsum = vec_splats(0);
-                    vsum2 = vec_splats(0);
-
-                    t1 = vec_perm(c1[0], c2[0], swiz1);
-                    t2 = vec_perm(c1[0], c2[0], swiz2);
-                    t3 = vec_perm(c3[0], c4[0], swiz1);
-                    t4 = vec_perm(c3[0], c4[0], swiz2);
-                    t5 = vec_perm(t1, t3, swiz3);
-                    t6 = vec_perm(t1, t3, swiz4);
-                    t7 = vec_perm(t2, t4, swiz3);
-                    t8 = vec_perm(t2, t4, swiz4);
-                    vec_xst(t5, 0, vecOffset);
-                    vec_xst(t6, 0, vecOffset+16);
-                    vec_xst(t7, 0, vecOffset+32);
-                    vec_xst(t8, 0, vecOffset+48);
-
-                    t1 = vec_perm(c1[1], c2[1], swiz1);
-                    t2 = vec_perm(c1[1], c2[1], swiz2);
-                    t3 = vec_perm(c3[1], c4[1], swiz1);
-                    t4 = vec_perm(c3[1], c4[1], swiz2);
-                    t5 = vec_perm(t1, t3, swiz3);
-                    t6 = vec_perm(t1, t3, swiz4);
-                    t7 = vec_perm(t2, t4, swiz3);
-                    t8 = vec_perm(t2, t4, swiz4);
-                    vec_xst(t5, 0, vecOffset+64);
-                    vec_xst(t6, 0, vecOffset+80);
-                    vec_xst(t7, 0, vecOffset+96);
-                    vec_xst(t8, 0, vecOffset+112);
+                    process_q4_elements(c1, &comparray[0]);
+                    process_q4_elements(c2, &comparray[1]);
+                    process_q4_elements(c3, &comparray[2]);
+                    process_q4_elements(c4, &comparray[3]);
+                    vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
+                    vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
                     aoffset1 += lda;
                     aoffset2 += lda;
                     aoffset3 += lda;
@@ -2001,146 +1781,40 @@ class tinyBLAS_Q0_PPC {
             }
         }
     }
-
     template
-    void packNormal(const TB* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
+    void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
         int64_t i, j;
-        TB *aoffset = NULL;
+        block_q8_0 *aoffset = NULL;
         VA *vecOffset = NULL;
-        TB *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
-        TB *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
-        __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
-        VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0};
-        VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0};
-        VB t1, t2, t3, t4, t5, t6, t7, t8;
-        vector unsigned char xor_vector;
-        uint8_t flip_vec = 0x80;
-        xor_vector = vec_splats(flip_vec);
-        vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
-        vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
-        vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
-        vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
-
-        aoffset = const_cast(a);
+        block_q8_0* aoffsets[8];
+        __vector_pair arr[8];
+        VB c[8][2] = {0};
+        VB c1[8] = {0}; VB c2[8] = {0};
+        aoffset = const_cast(a);
         vecOffset = vec;
         j = (rows >> 3);
         if (j > 0) {
             do {
-                aoffset1 = aoffset;
-                aoffset2 = aoffset1 + lda;
-                aoffset3 = aoffset2 + lda;
-                aoffset4 = aoffset3 + lda;
-                aoffset5 = aoffset4 + lda;
-                aoffset6 = aoffset5 + lda;
-                aoffset7 = aoffset6 + lda;
-                aoffset8 = aoffset7 + lda;
+                aoffsets[0] = aoffset;
+                for (int it = 1; it < 8; it++)
+                    aoffsets[it] = aoffsets[it-1] + lda;
                 aoffset += 8 * lda;
 
                 i = (cols >> 3);
                 if (i > 0) {
                 do {
-                    C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
-                    C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
-                    C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
-                    C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs);
-                    C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5->qs);
-                    C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6->qs);
-                    C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7->qs);
-                    C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8->qs);
-
-                    __builtin_vsx_disassemble_pair(c1, &C1);
-                    __builtin_vsx_disassemble_pair(c2, &C2);
-                    __builtin_vsx_disassemble_pair(c3, &C3);
-                    __builtin_vsx_disassemble_pair(c4, &C4);
-                    __builtin_vsx_disassemble_pair(c5, &C5);
-                    __builtin_vsx_disassemble_pair(c6, &C6);
-                    __builtin_vsx_disassemble_pair(c7, &C7);
-                    __builtin_vsx_disassemble_pair(c8, &C8);
-
-                    t1 = vec_perm(c1[0], c2[0], swiz1);
-                    t2 = vec_perm(c1[0], c2[0], swiz2);
-                    t3 = vec_perm(c3[0], c4[0], swiz1);
-                    t4 = vec_perm(c3[0], c4[0], swiz2);
-                    t5 = vec_perm(t1, t3, swiz3);
-                    t6 = vec_perm(t1, t3, swiz4);
-                    t7 = vec_perm(t2, t4, swiz3);
-                    t8 = vec_perm(t2, t4, swiz4);
-                    if (flip == true) {
-                        t5 = vec_xor(t5, xor_vector);
-                        t6 = vec_xor(t6, xor_vector);
-                        t7 = vec_xor(t7, xor_vector);
-                        t8 = vec_xor(t8, xor_vector);
-                    }
-                    vec_xst(t5, 0, vecOffset);
-                    vec_xst(t6, 0, vecOffset+16);
-                    vec_xst(t7, 0, vecOffset+32);
-                    vec_xst(t8, 0, vecOffset+48);
-
-                    t1 = vec_perm(c1[1], c2[1], swiz1);
-                    t2 = vec_perm(c1[1], c2[1], swiz2);
-                    t3 = vec_perm(c3[1], c4[1], swiz1);
-                    t4 = vec_perm(c3[1], c4[1], swiz2);
-                    t5 = vec_perm(t1, t3, swiz3);
-                    t6 = vec_perm(t1, t3, swiz4);
-                    t7 = vec_perm(t2, t4, swiz3);
-                    t8 = vec_perm(t2, t4, swiz4);
-                    if (flip == true) {
-                        t5 = vec_xor(t5, xor_vector);
-                        t6 = vec_xor(t6, xor_vector);
-                        t7 = vec_xor(t7, xor_vector);
-                        t8 = vec_xor(t8, xor_vector);
-                    }
-                    vec_xst(t5, 0, vecOffset+64);
-                    vec_xst(t6, 0, vecOffset+80);
-                    vec_xst(t7, 0, vecOffset+96);
-                    vec_xst(t8, 0, vecOffset+112);
-
-                    t1 = vec_perm(c5[0], c6[0], swiz1);
-                    t2 = vec_perm(c5[0], c6[0], swiz2);
-                    t3 = vec_perm(c7[0], c8[0], swiz1);
-                    t4 = vec_perm(c7[0], c8[0], swiz2);
-                    t5 = vec_perm(t1, t3, swiz3);
-                    t6 = vec_perm(t1, t3, swiz4);
-                    t7 = vec_perm(t2, t4, swiz3);
-                    t8 = vec_perm(t2, t4, swiz4);
-                    if (flip == true) {
-                        t5 = vec_xor(t5, xor_vector);
-                        t6 = vec_xor(t6, xor_vector);
-                        t7 = vec_xor(t7, xor_vector);
-                        t8 = vec_xor(t8, xor_vector);
-                    }
-                    vec_xst(t5, 0, vecOffset+128);
-                    vec_xst(t6, 0, vecOffset+144);
-                    vec_xst(t7, 0, vecOffset+160);
-                    vec_xst(t8, 0, vecOffset+176);
-
-                    t1 = vec_perm(c5[1], c6[1], swiz1);
-                    t2 = vec_perm(c5[1], c6[1], swiz2);
-                    t3 = vec_perm(c7[1], c8[1], swiz1);
-                    t4 = vec_perm(c7[1], c8[1], swiz2);
-                    t5 = vec_perm(t1, t3, swiz3);
-                    t6 = vec_perm(t1, t3, swiz4);
-                    t7 = vec_perm(t2, t4, swiz3);
-                    t8 = vec_perm(t2, t4, swiz4);
-                    if (flip == true) {
-                        t5 = vec_xor(t5, xor_vector);
-                        t6 = vec_xor(t6, xor_vector);
-                        t7 = vec_xor(t7, xor_vector);
-                        t8 = vec_xor(t8, xor_vector);
+                    for (int it = 0; it < 8; it++) {
+                        arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
+                        __builtin_vsx_disassemble_pair(c[it], &arr[it]);
+                        c1[it] = c[it][0];
+                        c2[it] = c[it][1];
                     }
-                    vec_xst(t5, 0, vecOffset+192);
-                    vec_xst(t6, 0, vecOffset+208);
-                    vec_xst(t7, 0, vecOffset+224);
-                    vec_xst(t8, 0, vecOffset+240);
-
-                    aoffset1 += lda;
-                    aoffset2 += lda;
-                    aoffset3 += lda;
-                    aoffset4 += lda;
-                    aoffset5 += lda;
-                    aoffset6 += lda;
-                    aoffset7 += lda;
-                    aoffset8 += lda;
+                    vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
+                    vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
+                    vector_permute_store(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
+                    vector_permute_store(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
+                    for (int it = 0; it < 8; it++)
+                        aoffsets[it] += lda;
                     vecOffset += 256;
                     i--;
                } while(i > 0);
@@ -2150,129 +1824,53 @@ class tinyBLAS_Q0_PPC {
     }
 
     if (rows & 4) {
-        aoffset1 = aoffset;
-        aoffset2 = aoffset1 + lda;
-        aoffset3 = aoffset2 + lda;
-        aoffset4 = aoffset3 + lda;
-        aoffset += 4 * lda;
-
+            aoffsets[0]  = aoffset;
+            for (int it = 1; it < 4; it++ )
+                aoffsets[it] = aoffsets[it-1] + lda;
+            aoffset += 4 * lda;
         i = (cols >> 3);
             if (i > 0) {
                do {
-                    C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
-                    C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
-                    C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
-                    C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs);
-
-                    __builtin_vsx_disassemble_pair(c1, &C1);
-                    __builtin_vsx_disassemble_pair(c2, &C2);
-                    __builtin_vsx_disassemble_pair(c3, &C3);
-                    __builtin_vsx_disassemble_pair(c4, &C4);
-
-                    t1 = vec_perm(c1[0], c2[0], swiz1);
-                    t2 = vec_perm(c1[0], c2[0], swiz2);
-                    t3 = vec_perm(c3[0], c4[0], swiz1);
-                    t4 = vec_perm(c3[0], c4[0], swiz2);
-                    t5 = vec_perm(t1, t3, swiz3);
-                    t6 = vec_perm(t1, t3, swiz4);
-                    t7 = vec_perm(t2, t4, swiz3);
-                    t8 = vec_perm(t2, t4, swiz4);
-                    if (flip == true) {
-                       t5 = vec_xor(t5, xor_vector);
-                       t6 = vec_xor(t6, xor_vector);
-                       t7 = vec_xor(t7, xor_vector);
-                       t8 = vec_xor(t8, xor_vector);
+                    for (int it = 0; it < 4; it++) {
+                        arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
+                        __builtin_vsx_disassemble_pair(c[it], &arr[it]);
+                        c1[it] = c[it][0];
+                        c2[it] = c[it][1];
                     }
-                    vec_xst(t5, 0, vecOffset);
-                    vec_xst(t6, 0, vecOffset+16);
-                    vec_xst(t7, 0, vecOffset+32);
-                    vec_xst(t8, 0, vecOffset+48);
-
-                    t1 = vec_perm(c1[1], c2[1], swiz1);
-                    t2 = vec_perm(c1[1], c2[1], swiz2);
-                    t3 = vec_perm(c3[1], c4[1], swiz1);
-                    t4 = vec_perm(c3[1], c4[1], swiz2);
-                    t5 = vec_perm(t1, t3, swiz3);
-                    t6 = vec_perm(t1, t3, swiz4);
-                    t7 = vec_perm(t2, t4, swiz3);
-                    t8 = vec_perm(t2, t4, swiz4);
-                    if (flip == true) {
-                       t5 = vec_xor(t5, xor_vector);
-                       t6 = vec_xor(t6, xor_vector);
-                       t7 = vec_xor(t7, xor_vector);
-                       t8 = vec_xor(t8, xor_vector);
+                    vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
+                    vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
+                    for (int it = 0; it < 4; it++) {
+                        aoffsets[it] += lda;
                     }
-                    vec_xst(t5, 0, vecOffset+64);
-                    vec_xst(t6, 0, vecOffset+80);
-                    vec_xst(t7, 0, vecOffset+96);
-                    vec_xst(t8, 0, vecOffset+112);
-
-                    aoffset1 += lda;
-                    aoffset2 += lda;
-                    aoffset3 += lda;
-                    aoffset4 += lda;
                     vecOffset += 128;
                     i--;
                } while(i > 0);
             }
         }
+
         if (rows & 3) {
-            aoffset1 = aoffset;
-            aoffset2 = aoffset1 + lda;
-            aoffset3 = aoffset2 + lda;
+            aoffsets[0]  = aoffset;
+            for (int it = 1; it < 3; it++ )
+                aoffsets[it] = aoffsets[it-1] + lda;
             i = (cols >> 3);
             if (i > 0) {
                 do {
                     switch(rows) {
-                        case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
-                                __builtin_vsx_disassemble_pair(c3, &C3);
-                        case 2: C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
-                                __builtin_vsx_disassemble_pair(c2, &C2);
-                        case 1: C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
-                                __builtin_vsx_disassemble_pair(c1, &C1);
+                        case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs);
+                                __builtin_vsx_disassemble_pair(c[2], &arr[2]);
+                                c1[2] = c[2][0]; c2[2] = c[2][1];
+                        case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs);
+                                __builtin_vsx_disassemble_pair(c[1], &arr[1]);
+                                c1[1] = c[1][0]; c2[1] = c[1][1];
+                        case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs);
+                                __builtin_vsx_disassemble_pair(c[0], &arr[0]);
+                                c1[0] = c[0][0]; c2[0] = c[0][1];
                                 break;
                     }
-                    t1 = vec_perm(c1[0], c2[0], swiz1);
-                    t2 = vec_perm(c1[0], c2[0], swiz2);
-                    t3 = vec_perm(c3[0], c4[0], swiz1);
-                    t4 = vec_perm(c3[0], c4[0], swiz2);
-                    t5 = vec_perm(t1, t3, swiz3);
-                    t6 = vec_perm(t1, t3, swiz4);
-                    t7 = vec_perm(t2, t4, swiz3);
-                    t8 = vec_perm(t2, t4, swiz4);
-                    if (flip == true) {
-                       t5 = vec_xor(t5, xor_vector);
-                       t6 = vec_xor(t6, xor_vector);
-                       t7 = vec_xor(t7, xor_vector);
-                       t8 = vec_xor(t8, xor_vector);
-                    }
-                    vec_xst(t5, 0, vecOffset);
-                    vec_xst(t6, 0, vecOffset+16);
-                    vec_xst(t7, 0, vecOffset+32);
-                    vec_xst(t8, 0, vecOffset+48);
-
-                    t1 = vec_perm(c1[1], c2[1], swiz1);
-                    t2 = vec_perm(c1[1], c2[1], swiz2);
-                    t3 = vec_perm(c3[1], c4[1], swiz1);
-                    t4 = vec_perm(c3[1], c4[1], swiz2);
-                    t5 = vec_perm(t1, t3, swiz3);
-                    t6 = vec_perm(t1, t3, swiz4);
-                    t7 = vec_perm(t2, t4, swiz3);
-                    t8 = vec_perm(t2, t4, swiz4);
-                    if (flip == true) {
-                       t5 = vec_xor(t5, xor_vector);
-                       t6 = vec_xor(t6, xor_vector);
-                       t7 = vec_xor(t7, xor_vector);
-                       t8 = vec_xor(t8, xor_vector);
-                    }
-                    vec_xst(t5, 0, vecOffset+64);
-                    vec_xst(t6, 0, vecOffset+80);
-                    vec_xst(t7, 0, vecOffset+96);
-                    vec_xst(t8, 0, vecOffset+112);
-
-                    aoffset1 += lda;
-                    aoffset2 += lda;
-                    aoffset3 += lda;
+                    vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
+                    vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
+                    for (int it = 0; it < 3; it++)
+                         aoffsets[it] += lda;
                     vecOffset += 128;
                     i--;
                } while(i > 0);
@@ -2281,159 +1879,42 @@ class tinyBLAS_Q0_PPC {
     }
 
     void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
-        int64_t mc, nc, mp, np;
-        int m_rem = MIN(m - m0, 8);
-        int n_rem = MIN(n - n0, 8);
-        // TO-DO: KERNEL_16x8 and KERNEL_8x16 are having some performance
-        // issues. After resolving them, below code will be enabled.
-        /*if (m_rem >= 16 && n_rem >= 8) {
-            mc = 16;
-            nc = 8;
-            gemm<16,8>(m0, m, n0, n);
-        } else if(m_rem >= 8 && n_rem >= 16) {
-            mc = 8;
-            nc = 16;
-            gemm<8,16>(m0, m, n0, n);
-        }*/
+        int m_rem = MIN(m - m0, 16);
+        int n_rem = MIN(n - n0, 16);
+
+        int mc = 0, nc = 0;
+
         if (m_rem >= 8 && n_rem >= 8) {
-            mc = 8;
-            nc = 8;
-            gemm<8,8>(m0, m, n0, n);
+           mc = 8;
+           nc = 8;
+           gemm<8, 8>(m0, m, n0, n);
         } else if (m_rem >= 4 && n_rem >= 8) {
             mc = 4;
             nc = 8;
-            gemm<4,8>(m0, m, n0, n);
+            gemm<4, 8>(m0, m, n0, n);
         } else if (m_rem >= 8 && n_rem >= 4) {
             mc = 8;
             nc = 4;
-            gemm<8,4>(m0, m, n0, n);
+            gemm<8, 4>(m0, m, n0, n);
         } else if (m_rem >= 4 && n_rem >= 4) {
             mc = 4;
             nc = 4;
-            gemm_small<4, 4>(m0, m, n0, n);
-        } else if ((m_rem < 4) && (n_rem > 4)) {
-            nc = 4;
-            switch(m_rem) {
-                case 1:
-                    mc = 1;
-                    gemm_small<1, 4>(m0, m, n0, n);
-                    break;
-                case 2:
-                    mc = 2;
-                    gemm_small<2, 4>(m0, m, n0, n);
-                    break;
-                case 3:
-                    mc = 3;
-                    gemm_small<3, 4>(m0, m, n0, n);
-                    break;
-                default:
-                    return;
-            }
-        } else if ((m_rem > 4) && (n_rem < 4)) {
-            mc = 4;
-            switch(n_rem) {
-                case 1:
-                    nc = 1;
-                    gemm_small<4, 1>(m0, m, n0, n);
-                    break;
-                case 2:
-                    nc = 2;
-                    gemm_small<4, 2>(m0, m, n0, n);
-                    break;
-                case 3:
-                    nc = 3;
-                    gemm_small<4, 3>(m0, m, n0, n);
-                    break;
-                default:
-                    return;
-            }
+            gemm_small(m0, m, n0, n, mc, nc);
         } else {
-            switch((m_rem << 4) | n_rem) {
-                case 0x43:
-                    mc = 4;
-                    nc = 3;
-                    gemm_small<4, 3>(m0, m, n0, n);
-                    break;
-                case 0x42:
-                    mc = 4;
-                    nc = 2;
-                    gemm_small<4, 2>(m0, m, n0, n);
-                    break;
-                case 0x41:
-                    mc = 4;
-                    nc = 1;
-                    gemm_small<4, 1>(m0, m, n0, n);
-                    break;
-                case 0x34:
-                    mc = 3;
-                    nc = 4;
-                    gemm_small<3, 4>(m0, m, n0, n);
-                    break;
-                case 0x33:
-                    mc = 3;
-                    nc = 3;
-                    gemm_small<3, 3>(m0, m, n0, n);
-                    break;
-                case 0x32:
-                    mc = 3;
-                    nc = 2;
-                    gemm_small<3, 2>(m0, m, n0, n);
-                    break;
-                case 0x31:
-                    mc = 3;
-                    nc = 1;
-                    gemm_small<3, 1>(m0, m, n0, n);
-                    break;
-                case 0x24:
-                    mc = 2;
-                    nc = 4;
-                    gemm_small<2, 4>(m0, m, n0, n);
-                    break;
-                case 0x23:
-                    mc = 2;
-                    nc = 3;
-                    gemm_small<2, 3>(m0, m, n0, n);
-                    break;
-                case 0x22:
-                    mc = 2;
-                    nc = 2;
-                    gemm_small<2, 2>(m0, m, n0, n);
-                    break;
-                case 0x21:
-                    mc = 2;
-                    nc = 1;
-                    gemm_small<2, 1>(m0, m, n0, n);
-                    break;
-                case 0x14:
-                    mc = 1;
-                    nc = 4;
-                    gemm_small<1, 4>(m0, m, n0, n);
-                    break;
-                case 0x13:
-                    mc = 1;
-                    nc = 3;
-                    gemm_small<1, 3>(m0, m, n0, n);
-                    break;
-                case 0x12:
-                    mc = 1;
-                    nc = 2;
-                    gemm_small<1, 2>(m0, m, n0, n);
-                    break;
-                case 0x11:
-                    mc = 1;
-                    nc = 1;
-                    gemm_small<1, 1>(m0, m, n0, n);
-                    break;
-                default:
-                    return;
-            }
+            mc = (m_rem >= 4) ? 4 : m_rem;
+            nc = (n_rem >= 4) ? 4 : n_rem;
+            if (mc == 0 || nc == 0)
+               return;
+            gemm_small(m0, m, n0, n, mc, nc);
         }
-        mp = m0 + (m - m0) / mc * mc;
-        np = n0 + (n - n0) / nc * nc;
+
+        int64_t mp = m0 + ((m - m0) / mc) * mc;
+        int64_t np = n0 + ((n - n0) / nc) * nc;
         mnpack(mp, m, n0, np);
         mnpack(m0, m, np, n);
     }
 
+
     void KERNEL_4x8(int64_t ii, int64_t jj) {
         vec_t vec_A[8], vec_B[16] = {0};
         acc_t acc_0, acc_1;
@@ -2445,9 +1926,9 @@ class tinyBLAS_Q0_PPC {
             __builtin_mma_xxsetaccz(&acc_0);
             __builtin_mma_xxsetaccz(&acc_1);
             if (std::is_same_v) {
-               packNormalInt4((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
+               packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
             } else {
-               packNormal((const TB*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
+               packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
             }
             packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
             for(int x = 0; x < 8; x++) {
@@ -2475,8 +1956,8 @@ class tinyBLAS_Q0_PPC {
             compute<4>(&acc_0, 0, 0, comparray, vs, fin_res);
             compute<4>(&acc_1, 0, 4, comparray, vs, fin_res);
         }
-        save_res<4, 4>(ii, jj, 0, fin_res);
-        save_res<4, 4>(ii, jj+4, 4, fin_res);
+        save_res(ii, jj, 0, fin_res);
+        save_res(ii, jj+4, 4, fin_res);
     }
 
     void KERNEL_8x4(int64_t ii, int64_t jj) {
@@ -2490,9 +1971,9 @@ class tinyBLAS_Q0_PPC {
             __builtin_mma_xxsetaccz(&acc_0);
             __builtin_mma_xxsetaccz(&acc_1);
             if (std::is_same_v) {
-               packNormalInt4((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
+               packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
             } else {
-               packNormal((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
+               packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
             }
             packNormal((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
             for(int x = 0; x < 8; x++) {
@@ -2519,8 +2000,8 @@ class tinyBLAS_Q0_PPC {
             compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
             compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
         }
-        save_res<4, 4>(ii, jj, 0, fin_res);
-        save_res<4, 4>(ii+4, jj, 4, fin_res);
+        save_res(ii, jj, 0, fin_res);
+        save_res(ii+4, jj, 4, fin_res);
     }
 
     void KERNEL_8x8(int64_t ii, int64_t jj) {
@@ -2536,9 +2017,9 @@ class tinyBLAS_Q0_PPC {
             __builtin_mma_xxsetaccz(&acc_2);
             __builtin_mma_xxsetaccz(&acc_3);
             if (std::is_same_v) {
-               packNormalInt4((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
+               packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
             } else {
-               packNormal((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
+               packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
             }
             packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
             for(int x = 0; x < 8; x++) {
@@ -2570,14 +2051,13 @@ class tinyBLAS_Q0_PPC {
             compute<8>(&acc_2, 0, 8, comparray, vs, fin_res);
             compute<8>(&acc_3, 4, 12, comparray, vs, fin_res);
         }
-        save_res<4, 4>(ii, jj, 0, fin_res);
-        save_res<4, 4>(ii+4, jj, 4, fin_res);
-        save_res<4, 4>(ii, jj+4, 8, fin_res);
-        save_res<4, 4>(ii+4, jj+4, 12, fin_res);
+        save_res(ii, jj, 0, fin_res);
+        save_res(ii+4, jj, 4, fin_res);
+        save_res(ii, jj+4, 8, fin_res);
+        save_res(ii+4, jj+4, 12, fin_res);
     }
 
-    template
-    void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+    void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
         int64_t ytiles = (m - m0) / RM;
         int64_t xtiles = (n - n0) / RN;
         int64_t tiles = xtiles * ytiles;
@@ -2606,9 +2086,9 @@ class tinyBLAS_Q0_PPC {
                 __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
                 __builtin_mma_xxsetaccz(&acc_0);
                 if (isAblock_q4) {
-                   packNormalInt4((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
+                   packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
                 } else {
-                   packNormal((const TB*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
+                   packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
                 }
                 packNormal((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
                 for(int x = 0; x < 8; x+=4) {
@@ -2641,7 +2121,7 @@ class tinyBLAS_Q0_PPC {
                     fin_res[i] = vec_madd(res[i], vs[i], fin_res[i]);
                 }
             }
-            save_res(ii, jj, 0, fin_res);
+            save_res(ii, jj, 0, fin_res, RM, RN);
         }
     }
 
@@ -2654,7 +2134,7 @@ class tinyBLAS_Q0_PPC {
        } else if constexpr(RM == 8 && RN == 8) {
           KERNEL_8x8(ii,jj);
        } else {
-          static_assert(false, "RN/RM values not supported");
+          assert(false && "RN/RM values not supported");
        }
     }
 
@@ -2676,10 +2156,8 @@ class tinyBLAS_Q0_PPC {
     }
 
     const TA *const A;
-    const TB *const B;
-    TC *C;
-    TA *At;
-    TB *Bt;
+    const block_q8_0 *const B;
+    float *C;
     const int64_t k;
     const int64_t lda;
     const int64_t ldb;
@@ -2688,13 +2166,12 @@ class tinyBLAS_Q0_PPC {
     const int nth;
 };
 
-template 
 class tinyBLAS_PPC {
   public:
     tinyBLAS_PPC(int64_t k,
-                const TA *A, int64_t lda,
-                const TB *B, int64_t ldb,
-                TC *C, int64_t ldc,
+                const float *A, int64_t lda,
+                const float *B, int64_t ldb,
+                float *C, int64_t ldc,
                 int ith, int nth)
         : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
     }
@@ -2707,247 +2184,139 @@ class tinyBLAS_PPC {
 
     void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
 
-    template
-    void packTranspose(const TA* a, int64_t lda, int rows, int cols, TA* vec) {
+    inline void vector_permute_store_4(vector float *src, float *vecOffset) {
+       vector float t1, t2, t3, t4, t5, t6, t7, t8;
+           t1 = vec_mergeh(src[0], src[1]);
+           t2 = vec_mergeh(src[2], src[3]);
+           t3 = vec_mergel(src[0], src[1]);
+           t4 = vec_mergel(src[2], src[3]);
+
+           t5 = vec_xxpermdi(t1, t2, 0);
+           t6 = vec_xxpermdi(t1, t2, 3);
+           t7 = vec_xxpermdi(t3, t4, 0);
+           t8 = vec_xxpermdi(t3, t4, 3);
+
+           vec_xst(t5, 0, vecOffset);
+           vec_xst(t6, 0, vecOffset + 4);
+           vec_xst(t7, 0, vecOffset + 8);
+           vec_xst(t8, 0, vecOffset + 12);
+       }
+
+    inline void vector_permute_store_8(vector float *src, float *vecOffset) {
+       vector float t1, t2, t3, t4, t5, t6, t7, t8;
+          t1 = vec_mergeh(src[0], src[1]);
+          t2 = vec_mergeh(src[2], src[3]);
+          t3 = vec_mergeh(src[4], src[5]);
+          t4 = vec_mergeh(src[6], src[7]);
+
+          t5 = vec_xxpermdi(t1, t2, 0);
+          t6 = vec_xxpermdi(t3, t4, 0);
+          t7 = vec_xxpermdi(t1, t2, 3);
+          t8 = vec_xxpermdi(t3, t4, 3);
+
+          vec_xst(t5, 0, vecOffset);
+          vec_xst(t6, 0, vecOffset + 4);
+          vec_xst(t7, 0, vecOffset + 8);
+          vec_xst(t8, 0, vecOffset + 12);
+
+          t1 = vec_mergel(src[0], src[1]);
+          t2 = vec_mergel(src[2], src[3]);
+          t3 = vec_mergel(src[4], src[5]);
+          t4 = vec_mergel(src[6], src[7]);
+
+          t5 = vec_xxpermdi(t1, t2, 0);
+          t6 = vec_xxpermdi(t3, t4, 0);
+          t7 = vec_xxpermdi(t1, t2, 3);
+          t8 = vec_xxpermdi(t3, t4, 3);
+
+          vec_xst(t5, 0, vecOffset + 16);
+          vec_xst(t6, 0, vecOffset + 20);
+          vec_xst(t7, 0, vecOffset + 24);
+          vec_xst(t8, 0, vecOffset + 28);
+    }
+
+ void packTranspose(const float* a, int64_t lda, int rows, int cols, float* vec) {
         int64_t i, j;
-        TA *aoffset = NULL, *boffset = NULL;
-        TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
-        TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
-        __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
-        VA c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
-        VA c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
-        VA t1, t2, t3, t4, t5, t6, t7, t8;
-        aoffset = const_cast(a);
+        float * aoffsets[8];
+        float *aoffset = NULL, *boffset = NULL;
+        __vector_pair arr[8];
+        vector float c[8][2] = {0};
+        vector float c1[8] = {0};
+        vector float c2[8] = {0};
+        aoffset = const_cast(a);
         boffset = vec;
         j = (rows >> 3);
         if (j > 0) {
 
             do {
-                aoffset1 = aoffset;
-                aoffset2 = aoffset1 + lda;
-                aoffset3 = aoffset2 + lda;
-                aoffset4 = aoffset3 + lda;
-                aoffset5 = aoffset4 + lda;
-                aoffset6 = aoffset5 + lda;
-                aoffset7 = aoffset6 + lda;
-                aoffset8 = aoffset7 + lda;
+                aoffsets[0] = aoffset;
+                for (int it = 1; it< 8; it++)
+                    aoffsets[it] = aoffsets[it-1] + lda;
                 aoffset += 8 * lda;
                 i = (cols >> 3);
                 if (i > 0) {
                     do {
-                        C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
-                        C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
-                        C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
-                        C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
-                        C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5);
-                        C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6);
-                        C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7);
-                        C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8);
-                        __builtin_vsx_disassemble_pair(c1, &C1);
-                        __builtin_vsx_disassemble_pair(c2, &C2);
-                        __builtin_vsx_disassemble_pair(c3, &C3);
-                        __builtin_vsx_disassemble_pair(c4, &C4);
-                        __builtin_vsx_disassemble_pair(c5, &C5);
-                        __builtin_vsx_disassemble_pair(c6, &C6);
-                        __builtin_vsx_disassemble_pair(c7, &C7);
-                        __builtin_vsx_disassemble_pair(c8, &C8);
-
-                        t1 = vec_mergeh(c1[0], c2[0]);
-                        t2 = vec_mergeh(c3[0], c4[0]);
-                        t3 = vec_mergeh(c5[0], c6[0]);
-                        t4 = vec_mergeh(c7[0], c8[0]);
-                        t5 = vec_xxpermdi(t1, t2, 0);
-                        t6 = vec_xxpermdi(t3, t4, 0);
-                        t7 = vec_xxpermdi(t1, t2, 3);
-                        t8 = vec_xxpermdi(t3, t4, 3);
-                        vec_xst(t5, 0, boffset);
-                        vec_xst(t6, 0, boffset+4);
-                        vec_xst(t7, 0, boffset+8);
-                        vec_xst(t8, 0, boffset+12);
-
-                        t1 = vec_mergel(c1[0], c2[0]);
-                        t2 = vec_mergel(c3[0], c4[0]);
-                        t3 = vec_mergel(c5[0], c6[0]);
-                        t4 = vec_mergel(c7[0], c8[0]);
-                        t5 = vec_xxpermdi(t1, t2, 0);
-                        t6 = vec_xxpermdi(t3, t4, 0);
-                        t7 = vec_xxpermdi(t1, t2, 3);
-                        t8 = vec_xxpermdi(t3, t4, 3);
-                        vec_xst(t5, 0, boffset+16);
-                        vec_xst(t6, 0, boffset+20);
-                        vec_xst(t7, 0, boffset+24);
-                        vec_xst(t8, 0, boffset+28);
-
-                        t1 = vec_mergeh(c1[1], c2[1]);
-                        t2 = vec_mergeh(c3[1], c4[1]);
-                        t3 = vec_mergeh(c5[1], c6[1]);
-                        t4 = vec_mergeh(c7[1], c8[1]);
-                        t5 = vec_xxpermdi(t1, t2, 0);
-                        t6 = vec_xxpermdi(t3, t4, 0);
-                        t7 = vec_xxpermdi(t1, t2, 3);
-                        t8 = vec_xxpermdi(t3, t4, 3);
-                        vec_xst(t5, 0, boffset+32);
-                        vec_xst(t6, 0, boffset+36);
-                        vec_xst(t7, 0, boffset+40);
-                        vec_xst(t8, 0, boffset+44);
-
-                        t1 = vec_mergel(c1[1], c2[1]);
-                        t2 = vec_mergel(c3[1], c4[1]);
-                        t3 = vec_mergel(c5[1], c6[1]);
-                        t4 = vec_mergel(c7[1], c8[1]);
-                        t5 = vec_xxpermdi(t1, t2, 0);
-                        t6 = vec_xxpermdi(t3, t4, 0);
-                        t7 = vec_xxpermdi(t1, t2, 3);
-                        t8 = vec_xxpermdi(t3, t4, 3);
-                        vec_xst(t5, 0, boffset+48);
-                        vec_xst(t6, 0, boffset+52);
-                        vec_xst(t7, 0, boffset+56);
-                        vec_xst(t8, 0, boffset+60);
-
-                        aoffset1 += 8*lda;
-                        aoffset2 += 8*lda;
-                        aoffset3 += 8*lda;
-                        aoffset4 += 8*lda;
+                        for (int it = 0; it< 8; it++) {
+                            arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
+                            __builtin_vsx_disassemble_pair(c[it], &arr[it]);
+                            c1[it] = c[it][0];
+                            c2[it] = c[it][1];
+                        }
+
+                        vector_permute_store_8(c1, boffset);
+                        vector_permute_store_8(c2, boffset+32);
+                        for (int it = 0; it < 4; it++)
+                            aoffsets[it] = aoffsets[it] + 8*lda;
                         boffset += 64;
                         i--;
                     } while(i > 0);
                 }
                 if (cols & 4) {
-                    c1[0] = vec_xl(0, aoffset1);
-                    c2[0] = vec_xl(0, aoffset2);
-                    c3[0] = vec_xl(0, aoffset3);
-                    c4[0] = vec_xl(0, aoffset4);
-                    c5[0] = vec_xl(0, aoffset5);
-                    c6[0] = vec_xl(0, aoffset6);
-                    c7[0] = vec_xl(0, aoffset7);
-                    c8[0] = vec_xl(0, aoffset8);
-
-                    t1 = vec_mergeh(c1[0], c2[0]);
-                    t2 = vec_mergeh(c3[0], c4[0]);
-                    t3 = vec_mergeh(c5[0], c6[0]);
-                    t4 = vec_mergeh(c7[0], c8[0]);
-                    t5 = vec_xxpermdi(t1, t2, 0);
-                    t6 = vec_xxpermdi(t3, t4, 0);
-                    t7 = vec_xxpermdi(t1, t2, 3);
-                    t8 = vec_xxpermdi(t3, t4, 3);
-                    vec_xst(t5, 0, boffset);
-                    vec_xst(t6, 0, boffset+4);
-                    vec_xst(t7, 0, boffset+8);
-                    vec_xst(t8, 0, boffset+12);
-
-                    t1 = vec_mergel(c1[0], c2[0]);
-                    t2 = vec_mergel(c3[0], c4[0]);
-                    t3 = vec_mergel(c5[0], c6[0]);
-                    t4 = vec_mergel(c7[0], c8[0]);
-                    t5 = vec_xxpermdi(t1, t2, 0);
-                    t6 = vec_xxpermdi(t3, t4, 0);
-                    t7 = vec_xxpermdi(t1, t2, 3);
-                    t8 = vec_xxpermdi(t3, t4, 3);
-                    vec_xst(t5, 0, boffset+16);
-                    vec_xst(t6, 0, boffset+20);
-                    vec_xst(t7, 0, boffset+24);
-                    vec_xst(t8, 0, boffset+28);
+                    for (int it = 0; it < 8 ; it++)
+                        c1[it] = vec_xl(0, aoffsets[it]);
+                    vector_permute_store_8(c1, boffset);
                 }
             j--;
             } while(j > 0);
         }
 
         if (rows & 4) {
-            aoffset1 = aoffset;
-            aoffset2 = aoffset1 + lda;
-            aoffset3 = aoffset2 + lda;
-            aoffset4 = aoffset3 + lda;
+            aoffsets[0] = aoffset;
+            for (int it = 1; it < 4; it++)
+                aoffsets[it] = aoffsets[it-1] + lda;
             aoffset += 4 * lda;
             i = (cols >> 3);
             if (i > 0) {
                 do {
-                    C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
-                    C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
-                    C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
-                    C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
-                    __builtin_vsx_disassemble_pair(c1, &C1);
-                    __builtin_vsx_disassemble_pair(c2, &C2);
-                    __builtin_vsx_disassemble_pair(c3, &C3);
-                    __builtin_vsx_disassemble_pair(c4, &C4);
-
-                    t1 = vec_mergeh(c1[0], c2[0]);
-                    t2 = vec_mergeh(c3[0], c4[0]);
-                    t3 = vec_mergel(c1[0], c2[0]);
-                    t4 = vec_mergel(c3[0], c4[0]);
-                    t5 = vec_xxpermdi(t1, t2, 0);
-                    t6 = vec_xxpermdi(t1, t2, 3);
-                    t7 = vec_xxpermdi(t3, t4, 0);
-                    t8 = vec_xxpermdi(t3, t4, 3);
-                    vec_xst(t5, 0, boffset);
-                    vec_xst(t6, 0, boffset+4);
-                    vec_xst(t7, 0, boffset+8);
-                    vec_xst(t8, 0, boffset+12);
-
-                    t1 = vec_mergeh(c1[1], c2[1]);
-                    t2 = vec_mergeh(c3[1], c4[1]);
-                    t3 = vec_mergel(c1[1], c2[1]);
-                    t4 = vec_mergel(c3[1], c4[1]);
-                    t5 = vec_xxpermdi(t1, t2, 0);
-                    t6 = vec_xxpermdi(t1, t2, 3);
-                    t7 = vec_xxpermdi(t3, t4, 0);
-                    t8 = vec_xxpermdi(t3, t4, 3);
-                    vec_xst(t5, 0, boffset+16);
-                    vec_xst(t6, 0, boffset+20);
-                    vec_xst(t7, 0, boffset+24);
-                    vec_xst(t8, 0, boffset+28);
-
-                    aoffset1 += 8*lda;
-                    aoffset2 += 8*lda;
-                    aoffset3 += 8*lda;
-                    aoffset4 += 8*lda;
+                    for (int it = 0; it < 4; it++) {
+                        arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]);
+                        __builtin_vsx_disassemble_pair(c[it], &arr[it]);
+                        c1[it] = c[it][0];
+                        c2[it] = c[it][1];
+                    }
+                    vector_permute_store_4(c1, boffset);
+                    vector_permute_store_4(c2, boffset+16);
+                    for (int it = 0; it < 4; it++)
+                        aoffsets[it] += 8*lda;
                     boffset += 32;
                     i--;
                 } while(i > 0);
             }
 
             if (cols & 4) {
-                c1[0] = vec_xl(0, aoffset1);
-                c2[0] = vec_xl(0, aoffset2);
-                c3[0] = vec_xl(0, aoffset3);
-                c4[0] = vec_xl(0, aoffset4);
-
-                t1 = vec_mergeh(c1[0], c2[0]);
-                t2 = vec_mergeh(c3[0], c4[0]);
-                t3 = vec_xxpermdi(t1, t2, 0);
-                t4 = vec_xxpermdi(t1, t2, 3);
-                vec_xst(t3, 0, boffset);
-                vec_xst(t4, 0, boffset+4);
-
-                t1 = vec_mergel(c1[0], c2[0]);
-                t2 = vec_mergel(c3[0], c4[0]);
-                t3 = vec_xxpermdi(t1, t2, 0);
-                t4 = vec_xxpermdi(t1, t2, 3);
-                vec_xst(t3, 0, boffset+8);
-                vec_xst(t4, 0, boffset+12);
+               for (int it = 0; it < 4; it++)
+                   c1[it] = vec_xl(0, aoffsets[it]);
+                vector_permute_store_4(c1, boffset);
             }
         }
         if (rows & 3) {
-            aoffset1 = aoffset;
-            aoffset2 = aoffset1 + lda;
-            aoffset3 = aoffset2 + lda;
+            aoffsets[0] = aoffset;
+            for (int it = 1; it < 3; it++)
+                aoffsets[it] = aoffsets[it-1] + lda;
             if (cols & 4) {
-                c1[0] = vec_xl(0, aoffset1);
-                c2[0] = vec_xl(0, aoffset2);
-                c3[0] = vec_xl(0, aoffset3);
-
-                t1 = vec_mergeh(c1[0], c2[0]);
-                t2 = vec_mergeh(c3[0], c4[0]);
-                t3 = vec_xxpermdi(t1, t2, 0);
-                t4 = vec_xxpermdi(t1, t2, 3);
-                vec_xst(t3, 0, boffset);
-                vec_xst(t4, 0, boffset+4);
-
-                t1 = vec_mergel(c1[0], c2[0]);
-                t2 = vec_mergel(c3[0], c4[0]);
-                t3 = vec_xxpermdi(t1, t2, 0);
-                t4 = vec_xxpermdi(t1, t2, 3);
-                vec_xst(t3, 0, boffset+8);
-                vec_xst(t4, 0, boffset+12);
+                for (int it = 0; it < 3; it++)
+                    c1[it] = vec_xl(0, aoffsets[it]);
+                vector_permute_store_4(c1, boffset);
             }
         }
     }
@@ -2957,8 +2326,8 @@ class tinyBLAS_PPC {
         acc_t acc_0;
         __builtin_mma_xxsetaccz(&acc_0);
         for (int l = 0; l < k; l+=4) {
-            packTranspose(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
-            packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
+            packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
+            packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
             __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
             __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
             __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
@@ -2973,8 +2342,8 @@ class tinyBLAS_PPC {
         __builtin_mma_xxsetaccz(&acc_0);
         __builtin_mma_xxsetaccz(&acc_1);
         for (int64_t l = 0; l < k; l+=4) {
-            packTranspose(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
-            packTranspose(B+(jj*ldb)+l, ldb, 8, 4, (TA*)vec_B);
+            packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
+            packTranspose(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
             __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
             __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
             __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
@@ -2994,8 +2363,8 @@ class tinyBLAS_PPC {
         __builtin_mma_xxsetaccz(&acc_0);
         __builtin_mma_xxsetaccz(&acc_1);
         for (int64_t l = 0; l < k; l+=4) {
-            packTranspose(A+(ii*lda)+l, lda, 8, 4, (TA*)vec_A);
-            packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
+            packTranspose(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
+            packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
             __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
             __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
             __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
@@ -3017,8 +2386,8 @@ class tinyBLAS_PPC {
         __builtin_mma_xxsetaccz(&acc_2);
         __builtin_mma_xxsetaccz(&acc_3);
         for (int l = 0; l < k; l+=8) {
-            packTranspose(A+(ii*lda)+l, lda, 8, 8, (TA*)vec_A);
-            packTranspose(B+(jj*ldb)+l, ldb, 8, 8, (TA*)vec_B);
+            packTranspose(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
+            packTranspose(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
             for(int x = 0; x < 16; x+=2) {
                 __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
                 __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
@@ -3033,155 +2402,37 @@ class tinyBLAS_PPC {
     }
 
     void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
-        int64_t mc, nc, mp, np;
-        int m_rem = MIN(m - m0, 16);
-        int n_rem = MIN(n - n0, 16);
-        if (m_rem >= 16 && n_rem >= 8) {
-            mc = 8;
-            nc = 8;
-            gemm<8,8>(m0, m, n0, n);
-        } else if(m_rem >= 8 && n_rem >= 16) {
-            mc = 8;
-            nc = 8;
-            gemm<8,8>(m0, m, n0, n);
-        } else if (m_rem >= 8 && n_rem >= 8) {
-            mc = 8;
-            nc = 8;
-            gemm<8,8>(m0, m, n0, n);
+        int m_rem = MIN(m - m0, 8);
+        int n_rem = MIN(n - n0, 8);
+        int mc = 0, nc = 0;
+        if (m_rem >= 8 && n_rem >= 8) {
+           mc = 8;
+           nc = 8;
+           gemm<8, 8>(m0, m, n0, n);
         } else if (m_rem >= 4 && n_rem >= 8) {
-            mc = 4;
-            nc = 8;
-            gemm<4,8>(m0, m, n0, n);
+                mc = 4;
+                nc = 8;
+                gemm<4, 8>(m0, m, n0, n);
         } else if (m_rem >= 8 && n_rem >= 4) {
-            mc = 8;
-            nc = 4;
-            gemm<8,4>(m0, m, n0, n);
+                mc = 8;
+                nc = 4;
+                gemm<8, 4>(m0, m, n0, n);
         } else if (m_rem >= 4 && n_rem >= 4) {
-            mc = 4;
-            nc = 4;
-            gemm<4,4>(m0, m, n0, n);
-        } else if ((m_rem < 4) && (n_rem > 4)) {
-            nc = 4;
-            switch(m_rem) {
-                case 1:
-                    mc = 1;
-                    gemm_small(m0, m, n0, n, mc, nc);
-                    break;
-                case 2:
-                    mc = 2;
-                    gemm_small(m0, m, n0, n, mc, nc);
-                    break;
-                case 3:
-                    mc = 3;
-                    gemm_small(m0, m, n0, n, mc, nc);
-                    break;
-                default:
-                    return;
-            }
-        } else if ((m_rem > 4) && (n_rem < 4)) {
-            mc = 4;
-            switch(n_rem) {
-                case 1:
-                    nc = 1;
-                    gemm_small(m0, m, n0, n, mc, nc);
-                    break;
-                case 2:
-                    nc = 2;
-                    gemm_small(m0, m, n0, n, mc, nc);
-                    break;
-                case 3:
-                    nc = 3;
-                    gemm_small(m0, m, n0, n, mc, nc);
-                    break;
-                default:
-                    return;
-            }
+                mc = 4;
+                nc = 4;
+                gemm<4, 4>(m0, m, n0, n);
         } else {
-            switch((m_rem << 4) | n_rem) {
-                case 0x43:
-                    mc = 4;
-                    nc = 3;
-                    gemm_small(m0, m, n0, n, mc, nc);
-                    break;
-                case 0x42:
-                    mc = 4;
-                    nc = 2;
-                    gemm_small(m0, m, n0, n, mc, nc);
-                    break;
-                case 0x41:
-                    mc = 4;
-                    nc = 1;
-                    gemm_small(m0, m, n0, n, mc, nc);
-                    break;
-                case 0x34:
-                    mc = 3;
-                    nc = 4;
-                    gemm_small(m0, m, n0, n, mc, nc);
-                    break;
-                case 0x33:
-                    mc = 3;
-                    nc = 3;
-                    gemm_small(m0, m, n0, n, mc, nc);
-                    break;
-                case 0x32:
-                    mc = 3;
-                    nc = 2;
-                    gemm_small(m0, m, n0, n, mc, nc);
-                    break;
-                case 0x31:
-                    mc = 3;
-                    nc = 1;
-                    gemm_small(m0, m, n0, n, mc, nc);
-                    break;
-                case 0x24:
-                    mc = 2;
-                    nc = 4;
-                    gemm_small(m0, m, n0, n, mc, nc);
-                    break;
-                case 0x23:
-                    mc = 2;
-                    nc = 3;
-                    gemm_small(m0, m, n0, n, mc, nc);
-                    break;
-                case 0x22:
-                    mc = 2;
-                    nc = 2;
-                    gemm_small(m0, m, n0, n, mc, nc);
-                    break;
-                case 0x21:
-                    mc = 2;
-                    nc = 1;
-                    gemm_small(m0, m, n0, n, mc, nc);
-                    break;
-                case 0x14:
-                    mc = 1;
-                    nc = 4;
-                    gemm_small(m0, m, n0, n, mc, nc);
-                    break;
-                case 0x13:
-                    mc = 1;
-                    nc = 3;
-                    gemm_small(m0, m, n0, n, mc, nc);
-                    break;
-                case 0x12:
-                    mc = 1;
-                    nc = 2;
-                    gemm_small(m0, m, n0, n, mc, nc);
-                    break;
-                case 0x11:
-                    mc = 1;
-                    nc = 1;
-                    gemm_small(m0, m, n0, n, mc, nc);
-                    break;
-                default:
-                    return;
-            }
+            mc = (m_rem >= 4) ? 4 : m_rem;
+            nc = (n_rem >= 4) ? 4 : n_rem;
+            if (mc == 0 || nc == 0)
+               return;
+            gemm_small(m0, m, n0, n, mc, nc);
         }
-        mp = m0 + (m - m0) / mc * mc;
-        np = n0 + (n - n0) / nc * nc;
+        int64_t mp = m0 + ((m - m0) / mc) * mc;
+        int64_t np = n0 + ((n - n0) / nc) * nc;
         mnpack(mp, m, n0, np);
         mnpack(m0, m, np, n);
-    }
+     }
 
      void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
         int64_t ytiles = (m - m0) / RM;
@@ -3206,22 +2457,22 @@ class tinyBLAS_PPC {
                  * matrix elements.
                  */
                 if (RM == 1) {
-                    TA* a = const_cast(A+(ii)*lda+l);
-                    packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
+                    float* a = const_cast(A+(ii)*lda+l);
+                    packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
                     vec_A[0] = (vec_t)vec_xl(0,a);
-                    vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1));
-                    vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2));
-                    vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3));
+                    vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1));
+                    vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2));
+                    vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3));
                 } else if (RN == 1) {
-                    packTranspose(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
-                    TB* b = const_cast(B+(jj)*ldb+l);
+                    packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
+                    float* b = const_cast(B+(jj)*ldb+l);
                     vec_B[0] = (vec_t)vec_xl(0,b);
-                    vec_B[1] = (vec_t)vec_splats(*((TB*)&vec_B+1));
-                    vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2));
-                    vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3));
+                    vec_B[1] = (vec_t)vec_splats(*((float*)&vec_B+1));
+                    vec_B[2] = (vec_t)vec_splats(*((float*)&vec_B+2));
+                    vec_B[3] = (vec_t)vec_splats(*((float*)&vec_B+3));
                 } else {
-                    packTranspose(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
-                    packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
+                    packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
+                    packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
                 }
                 __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
                 __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
@@ -3231,7 +2482,7 @@ class tinyBLAS_PPC {
             __builtin_mma_disassemble_acc(vec_C, &acc_0);
             for (int I = 0; I < RM; I++) {
                 for (int J = 0; J < RN; J++) {
-                    *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
+                    *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J);
                 }
             }
        }
@@ -3263,11 +2514,9 @@ class tinyBLAS_PPC {
         }
     }
 
-    const TA *const A;
-    const TB *const B;
-    TC *C;
-    TA *At;
-    TB *Bt;
+    const float *const A;
+    const float *const B;
+    float *C;
     const int64_t k;
     const int64_t lda;
     const int64_t ldb;
@@ -3366,7 +2615,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
 #elif defined(__MMA__)
         if (k % 8)
             return false;
-        tinyBLAS_PPC tb{
+        tinyBLAS_PPC tb{
             k, (const float *)A, lda,
             (const float *)B, ldb,
             (float *)C, ldc,
@@ -3493,7 +2742,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
            return false;
         if (m < 8 && m != 4)
            return false;
-        tinyBLAS_Q0_PPC tb{
+        tinyBLAS_Q0_PPC tb{
             k, (const block_q8_0 *)A, lda,
             (const block_q8_0 *)B, ldb,
             (float *)C, ldc,
@@ -3530,7 +2779,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
            return false;
         if (m < 8 && m != 4)
            return false;
-        tinyBLAS_Q0_PPC tb{
+        tinyBLAS_Q0_PPC tb{
             k, (const block_q4_0 *)A, lda,
             (const block_q8_0 *)B, ldb,
             (float *)C, ldc,
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index fd77e9a6aba..b72a2556a5f 100644
--- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp
@@ -8,6 +8,7 @@
 #include "vec.h"
 
 #include 
+#include 
 
 // ggml_compute_forward_dup
 
@@ -1283,6 +1284,7 @@ void ggml_compute_forward_add(
         case GGML_TYPE_Q5_0:
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
+        case GGML_TYPE_MXFP4:
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
@@ -1309,6 +1311,77 @@ void ggml_compute_forward_add(
     }
 }
 
+// ggml_compute_forward_add_id
+
+static void ggml_compute_forward_add_id_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    const ggml_tensor * src2 = dst->src[2];
+
+    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(src2->type == GGML_TYPE_I32);
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT(src1->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_TERNARY_OP_LOCALS
+
+    GGML_ASSERT( nb0 == sizeof(float));
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 indices
+        const int i3 = ir/(ne2*ne1);
+        const int i2 = (ir - i3*ne2*ne1)/ne1;
+        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+        // src1 indices
+        const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21);
+
+        GGML_ASSERT(i11 >= 0 && i11 < ne11);
+
+        ggml_vec_add_f32(ne0,
+                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ),
+                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
+                (float *) ((char *) src1->data + i11*nb11));
+    }
+}
+
+void ggml_compute_forward_add_id(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_add_id_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("unsupported type for ggml_compute_forward_add_id: %s", ggml_type_name(src0->type));
+            }
+    }
+}
+
 // ggml_compute_forward_add1
 
 static void ggml_compute_forward_add1_f32(
@@ -1660,6 +1733,7 @@ void ggml_compute_forward_add1(
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
+        case GGML_TYPE_MXFP4:
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
@@ -1787,6 +1861,7 @@ void ggml_compute_forward_acc(
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
+        case GGML_TYPE_MXFP4:
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
@@ -3614,6 +3689,93 @@ static void ggml_compute_forward_swiglu(
     }
 }
 
+// ggml_compute_forward_swiglu_oai
+
+static void ggml_compute_forward_swiglu_oai_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    char * src0_d = (char *) src0->data;
+    char * src1_d = (char *) (src1 ? src1->data : src0->data);
+    const size_t src0_o = src0->nb[1];
+    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
+
+    GGML_ASSERT(ggml_is_contiguous_1(src0));
+    GGML_ASSERT(ggml_is_contiguous_1(dst));
+
+    if (src1) {
+        GGML_ASSERT(ggml_is_contiguous_1(src1));
+        GGML_ASSERT(src0->type == src1->type);
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
+    const int nr = ggml_nrows(src0);
+
+    GGML_ASSERT(dst->ne[0] == nc);
+    GGML_ASSERT(ggml_nrows(dst) == nr);
+
+    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
+    const float alpha = ggml_get_op_params_f32(dst, 2);
+    const float limit = ggml_get_op_params_f32(dst, 3);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        float * src0_p = (float *) (src0_d + i1*src0_o);
+        float * src1_p = (float *) (src1_d + i1*src1_o);
+        float * dst_p  = (float *) ((char *) dst->data + i1*(dst->nb[1]));
+
+        if (!src1) {
+            src0_p += swapped ? nc : 0;
+            src1_p += swapped ? 0 : nc;
+        }
+
+        for (int k = 0; k < nc; k++) {
+            const float x = std::min(src0_p[k], limit);
+            const float y = std::clamp(src1_p[k], -limit, limit);
+            const float out_glu = x / (1.f + expf(alpha * (-x)));
+            dst_p[k] = out_glu * (y + 1.f);
+        }
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const float x = dst_p[k];
+            GGML_UNUSED(x);
+            assert(!isnan(x));
+            assert(!isinf(x));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_swiglu_oai(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_swiglu_oai_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
 // ggml_compute_forward_geglu_erf
 
 static void ggml_compute_forward_geglu_erf_f32(
@@ -4015,6 +4177,9 @@ static void ggml_compute_forward_rms_norm_f32(
 
                 const float scale = 1.0f/sqrtf(mean + eps);
 
+                // if you hit this, likely you got an inf somewhere earlier
+                assert(scale > 0.0f);
+
                 ggml_vec_scale_f32(ne00, y, scale);
             }
         }
@@ -4596,6 +4761,7 @@ void ggml_compute_forward_out_prod(
         case GGML_TYPE_Q5_0:
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
+        case GGML_TYPE_MXFP4:
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
@@ -4870,6 +5036,7 @@ void ggml_compute_forward_set(
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
+        case GGML_TYPE_MXFP4:
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
@@ -5131,6 +5298,7 @@ void ggml_compute_forward_get_rows(
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
+        case GGML_TYPE_MXFP4:
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
@@ -5520,6 +5688,7 @@ static void ggml_compute_forward_soft_max_f32(
 
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
+    const ggml_tensor * src2 = dst->src[2];
 
     assert(ggml_is_contiguous(dst));
     assert(ggml_are_same_shape(src0, dst));
@@ -5554,6 +5723,9 @@ static void ggml_compute_forward_soft_max_f32(
 
     const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
 
+    // sinks
+    const float * sk = src2 ? (float *)((char *) src2->data) : nullptr;
+
     for (int64_t i03 = 0; i03 < ne03; i03++) {
         for (int64_t i02 = 0; i02 < ne02; i02++) {
             for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
@@ -5596,9 +5768,18 @@ static void ggml_compute_forward_soft_max_f32(
                 float max = -INFINITY;
                 ggml_vec_max_f32(ne00, &max, wp);
 
+                // if we have sinks, make a correction as if they were included in the softmax
+                if (sk) {
+                    max = MAX(max, sk[i02]);
+                }
+
                 ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
                 assert(sum > 0.0);
 
+                if (sk) {
+                    sum += (ggml_float) expf(sk[i02] - max);
+                }
+
                 sum = 1.0/sum;
                 ggml_vec_scale_f32(ne00, dp, sum);
 
@@ -5833,6 +6014,7 @@ void ggml_compute_forward_clamp(
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
+        case GGML_TYPE_MXFP4:
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
@@ -7986,12 +8168,14 @@ void ggml_compute_forward_argsort(
 
 static void ggml_compute_forward_flash_attn_ext_f16(
         const ggml_compute_params * params,
-        const ggml_tensor * q,
-        const ggml_tensor * k,
-        const ggml_tensor * v,
-        const ggml_tensor * mask,
         ggml_tensor * dst) {
 
+    const ggml_tensor * q     = dst->src[0];
+    const ggml_tensor * k     = dst->src[1];
+    const ggml_tensor * v     = dst->src[2];
+    const ggml_tensor * mask  = dst->src[3];
+    const ggml_tensor * sinks = dst->src[4];
+
     GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
     GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
     GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
@@ -8186,6 +8370,23 @@ static void ggml_compute_forward_flash_attn_ext_f16(
             }
         }
 
+        // sinks
+        if (sinks) {
+            const float s = ((float *)((char *) sinks->data))[h];
+
+            float ms = 1.0f;
+            float vs = 1.0f;
+
+            if (s > M) {
+                ms = expf(M - s);
+                ggml_vec_scale_f32(DV, VKQ32, ms);
+            } else {
+                vs = expf(s - M);
+            }
+
+            S = S*ms + vs;
+        }
+
         // V /= S
         const float S_inv = 1.0f/S;
         ggml_vec_scale_f32(DV, VKQ32, S_inv);
@@ -8205,17 +8406,13 @@ static void ggml_compute_forward_flash_attn_ext_f16(
 
 void ggml_compute_forward_flash_attn_ext(
         const ggml_compute_params * params,
-        const ggml_tensor * q,
-        const ggml_tensor * k,
-        const ggml_tensor * v,
-        const ggml_tensor * mask,
         ggml_tensor * dst) {
     switch (dst->op_params[3]) {
         case GGML_PREC_DEFAULT:
         case GGML_PREC_F32:
             {
                 // uses F32 accumulators
-                ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
+                ggml_compute_forward_flash_attn_ext_f16(params, dst);
             } break;
         default:
             {
@@ -9077,6 +9274,10 @@ void ggml_compute_forward_glu(
             {
                 ggml_compute_forward_swiglu(params, dst);
             } break;
+        case GGML_GLU_OP_SWIGLU_OAI:
+            {
+                ggml_compute_forward_swiglu_oai(params, dst);
+            } break;
         case GGML_GLU_OP_GEGLU_ERF:
             {
                 ggml_compute_forward_geglu_erf(params, dst);
@@ -10129,6 +10330,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
     const int ir1 = MIN(ir0 + dr, nr);
 
     const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
+
     const float alpha  = adamw_params_ptr[0];
     const float beta1  = adamw_params_ptr[1];
     const float beta2  = adamw_params_ptr[2];
@@ -10136,7 +10338,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
     const float wd     = adamw_params_ptr[4];
     const float beta1h = adamw_params_ptr[5];
     const float beta2h = adamw_params_ptr[6];
-
+    const float keep   = 1.f - alpha * wd;
     for (int ir = ir0; ir < ir1; ++ir) {
         const int64_t i03 = ir/(ne02*ne01);
         const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
@@ -10159,7 +10361,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
             // The weight decay is applied independently of the Adam momenta m and v.
             // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
             // See: https://arxiv.org/pdf/1711.05101v3.pdf
-            w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh;
+            w[i00] = w[i00] * keep - alpha * mh / vh;
         }
     }
 }
@@ -10181,3 +10383,63 @@ void ggml_compute_forward_opt_step_adamw(
             }
     }
 }
+
+static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {
+    const ggml_tensor * src0       = dst->src[0];
+    const ggml_tensor * src0_grad  = dst->src[1];
+    const ggml_tensor * sgd_params = dst->src[2];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
+    GGML_ASSERT(ggml_nelements(sgd_params) == 2);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr = ggml_nrows(src0);
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    // rows per thread
+    const int dr = (nr + nth - 1) / nth;
+
+    // row range for this thread
+    const int ir0 = dr * ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    // using adamw param subset we care about - alpha, wd - could have a separate struct
+    const float * sgd_params_ptr   = ggml_get_data_f32(sgd_params);
+    const float   alpha            = sgd_params_ptr[0];
+    const float   keep             = 1.f - alpha * sgd_params_ptr[1];
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        const int64_t i03 = ir / (ne02 * ne01);
+        const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
+        const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
+
+        const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
+
+        float *       w = (float *) ((char *) src0->data + offset);                   // weight
+        const float * g = (const float *) ((const char *) src0_grad->data + offset);  // grad
+
+        for (int i00 = 0; i00 < ne00; ++i00) {
+            w[i00] = w[i00] * keep - alpha * g[i00];
+        }
+    }
+}
+
+void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_opt_step_sgd_f32(params, dst);
+            }
+            break;
+        default:
+            {
+                GGML_ABORT("fatal error - sgd is F32 only");
+            }
+    }
+}
diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h
index 3a32ec20dba..82ea79eaa51 100644
--- a/ggml/src/ggml-cpu/ops.h
+++ b/ggml/src/ggml-cpu/ops.h
@@ -29,6 +29,7 @@ extern "C" {
 
 void ggml_compute_forward_dup(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_add(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_add_id(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_add1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
@@ -82,13 +83,7 @@ void ggml_compute_forward_arange(const struct ggml_compute_params * params, stru
 void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
-void ggml_compute_forward_flash_attn_ext(
-    const struct ggml_compute_params * params,
-    const struct ggml_tensor * q,
-    const struct ggml_tensor * k,
-    const struct ggml_tensor * v,
-    const struct ggml_tensor * mask,
-    struct ggml_tensor * dst);
+void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_flash_attn_back(
         const struct ggml_compute_params * params,
         const bool masked,
@@ -112,7 +107,7 @@ void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params *
 void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
-
+void ggml_compute_forward_opt_step_sgd(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 #ifdef __cplusplus
 }
 #endif
diff --git a/ggml/src/ggml-cpu/quants.c b/ggml/src/ggml-cpu/quants.c
index ee35ab42fda..365cb36d2d7 100644
--- a/ggml/src/ggml-cpu/quants.c
+++ b/ggml/src/ggml-cpu/quants.c
@@ -46,6 +46,10 @@ void quantize_row_q8_1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRI
     quantize_row_q8_1_ref(x, y, k);
 }
 
+void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
+    quantize_row_mxfp4_ref(x, y, k);
+}
+
 //
 // 2-6 bit quantization in super-blocks
 //
@@ -181,6 +185,37 @@ void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c
     *s = sumf;
 }
 
+void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+    assert(n % QK_MXFP4 == 0);
+    static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
+
+    const block_mxfp4 * GGML_RESTRICT x = vx;
+    const block_q8_0 * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_MXFP4;
+
+    int ib = 0;
+    float sumf = 0;
+
+    for (; ib < nb; ++ib) {
+        const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
+
+        int sumi1 = 0;
+        int sumi2 = 0;
+        for (int j = 0; j < QK_MXFP4/2; ++j) {
+            sumi1 += y[ib].qs[j +          0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
+            sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >>  4];
+        }
+        sumf += d * (sumi1 + sumi2);
+    }
+    *s = sumf;
+}
+
 void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     const int qk = QK8_0;
     const int nb = n / qk;
diff --git a/ggml/src/ggml-cpu/quants.h b/ggml/src/ggml-cpu/quants.h
index dc4342c87f5..d83eb1b144d 100644
--- a/ggml/src/ggml-cpu/quants.h
+++ b/ggml/src/ggml-cpu/quants.h
@@ -19,6 +19,8 @@ void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
 void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
 void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
 
+void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+
 void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
 void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
 void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
@@ -39,6 +41,8 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
 void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 
+void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
 void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
@@ -67,8 +71,12 @@ void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c
 void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_q5_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
+void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
 void ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
 void ggml_vec_dot_q2_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_q3_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_q4_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp
index 72ee93a5abc..f531d21e232 100644
--- a/ggml/src/ggml-cpu/repack.cpp
+++ b/ggml/src/ggml-cpu/repack.cpp
@@ -14,7 +14,6 @@
 #include 
 #include 
 #include 
-#include  // for qsort
 #include   // for GGML_ASSERT
 
 #include "repack.h"
@@ -207,8 +206,9 @@ void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
     const int ncols_interleaved = 4;
     const int blocklen = 4;
 
-    assert (n % qk == 0);
-    assert (nc % ncols_interleaved == 0);
+    assert(nr == 1);
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
 
     UNUSED(s);
     UNUSED(bs);
@@ -308,30 +308,28 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
     UNUSED(ncols_interleaved);
     UNUSED(blocklen);
 
-    {
-        float sumf[8];
-        int sumi;
+    float sumf[8];
+    int sumi;
 
-        const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
-        for (int x = 0; x < nc / ncols_interleaved; x++) {
-            const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
 
-            for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
-            for (int l = 0; l < nb; l++) {
-                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
-                    for (int j = 0; j < ncols_interleaved; j++) {
-                        sumi = 0;
-                        for (int i = 0; i < blocklen; ++i) {
-                            const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
-                            const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
-                            sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
-                        }
-                        sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
+        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
+        for (int l = 0; l < nb; l++) {
+            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumi = 0;
+                    for (int i = 0; i < blocklen; ++i) {
+                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+                        const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
                     }
+                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
                 }
             }
-            for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
         }
+        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
     }
 }
 
@@ -413,11 +411,11 @@ void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
     }
 }
 
-void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
-    const int qk = QK8_0;
+void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK_K;
     const int nb = n / qk;
-    const int ncols_interleaved = 4;
-    const int blocklen = 4;
+    const int ncols_interleaved = 8;
+    const int blocklen = 8;
 
     assert (n % qk == 0);
     assert (nc % ncols_interleaved == 0);
@@ -432,30 +430,136 @@ void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
     UNUSED(ncols_interleaved);
     UNUSED(blocklen);
 
-    {
-        float sumf[4];
-        int sumi;
+    float sumf[8];
+    float sum_minf[8];
+    int sumi1,sumi2,sumi3,sumi4;
+    int sumi;
 
-        const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
-        for (int x = 0; x < nc / ncols_interleaved; x++) {
-            const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
+    const block_q8_K * a_ptr = (const block_q8_K *)vy;
+    for(int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q2_Kx8 * b_ptr = (const block_q2_Kx8 *) vx + (x * nb);
+        for (int j = 0; j < ncols_interleaved; j++) {
+            sumf[j] = 0.0;
+            sum_minf[j] = 0.0;
+        }
+        for (int l = 0; l < nb; l++) {
+            for (int k = 0; k < (qk / (4 * blocklen)); k++) {
+                const uint8_t *scales_0 = b_ptr[l].scales + (k / 4) * 64 ;
+                const uint8_t *scales_1 = b_ptr[l].scales + (k / 4) * 64 + 16;
+                const uint8_t *scales_2 = b_ptr[l].scales + (k / 4) * 64 + 32;
+                const uint8_t *scales_3 = b_ptr[l].scales + (k / 4) * 64 + 48;
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumi1 = 0;
+                    sumi2 = 0;
+                    sumi3 = 0;
+                    sumi4 = 0;
+                    sumi = 0;
+                    int offset = ((k / 2) % 2) + j * 2;
+                    for (int i = 0; i < blocklen; ++i){
+                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 3);
+                        const int v1 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 2 ) & 3);
+                        const int v2 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4 ) & 3);
+                        const int v3 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 6 ) & 3);
+                        sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i]);
+                        sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 32]);
+                        sumi3 = (v2 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 64]);
+                        sumi4 = (v3 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 96]);
+
+                        sumi1 = sumi1 * (scales_0[offset] & 0xF);
+                        sumi2 = sumi2 * (scales_1[offset] & 0xF);
+                        sumi3 = sumi3 * (scales_2[offset] & 0xF);
+                        sumi4 = sumi4 * (scales_3[offset] & 0xF);
+                        sumi += sumi1 + sumi2 + sumi3 + sumi4;
+                    }
+                    sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
+                }
+            }
+            for(int sb = 0; sb < 8; sb++) {
+                const uint8_t *mins = b_ptr[l].scales + sb * 16;
+                for(int j = 0; j < ncols_interleaved; j++){
+                    sum_minf[j] += ((mins[j * 2] >> 4) * a_ptr[l].bsums[sb * 2] + (mins[(j * 2)+ 1] >> 4) * a_ptr[l].bsums[sb * 2 + 1]) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
+                }
+            }
+        }
+        for (int j = 0; j < ncols_interleaved; j++) {
+            s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
+        }
+    }
+}
 
-            for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
-            for (int l = 0; l < nb; l++) {
-                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
-                    for (int j = 0; j < ncols_interleaved; j++) {
-                        sumi = 0;
-                        for (int i = 0; i < blocklen; ++i) {
-                            const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
-                            const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
-                            sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
-                        }
-                        sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
+void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 4;
+
+    assert(nr == 1);
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(bs);
+    UNUSED(nr);
+
+    float sumf[4];
+    int sumi;
+
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
+        for (int l = 0; l < nb; l++) {
+            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumi = 0;
+                    for (int i = 0; i < blocklen; ++i) {
+                        const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
+                        const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
+                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
                     }
+                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
                 }
             }
-            for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
         }
+        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
+    }
+}
+
+void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 8;
+    const int blocklen = 8;
+
+    assert(nr == 1);
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(bs);
+    UNUSED(nr);
+
+    float sumf[8];
+    int sumi;
+
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
+        for (int l = 0; l < nb; l++) {
+            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumi = 0;
+                    for (int i = 0; i < blocklen; ++i) {
+                        const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
+                        const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
+                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
+                    }
+                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
+                }
+            }
+        }
+        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
     }
 }
 
@@ -712,6 +816,97 @@ void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
     }
 }
 
+void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK_K;
+    const int nb = n / qk;
+    const int ncols_interleaved = 8;
+    const int blocklen = 8;
+
+    assert (n % qk == 0);
+    assert (nr % 4 == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+    float sumf[4][8];
+    float sum_minf[4][8];
+    int sumi1, sumi2, sumi3, sumi4;
+    int sumi;
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q2_Kx8 * b_ptr = (const block_q2_Kx8 *) vx + (x * nb);
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumf[m][j] = 0.0;
+                    sum_minf[m][j] = 0.0;
+                }
+            }
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / (4 * blocklen)); k++) {
+
+                    const uint8_t *scales_0 = b_ptr[l].scales + (k / 4) * 64 ;
+                    const uint8_t *scales_1 = b_ptr[l].scales + (k / 4) * 64 + 16;
+                    const uint8_t *scales_2 = b_ptr[l].scales + (k / 4) * 64 + 32;
+                    const uint8_t *scales_3 = b_ptr[l].scales + (k / 4) * 64 + 48;
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sumi1 = 0;
+                            sumi2 = 0;
+                            sumi3 = 0;
+                            sumi4 = 0;
+                            sumi = 0;
+                            int offset = ((k / 2) % 2) + j * 2;
+                            for (int i = 0; i < blocklen; ++i){
+                                const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 3);
+                                const int v1 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 2 ) & 3);
+                                const int v2 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4 ) & 3);
+                                const int v3 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 6 ) & 3);
+                                sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i]);
+                                sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 512  + (k % 4) * 4 * blocklen + m * blocklen + i + 128]);
+                                sumi3 = (v2 * a_ptr[l].qs[(k >> 2) * 512  + (k % 4) * 4 * blocklen + m * blocklen + i + 256]);
+                                sumi4 = (v3 * a_ptr[l].qs[(k >> 2) * 512  + (k % 4) * 4 * blocklen + m * blocklen + i + 384]);
+                                sumi1 = sumi1 * (scales_0[offset] & 0xF);
+                                sumi2 = sumi2 * (scales_1[offset] & 0xF);
+                                sumi3 = sumi3 * (scales_2[offset] & 0xF);
+                                sumi4 = sumi4 * (scales_3[offset] & 0xF);
+                                sumi += sumi1 + sumi2 + sumi3 + sumi4;
+                            }
+                            sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
+                        }
+                    }
+                }
+                for(int sb = 0; sb < 8; sb++) {
+                    const uint8_t *mins = b_ptr[l].scales + sb * 16;
+                    for(int m = 0; m < 4; m++) {
+                        const int16_t *bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) *  6);
+                        for(int j = 0; j < ncols_interleaved; j++) {
+                            int mins_prod = ((mins[j * 2] >> 4) * bsums[0] + (mins[(j * 2)+ 1] >> 4) * bsums[1]);
+                            sum_minf[m][j] += (mins_prod) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
+                        }
+                    }
+                }
+            }
+
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
+                }
+            }
+        }
+    }
+}
+
+
 void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
     const int qk = QK8_0;
     const int nb = n / qk;
@@ -768,6 +963,50 @@ void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
     }
 }
 
+void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 8;
+    const int blocklen = 8;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    float sumf[4][8];
+    int sumi;
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb);
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
+            }
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sumi = 0;
+                            for (int i = 0; i < blocklen; ++i) {
+                                const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
+                                const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
+                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
+                                         (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));
+                            }
+                            sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
+                        }
+                    }
+                }
+            }
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++)
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+            }
+        }
+    }
+}
+
 } // extern "C"
 
 static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
@@ -915,6 +1154,50 @@ static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_in
     return out;
 }
 
+static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_interleave) {
+    block_q2_Kx8 out;
+
+    // Delta(scale) and dmin values of the eight Q2_K structures are copied onto the output interleaved structure
+    for (int i = 0; i < 8; i++) {
+        out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;
+    }
+
+    for (int i = 0; i < 8; i++) {
+        out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;
+    }
+
+    const int end = QK_K * 2 / blck_size_interleave;
+
+    // Interleave Q2_K quants by taking 8 bytes at a time
+    for (int i = 0; i < end; ++i) {
+        int src_id = i % 8;
+        int src_offset = (i / 8) * blck_size_interleave;
+        int dst_offset = i * blck_size_interleave;
+
+        uint64_t elems;
+        memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
+        memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
+    }
+
+    // The below logic is designed so as to unpack and rearrange scales and mins values in Q2_K
+    // Currently the Q2_K structure has 16 scales and 16 mins packed in 16 bytes ( 4 bits for each value)
+    // The output Q2_Kx8 structure has 128 bytes for storing scales and mins
+    // Every 16 byte is packed such that it contains scales and mins for corresponding sub blocks from Q2_K structure
+    // For eg - First 16 bytes contains 16 scales and 16 mins - each of first and second sub blocks from different Q2_K structures
+
+    for(int i = 0; i < 128; i++){
+
+        // Index for selecting which q2k super block
+        int src1 = (i % 16) / 2;
+        // Index for selecting scale
+        int src2 = ((i / 16) * 2) + (i % 2);
+
+        out.scales[i] = in[src1].scales[src2];
+    }
+    return out;
+
+}
+
 static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
     GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
     GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
@@ -976,6 +1259,37 @@ static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block
     GGML_UNUSED(data_size);
 }
 
+static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_Q2_K);
+    GGML_ASSERT(interleave_block == 8);
+    constexpr int nrows_interleaved = 8;
+
+    block_q2_Kx8 * dst = (block_q2_Kx8*)t->data;
+    const block_q2_K * src = (const block_q2_K*) data;
+    block_q2_K dst_tmp[8];
+    int nrow = ggml_nrows(t);
+    int nblocks = t->ne[0] / QK_K;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q2_K));
+
+    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int i  = 0; i < nrows_interleaved; i++ ) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+            *dst++ = make_block_q2_Kx8(dst_tmp, interleave_block);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+
+    GGML_UNUSED(data_size);
+}
+
 static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
     GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
     GGML_ASSERT(interleave_block == 8);
@@ -1044,15 +1358,16 @@ static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_s
 
 static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
     GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL);
-    //GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
     GGML_ASSERT(interleave_block == 4);
 
-    block_iq4_nlx4 * dst = (block_iq4_nlx4 *)t->data;
-    const block_iq4_nl * src = (const block_iq4_nl *)data;
+    const block_iq4_nl   * src = (const block_iq4_nl   *)data;
+          block_iq4_nlx4 * dst = (      block_iq4_nlx4 *)t->data;
+
     block_iq4_nl dst_tmp[4];
+
     int nrow = ggml_nrows(t);
     int nrows_interleaved = 4;
-    int nblocks = t->ne[0] / QK4_0;
+    int nblocks = t->ne[0] / QK4_NL;
 
     GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));
 
@@ -1074,6 +1389,63 @@ static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_b
     GGML_UNUSED(data_size);
 }
 
+static block_iq4_nlx8 make_block_iq4_nlx8(block_iq4_nl * in, unsigned int blck_size_interleave) {
+    block_iq4_nlx8 out;
+
+    for (int i = 0; i < 8; i++) {
+        out.d[i] = in[i].d;
+    }
+
+    const int end = QK4_NL * 4 / blck_size_interleave;
+
+    if (blck_size_interleave == 8) {
+        for (int i = 0; i < end; ++i) {
+            int src_id = i % 8;
+            int src_offset = (i / 8) * blck_size_interleave;
+            int dst_offset = i * blck_size_interleave;
+
+            memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));
+        }
+    } else {
+        GGML_ASSERT(false);
+    }
+
+    return out;
+}
+
+static int repack_iq4_nl_to_iq4_nl_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL);
+    GGML_ASSERT(interleave_block == 8);
+
+    const block_iq4_nl   * src = (const block_iq4_nl   *)data;
+          block_iq4_nlx8 * dst = (      block_iq4_nlx8 *)t->data;
+
+    block_iq4_nl dst_tmp[8];
+
+    int nrow = ggml_nrows(t);
+    int nrows_interleaved = 8;
+    int nblocks = t->ne[0] / QK4_NL;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));
+
+    if (t->ne[1] % nrows_interleaved != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int i = 0; i < nrows_interleaved; i++) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+            *dst++ = make_block_iq4_nlx8(dst_tmp, interleave_block);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+
+    GGML_UNUSED(data_size);
+}
+
 namespace ggml::cpu::repack {
 // repack
 template 
@@ -1096,6 +1468,10 @@ template <> int repack(struct ggml_tensor * t, const void * da
     return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size);
 }
 
+template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size);
+}
+
 template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
     return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
 }
@@ -1105,6 +1481,10 @@ template <> int repack(struct ggml_tensor * t, const void *
 //    return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size);
 //}
 
+template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_iq4_nl_to_iq4_nl_8_bl(t, 8, data, data_size);
+}
+
 // gemv
 template 
 void gemv(int, float *, size_t, const void *, const void *, int, int);
@@ -1125,10 +1505,18 @@ template <> void gemv(int n, float * s, size_t
     ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
 }
 
+template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+
 template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
+template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
 // gemm
 template 
 void gemm(int, float *, size_t, const void *, const void *, int, int);
@@ -1149,10 +1537,18 @@ template <> void gemm(int n, float * s, size_t
     ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
 }
 
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+
 template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
 class tensor_traits_base : public ggml::cpu::tensor_traits {
   public:
     virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
@@ -1422,8 +1818,12 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
     static const ggml::cpu::repack::tensor_traits q4_0_8x8_q8_0;
     static const ggml::cpu::repack::tensor_traits q4_K_8x8_q8_K;
 
+    // instance for Q2
+    static const ggml::cpu::repack::tensor_traits q2_K_8x8_q8_K;
+
     // instance for IQ4
     static const ggml::cpu::repack::tensor_traits iq4_nl_4x4_q8_0;
+    static const ggml::cpu::repack::tensor_traits iq4_nl_8x8_q8_0;
 
     if (cur->type == GGML_TYPE_Q4_0) {
         if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) {
@@ -1447,7 +1847,18 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
                 return &q4_K_8x8_q8_K;
             }
         }
+    } else if (cur->type == GGML_TYPE_Q2_K) {
+        if (ggml_cpu_has_avx512()) {
+            if (cur->ne[1] % 8 == 0) {
+                return &q2_K_8x8_q8_K;
+            }
+        }
     } else if (cur->type == GGML_TYPE_IQ4_NL) {
+        if (ggml_cpu_has_avx2()) {
+            if (cur->ne[1] % 8 == 0) {
+                return &iq4_nl_8x8_q8_0;
+            }
+        }
         if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
             if (cur->ne[1] % 4 == 0) {
                 return &iq4_nl_4x4_q8_0;
diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h
index 4421e5f8e70..cb32b503d3a 100644
--- a/ggml/src/ggml-cpu/repack.h
+++ b/ggml/src/ggml-cpu/repack.h
@@ -44,7 +44,14 @@ struct block_q4_Kx8 {
 };
 
 static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
+struct block_q2_Kx8 {
+    ggml_half d[8];      // super-block scale for quantized scales
+    ggml_half dmin[8];   // super-block scale for quantized mins
+    uint8_t scales[128];  // scales and mins, quantized with 4 bits
+    uint8_t qs[512];    // 2--bit quants
+};
 
+static_assert(sizeof(block_q2_Kx8) == sizeof(ggml_half) * 16 + QK_K/2 + QK_K * 2, "wrong q2_K block size/padding");
 struct block_q8_Kx4 {
     float d[4];              // delta
     int8_t qs[QK_K * 4];     // quants
@@ -60,6 +67,13 @@ struct block_iq4_nlx4 {
 
 static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding");
 
+struct block_iq4_nlx8 {
+    ggml_half d[8];            // deltas for 8 iq4_nl blocks
+    uint8_t   qs[QK4_NL * 4];  // nibbles / quants for 8 iq4_nl blocks
+};
+
+static_assert(sizeof(block_iq4_nlx8) == 8 * sizeof(ggml_half) + QK4_NL * 4, "wrong iq4_nlx8 block size/padding");
+
 #if defined(__cplusplus)
 extern "C" {
 #endif
@@ -71,12 +85,16 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
 void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 
 // Native implementations
 void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
@@ -86,12 +104,16 @@ void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
 void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 
 #if defined(__cplusplus)
 } // extern "C"
diff --git a/ggml/src/ggml-cpu/traits.cpp b/ggml/src/ggml-cpu/traits.cpp
index 139fa596414..4f32f10255a 100644
--- a/ggml/src/ggml-cpu/traits.cpp
+++ b/ggml/src/ggml-cpu/traits.cpp
@@ -10,7 +10,7 @@ extra_buffer_type::~extra_buffer_type() {}
 }  // namespace ggml::cpu
 
 bool ggml_cpu_extra_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) {
-    for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
+    for (auto extra : ggml_backend_cpu_get_extra_buffer_types()) {
         if (extra && extra->context) {
             auto buf_extra     = (ggml::cpu::extra_buffer_type *) extra->context;
             auto tensor_traits = buf_extra->get_tensor_traits(op);
@@ -23,7 +23,7 @@ bool ggml_cpu_extra_compute_forward(struct ggml_compute_params * params, struct
 }
 
 bool ggml_cpu_extra_work_size(int n_threads, const struct ggml_tensor * op, size_t * size) {
-    for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
+    for (auto extra : ggml_backend_cpu_get_extra_buffer_types()) {
         if (extra && extra->context) {
             auto buf_extra     = (ggml::cpu::extra_buffer_type *) extra->context;
             auto tensor_traits = buf_extra->get_tensor_traits(op);
diff --git a/ggml/src/ggml-cpu/traits.h b/ggml/src/ggml-cpu/traits.h
index 99a6186b1d6..f4e0990ddfc 100644
--- a/ggml/src/ggml-cpu/traits.h
+++ b/ggml/src/ggml-cpu/traits.h
@@ -33,6 +33,6 @@ class extra_buffer_type {
 }  // namespace ggml::cpu
 
 // implemented in ggml-cpu.cpp.
-std::vector & ggml_backend_cpu_get_extra_buffers_type();
+std::vector & ggml_backend_cpu_get_extra_buffer_types();
 
 #endif
diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp
index a8156011eba..07b377bdd82 100644
--- a/ggml/src/ggml-cpu/vec.cpp
+++ b/ggml/src/ggml-cpu/vec.cpp
@@ -221,6 +221,9 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G
     for (int i = np; i < n; ++i) {
         sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
     }
+
+    // if you hit this, you are likely running outside the FP range
+    assert(!isnan(sumf) && !isinf(sumf));
 #else
     for (int i = 0; i < n; ++i) {
         sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h
index d18783a00a1..2250d93cb00 100644
--- a/ggml/src/ggml-cpu/vec.h
+++ b/ggml/src/ggml-cpu/vec.h
@@ -55,7 +55,22 @@ inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x)
 
 inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const ggml_fp16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
 inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
-inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] + y[i]; }
+
+inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) {
+    int i = 0;
+#if defined(__AVX2__)
+    for (; i + 7 < n; i += 8) {
+        __m256 vx = _mm256_loadu_ps(x + i);
+        __m256 vy = _mm256_loadu_ps(y + i);
+        __m256 vz = _mm256_add_ps(vx, vy);
+        _mm256_storeu_ps(z + i, vz);
+    }
+#endif
+    for (; i < n; ++i) {
+        z[i] = x[i] + y[i];
+    }
+}
+
 inline static void ggml_vec_add_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
     for (int i = 0; i < n; ++i) {
         z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) + GGML_CPU_FP16_TO_FP32(y[i]));
@@ -992,9 +1007,9 @@ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float *
 
 inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
     for (int i = 0; i < n; ++i) {
-        float v = GGML_CPU_FP16_TO_FP32(x[i]);
-        float w = GGML_CPU_FP16_TO_FP32(g[i]);
-        y[i] = GGML_CPU_FP32_TO_FP16((v/(1.0f + expf(-v))) * w);
+        float xi = GGML_CPU_FP16_TO_FP32(x[i]);
+        float gi = GGML_CPU_FP16_TO_FP32(g[i]);
+        y[i] = GGML_CPU_FP32_TO_FP16((xi/(1.0f + expf(-xi))) * gi);
     }
 }
 
diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt
index c9ff4aa321b..bce07ac3628 100644
--- a/ggml/src/ggml-cuda/CMakeLists.txt
+++ b/ggml/src/ggml-cuda/CMakeLists.txt
@@ -102,12 +102,12 @@ if (CUDAToolkit_FOUND)
     if (GGML_STATIC)
         if (WIN32)
             # As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
-            target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
+            target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas)
         else ()
-            target_link_libraries(ggml-cuda PRIVATE  CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
+            target_link_libraries(ggml-cuda PRIVATE  CUDA::cudart_static CUDA::cublas_static)
         endif()
     else()
-        target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas CUDA::cublasLt)
+        target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas)
     endif()
 
     if (GGML_CUDA_NO_VMM)
@@ -120,6 +120,10 @@ if (CUDAToolkit_FOUND)
 
     set(CUDA_FLAGS -use_fast_math -extended-lambda)
 
+    if (GGML_CUDA_DEBUG)
+        list(APPEND CUDA_FLAGS -lineinfo)
+    endif()
+
     if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
         # Options are:
         # - none (not recommended)
diff --git a/ggml/src/ggml-cuda/add-id.cu b/ggml/src/ggml-cuda/add-id.cu
new file mode 100644
index 00000000000..8bed62ac9d2
--- /dev/null
+++ b/ggml/src/ggml-cuda/add-id.cu
@@ -0,0 +1,58 @@
+#include "add-id.cuh"
+
+static __global__ void add_id_kernel(
+        const float * src0, const float * src1, const int32_t * src2, float * dst,
+        int64_t ne0, int64_t ne1,
+        size_t nb01, size_t nb02,
+        size_t nb11,
+        size_t nb21
+    ) {
+
+    const int64_t i1 = blockIdx.x;
+    const int64_t i2 = blockIdx.y;
+
+    const int i11 = *(int32_t *) ((char *) src2 + i1*sizeof(int32_t) + i2*nb21);
+
+    const size_t nb1 = ne0 * sizeof(float);
+    const size_t nb2 = ne1 * nb1;
+
+    float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2);
+    const float * src0_row = (const float *)((char *)src0 +  i1*nb01 + i2*nb02);
+    const float * src1_row = (const float *)((char *)src1 + i11*nb11);
+
+    for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
+        dst_row[i0] = src0_row[i0] + src1_row[i0];
+    }
+}
+
+void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    const ggml_tensor * src2 = dst->src[2];
+
+    GGML_TENSOR_TERNARY_OP_LOCALS
+
+    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(src2->type == GGML_TYPE_I32);
+
+    GGML_ASSERT(nb00 == sizeof(float));
+    GGML_ASSERT(nb10 == sizeof(float));
+    GGML_ASSERT(nb20 == sizeof(int32_t));
+
+    const float * src0_d = (const float *)src0->data;
+    const float * src1_d = (const float *)src1->data;
+    const int32_t * src2_d = (const int32_t *)src2->data;
+    float * dst_d = (float *)dst->data;
+
+    int threads = std::min((int)ne00, 768); // cols
+    dim3 blocks(ne01, ne02); // n_experts_used, n_tokens
+    add_id_kernel<<>>(
+        src0_d, src1_d, src2_d, dst_d,
+        ne0, ne1,
+        nb01, nb02,
+        nb11,
+        nb21
+    );
+}
diff --git a/ggml/src/ggml-cuda/add-id.cuh b/ggml/src/ggml-cuda/add-id.cuh
new file mode 100644
index 00000000000..30b1721ac32
--- /dev/null
+++ b/ggml/src/ggml-cuda/add-id.cuh
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index 1a2708ec9df..2b14b30ac90 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -1,6 +1,7 @@
 #pragma once
 
 #include "ggml.h"
+#include "ggml-impl.h"
 #include "ggml-cuda.h"
 
 #include 
@@ -56,7 +57,7 @@
 #define GGML_CUDA_CC_GCN4       (GGML_CUDA_CC_OFFSET_AMD + 0x803)  // Tonga, Fiji, Polaris, minimum for fast fp16
 #define GGML_CUDA_CC_VEGA       (GGML_CUDA_CC_OFFSET_AMD + 0x900)  // Vega56/64, minimum for fp16 dual issue
 #define GGML_CUDA_CC_VEGA20     (GGML_CUDA_CC_OFFSET_AMD + 0x906)  // MI50/Radeon VII, minimum for dp4a
-#define GGML_CUDA_CC_CDNA       (GGML_CUDA_CC_OFFSET_AMD + 0x908)  // MI100, minimum for MFMA, acc registers
+#define GGML_CUDA_CC_CDNA1      (GGML_CUDA_CC_OFFSET_AMD + 0x908)  // MI100, minimum for MFMA, acc registers
 #define GGML_CUDA_CC_CDNA2      (GGML_CUDA_CC_OFFSET_AMD + 0x910)  // MI210, minimum acc register renameing
 #define GGML_CUDA_CC_CDNA3      (GGML_CUDA_CC_OFFSET_AMD + 0x942)  // MI300
 
@@ -72,8 +73,9 @@
 #define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
 #define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
 #define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
-#define GGML_CUDA_CC_IS_GCN(cc)   (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
-#define GGML_CUDA_CC_IS_CDNA(cc)  (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
+#define GGML_CUDA_CC_IS_GCN(cc)   (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1)
+#define GGML_CUDA_CC_IS_CDNA(cc)  (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1)
+#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1)
 
 // Moore Threads
 #define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
@@ -85,6 +87,10 @@
 #define GGML_CUDA_CC_IS_QY2(cc)      (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG)
 #define GGML_CUDA_CC_IS_NG(cc)       (cc >= GGML_CUDA_CC_NG)
 
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
+#    define GGML_CUDA_USE_CUB
+#endif  // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
+
 #ifdef __CUDA_ARCH_LIST__
 constexpr bool ggml_cuda_has_arch_impl(int) {
     return false;
@@ -175,7 +181,7 @@ static const char * cu_get_error_str(CUresult err) {
 #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
 #endif
 
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
 #    define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes)                                                       \
         do {                                                                                                   \
             static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = { false };                         \
@@ -190,7 +196,7 @@ static const char * cu_get_error_str(CUresult err) {
         do {                                             \
             GGML_UNUSED(nbytes);                         \
         } while (0)
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
+#endif // !(defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
 
 #if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
 #define GGML_CUDA_ASSUME(x) __builtin_assume(x)
@@ -210,9 +216,9 @@ typedef float2 dfloat2;
 #define GGML_USE_VMM
 #endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
 
-#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
+#if defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
 #define FP16_AVAILABLE
-#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
+#endif // defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
 
 #if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
 #define FAST_FP16_AVAILABLE
@@ -226,13 +232,21 @@ typedef float2 dfloat2;
 #define FP16_MMA_AVAILABLE
 #endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
 
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
-#define NEW_MMA_AVAILABLE
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
+#if defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
+#define AMD_MFMA_AVAILABLE
+#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
+
+#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
+#define TURING_MMA_AVAILABLE
+#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
+
+#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#define AMPERE_MMA_AVAILABLE
+#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 #define CP_ASYNC_AVAILABLE
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 
 #if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
 #define FLASH_ATTN_AVAILABLE
@@ -254,7 +268,7 @@ static bool fast_fp16_hardware_available(const int cc) {
 
 // Any FP16 tensor core instructions are available for ggml code.
 static bool fp16_mma_available(const int cc) {
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
+#if defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
     return false;
 #else
     if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
@@ -270,7 +284,7 @@ static bool fp16_mma_available(const int cc) {
     } else {
         return false;
     }
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
+#endif // defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
 }
 
 // To be used for feature selection of external libraries, e.g. cuBLAS.
@@ -288,35 +302,47 @@ static bool fp32_mma_hardware_available(const int cc) {
     return GGML_CUDA_CC_IS_CDNA(cc);
 }
 
+static bool amd_mfma_available(const int cc) {
+#if !defined(GGML_HIP_NO_MMQ_MFMA)
+    return GGML_CUDA_CC_IS_CDNA(cc);
+#else
+    return false;
+#endif //!defined(GGML_HIP_NO_MMQ_MFMA)
+}
+
 // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
-static bool new_mma_available(const int cc) {
+static bool turing_mma_available(const int cc) {
     return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
 }
 
+static bool ampere_mma_available(const int cc) {
+    return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
+}
+
 static bool cp_async_available(const int cc) {
-    return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
+    return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
 }
 
 static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(__GFX9__) || defined(__GFX8__))
+#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
     return 64;
 #else
     return 32;
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(__GFX9__) || defined(__GFX8__))
+#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
 }
 
 [[noreturn]]
 static __device__ void no_device_code(
     const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
 
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#if defined(GGML_USE_HIP)
     printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
            file_name, line, function_name, arch);
     GGML_UNUSED(arch_list);
 #else
     printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
            file_name, line, function_name, arch, arch_list);
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#endif // defined(GGML_USE_HIP)
     __trap();
 
     GGML_UNUSED(no_device_code); // suppress unused function warning
@@ -353,7 +379,7 @@ struct ggml_cuda_unroll<1> {
 
 template
 static __device__ __forceinline__ int warp_reduce_sum(int x) {
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
     return __reduce_add_sync(0xffffffff, x);
 #else
 #pragma unroll
@@ -361,7 +387,7 @@ static __device__ __forceinline__ int warp_reduce_sum(int x) {
         x += __shfl_xor_sync(0xffffffff, x, offset, width);
     }
     return x;
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 }
 
 template
@@ -398,24 +424,18 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
 #endif // FP16_AVAILABLE
 }
 
-// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
-template
-static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
-    const int row = blockIdx.x;
-    const int col = threadIdx.x;
-
-    float sum = 0.0f;
-    for (int i = col; i < ncols; i += blockDim.x) {
-        sum += x[row * ncols + i];
-    }
-
-    sum = warp_reduce_sum(sum);
-
-    if (col != 0) {
-        return;
+template
+static __device__ __forceinline__ int warp_reduce_all(int x) {
+#ifdef GGML_USE_HIP
+#pragma unroll
+    for (int offset = width/2; offset > 0; offset >>= 1) {
+        x = x && __shfl_xor_sync(0xffffffff, x, offset, width);
     }
-
-    dst[row] = norm ? sum / ncols : sum;
+    return x;
+#else
+    static_assert(width == WARP_SIZE, "width != WARP_SIZE not implemented");
+    return __all_sync(0xffffffff, x);
+#endif // GGML_USE_HIP
 }
 
 template
@@ -430,11 +450,11 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
 static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
 #ifdef FP16_AVAILABLE
 
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
+#if !defined(GGML_USE_HIP) && CUDART_VERSION < CUDART_HMAX
     return __float2half(fmaxf(__half2float(a), __half2float(b)));
 #else
     return __hmax(a, b);
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
+#endif // !defined(GGML_USE_HIP) && CUDART_VERSION < CUDART_HMAX
 
 #else
    NO_DEVICE_CODE;
@@ -444,25 +464,21 @@ static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b
 }
 
 static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
-#if defined(GGML_USE_HIP) && HIP_VERSION >= 50700000
+#if defined(GGML_USE_HIP)
     return half2(__hmax(a.x, b.x), __hmax(a.y, b.y));
-#elif !defined(GGML_USE_HIP) && CUDART_VERSION >= CUDART_HMAX
+#elif CUDART_VERSION >= CUDART_HMAX
     return __hmax2(a, b);
-#elif !defined(GGML_USE_HIP)
+#else
     half2 ret;
     reinterpret_cast(ret.x) = __float2half(fmaxf( __low2float(a),  __low2float(b)));
     reinterpret_cast(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b)));
     return ret;
-#else
-    GGML_UNUSED(a);
-    GGML_UNUSED(b);
-    NO_DEVICE_CODE;
 #endif
 }
 
 template
 static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
+#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || defined(GGML_USE_HIP)
 #pragma unroll
    for (int offset = width/2; offset > 0; offset >>= 1) {
        x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, width));
@@ -471,7 +487,7 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
 #else
    GGML_UNUSED(x);
    NO_DEVICE_CODE;
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
+#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || defined(GGML_USE_HIP)
 }
 
 #if CUDART_VERSION < CUDART_HMASK
@@ -483,7 +499,7 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half
 #endif // CUDART_VERSION < CUDART_HMASK
 
 static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#if defined(GGML_USE_HIP)
 #if defined(CDNA) || defined(RDNA2) || defined(__gfx906__)
     c = __builtin_amdgcn_sdot4(a, b, c, false);
 #elif defined(RDNA3) || defined(RDNA4)
@@ -509,7 +525,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
 #endif
     return c;
 
-#else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#else // defined(GGML_USE_HIP)
 
 #if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
     return __dp4a(a, b, c);
@@ -519,7 +535,25 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
     return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
 #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
 
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#endif // defined(GGML_USE_HIP)
+}
+
+static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
+#if CUDART_VERSION >= 12080
+    const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x);
+    return (float) e;
+#else
+    uint32_t bits;
+    if (x == 0) {
+        bits = 0x00400000;
+    } else {
+        bits = (uint32_t) x << 23;
+    }
+
+    float result;
+    memcpy(&result, &bits, sizeof(float));
+    return result;
+#endif // CUDART_VERSION >= 12050
 }
 
 typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
@@ -580,6 +614,13 @@ struct ggml_cuda_type_traits {
     static constexpr int qi = QI8_0;
 };
 
+template<>
+struct ggml_cuda_type_traits {
+    static constexpr int qk = QK_MXFP4;
+    static constexpr int qr = QR_MXFP4;
+    static constexpr int qi = QI_MXFP4;
+};
+
 template<>
 struct ggml_cuda_type_traits {
     static constexpr int qk = QK_K;
@@ -765,7 +806,7 @@ struct ggml_tensor_extra_gpu {
 };
 
 
-#if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS))
+#if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)) || defined(GGML_MUSA_GRAPHS)
 #define USE_CUDA_GRAPH
 #endif
 
diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu
index eeaa14bf579..8f0efdcc126 100644
--- a/ggml/src/ggml-cuda/convert.cu
+++ b/ggml/src/ggml-cuda/convert.cu
@@ -6,24 +6,33 @@
 #define CUDA_Q8_0_NE_ALIGN 2048
 
 template 
-static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
-    const int64_t i = (int64_t)2*(blockDim.x*blockIdx.x + threadIdx.x);
+static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y,
+        const int64_t ne00, const int64_t ne01, const int64_t ne02,
+        const int64_t s01, const int64_t s02, const int64_t s03) {
+    const int64_t i00 = 2 * (int64_t(blockDim.x)*blockIdx.x + threadIdx.x);
 
-    if (i >= k) {
+    if (i00 >= ne00) {
         return;
     }
 
-    const int64_t ib = i/qk; // block index
-    const int64_t iqs = (i%qk)/qr; // quant index
-    const int64_t iybs = i - i%qk; // y block start index
+    const int64_t i01 = blockIdx.y;
+    const int64_t i02 = blockIdx.z % ne02;
+    const int64_t i03 = blockIdx.z / ne02;
+
+    const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
+
+    const int64_t ib = ibx0 + i00/qk; // block index
+    const int64_t iqs = (i00%qk)/qr; // quant index
+    const int64_t iybs = i00 - i00%qk; // y block start index
     const int64_t y_offset = qr == 1 ? 1 : qk/2;
 
     // dequantize
     dfloat2 v;
     dequantize_kernel(vx, ib, iqs, v);
 
-    y[iybs + iqs + 0]        = v.x;
-    y[iybs + iqs + y_offset] = v.y;
+    const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
+    y[iy0 + 0]        = ggml_cuda_cast(v.x);
+    y[iy0 + y_offset] = ggml_cuda_cast(v.y);
 }
 
 template 
@@ -456,10 +465,36 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
     }
 }
 
+template
+static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+    const int64_t i   = blockIdx.x;
+    const block_mxfp4 * x = (const block_mxfp4 *) vx + i*(QK_K/QK_MXFP4);
+
+    const int64_t tid = threadIdx.x;
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 4*il;
+    const uint8_t  * q4 = x[ib].qs + 4*il;
+    const float d = ggml_cuda_e8m0_to_fp32(x[ib].e);
+    for (int j = 0; j < 4; ++j) {
+        y[j+ 0] = d * kvalues_mxfp4[q4[j] & 0xf]*0.5f;
+        y[j+16] = d * kvalues_mxfp4[q4[j] >>  4]*0.5f;
+    }
+}
+
 template 
-static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
-    const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
-    dequantize_block<<>>(vx, y, k);
+static void dequantize_block_cuda(const void * vx, dst_t * y,
+        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+        const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
+    const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, ne02*ne03);
+    dequantize_block<<>>
+        (vx, y, ne00, ne01, ne02, s01, s02, s03);
+}
+
+template 
+static void dequantize_block_cont_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
+    dequantize_block_cuda(vx, y, k, 1, 1, 1, k/qk, k/qk, k/qk, stream);
 }
 
 static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
@@ -571,6 +606,12 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t
     dequantize_block_iq4_xs<<>>(vx, y);
 }
 
+template
+static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb = (k + QK_K - 1) / QK_K;
+    dequantize_block_mxfp4<<>>(vx, y);
+}
+
 template 
 static __global__ void convert_unary(
         const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
@@ -589,7 +630,7 @@ static __global__ void convert_unary(
 
     const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
     const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00;
-    y[iy] = float(x[ix]);
+    y[iy] = ggml_cuda_cast(x[ix]);
 }
 
 template 
@@ -624,14 +665,14 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
         case GGML_TYPE_Q4_1:
             return dequantize_row_q4_1_cuda;
         case GGML_TYPE_Q5_0:
-            return dequantize_block_cuda;
+            return dequantize_block_cont_cuda;
         case GGML_TYPE_Q5_1:
-            return dequantize_block_cuda;
+            return dequantize_block_cont_cuda;
         case GGML_TYPE_Q8_0:
             if (fp16_available(ggml_cuda_info().devices[ggml_cuda_get_device()].cc)) {
                 return dequantize_block_q8_0_f16_cuda;
             }
-            return dequantize_block_cuda;
+            return dequantize_block_cont_cuda;
         case GGML_TYPE_Q2_K:
             return dequantize_row_q2_K_cuda;
         case GGML_TYPE_Q3_K:
@@ -660,6 +701,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
             return dequantize_row_iq4_xs_cuda;
         case GGML_TYPE_IQ3_S:
             return dequantize_row_iq3_s_cuda;
+        case GGML_TYPE_MXFP4:
+            return dequantize_row_mxfp4_cuda;
         case GGML_TYPE_F32:
             return convert_unary_cont_cuda;
         case GGML_TYPE_BF16:
@@ -676,11 +719,11 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
         case GGML_TYPE_Q4_1:
             return dequantize_row_q4_1_cuda;
         case GGML_TYPE_Q5_0:
-            return dequantize_block_cuda;
+            return dequantize_block_cont_cuda;
         case GGML_TYPE_Q5_1:
-            return dequantize_block_cuda;
+            return dequantize_block_cont_cuda;
         case GGML_TYPE_Q8_0:
-            return dequantize_block_cuda;
+            return dequantize_block_cont_cuda;
         case GGML_TYPE_Q2_K:
             return dequantize_row_q2_K_cuda;
         case GGML_TYPE_Q3_K:
@@ -709,6 +752,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
             return dequantize_row_iq4_xs_cuda;
         case GGML_TYPE_IQ3_S:
             return dequantize_row_iq3_s_cuda;
+        case GGML_TYPE_MXFP4:
+            return dequantize_row_mxfp4_cuda;
         case GGML_TYPE_F16:
             return convert_unary_cont_cuda;
         case GGML_TYPE_BF16:
@@ -722,6 +767,16 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
     switch (type) {
         case GGML_TYPE_F32:
             return convert_unary_cuda;
+        case GGML_TYPE_Q4_0:
+            return dequantize_block_cuda;
+        case GGML_TYPE_Q4_1:
+            return dequantize_block_cuda;
+        case GGML_TYPE_Q5_0:
+            return dequantize_block_cuda;
+        case GGML_TYPE_Q5_1:
+            return dequantize_block_cuda;
+        case GGML_TYPE_Q8_0:
+            return dequantize_block_cuda;
         case GGML_TYPE_BF16:
             return convert_unary_cuda;
         default:
@@ -733,6 +788,16 @@ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
     switch (type) {
         case GGML_TYPE_F32:
             return convert_unary_cuda;
+        case GGML_TYPE_Q4_0:
+            return dequantize_block_cuda;
+        case GGML_TYPE_Q4_1:
+            return dequantize_block_cuda;
+        case GGML_TYPE_Q5_0:
+            return dequantize_block_cuda;
+        case GGML_TYPE_Q5_1:
+            return dequantize_block_cuda;
+        case GGML_TYPE_Q8_0:
+            return dequantize_block_cuda;
         case GGML_TYPE_F16:
             return convert_unary_cuda;
         default:
@@ -744,6 +809,16 @@ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
     switch (type) {
         case GGML_TYPE_F16:
             return convert_unary_cuda;
+        case GGML_TYPE_Q4_0:
+            return dequantize_block_cuda;
+        case GGML_TYPE_Q4_1:
+            return dequantize_block_cuda;
+        case GGML_TYPE_Q5_0:
+            return dequantize_block_cuda;
+        case GGML_TYPE_Q5_1:
+            return dequantize_block_cuda;
+        case GGML_TYPE_Q8_0:
+            return dequantize_block_cuda;
         case GGML_TYPE_BF16:
             return convert_unary_cuda;
         default:
diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh
index f04214be175..c62e8a1b104 100644
--- a/ggml/src/ggml-cuda/convert.cuh
+++ b/ggml/src/ggml-cuda/convert.cuh
@@ -29,3 +29,16 @@ typedef to_t_nc_cuda_t to_bf16_nc_cuda_t;
 to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type);
 to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type);
 to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type);
+
+template
+ __host__ __device__ inline dst_t ggml_cuda_cast(src_t x) {
+    if constexpr (std::is_same_v) {
+        return x;
+    } else if constexpr(std::is_same_v) {
+        return __float2bfloat16(float(x));
+    } else if constexpr(std::is_same_v) {
+        return __bfloat162float(x);
+    } else {
+        return float(x);
+    }
+}
diff --git a/ggml/src/ggml-cuda/cpy-utils.cuh b/ggml/src/ggml-cuda/cpy-utils.cuh
new file mode 100644
index 00000000000..e621cb9811a
--- /dev/null
+++ b/ggml/src/ggml-cuda/cpy-utils.cuh
@@ -0,0 +1,217 @@
+#pragma once
+
+#include "ggml-common.h"
+#include "convert.cuh"
+
+static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
+    if (x <= val[0]) return 0;
+    if (x >= val[n-1]) return n-1;
+    int ml = 0, mu = n-1;
+    while (mu-ml > 1) {
+        int mav = (ml+mu)/2;
+        if (x < val[mav]) mu = mav; else ml = mav;
+    }
+    return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
+}
+
+static __device__ void quantize_f32_q4_0_block(const float * __restrict__ x, block_q4_0 * __restrict__ y) {
+    float amax = 0.0f;
+    float vmax = 0.0f;
+
+    for (int j = 0; j < QK4_0; ++j) {
+        const float v = x[j];
+        if (amax < fabsf(v)) {
+            amax = fabsf(v);
+            vmax = v;
+        }
+    }
+
+    const float d  = vmax / -8;
+    const float id = d ? 1.0f/d : 0.0f;
+
+    y->d = d;
+
+    for (int j = 0; j < QK4_0/2; ++j) {
+        const float x0 = x[0       + j]*id;
+        const float x1 = x[QK4_0/2 + j]*id;
+
+        const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
+        const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
+
+        y->qs[j]  = xi0;
+        y->qs[j] |= xi1 << 4;
+    }
+}
+
+static __device__ void quantize_f32_q4_1_block(const float * __restrict__ x, block_q4_1 * __restrict__ y) {
+    float vmin = FLT_MAX;
+    float vmax = -FLT_MAX;
+
+    for (int j = 0; j < QK4_1; ++j) {
+        const float v = x[j];
+        if (v < vmin) vmin = v;
+        if (v > vmax) vmax = v;
+    }
+
+    const float d  = (vmax - vmin) / ((1 << 4) - 1);
+    const float id = d ? 1.0f/d : 0.0f;
+
+    y->dm.x = d;
+    y->dm.y = vmin;
+
+    for (int j = 0; j < QK4_1/2; ++j) {
+        const float x0 = (x[0       + j] - vmin)*id;
+        const float x1 = (x[QK4_1/2 + j] - vmin)*id;
+
+        const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
+        const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
+
+        y->qs[j]  = xi0;
+        y->qs[j] |= xi1 << 4;
+    }
+}
+
+static __device__ void quantize_f32_q5_0_block(const float * __restrict__ x, block_q5_0 * __restrict__ y) {
+    float amax = 0.0f;
+    float vmax = 0.0f;
+
+    for (int j = 0; j < QK5_0; ++j) {
+        const float v = x[j];
+        if (amax < fabsf(v)) {
+            amax = fabsf(v);
+            vmax = v;
+        }
+    }
+
+    const float d  = vmax / -16;
+    const float id = d ? 1.0f/d : 0.0f;
+
+    y->d = d;
+
+    uint32_t qh = 0;
+    for (int j = 0; j < QK5_0/2; ++j) {
+        const float x0 = x[0       + j]*id;
+        const float x1 = x[QK5_0/2 + j]*id;
+
+        const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
+        const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
+
+        y->qs[j]  = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+        qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+        qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
+    }
+    memcpy(y->qh, &qh, sizeof(qh));
+}
+
+static __device__ void quantize_f32_q5_1_block(const float * __restrict__ x, block_q5_1 * __restrict__ y) {
+    float min = x[0];
+    float max = x[0];
+
+    for (int j = 1; j < QK5_1; ++j) {
+        const float v = x[j];
+        min = v < min ? v : min;
+        max = v > max ? v : max;
+    }
+
+    const float d  = (max - min) / 31;
+    const float id = d ? 1.0f/d : 0.0f;
+
+    y->dm.x = d;
+    y->dm.y = min;
+
+    uint32_t qh = 0;
+    for (int j = 0; j < QK5_1/2; ++j) {
+        const float x0 = (x[0       + j] - min)*id;
+        const float x1 = (x[QK5_1/2 + j] - min)*id;
+
+        const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
+        const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
+
+        y->qs[j]  = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+        qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+        qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
+    }
+    memcpy(y->qh, &qh, sizeof(qh));
+}
+
+static __device__ void quantize_f32_q8_0_block(const float * __restrict__ x, block_q8_0 * __restrict__ y) {
+    float amax = 0.0f; // absolute max
+
+    for (int j = 0; j < QK8_0; j++) {
+        const float v = x[j];
+        amax = fmaxf(amax, fabsf(v));
+    }
+
+    const float d = amax / ((1 << 7) - 1);
+    const float id = d ? 1.0f/d : 0.0f;
+
+    y->d = d;
+
+    for (int j = 0; j < QK8_0; ++j) {
+        const float x0 = x[j]*id;
+        y->qs[j] = roundf(x0);
+    }
+}
+
+static __device__ void quantize_f32_iq4_nl_block(const float * __restrict__ x, block_iq4_nl * __restrict__ y) {
+    float amax = 0.0f;
+    float vmax = 0.0f;
+
+    for (int j = 0; j < QK4_NL; ++j) {
+        const float v = x[j];
+        if (amax < fabsf(v)) {
+            amax = fabsf(v);
+            vmax = v;
+        }
+    }
+
+    float d = vmax / kvalues_iq4nl[0];
+    const float id = d ? 1.0f/d : 0.0f;
+
+    float sumqx = 0, sumq2 = 0;
+    for (int j = 0; j < QK4_NL/2; ++j) {
+        const float x0 = x[0        + j]*id;
+        const float x1 = x[QK4_NL/2 + j]*id;
+        const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
+        const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
+        y->qs[j] = xi0 | (xi1 << 4);
+        const float v0 = kvalues_iq4nl[xi0];
+        const float v1 = kvalues_iq4nl[xi1];
+        const float w0 = x[0        + j]*x[0        + j];
+        const float w1 = x[QK4_NL/2 + j]*x[QK4_NL/2 + j];
+        sumqx += w0*v0*x[j] + w1*v1*x[QK4_NL/2 + j];
+        sumq2 += w0*v0*v0 + w1*v1*v1;
+    }
+
+    y->d = sumq2 > 0 ? sumqx/sumq2 : d;
+}
+
+// Wrapper functions for cpy.cu compatibility
+static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
+    quantize_f32_q4_0_block((const float *)cxi, (block_q4_0 *)cdsti);
+}
+
+static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
+    quantize_f32_q4_1_block((const float *)cxi, (block_q4_1 *)cdsti);
+}
+
+static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
+    quantize_f32_q5_0_block((const float *)cxi, (block_q5_0 *)cdsti);
+}
+
+static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
+    quantize_f32_q5_1_block((const float *)cxi, (block_q5_1 *)cdsti);
+}
+
+static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
+    quantize_f32_q8_0_block((const float *)cxi, (block_q8_0 *)cdsti);
+}
+
+static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
+    quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti);
+}
+
+template
+static __device__ void cpy_1_flt(const char * cxi, char * cdsti) {
+    *(dst_t *) cdsti = ggml_cuda_cast(*(const src_t *) cxi);
+}
diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu
index 2c55d2149b2..f9bb025643c 100644
--- a/ggml/src/ggml-cuda/cpy.cu
+++ b/ggml/src/ggml-cuda/cpy.cu
@@ -1,51 +1,17 @@
 #include "cpy.cuh"
 #include "dequantize.cuh"
-#ifdef GGML_USE_MUSA
+#include "cpy-utils.cuh"
+#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
 #include "ggml-musa/mudnn.cuh"
-#endif // GGML_USE_MUSA
+#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
 
 typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
 
-static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
-    const float * xi = (const float *) cxi;
-    float * dsti = (float *) cdsti;
-
-    *dsti = *xi;
-}
-
-static __device__ void cpy_1_f32_bf16(const char * cxi, char * cdsti) {
-    const float * xi = (const float *) cxi;
-    nv_bfloat16 * dsti = (nv_bfloat16 *) cdsti;
-
-    *dsti = *xi;
-}
-
-static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
-    const float * xi = (const float *) cxi;
-    half * dsti = (half *) cdsti;
-
-    *dsti = __float2half(*xi);
-}
-
-static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
-    const half * xi = (const half *) cxi;
-    half * dsti = (half *) cdsti;
-
-    *dsti = *xi;
-}
-
-static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
-    const half * xi = (const half *) cxi;
-    float * dsti = (float *) cdsti;
-
-    *dsti = *xi;
-}
-
 template 
-static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne,
-                                   const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-                                   const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
-                                   const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
+static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne,
+                               const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+                               const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
+                               const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
     const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= ne) {
@@ -71,29 +37,6 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const in
     cpy_1(cx + x_offset, cdst + dst_offset);
 }
 
-static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
-    const float * xi = (const float *) cxi;
-    block_q8_0 * dsti = (block_q8_0 *) cdsti;
-
-    float amax = 0.0f; // absolute max
-
-    for (int j = 0; j < QK8_0; j++) {
-        const float v = xi[j];
-        amax = fmaxf(amax, fabsf(v));
-    }
-
-    const float d = amax / ((1 << 7) - 1);
-    const float id = d ? 1.0f/d : 0.0f;
-
-    dsti->d = d;
-
-    for (int j = 0; j < QK8_0; ++j) {
-        const float x0 = xi[j]*id;
-
-        dsti->qs[j] = roundf(x0);
-    }
-}
-
 static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
     float * cdstf = (float *)(cdsti);
 
@@ -106,139 +49,6 @@ static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
     }
 }
 
-static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
-    const float * xi = (const float *) cxi;
-    block_q4_0 * dsti = (block_q4_0 *) cdsti;
-
-    float amax = 0.0f;
-    float vmax = 0.0f;
-
-    for (int j = 0; j < QK4_0; ++j) {
-        const float v = xi[j];
-        if (amax < fabsf(v)) {
-            amax = fabsf(v);
-            vmax = v;
-        }
-    }
-
-    const float d  = vmax / -8;
-    const float id = d ? 1.0f/d : 0.0f;
-
-    dsti->d = d;
-
-    for (int j = 0; j < QK4_0/2; ++j) {
-        const float x0 = xi[0       + j]*id;
-        const float x1 = xi[QK4_0/2 + j]*id;
-
-        const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
-        const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
-
-        dsti->qs[j]  = xi0;
-        dsti->qs[j] |= xi1 << 4;
-    }
-}
-
-static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
-    const float * xi = (const float *) cxi;
-    block_q4_1 * dsti = (block_q4_1 *) cdsti;
-
-    float vmin = FLT_MAX;
-    float vmax = -FLT_MAX;
-
-    for (int j = 0; j < QK4_1; ++j) {
-        const float v = xi[j];
-
-        if (v < vmin) vmin = v;
-        if (v > vmax) vmax = v;
-    }
-
-    const float d  = (vmax - vmin) / ((1 << 4) - 1);
-    const float id = d ? 1.0f/d : 0.0f;
-
-    dsti->dm.x = d;
-    dsti->dm.y = vmin;
-
-    for (int j = 0; j < QK4_1/2; ++j) {
-        const float x0 = (xi[0       + j] - vmin)*id;
-        const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
-
-        const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
-        const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
-
-        dsti->qs[j]  = xi0;
-        dsti->qs[j] |= xi1 << 4;
-    }
-}
-
-static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
-    const float * xi = (const float *) cxi;
-    block_q5_0 * dsti = (block_q5_0 *) cdsti;
-
-    float amax = 0.0f;
-    float vmax = 0.0f;
-
-    for (int j = 0; j < QK5_0; ++j) {
-        const float v = xi[j];
-        if (amax < fabsf(v)) {
-            amax = fabsf(v);
-            vmax = v;
-        }
-    }
-
-    const float d  = vmax / -16;
-    const float id = d ? 1.0f/d : 0.0f;
-
-    dsti->d = d;
-
-    uint32_t qh = 0;
-    for (int j = 0; j < QK5_0/2; ++j) {
-        const float x0 = xi[0       + j]*id;
-        const float x1 = xi[QK5_0/2 + j]*id;
-
-        const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
-        const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
-
-        dsti->qs[j]  = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
-        qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
-        qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
-    }
-    memcpy(dsti->qh, &qh, sizeof(qh));
-}
-
-static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
-    const float * xi = (const float *) cxi;
-    block_q5_1 * dsti = (block_q5_1 *) cdsti;
-
-    float min = xi[0];
-    float max = xi[0];
-
-    for (int j = 1; j < QK5_1; ++j) {
-        const float v = xi[j];
-        min = v < min ? v : min;
-        max = v > max ? v : max;
-    }
-
-    const float d  = (max - min) / 31;
-    const float id = d ? 1.0f/d : 0.0f;
-
-    dsti->dm.x = d;
-    dsti->dm.y = min;
-
-    uint32_t qh = 0;
-    for (int j = 0; j < QK5_1/2; ++j) {
-        const float x0 = (xi[0       + j] - min)*id;
-        const float x1 = (xi[QK5_1/2 + j] - min)*id;
-
-        const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
-        const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
-
-        dsti->qs[j]  = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
-        qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
-        qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
-    }
-    memcpy(dsti->qh, &qh, sizeof(qh));
-}
-
 template
 static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
     float * cdstf = (float *)(cdsti);
@@ -252,53 +62,6 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
     }
 }
 
-static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
-    if (x <= val[0]) return 0;
-    if (x >= val[n-1]) return n-1;
-    int ml = 0, mu = n-1;
-    while (mu-ml > 1) {
-        int mav = (ml+mu)/2;
-        if (x < val[mav]) mu = mav; else ml = mav;
-    }
-    return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
-}
-
-static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
-    const float * xi = (const float *) cxi;
-    block_iq4_nl * dsti = (block_iq4_nl *) cdsti;
-
-    float amax = 0.0f;
-    float vmax = 0.0f;
-
-    for (int j = 0; j < QK4_NL; ++j) {
-        const float v = xi[j];
-        if (amax < fabsf(v)) {
-            amax = fabsf(v);
-            vmax = v;
-        }
-    }
-
-    float d = vmax / kvalues_iq4nl[0];
-    const float id = d ? 1.0f/d : 0.0f;
-
-    float sumqx = 0, sumq2 = 0;
-    for (int j = 0; j < QK4_NL/2; ++j) {
-        const float x0 = xi[0        + j]*id;
-        const float x1 = xi[QK4_NL/2 + j]*id;
-        const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
-        const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
-        dsti->qs[j] = xi0 | (xi1 << 4);
-        const float v0 = kvalues_iq4nl[xi0];
-        const float v1 = kvalues_iq4nl[xi1];
-        const float w0 = xi[0        + j]*xi[0        + j];
-        const float w1 = xi[QK4_NL/2 + j]*xi[QK4_NL/2 + j];
-        sumqx += w0*v0*xi[j] + w1*v1*xi[QK4_NL/2 + j];
-        sumq2 += w0*v0*v0 + w1*v1*v1;
-    }
-
-    dsti->d = sumq2 > 0 ? sumqx/sumq2 : d;
-}
-
 template 
 static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne,
                                  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -358,7 +121,7 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int
 // Copy destination pointers to GPU to be available when pointer indirection is in use
 
 void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream) {
-#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
+#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
     if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers
         CUDA_CHECK(cudaStreamSynchronize(stream));
         if (cuda_graph->dest_ptrs_d != nullptr) {
@@ -376,43 +139,14 @@ void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_des
 #endif
 }
 
-static void ggml_cpy_f16_f32_cuda(
-    const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
-
-    const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
-    cpy_f32_f16<<>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
-}
-
-static void ggml_cpy_f32_f32_cuda(
+template
+static void ggml_cpy_flt_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
     const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
 
     const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
-    cpy_f32_f16<<>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
-}
-
-static void ggml_cpy_f32_bf16_cuda(
-    const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
-
-    const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
-    cpy_f32_f16<<>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
-}
-
-static void ggml_cpy_f32_f16_cuda(
-    const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
-
-    const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
-    cpy_f32_f16<<>>
+    cpy_flt><<>>
         (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
@@ -544,16 +278,6 @@ static void ggml_cpy_f32_iq4_nl_cuda(
         (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
-static void ggml_cpy_f16_f16_cuda(
-    const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
-
-    const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
-    cpy_f32_f16<<>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
-}
-
 void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) {
     const int64_t ne = ggml_nelements(src0);
     GGML_ASSERT(ne == ggml_nelements(src1));
@@ -590,7 +314,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
 
     char ** dest_ptrs_d = nullptr;
     int graph_cpynode_index = -1;
-#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
+#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
     if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
         dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d;
         graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
@@ -600,20 +324,20 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
 #endif
     if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
         GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
-#ifdef GGML_USE_MUSA
+#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
         if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) {
             CUDA_CHECK(mudnnMemcpyAsync(ctx, src1, src0));
         } else
-#endif // GGML_USE_MUSA
+#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
         {
             CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
         }
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
-        ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+        ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
-        ggml_cpy_f32_bf16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+        ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
-        ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+        ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
         ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -640,14 +364,22 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
     } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
         ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
-        ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+        ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
+        ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
-        ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+        ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+    } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
+        ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+    } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
+        ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+    } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
+        ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else {
         GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
                 ggml_type_name(src0->type), ggml_type_name(src1->type));
     }
-#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
+#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
     if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
         ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
     }
@@ -667,11 +399,11 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
     if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
         return nullptr;
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
-        return (void*) cpy_f32_f16;
+        return (void*) cpy_flt>;
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
-        return (void*) cpy_f32_f16;
+        return (void*) cpy_flt>;
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
-        return (void*) cpy_f32_f16;
+        return (void*) cpy_flt>;
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
         return (void*) cpy_f32_q;
     } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -695,9 +427,17 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
     } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
         return (void*) cpy_q_f32, QK5_1>;
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
-        return (void*) cpy_f32_f16;
+        return (void*) cpy_flt>;
+    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
+        return (void*) cpy_flt>;
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
-        return (void*) cpy_f32_f16;
+        return (void*) cpy_flt>;
+    } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
+        return (void*) cpy_flt>;
+    } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
+        return (void*) cpy_flt>;
+    } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
+        return (void*) cpy_flt>;
     } else {
         GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
                 ggml_type_name(src0->type), ggml_type_name(src1->type));
diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh
index 075f14a49e9..d4ed938391b 100644
--- a/ggml/src/ggml-cuda/fattn-common.cuh
+++ b/ggml/src/ggml-cuda/fattn-common.cuh
@@ -15,6 +15,8 @@ typedef void (* fattn_kernel_t)(
         const char * __restrict__ K,
         const char * __restrict__ V,
         const char * __restrict__ mask,
+        const char * __restrict__ sinks,
+        const int  * __restrict__ KV_max,
         float      * __restrict__ dst,
         float2     * __restrict__ dst_meta,
         const float scale,
@@ -23,31 +25,13 @@ typedef void (* fattn_kernel_t)(
         const float m1,
         const uint32_t n_head_log2,
         const float logit_softcap,
-        const int ne00,
-        const int ne01,
-        const int ne02,
-        const int ne03,
-        const int ne10,
-        const int ne11,
-        const int ne12,
-        const int ne13,
-        const int ne31,
-        const int ne32,
-        const int nb31,
-        const int nb32,
-        const int nb01,
-        const int nb02,
-        const int nb03,
-        const int nb11,
-        const int nb12,
-        const int nb13,
-        const int nb21,
-        const int nb22,
-        const int nb23,
-        const int ne0,
-        const int ne1,
-        const int ne2,
-        const int ne3);
+        const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+                            const int32_t nb01, const int32_t nb02, const int32_t nb03,
+        const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
+                            const int32_t nb11, const int32_t nb12, const int64_t nb13,
+                            const int32_t nb21, const int32_t nb22, const int64_t nb23,
+                            const int32_t ne31, const int32_t ne32, const int32_t ne33,
+                            const int32_t nb31, const int32_t nb32, const int64_t nb33);
 
 typedef half (*vec_dot_KQ_f16_t)(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
@@ -518,10 +502,63 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
         nullptr;
 }
 
+template 
+__launch_bounds__(FATTN_KQ_STRIDE/2, 1)
+static __global__ void flash_attn_mask_to_KV_max(
+        const half2 * __restrict__ mask, int * __restrict__ KV_max, const int ne30, const int s31, const int s33) {
+    const int ne31     = gridDim.x;
+    const int tid      = threadIdx.x;
+    const int sequence = blockIdx.y;
+    const int jt       = blockIdx.x;
+
+    mask += sequence*s33 + jt*ncols1*s31;
+
+    __shared__ int buf_iw[WARP_SIZE];
+    if (tid < WARP_SIZE) {
+        buf_iw[tid] = 1;
+    }
+    __syncthreads();
+
+    int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE;
+    for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) {
+        int all_inf = 1;
+
+#pragma unroll
+        for (int j = 0; j < ncols1; ++j) {
+            const float2 tmp = __half22float2(mask[j*s31 + KV_max_sj/2 + tid]);
+            all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y));
+        }
+
+        all_inf = warp_reduce_all(all_inf);
+        if (tid % WARP_SIZE == 0) {
+            buf_iw[tid / WARP_SIZE] = all_inf;
+        }
+        __syncthreads();
+        all_inf = buf_iw[tid % WARP_SIZE];
+        __syncthreads();
+        all_inf = warp_reduce_all(all_inf);
+
+        if (!all_inf) {
+            break;
+        }
+    }
+
+    // If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE.
+    // If the break was triggered it's the lower edge of the tile with the first non-masked values.
+    // In either case, walk back the decrementation by FATTN_KQ_STRIDE.
+    KV_max_sj += FATTN_KQ_STRIDE;
+
+    if (threadIdx.x != 0) {
+        return;
+    }
+
+    KV_max[sequence*ne31 + jt] = KV_max_sj;
+}
+
 template // D == head size
 __launch_bounds__(D, 1)
 static __global__ void flash_attn_stream_k_fixup(
-        float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
+        float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
     constexpr int ncols = ncols1*ncols2;
 
     const int bidx0 = blockIdx.x;
@@ -535,8 +572,8 @@ static __global__ void flash_attn_stream_k_fixup(
     const int iter_k = ne11 / FATTN_KQ_STRIDE;
     const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
 
-    const int kbc0      = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
-    const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
+    const int kbc0      = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
+    const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
 
     const bool did_not_have_any_data   = kbc0 == kbc0_stop;
     const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@@ -545,14 +582,15 @@ static __global__ void flash_attn_stream_k_fixup(
         return;
     }
 
-    const int channel = kbc0 / (iter_k*iter_j);
-    const int jt      = (kbc0 - channel*iter_k*iter_j) / iter_k;
+    const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
+    const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
+    const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
 
     if (jt*ncols1 + j >= ne01) {
         return;
     }
 
-    dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
+    dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
 
     // Load the partial result that needs a fixup:
     float dst_val = 0.0f;
@@ -571,7 +609,7 @@ static __global__ void flash_attn_stream_k_fixup(
     int bidx = bidx0 - 1;
     int kbc_stop = kbc0;
     while(true) {
-        const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
+        const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
         if (kbc == kbc_stop) { // Did not have any data.
             bidx--;
             kbc_stop = kbc;
@@ -609,24 +647,39 @@ static __global__ void flash_attn_stream_k_fixup(
 }
 
 template // D == head size
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#if !defined(GGML_USE_HIP)
 __launch_bounds__(D, 1)
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#endif // !(defined(GGML_USE_HIP)
 static __global__ void flash_attn_combine_results(
         const float  * __restrict__ VKQ_parts,
         const float2 * __restrict__ VKQ_meta,
         float * __restrict__ dst,
         const int parallel_blocks) {
-    VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x;
-    VKQ_meta  += parallel_blocks   * gridDim.z*blockIdx.x;
-    dst       +=                 D * gridDim.z*blockIdx.x;
+    // Dimension 0: threadIdx.x
+    // Dimension 1: blockIdx.x
+    // Dimension 2: blockIdx.y
+    // Dimension 3: blockIdx.z
+    // Memory layout is permuted with [0, 2, 1, 3]
+
+    const int ne01 = gridDim.x;
+    const int ne02 = gridDim.y;
+
+    const int col      = blockIdx.x;
+    const int head     = blockIdx.y;
+    const int sequence = blockIdx.z;
+
+    const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
+
+    VKQ_parts += j_dst_unrolled * parallel_blocks*D;
+    VKQ_meta  += j_dst_unrolled * parallel_blocks;
+    dst       += j_dst_unrolled *                 D;
 
     const int tid = threadIdx.x;
     __builtin_assume(tid < D);
 
     extern __shared__ float2 meta[];
     for (int i = tid; i < 2*parallel_blocks; i += D) {
-        ((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i];
+        ((float *) meta)[i] = ((const float *)VKQ_meta) [i];
     }
 
     __syncthreads();
@@ -644,11 +697,11 @@ static __global__ void flash_attn_combine_results(
         const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
         *((uint32_t *) &KQ_max_scale) &= ftz_mask;
 
-        VKQ_numerator   += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid];
+        VKQ_numerator   += KQ_max_scale * VKQ_parts[l*D + tid];
         VKQ_denominator += KQ_max_scale * meta[l].y;
     }
 
-    dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
+    dst[tid] = VKQ_numerator / VKQ_denominator;
 }
 
 [[noreturn]]
@@ -688,7 +741,8 @@ void launch_fattn(
 
     GGML_ASSERT(V || is_mla);
 
-    const ggml_tensor * mask = dst->src[3];
+    const ggml_tensor * mask  = dst->src[3];
+    const ggml_tensor * sinks = dst->src[4];
 
     ggml_tensor * KQV = dst;
 
@@ -705,8 +759,6 @@ void launch_fattn(
 
     GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
 
-    GGML_ASSERT(Q->ne[3] == 1);
-
     ggml_cuda_pool & pool = ctx.pool();
     cudaStream_t main_stream = ctx.stream();
     const int id  = ggml_cuda_get_device();
@@ -715,6 +767,7 @@ void launch_fattn(
 
     ggml_cuda_pool_alloc   K_f16(pool);
     ggml_cuda_pool_alloc   V_f16(pool);
+    ggml_cuda_pool_alloc    KV_max(pool);
     ggml_cuda_pool_alloc  dst_tmp(pool);
     ggml_cuda_pool_alloc dst_tmp_meta(pool);
 
@@ -729,40 +782,84 @@ void launch_fattn(
     size_t nb23 = V ? V->nb[3] : nb13;
 
     if (need_f16_K && K->type != GGML_TYPE_F16) {
-        GGML_ASSERT(ggml_is_contiguously_allocated(K));
-        K_f16.alloc(ggml_nelements(K));
-        to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
-        to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
-        K_data = (char *) K_f16.ptr;
-
         const size_t bs = ggml_blck_size(K->type);
         const size_t ts = ggml_type_size(K->type);
 
-        nb11 = nb11*bs*sizeof(half)/ts;
-        nb12 = nb12*bs*sizeof(half)/ts;
-        nb13 = nb13*bs*sizeof(half)/ts;
+        K_f16.alloc(ggml_nelements(K));
+        if (ggml_is_contiguously_allocated(K)) {
+            to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
+            to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
+
+            nb11 = nb11*bs*sizeof(half)/ts;
+            nb12 = nb12*bs*sizeof(half)/ts;
+            nb13 = nb13*bs*sizeof(half)/ts;
+        } else {
+            GGML_ASSERT(K->nb[0] == ts);
+            to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(K->type);
+            const int64_t s01 = nb11 / ts;
+            const int64_t s02 = nb12 / ts;
+            const int64_t s03 = nb13 / ts;
+            to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);
+
+            nb11 = K->ne[0] * sizeof(half);
+            nb12 = K->ne[1] * nb11;
+            nb13 = K->ne[2] * nb12;
+        }
+        K_data = (char *) K_f16.ptr;
     }
 
     if (V && need_f16_V && V->type != GGML_TYPE_F16) {
-        GGML_ASSERT(ggml_is_contiguously_allocated(V));
-        V_f16.alloc(ggml_nelements(V));
-        to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
-        to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
-        V_data = (char *) V_f16.ptr;
-
         const size_t bs = ggml_blck_size(V->type);
         const size_t ts = ggml_type_size(V->type);
 
-        nb21 = nb21*bs*sizeof(half)/ts;
-        nb22 = nb22*bs*sizeof(half)/ts;
-        nb23 = nb23*bs*sizeof(half)/ts;
+        V_f16.alloc(ggml_nelements(V));
+        if (ggml_is_contiguously_allocated(V)) {
+            to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
+            to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
+            V_data = (char *) V_f16.ptr;
+
+            nb21 = nb21*bs*sizeof(half)/ts;
+            nb22 = nb22*bs*sizeof(half)/ts;
+            nb23 = nb23*bs*sizeof(half)/ts;
+        } else {
+            GGML_ASSERT(V->nb[0] == ts);
+            to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
+            const int64_t s01 = nb21 / ts;
+            const int64_t s02 = nb22 / ts;
+            const int64_t s03 = nb23 / ts;
+            to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
+
+            nb21 = V->ne[0] * sizeof(half);
+            nb22 = V->ne[1] * nb21;
+            nb23 = V->ne[2] * nb22;
+        }
+        V_data = (char *) V_f16.ptr;
     }
 
-    int parallel_blocks = 1;
-
     const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
     const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
 
+    // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
+    // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
+    //     multiple sequences of possibly different lengths.
+    if (mask && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
+        const int s31 = mask->nb[1] / sizeof(half2);
+        const int s33 = mask->nb[3] / sizeof(half2);
+
+        const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);
+        const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1);
+
+        const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y;
+        const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;
+
+        KV_max.alloc(ne_KV_max);
+        flash_attn_mask_to_KV_max<<>>
+            ((const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33);
+        CUDA_CHECK(cudaGetLastError());
+    }
+
+    int parallel_blocks = 1;
+
     const dim3 block_dim(warp_size, nwarps, 1);
     int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
     CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
@@ -849,16 +946,15 @@ void launch_fattn(
         K_data,
         V_data,
         mask ? ((const char *) mask->data) : nullptr,
+        sinks ? ((const char *) sinks->data) : nullptr,
+        KV_max.ptr,
         !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
         scale, max_bias, m0, m1, n_head_log2, logit_softcap,
-        Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
-        K->ne[0], K->ne[1], K->ne[2], K->ne[3],
-        mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0,
-        mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
-        Q->nb[1], Q->nb[2], Q->nb[3],
-        nb11, nb12, nb13,
+        Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
+        K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
         nb21, nb22, nb23,
-        KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
+        mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
+        mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0
     );
     CUDA_CHECK(cudaGetLastError());
 
@@ -869,11 +965,11 @@ void launch_fattn(
 
             flash_attn_stream_k_fixup
                 <<>>
-                ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
+                ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
         }
     } else if (parallel_blocks > 1) {
         const dim3 block_dim_combine(DV, 1, 1);
-        const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
+        const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
         const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
 
         flash_attn_combine_results
diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
index 709589854f0..39731baaeb7 100644
--- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh
+++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
@@ -392,7 +392,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
     }
 }
 
-template
+template
 static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         const float2 * const __restrict__ Q_f2,
         const half2  * const __restrict__ K_h2,
@@ -408,7 +409,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         const int stride_K,
         const int stride_V,
         const int stride_mask,
-        const int jt,
         half2        * const __restrict__ tile_Q,
         half2        * const __restrict__ tile_K,
         half2        * const __restrict__ tile_V,
@@ -418,7 +418,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         float        * const __restrict__ KQ_max,
         float        * const __restrict__ KQ_rowsum,
         const int kb0) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
     typedef fattn_mma_f16_config c;
 
 #ifdef CP_ASYNC_AVAILABLE
@@ -455,7 +455,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         cp_async_wait_all();
         __syncthreads();
         flash_attn_ext_f16_load_tile
-            (V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
+            (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V);
     } else {
         constexpr bool use_cp_async = nstages == 1;
         if (ncols2 > 1 || mask_h2) {
@@ -471,7 +471,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         if (nstages <= 1) {
             constexpr bool use_cp_async = nstages == 1;
             flash_attn_ext_f16_load_tile
-                (K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K);
+                (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K);
             if (use_cp_async) {
                 cp_async_wait_all();
             }
@@ -715,7 +715,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                     (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
             }
             flash_attn_ext_f16_load_tile
-                (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
+                (K_h2 + int64_t(k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
         }
     }
 
@@ -732,7 +732,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         if (nstages <= 1 && i0_start < reusable_cutoff) {
             constexpr bool use_cp_async = nstages == 1;
             flash_attn_ext_f16_load_tile
-                (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
+                (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
             if (use_cp_async) {
                 cp_async_wait_all();
             }
@@ -771,13 +771,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
     GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
     GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
     GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V);
-    GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
-    GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
+    GGML_UNUSED(stride_mask); GGML_UNUSED(tile_K);
     GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
     GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
     GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
     NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
 }
 
 template
@@ -786,6 +785,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         const half2  * const __restrict__ K_h2,
         const half2  * const __restrict__ V_h2,
         const half2  * const __restrict__ mask_h2,
+        const float  * const __restrict__ sinks_f,
         float2       * const __restrict__ dstk,
         float2       * const __restrict__ dstk_fixup,
         const float scale,
@@ -801,7 +801,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         const int jt,
         const int kb0_start,
         const int kb0_stop) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
     //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
     typedef fattn_mma_f16_config c;
@@ -920,21 +920,22 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
                 (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
         }
         flash_attn_ext_f16_load_tile
-            (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
+            (K_h2 + int64_t(kb0_start)*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
     }
 
     // Iterate over ne11 == previous tokens:
-    for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
+    int kb0 = kb0_start;
+    for (; kb0 < kb0_stop-1; ++kb0) {
         constexpr bool last_iter = false;
         flash_attn_ext_f16_iter
             (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
-             ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
+             ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
     }
     { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
         constexpr bool last_iter = true;
         flash_attn_ext_f16_iter
             (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
-             ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
+             ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
     }
 
     // With multi-stage loading there is no __syncthreads at the end of the iter,
@@ -957,6 +958,52 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         }
     }
 
+    // If attention sinks are used, potentially re-scale if KQ_max is small.
+    // Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
+    //     so it's being done unconditionally for every thread.
+    if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
+        float KQ_max_scale[cols_per_thread];
+#pragma unroll
+        for (int col = 0; col < cols_per_thread; ++col) {
+            static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented");
+            const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
+            const float sink = sinks_f[jc % ncols2];
+
+            const float KQ_max_new = fmaxf(KQ_max[col], sink);
+            const float KQ_max_diff = KQ_max[col] - KQ_max_new;
+            KQ_max_scale[col] = expf(KQ_max_diff);
+            KQ_max[col] = KQ_max_new;
+
+            *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
+
+            const float KQ_max_add = expf(sink - KQ_max_new);
+            KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
+        }
+
+        if (ntiles == 1) {
+            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
+#pragma unroll
+            for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
+#pragma unroll
+                for (int l = 0; l < tile_C_VKQ::ne; ++l) {
+                    VKQ_C[i].x[l] *= KQ_max_scale_h2;
+                }
+            }
+        } else {
+#pragma unroll
+            for (int col = 0; col < cols_per_thread; ++col) {
+                const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
+#pragma unroll
+                for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) {
+#pragma unroll
+                    for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
+                        VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
+                    }
+                }
+            }
+        }
+    }
+
     // Combine VKQ accumulator values if np > 1.
     // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
     // So also write VKQ accumulators to shared memory in column-major format if np == 1.
@@ -1196,7 +1243,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
     GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask);
     GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop);
     NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
 }
 
 template
@@ -1206,6 +1253,8 @@ static __global__ void flash_attn_ext_f16(
         const char * __restrict__ K,
         const char * __restrict__ V,
         const char * __restrict__ mask,
+        const char * __restrict__ sinks,
+        const int  * __restrict__ KV_max,
         float      * __restrict__ dst,
         float2     * __restrict__ dst_meta,
         const float scale,
@@ -1214,32 +1263,14 @@ static __global__ void flash_attn_ext_f16(
         const float m1,
         const uint32_t n_head_log2,
         const float logit_softcap,
-        const int ne00,
-        const int ne01,
-        const int ne02,
-        const int ne03,
-        const int ne10,
-        const int ne11,
-        const int ne12,
-        const int ne13,
-        const int ne31,
-        const int ne32,
-        const int nb31,
-        const int nb32,
-        const int nb01,
-        const int nb02,
-        const int nb03,
-        const int nb11,
-        const int nb12,
-        const int nb13,
-        const int nb21,
-        const int nb22,
-        const int nb23,
-        const int ne0,
-        const int ne1,
-        const int ne2,
-        const int ne3) {
-#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
+        const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+                            const int32_t nb01, const int32_t nb02, const int32_t nb03,
+        const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
+                            const int32_t nb11, const int32_t nb12, const int64_t nb13,
+                            const int32_t nb21, const int32_t nb22, const int64_t nb23,
+                            const int32_t ne31, const int32_t ne32, const int32_t ne33,
+                            const int32_t nb31, const int32_t nb32, const int64_t nb33) {
+#if defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)
 
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
@@ -1274,8 +1305,8 @@ static __global__ void flash_attn_ext_f16(
     constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
 
     // kbc == k block continuous, current index in continuous ijk space.
-    int       kbc      = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
-    const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
+    int       kbc      = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
+    const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
 
     // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
     // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@@ -1284,33 +1315,42 @@ static __global__ void flash_attn_ext_f16(
     // kb0 == k start index when in the output tile.
     int kb0_start = kbc % iter_k;
     int kb0_stop  = min(iter_k, kb0_start + kbc_stop - kbc);
+
     while (kbc < kbc_stop && kb0_stop == iter_k) {
-        const int channel = kbc / (iter_k*iter_j);
-        const int jt      = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
+        const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
+        const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
+        const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
+
+        const int head0 = zt * ncols2;
 
-        const float2 * Q_f2    = (const float2 *) (Q + nb02* channel*ncols2);
-        const half2  * K_h2    = (const half2  *) (K + nb12*(channel*ncols2 / gqa_ratio));
+        const float2 * Q_f2    = (const float2 *) (Q + nb03*sequence + nb02* head0);
+        const half2  * K_h2    = (const half2  *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
         const half2  * mask_h2 = ncols2 == 1 && !mask ? nullptr :
-            (const half2  *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
-        float2       * dstk    = ((float2 *) dst) + channel*(ncols2 * DV/2);
+            (const half2  *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
+        float2       * dstk    = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);
 
-        const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
+        const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
+        const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
 
-        const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
+        const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
 
         const int kb0_start_kernel = kb0_start * kb_niter;
-        const int kb0_stop_kernel  = kb0_stop  * kb_niter;
+        int       kb0_stop_kernel  = kb0_stop  * kb_niter;
+
+        if (KV_max) {
+            kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa);
+        }
 
         constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
         if (kb0_start == 0) {
             constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
             flash_attn_ext_f16_process_tile
-                (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
+                (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
                  ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
         } else {
             constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
             flash_attn_ext_f16_process_tile
-                (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
+                (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
                  ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
         }
 
@@ -1325,40 +1365,49 @@ static __global__ void flash_attn_ext_f16(
         return;
     }
 
-    const int channel = kbc / (iter_k*iter_j);
-    const int jt      = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
+    const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
+    const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
+    const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
+
+    const int head0 = zt * ncols2;
 
-    const float2 * Q_f2    = (const float2 *) (Q + nb02* channel*ncols2);
-    const half2  * K_h2    = (const half2  *) (K + nb12*(channel*ncols2 / gqa_ratio));
+    const float2 * Q_f2    = (const float2 *) (Q + nb03*sequence + nb02* head0);
+    const half2  * K_h2    = (const half2  *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
     const half2  * mask_h2 = ncols2 == 1 && !mask ? nullptr :
-        (const half2  *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
-    float2       * dstk    = ((float2 *) dst) + channel*(ncols2 * DV/2);
+        (const half2  *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
+    float2       * dstk    = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);
 
-    const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
+    const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
+    const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
 
-    const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
+    const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
 
     const int kb0_start_kernel = kb0_start * kb_niter;
-    const int kb0_stop_kernel  = kb0_stop  * kb_niter;
+    int       kb0_stop_kernel  = kb0_stop  * kb_niter;
+
+    if (KV_max) {
+        kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa);
+    }
 
     constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
     constexpr bool needs_fixup = false;
     flash_attn_ext_f16_process_tile
-        (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
+        (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
          ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
 #else
-    GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
-    GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
-    GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
-    GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
-    GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
-    GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
-    GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
-    GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
-    GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
-    GGML_UNUSED(ne2); GGML_UNUSED(ne3);
+    GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
+    GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
+    GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
+    GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
+    GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
+    GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
+    GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
+    GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
+    GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
+    GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
+    GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
     NO_DEVICE_CODE;
-#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
+#endif // defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)
 }
 
 template 
@@ -1408,24 +1457,24 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
         constexpr bool use_logit_softcap = false;
         fattn_kernel = flash_attn_ext_f16;
 
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
         static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
         if (!shared_memory_limit_raised[id]) {
             CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
             shared_memory_limit_raised[id] = true;
         }
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
     } else {
         constexpr bool use_logit_softcap = true;
         fattn_kernel = flash_attn_ext_f16;
 
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
         static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
         if (!shared_memory_limit_raised[id]) {
             CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
             shared_memory_limit_raised[id] = true;
         }
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
     }
 
     launch_fattn
diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu
index 0c967f178e7..1e23f8f79c2 100644
--- a/ggml/src/ggml-cuda/fattn-tile-f16.cu
+++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu
@@ -5,14 +5,16 @@
 #define FATTN_KQ_STRIDE_TILE_F16 64
 
 template // D == head size
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#if !defined(GGML_USE_HIP)
 __launch_bounds__(nwarps*WARP_SIZE, 2)
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#endif // !defined(GGML_USE_HIP)
 static __global__ void flash_attn_tile_ext_f16(
         const char * __restrict__ Q,
         const char * __restrict__ K,
         const char * __restrict__ V,
         const char * __restrict__ mask,
+        const char * __restrict__ sinks,
+        const int  * __restrict__ KV_max,
         float      * __restrict__ dst,
         float2     * __restrict__ dst_meta,
         const float scale,
@@ -21,31 +23,13 @@ static __global__ void flash_attn_tile_ext_f16(
         const float m1,
         const uint32_t n_head_log2,
         const float logit_softcap,
-        const int ne00,
-        const int ne01,
-        const int ne02,
-        const int ne03,
-        const int ne10,
-        const int ne11,
-        const int ne12,
-        const int ne13,
-        const int ne31,
-        const int ne32,
-        const int nb31,
-        const int nb32,
-        const int nb01,
-        const int nb02,
-        const int nb03,
-        const int nb11,
-        const int nb12,
-        const int nb13,
-        const int nb21,
-        const int nb22,
-        const int nb23,
-        const int ne0,
-        const int ne1,
-        const int ne2,
-        const int ne3) {
+        const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+                            const int32_t nb01, const int32_t nb02, const int32_t nb03,
+        const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
+                            const int32_t nb11, const int32_t nb12, const int64_t nb13,
+                            const int32_t nb21, const int32_t nb22, const int64_t nb23,
+                            const int32_t ne31, const int32_t ne32, const int32_t ne33,
+                            const int32_t nb31, const int32_t nb32, const int64_t nb33) {
 #if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
 
     // Skip unused kernel variants for faster compilation:
@@ -62,15 +46,18 @@ static __global__ void flash_attn_tile_ext_f16(
 
     const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
 
+    const int sequence = blockIdx.z / ne02;
+    const int head = blockIdx.z - sequence*ne02;
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float2 * Q_f2  = (const float2 *) (Q    + nb02* blockIdx.z              + nb01*ic0);
-    const half2  * K_h2  = (const half2  *) (K    + nb12*(blockIdx.z / gqa_ratio));
-    const half2  * V_h2  = (const half2  *) (V    + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
-    const half   * maskh = (const half   *) (mask + nb32*(blockIdx.z % ne32)      + nb31*ic0);
+    const float2 * Q_f2   = (const float2 *) (Q    + nb03* sequence         + nb02* head              + nb01*ic0);
+    const half2  * K_h2   = (const half2  *) (K    + nb13* sequence         + nb12*(head / gqa_ratio));
+    const half2  * V_h2   = (const half2  *) (V    + nb13* sequence         + nb12*(head / gqa_ratio)); // K and V have same shape
+    const half   * maskh  = (const half   *) (mask  + nb33*(sequence % ne33)                          + nb31*ic0);
+    const float  * sinksf = (const float  *) (sinks);
 
     const int stride_KV2 = nb11 / sizeof(half2);
 
-    const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
+    const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
     const half  slopeh = __float2half(slopef);
 
     static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@@ -106,7 +93,8 @@ static __global__ void flash_attn_tile_ext_f16(
 
     __syncthreads();
 
-    for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) {
+    const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
+    for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) {
         // Calculate KQ tile and keep track of new maximum KQ values:
 
         half kqmax_new[ncols/nwarps];
@@ -123,7 +111,7 @@ static __global__ void flash_attn_tile_ext_f16(
             for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
                 const int k_KQ = k_KQ_0 + threadIdx.x;
 
-                KV_tmp[i_KQ][k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
+                KV_tmp[i_KQ][k_KQ] = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
             }
         }
 
@@ -217,7 +205,7 @@ static __global__ void flash_attn_tile_ext_f16(
             for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
                 const int i = i0 + threadIdx.x;
 
-                KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i];
+                KV_tmp[k][i] = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i];
             }
         }
 
@@ -255,6 +243,33 @@ static __global__ void flash_attn_tile_ext_f16(
         __syncthreads();
     }
 
+    //Attention sink: adjust running max and sum once per head
+    if (sinksf && blockIdx.y == 0) {
+        const half sink = __float2half(sinksf[head]);
+
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            half kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
+            kqmax_new_j = warp_reduce_max(kqmax_new_j);
+
+            const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new_j));
+            kqmax[j0/nwarps] = kqmax_new_j;
+
+            const half val = hexp(sink - kqmax[j0/nwarps]);
+            kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
+            if (threadIdx.x == 0) {
+                kqsum[j0/nwarps].x = __hadd(kqsum[j0/nwarps].x, val);
+            }
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+                VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale;
+            }
+        }
+    }
+
+    float2 * dst2 = (float2 *) dst;
+
 #pragma unroll
     for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
         const int j_VKQ = j_VKQ_0 + threadIdx.y;
@@ -266,36 +281,35 @@ static __global__ void flash_attn_tile_ext_f16(
         half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
         kqsum_j = warp_reduce_sum((float)kqsum_j);
 
+        const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
+
 #pragma unroll
-        for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
-            const int i0 = i00 + 2*threadIdx.x;
+        for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
+            const int i0 = i00 + threadIdx.x;
 
-            half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
+            half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
             if (gridDim.y == 1) {
                 dst_val /= __half2half2(kqsum_j);
             }
-            const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
-            dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] =  __low2float(dst_val);
-            dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val);
+            dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val);
         }
 
         if (gridDim.y != 1 && threadIdx.x == 0) {
-            dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
+            dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
         }
     }
 #else
-    GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+    GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
     GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
     GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
     GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
     GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
     GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
-    GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
-    GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+    GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
+    GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
     GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
     GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
-    GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
-    GGML_UNUSED(ne2); GGML_UNUSED(ne3);
+    GGML_UNUSED(nb23);
     NO_DEVICE_CODE;
 #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
 }
diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu
index 908c76dbdd2..c58194937d7 100644
--- a/ggml/src/ggml-cuda/fattn-tile-f32.cu
+++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu
@@ -5,14 +5,16 @@
 #define FATTN_KQ_STRIDE_TILE_F32 32
 
 template // D == head size
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#if !defined(GGML_USE_HIP)
 __launch_bounds__(nwarps*WARP_SIZE, 2)
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#endif // !defined(GGML_USE_HIP)
 static __global__ void flash_attn_tile_ext_f32(
         const char * __restrict__ Q,
         const char * __restrict__ K,
         const char * __restrict__ V,
         const char * __restrict__ mask,
+        const char * __restrict__ sinks,
+        const int  * __restrict__ KV_max,
         float      * __restrict__ dst,
         float2     * __restrict__ dst_meta,
         const float scale,
@@ -21,31 +23,13 @@ static __global__ void flash_attn_tile_ext_f32(
         const float m1,
         const uint32_t n_head_log2,
         const float logit_softcap,
-        const int ne00,
-        const int ne01,
-        const int ne02,
-        const int ne03,
-        const int ne10,
-        const int ne11,
-        const int ne12,
-        const int ne13,
-        const int ne31,
-        const int ne32,
-        const int nb31,
-        const int nb32,
-        const int nb01,
-        const int nb02,
-        const int nb03,
-        const int nb11,
-        const int nb12,
-        const int nb13,
-        const int nb21,
-        const int nb22,
-        const int nb23,
-        const int ne0,
-        const int ne1,
-        const int ne2,
-        const int ne3) {
+        const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+                            const int32_t nb01, const int32_t nb02, const int32_t nb03,
+        const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
+                            const int32_t nb11, const int32_t nb12, const int64_t nb13,
+                            const int32_t nb21, const int32_t nb22, const int64_t nb23,
+                            const int32_t ne31, const int32_t ne32, const int32_t ne33,
+                            const int32_t nb31, const int32_t nb32, const int64_t nb33) {
 #ifdef FLASH_ATTN_AVAILABLE
 
     // Skip unused kernel variants for faster compilation:
@@ -54,18 +38,17 @@ static __global__ void flash_attn_tile_ext_f32(
     return;
 #endif // FP16_MMA_AVAILABLE
     if (use_logit_softcap && !(D == 128 || D == 256)) {
-        GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
-        GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
-        GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
+        GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
+        GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
+        GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
         GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
-        GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
-        GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
-        GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
-        GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
-        GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
-        GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
-        GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
-        GGML_UNUSED(ne2); GGML_UNUSED(ne3);
+        GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
+        GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
+        GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
+        GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
+        GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
+        GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
+        GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
         NO_DEVICE_CODE;
         return;
     }
@@ -74,15 +57,18 @@ static __global__ void flash_attn_tile_ext_f32(
 
     const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
 
+    const int sequence = blockIdx.z / ne02;
+    const int head = blockIdx.z - sequence*ne02;
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float2 * Q_f2  = (const float2 *) (Q    + nb02* blockIdx.z              + nb01*ic0);
-    const half2  * K_h2  = (const half2  *) (K    + nb12*(blockIdx.z / gqa_ratio));
-    const half2  * V_h2  = (const half2  *) (V    + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
-    const half   * maskh = (const half   *) (mask + nb32*(blockIdx.z % ne32)      + nb31*ic0);
+    const float2 * Q_f2   = (const float2 *) (Q    + nb03* sequence         + nb02* head              + nb01*ic0);
+    const half2  * K_h2   = (const half2  *) (K    + nb13* sequence         + nb12*(head / gqa_ratio));
+    const half2  * V_h2   = (const half2  *) (V    + nb13* sequence         + nb12*(head / gqa_ratio)); // K and V have same shape
+    const half   * maskh  = (const half   *) (mask  + nb33*(sequence % ne33)                          + nb31*ic0);
+    const float  * sinksf = (const float  *) (sinks);
 
     const int stride_KV2 = nb11 / sizeof(half2);
 
-    const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
+    const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
 
     static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
 
@@ -116,7 +102,8 @@ static __global__ void flash_attn_tile_ext_f32(
 
     __syncthreads();
 
-    for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) {
+    const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
+    for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) {
         // Calculate KQ tile and keep track of new maximum KQ values:
 
         float kqmax_new[ncols/nwarps];
@@ -131,7 +118,7 @@ static __global__ void flash_attn_tile_ext_f32(
 
 #pragma unroll
             for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) {
-                const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
+                const half2 tmp = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
                 KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] =  __low2float(tmp);
                 KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp);
             }
@@ -227,8 +214,9 @@ static __global__ void flash_attn_tile_ext_f32(
             for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
                 const int i = i0 + threadIdx.x;
 
-                KV_tmp2[k*(D/2) + i].x =  __low2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
-                KV_tmp2[k*(D/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
+                const half2 tmp = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i];
+                KV_tmp2[k*(D/2) + i].x =  __low2float(tmp);
+                KV_tmp2[k*(D/2) + i].y = __high2float(tmp);
             }
         }
 
@@ -265,6 +253,35 @@ static __global__ void flash_attn_tile_ext_f32(
         __syncthreads();
     }
 
+
+    //Attention sink: adjust running max and sum once per head
+    if (sinksf && blockIdx.y == 0) {
+        const float sink = sinksf[head];
+
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
+            kqmax_new_j = warp_reduce_max(kqmax_new_j);
+
+            const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new_j);
+            kqmax[j0/nwarps] = kqmax_new_j;
+
+            const float val = expf(sink - kqmax[j0/nwarps]);
+            kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
+            if (threadIdx.x == 0) {
+                kqsum[j0/nwarps] += val;
+            }
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+                VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale;
+                VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale;
+            }
+        }
+    }
+
+    float2 * dst2 = (float2 *) dst;
+
 #pragma unroll
     for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
         const int j_VKQ = j_VKQ_0 + threadIdx.y;
@@ -276,37 +293,36 @@ static __global__ void flash_attn_tile_ext_f32(
         float kqsum_j = kqsum[j_VKQ_0/nwarps];
         kqsum_j = warp_reduce_sum(kqsum_j);
 
+        const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
+
 #pragma unroll
-        for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
-            const int i0 = i00 + 2*threadIdx.x;
+        for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
+            const int i0 = i00 + threadIdx.x;
 
-            float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
+            float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
             if (gridDim.y == 1) {
                 dst_val.x /= kqsum_j;
                 dst_val.y /= kqsum_j;
             }
-            const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
-            dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = dst_val.x;
-            dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = dst_val.y;
+            dst2[j_dst_unrolled*(D/2) + i0] = dst_val;
         }
 
         if (gridDim.y != 1 && threadIdx.x == 0) {
-            dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
+            dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
         }
     }
 #else
     GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
-    GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
-    GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
+    GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
+    GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
     GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
     GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
-    GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
-    GGML_UNUSED(ne31); GGML_UNUSED(ne32);
-    GGML_UNUSED(nb31); GGML_UNUSED(nb32);
     GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
+    GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
     GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
     GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
-    GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
+    GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
+    GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
     NO_DEVICE_CODE;
 #endif // FLASH_ATTN_AVAILABLE
 }
diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh
index e78fb181919..b05f682cd3b 100644
--- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh
+++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh
@@ -1,6 +1,12 @@
 #include "common.cuh"
 #include "fattn-common.cuh"
 
+// Currenlty llvm with the amdgcn target dose not support unrolling loops
+// that contain a break that can not be resolved at compile time.
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wpass-failed"
+#endif // __clang__
 template // D == head size
 #ifndef GGML_USE_HIP
 __launch_bounds__(D, 1)
@@ -10,6 +16,8 @@ static __global__ void flash_attn_vec_ext_f16(
         const char * __restrict__ K,
         const char * __restrict__ V,
         const char * __restrict__ mask,
+        const char * __restrict__ sinks,
+        const int  * __restrict__ KV_max,
         float      * __restrict__ dst,
         float2     * __restrict__ dst_meta,
         const float scale,
@@ -18,31 +26,13 @@ static __global__ void flash_attn_vec_ext_f16(
         const float m1,
         const uint32_t n_head_log2,
         const float logit_softcap,
-        const int ne00,
-        const int ne01,
-        const int ne02,
-        const int ne03,
-        const int ne10,
-        const int ne11,
-        const int ne12,
-        const int ne13,
-        const int ne31,
-        const int ne32,
-        const int nb31,
-        const int nb32,
-        const int nb01,
-        const int nb02,
-        const int nb03,
-        const int nb11,
-        const int nb12,
-        const int nb13,
-        const int nb21,
-        const int nb22,
-        const int nb23,
-        const int ne0,
-        const int ne1,
-        const int ne2,
-        const int ne3) {
+        const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+                            const int32_t nb01, const int32_t nb02, const int32_t nb03,
+        const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
+                            const int32_t nb11, const int32_t nb12, const int64_t nb13,
+                            const int32_t nb21, const int32_t nb22, const int64_t nb23,
+                            const int32_t ne31, const int32_t ne32, const int32_t ne33,
+                            const int32_t nb31, const int32_t nb32, const int64_t nb33) {
 #if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
 
     // Skip unused kernel variants for faster compilation:
@@ -65,14 +55,17 @@ static __global__ void flash_attn_vec_ext_f16(
 
     const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
 
+    const int sequence = blockIdx.z / ne02;
+    const int head = blockIdx.z - sequence*ne02;
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    Q += nb02* blockIdx.z              + nb01*ic0;
-    K += nb12*(blockIdx.z / gqa_ratio);
-    V += nb22*(blockIdx.z / gqa_ratio);
+    Q += nb03*sequence + nb02* head              + nb01*ic0;
+    K += nb13*sequence + nb12*(head / gqa_ratio);
+    V += nb23*sequence + nb22*(head / gqa_ratio);
 
-    const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
+    const half  * maskh  = (const half  *) (mask + nb33*(sequence % ne33) + nb31*ic0);
+    const float * sinksf = (const float *) (sinks);
 
-    const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
+    const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
     const half  slopeh = __float2half(slopef);
 
     static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@@ -84,11 +77,12 @@ static __global__ void flash_attn_vec_ext_f16(
     half2 * KQ2 = (half2 *) KQ;
 
     half kqmax[ncols];
+    half kqsum[ncols];
 #pragma unroll
     for (int j = 0; j < ncols; ++j) {
         kqmax[j] = -HALF_MAX_HALF;
+        kqsum[j] = 0.0f;
     }
-    half kqsum[ncols] = {0.0f};
 
     __shared__ half kqmax_shared[ncols][WARP_SIZE];
     __shared__ half kqsum_shared[ncols][WARP_SIZE];
@@ -187,37 +181,22 @@ static __global__ void flash_attn_vec_ext_f16(
 
     half2 VKQ[ncols] = {{0.0f, 0.0f}};
 
-    for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
+    const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
+    K     += blockIdx.y*D * nb11;
+    V     += blockIdx.y*D * nb21;
+    maskh += blockIdx.y*D;
+    for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*D,
+             // Increment pointers after each loop:
+             K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
+
         // Calculate KQ tile and keep track of new maximum KQ values:
 
         if (mask) {
 #pragma unroll
             for (int j = 0; j < ncols; ++j) {
-                maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + k_VKQ_0 + tid];
+                maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + tid];
             }
-
             __syncthreads();
-
-            // When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out.
-            // In such cases, skip the KV slice.
-            // On AMD __all_sync would not work correctly because it assumes a warp size of 64.
-#ifndef GGML_USE_HIP
-            bool skip = true;
-#pragma unroll
-            for (int j = 0; j < ncols; ++j) {
-#pragma unroll
-                for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
-                    const int i = i0 + threadIdx.x;
-
-                    const float2 tmp = __half22float2(((const half2 *) maskh_shared)[j*(D/2) + i]);
-                    skip = skip && isinf(tmp.x) && isinf(tmp.y);
-                }
-            }
-            if (__all_sync(0xFFFFFFFF, skip)) {
-                __syncthreads();
-                continue;
-            }
-#endif // GGML_USE_HIP
         }
 
         // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
@@ -240,7 +219,7 @@ static __global__ void flash_attn_vec_ext_f16(
 
 #pragma unroll
             for (int j = 0; j < ncols; ++j) {
-                half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
+                half sum = vec_dot_KQ(K + i_KQ*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
                 sum = warp_reduce_sum((float)sum);
 
                 if (use_logit_softcap) {
@@ -296,8 +275,8 @@ static __global__ void flash_attn_vec_ext_f16(
             }
 
             half2 V_k;
-            reinterpret_cast(V_k.x) = dequantize_1_v(V + (k_VKQ_0 + k0 + 0)*nb21, tid);
-            reinterpret_cast(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid);
+            reinterpret_cast(V_k.x) = dequantize_1_v(V + (k0 + 0)*nb21, tid);
+            reinterpret_cast(V_k.y) = dequantize_1_v(V + (k0 + 1)*nb21, tid);
 #pragma unroll
             for (int j = 0; j < ncols; ++j) {
                 VKQ[j] += V_k*KQ2[j*(D/2) + k0/2];
@@ -307,6 +286,39 @@ static __global__ void flash_attn_vec_ext_f16(
         __syncthreads();
     }
 
+    if (sinksf && blockIdx.y == 0) {
+        const half sink = __float2half(sinksf[head]);
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            if (threadIdx.x == 0) {
+                kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            half kqmax_new_j = kqmax_shared[j][threadIdx.x];
+            kqmax_new_j = warp_reduce_max(kqmax_new_j);
+
+            const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
+            kqmax[j] = kqmax_new_j;
+
+            const half val = hexp(sink - kqmax[j]);
+            kqsum[j] = kqsum[j]*KQ_max_scale;
+
+            if (tid == 0) {
+                kqsum[j] += val;
+            }
+
+            VKQ[j] *= __half2half2(KQ_max_scale);
+        }
+
+        __syncthreads();
+    }
+
 #pragma unroll
     for (int j = 0; j < ncols; ++j) {
         kqsum[j] = warp_reduce_sum((float)kqsum[j]);
@@ -330,29 +342,30 @@ static __global__ void flash_attn_vec_ext_f16(
         if (gridDim.y == 1) {
             dst_val /= kqsum[j_VKQ];
         }
-        const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
-        dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
+        dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
     }
 
     if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
-        dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
+        dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
     }
 #else
-    GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
-    GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
-    GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
+    GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
+    GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
+    GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
     GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
-    GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
-    GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
-    GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
-    GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
-    GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
-    GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
-    GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
-    GGML_UNUSED(ne2); GGML_UNUSED(ne3);
+    GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
+    GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
+    GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
+    GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
+    GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
+    GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
+    GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
     NO_DEVICE_CODE;
 #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
 }
+#ifdef __clang__
+#pragma clang diagnostic pop
+#endif // __clang__
 
 template 
 void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh
index b2f1724c955..d6d0bfb744b 100644
--- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh
+++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh
@@ -1,6 +1,12 @@
 #include "common.cuh"
 #include "fattn-common.cuh"
 
+// Currenlty llvm with the amdgcn target dose not support unrolling loops
+// that contain a break that can not be resolved at compile time.
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wpass-failed"
+#endif // __clang__
 template // D == head size
 #ifndef GGML_USE_HIP
 __launch_bounds__(D, 1)
@@ -10,6 +16,8 @@ static __global__ void flash_attn_vec_ext_f32(
         const char * __restrict__ K,
         const char * __restrict__ V,
         const char * __restrict__ mask,
+        const char * __restrict__ sinks,
+        const int  * __restrict__ KV_max,
         float      * __restrict__ dst,
         float2     * __restrict__ dst_meta,
         const float scale,
@@ -18,31 +26,13 @@ static __global__ void flash_attn_vec_ext_f32(
         const float m1,
         const uint32_t n_head_log2,
         const float logit_softcap,
-        const int ne00,
-        const int ne01,
-        const int ne02,
-        const int ne03,
-        const int ne10,
-        const int ne11,
-        const int ne12,
-        const int ne13,
-        const int ne31,
-        const int ne32,
-        const int nb31,
-        const int nb32,
-        const int nb01,
-        const int nb02,
-        const int nb03,
-        const int nb11,
-        const int nb12,
-        const int nb13,
-        const int nb21,
-        const int nb22,
-        const int nb23,
-        const int ne0,
-        const int ne1,
-        const int ne2,
-        const int ne3) {
+        const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+                            const int32_t nb01, const int32_t nb02, const int32_t nb03,
+        const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
+                            const int32_t nb11, const int32_t nb12, const int64_t nb13,
+                            const int32_t nb21, const int32_t nb22, const int64_t nb23,
+                            const int32_t ne31, const int32_t ne32, const int32_t ne33,
+                            const int32_t nb31, const int32_t nb32, const int64_t nb33) {
 #ifdef FLASH_ATTN_AVAILABLE
 
     // Skip unused kernel variants for faster compilation:
@@ -53,12 +43,11 @@ static __global__ void flash_attn_vec_ext_f32(
         GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
         GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
         GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
-        GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
-        GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+        GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
+        GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
         GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
         GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
-        GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
-        GGML_UNUSED(ne2); GGML_UNUSED(ne3);
+        GGML_UNUSED(nb23);
         NO_DEVICE_CODE;
         return;
     }
@@ -77,14 +66,17 @@ static __global__ void flash_attn_vec_ext_f32(
 
     const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
 
+    const int sequence = blockIdx.z / ne02;
+    const int head = blockIdx.z - sequence*ne02;
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    Q += nb02* blockIdx.z              + nb01*ic0;
-    K += nb12*(blockIdx.z / gqa_ratio);
-    V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
+    Q += nb03*sequence + nb02* head              + nb01*ic0;
+    K += nb13*sequence + nb12*(head / gqa_ratio);
+    V += nb23*sequence + nb22*(head / gqa_ratio);
 
-    const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
+    const half  * maskh  = (const half  *) (mask + nb33*(sequence % ne33) + nb31*ic0);
+    const float * sinksf = (const float *) (sinks);
 
-    const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
+    const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
 
     static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
     constexpr int nwarps = D / WARP_SIZE;
@@ -98,11 +90,12 @@ static __global__ void flash_attn_vec_ext_f32(
     }
 
     float kqmax[ncols];
+    float kqsum[ncols];
 #pragma unroll
     for (int j = 0; j < ncols; ++j) {
         kqmax[j] = -FLT_MAX/2.0f;
+        kqsum[j] = 0.0f;
     }
-    float kqsum[ncols] = {0.0f};
 
     __shared__ float kqmax_shared[ncols][WARP_SIZE];
     __shared__ float kqsum_shared[ncols][WARP_SIZE];
@@ -194,36 +187,22 @@ static __global__ void flash_attn_vec_ext_f32(
 
     float VKQ[ncols] = {0.0f};
 
-    for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
+    const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
+    K     += blockIdx.y*D * nb11;
+    V     += blockIdx.y*D * nb21;
+    maskh += blockIdx.y*D;
+    for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*D,
+             // Increment pointers after each loop:
+             K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
+
         // Calculate KQ tile and keep track of new maximum KQ values:
 
         if (mask) {
 #pragma unroll
             for (int j = 0; j < ncols; ++j) {
-                maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + k_VKQ_0 + tid]);
+                maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + tid]);
             }
-
             __syncthreads();
-
-            // When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out.
-            // In such cases, skip the KV slice.
-            // On AMD __all_sync would not work correctly because it assumes a warp size of 64.
-#ifndef GGML_USE_HIP
-            bool skip = true;
-#pragma unroll
-            for (int j = 0; j < ncols; ++j) {
-#pragma unroll
-                for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
-                    const int i = i0 + threadIdx.x;
-
-                    skip = skip && isinf(maskf_shared[j*D + i]);
-                }
-            }
-            if (__all_sync(0xFFFFFFFF, skip)) {
-                __syncthreads();
-                continue;
-            }
-#endif // GGML_USE_HIP
         }
 
         float kqmax_new_arr[ncols];
@@ -242,7 +221,7 @@ static __global__ void flash_attn_vec_ext_f32(
 
 #pragma unroll
             for (int j = 0; j < ncols; ++j) {
-                float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
+                float sum = vec_dot_KQ(K + i_KQ*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
                 sum = warp_reduce_sum(sum);
 
                 if (use_logit_softcap) {
@@ -293,7 +272,7 @@ static __global__ void flash_attn_vec_ext_f32(
                 break;
             }
 
-            const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid);
+            const float V_ki = dequantize_1_v(V + k*nb21, tid);
 #pragma unroll
             for (int j = 0; j < ncols; ++j) {
                 VKQ[j] += V_ki*KQ[j*D + k];
@@ -303,6 +282,39 @@ static __global__ void flash_attn_vec_ext_f32(
         __syncthreads();
     }
 
+    if (sinksf && blockIdx.y == 0) {
+        const float sink = sinksf[head];
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            if (threadIdx.x == 0) {
+                kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            float kqmax_new_j = kqmax_shared[j][threadIdx.x];
+            kqmax_new_j = warp_reduce_max(kqmax_new_j);
+
+            const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
+            kqmax[j] = kqmax_new_j;
+
+            const float val = expf(sink - kqmax[j]);
+            kqsum[j] = kqsum[j]*KQ_max_scale;
+
+            if (tid == 0) {
+                kqsum[j] += val;
+            }
+
+            VKQ[j] *= KQ_max_scale;
+        }
+
+        __syncthreads();
+    }
+
 #pragma unroll
     for (int j = 0; j < ncols; ++j) {
         kqsum[j] = warp_reduce_sum(kqsum[j]);
@@ -326,12 +338,11 @@ static __global__ void flash_attn_vec_ext_f32(
         if (gridDim.y == 1) {
             dst_val /= kqsum[j_VKQ];
         }
-        const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
-        dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
+        dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
     }
 
     if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
-        dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
+        dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
     }
 #else
     GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
@@ -340,15 +351,17 @@ static __global__ void flash_attn_vec_ext_f32(
     GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
     GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
     GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
-    GGML_UNUSED(ne31); GGML_UNUSED(ne32);
-    GGML_UNUSED(nb31); GGML_UNUSED(nb32);
+    GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
+    GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
     GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
     GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
     GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
-    GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
     NO_DEVICE_CODE;
 #endif // FLASH_ATTN_AVAILABLE
 }
+#ifdef __clang__
+#pragma clang diagnostic pop
+#endif // __clang__
 
 template 
 void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu
index c95ca7b1f28..6bc7943ccd5 100644
--- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu
+++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu
@@ -7,7 +7,7 @@
 #include "fattn-wmma-f16.cuh"
 
 #ifdef FP16_MMA_AVAILABLE
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#if !defined(GGML_USE_HIP)
 #include 
 #ifdef GGML_USE_MUSA
 namespace wmma = mtmusa::wmma;
@@ -15,10 +15,9 @@ namespace wmma = mtmusa::wmma;
 namespace wmma = nvcuda::wmma;
 #endif // GGML_USE_MUSA
 #elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
-#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
 #include 
 namespace wmma = rocwmma;
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#endif // !defined(GGML_USE_HIP)
 #endif // FP16_MMA_AVAILABLE
 
 // D == head size, VKQ_stride == num VKQ rows calculated in parallel:
@@ -29,6 +28,8 @@ static __global__ void flash_attn_ext_f16(
         const char * __restrict__ K,
         const char * __restrict__ V,
         const char * __restrict__ mask,
+        const char * __restrict__ sinks,
+        const int  * __restrict__ KV_max,
         float      * __restrict__ dst,
         float2     * __restrict__ dst_meta,
         const float scale,
@@ -37,31 +38,13 @@ static __global__ void flash_attn_ext_f16(
         const float m1,
         const uint32_t n_head_log2,
         const float logit_softcap,
-        const int ne00,
-        const int ne01,
-        const int ne02,
-        const int ne03,
-        const int ne10,
-        const int ne11,
-        const int ne12,
-        const int ne13,
-        const int ne31,
-        const int ne32,
-        const int nb31,
-        const int nb32,
-        const int nb01,
-        const int nb02,
-        const int nb03,
-        const int nb11,
-        const int nb12,
-        const int nb13,
-        const int nb21,
-        const int nb22,
-        const int nb23,
-        const int ne0,
-        const int ne1,
-        const int ne2,
-        const int ne3) {
+        const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
+                            const int32_t nb01, const int32_t nb02, const int32_t nb03,
+        const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
+                            const int32_t nb11, const int32_t nb12, const int64_t nb13,
+                            const int32_t nb21, const int32_t nb22, const int64_t nb23,
+                            const int32_t ne31, const int32_t ne32, const int32_t ne33,
+                            const int32_t nb31, const int32_t nb32, const int64_t nb33) {
 #if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(D == 128 || D == 256)) {
@@ -95,17 +78,20 @@ static __global__ void flash_attn_ext_f16(
     constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
     constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
 
+    const int sequence = blockIdx.z / ne02;
+    const int head = blockIdx.z - sequence*ne02;
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float * Q_f   = (const float *) (Q    + nb02* blockIdx.z              + nb01*ic0);
-    const half  * K_h   = (const half  *) (K    + nb12*(blockIdx.z / gqa_ratio));
-    const half  * V_h   = (const half  *) (V    + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
-    const half  * maskh = (const half  *) (mask + nb32*(blockIdx.z % ne32)      + nb31*ic0);
-    const half2 * mask2 = (const half2 *)  maskh;
+    const float * Q_f    = (const float *) (Q    + nb03* sequence         + nb02* head              + nb01*ic0);
+    const half  * K_h    = (const half  *) (K    + nb13* sequence         + nb12*(head / gqa_ratio));
+    const half  * V_h    = (const half  *) (V    + nb13* sequence         + nb12*(head / gqa_ratio)); // K and V have same shape
+    const half  * maskh  = (const half  *) (mask + nb33*(sequence % ne33)                           + nb31*ic0);
+    const half2 * mask2  = (const half2 *)  maskh;
+    const float * sinksf = (const float *) sinks;
 
     const int stride_Q  = nb01 / sizeof(float);
     const int stride_KV = nb11 / sizeof(half);
 
-    const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
+    const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
     const half  slopeh = __float2half(slopef);
     const half2 slope2 = make_half2(slopef, slopef);
 
@@ -181,7 +167,8 @@ static __global__ void flash_attn_ext_f16(
     __syncthreads();
 
     // Iterate over ne11 == previous tokens:
-    for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) {
+    const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
+    for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) {
         // Calculate tile of KQ:
 #pragma unroll
         for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
@@ -193,7 +180,7 @@ static __global__ void flash_attn_ext_f16(
 #pragma unroll
             for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
                 frag_a_K K_a;
-                wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
+                wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
 #pragma unroll
                 for (int j = 0; j < ncols/frag_n; ++j) {
                     wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
@@ -340,7 +327,7 @@ static __global__ void flash_attn_ext_f16(
                 const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
 
                 frag_a_V v_a;
-                wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
+                wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
 #pragma unroll
                 for (int j = 0; j < ncols/frag_n; ++j) {
                     wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
@@ -394,13 +381,59 @@ static __global__ void flash_attn_ext_f16(
         __syncthreads();
     }
 
+    // Apply attention sinks
+    if (sinksf && blockIdx.y == 0) {
+        const float sinkf = sinksf[head];
+        const half  sinkh = __float2half(sinkf);
+
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+            if (std::is_same::value) {
+                float kqmax_new = fmaxf(KQ_max_f[j0/nwarps], sinkf);
+
+                const float KQ_max_scale = expf(KQ_max_f[j0/nwarps] - kqmax_new);
+                KQ_max_f[j0/nwarps] = kqmax_new;
+
+                KQ_rowsum_f[j0/nwarps] = KQ_rowsum_f[j0/nwarps] * KQ_max_scale + expf(sinkf - KQ_max_f[j0/nwarps]);
+
+                const half2 scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
+#pragma unroll
+                for (int i0 = 0; i0 < D/2; i0 += warp_size) {
+                    const int i = i0 + threadIdx.x;
+                    if (i0 + warp_size > D/2 && i >= D/2) break;
+                    VKQ2[j*(D_padded/2) + i] *= scale_h2;
+                }
+            } else {
+                half kqmax_old = __low2half(KQ_max_h2[j0/nwarps]);
+                half kqmax_new = fmaxf(kqmax_old, sinkh);
+                KQ_max_h2[j0/nwarps] = __half2half2(kqmax_new);
+
+                const half  KQ_max_scale_h = hexp(kqmax_old - kqmax_new);
+                const half2 KQ_max_scale   = __half2half2(KQ_max_scale_h);
+
+                KQ_rowsum_h2[j0/nwarps] = KQ_rowsum_h2[j0/nwarps] * KQ_max_scale;
+                const half val = hexp(sinkh - kqmax_new);
+                KQ_rowsum_h2[j0/nwarps].x = __hadd(KQ_rowsum_h2[j0/nwarps].x, val);
+
+#pragma unroll
+                for (int i0 = 0; i0 < D/2; i0 += warp_size) {
+                    const int i = i0 + threadIdx.x;
+                    if (i0 + warp_size > D/2 && i >= D/2) break;
+                    VKQ2[j*(D_padded/2) + i] *= KQ_max_scale;
+                }
+            }
+        }
+
+        __syncthreads();
+    }
 #pragma unroll
     for (int j0 = 0; j0 < ncols; j0 += nwarps) {
         const int j_VKQ = j0 + threadIdx.y;
         if (ic0 + j_VKQ >= ne01) {
             return;
         }
-        const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
 
         float KQ_rowsum_j;
         if (std::is_same::value) {
@@ -409,6 +442,8 @@ static __global__ void flash_attn_ext_f16(
             KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
         }
 
+        const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
+
 #pragma unroll
         for (int i0 = 0; i0 < D; i0 += warp_size) {
             const int i = i0 + threadIdx.x;
@@ -419,7 +454,7 @@ static __global__ void flash_attn_ext_f16(
             if (gridDim.y == 1) {
                 dst_val /= KQ_rowsum_j;
             }
-            dst[j_dst*gridDim.z*D + blockIdx.z*D + i] = dst_val;
+            dst[j_dst_unrolled*D + i] = dst_val;
         }
 
         if (gridDim.y == 1 || threadIdx.x != 0) {
@@ -433,19 +468,19 @@ static __global__ void flash_attn_ext_f16(
             dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
         }
         dst_meta_val.y = KQ_rowsum_j;
-        dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val;
+        dst_meta[j_dst_unrolled] = dst_meta_val;
     }
 #else
-    GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+    GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
     GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
     GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
     GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
     GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
     GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
-    GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+    GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); GGML_UNUSED(nb31);
+    GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
     GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
     GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
-    GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
     NO_DEVICE_CODE;
 #endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
 }
@@ -561,7 +596,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
         return;
     }
 
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#if !defined(GGML_USE_HIP)
     if (Q->ne[1] <= 8 && Q->ne[0] % warp_size == 0) {
         constexpr int cols_per_block = 8;
         switch (Q->ne[0]) {
@@ -583,7 +618,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
         }
         return;
     }
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#endif // !defined(GGML_USE_HIP)
 
     if (Q->ne[1] <= 32) {
         constexpr int cols_per_block = 16;
diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
index 6bc0096cc65..22e90d0e7b3 100644
--- a/ggml/src/ggml-cuda/fattn.cu
+++ b/ggml/src/ggml-cuda/fattn.cu
@@ -269,33 +269,23 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
 }
 
 void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * KQV  = dst;
-    const ggml_tensor * Q    = dst->src[0];
-    const ggml_tensor * K    = dst->src[1];
-    const ggml_tensor * V    = dst->src[2];
-    const ggml_tensor * mask = dst->src[3];
+    const ggml_tensor * KQV   = dst;
+    const ggml_tensor * Q     = dst->src[0];
+    const ggml_tensor * K     = dst->src[1];
+    const ggml_tensor * V     = dst->src[2];
+    const ggml_tensor * mask  = dst->src[3];
 
     ggml_cuda_set_device(ctx.device);
     const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
     const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
     const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
 
-    if (GGML_CUDA_CC_IS_AMD(cc)) {
 #if defined(GGML_HIP_ROCWMMA_FATTN)
-        if (fp16_mma_available(cc)) {
-            ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
-            return;
-        }
-#endif // defined(GGML_HIP_ROCWMMA_FATTN)
-
-        // On AMD the tile kernels perform poorly, use the vec kernel instead:
-        if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
-            ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
-        } else {
-            ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
-        }
+    if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) {
+        ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
         return;
     }
+#endif // defined(GGML_HIP_ROCWMMA_FATTN)
 
     if (!fast_fp16_available(cc)) {
         if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
@@ -325,7 +315,9 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
 
     const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
     const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
-    const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion;
+    const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (Q->ne[2] > 4*K->ne[2] && K->ne[1] >= 8192);
+    const bool mma_faster_for_bs1 = turing_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion &&
+        (cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
     const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
     if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
         if (prec == GGML_PREC_DEFAULT) {
@@ -337,7 +329,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     }
 
     // The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
-    if (fp16_mma_available(cc) && !new_mma_available(cc)) {
+    if (fp16_mma_available(cc) && !turing_mma_available(cc)) {
         ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
         return;
     }
diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu
index f77b2629a19..68d3254fbe4 100644
--- a/ggml/src/ggml-cuda/getrows.cu
+++ b/ggml/src/ggml-cuda/getrows.cu
@@ -1,5 +1,6 @@
 #include "getrows.cuh"
 #include "dequantize.cuh"
+#include "convert.cuh"
 
 template
 static __global__ void k_get_rows(
@@ -34,8 +35,8 @@ static __global__ void k_get_rows(
     dfloat2 v;
     dequantize_kernel(src0_row, ib, iqs, v);
 
-    dst_row[iybs + iqs + 0]        = float(v.x);
-    dst_row[iybs + iqs + y_offset] = float(v.y);
+    dst_row[iybs + iqs + 0]        = ggml_cuda_cast(v.x);
+    dst_row[iybs + iqs + y_offset] = ggml_cuda_cast(v.y);
 }
 
 template
@@ -62,7 +63,7 @@ static __global__ void k_get_rows_float(
     dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
     const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
 
-    dst_row[i00] = float(src0_row[i00]);
+    dst_row[i00] = ggml_cuda_cast(src0_row[i00]);
 }
 
 template
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 72406f0af36..d6402a8daac 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -4,6 +4,7 @@
 
 #include "ggml-cuda/common.cuh"
 #include "ggml-cuda/acc.cuh"
+#include "ggml-cuda/add-id.cuh"
 #include "ggml-cuda/arange.cuh"
 #include "ggml-cuda/argmax.cuh"
 #include "ggml-cuda/argsort.cuh"
@@ -21,17 +22,21 @@
 #include "ggml-cuda/fattn.cuh"
 #include "ggml-cuda/getrows.cuh"
 #include "ggml-cuda/im2col.cuh"
+#include "ggml-cuda/mmf.cuh"
 #include "ggml-cuda/mmq.cuh"
-#include "ggml-cuda/mmv.cuh"
+#include "ggml-cuda/mmvf.cuh"
 #include "ggml-cuda/mmvq.cuh"
 #include "ggml-cuda/norm.cuh"
 #include "ggml-cuda/opt-step-adamw.cuh"
+#include "ggml-cuda/opt-step-sgd.cuh"
 #include "ggml-cuda/out-prod.cuh"
 #include "ggml-cuda/pad.cuh"
 #include "ggml-cuda/pool2d.cuh"
 #include "ggml-cuda/quantize.cuh"
 #include "ggml-cuda/rope.cuh"
+#include "ggml-cuda/roll.cuh"
 #include "ggml-cuda/scale.cuh"
+#include "ggml-cuda/softcap.cuh"
 #include "ggml-cuda/softmax.cuh"
 #include "ggml-cuda/ssm-conv.cuh"
 #include "ggml-cuda/ssm-scan.cuh"
@@ -43,6 +48,7 @@
 #include "ggml-cuda/upscale.cuh"
 #include "ggml-cuda/wkv.cuh"
 #include "ggml-cuda/gla.cuh"
+#include "ggml-cuda/set-rows.cuh"
 #include "ggml.h"
 
 #include 
@@ -54,6 +60,7 @@
 #include 
 #include 
 #include 
+#include 
 #include 
 #include 
 #include 
@@ -124,7 +131,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
     return err;
 }
 
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#if defined(GGML_USE_HIP)
 static int ggml_cuda_parse_id(char devName[]) {
     // A list of possible Target IDs can be found under the rocclr/clr repo in device.cpp
     // these values are not stable so this is susceptible to breakage
@@ -171,33 +178,9 @@ static int ggml_cuda_parse_id(char devName[]) {
     archNum += archMinor;
     return archNum;
 }
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#endif // defined(GGML_USE_HIP)
 
 static ggml_cuda_device_info ggml_cuda_init() {
-#ifdef __HIP_PLATFORM_AMD__
-    // Workaround for a rocBLAS bug when using multiple graphics cards:
-    // https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346
-    {
-        int major_version = 0;
-        size_t version_length = 0;
-        if (rocblas_get_version_string_size(&version_length) == rocblas_status_success) {
-            std::vector version(version_length+1, '\0');
-            if (rocblas_get_version_string(version.data(), version.size()) == rocblas_status_success) {
-                version.resize(::strlen(version.data()));
-                int parsed_value = 0;
-                if (std::from_chars(version.data(), version.data() + version.size(), parsed_value).ec == std::errc()) {
-                    major_version = parsed_value;
-                }
-            }
-        }
-        if (major_version < 4) {
-            GGML_LOG_DEBUG(GGML_CUDA_NAME " calling rocblas_initialize as a workaround for a rocBLAS bug\n");
-            rocblas_initialize();
-            CUDA_CHECK(cudaDeviceSynchronize());
-        }
-    }
-#endif
-
     ggml_cuda_device_info info = {};
 
     cudaError_t err = cudaGetDeviceCount(&info.device_count);
@@ -247,7 +230,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
         info.devices[id].nsm        = prop.multiProcessorCount;
         info.devices[id].smpb       = prop.sharedMemPerBlock;
         info.devices[id].warp_size  = prop.warpSize;
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#if defined(GGML_USE_HIP)
         info.devices[id].smpbo = prop.sharedMemPerBlock;
 
         info.devices[id].cc = ggml_cuda_parse_id(prop.gcnArchName);
@@ -277,7 +260,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
         info.devices[id].cc = 100*prop.major + 10*prop.minor;
         GGML_LOG_INFO("  Device %d: %s, compute capability %d.%d, VMM: %s\n",
                         id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#endif // defined(GGML_USE_HIP)
     }
 
     for (int id = 0; id < info.device_count; ++id) {
@@ -1848,6 +1831,9 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
     ggml_cuda_pool_alloc src0_alloc(ctx.pool());
     ggml_cuda_pool_alloc src1_alloc(ctx.pool());
 
+    bool is_src0_cont_2 = ggml_is_contiguous_2(src0);
+    bool is_src1_cont_2 = ggml_is_contiguous_2(src1);
+
     // Handle src0
     src0_ptr = (const cuda_t *) src0->data;
 
@@ -1866,6 +1852,8 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
         s11 = ne10;
         s12 = ne11*s11;
         s13 = ne12*s12;
+
+        is_src1_cont_2 = true;
     }
 
     // Setup destination buffer
@@ -1914,15 +1902,19 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
     const int64_t r2 = ne12/ne02;
     const int64_t r3 = ne13/ne03;
 
-    if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
+    if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
+        // with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
+        const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
+        const int64_t smb = ne12 == 1 ? s13       : s12;
+
         // there is no broadcast and src0, src1 are contiguous across dims 2, 3
         // use cublasGemmStridedBatchedEx
         CUBLAS_CHECK(
         cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
                 ne01, ne11, ne10,
-                alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
-                       src1_ptr, cu_data_type_b, s11,       s12,       // strideB
-                beta,     dst_t, cu_data_type,   ne0,       ne1*ne0,   // strideC
+                alpha, src0_ptr, cu_data_type_a, nb01/nb00, sma,     // strideA
+                       src1_ptr, cu_data_type_b, s11,       smb,     // strideB
+                beta,     dst_t, cu_data_type,   ne0,       ne1*ne0, // strideC
                 ne12*ne13,
                 cu_compute_type,
                 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -1994,7 +1986,9 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
     const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE
         && ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;
 
-    bool use_mul_mat_vec   = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
+    bool use_mul_mat_vec_f = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
+        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
+    bool use_mul_mat_f     = !ggml_is_quantized(src0->type)
         && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
     bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
         && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
@@ -2014,14 +2008,18 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
             }
 
             const int cc            = ggml_cuda_info().devices[id].cc;
+            const int warp_size     = ggml_cuda_info().devices[id].warp_size;
             use_mul_mat_q           = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
-            use_mul_mat_vec         = use_mul_mat_vec           && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
+            use_mul_mat_f           = use_mul_mat_f             && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]);
+            use_mul_mat_vec_f       = use_mul_mat_vec_f         && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
             any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available(cc);
         }
     } else {
         const int cc            = ggml_cuda_info().devices[ctx.device].cc;
+        const int warp_size     = ggml_cuda_info().devices[ctx.device].warp_size;
         use_mul_mat_q           = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
-        use_mul_mat_vec         = use_mul_mat_vec           && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
+        use_mul_mat_f           = use_mul_mat_f             && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]);
+        use_mul_mat_vec_f       = use_mul_mat_vec_f         && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
         any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available(cc);
     }
 
@@ -2034,15 +2032,17 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
     //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
 
     //TODO update for generic tensor parallelism
-    const int cc                     = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+    const int cc                 = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
     bool use_batched_cublas_f16  = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
     bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
     bool use_batched_cublas_f32  = src0->type == GGML_TYPE_F32;
 
-    if (!split && use_mul_mat_vec) {
+    if (!split && use_mul_mat_vec_f) {
         // the custom F16 vector kernel can be used over batched cuBLAS GEMM
         // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
-        ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
+        ggml_cuda_mul_mat_vec_f(ctx, src0, src1, nullptr, dst);
+    } else if (!split && use_mul_mat_f) {
+        ggml_cuda_mul_mat_f(ctx, src0, src1, nullptr, dst);
     } else if (!split && use_mul_mat_vec_q) {
         ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
     } else if (!split && use_mul_mat_q) {
@@ -2051,8 +2051,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
         && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
         // general KQ + KQV multi-batch without FlashAttention
         ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
-    } else if (use_mul_mat_vec) {
-        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec, nullptr);
+    } else if (use_mul_mat_vec_f) {
+        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_f, nullptr);
     } else if (use_mul_mat_vec_q) {
         ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
     } else if (use_mul_mat_q) {
@@ -2080,7 +2080,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
             if (ggml_is_quantized(src0->type)) {
                 ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
             } else {
-                ggml_cuda_mul_mat_vec(ctx, src0, src1, ids, dst);
+                ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
             }
             return;
         }
@@ -2230,6 +2230,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_GET_ROWS_BACK:
             ggml_cuda_op_get_rows_back(ctx, dst);
             break;
+        case GGML_OP_SET_ROWS:
+            ggml_cuda_op_set_rows(ctx, dst);
+            break;
         case GGML_OP_DUP:
             ggml_cuda_dup(ctx, dst);
             break;
@@ -2243,6 +2246,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_ADD1: // TODO: more efficient implementation
             ggml_cuda_op_add(ctx, dst);
             break;
+        case GGML_OP_ADD_ID:
+            ggml_cuda_op_add_id(ctx, dst);
+            break;
         case GGML_OP_SUB:
             ggml_cuda_op_sub(ctx, dst);
             break;
@@ -2299,6 +2305,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
                 case GGML_UNARY_OP_EXP:
                     ggml_cuda_op_exp(ctx, dst);
                     break;
+                case GGML_UNARY_OP_ELU:
+                    ggml_cuda_op_elu(ctx, dst);
+                    break;
                 default:
                     return false;
             }
@@ -2314,6 +2323,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
                 case GGML_GLU_OP_SWIGLU:
                     ggml_cuda_op_swiglu(ctx, dst);
                     break;
+                case GGML_GLU_OP_SWIGLU_OAI:
+                    ggml_cuda_op_swiglu_oai(ctx, dst);
+                    break;
                 case GGML_GLU_OP_GEGLU_ERF:
                     ggml_cuda_op_geglu_erf(ctx, dst);
                     break;
@@ -2411,6 +2423,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_ROPE_BACK:
             ggml_cuda_op_rope_back(ctx, dst);
             break;
+        case GGML_OP_ROLL:
+            ggml_cuda_op_roll(ctx, dst);
+            break;
         case GGML_OP_IM2COL:
             ggml_cuda_op_im2col(ctx, dst);
             break;
@@ -2465,6 +2480,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_OPT_STEP_ADAMW:
             ggml_cuda_opt_step_adamw(ctx, dst);
             break;
+        case GGML_OP_OPT_STEP_SGD:
+            ggml_cuda_opt_step_sgd(ctx, dst);
+            break;
         default:
             return false;
     }
@@ -2583,6 +2601,12 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
     // Loop over nodes in GGML graph to obtain info needed for CUDA graph
     cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
 
+    const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
+    const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
+    const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
+    const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
+    const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
+
     for (int i = 0; i < cgraph->n_nodes; i++) {
         ggml_tensor * node = cgraph->nodes[i];
 
@@ -2604,9 +2628,18 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
 #endif
         }
 
-        if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
-            // disable CUDA graphs for batch size > 1 for now.
-            // Changes in batch size or context size can cause changes to the grid size of some kernels.
+        if (node->op == GGML_OP_ADD &&
+            node->src[1] && node->src[1]->ne[1] > 1 &&
+            (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
+            (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
+            strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
+            strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
+            strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0) {
+            // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
+            // by means of matching node names. See
+            // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
+            // https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
+            // Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
             use_cuda_graph = false;
 #ifndef NDEBUG
             GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
@@ -2752,6 +2785,67 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
 }
 #endif
 
+static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list ops, std::initializer_list unary_ops) {
+#ifndef NDEBUG
+    const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
+    GGML_ASSERT(unary_ops.size() == num_unary);
+#endif
+
+    if (!ggml_can_fuse(cgraph, node_idx, ops)) {
+        return false;
+    }
+
+    if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
+        const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
+        const ggml_tensor *mul      = cgraph->nodes[node_idx+1];
+
+        GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
+        GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
+
+        //rms norm only supports F32
+        if (mul->src[0]->type != GGML_TYPE_F32 ||
+            mul->src[1]->type != GGML_TYPE_F32 ||
+            mul->type != GGML_TYPE_F32) {
+            return false;
+        }
+
+        //if rms norm is the B operand, then we don't handle broadcast
+        if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
+            return false;
+        }
+
+        //rms_norm kernel assumes contigous rows
+        if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
+            return false;
+        }
+
+        return true;
+    }
+
+    if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE
+     && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) {
+        const ggml_tensor *scale  = cgraph->nodes[node_idx];
+        const ggml_tensor *tanh   = cgraph->nodes[node_idx+1];
+        const ggml_tensor *scale2 = cgraph->nodes[node_idx+2];
+
+        GGML_ASSERT(scale->src[0]->type == GGML_TYPE_F32);
+        GGML_ASSERT(scale->type == GGML_TYPE_F32);
+
+        if (ggml_get_unary_op(tanh) != GGML_UNARY_OP_TANH) {
+            return false;
+        }
+
+        // Check for bias
+        if (ggml_get_op_params_f32(scale, 1) != 0.0f || ggml_get_op_params_f32(scale2, 1) != 0.0f) {
+            return false;
+        }
+
+        return true;
+    }
+
+    return false;
+}
+
 static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
     bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
     // flag used to determine whether it is an integrated_gpu
@@ -2761,6 +2855,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
         // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
         // With the use of CUDA graphs, the execution will be performed by the graph launch.
         if (!use_cuda_graph || cuda_graph_update_required) {
+
             for (int i = 0; i < cgraph->n_nodes; i++) {
                 ggml_tensor * node = cgraph->nodes[i];
 
@@ -2768,6 +2863,20 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
                     continue;
                 }
 
+                static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
+                if (!disable_fusion) {
+                    if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
+                        ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
+                        i++;
+                        continue;
+                    }
+
+                    if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
+                        i += 2;
+                        ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
+                        continue;
+                    }
+                }
 #ifndef NDEBUG
                 assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
                 for (int j = 0; j < GGML_MAX_SRC; j++) {
@@ -3112,6 +3221,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 case GGML_UNARY_OP_GELU_QUICK:
                 case GGML_UNARY_OP_TANH:
                 case GGML_UNARY_OP_EXP:
+                case GGML_UNARY_OP_ELU:
                     return ggml_is_contiguous(op->src[0]);
                 default:
                     return false;
@@ -3122,6 +3232,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 case GGML_GLU_OP_REGLU:
                 case GGML_GLU_OP_GEGLU:
                 case GGML_GLU_OP_SWIGLU:
+                case GGML_GLU_OP_SWIGLU_OAI:
                 case GGML_GLU_OP_GEGLU_ERF:
                 case GGML_GLU_OP_GEGLU_QUICK:
                     return ggml_is_contiguous_1(op->src[0]);
@@ -3172,6 +3283,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                     case GGML_TYPE_Q5_0:
                     case GGML_TYPE_Q5_1:
                     case GGML_TYPE_Q8_0:
+                    case GGML_TYPE_MXFP4:
                     case GGML_TYPE_Q2_K:
                     case GGML_TYPE_Q3_K:
                     case GGML_TYPE_Q4_K:
@@ -3216,17 +3328,21 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             {
                 return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
             } break;
+        case GGML_OP_SET_ROWS:
+            {
+                return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
+                       op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q5_0 ||
+                       op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL) &&
+                       op->src[0]->type == GGML_TYPE_F32 &&
+                       op->src[1]->type == GGML_TYPE_I64;
+            } break;
         case GGML_OP_CPY:
             {
                 ggml_type src0_type = op->src[0]->type;
                 ggml_type src1_type = op->src[1]->type;
-                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
-                    return true;
-                }
-                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_BF16) {
-                    return true;
-                }
-                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
+                if ((src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_BF16 || src0_type == GGML_TYPE_F16) &&
+                    (src1_type == GGML_TYPE_F32 || src1_type == GGML_TYPE_BF16 || src1_type == GGML_TYPE_F16)
+                ) {
                     return true;
                 }
                 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
@@ -3262,12 +3378,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
                     return true;
                 }
-                if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
-                    return true;
-                }
-                if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
-                    return true;
-                }
                 if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
                     return true;
                 }
@@ -3320,6 +3430,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_PERMUTE:
         case GGML_OP_TRANSPOSE:
         case GGML_OP_ADD:
+        case GGML_OP_ADD_ID:
         case GGML_OP_ADD1:
         case GGML_OP_SUB:
         case GGML_OP_MUL:
@@ -3348,7 +3459,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             return op->src[0]->ne[1] % 128 == 0;
         }
         case GGML_OP_CONT:
-            return op->src[0]->type != GGML_TYPE_BF16;
+            return true;
         case GGML_OP_DIAG_MASK_INF:
             return true;
         case GGML_OP_SOFT_MAX:
@@ -3358,6 +3469,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
             return max_bias == 0.0f;
         }
+        case GGML_OP_ROLL:
+            if(op->src[0]->type == GGML_TYPE_F32) {
+                return true;
+            }
+            return false;
         case GGML_OP_ROPE:
         case GGML_OP_ROPE_BACK: {
             return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
@@ -3389,19 +3505,18 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
 #endif // FLASH_ATTN_AVAILABLE
             if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
                 const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
-                if (!new_mma_available(cc)) {
+                if (!turing_mma_available(cc)) {
                     return false;
                 }
                 const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
                 return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
             }
-            if (op->src[0]->ne[0] == 192) {
+            // TODO: more general-purpose attention sink support [TAG_ATTN_SINKS]
+            if (op->src[4] && !fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc)
+                    && op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) {
                 return false;
             }
-            // TODO: support broadcast
-            // note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but
-            //       the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
-            if (op->src[0]->ne[3] != 1) {
+            if (op->src[0]->ne[0] == 192) {
                 return false;
             }
             if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
@@ -3416,12 +3531,16 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
                 return true;
             }
+            if (op->src[3] && op->src[3]->ne[2] != 1) {
+                return false;
+            }
             return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
                 op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
         }
         case GGML_OP_CROSS_ENTROPY_LOSS:
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
         case GGML_OP_OPT_STEP_ADAMW:
+        case GGML_OP_OPT_STEP_SGD:
             return true;
         default:
             return false;
@@ -3661,10 +3780,10 @@ ggml_backend_t ggml_backend_cuda_init(int device) {
     }
 
     ggml_backend_t cuda_backend = new ggml_backend {
-        /* .guid      = */ ggml_backend_cuda_guid(),
-        /* .interface = */ ggml_backend_cuda_interface,
-        /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device),
-        /* .context   = */ ctx,
+        /* .guid    = */ ggml_backend_cuda_guid(),
+        /* .iface   = */ ggml_backend_cuda_interface,
+        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device),
+        /* .context = */ ctx,
     };
 
     return cuda_backend;
diff --git a/ggml/src/ggml-cuda/im2col.cu b/ggml/src/ggml-cuda/im2col.cu
index 86a54e42bb7..16bb9bec97d 100644
--- a/ggml/src/ggml-cuda/im2col.cu
+++ b/ggml/src/ggml-cuda/im2col.cu
@@ -1,65 +1,76 @@
 #include "im2col.cuh"
 
+#define MAX_GRIDDIM_Z 65535
+
 template 
 static  __global__ void im2col_kernel(
-        const float * x, T * dst, int64_t batch_offset,
-        int64_t offset_delta, int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW,
+        const float * x, T * dst,
+        int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH,
+        int64_t IC_IH_IW, int64_t IH_IW, int64_t N_OH, int64_t KH_KW, int64_t IC_KH_KW,
         int s0, int s1, int p0, int p1, int d0, int d1) {
     const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
-    if (i >= pelements) {
+    if (i >= IC_KH_KW) {
         return;
     }
 
-    const int64_t  ksize = OW * (KH > 1 ? KW : 1);
-    const int64_t  kx = i / ksize;
-    const int64_t  kd = kx * ksize;
-    const int64_t  ky = (i - kd) / OW;
-    const int64_t  ix = i % OW;
+    const int64_t iic = i / (KH_KW);
+    const int64_t rem = i - iic * KH_KW;
+    const int64_t ikh = rem / KW;
+    const int64_t ikw = rem - ikh * KW;
 
-    const int64_t  oh = blockIdx.y;
-    const int64_t  batch = blockIdx.z / IC;
-    const int64_t  ic = blockIdx.z % IC;
+    const int64_t  iow = blockIdx.y;
+    for (int64_t iz = blockIdx.z; iz < N_OH; iz+=MAX_GRIDDIM_Z) {
+        const int64_t  in = iz / OH;
+        const int64_t  ioh = iz - in * OH;
 
-    const int64_t iiw = ix * s0 + kx * d0 - p0;
-    const int64_t iih = oh * s1 + ky * d1 - p1;
+        const int64_t iiw = iow * s0 + ikw * d0 - p0;
+        const int64_t iih = ioh * s1 + ikh * d1 - p1;
 
-    const int64_t offset_dst =
-        ((batch * OH + oh) * OW + ix) * CHW +
-        (ic * (KW * KH) + ky * KW + kx);
+        const int64_t offset_dst =
+            ((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw;
 
-    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
-        dst[offset_dst] = 0.0f;
-    } else {
-        const int64_t offset_src = ic * offset_delta + batch * batch_offset;
-        dst[offset_dst] = x[offset_src + iih * IW + iiw];
+        if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+            dst[offset_dst] = 0.0f;
+        } else {
+            const int64_t offset_src = iic * IC_IH_IW + in * IH_IW;
+            dst[offset_dst] = x[offset_src + iih * IW + iiw];
+        }
     }
+
+    GGML_UNUSED(IC);
+    GGML_UNUSED(KH);
 }
 
+// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
 template 
 static void im2col_cuda(const float * x, T* dst,
     int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
-    int64_t batch, int64_t batch_offset, int64_t offset_delta,
+    int64_t N, int64_t IC_IH_IW, int64_t IH_IW,
     int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
-    const int parallel_elements = OW * KW * KH;
-    const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
-    dim3 block_nums(num_blocks, OH, batch * IC);
-    im2col_kernel<<>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
+    const int64_t IC_KH_KW = IC * KH * KW;
+    const int64_t num_blocks = (IC_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
+    const int64_t N_OH = N * OH;
+    const int64_t KH_KW = KW*KH;
+    dim3 block_nums(num_blocks, OW, MIN(N_OH, MAX_GRIDDIM_Z));
+    im2col_kernel<<>>(x, dst, IC, IW, IH, OH, OW, KW, KH,
+                                                                                     IC_IH_IW, IH_IW, N_OH, KH_KW, IC_KH_KW,
+                                                                                     s0, s1, p0, p1, d0, d1);
 }
 
 static void im2col_cuda_f16(const float * x, half * dst,
     int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
-    int64_t batch, int64_t batch_offset, int64_t offset_delta,
+    int64_t N, int64_t IC_IH_IW, int64_t IH_IW,
     int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
 
-    im2col_cuda(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream);
+    im2col_cuda(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
 }
 
 static void im2col_cuda_f32(const float * x, float * dst,
     int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
-    int64_t batch, int64_t batch_offset, int64_t offset_delta,
+    int64_t N, int64_t IC_IH_IW, int64_t IH_IW,
     int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
 
-    im2col_cuda(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream);
+    im2col_cuda(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
 }
 
 void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -91,13 +102,13 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const int64_t OH = is_2D ? dst->ne[2] : 1;
     const int64_t OW =         dst->ne[1];
 
-    const size_t  delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
-    const int64_t batch        = src1->ne[is_2D ? 3 : 2];
-    const size_t  batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
+    const int64_t IC_IH_IW = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
+    const int64_t N        = src1->ne[is_2D ? 3 : 2];
+    const int64_t IH_IW    = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
 
     if(dst->type == GGML_TYPE_F16) {
-        im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
+        im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
     } else {
-        im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
+        im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
     }
 }
diff --git a/ggml/src/ggml-cuda/mean.cu b/ggml/src/ggml-cuda/mean.cu
index 4b238a3998b..347abc18660 100644
--- a/ggml/src/ggml-cuda/mean.cu
+++ b/ggml/src/ggml-cuda/mean.cu
@@ -1,4 +1,14 @@
 #include "mean.cuh"
+#include "reduce_rows.cuh"
+
+#ifdef GGML_CUDA_USE_CUB
+#include 
+using namespace cub;
+#endif  // GGML_CUDA_USE_CUB
+
+template  __global__ void divide_by_count(T * result, size_t count) {
+    *result /= static_cast(count);
+}
 
 void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0   = dst->src[0];
@@ -13,7 +23,51 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const int64_t ncols = src0->ne[0];
     const int64_t nrows = ggml_nrows(src0);
 
-    const dim3 block_dims(WARP_SIZE, 1, 1);
+// Special case for reducing vectors
+#ifdef GGML_CUDA_USE_CUB
+#ifdef USE_CUDA_GRAPH
+    cudaStreamCaptureStatus iscapturing;
+    CUDA_CHECK(cudaStreamIsCapturing(stream, &iscapturing));
+#endif // USE_CUDA_GRAPH
+    if ((nrows == 1) &&
+#ifdef USE_CUDA_GRAPH
+            // CUDA_GRAPHS_DISABLED
+            ((ncols > 65536) &&
+             ((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
+              ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates ||
+              ctx.cuda_graph->disable_due_to_failed_graph_capture)) ||
+        // CUDA_GRAPHS ENABLED
+        ((ncols > 32768) &&
+         !((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
+           ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates ||
+           ctx.cuda_graph->disable_due_to_failed_graph_capture))) {
+#else
+        (ncols > 65536)) {
+#endif // USE_CUDA_GRAPH
+        // Single row - use device-wide reduction
+        size_t           tmp_size = 0;
+        ggml_cuda_pool & pool     = ctx.pool();
+
+        DeviceReduce::Sum(nullptr, tmp_size, src0_d, dst_d, ncols, stream);
+
+        ggml_cuda_pool_alloc tmp_alloc(pool, tmp_size);
+        DeviceReduce::Sum(tmp_alloc.ptr, tmp_size, src0_d, dst_d, ncols, stream);
+
+        // Divide by ncols
+        divide_by_count<<<1, 1, 0, stream>>>(dst_d, ncols);
+        return;
+    }
+#endif // GGML_CUDA_USE_CUB
+
     const dim3 block_nums(nrows, 1, 1);
-    reduce_rows_f32<<>>(src0_d, dst_d, ncols);
+
+    const int id  = ggml_cuda_get_device();
+    const int nsm = ggml_cuda_info().devices[id].nsm;
+    if ((nrows / nsm) < 2) {
+        const dim3 block_dims(512, 1, 1);
+        reduce_rows_f32<<>>(src0_d, dst_d, ncols);
+    } else {
+        const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
+        reduce_rows_f32<<>>(src0_d, dst_d, ncols);
+    }
 }
diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh
index 2af63355a19..83ee16b27d0 100644
--- a/ggml/src/ggml-cuda/mma.cuh
+++ b/ggml/src/ggml-cuda/mma.cuh
@@ -12,7 +12,8 @@
 // The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile.
 // All matrix tiles have ne physical 32 bit elements per warp.
 //
-// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
+// As described in the PTX documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
+// The API in this file also assumes that the pointers for load_generic are aligned to 16 bytes, unaligned pointers are considered undefined behavior.
 
 #include "common.cuh"
 
@@ -22,13 +23,13 @@
 static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
     int ret = 0;
 
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
     asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
         : "=r"(ret) : "r"(x));
 #else
     GGML_UNUSED(x);
     NO_DEVICE_CODE;
-#endif // defined(NEW_MMA_AVAILABLE)
+#endif // defined(TURING_MMA_AVAILABLE)
     return ret;
 }
 
@@ -66,7 +67,44 @@ namespace ggml_cuda_mma {
     struct tile {
         static constexpr int I  = I_;
         static constexpr int J  = J_;
-        static constexpr int ne = I * J / WARP_SIZE;
+
+#if defined(GGML_USE_HIP)
+        static constexpr int ne = I * J / 64;
+        T x[ne] = {0};
+
+        static __device__ __forceinline__ int get_i(const int l) {
+            if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
+                return threadIdx.x % 16;
+            } else if constexpr (I == 16 && J == 8) {
+                return threadIdx.x % 16;
+            } else if constexpr (I == 32 && J == 4) {
+                return threadIdx.x % 32;
+            } else if constexpr (I == 16 && J == 16) {
+                return 4 * (threadIdx.x / 16) + l;
+            } else if constexpr (I == 32 && J == 32) {
+                return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
+            } else {
+                static_assert(I == -1 && J == -1, "template specialization not implemented");
+            }
+        }
+
+        static __device__ __forceinline__ int get_j(const int l) {
+            if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
+                return (2 * ((threadIdx.x / 16) % 2) + l);
+            } else if constexpr (I == 16 && J == 8) {
+                return 2 * (threadIdx.x / 16) + l;
+            } else if constexpr (I == 32 && J == 4) {
+                return 2 * (threadIdx.x / 32) + l;
+            } else if constexpr (I == 16 && J == 16) {
+                return threadIdx.x % 16;
+            } else if constexpr (I == 32 && J == 32) {
+                return threadIdx.x % 32;
+            } else {
+                static_assert(I == -1 && J == -1, "template specialization not implemented");
+            }
+        }
+#else
+        static constexpr int ne = I * J / 32;
         T x[ne] = {0};
 
         static __device__ __forceinline__ int get_i(const int l) {
@@ -94,6 +132,7 @@ namespace ggml_cuda_mma {
                 static_assert(I == -1 && J == -1, "template specialization not implemented");
             }
         }
+#endif // defined(GGML_USE_HIP)
     };
 
     template 
@@ -128,6 +167,38 @@ namespace ggml_cuda_mma {
         }
     };
 
+    template 
+    struct tile {
+        static constexpr int I  = I_;
+        static constexpr int J  = J_;
+        static constexpr int ne = I * J / WARP_SIZE;
+        nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
+
+        static __device__ __forceinline__ int get_i(const int l) {
+            if constexpr (I == 8 && J == 8) {
+                return threadIdx.x / 4;
+            } else if constexpr (I == 16 && J == 4) {
+                return l * 8 + threadIdx.x / 4;
+            } else if constexpr (I == 16 && J == 8) {
+                return (l % 2) * 8 + threadIdx.x / 4;
+            } else {
+                static_assert(I == -1 && J == -1, "template specialization not implemented");
+            }
+        }
+
+        static __device__ __forceinline__ int get_j(const int l) {
+            if constexpr (I == 8 && J == 8) {
+                return l * 4 + threadIdx.x % 4;
+            } else if constexpr (I == 16 && J == 4) {
+                return threadIdx.x % 4;
+            } else if constexpr (I == 16 && J == 8) {
+                return (l / 2) * 4 + threadIdx.x % 4;
+            } else {
+                static_assert(I == -1 && J == -1, "template specialization not implemented");
+            }
+        }
+    };
+
     template 
     static __device__ __forceinline__ tile get_half2(const tile & tile_float) {
         tile ret;
@@ -148,16 +219,29 @@ namespace ggml_cuda_mma {
 
     template 
     static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) {
+#if defined(AMD_MFMA_AVAILABLE)
+        if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
+#pragma unroll
+            for (int l = 0; l < t.ne; ++l) {
+                t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
+            }
+        } else {
+            int64_t * xi = (int64_t *) t.x;
+            const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
+            xi[0] = xs[0];
+        }
+#else
 #pragma unroll
         for (int l = 0; l < t.ne; ++l) {
             t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
         }
+#endif // defined(AMD_MFMA_AVAILABLE)
     }
 
     template 
     static __device__ __forceinline__ void load_ldmatrix(
             tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
         int * xi = (int *) t.x;
         const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J;
         asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
@@ -165,13 +249,13 @@ namespace ggml_cuda_mma {
             : "l"(xs));
 #else
         load_generic(t, xs0, stride);
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
     }
 
     template 
     static __device__ __forceinline__ void load_ldmatrix(
             tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
         int * xi = (int *) t.x;
         const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
         asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
@@ -180,13 +264,13 @@ namespace ggml_cuda_mma {
 #else
         load_generic(xs0, stride);
         GGML_UNUSED(t);
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
     }
 
     template 
     static __device__ __forceinline__ void load_ldmatrix(
             tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
-#ifdef NEW_MMA_AVAILABLE
+#if defined(TURING_MMA_AVAILABLE)
         int * xi = (int * ) t.x;
         const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
         asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
@@ -194,13 +278,13 @@ namespace ggml_cuda_mma {
             : "l"(xs));
 #else
         load_generic(t, xs0, stride);
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
     }
 
     template 
     static __device__ __forceinline__ void load_ldmatrix_trans(
             tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
         int * xi = (int * ) t.x;
         const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
         asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
@@ -211,12 +295,12 @@ namespace ggml_cuda_mma {
         GGML_UNUSED(xs0);
         GGML_UNUSED(stride);
         NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
     }
 
     static __device__ __forceinline__ void mma(
             tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
 #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
         asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
             : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
@@ -235,12 +319,12 @@ namespace ggml_cuda_mma {
         GGML_UNUSED(A);
         GGML_UNUSED(B);
         NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
     }
 
     static __device__ __forceinline__ void mma(
             tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
 #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
         asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
             : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
@@ -265,12 +349,12 @@ namespace ggml_cuda_mma {
         GGML_UNUSED(A);
         GGML_UNUSED(B);
         NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
     }
 
     static __device__ __forceinline__ void mma(
             tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
         const int * Axi = (const int *) A.x;
         const int * Bxi = (const int *) B.x;
         int       * Dxi = (int       *) D.x;
@@ -292,12 +376,12 @@ namespace ggml_cuda_mma {
         GGML_UNUSED(A);
         GGML_UNUSED(B);
         NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
     }
 
     static __device__ __forceinline__ void mma(
             tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
         const int * Axi = (const int *) A.x;
         const int * Bxi = (const int *) B.x;
         int       * Dxi = (int       *) D.x;
@@ -328,12 +412,29 @@ namespace ggml_cuda_mma {
         GGML_UNUSED(A);
         GGML_UNUSED(B);
         NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
+    }
+
+    static __device__ __forceinline__ void mma(
+            tile<16, 8, float> & D, const tile<16, 8, float> & A, const tile<8, 8, float> & B) {
+#ifdef AMPERE_MMA_AVAILABLE
+        const int * Axi = (const int *) A.x;
+        const int * Bxi = (const int *) B.x;
+        int       * Dxi = (int       *) D.x;
+        asm("mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
+#else
+        GGML_UNUSED(D);
+        GGML_UNUSED(A);
+        GGML_UNUSED(B);
+        NO_DEVICE_CODE;
+#endif // AMPERE_MMA_AVAILABLE
     }
 
     static __device__ __forceinline__ void mma(
             tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
         const int * Axi = (const int *) A.x;
         const int * Bxi = (const int *) B.x;
         int       * Dxi = (int       *) D.x;
@@ -355,12 +456,29 @@ namespace ggml_cuda_mma {
         GGML_UNUSED(A);
         GGML_UNUSED(B);
         NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
+    }
+
+    static __device__ __forceinline__ void mma(
+            tile<16, 8, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<8, 8, nv_bfloat162> & B) {
+#ifdef AMPERE_MMA_AVAILABLE
+        const int * Axi = (const int *) A.x;
+        const int * Bxi = (const int *) B.x;
+        int       * Dxi = (int       *) D.x;
+        asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
+#else
+        GGML_UNUSED(D);
+        GGML_UNUSED(A);
+        GGML_UNUSED(B);
+        NO_DEVICE_CODE;
+#endif // AMPERE_MMA_AVAILABLE
     }
 
     static __device__ __forceinline__ void mma(
             tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
-#ifdef NEW_MMA_AVAILABLE
+#ifdef TURING_MMA_AVAILABLE
         const int * Axi = (const int *) A.x;
         const int * Bxi = (const int *) B.x;
         int       * Dxi = (int       *) D.x;
@@ -391,6 +509,62 @@ namespace ggml_cuda_mma {
         GGML_UNUSED(A);
         GGML_UNUSED(B);
         NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // TURING_MMA_AVAILABLE
+    }
+
+    static __device__ __forceinline__ void mma(
+            tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
+#if defined(AMD_MFMA_AVAILABLE)
+        using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
+        int32x4_t * acc = (int32x4_t *) D.x;
+#if defined(CDNA3)
+        acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0],
+                                                       ((int64_t *) B.x)[0],
+                                                       acc[0],
+                                                       0, 0, 0);
+#elif defined(CDNA2) || defined(CDNA)
+        acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0],
+                                                      B.x[0],
+                                                      acc[0],
+                                                      0, 0, 0);
+        acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1],
+                                                      B.x[1],
+                                                      acc[0],
+                                                      0, 0, 0);
+#endif // defined(CDNA3)
+#else
+        GGML_UNUSED(D);
+        GGML_UNUSED(A);
+        GGML_UNUSED(B);
+        NO_DEVICE_CODE;
+#endif // AMD_MFMA_AVAILABLE
+    }
+
+    static __device__ __forceinline__ void mma(
+            tile<32, 32, int> & D, const tile<32, 4, int> & A, const tile<32, 4, int> & B) {
+#if defined(AMD_MFMA_AVAILABLE)
+        using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int;
+        int32x16_t * acc = (int32x16_t *) D.x;
+#if defined(CDNA3)
+        acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0],
+                                                       ((int64_t *) B.x)[0],
+                                                       acc[0],
+                                                       0, 0, 0);
+#elif defined(CDNA2) || defined(CDNA)
+        acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0],
+                                                     B.x[0],
+                                                     acc[0],
+                                                     0, 0, 0);
+        acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1],
+                                                     B.x[1],
+                                                     acc[0],
+                                                     0, 0, 0);
+#endif // defined(CDNA3)
+#else
+        GGML_UNUSED(D);
+        GGML_UNUSED(A);
+        GGML_UNUSED(B);
+        NO_DEVICE_CODE;
+#endif // AMD_MFMA_AVAILABLE
     }
 }
diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu
new file mode 100644
index 00000000000..1437367e871
--- /dev/null
+++ b/ggml/src/ggml-cuda/mmf.cu
@@ -0,0 +1,431 @@
+#include "ggml.h"
+#include "common.cuh"
+#include "mma.cuh"
+#include "mmf.cuh"
+
+using namespace ggml_cuda_mma;
+
+#define MMF_ROWS_PER_BLOCK 32
+
+template 
+__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
+static __global__ void mul_mat_f(
+        const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
+        const int ncols, const int nchannels_y, const int stride_row, const int stride_col_y, const int stride_col_dst,
+        const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+        const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+    typedef tile<16, 8, T>     tile_A;
+    typedef tile< 8, 8, T>     tile_B;
+    typedef tile<16, 8, float> tile_C;
+
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+    constexpr int tile_k_padded = warp_size + 4;
+    constexpr int ntA = rows_per_block / tile_A::I;
+    constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
+
+    const int row0        = blockIdx.x * rows_per_block;
+    const int channel_dst = blockIdx.y;
+    const int channel_x   = channel_dst / channel_ratio;
+    const int channel_y   = channel_dst;
+    const int sample_dst  = blockIdx.z;
+    const int sample_x    = sample_dst / sample_ratio;
+    const int sample_y    = sample_dst;
+
+    x   += int64_t(sample_x)  *stride_sample_x   + channel_x  *stride_channel_x   + row0*stride_row ;
+    y   += int64_t(sample_y)  *stride_sample_y   + channel_y  *stride_channel_y;
+    dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
+
+    const float2 * y2 = (const float2 *) y;
+
+    extern __shared__ char data_mmv[];
+
+    tile_C C[ntA][ntB];
+
+    T * tile_xy = (T *) data_mmv + threadIdx.y*(tile_A::I * tile_k_padded);
+
+    for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
+        tile_A A[ntA][warp_size / tile_A::J];
+#pragma unroll
+        for (int itA = 0; itA < ntA; ++itA) {
+#pragma unroll
+            for (int i = 0; i < tile_A::I; ++i) {
+                tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row  + col];
+            }
+#pragma unroll
+            for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
+                load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
+            }
+        }
+
+#pragma unroll
+        for (int itB = 0; itB < ntB; ++itB) {
+            if constexpr (std::is_same_v) {
+#pragma unroll
+                for (int j0 = 0; j0 < tile_B::I; ++j0) {
+                    const int j = j0 + itB*tile_B::I;
+
+                    tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
+                }
+            } else if constexpr (std::is_same_v || std::is_same_v) {
+#pragma unroll
+                for (int j0 = 0; j0 < tile_B::I; ++j0) {
+                    const int j = j0 + itB*tile_B::I;
+
+                    const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
+                    tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
+                }
+            } else {
+                static_assert(std::is_same_v, "unsupported type");
+            }
+#pragma unroll
+            for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
+                tile_B B;
+                load_ldmatrix(B, tile_xy + k0, tile_k_padded);
+#pragma unroll
+                for (int itA = 0; itA < ntA; ++itA) {
+                    mma(C[itA][itB], A[itA][k0/tile_B::J], B);
+                }
+            }
+        }
+    }
+
+    float * buf_iw = (float *) data_mmv;
+    constexpr int kiw = nwarps*rows_per_block + 4;
+
+    if (nwarps > 1) {
+        __syncthreads();
+    }
+#pragma unroll
+    for (int itB = 0; itB < ntB; ++itB) {
+#pragma unroll
+        for (int itA = 0; itA < ntA; ++itA) {
+#pragma unroll
+            for (int l = 0; l < tile_C::ne; ++l) {
+                const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
+                const int j = itB*tile_C::J + tile_C::get_j(l);
+                buf_iw[j*kiw + i] = C[itA][itB].x[l];
+            }
+        }
+    }
+
+    if (nwarps > 1) {
+        __syncthreads();
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+        if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
+            return;
+        }
+
+        float sum = 0.0f;
+        static_assert(rows_per_block == warp_size, "need loop/check");
+#pragma unroll
+        for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
+            const int i = i0 + threadIdx.x;
+
+            sum += buf_iw[j*kiw + i];
+        }
+        dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
+    }
+#else
+    NO_DEVICE_CODE;
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(ids); GGML_UNUSED(dst);
+    GGML_UNUSED(ncols); GGML_UNUSED(nchannels_y); GGML_UNUSED(stride_row); GGML_UNUSED(stride_col_y); GGML_UNUSED(stride_col_dst);
+    GGML_UNUSED(channel_ratio); GGML_UNUSED(stride_channel_x); GGML_UNUSED(stride_channel_y); GGML_UNUSED(stride_channel_dst);
+    GGML_UNUSED(sample_ratio); GGML_UNUSED(stride_sample_x); GGML_UNUSED(stride_sample_y); GGML_UNUSED(stride_sample_dst);
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+}
+
+template 
+static void mul_mat_f_cuda(
+        const T * x, const float * y, const int32_t * ids, float * dst,
+        const int64_t ncols_x, const int64_t nrows_x,
+        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+        cudaStream_t stream) {
+    typedef tile<16, 8, T>     tile_A;
+    typedef tile< 8, 8, T>     tile_B;
+    typedef tile<16, 8, float> tile_C;
+
+    GGML_ASSERT(!ids && "mul_mat_id not implemented");
+
+    GGML_ASSERT(ncols_x      % 2 == 0);
+    GGML_ASSERT(stride_row   % 2 == 0);
+    GGML_ASSERT(stride_col_y % 2 == 0);
+    GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
+    GGML_ASSERT(       nsamples_dst  % nsamples_x  == 0);
+    const int64_t channel_ratio = nchannels_dst / nchannels_x;
+    const int64_t sample_ratio  = nsamples_dst  / nsamples_x;
+
+    const int device = ggml_cuda_get_device();
+    const int warp_size = ggml_cuda_info().devices[device].warp_size;
+
+    int64_t nwarps_best     = 1;
+    int64_t niter_best      = (ncols_x + warp_size*2 - 1) / (warp_size*2);
+    int64_t max_block_size  = 256;
+    for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
+        const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
+        if (niter < niter_best) {
+            niter_best  = niter;
+            nwarps_best = nwarps;
+        }
+    }
+
+    constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
+    const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4;
+    const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
+    const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
+    const dim3 block_nums(nrows_x/rows_per_block, nchannels_dst, nsamples_dst);
+    const dim3 block_dims(warp_size, nwarps_best, 1);
+    switch (nwarps_best) {
+        case 1: {
+            mul_mat_f<<>>
+                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case 2: {
+            mul_mat_f<<>>
+                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case 3: {
+            mul_mat_f<<>>
+                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case 4: {
+            mul_mat_f<<>>
+                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case 5: {
+            mul_mat_f<<>>
+                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case 6: {
+            mul_mat_f<<>>
+                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case 7: {
+            mul_mat_f<<>>
+                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case 8: {
+            mul_mat_f<<>>
+                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        default: {
+            GGML_ABORT("fatal error");
+        } break;
+    }
+}
+
+template 
+static void mul_mat_f_switch_cols_per_block(
+        const T * x, const float * y, const int32_t * ids, float * dst,
+        const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
+        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+        cudaStream_t stream) {
+    switch (ncols_dst) {
+        case  1: {
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case  2: {
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case  3: {
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case  4: {
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case  5: {
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case  6: {
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case  7: {
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case  8: {
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case  9: {
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case 10: {
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case 11: {
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case 12: {
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case 13: {
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case 14: {
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case 15: {
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        case 16: {
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
+        } break;
+        default: {
+            GGML_ABORT("fatal error");
+        } break;
+    }
+}
+
+void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
+    GGML_ASSERT(        src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(!ids ||  ids->type == GGML_TYPE_I32);
+    GGML_ASSERT(         dst->type == GGML_TYPE_F32);
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    const size_t ts_src0 = ggml_type_size(src0->type);
+    const size_t ts_src1 = ggml_type_size(src1->type);
+    const size_t ts_dst  = ggml_type_size(dst->type);
+
+    GGML_ASSERT(ne13 == ne3);
+
+    GGML_ASSERT(        nb00       == ts_src0);
+    GGML_ASSERT(        nb10       == ts_src1);
+    GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
+    GGML_ASSERT(        nb0        == ts_dst);
+
+    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+    const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
+
+    const float   * src1_d =       (const float   *) src1->data;
+    const int32_t *  ids_d = ids ? (const int32_t *)  ids->data : nullptr;
+    float         *  dst_d =       (float         *)  dst->data;
+
+    const int64_t s01 = src0->nb[1] / ts_src0;
+    const int64_t s11 = src1->nb[1] / ts_src1;
+    const int64_t s1  =  dst->nb[1] / ts_dst;
+    const int64_t s02 = src0->nb[2] / ts_src0;
+    const int64_t s12 = src1->nb[2] / ts_src1;
+    const int64_t s2  =  dst->nb[2] / ts_dst;
+    const int64_t s03 = src0->nb[3] / ts_src0;
+    const int64_t s13 = src1->nb[3] / ts_src1;
+    const int64_t s3  =  dst->nb[3] / ts_dst;
+
+    // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
+    const int64_t ncols_dst          = ids ? ne2  : ne1;
+    const int64_t nchannels_y        = ids ? ne11 : ne12;
+    const int64_t nchannels_dst      = ids ? ne1  : ne2;
+    const int64_t stride_channel_dst = ids ? s1   : s2;
+    const int64_t stride_channel_y   = ids ? s11  : s12;
+
+    GGML_ASSERT(!ids || ncols_dst == 1);
+
+    switch (src0->type) {
+        case GGML_TYPE_F32: {
+            const float * src0_d = (const float *) src0->data;
+            constexpr int vals_per_T = 1;
+            mul_mat_f_switch_cols_per_block(
+                src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
+                ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
+                ne03,              ne3,           s03/vals_per_T, s13,              s3,                 ctx.stream());
+        } break;
+        case GGML_TYPE_F16: {
+            const half2 * src0_d = (const half2 *) src0->data;
+            constexpr int vals_per_T = 2;
+            mul_mat_f_switch_cols_per_block(
+                src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
+                ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
+                ne03,              ne3,           s03/vals_per_T, s13,              s3,                 ctx.stream());
+        } break;
+        case GGML_TYPE_BF16: {
+            const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
+            constexpr int vals_per_T = 2;
+            mul_mat_f_switch_cols_per_block(
+                src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
+                ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
+                ne03,              ne3,           s03/vals_per_T, s13,              s3,                 ctx.stream());
+        } break;
+        default:
+            GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
+    }
+}
+
+bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, int64_t ne11) {
+    if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) {
+        return false;
+    }
+    if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
+        return false;
+    }
+    if (ne11 > 16) {
+        return false;
+    }
+    switch (type) {
+        case GGML_TYPE_F32:
+            return ampere_mma_available(cc);
+        case GGML_TYPE_F16:
+            return turing_mma_available(cc);
+        case GGML_TYPE_BF16:
+            return ampere_mma_available(cc);
+        default:
+            return false;
+    }
+}
diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh
new file mode 100644
index 00000000000..785f9f211c3
--- /dev/null
+++ b/ggml/src/ggml-cuda/mmf.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
+
+bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, int64_t ne11);
diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu
index 2db5b4ab0f0..384ee7615f7 100644
--- a/ggml/src/ggml-cuda/mmq.cu
+++ b/ggml/src/ggml-cuda/mmq.cu
@@ -20,6 +20,9 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, con
         case GGML_TYPE_Q8_0:
             mul_mat_q_case(ctx, args, stream);
             break;
+        case GGML_TYPE_MXFP4:
+            mul_mat_q_case(ctx, args, stream);
+            break;
         case GGML_TYPE_Q2_K:
             mul_mat_q_case(ctx, args, stream);
             break;
@@ -109,7 +112,8 @@ void ggml_cuda_mul_mat_q(
     const int64_t s03 = src0->nb[3] / ts_src0;
     const int64_t s3  =  dst->nb[3] / ts_dst;
 
-    const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
+    const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
+                            || GGML_CUDA_CC_IS_CDNA(cc);
 
     if (!ids) {
         const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
@@ -250,8 +254,9 @@ void ggml_cuda_op_mul_mat_q(
     // The stream-k decomposition is only faster for recent NVIDIA GPUs.
     // Also its fixup needs to allocate a temporary buffer in the memory pool.
     // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
-    const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) &&
-        ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && src1_ncols == ne11;
+    const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
+                            || GGML_CUDA_CC_IS_CDNA(cc))
+                            && src1_ncols == ne11;
     const mmq_args args = {
         src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i,
         ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,
@@ -280,6 +285,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
         case GGML_TYPE_Q5_0:
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
+        case GGML_TYPE_MXFP4:
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
@@ -304,7 +310,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
         return false;
     }
 
-    if (new_mma_available(cc)) {
+    if (turing_mma_available(cc)) {
         return true;
     }
 
@@ -320,5 +326,21 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
         return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
     }
 
+    if (amd_mfma_available(cc)) {
+        // As of ROCM 7.0 rocblas/tensile performs very poorly on CDNA3 and hipblaslt (via ROCBLAS_USE_HIPBLASLT)
+        // performs better but is currently suffering from a crash on this architecture.
+        // TODO: Revisit when hipblaslt is fixed on CDNA3
+        if (GGML_CUDA_CC_IS_CDNA3(cc)) {
+            return true;
+        }
+        if (ne11 <= 128 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) {
+            return true;
+        }
+        if (ne11 <= 256 && (type == GGML_TYPE_Q4_K || type == GGML_TYPE_Q5_K)) {
+            return true;
+        }
+        return false;
+    }
+
     return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
 }
diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh
index 9696a320462..96129bd831f 100644
--- a/ggml/src/ggml-cuda/mmq.cuh
+++ b/ggml/src/ggml-cuda/mmq.cuh
@@ -58,6 +58,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
             return MMQ_Q8_1_DS_LAYOUT_DS4;
         case GGML_TYPE_Q8_0:
             return MMQ_Q8_1_DS_LAYOUT_D4;
+        case GGML_TYPE_MXFP4:
+            return MMQ_Q8_1_DS_LAYOUT_D4;
         case GGML_TYPE_Q2_K:
             return MMQ_Q8_1_DS_LAYOUT_D2S6;
         case GGML_TYPE_Q3_K:
@@ -90,7 +92,7 @@ struct tile_x_sizes {
 };
 
 static int get_mmq_x_max_host(const int cc) {
-    return new_mma_available(cc) ? 128 :
+    return (amd_mfma_available(cc) || turing_mma_available(cc)) ? 128 :
         GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
 #ifdef GGML_CUDA_FORCE_MMQ
             128                     : 64;
@@ -100,13 +102,13 @@ static int get_mmq_x_max_host(const int cc) {
 }
 
 static constexpr __device__ int get_mmq_x_max_device() {
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     return 128;
-#else // NEW_MMA_AVAILABLE
+#else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
-    return 128;
-#else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#if defined(GGML_USE_HIP)
+    return 64;
+#else // defined(GGML_USE_HIP)
 
 #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
 #ifdef GGML_CUDA_FORCE_MMQ
@@ -115,12 +117,11 @@ static constexpr __device__ int get_mmq_x_max_device() {
     return MMQ_DP4A_MAX_BATCH_SIZE;
 #endif // GGML_CUDA_FORCE_MMQ
 #else // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
-
     return 64;
 #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
 
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(GGML_USE_HIP)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 }
 
 static int get_mmq_y_host(const int cc) {
@@ -129,7 +130,7 @@ static int get_mmq_y_host(const int cc) {
 }
 
 static constexpr __device__ int get_mmq_y_device() {
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#if defined(GGML_USE_HIP)
 #if defined(RDNA1)
     return 64;
 #else
@@ -141,19 +142,28 @@ static constexpr __device__ int get_mmq_y_device() {
 #else
     return 64;
 #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#endif // defined(GGML_USE_HIP)
 }
 
-#define MMQ_DP4A_TXS_Q4_0    tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_0   + mmq_y/QI4_0,     0}
-#define MMQ_DP4A_TXS_Q4_1    tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_1   + mmq_y/QI4_1,     0}
-#define MMQ_DP4A_TXS_Q8_0    tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0}
-#define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*4/QI8_0 + mmq_y/(QI8_0/4), 0}
-#define MMQ_DP4A_TXS_Q8_1    tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_1 + mmq_y/(QI8_1/2), 0}
-#define MMQ_DP4A_TXS_Q2_K    tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE         + mmq_y,           0}
-#define MMQ_DP4A_TXS_Q3_K    tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y,                                     mmq_y*WARP_SIZE/8 + mmq_y/8}
-#define MMQ_DP4A_TXS_Q4_K    tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_K,                     mmq_y*WARP_SIZE/8 + mmq_y/8}
-#define MMQ_DP4A_TXS_Q5_K    tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K   + mmq_y/QI5_K,     mmq_y*WARP_SIZE/8 + mmq_y/8}
-#define MMQ_DP4A_TXS_Q6_K    tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K   + mmq_y/QI6_K,     mmq_y*WARP_SIZE/8 + mmq_y/8}
+// Decouple shared memory tile sizes from WARP_SIZE to allow for different warp sizes.
+// The K dimension of the tiles has either,
+// 1*MMQ_TILE_NE_K==32 (always for TILE_Y_K) or 2*MMQ_TILE_NE_K==64 (typically for TILE_X_K),
+// 32 bit elements for the quantized data (does not include scales).
+// In other words, the size of the quantized data in the K dimension is a multiple of MMQ_TILE_NE_K.
+// The final tile size in K direction is padded to avoid shared memory bank conflicts,
+// in terms of 32 bit elements that means K % 2 == 1 for dp4a or K % 8 == 4 for mma.
+#define MMQ_TILE_NE_K 32
+
+#define MMQ_DP4A_TXS_Q4_0    tile_x_sizes{mmq_y*MMQ_TILE_NE_K   + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_0   + mmq_y/QI4_0,     0}
+#define MMQ_DP4A_TXS_Q4_1    tile_x_sizes{mmq_y*MMQ_TILE_NE_K   + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_1   + mmq_y/QI4_1,     0}
+#define MMQ_DP4A_TXS_Q8_0    tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_0 + mmq_y/(QI8_0/2), 0}
+#define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*4/QI8_0 + mmq_y/(QI8_0/4), 0}
+#define MMQ_DP4A_TXS_Q8_1    tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_1 + mmq_y/(QI8_1/2), 0}
+#define MMQ_DP4A_TXS_Q2_K    tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K         + mmq_y,           0}
+#define MMQ_DP4A_TXS_Q3_K    tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y,                                         mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q4_K    tile_x_sizes{mmq_y*MMQ_TILE_NE_K   + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_K,                     mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q5_K    tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI5_K   + mmq_y/QI5_K,     mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q6_K    tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI6_K   + mmq_y/QI6_K,     mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
 
 static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
     switch (type) {
@@ -162,6 +172,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
         case GGML_TYPE_Q5_0:    return MMQ_DP4A_TXS_Q8_0;
         case GGML_TYPE_Q5_1:    return MMQ_DP4A_TXS_Q8_1;
         case GGML_TYPE_Q8_0:    return MMQ_DP4A_TXS_Q8_0;
+        case GGML_TYPE_MXFP4:   return MMQ_DP4A_TXS_Q8_1;
         case GGML_TYPE_Q2_K:    return MMQ_DP4A_TXS_Q2_K;
         case GGML_TYPE_Q3_K:    return MMQ_DP4A_TXS_Q3_K;
         case GGML_TYPE_Q4_K:    return MMQ_DP4A_TXS_Q4_K;
@@ -179,11 +190,11 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
     }
 }
 
-#define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0                 + 4)
-#define MMQ_MMA_TILE_X_K_Q8_1 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0                 + 4)
-#define MMQ_MMA_TILE_X_K_Q2_K (2*WARP_SIZE + WARP_SIZE                         + 4)
-#define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/2                       + 4)
-#define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K     + WARP_SIZE/8 + 7)
+#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0                   + 4)
+#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0                   + 4)
+#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K                           + 4)
+#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2                         + 4)
+#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K   + MMQ_TILE_NE_K/8 + 7)
 
 static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
 static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
@@ -198,6 +209,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
         case GGML_TYPE_Q5_0:    return MMQ_MMA_TILE_X_K_Q8_0;
         case GGML_TYPE_Q5_1:    return MMQ_MMA_TILE_X_K_Q8_1;
         case GGML_TYPE_Q8_0:    return MMQ_MMA_TILE_X_K_Q8_0;
+        case GGML_TYPE_MXFP4:   return MMQ_MMA_TILE_X_K_Q8_1;
         case GGML_TYPE_Q2_K:    return MMQ_MMA_TILE_X_K_Q2_K;
         case GGML_TYPE_Q3_K:    return MMQ_MMA_TILE_X_K_Q3_K;
         case GGML_TYPE_Q4_K:    return MMQ_MMA_TILE_X_K_Q8_1;
@@ -215,42 +227,76 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
     }
 }
 
-#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
+// block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales)
+#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1)
 
 static int mmq_get_granularity_host(const int mmq_x, const int cc) {
-    return new_mma_available(cc) && mmq_x >= 48 ? 16 : 8;
+    if (amd_mfma_available(cc)) {
+        return mmq_x >= 128 ? 32 : 16;
+    } else if (turing_mma_available(cc) && mmq_x >= 48) {
+        return 16;
+    } else {
+        return 8;
+    }
 }
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE)
+static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
+    return mmq_x >= 128 ? 32 : 16;
+}
+#elif defined(TURING_MMA_AVAILABLE)
 static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
     return mmq_x >= 48 ? 16 : 8;
 }
 #else
-static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */) {
+static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/) {
     return 8;
 }
-#endif // NEW_MMA_AVAILABLE
+#endif // AMD_MFMA_AVAILABLE
+
+#if defined(GGML_USE_HIP)
+static int mmq_get_nwarps_host(const int cc, const int warp_size) {
+    return amd_mfma_available(cc) ? 8 : 256/warp_size;
+}
+#else
+static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) {
+    return 256/warp_size;
+}
+#endif // (GGML_USE_HIP)
+
+static constexpr __device__ int mmq_get_nwarps_device() {
+#if defined(AMD_MFMA_AVAILABLE)
+    return 8;
+#else
+    return 256/ggml_cuda_get_physical_warp_size();
+#endif // AMD_MFMA_AVAILABLE
+}
 
 // ------------------------------------------------------------
 
-template  static __device__ __forceinline__ void load_tiles_q4_0(
+template  static __device__ __forceinline__ void load_tiles_q4_0(
     const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
-    float * x_df = (float *) (x_qs + 2*WARP_SIZE);
+    float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
-    const int kbx  = threadIdx.x / QI4_0;
-    const int kqsx = threadIdx.x % QI4_0;
+    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0);
+    constexpr int nrows = warp_size / threads_per_row;
+    const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
+    const int kbx  = txi / QI4_0;
+    const int kqsx = txi % QI4_0;
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + threadIdx.y;
+    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
 
         if (need_check) {
             i = min(i, i_max);
@@ -259,20 +305,21 @@ template  static __device__ __forceinlin
         const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
         const int qs0 = get_int_b2(bxi->qs, kqsx);
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0]     = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
 #else
-        x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
-#endif // NEW_MMA_AVAILABLE
+        x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 
-    const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
+    constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_0;
+    constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
     const int kbxd = threadIdx.x % blocks_per_tile_x_row;
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
-        int i = i0 + threadIdx.y * QI4_0 + threadIdx.x / blocks_per_tile_x_row;
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
+        int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
 
         if (need_check) {
             i = min(i, i_max);
@@ -280,17 +327,19 @@ template  static __device__ __forceinlin
 
         const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
 
-#ifdef NEW_MMA_AVAILABLE
-        x_df[i*MMQ_MMA_TILE_X_K_Q8_0       + kbxd] = bxi->d;
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0           + kbxd] = bxi->d;
 #else
-        x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
-#endif // NEW_MMA_AVAILABLE
+        x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 }
 
-template 
+template 
 static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
     const int   * x_qs = (const int   *) x;
@@ -299,7 +348,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
     const half2 * y_ds = (const half2 *) y;
 
 // #pragma unroll
-    for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) {
+    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) {
         const int k0 = k00 + k01;
 
 #pragma unroll
@@ -307,7 +356,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
             const int j = j0 + threadIdx.y;
 
 #pragma unroll
-            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
                 const int i = i0 + threadIdx.x;
 
                 const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
@@ -320,32 +369,37 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
                     u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)];
                 }
 
-                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl
-                    (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_0], u,
-                     x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl
+                    (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u,
+                     x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
             }
         }
     }
 }
 
-template  static __device__ __forceinline__ void load_tiles_q4_1(
+template  static __device__ __forceinline__ void load_tiles_q4_1(
     const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
-    half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
+    half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + txs.qs);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
-    const int kbx  = threadIdx.x / QI4_1;
-    const int kqsx = threadIdx.x % QI4_1;
+    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1);
+    constexpr int nrows = warp_size / threads_per_row;
+    const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
+    const int kbx  = txi / QI4_1;
+    const int kqsx = txi % QI4_1;
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + threadIdx.y;
+    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
 
         if (need_check) {
             i = min(i, i_max);
@@ -354,20 +408,21 @@ template  static __device__ __forceinlin
         const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
         const int qs0 = get_int_b4(bxi->qs, kqsx);
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0]     = (qs0 >> 0) & 0x0F0F0F0F;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
 #else
-        x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
-#endif // NEW_MMA_AVAILABLE
+        x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 
-    const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
+    constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1;
+    constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
     const int kbxd = threadIdx.x % blocks_per_tile_x_row;
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
-        int i = i0 + threadIdx.y * QI4_1 + threadIdx.x / blocks_per_tile_x_row;
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
+        int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
 
         if (need_check) {
             i = min(i, i_max);
@@ -375,17 +430,19 @@ template  static __device__ __forceinlin
 
         const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
 
-#ifdef NEW_MMA_AVAILABLE
-        x_dm[i*MMQ_MMA_TILE_X_K_Q8_1       + kbxd] = bxi->dm;
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+        x_dm[i*MMQ_MMA_TILE_X_K_Q8_1           + kbxd] = bxi->dm;
 #else
-        x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
-#endif // NEW_MMA_AVAILABLE
+        x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 }
 
-template 
+template 
 static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
     const int   * x_qs = (const int   *) x;
@@ -394,7 +451,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
     const half2 * y_ds = (const half2 *) y;
 
 // #pragma unroll
-    for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) {
+    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) {
         const int k0 = k00 + k01;
 
 #pragma unroll
@@ -402,7 +459,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
             const int j = j0 + threadIdx.y;
 
 #pragma unroll
-            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
                 const int i = i0 + threadIdx.x;
 
                 const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
@@ -415,32 +472,37 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
                     u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)];
                 }
 
-                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl
-                    (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_1], u,
-                     x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl
+                    (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u,
+                     x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
             }
         }
     }
 }
 
-template  static __device__ __forceinline__ void load_tiles_q5_0(
+template  static __device__ __forceinline__ void load_tiles_q5_0(
     const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
-    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
-    const int kbx  = threadIdx.x / QI5_0;
-    const int kqsx = threadIdx.x % QI5_0;
+    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0);
+    constexpr int nrows = warp_size / threads_per_row;
+    const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
+    const int kbx  = txi / QI5_0;
+    const int kqsx = txi % QI5_0;
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + threadIdx.y;
+    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
 
         if (need_check) {
             i = min(i, i_max);
@@ -449,7 +511,7 @@ template  static __device__ __forceinlin
         const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
 
         const int ql = get_int_b2(bxi->qs, kqsx);
-        const int qh = get_int_b2(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0));
+        const int qh = get_int_b2(bxi->qh, 0) >> (4 * kqsx);
 
         int qs0 = (ql >>  0)   & 0x0F0F0F0F;
         qs0    |= (qh <<  4)   & 0x00000010;  // 0 ->  4
@@ -465,21 +527,22 @@ template  static __device__ __forceinlin
         qs1    |= (qh <<  9)   & 0x10000000;  // 19 -> 28
         qs1     = __vsubss4(qs1, 0x10101010); // subtract 16
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0]     = qs0;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
 #else
-        x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_0) + kqsx + 0]     = qs0;
-        x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
-#endif // NEW_MMA_AVAILABLE
+        x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0]     = qs0;
+        x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 
-    const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
+    constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0;
+    constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
     const int kbxd = threadIdx.x % blocks_per_tile_x_row;
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
-        int i = i0 + threadIdx.y * QI5_0 + threadIdx.x / blocks_per_tile_x_row;
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
+        int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
 
         if (need_check) {
             i = min(i, i_max);
@@ -487,32 +550,37 @@ template  static __device__ __forceinlin
 
         const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
 
-#ifdef NEW_MMA_AVAILABLE
-        x_df[i*MMQ_MMA_TILE_X_K_Q8_0       + kbxd] = bxi->d;
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0           + kbxd] = bxi->d;
 #else
-        x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
-#endif // NEW_MMA_AVAILABLE
+        x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 }
 
-template  static __device__ __forceinline__ void load_tiles_q5_1(
+template  static __device__ __forceinline__ void load_tiles_q5_1(
     const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
-    half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
+    half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + txs.qs);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
-    const int kbx  = threadIdx.x / QI5_1;
-    const int kqsx = threadIdx.x % QI5_1;
+    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1);
+    constexpr int nrows = warp_size / threads_per_row;
+    const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
+    const int kbx  = txi / QI5_1;
+    const int kqsx = txi % QI5_1;
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + threadIdx.y;
+    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
 
         if (need_check) {
             i = min(i, i_max);
@@ -521,7 +589,7 @@ template  static __device__ __forceinlin
         const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
 
         const int ql = get_int_b4(bxi->qs, kqsx);
-        const int qh = get_int_b4(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1));
+        const int qh = get_int_b4(bxi->qh, 0) >> (4 * kqsx);
 
         int qs0 = (ql >>  0) & 0x0F0F0F0F;
         qs0    |= (qh <<  4) & 0x00000010; // 0 ->  4
@@ -535,21 +603,22 @@ template  static __device__ __forceinlin
         qs1    |= (qh <<  2) & 0x00100000; // 18 -> 20
         qs1    |= (qh <<  9) & 0x10000000; // 19 -> 28
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0]     = qs0;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
 #else
-        x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_1) + kqsx + 0]     = qs0;
-        x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
-#endif // NEW_MMA_AVAILABLE
+        x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0]     = qs0;
+        x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 
-    const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
+    constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1;
+    constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
     const int kbxd = threadIdx.x % blocks_per_tile_x_row;
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
-        int i = i0 + threadIdx.y * QI5_1 + threadIdx.x / blocks_per_tile_x_row;
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
+        int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
 
         if (need_check) {
             i = min(i, i_max);
@@ -557,32 +626,38 @@ template  static __device__ __forceinlin
 
         const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
 
-#ifdef NEW_MMA_AVAILABLE
-        x_dm[i*MMQ_MMA_TILE_X_K_Q8_1       + kbxd] = bxi->dm;
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+        x_dm[i*MMQ_MMA_TILE_X_K_Q8_1           + kbxd] = bxi->dm;
 #else
-        x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
-#endif // NEW_MMA_AVAILABLE
+        x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 }
 
-template  static __device__ __forceinline__ void load_tiles_q8_0(
+template  static __device__ __forceinline__ void load_tiles_q8_0(
     const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
-    float * x_df = (float *) (x_tile + 2*WARP_SIZE);
+    float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
-    const int kbx  = threadIdx.x / QI8_0;
-    const int kqsx = threadIdx.x % QI8_0;
+    // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp
+    constexpr int threads_per_row = 32;
+    constexpr int nrows = warp_size / threads_per_row;
+    const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
+    const int kbx  = txi / QI8_0;
+    const int kqsx = txi % QI8_0;
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + threadIdx.y;
+    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
 
         if (need_check) {
             i = min(i, i_max);
@@ -590,21 +665,22 @@ template  static __device__ __forceinlin
 
         const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
 
-#ifdef NEW_MMA_AVAILABLE
-        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0         + threadIdx.x] = get_int_b2(bxi[0].qs,               kqsx);
-        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0             + txi] = get_int_b2(bxi[0].qs,                   kqsx);
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
 #else
-        x_qs[i*(2*WARP_SIZE + 1)     + 0         + threadIdx.x] = get_int_b2(bxi[0].qs,               kqsx);
-        x_qs[i*(2*WARP_SIZE + 1)     + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
-#endif // NEW_MMA_AVAILABLE
+        x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0             + txi] = get_int_b2(bxi[0].qs,                   kqsx);
+        x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 
-    const int blocks_per_tile_x_row = 2*WARP_SIZE / QI8_0;
+    constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0;
+    constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
     const int kbxd = threadIdx.x % blocks_per_tile_x_row;
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0/2) {
-        int i = i0 + threadIdx.y * (QI8_0/2) + threadIdx.x / blocks_per_tile_x_row;
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
+        int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
 
         if (need_check) {
             i = min(i, i_max);
@@ -612,17 +688,84 @@ template  static __device__ __forceinlin
 
         const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
 
-#ifdef NEW_MMA_AVAILABLE
-        x_df[i*MMQ_MMA_TILE_X_K_Q8_0             + kbxd] = bxi->d;
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0                 + kbxd] = bxi->d;
+#else
+        x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+    }
+}
+
+template  static __device__ __forceinline__ void load_tiles_mxfp4(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+
+    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4);
+    constexpr int nrows = warp_size / threads_per_row;
+    const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
+    const int kbx  = txi / QI_MXFP4;
+    const int kqsx = txi % QI_MXFP4;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbx;
+
+        const int aux_q4 = get_int_b1(bxi->qs, kqsx);
+        const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
+        const int k0 = kbx * (2 * QI_MXFP4) + kqsx;
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0]        = v.x;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y;
+#else
+        x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0]        = v.x;
+        x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+    }
+
+    constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4;
+    constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
+        int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd;
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_1                 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
 #else
-        x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
-#endif // NEW_MMA_AVAILABLE
+        x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 }
 
-template 
+template 
 static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
     const int   * x_qs = (const int   *) x;
@@ -631,7 +774,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
     const float * y_df = (const float *) y;
 
 // #pragma unroll
-    for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) {
+    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) {
         const int k0 = k00 + k01;
 
 #pragma unroll
@@ -639,21 +782,76 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
             const int j = j0 + threadIdx.y;
 
 #pragma unroll
-            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
                 const int i = i0 + threadIdx.x;
 
-                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl
-                    (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % WARP_SIZE],
-                     x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_q8_1_impl
+                    (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % MMQ_TILE_NE_K],
+                     x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (MMQ_TILE_NE_K/QI8_1)]);
             }
         }
     }
 }
 
-template 
+template 
 static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+#if defined(AMD_MFMA_AVAILABLE)
+    typedef tile<16,  8, int> tile_A;
+    typedef tile<16,  8, int> tile_B;
+    typedef tile<16, 16, int> tile_C;
+
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = granularity;
+    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
+    const half2 * y_ds = (const half2 *) y;
+
+    const int i0 = (threadIdx.y / ntx) * rows_per_warp;
+
+    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
+        const int k0 = k00 + k01;
+
+        tile_A A[ntx];
+#pragma unroll
+        for (int n = 0; n < ntx; ++n) {
+            load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
+        }
 
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+            tile_B B;
+            load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+
+            float dB;
+            const int j = j0 + tile_C::get_j(0);
+            if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
+                dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
+            } else {
+                dB = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+            }
+
+#pragma unroll
+            for (int n = 0; n < ntx; ++n) {
+                tile_C C;
+                mma(C, A[n], B);
+
+#pragma unroll
+                for (int l = 0; l < tile_C::ne; ++l) {
+                    const int i = i0 + n*tile_A::I + tile_C::get_i(l);
+                    const float dA = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA*dB;
+                }
+            }
+        }
+    }
+#else
     typedef tile<16, 8, int> tile_A;
     typedef tile< 8, 8, int> tile_B;
     typedef tile<16, 8, int> tile_C;
@@ -662,23 +860,23 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
     constexpr int rows_per_warp = 2 * granularity;
     constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
 
-    y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
+    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
 
     const int   * x_qs = (const int   *) x;
-    const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
+    const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K;
     const int   * y_qs = (const int   *) y + 4;
     const float * y_df = (const float *) y;
     const half2 * y_ds = (const half2 *) y;
 
-    tile_A A[ntx][WARP_SIZE/QI8_0];
-    float dA[ntx][tile_C::ne/2][WARP_SIZE/QI8_0];
+    tile_A A[ntx][MMQ_TILE_NE_K/QI8_0];
+    float dA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_0];
 
     const int i0 = (threadIdx.y/ntx)*rows_per_warp;
 
 #pragma unroll
     for (int n = 0; n < ntx; ++n) {
 #pragma unroll
-        for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
+        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
             const int k0 = k00 + k01;
 
             load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
@@ -689,7 +887,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
             const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
 
 #pragma unroll
-            for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
+            for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
                 const int k0 = k00 + k01;
 
                 dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
@@ -700,7 +898,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
 #pragma unroll
-        for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
+        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
             tile_B B;
             float dB[tile_C::ne/2];
 
@@ -729,11 +927,14 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
             }
         }
     }
+#endif // defined(AMD_MFMA_AVAILABLE)
 }
 
-template 
+template 
 static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
     const int   * x_qs = (const int   *) x;
@@ -742,7 +943,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
     const half2 * y_ds = (const half2 *) y;
 
 // #pragma unroll
-    for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) {
+    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) {
         const int k0 = k00 + k01;
 
 #pragma unroll
@@ -750,45 +951,95 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
             const int j = j0 + threadIdx.y;
 
 #pragma unroll
-            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
                 const int i = i0 + threadIdx.x;
 
-                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl
-                    (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
-                    x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_1_q8_1_impl
+                    (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
+                    x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
             }
         }
     }
 }
 
-template 
+template 
 static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+#if defined(AMD_MFMA_AVAILABLE)
+    typedef tile<16,  8, int> tile_A;
+    typedef tile<16,  8, int> tile_B;
+    typedef tile<16, 16, int> tile_C;
 
-    typedef tile<16, 8, int> tile_A;
-    typedef tile< 8, 8, int> tile_B;
-    typedef tile<16, 8, int> tile_C;
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = granularity;
+    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K;
+    const int   * y_qs = (const int   *) y + 4;
+    const half2 * y_dm = (const half2 *) y;
+
+    const int i0 = (threadIdx.y / ntx) * rows_per_warp;
+
+    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
+        const int k0 = k00 + k01;
+
+        tile_A A[ntx];
+#pragma unroll
+        for (int n = 0; n < ntx; ++n) {
+            load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
+        }
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+            tile_B B;
+            load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+
+            const int j = j0 + tile_C::get_j(0);
+            const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
+
+#pragma unroll
+            for (int n = 0; n < ntx; ++n) {
+                tile_C C;
+                mma(C, A[n], B);
+
+#pragma unroll
+                for (int l = 0; l < tile_C::ne; ++l) {
+                    const int i = i0 + n*tile_A::I + tile_C::get_i(l);
+                    float2 dmA = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.x*dsB.x*C.x[l];
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.y*dsB.y;
+                }
+            }
+        }
+    }
+#else
+    typedef tile<16,  8, int> tile_A;
+    typedef tile< 8,  8, int> tile_B;
+    typedef tile<16,  8, int> tile_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = 2 * granularity;
     constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
 
-    y += (threadIdx.y % ntx) * (tile_B::J*MMQ_TILE_Y_K);
+    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
 
     const int   * x_qs = (const int   *) x;
-    const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
+    const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K;
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_dm = (const half2 *) y;
 
-    tile_A   A[ntx][WARP_SIZE/QI8_1];
-    float2 dmA[ntx][tile_C::ne/2][WARP_SIZE/QI8_1];
+    tile_A   A[ntx][MMQ_TILE_NE_K/QI8_1];
+    float2 dmA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_1];
 
     const int i0 = (threadIdx.y/ntx)*rows_per_warp;
 
 #pragma unroll
     for (int n = 0; n < ntx; ++n) {
 #pragma unroll
-        for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
             const int k0 = k00 + k01;
 
             load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
@@ -799,7 +1050,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
             const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
 
 #pragma unroll
-            for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+            for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
                 const int k0 = k00 + k01;
 
                 dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
@@ -810,7 +1061,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
 #pragma unroll
-        for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
             tile_B   B;
             float2 dsB[tile_C::ne/2];
 
@@ -836,11 +1087,15 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
             }
         }
     }
+#endif // defined(AMD_MFMA_AVAILABLE)
 }
 
-template 
+// Used for Q3_K, IQ2_S, and IQ2_XS
+template 
 static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
     constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
     const int   * x_qs = (const int   *) x;
@@ -849,7 +1104,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
     const float * y_df = (const float *) y;
 
 // #pragma unroll
-    for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
+    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
         const int k0 = k00 + k01;
 
 #pragma unroll
@@ -857,23 +1112,73 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
             const int j = j0 + threadIdx.y;
 
 #pragma unroll
-            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
                 const int i = i0 + threadIdx.x;
 
-                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_16_q8_1_impl(
-                    &x_qs[i*(2*WARP_SIZE + 1) + k0],
+                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_16_q8_1_impl(
+                    &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0],
                     &y_qs[j*MMQ_TILE_Y_K + k01],
-                    &x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)],
+                    &x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)],
                     y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
             }
         }
     }
 }
 
-template 
+// Used for Q3_K, IQ2_S, and IQ2_XS:
+template 
 static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE)
+    typedef tile<16,  8, int> tile_A;
+    typedef tile<16,  8, int> tile_B;
+    typedef tile<16, 16, int> tile_C;
+    typedef tile<64,  2, int> tile_load;
+
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = granularity;
+    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
+
+    const int i0 = (threadIdx.y / ntx) * rows_per_warp;
+
+    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
+        const int k0 = k00 + k01;
+
+        tile_A A[ntx];
+#pragma unroll
+        for (int n = 0; n < ntx; ++n) {
+            load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
+        }
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+            tile_B B[1];
+            load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+
+            const int j = j0 + tile_C::get_j(0);
+            const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
+
+#pragma unroll
+            for (int n = 0; n < ntx; ++n) {
+                tile_C C;
+                mma(C, A[n], B[0]);
+
+#pragma unroll
+                for (int l = 0; l < tile_C::ne; ++l) {
+                    const int i = i0 + n*tile_C::I + tile_C::get_i(l);
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB;
+                }
+            }
+        }
+    }
+#elif defined(TURING_MMA_AVAILABLE)
 
     typedef tile<16, 4, int> tile_A;
     typedef tile<16, 8, int> tile_A_8;
@@ -884,10 +1189,10 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
     constexpr int rows_per_warp = 2 * granularity;
     constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
 
-    y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
+    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
 
     const int   * x_qs = (const int   *) x;
-    const float * x_df = (const float *) x_qs + WARP_SIZE*2;
+    const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
     const int   * y_qs = (const int   *) y + 4;
     const float * y_df = (const float *) y;
 
@@ -899,7 +1204,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
 #pragma unroll
     for (int n = 0; n < ntx; ++n) {
 #pragma unroll
-        for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
+        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
             const int k0 = k00 + k01;
 
             load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
@@ -910,7 +1215,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
             const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
 
 #pragma unroll
-            for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) {
+            for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
                 const int k0 = k00 + k01;
 
                 dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4];
@@ -921,7 +1226,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
 #pragma unroll
-        for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
+        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
             tile_B B[2];
             float dB[tile_C::ne/2];
 
@@ -952,26 +1257,29 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
 #else
     GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
     NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // AMD_MFMA_AVAILABLE
 }
 
-template  static __device__ __forceinline__ void load_tiles_q2_K(
+template  static __device__ __forceinline__ void load_tiles_q2_K(
     const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+    constexpr int nwarps = mmq_get_nwarps_device();
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
-    half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
+    half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + txs.qs);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
-    const int kqsx = threadIdx.x % QI2_K;
+    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K);
+    constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row;
+    const int kqsx = threadIdx.x % threads_per_row;
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI2_K) {
-        int i = i0 + threadIdx.y*(WARP_SIZE/QI2_K) + threadIdx.x/QI2_K;
+    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+        int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
 
         if (need_check) {
             i = min(i, i_max);
@@ -987,11 +1295,11 @@ template  static __device__ __forceinlin
 
             const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
 #else
-            x_qs[i*(2*WARP_SIZE + 1)     + k] = x_qs_k;
-#endif // NEW_MMA_AVAILABLE
+            x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         }
 
         const int sc_m = bxi->scales[kqsx];
@@ -1002,17 +1310,19 @@ template  static __device__ __forceinlin
         const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
 #endif // FAST_FP16_AVAILABLE
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
 #else
-        x_dm[i*(WARP_SIZE + 1)       + kqsx] = x_dm_ik;
-#endif // NEW_MMA_AVAILABLE
+        x_dm[i*(MMQ_TILE_NE_K + 1)   + kqsx] = x_dm_ik;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 }
 
-template 
+template 
 static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
     const int   * x_qs = (const int   *) x;
@@ -1029,7 +1339,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
     }
 
 #pragma unroll
-    for (int k01 = 0; k01 < WARP_SIZE/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
+    for (int k01 = 0; k01 < MMQ_TILE_NE_K/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
         const int k0 = k00 + k01;
 
 #pragma unroll
@@ -1037,13 +1347,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
             const int j = j0 + threadIdx.y;
 
 #pragma unroll
-            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
                 const int i = i0 + threadIdx.x;
 
                 constexpr int ns = 2;
-                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
-                    &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
-                    &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
+                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq(
+                    &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
+                    &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
                     &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
             }
         }
@@ -1052,7 +1362,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
     // Some compilers fail to unroll the loop over k01 if there is a conditional statement for ns in the inner loop.
     // As a workaround 2 separate loops are used instead.
 #pragma unroll
-    for (int k01 = WARP_SIZE/2; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
+    for (int k01 = MMQ_TILE_NE_K/2; k01 < MMQ_TILE_NE_K; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
         const int k0 = k00 + k01;
 
 #pragma unroll
@@ -1060,23 +1370,89 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
             const int j = j0 + threadIdx.y;
 
 #pragma unroll
-            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
                 const int i = i0 + threadIdx.x;
 
                 constexpr int ns = 1;
-                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
-                    &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
-                    &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
+                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq(
+                    &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
+                    &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
                     &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
             }
         }
     }
 }
 
-template 
+template 
 static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE)
+    typedef tile<16,  8, int> tile_A;
+    typedef tile<16,  8, int> tile_B;
+    typedef tile<16, 16, int> tile_C;
+    typedef tile<64,  2, int> tile_load;
+
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = granularity;
+    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
+    const int   * y_qs = (const int   *) y + 4;
+    const half2 * y_ds = (const half2 *) y;
+
+    const int i0 = (threadIdx.y / ntx) * rows_per_warp;
+
+    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
+        const int k0 = k00 + k01;
+
+        tile_A A[ntx];
+#pragma unroll
+        for (int n = 0; n < ntx; ++n) {
+            load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
+        }
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+            tile_B B[1];
+            load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+
+            const int j = j0 + tile_C::get_j(0);
+            const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2;
+            const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0
+                                              : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y
+                                                             : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x);
+
+            tile_C Cm;
+            if (k01 >= MMQ_TILE_NE_K * 3/4) {
+                tile_A A1;
+                A1.x[0] = 0x01010101;
+                A1.x[1] = 0x01010101;
+                mma(Cm, A1, B[0]);
+            }
+
+#pragma unroll
+            for (int n = 0; n < ntx; ++n) {
+                tile_C Cd;
+                mma(Cd, A[n], B[0]);
+
+#pragma unroll
+                for (int l = 0; l < tile_C::ne; ++l) {
+                    const int i = i0 + n*tile_C::I + tile_C::get_i(l);
+                    const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]);
+                    float tmp = Cd.x[l]*dm.x;
+                    if (k01 >= MMQ_TILE_NE_K * 3/4) {
+                        tmp -= Cm.x[l]*dm.y;
+                    }
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB;
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB;
+                }
+            }
+        }
+    }
+#elif defined(TURING_MMA_AVAILABLE)
 
     typedef tile<16, 4, int> tile_A;
     typedef tile<16, 8, int> tile_A_8;
@@ -1087,10 +1463,10 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
     constexpr int rows_per_warp = 2 * granularity;
     constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
 
-    y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
+    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
 
     const int   * x_qs = (const int   *) x;
-    const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
+    const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_ds = (const half2 *) y;
 
@@ -1103,7 +1479,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
 #pragma unroll
     for (int n = 0; n < ntx; ++n) {
 #pragma unroll
-        for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
             const int k0 = k00 + k01;
 
             load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
@@ -1117,7 +1493,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
             const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
 
 #pragma unroll
-            for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) {
+            for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1/2) {
                 const int k0 = k00 + k01;
 
                 const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]);
@@ -1140,7 +1516,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
         }
 
 #pragma unroll
-        for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
             tile_B B[2];
 
             // Here load_generic is faster than load_ldmatrix.
@@ -1148,7 +1524,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
             load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
 
             tile_C Cm[2];
-            if (k01 >= WARP_SIZE * 3/4) {
+            if (k01 >= MMQ_TILE_NE_K * 3/4) {
                 tile_A A1;
                 A1.x[0] = 0x01010101;
                 A1.x[1] = 0x01010101;
@@ -1166,16 +1542,16 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
 #pragma unroll
                 for (int l = 0; l < tile_C::ne; ++l) {
                     float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
-                    if (k01 >= WARP_SIZE * 3/4) {
+                    if (k01 >= MMQ_TILE_NE_K * 3/4) {
                         tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
                     }
-                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < MMQ_TILE_NE_K/2 ? dB[l%2].x : dB[l%2].y);
                 }
             }
         }
 
 #pragma unroll
-        for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) {
+        for (int k01 = 0; k01 < MMQ_TILE_NE_K * 3/4; k01 += QI8_1) {
             float2 sB[tile_C::ne/2];
 
 #pragma unroll
@@ -1198,27 +1574,31 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
 #else
     GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
     NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // AMD_MFMA_AVAILABLE
 }
 
-template  static __device__ __forceinline__ void load_tiles_q3_K(
+template  static __device__ __forceinline__ void load_tiles_q3_K(
     const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
-    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
     int   * x_sc = (int   *) (x_df + txs.dm);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
-    const int kqsx = threadIdx.x % QI3_K;
+    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR3_K);
+    constexpr int nrows = warp_size / threads_per_row;
+    const int kqsx = threadIdx.x % threads_per_row;
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI3_K) {
-        int i = i0 + threadIdx.y * (WARP_SIZE/QI3_K) + threadIdx.x / QI3_K;
+    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+        int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
 
         if (need_check) {
             i = min(i, i_max);
@@ -1238,17 +1618,18 @@ template  static __device__ __forceinlin
 
             const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
 #else
-            x_qs[i*(2*WARP_SIZE + 1)     + k] = x_qs_k;
-#endif // NEW_MMA_AVAILABLE
+            x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         }
     }
 
+    constexpr int rows_per_warp = warp_size / 4;
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) {
-        int i = i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8);
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
+        int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/4;
 
         if (need_check) {
             i = min(i, i_max);
@@ -1256,7 +1637,7 @@ template  static __device__ __forceinlin
 
         const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
 
-        const int ksc = threadIdx.x % (WARP_SIZE/8);
+        const int ksc = threadIdx.x % 4;
 
         const int ksc_low = ksc % (QI3_K/8);
         const int shift_low = 4 * (ksc / (QI3_K/8));
@@ -1268,23 +1649,23 @@ template  static __device__ __forceinlin
 
         const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         const int8_t * sc8 = (const int8_t *) ≻
         const float d = bxi->d;
 
 #pragma unroll
         for (int l = 0; l < int(sizeof(int)); ++l) {
-            x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*(threadIdx.x % (WARP_SIZE/8)) + l] = d*sc8[l];
+            x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*ksc + l] = d*sc8[l];
         }
 #else
-        x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc;
-#endif // NEW_MMA_AVAILABLE
+        x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 
-#ifndef NEW_MMA_AVAILABLE
+#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) {
-        int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y;
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
+        int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
 
         if (need_check) {
             i = min(i, i_max);
@@ -1294,12 +1675,14 @@ template  static __device__ __forceinlin
 
         x_df[i] = bxi->d;
     }
-#endif // NEW_MMA_AVAILABLE
+#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
 }
 
-template 
+template 
 static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
     const int   * x_qs = (const int   *) x;
@@ -1309,7 +1692,7 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
     const float * y_df = (const float *) y;
 
 // #pragma unroll
-    for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
+    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
         const int k0 = k00 + k01;
 
 #pragma unroll
@@ -1317,13 +1700,13 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
             const int j = j0 + threadIdx.y;
 
 #pragma unroll
-            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
                 const int i = i0 + threadIdx.x;
 
-                const int8_t * scales = ((const int8_t *) (x_sc + i*(WARP_SIZE/8) + i/8)) + k0/4;
+                const int8_t * scales = ((const int8_t *) (x_sc + i*(MMQ_TILE_NE_K/8) + i/8)) + k0/4;
 
-                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
-                    &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales,
+                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q3_K_q8_1_impl_mmq(
+                    &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales,
                     x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
             }
         }
@@ -1340,72 +1723,85 @@ static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, co
            ((scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030);  // upper 2 bits
 }
 
-template  static __device__ __forceinline__ void load_tiles_q4_K(
+template  static __device__ __forceinline__ void load_tiles_q4_K(
     const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
-    half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
+    half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + txs.qs);
     int   * x_sc = (int   *) (x_dm + txs.dm);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+
+    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K);
+    constexpr int nrows = warp_size / threads_per_row;
+    const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + threadIdx.y;
+    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
 
         if (need_check) {
             i = min(i, i_max);
         }
 
         const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
-        const int qs0 = get_int_b4(bxi->qs, threadIdx.x);
+        const int qs0 = get_int_b4(bxi->qs, txi);
 
-#ifdef NEW_MMA_AVAILABLE
-        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
-        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
 #else
-        x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
-#endif // NEW_MMA_AVAILABLE
+        x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 
-#ifdef NEW_MMA_AVAILABLE
-
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+    constexpr int rows_per_warp = warp_size / 2;
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
-        int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y;
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
+#if defined(AMD_MFMA_AVAILABLE)
+        // Need if on AMD instead of % because warp_size == 64
+        // This causes double work and throughput loss (MI300X)
+        // H100 loses about 100 t/s with 'if' condition over '%'
+        int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2;
+        if (i < mmq_y) {
+#else
+        int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
+        {
+#endif // defined(AMD_MFMA_AVAILABLE)
+            if (need_check) {
+                i = min(i, i_max);
+            }
 
-        const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
+            const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
 
-        const int * scales = (const int *) bxi->scales;
-        const int ksc = threadIdx.x % (WARP_SIZE/16);
+            const int * scales = (const int *) bxi->scales;
+            const int ksc = threadIdx.x % 2;
 
-        const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
-        const int  m32 = unpack_scales_q45_K(scales, ksc + 2);
+            const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
+            const int  m32 = unpack_scales_q45_K(scales, ksc + 2);
 
-        const uint8_t * sc8 = (const uint8_t *) &sc32;
-        const uint8_t *  m8 = (const uint8_t *)  &m32;
+            const uint8_t * sc8 = (const uint8_t *) &sc32;
+            const uint8_t *  m8 = (const uint8_t *)  &m32;
 
-        const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
+            const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
 
-#pragma unroll
-        for (int l = 0; l < int(sizeof(int)); ++l) {
-            x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
+    #pragma unroll
+            for (int l = 0; l < sizeof(int); ++l) {
+                x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
+            }
         }
     }
-
 #else
-
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI4_K) {
-        int i = (i0 + threadIdx.y*QI4_K + threadIdx.x) % mmq_y;
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
+        int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
 
         if (need_check) {
             i = min(i, i_max);
@@ -1415,30 +1811,32 @@ template  static __device__ __forceinlin
 
         x_dm[i] = bxi->dm;
     }
-
+    constexpr int rows_per_warp = warp_size / 4;
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
-        int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
+        int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
 
         if (need_check) {
             i = min(i, i_max);
         }
 
-        const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI4_K/8);
+        const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / (QI4_K/8);
 
         const int * scales = (const int *) bxi->scales;
 
-        const int ksc = threadIdx.x % (WARP_SIZE/8);
+        const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8);
         const int scales8 = unpack_scales_q45_K(scales, ksc);
 
-        x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
+        x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
     }
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 }
 
-template 
+template 
 static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
     const int   * x_qs = (const int   *) x;
@@ -1448,7 +1846,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
     const half2 * y_ds = (const half2 *) y;
 
 // #pragma unroll
-    for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) {
+    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) {
         const int k0 = k00 + k01;
 
 #pragma unroll
@@ -1456,97 +1854,110 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
             const int j = j0 + threadIdx.y;
 
 #pragma unroll
-            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
                 const int i = i0 + threadIdx.x;
 
-                const uint8_t * sc = (const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/32] + 2*(k01/16);
+                const uint8_t * sc = (const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/32] + 2*(k01/16);
 
-                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
-                    &x_qs[i*(WARP_SIZE + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
+                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_K_q8_1_impl_mmq(
+                    &x_qs[i*(MMQ_TILE_NE_K + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
                     x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
             }
         }
     }
 }
 
-template  static __device__ __forceinline__ void load_tiles_q5_K(
+template  static __device__ __forceinline__ void load_tiles_q5_K(
     const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
-    half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2);
+    half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + txs.qs);
     int   * x_sc = (int   *) (x_dm + txs.dm);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+
+    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_K);
+    constexpr int nrows = warp_size / threads_per_row;
+    const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + threadIdx.y;
+    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
 
         if (need_check) {
             i = min(i, i_max);
         }
 
         const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
-        const int ky = QR5_K*threadIdx.x;
+        const int ky = QR5_K*txi;
 
-        const int ql = get_int_b4(bxi->qs, threadIdx.x);
+        const int ql = get_int_b4(bxi->qs, txi);
         const int ql0 = (ql >> 0) & 0x0F0F0F0F;
         const int ql1 = (ql >> 4) & 0x0F0F0F0F;
 
-        const int qh = get_int_b4(bxi->qh, threadIdx.x % (QI5_K/4));
-        const int qh0 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 0)) << 4) & 0x10101010;
-        const int qh1 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 1)) << 4) & 0x10101010;
+        const int qh = get_int_b4(bxi->qh, txi % (QI5_K/4));
+        const int qh0 = ((qh >> (2 * (txi / (QI5_K/4)) + 0)) << 4) & 0x10101010;
+        const int qh1 = ((qh >> (2 * (txi / (QI5_K/4)) + 1)) << 4) & 0x10101010;
 
-        const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
-        const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + QI5_K/4;
+        const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0;
+        const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4;
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
 #else
-        x_qs[i*(2*WARP_SIZE + 1)     + kq0] = ql0 | qh0;
-        x_qs[i*(2*WARP_SIZE + 1)     + kq1] = ql1 | qh1;
-#endif // NEW_MMA_AVAILABLE
+        x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0;
+        x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 
-#ifdef NEW_MMA_AVAILABLE
-
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+    constexpr int rows_per_warp = warp_size / 2;
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
-        int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y;
-
-        if (need_check) {
-            i = min(i, i_max);
-        }
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
+#if defined(AMD_MFMA_AVAILABLE)
+        // Need if on AMD instead of % because warp_size == 64
+        // This causes double work and throughput loss (MI300X)
+        // H100 loses about 100 t/s with 'if' condition over '%'
+        int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2;
+        if (i < mmq_y) {
+#else
+        int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
+        {
+#endif // defined(AMD_MFMA_AVAILABLE)
+            if (need_check) {
+                i = min(i, i_max);
+            }
 
-        const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
+            const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
 
-        const int * scales = (const int *) bxi->scales;
-        const int ksc = threadIdx.x % (WARP_SIZE/16);
+            const int * scales = (const int *) bxi->scales;
+            const int ksc = threadIdx.x % 2;
 
-        const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
-        const int  m32 = unpack_scales_q45_K(scales, ksc + 2);
+            const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
+            const int  m32 = unpack_scales_q45_K(scales, ksc + 2);
 
-        const uint8_t * sc8 = (const uint8_t *) &sc32;
-        const uint8_t *  m8 = (const uint8_t *)  &m32;
+            const uint8_t * sc8 = (const uint8_t *) &sc32;
+            const uint8_t *  m8 = (const uint8_t *)  &m32;
 
-        const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
+            const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
 
 #pragma unroll
-        for (int l = 0; l < int(sizeof(int)); ++l) {
-            x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
+            for (int l = 0; l < int(sizeof(int)); ++l) {
+                x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
+            }
         }
     }
-
 #else
-
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI5_K) {
-        int i = (i0 + threadIdx.y*QI5_K + threadIdx.x) % mmq_y;
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
+        int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
 
         if (need_check) {
             i = min(i, i_max);
@@ -1557,9 +1968,10 @@ template  static __device__ __forceinlin
         x_dm[i] = bxi->dm;
     }
 
+    constexpr int rows_per_warp = warp_size / 4;
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) {
-        int i = (i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8)) % mmq_y;
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
+        int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
 
         if (need_check) {
             i = min(i, i_max);
@@ -1569,17 +1981,19 @@ template  static __device__ __forceinlin
 
         const int * scales = (const int *) bxi->scales;
 
-        const int ksc = threadIdx.x % (WARP_SIZE/8);
+        const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8);
         const int scales8 = unpack_scales_q45_K(scales, ksc);
 
-        x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
+        x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
     }
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 }
 
-template 
+template 
 static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
     const int   * x_qs = (const int   *) x;
@@ -1589,7 +2003,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
     const half2 * y_ds = (const half2 *) y;
 
 // #pragma unroll
-    for (int k01 = 0; k01 < WARP_SIZE; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) {
+    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) {
         const int k0 = k00 + k01;
 
 #pragma unroll
@@ -1597,36 +2011,42 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
             const int j = j0 + threadIdx.y;
 
 #pragma unroll
-            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
                 const int i = i0 + threadIdx.x;
 
-                const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k00/32]) + 2*(k01/16);
+                const uint8_t * sc = ((const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k00/32]) + 2*(k01/16);
 
-                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
-                    &x_qs[i*(QR5_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
+                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q5_K_q8_1_impl_mmq(
+                    &x_qs[i*(QR5_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
                     x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
             }
         }
     }
 }
 
-template  static __device__ __forceinline__ void load_tiles_q6_K(
+template  static __device__ __forceinline__ void load_tiles_q6_K(
     const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
-    float * x_df = (float *) (x_qs + WARP_SIZE*2);
-    int   * x_sc = (int   *) (x_df + WARP_SIZE/QI6_K);
+    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
+    int   * x_sc = (int   *) (x_df + MMQ_TILE_NE_K/QI6_K);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
     int   * x_sc = (int   *) (x_df + txs.dm);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+
+    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K);
+    constexpr int nrows = warp_size / threads_per_row;
+    const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + threadIdx.y;
+    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
 
         if (need_check) {
             i = min(i, i_max);
@@ -1634,67 +2054,67 @@ template  static __device__ __forceinlin
 
         const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
 
-        const int ql = get_int_b2(bxi->ql, threadIdx.x);
+        const int ql = get_int_b2(bxi->ql, txi);
         const int ql0 = (ql >> 0) & 0x0F0F0F0F;
         const int ql1 = (ql >> 4) & 0x0F0F0F0F;
 
-        const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (threadIdx.x / (QI6_K/2)) + threadIdx.x % (QI6_K/4));
-        const int qh0 = ((qh >> ((threadIdx.x & 0x08) >> 2)) << 4) & 0x30303030;
-        const int qh1 =  (qh >> ((threadIdx.x & 0x08) >> 2))       & 0x30303030;
+        const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (txi / (QI6_K/2)) + txi % (QI6_K/4));
+        const int qh0 = ((qh >> ((txi & 0x08) >> 2)) << 4) & 0x30303030;
+        const int qh1 =  (qh >> ((txi & 0x08) >> 2))       & 0x30303030;
 
-        const int kq0 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + 0;
-        const int kq1 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + QI6_K/2;
+        const int kq0 = 2*txi - txi % (QI6_K/2) + 0;
+        const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2;
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
         x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
 #else
-        x_qs[i*(2*WARP_SIZE + 1)     + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
-        x_qs[i*(2*WARP_SIZE + 1)     + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
-#endif // NEW_MMA_AVAILABLE
+        x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
+        x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 
-    const int blocks_per_tile_x_row = WARP_SIZE / QI6_K;  // == 1 if QK_K == 256
-    const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
-
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
-        int i = (i0 + threadIdx.y * QI6_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
+        int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
 
         if (need_check) {
             i = min(i, i_max);
         }
 
-        const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd;
+        const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
 
-#ifdef NEW_MMA_AVAILABLE
-        x_df[i*MMQ_MMA_TILE_X_K_Q6_K       + kbxd] = bxi->d;
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+        x_df[i*MMQ_MMA_TILE_X_K_Q6_K]           = bxi->d;
 #else
-        x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + kbxd] = bxi->d;
-#endif // NEW_MMA_AVAILABLE
+        x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 
+    constexpr int rows_per_warp = warp_size / 4;
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
-        int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
+        int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
 
         if (need_check) {
             i = min(i, i_max);
         }
 
-        const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
+        const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4;
 
-#ifdef NEW_MMA_AVAILABLE
-        x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+        x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8));
 #else
-        x_sc[i*(WARP_SIZE/8) + i/8   + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
-#endif // NEW_MMA_AVAILABLE
+        x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8));
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 }
 
-template 
+template 
 static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
     const int   * x_qs = (const int   *) x;
@@ -1704,7 +2124,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
     const float * y_df = (const float *) y;
 
 // #pragma unroll
-    for (int k01 = 0; k01 < WARP_SIZE; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) {
+    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) {
         const int k0 = k00 + k01;
 
 #pragma unroll
@@ -1712,23 +2132,74 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
             const int j = j0 + threadIdx.y;
 
 #pragma unroll
-            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
                 const int i = i0 + threadIdx.x;
 
-                const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]);
+                const int8_t * sc = ((const int8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/16]);
 
-                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
-                    &x_qs[i*(QR6_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc,
-                    x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
+                sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q6_K_q8_1_impl_mmq(
+                    &x_qs[i*(QR6_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc,
+                    x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
             }
         }
     }
 }
 
-template 
+template 
 static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE)
+    typedef tile<16,  8, int> tile_A;
+    typedef tile<16,  8, int> tile_B;
+    typedef tile<16, 16, int> tile_C;
+    typedef tile<64,  2, int> tile_load;
+
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = granularity;
+    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
+    const int   * x_sc = (const int   *) x_df + MMQ_TILE_NE_K/QI6_K;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
+
+    const int i0 = (threadIdx.y / ntx) * rows_per_warp;
+
+    for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
+        const int k0 = k00 + k01;
+
+        tile_A A[ntx];
+#pragma unroll
+        for (int n = 0; n < ntx; ++n) {
+            load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
+        }
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+            tile_B B[1];
+            load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+
+            const int j = j0 + tile_C::get_j(0);
+            const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
+
+#pragma unroll
+            for (int n = 0; n < ntx; ++n) {
+                tile_C C;
+                mma(C, A[n], B[0]);
+
+#pragma unroll
+                for (int l = 0; l < tile_C::ne; ++l) {
+                    const int i = i0 + n*tile_C::I + tile_C::get_i(l);
+                    const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16);
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;
+                }
+            }
+        }
+    }
+#elif defined(TURING_MMA_AVAILABLE)
 
     typedef tile<16, 4, int> tile_A;
     typedef tile< 8, 4, int> tile_B;
@@ -1738,11 +2209,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
     constexpr int rows_per_warp = 2 * granularity;
     constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
 
-    y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
+    y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
 
     const int   * x_qs = (const int   *) x;
-    const float * x_df = (const float *) x_qs + WARP_SIZE*2;
-    const int   * x_sc = (const int   *) x_df + WARP_SIZE/QI6_K;
+    const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
+    const int   * x_sc = (const int   *) x_df + MMQ_TILE_NE_K/QI6_K;
     const int   * y_qs = (const int   *) y + 4;
     const float * y_df = (const float *) y;
 
@@ -1755,7 +2226,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
 #pragma unroll
     for (int n = 0; n < ntx; ++n) {
 #pragma unroll
-        for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
+        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
             const int k0 = k00 + k01;
 
             load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0),         MMQ_MMA_TILE_X_K_Q6_K);
@@ -1763,7 +2234,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
         }
 
 #pragma unroll
-        for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) {
+        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 16) {
             const int k0 = k00 + k01;
 
 #pragma unroll
@@ -1793,7 +2264,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
         float tmp[ntx][tile_C::ne] = {{0.0f}};
 
 #pragma unroll
-        for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
+        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
             tile_B B[2];
             float dB[tile_C::ne/2];
 
@@ -1832,27 +2303,32 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
 #else
     GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
     NO_DEVICE_CODE;
-#endif // NEW_MMA_AVAILABLE
+#endif // AMD_MFMA_AVAILABLE
 }
 
-template  static __device__ __forceinline__ void load_tiles_iq4_nl(
+template  static __device__ __forceinline__ void load_tiles_iq4_nl(
     const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
-    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
-    const int kbx  = threadIdx.x / QI4_NL;
-    const int kqsx = threadIdx.x % QI4_NL;
+    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL);
+    constexpr int nrows = warp_size / threads_per_row;
+    const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
+    const int kbx  = txi / QI4_NL;
+    const int kqsx = txi % QI4_NL;
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + threadIdx.y;
+    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
 
         if (need_check) {
             i = min(i, i_max);
@@ -1861,23 +2337,25 @@ template  static __device__ __forceinlin
         const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx;
 
         const int aux_q4 = get_int_b2(bxi->qs, kqsx);
-        const int2 v = get_int_from_table_16(aux_q4);
-        const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
-#ifdef NEW_MMA_AVAILABLE
-        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
-        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
+        const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
+        const int k0 = kbx * (2 * QI4_NL) + kqsx;
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0]      = v.x;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y;
 #else
-        x_qs[i*(2*WARP_SIZE + 1)     + k0 + 0] = v.x;
-        x_qs[i*(2*WARP_SIZE + 1)     + k0 + 4] = v.y;
-#endif // NEW_MMA_AVAILABLE
+        x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0]      = v.x;
+        x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 
-    const int blocks_per_tile_x_row = WARP_SIZE / QI4_NL;
+    constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL;
+    constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
     const int kbxd = threadIdx.x % blocks_per_tile_x_row;
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_NL) {
-        int i = i0 + threadIdx.y * QI4_NL + threadIdx.x / blocks_per_tile_x_row;
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
+        int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
 
         if (need_check) {
             i = min(i, i_max);
@@ -1885,31 +2363,35 @@ template  static __device__ __forceinlin
 
         const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
 
-#ifdef NEW_MMA_AVAILABLE
-        x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0             + kbxd] = __half2float(bxi->d);
 #else
-        x_df[i*(WARP_SIZE/4) + i/4   + kbxd] = __half2float(bxi->d);
-#endif // NEW_MMA_AVAILABLE
+        x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 }
 
-template  static __device__ __forceinline__ void load_tiles_iq2_xxs(
+template  static __device__ __forceinline__ void load_tiles_iq2_xxs(
     const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
-    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
-    const int kqsx = threadIdx.x % (QI2_XXS/2);
+    constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2;
+    constexpr int nrows = warp_size / threads_per_row;
+    const int kqsx = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XXS/2)) {
-        int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XXS) + threadIdx.x/(QI2_XXS/2);
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
+        int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
 
         if (need_check) {
             i = min(i, i_max);
@@ -1932,42 +2414,46 @@ template  static __device__ __forceinlin
             const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
             const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
 #else
-            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 0)] = grid0;
-            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 1)] = grid1;
-#endif // NEW_MMA_AVAILABLE
+            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0;
+            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         }
 
         const int ls = aux32 >> 28;
         const float d = bxi->d;
-#ifdef NEW_MMA_AVAILABLE
-        x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0   + kqsx] = (ls*d + d/2)/4;
 #else
-        x_df[i*(WARP_SIZE/4) + i/4   + kqsx] = (ls*d + d/2)/4;
-#endif // NEW_MMA_AVAILABLE
+        x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 }
 
-template  static __device__ __forceinline__ void load_tiles_iq2_xs(
+template  static __device__ __forceinline__ void load_tiles_iq2_xs(
     const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
-    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
 #else
     constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
-    const int kqsx = threadIdx.x % (QI2_XS/2);
+    constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2;
+    constexpr int nrows = warp_size / threads_per_row;
+    const int kqsx = threadIdx.x % threads_per_row;
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XS/2)) {
-        int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XS) + threadIdx.x/(QI2_XS/2);
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
+        int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
 
         if (need_check) {
             i = min(i, i_max);
@@ -1986,44 +2472,48 @@ template  static __device__ __forceinlin
             const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
             const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
 #else
-            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 0)] = grid_l;
-            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 1)] = grid_h;
-#endif // NEW_MMA_AVAILABLE
+            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
+            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         }
 
         const int ls = bxi->scales[kqsx];
         const float d = bxi->d;
-#ifdef NEW_MMA_AVAILABLE
-        x_df[i*MMQ_MMA_TILE_X_K_Q3_K               + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
-        x_df[i*MMQ_MMA_TILE_X_K_Q3_K               + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+        x_df[i*MMQ_MMA_TILE_X_K_Q3_K                   + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
+        x_df[i*MMQ_MMA_TILE_X_K_Q3_K                   + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
 #else
-        x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
-        x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
-#endif // NEW_MMA_AVAILABLE
+        x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
+        x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 }
 
-template  static __device__ __forceinline__ void load_tiles_iq2_s(
+template  static __device__ __forceinline__ void load_tiles_iq2_s(
     const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
-    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
-    const int kqsx = threadIdx.x % (QI2_S/2);
+    constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2;
+    constexpr int nrows = warp_size / threads_per_row;
+    const int kqsx = threadIdx.x % threads_per_row;
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_S/2)) {
-        int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_S) + threadIdx.x/(QI2_S/2);
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
+        int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
 
         if (need_check) {
             i = min(i, i_max);
@@ -2049,44 +2539,48 @@ template  static __device__ __forceinlin
             const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
             const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
 #else
-            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 0)] = grid_l;
-            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 1)] = grid_h;
-#endif // NEW_MMA_AVAILABLE
+            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
+            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         }
 
         const int ls = bxi->scales[kqsx];
         const float d = bxi->d;
-#ifdef NEW_MMA_AVAILABLE
-        x_df[i*MMQ_MMA_TILE_X_K_Q3_K               + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
-        x_df[i*MMQ_MMA_TILE_X_K_Q3_K               + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+        x_df[i*MMQ_MMA_TILE_X_K_Q3_K                   + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
+        x_df[i*MMQ_MMA_TILE_X_K_Q3_K                   + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
 #else
-        x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
-        x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
-#endif // NEW_MMA_AVAILABLE
+        x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
+        x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 }
 
-template  static __device__ __forceinline__ void load_tiles_iq3_xxs(
+template  static __device__ __forceinline__ void load_tiles_iq3_xxs(
     const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
-    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
-    const int kqsx = threadIdx.x % (QI3_XXS/2);
+    constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2;
+    constexpr int nrows = warp_size / threads_per_row;
+    const int kqsx = threadIdx.x % threads_per_row;
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_XXS/2)) {
-        int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_XXS) + threadIdx.x/(QI3_XXS/2);
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
+        int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
 
         if (need_check) {
             i = min(i, i_max);
@@ -2107,42 +2601,46 @@ template  static __device__ __forceinlin
             const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
             const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
 #else
-            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 0)] = grid_l;
-            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 1)] = grid_h;
-#endif // NEW_MMA_AVAILABLE
+            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
+            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         }
 
         const int ls = aux32 >> 28;
         const float d = bxi->d;
-#ifdef NEW_MMA_AVAILABLE
-        x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0     + kqsx] = (ls*d + d/2)/2;
 #else
-        x_df[i*(WARP_SIZE/4) + i/4   + kqsx] = (ls*d + d/2)/2;
-#endif // NEW_MMA_AVAILABLE
+        x_df[i*(MMQ_TILE_NE_K/4) + i/4   + kqsx] = (ls*d + d/2)/2;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 }
 
-template  static __device__ __forceinline__ void load_tiles_iq3_s(
+template  static __device__ __forceinline__ void load_tiles_iq3_s(
     const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
-    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
-    const int kqsx = threadIdx.x % (QI3_S/2);
+    constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2;
+    constexpr int nrows = warp_size / threads_per_row;
+    const int kqsx = threadIdx.x % threads_per_row;
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_S/2)) {
-        int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_S) + threadIdx.x/(QI3_S/2);
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
+        int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
 
         if (need_check) {
             i = min(i, i_max);
@@ -2170,42 +2668,46 @@ template  static __device__ __forceinlin
             const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
             const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
 #else
-            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l+0)] = grid_l;
-            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l+1)] = grid_h;
-#endif // NEW_MMA_AVAILABLE
+            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l;
+            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         }
 
         const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
         const float d = bxi->d;
-#ifdef NEW_MMA_AVAILABLE
-        x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0     + kqsx] = ls*d;
 #else
-        x_df[i*(WARP_SIZE/4) + i/4   + kqsx] = ls*d;
-#endif // NEW_MMA_AVAILABLE
+        x_df[i*(MMQ_TILE_NE_K/4) + i/4   + kqsx] = ls*d;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 }
 
-template  static __device__ __forceinline__ void load_tiles_iq1_s(
+template  static __device__ __forceinline__ void load_tiles_iq1_s(
     const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
-    half2 * x_ds = (half2 *) (x_qs + WARP_SIZE*2);
+    half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     half2 * x_ds = (half2 *) (x_qs + txs.qs);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
-    const int kqsx = threadIdx.x % QI1_S;
+    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S);
+    constexpr int nrows = warp_size / threads_per_row;
+    const int kqsx = threadIdx.x % threads_per_row;
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI1_S) {
-        int i = i0 + threadIdx.y*(WARP_SIZE/QI1_S) + threadIdx.x/QI1_S;
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
+        int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
 
         if (need_check) {
             i = min(i, i_max);
@@ -2225,66 +2727,71 @@ template  static __device__ __forceinlin
             const int grid0 = (grid >> 0) & 0x0F0F0F0F;
             const int grid1 = (grid >> 4) & 0x0F0F0F0F;
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
 #else
-            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l+0)] = grid0;
-            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l+1)] = grid1;
-#endif // NEW_MMA_AVAILABLE
+            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0;
+            x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         }
 
         const float  d1q   = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
         const float  delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
 
-#ifdef NEW_MMA_AVAILABLE
-        x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+        x_ds[i*MMQ_MMA_TILE_X_K_Q8_1     + kqsx] = make_half2(d1q, d1q*delta);
 #else
-        x_ds[i*(WARP_SIZE/4) + i/4   + kqsx] = make_half2(d1q, d1q*delta);
-#endif // NEW_MMA_AVAILABLE
+        x_ds[i*(MMQ_TILE_NE_K/4) + i/4   + kqsx] = make_half2(d1q, d1q*delta);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 }
 
-template  static __device__ __forceinline__ void load_tiles_iq4_xs(
+template  static __device__ __forceinline__ void load_tiles_iq4_xs(
     const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-#ifdef NEW_MMA_AVAILABLE
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     int   * x_qs = (int   *)  x_tile;
-    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
-    const int kbx  = 0;           // threadIdx.x / QI4_XS
-    const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
+    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS);
+    constexpr int nrows = warp_size / threads_per_row;
+    const int kqsx = threadIdx.x % threads_per_row;
 
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
-        int i = i0 + threadIdx.y;
+    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
 
         if (need_check) {
             i = min(i, i_max);
         }
 
-        const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride + kbx;
+        const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
 
         const int aux_q4 = get_int_b4(bxi->qs, kqsx);
-        const int2 v = get_int_from_table_16(aux_q4);
-        const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
-#ifdef NEW_MMA_AVAILABLE
+        const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
+        const int k0 = 8 * (kqsx / 4) + kqsx % 4;
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
 #else
-        x_qs[i*(2*WARP_SIZE + 1)     + k0 + 0] = v.x;
-        x_qs[i*(2*WARP_SIZE + 1)     + k0 + 4] = v.y;
-#endif // NEW_MMA_AVAILABLE
+        x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
+        x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 
+    constexpr int rows_per_warp = warp_size / 8;
 #pragma unroll
-    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
-        int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
+        int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / (MMQ_TILE_NE_K/4);
 
         if (need_check) {
             i = min(i, i_max);
@@ -2297,18 +2804,21 @@ template  static __device__ __forceinlin
         const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
             | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
 
-#ifdef NEW_MMA_AVAILABLE
-        x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0   + threadIdx.x % 8] = d * (ls - 32);
 #else
-        x_df[i*(WARP_SIZE/4) + i/4   + threadIdx.x % 8] = d * (ls - 32);
-#endif // NEW_MMA_AVAILABLE
+        x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
     }
 }
 
-template
+template
 static __device__ __forceinline__ void mmq_write_back_dp4a(
         const float * __restrict__ sum, const int32_t * __restrict__ ids_dst, float * __restrict__ dst,
         const int stride, const int i_max, const int j_max) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
         const int j = j0 + threadIdx.y;
@@ -2318,32 +2828,40 @@ static __device__ __forceinline__ void mmq_write_back_dp4a(
         }
 
 #pragma unroll
-        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+        for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
             const int i = i0 + threadIdx.x;
 
             if (need_check && i > i_max) {
                 continue;
             }
 
-            dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
+            dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
         }
     }
 }
 
-template
+template
 static __device__ __forceinline__ void mmq_write_back_mma(
         const float * __restrict__ sum, const int * __restrict__ ids_dst, float * __restrict__ dst,
         const int stride, const int i_max, const int j_max) {
-    typedef tile<16, 8, int> tile_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int nwarps = mmq_get_nwarps_device();
+
+#if defined(AMD_MFMA_AVAILABLE)
+    constexpr int tileC_IJ = mmq_get_granularity_device(0);
+    typedef tile tile_C;
+    constexpr int rows_per_warp = granularity;
+#else
+    typedef tile<16, 8, int> tile_C;
     constexpr int rows_per_warp = 2 * granularity;
+#endif
     constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
 
     const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
-#ifdef NEW_MMA_AVAILABLE
+#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
     static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
-#endif // NEW_MMA_AVAILABLE
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
@@ -2371,179 +2889,189 @@ static __device__ __forceinline__ void mmq_write_back_mma(
 
 // -------------------------------------------------------------------------------------------------------------------------------------
 
-template 
+template 
 struct mmq_type_traits;
 
-template 
-struct mmq_type_traits {
+template 
+struct mmq_type_traits {
     static constexpr int              vdr          = VDR_Q4_0_Q8_1_MMQ;
-    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q4_0;
-    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma;
-    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q4_0;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a;
 };
 
-template 
-struct mmq_type_traits {
+template 
+struct mmq_type_traits {
     static constexpr int              vdr          = VDR_Q4_1_Q8_1_MMQ;
-    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q4_1;
-    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma;
-    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q4_1;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a;
 };
 
-template 
-struct mmq_type_traits {
+template 
+struct mmq_type_traits {
     static constexpr int              vdr          = VDR_Q5_0_Q8_1_MMQ;
-    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q5_0;
-    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma;
-    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q5_0;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a;
 };
 
-template 
-struct mmq_type_traits {
+template 
+struct mmq_type_traits {
     static constexpr int              vdr          = VDR_Q5_1_Q8_1_MMQ;
-    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q5_1;
-    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma;
-    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q5_1;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a;
 };
 
-template 
-struct mmq_type_traits {
+template 
+struct mmq_type_traits {
     static constexpr int              vdr          = VDR_Q8_0_Q8_1_MMQ;
-    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q8_0;
-    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma;
-    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q8_0;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a;
+};
+
+template 
+struct mmq_type_traits {
+    static constexpr int              vdr          = VDR_MXFP4_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_mxfp4;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a;
 };
 
-template 
-struct mmq_type_traits {
+template 
+struct mmq_type_traits {
     static constexpr int              vdr          = VDR_Q2_K_Q8_1_MMQ;
-    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q2_K;
-    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q2_K_q8_1_mma;
-    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q2_K;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q2_K_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a;
 };
 
-template 
-struct mmq_type_traits {
+template 
+struct mmq_type_traits {
     static constexpr int              vdr          = VDR_Q3_K_Q8_1_MMQ;
-    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q3_K;
-    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_16_q8_1_mma;
-    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q3_K;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_16_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a;
 };
 
-template 
-struct mmq_type_traits {
+template 
+struct mmq_type_traits {
     static constexpr int              vdr          = VDR_Q4_K_Q8_1_MMQ;
-    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q4_K;
-    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma;
-    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q4_K;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a;
 };
 
-template 
-struct mmq_type_traits {
+template 
+struct mmq_type_traits {
     static constexpr int              vdr          = VDR_Q5_K_Q8_1_MMQ;
-    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q5_K;
-    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma;
-    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q5_K;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a;
 };
 
-template 
-struct mmq_type_traits {
+template 
+struct mmq_type_traits {
     static constexpr int              vdr          = VDR_Q6_K_Q8_1_MMQ;
-    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q6_K;
-    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q6_K_q8_1_mma;
-    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q6_K;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q6_K_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a;
 };
 
-template 
-struct mmq_type_traits {
+template 
+struct mmq_type_traits {
     static constexpr int              vdr          = VDR_IQ2_XXS_Q8_1_MMQ;
-    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq2_xxs;
-    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma;
-    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq2_xxs;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a;
 };
 
-template 
-struct mmq_type_traits {
+template 
+struct mmq_type_traits {
     static constexpr int              vdr          = VDR_IQ2_XS_Q8_1_MMQ;
-    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq2_xs;
-    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_16_q8_1_mma;
-    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq2_xs;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_16_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a;
 };
 
-template 
-struct mmq_type_traits {
+template 
+struct mmq_type_traits {
     static constexpr int              vdr          = VDR_IQ2_S_Q8_1_MMQ;
-    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq2_s;
-    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_16_q8_1_mma;
-    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq2_s;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_16_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a;
 };
 
-template 
-struct mmq_type_traits {
+template 
+struct mmq_type_traits {
     static constexpr int              vdr          = VDR_IQ3_XXS_Q8_1_MMQ;
-    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq3_xxs;
-    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma;
-    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq3_xxs;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a;
 };
 
-template 
-struct mmq_type_traits {
+template 
+struct mmq_type_traits {
     static constexpr int              vdr          = VDR_IQ3_S_Q8_1_MMQ;
-    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq3_s;
-    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma;
-    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq3_s;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a;
 };
 
-template 
-struct mmq_type_traits {
+template 
+struct mmq_type_traits {
     static constexpr int              vdr          = VDR_IQ1_S_Q8_1_MMQ;
-    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq1_s;
-    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma;
-    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq1_s;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a;
 };
 
-template 
-struct mmq_type_traits {
+template 
+struct mmq_type_traits {
     static constexpr int              vdr          = VDR_IQ4_NL_Q8_1_MMQ;
-    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq4_nl;
-    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma;
-    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq4_nl;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a;
 };
 
-template 
-struct mmq_type_traits {
+template 
+struct mmq_type_traits {
     static constexpr int              vdr          = VDR_IQ4_XS_Q8_1_MMQ;
-    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq4_xs;
-    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma;
-    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq4_xs;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a;
 };
 
-template 
+template 
 static __device__ __forceinline__ void mul_mat_q_process_tile(
         const char * __restrict__ x, const int offset_x, const int * __restrict__ y,
         const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup,
         const int stride_row_x, const int ncols_y, const int stride_col_dst,
         const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) {
 
+    constexpr int              warp_size  = ggml_cuda_get_physical_warp_size();
+    constexpr int              nwarps     = mmq_get_nwarps_device();
     constexpr int              qk         = ggml_cuda_type_traits::qk;
     constexpr int              mmq_y      = get_mmq_y_device();
-    constexpr load_tiles_mmq_t load_tiles = mmq_type_traits::load_tiles;
+    constexpr load_tiles_mmq_t load_tiles = mmq_type_traits::load_tiles;
 
     extern __shared__ int data_mul_mat_q[];
     int * tile_y = data_mul_mat_q + mmq_x;
-    int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE);
+    int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size);
 
-#ifdef NEW_MMA_AVAILABLE
-    constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits::vec_dot_mma;
-    constexpr mmq_write_back_t write_back = mmq_write_back_mma;
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+    constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits::vec_dot_mma;
+    constexpr mmq_write_back_t write_back = mmq_write_back_mma;
 #else
-    constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits::vec_dot_dp4a;
-    constexpr mmq_write_back_t write_back = mmq_write_back_dp4a;
-#endif // NEW_MMA_AVAILABLE
+    constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits::vec_dot_dp4a;
+    constexpr mmq_write_back_t write_back = mmq_write_back_dp4a;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
 
     constexpr int blocks_per_iter = MMQ_ITER_K / qk;
 
-    float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
+    float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
 
     for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
         load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
@@ -2551,8 +3079,8 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
         {
             const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
 #pragma unroll
-            for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
-                int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
+            for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
+                int l = l0 + threadIdx.y*warp_size + threadIdx.x;
 
                 tile_y[l] = by0[l];
             }
@@ -2567,8 +3095,8 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
         {
             const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
 #pragma unroll
-            for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
-                int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
+            for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
+                int l = l0 + threadIdx.y*warp_size + threadIdx.x;
 
                 tile_y[l] = by0[l];
             }
@@ -2576,7 +3104,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
 
         __syncthreads();
 
-        vec_dot(tile_x, tile_y, sum, WARP_SIZE);
+        vec_dot(tile_x, tile_y, sum, MMQ_TILE_NE_K);
 
         __syncthreads();
     }
@@ -2591,18 +3119,18 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
 
 // The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
 
-template 
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+template 
+#if defined(GGML_USE_HIP)
 #if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
-    __launch_bounds__(WARP_SIZE*nwarps, 2)
+    __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2)
 #endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
 #else
 #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
-    __launch_bounds__(WARP_SIZE*nwarps, 1)
+    __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 1)
 #else
-    __launch_bounds__(WARP_SIZE*nwarps, 2)
+    __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2)
 #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#endif // defined(GGML_USE_HIP)
 static __global__ void mul_mat_q(
         const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst,
         const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
@@ -2616,6 +3144,9 @@ static __global__ void mul_mat_q(
         return;
     }
 
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
     constexpr int qk    = ggml_cuda_type_traits::qk;
     constexpr int mmq_y = get_mmq_y_device();
 
@@ -2627,10 +3158,10 @@ static __global__ void mul_mat_q(
     // For MoE the correct indices are loaded from ids_dst.
     extern __shared__ int ids_dst_shared[]; // Stored at beginning of shared memory.
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
-        const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
+        const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
 
-        if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
+        if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
             break;
         }
 
@@ -2638,8 +3169,8 @@ static __global__ void mul_mat_q(
     }
     __syncthreads();
 
-    // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
-#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
+    // On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
+#if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
     {
         const int wt = blockIdx.z / nchannels_y;
         const int zt = blockIdx.z - wt*nchannels_y;
@@ -2667,10 +3198,10 @@ static __global__ void mul_mat_q(
 
             // __syncthreads(); // There is no previous tile that could cause a race condition.
 #pragma unroll
-            for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
-                const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
+            for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
+                const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
 
-                if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
+                if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
                     break;
                 }
 
@@ -2688,12 +3219,12 @@ static __global__ void mul_mat_q(
         const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
 
         constexpr bool fixup = false;
-        mul_mat_q_process_tile
+        mul_mat_q_process_tile
             (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
              tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
         return;
     }
-#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
+#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
 
     const     int64_t blocks_per_ne00 = ncols_x / qk;
     constexpr int     blocks_per_iter = MMQ_ITER_K / qk;
@@ -2745,10 +3276,10 @@ static __global__ void mul_mat_q(
 
             __syncthreads();
 #pragma unroll
-            for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
-                const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
+            for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
+                const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
 
-                if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
+                if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
                     break;
                 }
 
@@ -2766,7 +3297,7 @@ static __global__ void mul_mat_q(
         const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
 
         constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
-        mul_mat_q_process_tile
+        mul_mat_q_process_tile
             (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
              tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
 
@@ -2812,10 +3343,10 @@ static __global__ void mul_mat_q(
         // The memory layout for the fixup buffer is always contiguous, therefore reset ids:
         __syncthreads();
 #pragma unroll
-        for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
-            const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
+            const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
 
-            if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
+            if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
                 break;
             }
 
@@ -2833,13 +3364,13 @@ static __global__ void mul_mat_q(
     const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
 
     constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
-    mul_mat_q_process_tile
+    mul_mat_q_process_tile
         (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
          tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
 }
 
 
-template 
+template 
 static __global__ void mul_mat_q_stream_k_fixup(
         const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
         const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
@@ -2849,7 +3380,10 @@ static __global__ void mul_mat_q_stream_k_fixup(
     constexpr int     blocks_per_iter = MMQ_ITER_K / qk;
     const     int64_t blocks_per_ne00 = ncols_x / qk;
 
-    float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+    float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
 
     const int ntx  = (ncols_dst + mmq_x - 1) / mmq_x;
     const int nty  = (nrows_x   + mmq_y - 1) / mmq_y;
@@ -2893,10 +3427,10 @@ static __global__ void mul_mat_q_stream_k_fixup(
             const int j = j0 + threadIdx.y;
 
 #pragma unroll
-            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
                 const int i = i0 + threadIdx.x;
 
-                sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
+                sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
             }
         }
 
@@ -2937,14 +3471,14 @@ static __global__ void mul_mat_q_stream_k_fixup(
             }
 
 #pragma unroll
-            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
                 const int i = i0 + threadIdx.x;
 
                 if (need_check && i > i_max) {
                     continue;
                 }
 
-                dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
+                dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
             }
         }
         return;
@@ -2955,7 +3489,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
     const int col_high = expert_bounds[zt + 1];
     const int col_diff = col_high - col_low;
 
-    for (int j = threadIdx.y*WARP_SIZE + threadIdx.x; j < mmq_x; j += nwarps*WARP_SIZE) {
+    for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) {
         ids_dst_shared[j] = ids_dst[col_low + j];
     }
     __syncthreads();
@@ -2975,14 +3509,14 @@ static __global__ void mul_mat_q_stream_k_fixup(
         }
 
 #pragma unroll
-        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+        for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
             const int i = i0 + threadIdx.x;
 
             if (need_check && i > i_max) {
                 continue;
             }
 
-            dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
+            dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
         }
     }
 }
@@ -2996,13 +3530,13 @@ struct mmq_args {
 };
 
 template
-static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc) {
+static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc, const int warp_size, const int nwarps) {
     const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
     const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
     const size_t nbs_ids = mmq_x*sizeof(int);
-    const size_t nbs_x = new_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
+    const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
     const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
-    return nbs_ids + nbs_x + GGML_PAD(nbs_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
+    return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
 }
 
 template 
@@ -3010,14 +3544,16 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
     const int id = ggml_cuda_get_device();
     const int cc = ggml_cuda_info().devices[id].cc;
     const int nsm = ggml_cuda_info().devices[id].nsm;
+    const int warp_size = ggml_cuda_info().devices[id].warp_size;
+    const int nwarps = mmq_get_nwarps_host(cc, warp_size);
     const int mmq_y = get_mmq_y_host(cc);
 
-    const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1);
+    const dim3 block_dims(warp_size, nwarps, 1);
 
-    const int nbytes_shared = mmq_get_nbytes_shared(mmq_x, mmq_y, cc);
+    const int nbytes_shared = mmq_get_nbytes_shared(mmq_x, mmq_y, cc, warp_size, nwarps);
 
-    CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared);
-    CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q),  nbytes_shared);
+    CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared);
+    CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared);
 
     const int nty  = (args.nrows_x   + mmq_y - 1) / mmq_y;
     const int ntx  = (args.ncols_dst + mmq_x - 1) / mmq_x;
@@ -3032,14 +3568,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
     if (!args.use_stream_k) {
         if (args.nrows_x % mmq_y == 0) {
             constexpr bool need_check = false;
-            mul_mat_q<<>>
+            mul_mat_q<<>>
                 (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
                  args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
                  channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
                  sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
         } else {
             constexpr bool need_check = true;
-            mul_mat_q<<>>
+            mul_mat_q<<>>
                 (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
                  args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
                  channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
@@ -3059,8 +3595,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
 
     if (args.nrows_x % mmq_y == 0) {
         constexpr bool need_check = false;
-
-        mul_mat_q<<>>
+        mul_mat_q<<>>
             (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
              args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
              channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
@@ -3070,13 +3605,12 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
             return;
         }
 
-        mul_mat_q_stream_k_fixup<<>>
+        mul_mat_q_stream_k_fixup<<>>
             (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
              args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
     } else {
         constexpr bool need_check = true;
-
-        mul_mat_q<<>>
+        mul_mat_q<<>>
             (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
              args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
              channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
@@ -3086,7 +3620,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
             return;
         }
 
-        mul_mat_q_stream_k_fixup<<>>
+        mul_mat_q_stream_k_fixup<<>>
             (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
              args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
     }
@@ -3094,9 +3628,11 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
 
 template 
 void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
-    const int    id    = ggml_cuda_get_device();
-    const int    cc    = ggml_cuda_info().devices[id].cc;
-    const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
+    const int    id     = ggml_cuda_get_device();
+    const int    cc     = ggml_cuda_info().devices[id].cc;
+    const size_t smpbo  = ggml_cuda_info().devices[id].smpbo;
+    const int warp_size = ggml_cuda_info().devices[id].warp_size;
+    const int nwarps    = mmq_get_nwarps_host(cc, warp_size);
 
     const int mmq_x_max = get_mmq_x_max_host(cc);
     const int mmq_y = get_mmq_y_host(cc);
@@ -3107,7 +3643,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
     for (int mmq_x = 8; mmq_x <= mmq_x_max && ntiles_x_best > 1; mmq_x += 8) {
         const int granularity = mmq_get_granularity_host(mmq_x, cc);
 
-        if (mmq_x % granularity != 0 || mmq_get_nbytes_shared(mmq_x, mmq_y, cc) > smpbo) {
+        if (mmq_x % granularity != 0 || mmq_get_nbytes_shared(mmq_x, mmq_y, cc, warp_size, nwarps) > smpbo) {
             continue;
         }
 
@@ -3183,6 +3719,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
 extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
 extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
 extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
+extern DECL_MMQ_CASE(GGML_TYPE_MXFP4);
 extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
 extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
 extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu
new file mode 100644
index 00000000000..16100b68045
--- /dev/null
+++ b/ggml/src/ggml-cuda/mmvf.cu
@@ -0,0 +1,511 @@
+#include "ggml.h"
+#include "common.cuh"
+#include "convert.cuh"
+#include "mmvf.cuh"
+
+template 
+static __global__ void mul_mat_vec_f(
+        const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
+        const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
+        const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+        const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
+    const int row         = blockIdx.x;
+    const int channel_dst = blockIdx.y;
+    const int channel_x   = ids ? ids[channel_dst]          : channel_dst / channel_ratio;
+    const int channel_y   = ids ? channel_dst % nchannels_y : channel_dst;
+    const int sample_dst  = blockIdx.z;
+    const int sample_x    = sample_dst / sample_ratio;
+    const int sample_y    = sample_dst;
+    const int tid         = threadIdx.x;
+
+    constexpr int warp_size   = ggml_cuda_get_physical_warp_size();
+
+    x   += int64_t(sample_x)  *stride_sample_x   + channel_x  *stride_channel_x   + row*stride_row;
+    y   += int64_t(sample_y)  *stride_sample_y   + channel_y  *stride_channel_y;
+    dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
+
+    const float2 * y2 = (const float2 *) y;
+
+    extern __shared__ char data_mmv[];
+    float * buf_iw = (float *) data_mmv;
+
+    if (block_size > warp_size) {
+        if (tid < warp_size) {
+            buf_iw[tid] = 0.0f;
+        }
+        __syncthreads();
+    }
+
+    float sumf[ncols_dst] = {0.0f};
+
+    if constexpr (std::is_same_v) {
+        const float2 * x2 = (const float2 *) x;
+
+        for (int col2 = tid; col2 < ncols2; col2 += block_size) {
+            const float2 tmpx = x2[col2];
+
+#pragma unroll
+            for (int j = 0; j < ncols_dst; ++j) {
+                const float2 tmpy = y2[j*stride_col_y2 + col2];
+                sumf[j] += tmpx.x*tmpy.x;
+                sumf[j] += tmpx.y*tmpy.y;
+            }
+        }
+    } else if constexpr (std::is_same_v) {
+        const half2 * x2 = (const half2 *) x;
+
+        if (std::is_same_v) {
+            for (int col2 = tid; col2 < ncols2; col2 += block_size) {
+                const float2 tmpx = __half22float2(x2[col2]);
+
+#pragma unroll
+                for (int j = 0; j < ncols_dst; ++j) {
+                    const float2 tmpy = y2[j*stride_col_y2 + col2];
+                    sumf[j] += tmpx.x * tmpy.x;
+                    sumf[j] += tmpx.y * tmpy.y;
+                }
+            }
+        } else {
+#ifdef FP16_AVAILABLE
+            half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};
+
+            for (int col2 = tid; col2 < ncols2; col2 += block_size) {
+                const half2 tmpx = x2[col2];
+
+#pragma unroll
+                for (int j = 0; j < ncols_dst; ++j) {
+                    const float2 tmpy = y2[j*stride_col_y2 + col2];
+                    sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y);
+                }
+            }
+
+#pragma unroll
+            for (int j = 0; j < ncols_dst; ++j) {
+                sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]);
+            }
+#else
+            NO_DEVICE_CODE;
+#endif // FP16_AVAILABLE
+        }
+    } else if constexpr (std::is_same_v) {
+        const int * x2 = (const int *) x;
+        for (int col2 = tid; col2 < ncols2; col2 += block_size) {
+            const int tmpx = x2[col2];
+#pragma unroll
+            for (int j = 0; j < ncols_dst; ++j) {
+                const float2 tmpy = y2[j*stride_col_y2 + col2];
+                sumf[j] += ggml_cuda_cast(reinterpret_cast(&tmpx)[0]) * tmpy.x;
+                sumf[j] += ggml_cuda_cast(reinterpret_cast(&tmpx)[1]) * tmpy.y;
+            }
+        }
+    } else {
+        static_assert(std::is_same_v, "unsupported type");
+    }
+
+#pragma unroll
+    for (int j = 0; j < ncols_dst; ++j) {
+        sumf[j] = warp_reduce_sum(sumf[j]);
+
+        if (block_size > warp_size) {
+            buf_iw[tid/warp_size] = sumf[j];
+            __syncthreads();
+            if (tid < warp_size) {
+                sumf[j] = buf_iw[tid];
+                sumf[j] = warp_reduce_sum(sumf[j]);
+            }
+            if (j < ncols_dst) {
+                __syncthreads();
+            }
+        }
+    }
+
+    if (tid >= ncols_dst) {
+        return;
+    }
+
+    dst[tid*stride_col_dst + row] = sumf[tid];
+}
+
+template 
+static void launch_mul_mat_vec_f_cuda(
+        const T * x, const float * y, const int32_t * ids, float * dst,
+        const int64_t ncols, const int64_t nrows,
+        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+        cudaStream_t stream) {
+    GGML_ASSERT(ncols        % 2 == 0);
+    GGML_ASSERT(stride_row   % 2 == 0);
+    GGML_ASSERT(stride_col_y % 2 == 0);
+    GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
+    GGML_ASSERT(       nsamples_dst  % nsamples_x  == 0);
+    const int64_t channel_ratio = nchannels_dst / nchannels_x;
+    const int64_t sample_ratio  = nsamples_dst  / nsamples_x;
+
+    const int device = ggml_cuda_get_device();
+    const int warp_size = ggml_cuda_info().devices[device].warp_size;
+
+    int64_t block_size_best = warp_size;
+    int64_t niter_best      = (ncols + 2*warp_size - 1) / (2*warp_size);
+    int64_t max_block_size  = 256;
+    if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) {
+        max_block_size = 128;
+    }
+    for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) {
+        const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size);
+        if (niter < niter_best) {
+            niter_best      = niter;
+            block_size_best = block_size;
+        }
+    }
+
+    const int nbytes_shared = warp_size*sizeof(float);
+    const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
+    const dim3 block_dims(block_size_best, 1, 1);
+    switch (block_size_best) {
+        case   32: {
+            mul_mat_vec_f<<>>
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case   64: {
+            mul_mat_vec_f<<>>
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case   96: {
+            mul_mat_vec_f<<>>
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case  128: {
+            mul_mat_vec_f<<>>
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case  160: {
+            mul_mat_vec_f<<>>
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case  192: {
+            mul_mat_vec_f<<>>
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case  224: {
+            mul_mat_vec_f<<>>
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        case  256: {
+            mul_mat_vec_f<<>>
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        } break;
+        default: {
+            GGML_ABORT("fatal error");
+        } break;
+    }
+}
+
+template 
+static void mul_mat_vec_f_cuda_switch_ncols_dst(
+        const T * x, const float * y, const int32_t * ids, float * dst,
+        const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
+        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+        cudaStream_t stream) {
+    switch (ncols_dst) {
+        case 1:
+            launch_mul_mat_vec_f_cuda
+                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+            break;
+        case 2:
+            launch_mul_mat_vec_f_cuda
+                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+            break;
+        case 3:
+            launch_mul_mat_vec_f_cuda
+                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+            break;
+        case 4:
+            launch_mul_mat_vec_f_cuda
+                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+            break;
+        case 5:
+            launch_mul_mat_vec_f_cuda
+                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+            break;
+        case 6:
+            launch_mul_mat_vec_f_cuda
+                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+            break;
+        case 7:
+            launch_mul_mat_vec_f_cuda
+                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+            break;
+        case 8:
+            launch_mul_mat_vec_f_cuda
+                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+            break;
+        default:
+            GGML_ABORT("fatal error");
+            break;
+    }
+}
+
+template
+static void mul_mat_vec_f_cuda(
+        const T * x, const float * y, const int32_t * ids, float * dst,
+        const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
+        const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
+        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+        enum ggml_prec prec, cudaStream_t stream) {
+    if constexpr(std::is_same_v) {
+        if (prec == GGML_PREC_DEFAULT) {
+            mul_mat_vec_f_cuda_switch_ncols_dst
+                (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+            return;
+        }
+    }
+    mul_mat_vec_f_cuda_switch_ncols_dst
+        (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+         nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+         stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+}
+
+void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
+    GGML_ASSERT(        src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(!ids ||  ids->type == GGML_TYPE_I32);
+    GGML_ASSERT(         dst->type == GGML_TYPE_F32);
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    const size_t ts_src0 = ggml_type_size(src0->type);
+    const size_t ts_src1 = ggml_type_size(src1->type);
+    const size_t ts_dst  = ggml_type_size(dst->type);
+
+    GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for  batch size 1.
+    GGML_ASSERT(ne13 == ne3);
+
+    GGML_ASSERT(        nb00       == ts_src0);
+    GGML_ASSERT(        nb10       == ts_src1);
+    GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
+    GGML_ASSERT(        nb0        == ts_dst);
+
+    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+    const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
+
+    const float   * src1_d =       (const float   *) src1->data;
+    const int32_t *  ids_d = ids ? (const int32_t *)  ids->data : nullptr;
+    float         *  dst_d =       (float         *)  dst->data;
+
+    const int64_t s01 = src0->nb[1] / ts_src0;
+    const int64_t s11 = src1->nb[1] / ts_src1;
+    const int64_t s1  =  dst->nb[1] / ts_dst;
+    const int64_t s02 = src0->nb[2] / ts_src0;
+    const int64_t s12 = src1->nb[2] / ts_src1;
+    const int64_t s2  =  dst->nb[2] / ts_dst;
+    const int64_t s03 = src0->nb[3] / ts_src0;
+    const int64_t s13 = src1->nb[3] / ts_src1;
+    const int64_t s3  =  dst->nb[3] / ts_dst;
+
+    // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
+    const int64_t ncols_dst          = ids ? ne2  : ne1;
+    const int64_t nchannels_y        = ids ? ne11 : ne12;
+    const int64_t nchannels_dst      = ids ? ne1  : ne2;
+    const int64_t stride_channel_dst = ids ? s1   : s2;
+    const int64_t stride_channel_y   = ids ? s11  : s12;
+
+    GGML_ASSERT(!ids || ncols_dst == 1);
+
+    switch (src0->type) {
+        case GGML_TYPE_F32: {
+            const float * src0_d = (const float *) src0->data;
+            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
+                ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
+                ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
+        } break;
+        case GGML_TYPE_F16: {
+            const half * src0_d = (const half *) src0->data;
+            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
+                ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
+                ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
+        } break;
+        case GGML_TYPE_BF16: {
+            const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
+            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
+                ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
+                ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
+        } break;
+        default:
+            GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
+    }
+}
+
+void ggml_cuda_op_mul_mat_vec_f(
+    ggml_backend_cuda_context & ctx,
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, cudaStream_t stream) {
+
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne0  =  dst->ne[0];
+    const int64_t row_diff = row_high - row_low;
+
+    const int id = ggml_cuda_get_device();
+    const int cc = ggml_cuda_info().devices[id].cc;
+    const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
+
+
+    // ggml_cuda_op provides single, contiguous matrices
+    const int64_t stride_row         = ne00;
+    const int64_t stride_col_y       = ne10;
+    const int64_t stride_col_dst     = id == ctx.device ? ne0 : row_diff; // main device has larger memory buffer
+    const int64_t nchannels_x        = 1;
+    const int64_t nchannels_y        = 1;
+    const int64_t nchannels_dst      = 1;
+    const int64_t stride_channel_x   = 0;
+    const int64_t stride_channel_y   = 0;
+    const int64_t stride_channel_dst = 0;
+    const int64_t nsamples_x         = 1;
+    const int64_t nsamples_dst       = 1;
+    const int64_t stride_sample_x    = 0;
+    const int64_t stride_sample_y    = 0;
+    const int64_t stride_sample_dst  = 0;
+
+    switch (src0->type) {
+        case GGML_TYPE_F32: {
+            const float * src0_d = (const float *) src0_dd_i;
+            mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
+        } break;
+        case GGML_TYPE_F16: {
+            const half * src0_d = (const half *) src0_dd_i;
+            mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
+        } break;
+        case GGML_TYPE_BF16: {
+            const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
+            mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
+        } break;
+        default:
+            GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
+    }
+
+    GGML_UNUSED(ctx);
+    GGML_UNUSED(src1);
+    GGML_UNUSED(dst);
+    GGML_UNUSED(src1_ddq_i);
+    GGML_UNUSED(src1_ncols);
+    GGML_UNUSED(src1_padded_row_size);
+}
+
+bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) {
+    if (src0_ne[0] % 2 != 0) {
+        return false;
+    }
+    switch (type) {
+        case GGML_TYPE_F32:
+            if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
+                if (ampere_mma_available(cc)) {
+                    return ne11 <= 3;
+                }
+                if (cc >= GGML_CUDA_CC_TURING) {
+                    return ne11 <= 4;
+                }
+                return ne11 <= 3;
+            } else if (GGML_CUDA_CC_IS_AMD(cc)) {
+                if (fp32_mma_hardware_available(cc)) {
+                    return ne11 <= 3;
+                }
+                return ne11 <= 8;
+            }
+            return ne11 <= 8;
+        case GGML_TYPE_F16:
+            if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
+                const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
+                if (ampere_mma_available(cc)) {
+                    return src0_small && ne11 == 1;
+                }
+                if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
+                    return src0_small && ne11 <= 4;
+                }
+                if (fp16_mma_hardware_available(cc)) {
+                    return src0_small && ne11 <= 3;
+                }
+                return ne11 <= 8;
+            } else if (GGML_CUDA_CC_IS_AMD(cc)) {
+                if (fp16_mma_hardware_available(cc)) {
+                    if (GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
+                        return ne11 <= 5;
+                    }
+                    return ne11 <= 2;
+                }
+                return ne11 <= 8;
+            }
+            return ne11 <= 8;
+        case GGML_TYPE_BF16:
+            if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
+                const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
+                if (ampere_mma_available(cc)) {
+                    return src0_small && ne11 == 1;
+                }
+                if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
+                    return src0_small && ne11 <= 4;
+                }
+                if (bf16_mma_hardware_available(cc)) {
+                    return src0_small && ne11 <= 3;
+                }
+                return ne11 <= 8;
+            } else if (GGML_CUDA_CC_IS_AMD(cc)) {
+                if (bf16_mma_hardware_available(cc)) {
+                    return ne11 <= 3;
+                }
+                return ne11 <= 8;
+            }
+            return ne11 <= 8;
+        default:
+            return false;
+    }
+}
diff --git a/ggml/src/ggml-cuda/mmvf.cuh b/ggml/src/ggml-cuda/mmvf.cuh
new file mode 100644
index 00000000000..1da460992e7
--- /dev/null
+++ b/ggml/src/ggml-cuda/mmvf.cuh
@@ -0,0 +1,11 @@
+#include "common.cuh"
+
+void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
+
+void ggml_cuda_op_mul_mat_vec_f(
+    ggml_backend_cuda_context & ctx,
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, cudaStream_t stream);
+
+bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11);
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
index dc7adf509fa..5c8e5c4a7ee 100644
--- a/ggml/src/ggml-cuda/mmvq.cu
+++ b/ggml/src/ggml-cuda/mmvq.cu
@@ -13,6 +13,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
         case GGML_TYPE_Q5_0:    return vec_dot_q5_0_q8_1;
         case GGML_TYPE_Q5_1:    return vec_dot_q5_1_q8_1;
         case GGML_TYPE_Q8_0:    return vec_dot_q8_0_q8_1;
+        case GGML_TYPE_MXFP4:   return vec_dot_mxfp4_q8_1;
         case GGML_TYPE_Q2_K:    return vec_dot_q2_K_q8_1;
         case GGML_TYPE_Q3_K:    return vec_dot_q3_K_q8_1;
         case GGML_TYPE_Q4_K:    return vec_dot_q4_K_q8_1;
@@ -38,6 +39,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
         case GGML_TYPE_Q5_0:    return VDR_Q5_0_Q8_1_MMVQ;
         case GGML_TYPE_Q5_1:    return VDR_Q5_1_Q8_1_MMVQ;
         case GGML_TYPE_Q8_0:    return VDR_Q8_0_Q8_1_MMVQ;
+        case GGML_TYPE_MXFP4:   return VDR_MXFP4_Q8_1_MMVQ;
         case GGML_TYPE_Q2_K:    return VDR_Q2_K_Q8_1_MMVQ;
         case GGML_TYPE_Q3_K:    return VDR_Q3_K_Q8_1_MMVQ;
         case GGML_TYPE_Q4_K:    return VDR_Q4_K_Q8_1_MMVQ;
@@ -384,6 +386,13 @@ static void mul_mat_vec_q_switch_type(
                  nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
                  stream);
             break;
+        case GGML_TYPE_MXFP4:
+            mul_mat_vec_q_switch_ncols_dst
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
+            break;
         case GGML_TYPE_Q2_K:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu
index 0020dbcec5f..bddcca51b7b 100644
--- a/ggml/src/ggml-cuda/norm.cu
+++ b/ggml/src/ggml-cuda/norm.cu
@@ -104,10 +104,12 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
     }
 }
 
-template 
+template 
 static __global__ void rms_norm_f32(
         const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
-        const int64_t stride_sample, const float eps) {
+        const int64_t stride_sample, const float eps, const float * mul = nullptr, const int64_t mul_stride_row = 0,
+        const int64_t mul_stride_channel = 0, const int64_t mul_stride_sample = 0, const int mul_ncols = 0,
+        const int mul_nrows = 0, const int mul_nchannels = 0, const int mul_nsamples = 0) {
     const int nrows     = gridDim.x;
     const int nchannels = gridDim.y;
 
@@ -119,6 +121,13 @@ static __global__ void rms_norm_f32(
     x   += sample*stride_sample + channel*stride_channel + row*stride_row;
     dst += ((sample*nchannels + channel)*nrows + row)*ncols;
 
+    if constexpr (do_multiply) {
+        const int mul_row = row % mul_nrows;
+        const int mul_channel = channel % mul_nchannels;
+        const int mul_sample = sample % mul_nsamples;
+        mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row;
+    }
+
     float tmp = 0.0f; // partial sum for thread in warp
 
     for (int col = tid; col < ncols; col += block_size) {
@@ -145,7 +154,12 @@ static __global__ void rms_norm_f32(
     const float scale = rsqrtf(mean + eps);
 
     for (int col = tid; col < ncols; col += block_size) {
-        dst[col] = scale * x[col];
+        if constexpr (do_multiply) {
+            const int mul_col = col % mul_ncols;
+            dst[col] = scale * x[col] * mul[mul_col];
+        } else {
+            dst[col] = scale * x[col];
+        }
     }
 }
 
@@ -310,10 +324,30 @@ static void rms_norm_f32_cuda(
     const dim3 blocks_num(nrows, nchannels, nsamples);
     if (ncols < 1024) {
         const dim3 block_dims(WARP_SIZE, 1, 1);
-        rms_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+        rms_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+    } else {
+        const dim3 block_dims(1024, 1, 1);
+        rms_norm_f32<1024, false><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+    }
+}
+
+static void rms_norm_mul_f32_cuda(
+        const float * x, const float * mul, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
+        const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
+        const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample,
+        const int mul_ncols, const int mul_nrows, const int mul_nchannels, const int mul_nsamples,
+        const float eps, cudaStream_t stream) {
+    const dim3 blocks_num(nrows, nchannels, nsamples);
+    if (mul == nullptr) {
+        rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
+        return;
+    }
+    if (ncols < 1024) {
+        const dim3 block_dims(WARP_SIZE, 1, 1);
+        rms_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
     } else {
         const dim3 block_dims(1024, 1, 1);
-        rms_norm_f32<1024><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+        rms_norm_f32<1024, true><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
     }
 }
 
@@ -407,6 +441,59 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
 }
 
+void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor) {
+    const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
+    float eps = 0.0f;
+
+    memcpy(&eps, dst->op_params, sizeof(float));
+
+    const float * src0_d = (const float *) rms_norm_src->data;
+    const float * mul_d = nullptr;
+    const ggml_tensor * mul_src = nullptr;
+
+    if (mul_tensor->src[0] == dst) {
+        mul_d = (float *) mul_tensor->src[1]->data;
+        mul_src = mul_tensor->src[1];
+    } else if(mul_tensor->src[1] == dst) {
+        mul_d = (float *) mul_tensor->src[0]->data;
+        mul_src = mul_tensor->src[0];
+    } else {
+        GGML_ASSERT(false);
+    }
+
+    float * dst_d = (float *) mul_tensor->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
+    GGML_ASSERT(eps >= 0.0f);
+
+    const int64_t ne00 = rms_norm_src->ne[0];
+    const int64_t ne01 = rms_norm_src->ne[1];
+    const int64_t ne02 = rms_norm_src->ne[2];
+    const int64_t ne03 = rms_norm_src->ne[3];
+
+    const size_t ts0 = ggml_type_size(rms_norm_src->type);
+    GGML_ASSERT(rms_norm_src->nb[0] == ts0);
+    const int64_t s01 = rms_norm_src->nb[1] / ts0;
+    const int64_t s02 = rms_norm_src->nb[2] / ts0;
+    const int64_t s03 = rms_norm_src->nb[3] / ts0;
+
+    const size_t ts_mul = ggml_type_size(mul_src->type);
+    GGML_ASSERT(mul_src->nb[0] == ts_mul);
+    const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
+    const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
+    const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
+
+    const int mul_ncols     = mul_src->ne[0];
+    const int mul_nrows     = mul_src->ne[1];
+    const int mul_nchannels = mul_src->ne[2];
+    const int mul_nsamples  = mul_src->ne[3];
+
+    rms_norm_mul_f32_cuda(src0_d, mul_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, mul_s01, mul_s02, mul_s03, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, eps, stream);
+}
+
 void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * grad  = dst->src[0]; // gradients
     const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass
diff --git a/ggml/src/ggml-cuda/norm.cuh b/ggml/src/ggml-cuda/norm.cuh
index 706a5660a68..7ea7bd4df3c 100644
--- a/ggml/src/ggml-cuda/norm.cuh
+++ b/ggml/src/ggml-cuda/norm.cuh
@@ -6,6 +6,8 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
 
 void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
+void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor);
+
 void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/opt-step-sgd.cu b/ggml/src/ggml-cuda/opt-step-sgd.cu
new file mode 100644
index 00000000000..460b16de447
--- /dev/null
+++ b/ggml/src/ggml-cuda/opt-step-sgd.cu
@@ -0,0 +1,49 @@
+#include "ggml-impl.h"
+#include "opt-step-sgd.cuh"
+
+#include 
+
+static __global__ void opt_step_sgd_f32(
+    float * __restrict__ x, const float * __restrict__ g,
+    const float * __restrict__ pars, const int64_t k) {
+
+    const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    x[i] = x[i] * (1.0f - pars[0] * pars[1]) - pars[0] * g[i];
+}
+
+static void opt_step_sgd_f32_cuda(
+    float * x, const float * g, const float * __restrict__ pars, const int64_t k, cudaStream_t stream) {
+
+    const dim3 block_dims(CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1);
+    const dim3 block_nums((k + CUDA_OPT_STEP_SGD_BLOCK_SIZE - 1) / CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1);
+    opt_step_sgd_f32<<>>(x, g, pars, k);
+}
+
+void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0      = dst->src[0];
+    const ggml_tensor * src0_grad = dst->src[1];
+    const ggml_tensor * params    = dst->src[2];
+
+    GGML_ASSERT(src0->type      == GGML_TYPE_F32);
+    GGML_ASSERT(src0_grad->type == GGML_TYPE_F32);
+    GGML_ASSERT(params->type    == GGML_TYPE_F32);
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src0_grad));
+    GGML_ASSERT(ggml_is_contiguous(params));
+    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
+    GGML_ASSERT(ggml_nelements(params) == 2);
+
+    float       * src0_d      = (float       *) src0->data;
+    const float * src0_grad_d = (const float *) src0_grad->data;
+    const float * params_d    = (const float *) params->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    const int64_t ne = ggml_nelements(src0);
+
+    opt_step_sgd_f32_cuda(src0_d, src0_grad_d, params_d, ne, stream);
+}
diff --git a/ggml/src/ggml-cuda/opt-step-sgd.cuh b/ggml/src/ggml-cuda/opt-step-sgd.cuh
new file mode 100644
index 00000000000..f97ab7d9bed
--- /dev/null
+++ b/ggml/src/ggml-cuda/opt-step-sgd.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_OPT_STEP_SGD_BLOCK_SIZE 256
+
+void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/reduce_rows.cuh b/ggml/src/ggml-cuda/reduce_rows.cuh
new file mode 100644
index 00000000000..6bee204136b
--- /dev/null
+++ b/ggml/src/ggml-cuda/reduce_rows.cuh
@@ -0,0 +1,53 @@
+#include "common.cuh"
+
+// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
+template 
+static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols) {
+    const int row = blockIdx.x;
+    const int col = threadIdx.x;
+
+    float     sum        = 0.0f;
+    const int num_unroll = 8;
+    float     temp[num_unroll];
+    float     sum_temp[num_unroll] = { 0.0f };
+    for (int i = col; i < ncols;) {
+        for (int j = 0; j < num_unroll; ++j) {
+            if (i < ncols) {
+                temp[j] = x[row * ncols + i];
+            } else {
+                temp[j] = 0;
+            }
+            i += blockDim.x;
+        }
+        for (int j = 0; j < num_unroll; ++j) {
+            sum_temp[j] += temp[j];
+        }
+    }
+    for (int j = 0; j < num_unroll; ++j) {
+        sum += sum_temp[j];
+    }
+
+    // sum up partial sums
+    sum = warp_reduce_sum(sum);
+    if (blockDim.x > WARP_SIZE) {
+        assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
+        __shared__ float s_sum[32];
+        const int        warp_id = threadIdx.x / WARP_SIZE;
+        const int        lane_id = threadIdx.x % WARP_SIZE;
+        if (lane_id == 0) {
+            s_sum[warp_id] = sum;
+        }
+        __syncthreads();
+        sum = 0.0f;
+        if (lane_id < (blockDim.x / WARP_SIZE)) {
+            sum = s_sum[lane_id];
+        }
+        sum = warp_reduce_sum(sum);
+    }
+
+    if (col != 0) {
+        return;
+    }
+
+    dst[row] = norm ? sum / ncols : sum;
+}
diff --git a/ggml/src/ggml-cuda/roll.cu b/ggml/src/ggml-cuda/roll.cu
new file mode 100644
index 00000000000..a339dfc1ae0
--- /dev/null
+++ b/ggml/src/ggml-cuda/roll.cu
@@ -0,0 +1,67 @@
+#include "ggml-cuda/common.cuh"
+#include "roll.cuh"
+
+static __forceinline__ __device__ int64_t wrap_index(const int64_t idx, const int64_t ne) {
+    if (idx < 0) {
+        return idx + ne;
+    }
+    if (idx >= ne) {
+        return idx - ne;
+    }
+    return idx;
+}
+
+static __global__ void roll_f32_cuda(const float * __restrict__ src,
+                                     float * __restrict__ dst,
+                                     const int64_t ne00,
+                                     const int64_t ne01,
+                                     const int64_t ne02,
+                                     const int64_t ne03,
+                                     const int     s0,
+                                     const int     s1,
+                                     const int     s2,
+                                     const int     s3) {
+    const int64_t idx        = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
+    const int64_t n_elements = ne00 * ne01 * ne02 * ne03;
+
+    if (idx >= n_elements) {
+        return;
+    }
+
+    const int64_t i0 = idx % ne00;
+    const int64_t i1 = (idx / ne00) % ne01;
+    const int64_t i2 = (idx / (ne00 * ne01)) % ne02;
+    const int64_t i3 = (idx / (ne00 * ne01 * ne02)) % ne03;
+
+    const int64_t d0 = wrap_index(i0 - s0, ne00);
+    const int64_t d1 = wrap_index(i1 - s1, ne01);
+    const int64_t d2 = wrap_index(i2 - s2, ne02);
+    const int64_t d3 = wrap_index(i3 - s3, ne03);
+
+    dst[i3 * (ne00 * ne01 * ne02) + i2 * (ne01 * ne00) + i1 * ne00 + i0] =
+        src[d3 * (ne00 * ne01 * ne02) + d2 * (ne01 * ne00) + d1 * ne00 + d0];
+}
+
+void ggml_cuda_op_roll(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    int s0 = dst->op_params[0];
+    int s1 = dst->op_params[1];
+    int s2 = dst->op_params[2];
+    int s3 = dst->op_params[3];
+
+    const ggml_tensor * src0   = dst->src[0];
+    const float *       src0_d = (const float *) dst->src[0]->data;
+    float *             dst_d  = (float *) dst->data;
+
+    GGML_TENSOR_UNARY_OP_LOCALS;
+
+    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
+    GGML_ASSERT(ggml_are_same_shape(dst->src[0], dst));
+
+    cudaStream_t stream = ctx.stream();
+
+    int64_t sz         = (ne00 * ne01 * ne02 * ne03);
+    int64_t num_blocks = (sz + CUDA_ROLL_BLOCK_SIZE - 1) / CUDA_ROLL_BLOCK_SIZE;
+
+    roll_f32_cuda<<>>(
+        src0_d, dst_d, ne00, ne01, ne02, ne03, s0, s1, s2, s3);
+}
diff --git a/ggml/src/ggml-cuda/roll.cuh b/ggml/src/ggml-cuda/roll.cuh
new file mode 100644
index 00000000000..322d55436e2
--- /dev/null
+++ b/ggml/src/ggml-cuda/roll.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_ROLL_BLOCK_SIZE 256
+
+void ggml_cuda_op_roll(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/set-rows.cu b/ggml/src/ggml-cuda/set-rows.cu
new file mode 100644
index 00000000000..b4115a43c2a
--- /dev/null
+++ b/ggml/src/ggml-cuda/set-rows.cu
@@ -0,0 +1,268 @@
+#include "set-rows.cuh"
+#include "cpy-utils.cuh"
+
+typedef void (*set_rows_kernel_t)(const char * src, char * dst);
+
+// Generic quantized set_rows kernel template
+template
+static __global__ void k_set_rows_quant(
+        const float * __restrict__ src0, const int64_t * __restrict__ src1, block_type * __restrict__ dst,
+        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+        const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
+        const int64_t s01, const int64_t s02, const int64_t s03,
+        const int64_t s10, const int64_t s11, const int64_t s12,
+        const int64_t s1, const int64_t s2, const int64_t s3) {
+
+    const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
+    const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk;
+
+    if (i >= ne_total) {
+        return;
+    }
+
+    const int64_t i_base = i * qk;
+    const int64_t i03 = i_base / (ne00 * ne01 * ne02);
+    const int64_t i02 = (i_base - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
+    const int64_t i01 = (i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
+    const int64_t i00 = i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
+
+    const int64_t i12 = i03 % ne12;
+    const int64_t i11 = i02 % ne11;
+    const int64_t i10 = i01;
+
+    const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
+
+    const float * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
+    block_type * dst_row_ptr = dst + (dst_row*s1 + i02*s2 + i03*s3) / sizeof(block_type);
+
+    const float * src_block = src0_row + i00;
+    block_type * dst_block = dst_row_ptr + i00 / qk;
+
+    quantize_func(src_block, dst_block);
+
+    GGML_UNUSED(ne10);
+    GGML_UNUSED(ne13);
+}
+
+// Template dispatch function for quantized set_rows
+template
+static void set_rows_cuda_quant(
+        const float * src0_d, const int64_t * src1_d, block_type * dst_d,
+        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+        const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
+        const size_t nb01, const size_t nb02, const size_t nb03,
+        const size_t nb10, const size_t nb11, const size_t nb12,
+        const size_t nb1, const size_t nb2, const size_t nb3,
+        cudaStream_t stream) {
+
+    GGML_ASSERT(ne00 % qk == 0);
+    const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk;
+    const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE;
+    const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE);
+    const dim3 grid_size(num_blocks);
+
+    const int64_t s01 = nb01/sizeof(float);
+    const int64_t s02 = nb02/sizeof(float);
+    const int64_t s03 = nb03/sizeof(float);
+    const int64_t s10 = nb10/sizeof(int64_t);
+    const int64_t s11 = nb11/sizeof(int64_t);
+    const int64_t s12 = nb12/sizeof(int64_t);
+    const int64_t s1  = nb1;
+    const int64_t s2  = nb2;
+    const int64_t s3  = nb3;
+
+    if (ne_total > 0) {
+        k_set_rows_quant<<>>(
+            src0_d, src1_d, dst_d,
+            ne00, ne01, ne02, ne03,
+            ne10, ne11, ne12, ne13,
+            s01, s02, s03,
+            s10, s11, s12,
+            s1, s2, s3);
+    }
+}
+
+template
+static __global__ void k_set_rows(
+        const src_t * __restrict__ src0, const int64_t * __restrict__ src1, dst_t * __restrict__ dst,
+        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+        const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
+        const int64_t s01, const int64_t s02, const int64_t s03,
+        const int64_t s10, const int64_t s11, const int64_t s12,
+        const int64_t s1, const int64_t s2, const int64_t s3) {
+
+    const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
+    const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
+
+    if (i >= ne_total) {
+        return;
+    }
+
+    const int64_t i03 = i / (ne00 * ne01 * ne02);
+    const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
+    const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
+    const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
+
+    const int64_t i12 = i03 % ne12;
+    const int64_t i11 = i02 % ne11;
+    const int64_t i10 = i01;
+
+    const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
+
+    const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
+    dst_t * dst_row_ptr    = dst + dst_row*s1 + i02*s2 + i03*s3;
+
+    dst_row_ptr[i00] = ggml_cuda_cast(src0_row[i00]);
+
+    GGML_UNUSED(ne10);
+    GGML_UNUSED(ne13);
+}
+
+template
+static void set_rows_cuda(
+        const src_t * src0_d, const int64_t * src1_d, dst_t * dst_d,
+        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+        const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
+        const size_t nb01, const size_t nb02, const size_t nb03,
+        const size_t nb10, const size_t nb11, const size_t nb12,
+        const size_t nb1, const size_t nb2, const size_t nb3,
+        cudaStream_t stream) {
+
+    const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
+    const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE;
+    const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE);
+    const dim3 grid_size(num_blocks);
+
+
+    const int64_t s01 = nb01/sizeof(src_t);
+    const int64_t s02 = nb02/sizeof(src_t);
+    const int64_t s03 = nb03/sizeof(src_t);
+    const int64_t s10 = nb10/sizeof(int64_t);
+    const int64_t s11 = nb11/sizeof(int64_t);
+    const int64_t s12 = nb12/sizeof(int64_t);
+    const int64_t s1  = nb1/sizeof(dst_t);
+    const int64_t s2  = nb2/sizeof(dst_t);
+    const int64_t s3  = nb3/sizeof(dst_t);
+
+    if (ne_total > 0) {
+        k_set_rows<<>>(
+            src0_d, src1_d, dst_d,
+            ne00, ne01, ne02, ne03,
+            ne10, ne11, ne12, ne13,
+            s01, s02, s03,
+            s10, s11, s12,
+            s1, s2, s3);
+    }
+}
+
+
+void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_I64);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const float * src0_d   = (const float *)src0->data;
+    const int64_t * src1_d = (const int64_t *)src1->data;
+
+    cudaStream_t stream = ctx.stream();
+
+
+
+    if (dst->type == GGML_TYPE_F32) {
+        set_rows_cuda(
+            src0_d, src1_d, (float*)dst->data,
+            ne00, ne01, ne02, ne03,
+            ne10, ne11, ne12, ne13,
+            nb01, nb02, nb03,
+            nb10, nb11, nb12,
+            nb1, nb2, nb3,
+            stream
+        );
+    } else if (dst->type == GGML_TYPE_F16) {
+        set_rows_cuda(
+            src0_d, src1_d, (half*)dst->data,
+            ne00, ne01, ne02, ne03,
+            ne10, ne11, ne12, ne13,
+            nb01, nb02, nb03,
+            nb10, nb11, nb12,
+            nb1, nb2, nb3,
+            stream
+        );
+    } else if (dst->type == GGML_TYPE_BF16) {
+        set_rows_cuda(
+            src0_d, src1_d, (nv_bfloat16*)dst->data,
+            ne00, ne01, ne02, ne03,
+            ne10, ne11, ne12, ne13,
+            nb01, nb02, nb03,
+            nb10, nb11, nb12,
+            nb1, nb2, nb3,
+            stream
+        );
+    } else if (dst->type == GGML_TYPE_Q4_0) {
+        set_rows_cuda_quant(
+            src0_d, src1_d, (block_q4_0*)dst->data,
+            ne00, ne01, ne02, ne03,
+            ne10, ne11, ne12, ne13,
+            nb01, nb02, nb03,
+            nb10, nb11, nb12,
+            nb1, nb2, nb3,
+            stream
+        );
+    } else if (dst->type == GGML_TYPE_Q4_1) {
+        set_rows_cuda_quant(
+            src0_d, src1_d, (block_q4_1*)dst->data,
+            ne00, ne01, ne02, ne03,
+            ne10, ne11, ne12, ne13,
+            nb01, nb02, nb03,
+            nb10, nb11, nb12,
+            nb1, nb2, nb3,
+            stream
+        );
+    } else if (dst->type == GGML_TYPE_Q5_0) {
+        set_rows_cuda_quant(
+            src0_d, src1_d, (block_q5_0*)dst->data,
+            ne00, ne01, ne02, ne03,
+            ne10, ne11, ne12, ne13,
+            nb01, nb02, nb03,
+            nb10, nb11, nb12,
+            nb1, nb2, nb3,
+            stream
+        );
+    } else if (dst->type == GGML_TYPE_Q5_1) {
+        set_rows_cuda_quant(
+            src0_d, src1_d, (block_q5_1*)dst->data,
+            ne00, ne01, ne02, ne03,
+            ne10, ne11, ne12, ne13,
+            nb01, nb02, nb03,
+            nb10, nb11, nb12,
+            nb1, nb2, nb3,
+            stream
+        );
+    } else if (dst->type == GGML_TYPE_Q8_0) {
+        set_rows_cuda_quant(
+            src0_d, src1_d, (block_q8_0*)dst->data,
+            ne00, ne01, ne02, ne03,
+            ne10, ne11, ne12, ne13,
+            nb01, nb02, nb03,
+            nb10, nb11, nb12,
+            nb1, nb2, nb3,
+            stream
+        );
+    } else if (dst->type == GGML_TYPE_IQ4_NL) {
+        set_rows_cuda_quant(
+            src0_d, src1_d, (block_iq4_nl*)dst->data,
+            ne00, ne01, ne02, ne03,
+            ne10, ne11, ne12, ne13,
+            nb01, nb02, nb03,
+            nb10, nb11, nb12,
+            nb1, nb2, nb3,
+            stream
+        );
+    } else {
+        GGML_ABORT("unsupported type %s", ggml_type_name(dst->type));
+    }
+}
diff --git a/ggml/src/ggml-cuda/set-rows.cuh b/ggml/src/ggml-cuda/set-rows.cuh
new file mode 100644
index 00000000000..c140c0873c8
--- /dev/null
+++ b/ggml/src/ggml-cuda/set-rows.cuh
@@ -0,0 +1,7 @@
+#pragma once
+
+#include "common.cuh"
+
+#define CUDA_SET_ROWS_BLOCK_SIZE 256
+
+void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/softcap.cu b/ggml/src/ggml-cuda/softcap.cu
new file mode 100644
index 00000000000..40dfe45d65c
--- /dev/null
+++ b/ggml/src/ggml-cuda/softcap.cu
@@ -0,0 +1,34 @@
+#include "softcap.cuh"
+
+static __global__ void softcap_f32(const float * x, float * dst, const float scale, const float softcap, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    dst[i] = tanhf(scale * x[i]) * softcap;
+}
+
+static void softcap_f32_cuda(const float * x, float * dst, const float scale, const float softcap, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_SOFTCAP_BLOCK_SIZE - 1) / CUDA_SOFTCAP_BLOCK_SIZE;
+    softcap_f32<<>>(x, dst, scale, softcap, k);
+}
+
+// fused GGML_OP_SCALE + GGML_UNARY_OP_TANH + GGML_OP_SCALE
+void ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * src) {
+    const ggml_tensor * src0 = src->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    float scale;
+    float softcap;
+    memcpy(&scale,   (float *) src->op_params + 0, sizeof(float));
+    memcpy(&softcap, (float *) dst->op_params + 0, sizeof(float));
+
+    softcap_f32_cuda(src0_d, dst_d, scale, softcap, ggml_nelements(src0), stream);
+}
diff --git a/ggml/src/ggml-cuda/softcap.cuh b/ggml/src/ggml-cuda/softcap.cuh
new file mode 100644
index 00000000000..6d34fb2bee4
--- /dev/null
+++ b/ggml/src/ggml-cuda/softcap.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_SOFTCAP_BLOCK_SIZE 256
+
+void ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * src);
diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu
index 14543e978cf..eeacde0bdb1 100644
--- a/ggml/src/ggml-cuda/softmax.cu
+++ b/ggml/src/ggml-cuda/softmax.cu
@@ -45,7 +45,7 @@ struct soft_max_params {
 #endif // __clang__
 template 
 static __global__ void soft_max_f32(
-        const float * x, const T * mask, float * dst, const soft_max_params p) {
+        const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params p) {
     const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
 
     const int tid  = threadIdx.x;
@@ -77,7 +77,7 @@ static __global__ void soft_max_f32(
     // shared memory buffer to cache values between iterations:
     float * vals = use_shared ? buf_iw + WARP_SIZE : dst;
 
-    float max_val = -INFINITY;
+    float max_val = sinks ? sinks[i02] : -INFINITY;
 
 #pragma unroll
     for (int col0 = 0; col0 < ncols; col0 += block_size) {
@@ -143,6 +143,10 @@ static __global__ void soft_max_f32(
         tmp = warp_reduce_sum(tmp);
     }
 
+    if (sinks) {
+        tmp += expf(sinks[i02] - max_val);
+    }
+
     const float inv_sum = 1.0f / tmp;
 
 #pragma unroll
@@ -183,7 +187,7 @@ static __global__ void soft_max_back_f32(
 }
 
 template
-static void launch_soft_max_kernels(const float * x, const T * mask, float * dst,
+static void launch_soft_max_kernels(const float * x, const T * mask, const float * sinks, float * dst,
                              const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
 {
     const int id       = ggml_cuda_get_device();
@@ -196,7 +200,7 @@ static void launch_soft_max_kernels(const float * x, const T * mask, float * dst
         if (p.ncols == ncols) {
             CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32), smpbo);
             soft_max_f32<<>>
-                (x, mask, dst, p);
+                (x, mask, sinks, dst, p);
             return true;
         }
         return false;
@@ -209,12 +213,12 @@ static void launch_soft_max_kernels(const float * x, const T * mask, float * dst
 
     //default case
     CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32), smpbo);
-    soft_max_f32<<>>(x, mask, dst, p);
+    soft_max_f32<<>>(x, mask, sinks, dst, p);
 }
 
 
 template
-static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
+static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) {
     int nth = WARP_SIZE;
     const int64_t ncols_x = params.ncols;
 
@@ -230,10 +234,10 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
 
 
     if (nbytes_shared <= smpbo) {
-        launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared);
+        launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared);
     } else {
         const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
-        soft_max_f32<<>>(x, mask, dst, params);
+        soft_max_f32<<>>(x, mask, sinks, dst, params);
     }
 }
 
@@ -249,9 +253,11 @@ static void soft_max_back_f32_cuda(
 void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
+    const ggml_tensor * src2 = dst->src[2];
 
     const float * src0_d = (const float *) src0->data;
     const void  * src1_d = src1 ? (const void *) src1->data : nullptr;
+    const void  * src2_d = src2 ? (const void *) src2->data : nullptr;
     float       *  dst_d = (float *) dst->data;
 
     cudaStream_t stream = ctx.stream();
@@ -309,9 +315,9 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     params.m1 = m1;
 
     if (use_f16) {
-        soft_max_f32_cuda(src0_d, (const half  *) src1_d, dst_d, params, stream);
+        soft_max_f32_cuda(src0_d, (const half  *) src1_d, (const float *) src2_d, dst_d, params, stream);
     } else {
-        soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, params, stream);
+        soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream);
     }
 }
 
diff --git a/ggml/src/ggml-cuda/ssm-scan.cu b/ggml/src/ggml-cuda/ssm-scan.cu
index c9184398b42..dc9a7d58d05 100644
--- a/ggml/src/ggml-cuda/ssm-scan.cu
+++ b/ggml/src/ggml-cuda/ssm-scan.cu
@@ -1,87 +1,117 @@
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
+#define USE_CUB
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
+
+#ifdef USE_CUB
+#include 
+using namespace cub;
+#endif // USE_CUB
+
 #include "ssm-scan.cuh"
 
-template 
-__global__ void __launch_bounds__(splitD, 2)
-    ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
-                 const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
+// We would like to keep pragma unroll for cases where L_template is not 0,
+// so we suppress the clang transformation warning.
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wpass-failed"
+#endif // __clang__
+template 
+__global__ void __launch_bounds__(splitD, 1)
+    ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2,
+                 const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5,
                  const int32_t * __restrict__ src6, float * __restrict__ dst,
                  const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
                  const int src2_nb1, const int src2_nb2, const int src3_nb1,
                  const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
-                 const int64_t s_off, const int64_t d_inner, const int64_t L) {
-
-    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-    const int bidx = blockIdx.x;  // split along B (sequences)
-    const int bidy = blockIdx.y;  // split along D (d_inner)
-    const int tid  = threadIdx.x;
-    const int wid  = tid / 32;
-    const int wtid = tid % 32;
-
-    extern __shared__ float smem[];
-    const int               stride_sA  = N + 1;
-    const int               stride_ss0 = N + 1;
-    float *                 smem_A     = smem;
-    float *                 smem_s0    = smem_A + splitD * stride_sA;
-
-    const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2);
-    const float * x_block  = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof(float));
-    const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
-    const float * A_block  = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
-    const float * B_block  = (const float *) ((const char *) src4 + (bidx * src4_nb3));
-    const float * C_block  = (const float *) ((const char *) src5 + (bidx * src5_nb3));
-    float *       y_block  = (float *) ((char *) dst + (bidx * d_inner * L * sizeof(float)) + bidy * splitD * sizeof(float));
-    float *       s_block  = (float *) ((char *) dst + s_off + bidx * src0_nb3 + bidy * splitD * src0_nb2);
-
-    const int stride_s0 = src0_nb2 / sizeof(float);
-    const int stride_x  = src1_nb2 / sizeof(float);
+                 const int64_t s_off, const int64_t d_inner, const int64_t L_param)
+{
+    const size_t L = L_template == 0 ? L_param : L_template;
+    const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx.x] * src0_nb3 + blockIdx.y * splitD * src0_nb2);
+    const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb3) + blockIdx.y * splitD * sizeof(float));
+    const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float));
+    const float *A_block = (const float *)((const char *)src3 + blockIdx.y * splitD * src3_nb1);
+    const float *B_block = (const float *)((const char *)src4 + (blockIdx.x * src4_nb3));
+    const float *C_block = (const float *)((const char *)src5 + (blockIdx.x * src5_nb3));
+    float *y_block = (float *)((char *)dst + (blockIdx.x * d_inner * L * sizeof(float)) + blockIdx.y * splitD * sizeof(float));
+    float *s_block = (float *)((char *)dst + s_off + blockIdx.x * src0_nb3 + blockIdx.y * splitD * src0_nb2);
+
+    const int stride_x = src1_nb2 / sizeof(float);
     const int stride_dt = src2_nb1 / sizeof(float);
-    const int stride_A  = src3_nb1 / sizeof(float);
-    const int stride_B  = src4_nb2 / sizeof(float);
-    const int stride_C  = src5_nb2 / sizeof(float);
-    const int stride_s  = stride_s0;
-    const int stride_y  = d_inner;
+    const int stride_B = src4_nb2 / sizeof(float);
+    const int stride_C = src5_nb2 / sizeof(float);
+    const int stride_y = d_inner;
 
-    // can N not be 16? for example 32?
-    if (N == 16) {
-#pragma unroll
-        for (size_t i = 0; i < splitD / 4; i += 2) {
-            float value = A_block[(wid * warp_size + i) * stride_A + wtid];
-            // todo: bank conflict
-            // I am always confused with how to use the swizzling method to solve
-            // bank conflit. Hoping somebody can tell me.
-            smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
-        }
+    float regA[N];
+    float regs0[N];
+
+    __shared__ float smemB[N];
+    __shared__ float smemC[N];
+
+#ifdef USE_CUB
+    using BlockLoad = cub::BlockLoad;
+    using BlockStore = cub::BlockStore;
+
+    union CubTempStorage {
+        typename BlockLoad::TempStorage load_temp;
+        typename BlockStore::TempStorage store_temp;
+    };
+    __shared__ CubTempStorage cub_temp_storage;
+
+    BlockLoad(cub_temp_storage.load_temp).Load(A_block, regA);
+    BlockLoad(cub_temp_storage.load_temp).Load(s0_block, regs0);
+#else
+    const int stride_s0 = src0_nb2 / sizeof(float);
+    const int stride_A = src3_nb1 / sizeof(float);
 #pragma unroll
-        for (size_t i = 0; i < splitD / 4; i += 2) {
-            float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid];
-            smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
-        }
+    for (size_t n = 0; n < N; ++n)
+    {
+        regA[n] = A_block[threadIdx.x * stride_A + n];
+        regs0[n] = s0_block[threadIdx.x * stride_s0 + n];
     }
+#endif
 
-    __syncthreads();
+#pragma unroll
+    for (size_t i = 0; i < L; i++)
+    {
+        if (threadIdx.x < N)
+        {
+            smemB[threadIdx.x] = B_block[i * stride_B + threadIdx.x];
+            smemC[threadIdx.x] = C_block[i * stride_C + threadIdx.x];
+        }
+        __syncthreads();
 
-    for (int64_t i = 0; i < L; i++) {
-        float dt_soft_plus = dt_block[i * stride_dt + tid];
-        if (dt_soft_plus <= 20.0f) {
-            dt_soft_plus = log1pf(exp(dt_soft_plus));
+        float dt_soft_plus = dt_block[i * stride_dt + threadIdx.x];
+        if (dt_soft_plus <= 20.0f)
+        {
+            dt_soft_plus = log1pf(expf(dt_soft_plus));
         }
-        float x_dt = x_block[i * stride_x + tid] * dt_soft_plus;
+        float x_dt = x_block[i * stride_x + threadIdx.x] * dt_soft_plus;
+
         float sumf = 0.0f;
 #pragma unroll
-        for (size_t j = 0; j < N; j++) {
-            float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) +
-                          (B_block[i * stride_B + j] * x_dt);
-            sumf += state * C_block[i * stride_C + j];
-            if (i == L - 1) {
-                s_block[tid * stride_s + j] = state;
-            } else {
-                smem_s0[tid * stride_ss0 + j] = state;
-            }
+        for (size_t n = 0; n < N; n++)
+        {
+            float state = regs0[n] * expf(dt_soft_plus * regA[n]) + smemB[n] * x_dt;
+            sumf += state * smemC[n];
+            regs0[n] = state;
         }
-        __syncthreads();
-        y_block[i * stride_y + tid] = sumf;
+        y_block[i * stride_y + threadIdx.x] = sumf;
     }
+
+#ifdef USE_CUB
+    BlockStore(cub_temp_storage.store_temp).Store(s_block, regs0);
+#else
+    const int stride_s = stride_s0;
+#pragma unroll
+    for (size_t n = 0; n < N; ++n)
+    {
+        s_block[threadIdx.x * stride_s + n] = regs0[n];
+    }
+#endif
 }
+#ifdef __clang__
+#pragma clang diagnostic pop
+#endif // __clang__
 
 // assumes as many threads as d_state
 template 
@@ -201,11 +231,11 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
                               const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
                               const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
                               cudaStream_t stream) {
+    const int threads = 128;
     // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
     if (src3_nb1 == sizeof(float)) {
         // Mamba-2
         if (d_state == 128) {
-            const int threads = 128;
             GGML_ASSERT(d_state % threads == 0);
             // NOTE: can be any power of two between 4 and 64
             const int splitH = 16;
@@ -229,7 +259,6 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
             GGML_ABORT("doesn't support d_state!=(128 or 256).");
         }
     } else {
-        const int threads = 128;
         // Mamba-1
         GGML_ASSERT(n_head % threads == 0);
         GGML_ASSERT(head_dim == 1);
@@ -237,10 +266,63 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
         const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
         const int  smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
         if (d_state == 16) {
-            ssm_scan_f32<128, 16><<>>(
-                src0, src1, src2, src3, src4, src5, src6, dst,
+            switch (n_tok)
+            {
+            case 1:
+                ssm_scan_f32<<>>(
+                    src0, src1, src2, src3, src4, src5, src6, dst,
+                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+                break;
+            case 2:
+                ssm_scan_f32<<>>(
+                    src0, src1, src2, src3, src4, src5, src6, dst,
+                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+                break;
+            case 3:
+                ssm_scan_f32<<>>(
+                    src0, src1, src2, src3, src4, src5, src6, dst,
+                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+                break;
+            case 4:
+                ssm_scan_f32<<>>(
+                    src0, src1, src2, src3, src4, src5, src6, dst,
+                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+                break;
+            case 5:
+                ssm_scan_f32<<>>(
+                    src0, src1, src2, src3, src4, src5, src6, dst,
                 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
                 src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+                break;
+            case 6:
+                ssm_scan_f32<<>>(
+                    src0, src1, src2, src3, src4, src5, src6, dst,
+                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+                break;
+            case 7:
+                ssm_scan_f32<<>>(
+                    src0, src1, src2, src3, src4, src5, src6, dst,
+                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+                break;
+            case 8:
+                ssm_scan_f32<<>>(
+                    src0, src1, src2, src3, src4, src5, src6, dst,
+                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+                break;
+            default:
+                ssm_scan_f32<<>>(
+                    src0, src1, src2, src3, src4, src5, src6, dst,
+                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+                break;
+            }
         } else {
             GGML_ABORT("doesn't support d_state!=16.");
         }
diff --git a/ggml/src/ggml-cuda/sum.cu b/ggml/src/ggml-cuda/sum.cu
index eb3d7cdba98..c56257b4406 100644
--- a/ggml/src/ggml-cuda/sum.cu
+++ b/ggml/src/ggml-cuda/sum.cu
@@ -1,19 +1,15 @@
-#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
-#define USE_CUB
-#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
+#include "sum.cuh"
+#include "sumrows.cuh"
 
-#ifdef USE_CUB
+#ifdef GGML_CUDA_USE_CUB
 #include 
 using namespace cub;
-#endif // USE_CUB
-
-#include "sumrows.cuh"
-#include "sum.cuh"
+#endif  // GGML_CUDA_USE_CUB
 
 #include 
 
 void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream) {
-#ifdef USE_CUB
+#ifdef GGML_CUDA_USE_CUB
     size_t tmp_size = 0;
     DeviceReduce::Sum(nullptr,       tmp_size, x, dst, ne, stream);
     ggml_cuda_pool_alloc tmp_alloc(pool, tmp_size);
@@ -23,7 +19,7 @@ void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int
     // For AMD there is rocPRIM which could be used as a drop-in replacement via hipcub but this would require C++11 -> C++14.
     sum_rows_f32_cuda(x, dst, ne, 1, stream);
     GGML_UNUSED(pool);
-#endif // USE_CUB
+#endif // GGML_CUDA_USE_CUB
 }
 
 void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
diff --git a/ggml/src/ggml-cuda/sumrows.cu b/ggml/src/ggml-cuda/sumrows.cu
index 2eee08fa073..4025771aadb 100644
--- a/ggml/src/ggml-cuda/sumrows.cu
+++ b/ggml/src/ggml-cuda/sumrows.cu
@@ -1,9 +1,17 @@
+#include "reduce_rows.cuh"
 #include "sumrows.cuh"
 
 void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    const dim3 block_dims(WARP_SIZE, 1, 1);
+    const int  id  = ggml_cuda_get_device();
+    const int  nsm = ggml_cuda_info().devices[id].nsm;
     const dim3 block_nums(nrows, 1, 1);
-    reduce_rows_f32<<>>(x, dst, ncols);
+    if ((nrows / nsm) < 2) {
+        const dim3 block_dims(512, 1, 1);
+        reduce_rows_f32<<>>(x, dst, ncols);
+    } else {
+        const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
+        reduce_rows_f32<<>>(x, dst, ncols);
+    }
 }
 
 void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -19,8 +27,17 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const int64_t ncols = src0->ne[0];
     const int64_t nrows = ggml_nrows(src0);
 
-    const dim3 block_dims(WARP_SIZE, 1, 1);
     const dim3 block_nums(nrows, 1, 1);
 
-    reduce_rows_f32<<>>(src0_d, dst_d, ncols);
+    const int id  = ggml_cuda_get_device();
+    const int nsm = ggml_cuda_info().devices[id].nsm;
+    if ((nrows / nsm) < 2) {
+        // Increase num threads to 512 for small nrows to better hide the latency
+        const dim3 block_dims(512, 1, 1);
+        reduce_rows_f32<<>>(src0_d, dst_d, ncols);
+    } else {
+        // Enough active SMs to hide latency, use smaller blocks to allow better scheduling
+        const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
+        reduce_rows_f32<<>>(src0_d, dst_d, ncols);
+    }
 }
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu
new file mode 100644
index 00000000000..c14624c52ca
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_MXFP4);
diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu
index f9c7b83c40d..5aff8a876af 100644
--- a/ggml/src/ggml-cuda/unary.cu
+++ b/ggml/src/ggml-cuda/unary.cu
@@ -83,6 +83,10 @@ static __device__ __forceinline__ float op_log(float x) {
     return logf(x);
 }
 
+static __device__ __forceinline__ float op_elu(float x) {
+    return (x > 0.f) ? x : expm1f(x);
+}
+
 template 
 static __global__ void unary_op_kernel(const T * x, T * dst, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
@@ -196,6 +200,9 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     ggml_cuda_op_unary(ctx, dst);
 }
 
+void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_unary(ctx, dst);
+}
 /* gated ops */
 
 template 
@@ -293,6 +300,81 @@ void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     ggml_cuda_op_unary_gated(ctx, dst);
 }
 
+// swiglu_oai
+
+template 
+static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, float alpha, float limit) {
+    const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    // perform base op and multiply with gate (either offset in same tensor or a separate one)
+    const int64_t j0 = (i / n) * o0 + (i % n);
+    const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
+
+    float xi = x[j0];
+    float gi = g[j1];
+    xi = fminf(xi, limit);
+    gi = fmaxf(fminf(gi, limit), -limit);
+
+    float out_glu = xi / (1.0f + expf(-xi * alpha));
+    out_glu = out_glu * (1.0f + gi);
+
+    dst[i] = out_glu;
+}
+
+template 
+static void swiglu_oai_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream) {
+    const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE;
+    swiglu_oai_kernel<<>>(x, g, dst, k, n, o0, o1, alpha, limit);
+}
+
+void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    void * src0_d = src0->data;
+    void * src1_d = src1 ? src1->data : src0->data;
+    const int64_t src0_o = src0->nb[1];
+    const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
+    void * dst_d = dst->data;
+    const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous_1(src0));
+    GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == dst->type);
+    GGML_ASSERT(dst->ne[0] == nc);
+    GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
+
+    if (src1) {
+        GGML_ASSERT(ggml_is_contiguous_1(src1));
+        GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
+        GGML_ASSERT(src1->ne[0] == nc);
+        GGML_ASSERT(src0->type == src1->type);
+    }
+
+    //const int32_t swapped = ((const int32_t *) dst->op_params)[1];
+    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
+    const float alpha = ggml_get_op_params_f32(dst, 2);
+    const float limit = ggml_get_op_params_f32(dst, 3);
+
+    float * src0_p = (float *) src0_d;
+    float * src1_p = (float *) src1_d;
+
+    if (!src1) {
+        src0_p += swapped ? nc : 0;
+        src1_p += swapped ? 0 : nc;
+    }
+
+    swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);
+}
+
 /* silu_back */
 
 static __device__ __forceinline__ float op_silu_back(float grad, float x) {
diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh
index 289d690e5cf..da3caf1d896 100644
--- a/ggml/src/ggml-cuda/unary.cuh
+++ b/ggml/src/ggml-cuda/unary.cuh
@@ -59,12 +59,16 @@ void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
+void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
 void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
+void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
 void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh
index ba195e1d100..d8f9aa5ba62 100644
--- a/ggml/src/ggml-cuda/vecdotq.cuh
+++ b/ggml/src/ggml-cuda/vecdotq.cuh
@@ -1,8 +1,20 @@
 #pragma once
 
 #include "common.cuh"
+
 #include 
 
+static __device__ __forceinline__ int get_int_b1(const void * x, const int & i32) {
+    const uint8_t * x8 = (const uint8_t *) x;
+
+    int x32  = x8[4*i32 + 0] <<  0;
+    x32     |= x8[4*i32 + 1] <<  8;
+    x32     |= x8[4*i32 + 2] << 16;
+    x32     |= x8[4*i32 + 3] << 24;
+
+    return x32;
+}
+
 static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
     const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
 
@@ -16,6 +28,20 @@ static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32
     return ((const int *) x)[i32]; // assume at least 4 byte alignment
 }
 
+static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) {
+    const int      q0_32  = (q4 >> 0) & 0x0F0F0F0F;
+    const int8_t * q0_8   = (const int8_t *) &q0_32;
+    const char4    val0_8 = make_char4(
+        table[q0_8[0]], table[q0_8[1]], table[q0_8[2]], table[q0_8[3]]);
+
+    const int      q1_32  = (q4 >> 4) & 0x0F0F0F0F;
+    const int8_t * q1_8   = (const int8_t *) &q1_32;
+    const char4    val1_8 = make_char4(
+        table[q1_8[0]], table[q1_8[1]], table[q1_8[2]], table[q1_8[3]]);
+
+    return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
+}
+
 // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
 // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
 
@@ -211,6 +237,30 @@ template  static __device__ __forceinline__ float vec_dot_q8_0_16_q8_1_
     return d8_1*sumf;
 }
 
+#define VDR_MXFP4_Q8_1_MMVQ 2
+#define VDR_MXFP4_Q8_1_MMQ  4
+
+static __device__ __forceinline__ float vec_dot_mxfp4_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+    const block_mxfp4 * bq4 = (const block_mxfp4 *) vbq + kbx;
+
+    const int * q8 = (const int *) bq8_1->qs + iqs;
+
+    int sumi = 0;
+#pragma unroll
+    for (int l = 0; l < VDR_MXFP4_Q8_1_MMVQ; ++l) {
+        const int aux_q4 = get_int_b1(bq4->qs, iqs + l);
+        const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
+
+        sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
+        sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
+    }
+
+    const float d = ggml_cuda_e8m0_to_fp32(bq4->e) * 0.5f * __low2float(bq8_1->ds);
+    return d * sumi;
+}
+
 #define VDR_Q2_K_Q8_1_MMVQ 1
 #define VDR_Q2_K_Q8_1_MMQ  4
 
@@ -1068,20 +1118,6 @@ static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
     return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);
 }
 
-static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
-    const int      q0_32  = (q4 >> 0) & 0x0F0F0F0F;
-    const int8_t * q0_8   = (const int8_t *) &q0_32;
-    const char4    val0_8 = make_char4(
-        kvalues_iq4nl[q0_8[0]], kvalues_iq4nl[q0_8[1]], kvalues_iq4nl[q0_8[2]], kvalues_iq4nl[q0_8[3]]);
-
-    const int      q1_32  = (q4 >> 4) & 0x0F0F0F0F;
-    const int8_t * q1_8   = (const int8_t *) &q1_32;
-    const char4    val1_8 = make_char4(
-        kvalues_iq4nl[q1_8[0]], kvalues_iq4nl[q1_8[1]], kvalues_iq4nl[q1_8[2]], kvalues_iq4nl[q1_8[3]]);
-
-    return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
-}
-
 #define VDR_IQ4_NL_Q8_1_MMVQ 2
 #define VDR_IQ4_NL_Q8_1_MMQ  4
 
@@ -1096,7 +1132,7 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
 #pragma unroll
     for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
         const int aux_q4 = get_int_b2(bq4->qs, iqs + l);
-        const int2 v = get_int_from_table_16(aux_q4);
+        const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
 
         sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
         sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
@@ -1118,7 +1154,7 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
 #pragma unroll
     for (int j = 0; j < 4; ++j) {
         const int aux_q4 = get_int_b4(bq4->qs, iqs + j);
-        const int2 v = get_int_from_table_16(aux_q4);
+        const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
 
         const int u0 = get_int_b4(bq8_1[iqs/4].qs, j + 0);
         const int u1 = get_int_b4(bq8_1[iqs/4].qs, j + 4);
diff --git a/ggml/src/ggml-cuda/vendors/cuda.h b/ggml/src/ggml-cuda/vendors/cuda.h
index 1746b073203..3b3086778ee 100644
--- a/ggml/src/ggml-cuda/vendors/cuda.h
+++ b/ggml/src/ggml-cuda/vendors/cuda.h
@@ -6,6 +6,10 @@
 #include 
 #include 
 
+#if CUDART_VERSION >= 12050
+#include 
+#endif // CUDART_VERSION >= 12050
+
 #if CUDART_VERSION < 11020
 #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
 #define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
index 184d445f5c0..6e9c67aca09 100644
--- a/ggml/src/ggml-cuda/vendors/hip.h
+++ b/ggml/src/ggml-cuda/vendors/hip.h
@@ -1,14 +1,10 @@
 #pragma once
 
-#define HIP_ENABLE_WARP_SYNC_BUILTINS 1
+#define HIP_DISABLE_WARP_SYNC_BUILTINS 1
 #include 
 #include 
 #include 
-#include 
-#ifdef __HIP_PLATFORM_AMD__
-// for rocblas_initialize()
-#include "rocblas/rocblas.h"
-#endif // __HIP_PLATFORM_AMD__
+#include 
 
 #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
 #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
@@ -139,7 +135,7 @@
 #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
 #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
 
-#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION >= 70000000
+#if HIP_VERSION >= 60500000
 #define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F
 #define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F
 #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F
@@ -151,7 +147,11 @@
 #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
 #define cublasComputeType_t hipblasDatatype_t
 #define cudaDataType_t hipblasDatatype_t
-#endif
+#endif // HIP_VERSION >= 6050000
+
+#if !defined(__HIP_PLATFORM_AMD__)
+#error "The HIP backend supports only AMD targets"
+#endif // !defined(__HIP_PLATFORM_AMD__)
 
 #define __CUDA_ARCH__ 1300
 
@@ -160,15 +160,26 @@
 #endif
 
 #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
-#define CDNA
+#define CDNA // For the entire family
+#endif
+
+#if defined(__gfx942__)
+#define CDNA3
+#endif
+
+#if defined(__gfx90a__)
+#define CDNA2
+#endif
+
+#if defined(__gfx908__)
+#define CDNA1
 #endif
 
 #if defined(__GFX12__)
 #define RDNA4
 #endif
 
-#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
-    defined(__gfx1150__) || defined(__gfx1151__)
+#if defined(__GFX11__)
 #define RDNA3
 #endif
 
@@ -185,7 +196,8 @@
     #define __has_builtin(x) 0
 #endif
 
-typedef hip_bfloat16 nv_bfloat16;
+typedef __hip_bfloat16 nv_bfloat16;
+typedef __hip_bfloat162 nv_bfloat162;
 
 typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
 typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
@@ -236,17 +248,3 @@ static __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigne
     }
     return c;
 }
-
-#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
-// __shfl_xor() for half2 was added in ROCm 5.6
-static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int width) {
-    typedef union half2_b32 {
-        half2 val;
-        int   b32;
-    } half2_b32_t;
-    half2_b32_t tmp;
-    tmp.val = var;
-    tmp.b32 = __shfl_xor(tmp.b32, laneMask, width);
-    return tmp.val;
-}
-#endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
diff --git a/ggml/src/ggml-cuda/vendors/musa.h b/ggml/src/ggml-cuda/vendors/musa.h
index 937779a90af..8c55a2e4e56 100644
--- a/ggml/src/ggml-cuda/vendors/musa.h
+++ b/ggml/src/ggml-cuda/vendors/musa.h
@@ -13,7 +13,7 @@
 #define CUBLAS_OP_N MUBLAS_OP_N
 #define CUBLAS_OP_T MUBLAS_OP_T
 #define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
-#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
+#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_TENSOR_OP_MATH
 #define CUDA_R_16F  MUSA_R_16F
 #define CUDA_R_16BF MUSA_R_16BF
 #define CUDA_R_32F  MUSA_R_32F
@@ -29,7 +29,7 @@
 #define cublasSgemm mublasSgemm
 #define cublasStatus_t mublasStatus_t
 #define cublasOperation_t mublasOperation_t
-#define cublasGetStatusString mublasStatus_to_string
+#define cublasGetStatusString mublasGetStatusString
 #define cudaDataType_t musaDataType_t
 #define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
 #define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
@@ -137,4 +137,5 @@
 #define cudaStreamEndCapture musaStreamEndCapture
 #define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
 
-typedef mt_bfloat16 nv_bfloat16;
+typedef __mt_bfloat16 nv_bfloat16;
+typedef __mt_bfloat162 nv_bfloat162;
diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt
index e29df98560e..d327b90cceb 100644
--- a/ggml/src/ggml-hip/CMakeLists.txt
+++ b/ggml/src/ggml-hip/CMakeLists.txt
@@ -46,8 +46,8 @@ if (GGML_HIP_ROCWMMA_FATTN)
     endif()
 endif()
 
-if (${hip_VERSION} VERSION_LESS 5.5)
-    message(FATAL_ERROR "At least ROCM/HIP V5.5 is required")
+if (${hip_VERSION} VERSION_LESS 6.1)
+    message(FATAL_ERROR "At least ROCM/HIP V6.1 is required")
 endif()
 
 message(STATUS "HIP and hipBLAS found")
@@ -113,10 +113,18 @@ if (GGML_HIP_ROCWMMA_FATTN)
     add_compile_definitions(GGML_HIP_ROCWMMA_FATTN)
 endif()
 
+if (NOT GGML_HIP_MMQ_MFMA)
+    add_compile_definitions(GGML_HIP_NO_MMQ_MFMA)
+endif()
+
 if (GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 OR ${hip_VERSION} VERSION_GREATER_EQUAL 7.0)
     add_compile_definitions(GGML_HIP_ROCWMMA_FATTN_GFX12)
 endif()
 
+if (GGML_HIP_EXPORT_METRICS)
+    set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Rpass-analysis=kernel-resource-usage --save-temps")
+endif()
+
 if (NOT GGML_CUDA_FA)
     add_compile_definitions(GGML_CUDA_NO_FA)
 endif()
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
index 4972558c98b..19a7adb2d10 100644
--- a/ggml/src/ggml-impl.h
+++ b/ggml/src/ggml-impl.h
@@ -73,6 +73,22 @@ static inline int ggml_up(int n, int m) {
     return (n + m - 1) & ~(m - 1);
 }
 
+// TODO: move to ggml.h?
+static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
+    if (a->type != b->type) {
+        return false;
+    }
+    for (int i = 0; i < GGML_MAX_DIMS; i++) {
+        if (a->ne[i] != b->ne[i]) {
+            return false;
+        }
+        if (a->nb[i] != b->nb[i]) {
+            return false;
+        }
+    }
+    return true;
+}
+
 //
 // logging
 //
@@ -394,6 +410,67 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
 #define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
 #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
 
+static inline float ggml_e8m0_to_fp32(uint8_t x) {
+    uint32_t bits;  // Stores the raw bit representation of the float
+
+    // Handle special case for minimum exponent (denormalized float)
+    if (x == 0) {
+        // Bit pattern for 2^(-127):
+        // - Sign bit: 0 (positive)
+        // - Exponent: 0 (denormalized number)
+        // - Mantissa: 0x400000 (0.5 in fractional form)
+        // Value = 0.5 * 2^(-126) = 2^(-127)
+        bits = 0x00400000;
+    }
+    // note: disabled as we don't need to handle NaNs
+    //// Handle special case for NaN (all bits set)
+    //else if (x == 0xFF) {
+    //    // Standard quiet NaN pattern:
+    //    // - Sign bit: 0
+    //    // - Exponent: all 1s (0xFF)
+    //    // - Mantissa: 0x400000 (quiet NaN flag)
+    //    bits = 0x7FC00000;
+    //}
+    // Normalized values (most common case)
+    else {
+        // Construct normalized float by shifting exponent into position:
+        // - Exponent field: 8 bits (positions 30-23)
+        // - Mantissa: 0 (implicit leading 1)
+        // Value = 2^(x - 127)
+        bits = (uint32_t) x << 23;
+    }
+
+    float result;  // Final float value
+                   // Safely reinterpret bit pattern as float without type-punning issues
+    memcpy(&result, &bits, sizeof(float));
+    return result;
+}
+
+// Equal to ggml_e8m0_to_fp32/2
+// Useful with MXFP4 quantization since the E0M2 values are doubled
+static inline float ggml_e8m0_to_fp32_half(uint8_t x) {
+    uint32_t bits;
+
+    // For x < 2: use precomputed denormal patterns
+    if (x < 2) {
+        // 0x00200000 = 2^(-128), 0x00400000 = 2^(-127)
+        bits = 0x00200000 << x;
+    }
+    // For x >= 2: normalized exponent adjustment
+    else {
+        // 0.5 * 2^(x-127) = 2^(x-128) = normalized with exponent (x-1)
+        bits = (uint32_t)(x - 1) << 23;
+    }
+    // Note: NaNs are not handled here
+
+    float result;
+    memcpy(&result, &bits, sizeof(float));
+    return result;
+}
+
+#define GGML_E8M0_TO_FP32(x) ggml_e8m0_to_fp32(x)
+#define GGML_E8M0_TO_FP32_HALF(x) ggml_e8m0_to_fp32_half(x)
+
 /**
  * Converts brain16 to float32.
  *
diff --git a/ggml/src/ggml-kompute/CMakeLists.txt b/ggml/src/ggml-kompute/CMakeLists.txt
deleted file mode 100644
index c9109d5e8ee..00000000000
--- a/ggml/src/ggml-kompute/CMakeLists.txt
+++ /dev/null
@@ -1,166 +0,0 @@
-
-find_package(Vulkan COMPONENTS glslc REQUIRED)
-find_program(glslc_executable NAMES glslc HINTS Vulkan::glslc)
-
-if (NOT glslc_executable)
-    message(FATAL_ERROR "glslc not found")
-endif()
-
-ggml_add_backend_library(ggml-kompute
-                         ggml-kompute.cpp
-                         ../../include/ggml-kompute.h
-                        )
-
-target_link_libraries(ggml-kompute PRIVATE ggml-base kompute)
-target_include_directories(ggml-kompute PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
-
-add_compile_definitions(VULKAN_HPP_DISPATCH_LOADER_DYNAMIC=1)
-
-function(compile_shader)
-    set(options)
-    set(oneValueArgs)
-    set(multiValueArgs SOURCES)
-    cmake_parse_arguments(compile_shader "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
-    foreach(source ${compile_shader_SOURCES})
-        get_filename_component(filename ${source} NAME)
-        set(spv_file ${filename}.spv)
-        add_custom_command(
-            OUTPUT ${spv_file}
-            DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${source}
-            ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/common.comp
-            ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_getrows.comp
-            ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_mul_mv_q_n_pre.comp
-            ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_mul_mv_q_n.comp
-            COMMAND ${glslc_executable} --target-env=vulkan1.2 -o ${spv_file} ${CMAKE_CURRENT_SOURCE_DIR}/${source}
-            COMMENT "Compiling ${source} to ${spv_file}"
-            )
-
-        get_filename_component(RAW_FILE_NAME ${spv_file} NAME)
-        set(FILE_NAME "shader${RAW_FILE_NAME}")
-        string(REPLACE ".comp.spv" ".h" HEADER_FILE ${FILE_NAME})
-        string(TOUPPER ${HEADER_FILE} HEADER_FILE_DEFINE)
-        string(REPLACE "." "_" HEADER_FILE_DEFINE "${HEADER_FILE_DEFINE}")
-        set(OUTPUT_HEADER_FILE "${HEADER_FILE}")
-        message(STATUS "${HEADER_FILE} generating ${HEADER_FILE_DEFINE}")
-        if(CMAKE_GENERATOR MATCHES "Visual Studio")
-            add_custom_command(
-                OUTPUT ${OUTPUT_HEADER_FILE}
-                COMMAND ${CMAKE_COMMAND} -E echo "/*THIS FILE HAS BEEN AUTOMATICALLY GENERATED - DO NOT EDIT*/" > ${OUTPUT_HEADER_FILE}
-                COMMAND ${CMAKE_COMMAND} -E echo \"\#ifndef ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
-                COMMAND ${CMAKE_COMMAND} -E echo \"\#define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
-                COMMAND ${CMAKE_COMMAND} -E echo "namespace kp {" >> ${OUTPUT_HEADER_FILE}
-                COMMAND ${CMAKE_COMMAND} -E echo "namespace shader_data {" >> ${OUTPUT_HEADER_FILE}
-                COMMAND ${CMAKE_BINARY_DIR}/bin/$/xxd -i ${RAW_FILE_NAME} >> ${OUTPUT_HEADER_FILE}
-                COMMAND ${CMAKE_COMMAND} -E echo "}}" >> ${OUTPUT_HEADER_FILE}
-                COMMAND ${CMAKE_COMMAND} -E echo \"\#endif // define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
-                DEPENDS ${spv_file} xxd
-                COMMENT "Converting to hpp: ${FILE_NAME} ${CMAKE_BINARY_DIR}/bin/$/xxd"
-                )
-        else()
-            add_custom_command(
-                OUTPUT ${OUTPUT_HEADER_FILE}
-                COMMAND ${CMAKE_COMMAND} -E echo "/*THIS FILE HAS BEEN AUTOMATICALLY GENERATED - DO NOT EDIT*/" > ${OUTPUT_HEADER_FILE}
-                COMMAND ${CMAKE_COMMAND} -E echo \"\#ifndef ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
-                COMMAND ${CMAKE_COMMAND} -E echo \"\#define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
-                COMMAND ${CMAKE_COMMAND} -E echo "namespace kp {" >> ${OUTPUT_HEADER_FILE}
-                COMMAND ${CMAKE_COMMAND} -E echo "namespace shader_data {" >> ${OUTPUT_HEADER_FILE}
-                COMMAND ${CMAKE_BINARY_DIR}/bin/xxd -i ${RAW_FILE_NAME} >> ${OUTPUT_HEADER_FILE}
-                COMMAND ${CMAKE_COMMAND} -E echo "}}" >> ${OUTPUT_HEADER_FILE}
-                COMMAND ${CMAKE_COMMAND} -E echo \"\#endif // define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
-                DEPENDS ${spv_file} xxd
-                COMMENT "Converting to hpp: ${FILE_NAME} ${CMAKE_BINARY_DIR}/bin/xxd"
-                )
-        endif()
-    endforeach()
-endfunction()
-
-if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/kompute/CMakeLists.txt")
-    message(STATUS "Kompute found")
-    set(KOMPUTE_OPT_LOG_LEVEL Error CACHE STRING "Kompute log level")
-    add_subdirectory(kompute)
-
-    # Compile our shaders
-    compile_shader(SOURCES
-        kompute-shaders/op_scale.comp
-        kompute-shaders/op_scale_8.comp
-        kompute-shaders/op_add.comp
-        kompute-shaders/op_addrow.comp
-        kompute-shaders/op_mul.comp
-        kompute-shaders/op_silu.comp
-        kompute-shaders/op_relu.comp
-        kompute-shaders/op_gelu.comp
-        kompute-shaders/op_softmax.comp
-        kompute-shaders/op_norm.comp
-        kompute-shaders/op_rmsnorm.comp
-        kompute-shaders/op_diagmask.comp
-        kompute-shaders/op_mul_mat_mat_f32.comp
-        kompute-shaders/op_mul_mat_f16.comp
-        kompute-shaders/op_mul_mat_q8_0.comp
-        kompute-shaders/op_mul_mat_q4_0.comp
-        kompute-shaders/op_mul_mat_q4_1.comp
-        kompute-shaders/op_mul_mat_q4_k.comp
-        kompute-shaders/op_mul_mat_q6_k.comp
-        kompute-shaders/op_getrows_f32.comp
-        kompute-shaders/op_getrows_f16.comp
-        kompute-shaders/op_getrows_q4_0.comp
-        kompute-shaders/op_getrows_q4_1.comp
-        kompute-shaders/op_getrows_q6_k.comp
-        kompute-shaders/op_rope_norm_f16.comp
-        kompute-shaders/op_rope_norm_f32.comp
-        kompute-shaders/op_rope_neox_f16.comp
-        kompute-shaders/op_rope_neox_f32.comp
-        kompute-shaders/op_cpy_f16_f16.comp
-        kompute-shaders/op_cpy_f16_f32.comp
-        kompute-shaders/op_cpy_f32_f16.comp
-        kompute-shaders/op_cpy_f32_f32.comp
-    )
-
-    # Create a custom target for our generated shaders
-    add_custom_target(generated_shaders DEPENDS
-        shaderop_scale.h
-        shaderop_scale_8.h
-        shaderop_add.h
-        shaderop_addrow.h
-        shaderop_mul.h
-        shaderop_silu.h
-        shaderop_relu.h
-        shaderop_gelu.h
-        shaderop_softmax.h
-        shaderop_norm.h
-        shaderop_rmsnorm.h
-        shaderop_diagmask.h
-        shaderop_mul_mat_mat_f32.h
-        shaderop_mul_mat_f16.h
-        shaderop_mul_mat_q8_0.h
-        shaderop_mul_mat_q4_0.h
-        shaderop_mul_mat_q4_1.h
-        shaderop_mul_mat_q4_k.h
-        shaderop_mul_mat_q6_k.h
-        shaderop_getrows_f32.h
-        shaderop_getrows_f16.h
-        shaderop_getrows_q4_0.h
-        shaderop_getrows_q4_1.h
-        shaderop_getrows_q6_k.h
-        shaderop_rope_norm_f16.h
-        shaderop_rope_norm_f32.h
-        shaderop_rope_neox_f16.h
-        shaderop_rope_neox_f32.h
-        shaderop_cpy_f16_f16.h
-        shaderop_cpy_f16_f32.h
-        shaderop_cpy_f32_f16.h
-        shaderop_cpy_f32_f32.h
-    )
-
-    # Create a custom command that depends on the generated_shaders
-    add_custom_command(
-        OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp
-        COMMAND ${CMAKE_COMMAND} -E touch ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp
-        DEPENDS generated_shaders
-        COMMENT "Ensuring shaders are generated before compiling ggml-kompute.cpp"
-    )
-
-    # Add the stamp to the main sources to ensure dependency tracking
-    target_sources(ggml-kompute PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp)
-else()
-    message(WARNING "Kompute not found")
-endif()
diff --git a/ggml/src/ggml-kompute/ggml-kompute.cpp b/ggml/src/ggml-kompute/ggml-kompute.cpp
deleted file mode 100644
index 50579227183..00000000000
--- a/ggml/src/ggml-kompute/ggml-kompute.cpp
+++ /dev/null
@@ -1,2251 +0,0 @@
-#include "ggml-impl.h"
-#include "ggml-backend.h"
-#include "ggml-backend-impl.h"
-#include "ggml-kompute.h"
-
-// These are generated at build time by cmake custom command
-#include "shaderop_scale.h"
-#include "shaderop_scale_8.h"
-#include "shaderop_add.h"
-#include "shaderop_addrow.h"
-#include "shaderop_mul.h"
-#include "shaderop_silu.h"
-#include "shaderop_relu.h"
-#include "shaderop_gelu.h"
-#include "shaderop_softmax.h"
-#include "shaderop_norm.h"
-#include "shaderop_rmsnorm.h"
-#include "shaderop_diagmask.h"
-#include "shaderop_mul_mat_f16.h"
-#include "shaderop_mul_mat_q8_0.h"
-#include "shaderop_mul_mat_q4_0.h"
-#include "shaderop_mul_mat_q4_1.h"
-#include "shaderop_mul_mat_q4_k.h"
-#include "shaderop_mul_mat_q6_k.h"
-#include "shaderop_mul_mat_mat_f32.h"
-#include "shaderop_getrows_f32.h"
-#include "shaderop_getrows_f16.h"
-#include "shaderop_getrows_q4_0.h"
-#include "shaderop_getrows_q4_1.h"
-#include "shaderop_getrows_q6_k.h"
-#include "shaderop_rope_norm_f16.h"
-#include "shaderop_rope_norm_f32.h"
-#include "shaderop_rope_neox_f16.h"
-#include "shaderop_rope_neox_f32.h"
-#include "shaderop_cpy_f16_f16.h"
-#include "shaderop_cpy_f16_f32.h"
-#include "shaderop_cpy_f32_f16.h"
-#include "shaderop_cpy_f32_f32.h"
-
-#include 
-#include 
-#include 
-#include 
-#include 
-#include 
-#include 
-#include 
-#include 
-#include 
-#include 
-#include 
-#include 
-#include 
-
-#include 
-#include 
-
-#ifdef __linux__
-#include  // for setenv
-#endif
-
-#define QK4_0 32
-#define QR4_0 2
-#define QK4_1 32
-#define QK_NL 16
-
-typedef ggml_fp16_t half;
-
-static std::string ggml_kompute_format_name(int device) {
-    return "Kompute" + std::to_string(device);
-}
-
-struct ggml_kompute_context {
-    int device;
-    std::string name;
-    std::shared_ptr pool;
-
-    ggml_kompute_context(int device)
-        : device(device), name(ggml_kompute_format_name(device)) {}
-};
-
-// FIXME: It would be good to consolidate the kompute manager and the kompute context into one object
-// and consolidate the init functions and simplify object lifetime management. As it currently stands,
-// we *have* to have the kompute manager no matter what for device discovery, but the kompute context
-// is only created when a device is set and vulkan is explicitly turned on.
-static ggml_kompute_context *s_kompute_context = nullptr;
-
-class kompute_manager {
-    kp::Manager *s_mgr = nullptr;
-
-public:
-    kp::Manager *operator()() {
-        if (s_mgr && !s_mgr->hasInstance()) {
-            destroy();
-        }
-        if (!s_mgr) {
-            s_mgr = new kp::Manager;
-        }
-        return s_mgr;
-    }
-
-    void destroy() {
-        delete s_mgr;
-        s_mgr = nullptr;
-    }
-};
-
-static kompute_manager komputeManager;
-
-struct ggml_vk_memory {
-    void *data = nullptr;
-    size_t size = 0;
-    vk::DeviceMemory *primaryMemory = nullptr;
-    vk::Buffer *primaryBuffer = nullptr;
-    vk::DeviceMemory *stagingMemory = nullptr;
-    vk::Buffer *stagingBuffer = nullptr;
-};
-
-#ifdef __linux__
-__attribute__((constructor))
-static void enable_sam() {
-    setenv("RADV_PERFTEST", "sam", false);
-}
-#endif
-
-static bool ggml_vk_checkPhysicalDeviceFeatures(vk::PhysicalDevice physical_device) {
-    vk::PhysicalDeviceFeatures availableFeatures;
-    physical_device.getFeatures(&availableFeatures);
-
-    if (!availableFeatures.shaderInt16)
-        return false;
-
-    vk::PhysicalDeviceVulkan11Features availableFeatures11;
-    vk::PhysicalDeviceVulkan12Features availableFeatures12;
-
-    availableFeatures11.pNext = &availableFeatures12;
-    availableFeatures12.pNext = nullptr;
-
-    vk::PhysicalDeviceFeatures2 features2;
-    features2.pNext = &availableFeatures11;
-
-    physical_device.getFeatures2(&features2);
-
-    if (!availableFeatures11.uniformAndStorageBuffer16BitAccess ||
-        !availableFeatures11.storageBuffer16BitAccess) {
-        return false;
-    }
-
-    if (!availableFeatures12.storageBuffer8BitAccess ||
-        !availableFeatures12.uniformAndStorageBuffer8BitAccess ||
-        !availableFeatures12.shaderFloat16 ||
-        !availableFeatures12.shaderInt8) {
-        return false;
-    }
-
-    return true;
-}
-
-static const char * ggml_vk_getVendorName(uint32_t vendorID) {
-    switch (vendorID) {
-        case 0x10DE:
-            return "nvidia";
-        case 0x1002:
-            return "amd";
-        case 0x8086:
-            return "intel";
-        default:
-            return "unknown";
-    }
-}
-
-static std::vector ggml_vk_available_devices_internal(size_t memoryRequired) {
-    std::vector results;
-    if (!komputeManager()->hasVulkan() || !komputeManager()->hasInstance())
-        return results;
-
-    std::vector physical_devices;
-    try {
-        physical_devices = komputeManager()->listDevices();
-    } catch (vk::SystemError & err) {
-        std::cerr << __func__ << ": ignoring Vulkan exception: " << err.what() << "\n";
-        return results;
-    }
-
-    uint32_t deviceCount = physical_devices.size();
-    if (deviceCount == 0)
-        return results;
-
-    std::unordered_map count_by_name;
-
-    for (uint32_t i = 0; i < deviceCount; i++) {
-        const auto & physical_device = physical_devices[i];
-
-        VkPhysicalDeviceProperties dev_props = physical_device.getProperties();
-        VkPhysicalDeviceMemoryProperties memoryProperties = physical_device.getMemoryProperties();
-        const uint32_t major = VK_VERSION_MAJOR(dev_props.apiVersion);
-        const uint32_t minor = VK_VERSION_MINOR(dev_props.apiVersion);
-        if (major < 1 || minor < 2)
-            continue;
-
-        if (!ggml_vk_checkPhysicalDeviceFeatures(physical_device))
-            continue;
-
-        size_t heapSize = 0;
-        for (uint32_t j = 0; j < memoryProperties.memoryHeapCount; ++j) {
-            VkMemoryHeap heap = memoryProperties.memoryHeaps[j];
-            if (heap.flags & VK_MEMORY_HEAP_DEVICE_LOCAL_BIT) {
-                heapSize = heap.size;
-                break;
-            }
-        }
-
-        if (heapSize < memoryRequired)
-            continue;
-
-        auto ext_props = physical_device.enumerateDeviceExtensionProperties();
-        bool has_maintenance4 = false;
-
-        // Check if maintenance4 is supported
-        for (const auto & properties : ext_props) {
-            if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
-                has_maintenance4 = true;
-            }
-        }
-
-        vk::PhysicalDeviceSubgroupProperties subgroup_props;
-        vk::PhysicalDeviceProperties2 dev_props2;
-        vk::PhysicalDeviceMaintenance3Properties dev_props3;
-        vk::PhysicalDeviceMaintenance4Properties dev_props4;
-        dev_props2.pNext = &dev_props3;
-        dev_props3.pNext = &subgroup_props;
-        if (has_maintenance4) {
-            subgroup_props.pNext = &dev_props4;
-        }
-        physical_device.getProperties2(&dev_props2);
-
-        if (subgroup_props.subgroupSize < 32)
-            continue;
-
-        ggml_vk_device d;
-        d.index = i;
-        d.type = dev_props.deviceType;
-        d.heapSize = heapSize;
-        d.vendor = strdup(ggml_vk_getVendorName(dev_props.vendorID));
-        d.subgroupSize = subgroup_props.subgroupSize;
-        d.bufferAlignment = dev_props.limits.minStorageBufferOffsetAlignment;
-
-        if (has_maintenance4) {
-            d.maxAlloc = std::min(dev_props3.maxMemoryAllocationSize, dev_props4.maxBufferSize);
-        } else {
-            d.maxAlloc = dev_props3.maxMemoryAllocationSize;
-        }
-
-        std::string name(dev_props.deviceName);
-        size_t n_idx = ++count_by_name[name];
-        if (n_idx > 1) {
-            name += " (" + std::to_string(n_idx) + ")";
-        }
-        d.name = strdup(name.c_str());
-
-        results.push_back(d);
-    }
-
-    std::stable_sort(results.begin(), results.end(),
-        [](const ggml_vk_device& lhs, const ggml_vk_device& rhs) -> bool {
-            if (lhs.type != rhs.type) {
-                if (lhs.type == VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU) return true;
-                if (rhs.type == VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU) return false;
-
-                if (lhs.type == VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU) return true;
-                if (rhs.type == VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU) return false;
-            }
-            return lhs.heapSize < rhs.heapSize;
-        }
-    );
-
-    return results;
-}
-
-static std::vector& ggml_vk_available_devices() {
-    static std::vector devices = ggml_vk_available_devices_internal(0);
-    return devices;
-}
-
-static void ggml_vk_filterByVendor(std::vector& devices, const std::string& targetVendor) {
-    devices.erase(
-        std::remove_if(devices.begin(), devices.end(),
-            [&targetVendor](const ggml_vk_device& device) {
-                return device.vendor != targetVendor;
-            }),
-        devices.end()
-    );
-}
-
-static void ggml_vk_filterByName(std::vector& devices, const std::string& targetName) {
-    devices.erase(
-        std::remove_if(devices.begin(), devices.end(),
-            [&targetName](const ggml_vk_device& device) {
-                return device.name != targetName;
-            }),
-        devices.end()
-    );
-}
-
-static bool ggml_vk_get_device(ggml_vk_device * device, size_t memoryRequired, const std::string & name) {
-    if (name.empty())
-        return false;
-
-    auto devices = ggml_vk_available_devices_internal(memoryRequired);
-    if (name == "amd" || name == "nvidia" || name == "intel") {
-        ggml_vk_filterByVendor(devices, name);
-    } else if (name != "gpu") {
-        ggml_vk_filterByName(devices, name);
-    }
-
-    if (devices.empty())
-        return false;
-
-    *device = devices.front();
-    return true;
-}
-
-bool ggml_vk_get_device(ggml_vk_device * device, size_t memoryRequired, const char * name) {
-    return ggml_vk_get_device(device, memoryRequired, std::string(name));
-}
-
-bool ggml_vk_has_vulkan() {
-    return komputeManager()->hasVulkan();
-}
-
-bool ggml_vk_has_device() {
-    return komputeManager()->hasDevice();
-}
-
-ggml_vk_device ggml_vk_current_device() {
-    if (!komputeManager()->hasDevice())
-        return ggml_vk_device();
-
-    auto devices = ggml_vk_available_devices();
-    ggml_vk_filterByName(devices, komputeManager()->physicalDevice()->getProperties().deviceName.data());
-    GGML_ASSERT(!devices.empty());
-    return devices.front();
-}
-
-static
-void ggml_vk_allocate_descriptor_pool(struct ggml_kompute_context * ctx, size_t size) {
-    std::vector descriptorPoolSizes = {
-        vk::DescriptorPoolSize(
-          vk::DescriptorType::eStorageBuffer,
-          4 * size // Descriptor count is number of possible tensors to pass into an algorithm
-          )
-    };
-
-    vk::DescriptorPoolCreateInfo descriptorPoolInfo(
-      vk::DescriptorPoolCreateFlags(),
-      size, // Max sets
-      static_cast(descriptorPoolSizes.size()),
-      descriptorPoolSizes.data());
-
-    ctx->pool = std::make_shared();
-    vk::Result r = komputeManager()->device()->createDescriptorPool(
-      &descriptorPoolInfo, nullptr, ctx->pool.get());
-    if (r != vk::Result::eSuccess)
-        std::cerr << "Error allocating descriptor pool" << vk::to_string(r);
-}
-
-static
-void ggml_vk_free_descriptor_pool(struct ggml_kompute_context * ctx) {
-    if (ctx->pool) {
-        komputeManager()->device()->destroy(
-          *ctx->pool,
-          (vk::Optional)nullptr);
-        ctx->pool = nullptr;
-    }
-}
-
-static
-vk::Buffer *ggml_vk_allocate_buffer(size_t size) {
-    vk::BufferCreateInfo bufferCreateInfo;
-    bufferCreateInfo.size = size;
-    bufferCreateInfo.usage = vk::BufferUsageFlagBits::eStorageBuffer |
-                             vk::BufferUsageFlagBits::eTransferSrc |
-                             vk::BufferUsageFlagBits::eTransferDst;
-    bufferCreateInfo.sharingMode = vk::SharingMode::eExclusive;
-
-    vk::Buffer *vkBuffer = new vk::Buffer;
-    vk::Result r = komputeManager()->device()->createBuffer(&bufferCreateInfo, nullptr, vkBuffer);
-    if (r != vk::Result::eSuccess)
-        std::cerr << "Error allocating buffer " << vk::to_string(r) << std::endl;
-    return vkBuffer;
-}
-
-static
-vk::DeviceMemory *ggml_vk_allocate(size_t size, vk::MemoryPropertyFlags flags, vk::MemoryRequirements requirements, bool *isHostVisible) {
-
-    uint32_t memoryTypeIndex = -1;
-    bool memoryTypeIndexFound = false;
-    vk::PhysicalDeviceMemoryProperties memoryProperties = komputeManager()->physicalDevice()->getMemoryProperties();
-    for (uint32_t i = 0; i < memoryProperties.memoryTypeCount; i++) {
-        const vk::MemoryType &memoryType = memoryProperties.memoryTypes[i];
-        const vk::MemoryHeap &memoryHeap = memoryProperties.memoryHeaps[memoryType.heapIndex];
-        if (memoryHeap.size < size) {
-            continue;
-        }
-
-        if (requirements.memoryTypeBits & (1 << i)) {
-            if (((memoryProperties.memoryTypes[i]).propertyFlags &
-                 flags) == flags) {
-                memoryTypeIndex = i;
-                memoryTypeIndexFound = true;
-                if (isHostVisible && (memoryProperties.memoryTypes[i].propertyFlags & vk::MemoryPropertyFlagBits::eHostVisible)) {
-                    *isHostVisible = true;
-                }
-                break;
-            }
-        }
-    }
-    if (!memoryTypeIndexFound) {
-        throw std::runtime_error(
-          "Memory type index for buffer creation not found");
-    }
-
-    vk::MemoryAllocateInfo allocInfo;
-    allocInfo.allocationSize = size;
-    allocInfo.memoryTypeIndex = memoryTypeIndex;
-    vk::DeviceMemory *vkDeviceMemory =  new vk::DeviceMemory;
-    vk::Result r = komputeManager()->device()->allocateMemory(&allocInfo, nullptr, vkDeviceMemory);
-    if (r != vk::Result::eSuccess) {
-        std::cerr << "Error allocating memory " << vk::to_string(r) << std::endl;
-        throw std::runtime_error("Error allocating vulkan memory.");
-    }
-    return vkDeviceMemory;
-}
-
-static size_t ggml_vk_aligned_offset(ggml_backend_buffer_t buffer, size_t offset) {
-    size_t minStorageBufferOffsetAlignment = ggml_backend_buffer_get_alignment(buffer);
-
-    // If offset is already aligned, return it directly
-    if (offset % minStorageBufferOffsetAlignment == 0) {
-        return offset;
-    }
-
-    // Otherwise, return the largest multiple of minStorageBufferOffsetAlignment less than offset
-    return (offset / minStorageBufferOffsetAlignment) * minStorageBufferOffsetAlignment;
-}
-
-static ggml_vk_memory ggml_vk_allocate(size_t size) {
-    ggml_vk_memory memory;
-    bool isHostVisible = false;
-    {
-        memory.primaryBuffer = ggml_vk_allocate_buffer(size);
-        vk::MemoryRequirements memoryRequirements = komputeManager()->device()->getBufferMemoryRequirements(*memory.primaryBuffer);
-        vk::MemoryPropertyFlags memoryPropertyFlags = vk::MemoryPropertyFlagBits::eDeviceLocal;
-        memory.primaryMemory = ggml_vk_allocate(size, memoryPropertyFlags, memoryRequirements, &isHostVisible);
-        komputeManager()->device()->bindBufferMemory(*memory.primaryBuffer, *memory.primaryMemory, 0);
-        if (isHostVisible) {
-            vk::Result r = komputeManager()->device()->mapMemory(*memory.primaryMemory, 0, size, vk::MemoryMapFlags(), &memory.data);
-            if (r != vk::Result::eSuccess)
-                std::cerr << "Error mapping memory" << vk::to_string(r);
-        }
-    }
-
-    if (!isHostVisible) {
-        memory.stagingBuffer = ggml_vk_allocate_buffer(size);
-        vk::MemoryRequirements memoryRequirements = komputeManager()->device()->getBufferMemoryRequirements(*memory.stagingBuffer);
-        vk::MemoryPropertyFlags memoryPropertyFlags = vk::MemoryPropertyFlagBits::eHostVisible |
-                                                      vk::MemoryPropertyFlagBits::eHostCoherent |
-                                                      vk::MemoryPropertyFlagBits::eHostCached;
-        memory.stagingMemory = ggml_vk_allocate(size, memoryPropertyFlags, memoryRequirements, &isHostVisible);
-        komputeManager()->device()->bindBufferMemory(*memory.stagingBuffer, *memory.stagingMemory, 0);
-        vk::Result r = komputeManager()->device()->mapMemory(*memory.stagingMemory, 0, size, vk::MemoryMapFlags(), &memory.data);
-        if (r != vk::Result::eSuccess)
-            std::cerr << "Error mapping memory" << vk::to_string(r);
-    }
-
-    memory.size = size;
-    return memory;
-}
-
-static void ggml_vk_free_memory(ggml_vk_memory &memory)
-{
-    komputeManager()->device()->destroy(
-      *memory.primaryBuffer,
-      (vk::Optional)nullptr);
-    if (memory.stagingBuffer) {
-        komputeManager()->device()->destroy(
-          *memory.stagingBuffer,
-          (vk::Optional)nullptr);
-    }
-    komputeManager()->device()->freeMemory(
-      *memory.primaryMemory,
-      (vk::Optional)nullptr);
-    if (memory.stagingMemory) {
-        komputeManager()->device()->freeMemory(
-          *memory.stagingMemory,
-          (vk::Optional)nullptr);
-    }
-}
-
-static const char * ggml_backend_kompute_buffer_type_get_name(ggml_backend_buffer_type_t buft);
-
-static
-ggml_vk_memory * ggml_vk_find_tensor(const struct ggml_tensor * t, uint64_t & offset) {
-    ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
-
-    // compatibility with ggml-backend
-    GGML_ASSERT(buffer && buffer->buft->iface.get_name == ggml_backend_kompute_buffer_type_get_name);
-
-    ggml_vk_memory * buf_ctx = static_cast(buffer->context);
-
-    const intptr_t ioffs = intptr_t(t->data) - intptr_t(buf_ctx->data);
-
-    GGML_ASSERT(ioffs >= 0 && ioffs + int64_t(ggml_nbytes(t)) <= int64_t(buffer->size));
-
-    offset = uint64_t(ioffs);
-    return buf_ctx;
-}
-
-static
-const std::shared_ptr ggml_vk_get_tensor(const struct ggml_tensor * t, uint32_t * alignedOffset = nullptr) {
-    uint64_t originalOffset = 0;
-    auto * res = ggml_vk_find_tensor(t, originalOffset);
-    if (!res) {
-        static std::shared_ptr nullTensor = nullptr;
-        return nullTensor;
-    }
-
-    // Create a tensor whose memory will be composed of our buffers at the correct offset
-    const size_t nelements = ggml_nelements(t);
-    size_t nbytes = ggml_nbytes(t);
-
-    size_t vulkanOffset = ggml_vk_aligned_offset(t->buffer, originalOffset);
-    if (alignedOffset) {
-        *alignedOffset = originalOffset - vulkanOffset;
-        nbytes += *alignedOffset;
-    }
-
-    return komputeManager()->tensor(
-        t->data,
-        nelements,
-        nbytes, kp::Tensor::TensorDataTypes::eFloat,
-        res->primaryMemory, res->primaryBuffer,
-        res->stagingMemory, res->stagingBuffer,
-        vulkanOffset);
-}
-
-static std::vector getSpirvShader(const unsigned char* rawData, size_t size) {
-    if (size % sizeof(uint32_t) != 0) {
-        throw std::runtime_error("Invalid size: must be divisible by sizeof(uint32_t)");
-    }
-
-    const uint32_t* data_ptr = reinterpret_cast(rawData);
-    size_t count = size / sizeof(uint32_t);
-    return std::vector(data_ptr, data_ptr + count);
-}
-
-inline static
-uint32_t safe_divide(uint32_t a, uint32_t b) {
-    if (b <= 1) {
-        return a;
-    }
-    if ((a % b) != 0) {
-        fprintf(stderr, "((%u %% %u) == %u) != 0\n", a, b, a % b);
-        GGML_ABORT("safe_divide result would've had remainder");
-    }
-    return a / b;
-}
-
-static void ggml_vk_add(
-    kp::Sequence& seq,
-    const std::shared_ptr& inA,
-    const std::shared_ptr& inB,
-    const std::shared_ptr& out,
-    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
-    int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03,
-    int32_t nb00, int32_t nb01, int32_t nb02, int32_t nb03,
-    int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
-    int32_t nb10, int32_t nb11, int32_t nb12, int32_t nb13,
-    int32_t ne0,
-    int32_t nb0,  int32_t nb1,  int32_t nb2,  int32_t nb3
-) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_add_comp_spv,
-        kp::shader_data::op_add_comp_spv_len);
-
-    struct PushConstants {
-        uint32_t inAOff, inBOff, outOff;
-        int32_t ne00;
-        int32_t nb00, nb01, nb02, nb03;
-        int32_t ne10, ne11, ne12, ne13;
-        int32_t nb10, nb11, nb12, nb13;
-        int32_t ne0;
-        int32_t nb0, nb1, nb2, nb3;
-    } const pushConsts {
-        safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
-        ne00,
-        nb00, nb01, nb02, nb03,
-        ne10, ne11, ne12, ne13,
-        nb10, nb11, nb12, nb13,
-        ne0,
-        nb0, nb1, nb2, nb3
-    };
-
-    std::shared_ptr s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(__func__)) {
-        s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts});
-    } else {
-        s_algo = komputeManager()->getAlgorithm(__func__);
-        s_algo->setTensors({inA, inB, out});
-        s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
-        s_algo->setPushConstants({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record(s_algo);
-}
-
-static void ggml_vk_addrow(kp::Sequence& seq,
-                 const std::shared_ptr& inA,
-                 const std::shared_ptr& inB,
-                 const std::shared_ptr& out,
-                 uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
-                 uint32_t size, uint32_t row = 0) {
-
-    const static auto spirv = getSpirvShader(kp::shader_data::op_addrow_comp_spv,
-        kp::shader_data::op_addrow_comp_spv_len);
-
-    struct PushConstants {
-        uint32_t inAOff, inBOff, outOff;
-        uint32_t row;
-    } const pushConsts {
-        safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
-        row
-    };
-
-    std::shared_ptr s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(__func__))
-        s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {size}, {}, {pushConsts});
-    else {
-        s_algo = komputeManager()->getAlgorithm(__func__);
-        s_algo->setTensors({inA, inB, out});
-        s_algo->setWorkgroup({size});
-        s_algo->setPushConstants({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record(s_algo);
-}
-
-static void ggml_vk_mul(
-    kp::Sequence& seq,
-    const std::shared_ptr& inA,
-    const std::shared_ptr& inB,
-    const std::shared_ptr& out,
-    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
-    int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03,
-    int32_t nb00, int32_t nb01, int32_t nb02, int32_t nb03,
-    int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
-    int32_t nb10, int32_t nb11, int32_t nb12, int32_t nb13,
-    int32_t ne0,
-    int32_t nb0,  int32_t nb1,  int32_t nb2,  int32_t nb3
-) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_mul_comp_spv,
-        kp::shader_data::op_mul_comp_spv_len);
-
-    struct PushConstants {
-        uint32_t inAOff, inBOff, outOff;
-        int32_t ne00;
-        int32_t nb00, nb01, nb02, nb03;
-        int32_t ne10, ne11, ne12, ne13;
-        int32_t nb10, nb11, nb12, nb13;
-        int32_t ne0;
-        int32_t nb0, nb1, nb2, nb3;
-    } const pushConsts {
-        safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
-        ne00,
-        nb00, nb01, nb02, nb03,
-        ne10, ne11, ne12, ne13,
-        nb10, nb11, nb12, nb13,
-        ne0,
-        nb0, nb1, nb2, nb3
-    };
-
-    std::shared_ptr s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(__func__)) {
-        s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts});
-    } else {
-        s_algo = komputeManager()->getAlgorithm(__func__);
-        s_algo->setTensors({inA, inB, out});
-        s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
-        s_algo->setPushConstants({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record(s_algo);
-}
-
-static void ggml_vk_scale(kp::Sequence& seq,
-                   const std::shared_ptr& in,
-                   const std::shared_ptr& out,
-                   uint32_t inOff, uint32_t outOff,
-                   uint32_t size, float scale) {
-    const static auto spirv_1 = getSpirvShader(
-        kp::shader_data::op_scale_comp_spv, kp::shader_data::op_scale_comp_spv_len
-    );
-    const static auto spirv_8 = getSpirvShader(
-        kp::shader_data::op_scale_8_comp_spv, kp::shader_data::op_scale_8_comp_spv_len
-    );
-
-    struct PushConstants {
-        uint32_t inOff, outOff;
-        float scale;
-    } const pushConsts {
-        safe_divide(inOff, 4), safe_divide(outOff, 4),
-        scale
-    };
-
-    const auto * spirv = &spirv_1;
-    std::string name(__func__);
-    if (size % 8 == 0) {
-        size /= 8;
-        name += "_8";
-        spirv = &spirv_8;
-    }
-
-    std::shared_ptr s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(name)) {
-        s_algo = komputeManager()->algorithm(name, s_kompute_context->pool.get(), {in, out}, *spirv, {size}, {}, {pushConsts});
-    } else {
-        s_algo = komputeManager()->getAlgorithm(name);
-        s_algo->setTensors({in, out});
-        s_algo->setWorkgroup({size});
-        s_algo->setPushConstants({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record(s_algo);
-}
-
-static void ggml_vk_xxlu(
-    const std::vector& spirv, const char * suffix, kp::Sequence& seq,
-    const std::shared_ptr& in,
-    const std::shared_ptr& out,
-    uint32_t inOff, uint32_t outOff,
-    uint32_t size
-) {
-    struct PushConstants {
-        uint32_t inOff, outOff;
-    } const pushConsts {
-        safe_divide(inOff, 4), safe_divide(outOff, 4),
-    };
-
-    auto name = std::string(__func__) + "_" + suffix;
-    std::shared_ptr s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(name)) {
-        s_algo = komputeManager()->algorithm(name, s_kompute_context->pool.get(), {in, out}, spirv, {size}, {}, {pushConsts});
-    } else {
-        s_algo = komputeManager()->getAlgorithm(name);
-        s_algo->setTensors({in, out});
-        s_algo->setWorkgroup({size});
-        s_algo->setPushConstants({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record(s_algo);
-}
-
-template 
-static void ggml_vk_silu(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_silu_comp_spv,
-        kp::shader_data::op_silu_comp_spv_len);
-
-    ggml_vk_xxlu(spirv, "silu", std::forward(args)...);
-}
-
-template 
-static void ggml_vk_relu(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_relu_comp_spv,
-        kp::shader_data::op_relu_comp_spv_len);
-
-    ggml_vk_xxlu(spirv, "relu", std::forward(args)...);
-}
-
-template 
-static void ggml_vk_gelu(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_gelu_comp_spv,
-        kp::shader_data::op_gelu_comp_spv_len);
-
-    ggml_vk_xxlu(spirv, "gelu", std::forward(args)...);
-}
-
-static void ggml_vk_soft_max(
-    kp::Sequence& seq,
-    const std::shared_ptr& inA,
-    const std::shared_ptr& inB,
-    const std::shared_ptr& out,
-    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
-    int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03,
-    float scale, float max_bias, float m0, float m1,
-    uint32_t n_head_log2
-) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_softmax_comp_spv,
-        kp::shader_data::op_softmax_comp_spv_len);
-
-    struct PushConstants {
-        uint32_t inAOff, inBOff, outOff;
-        int32_t ne00, ne01, ne02;
-        float scale, max_bias, m0, m1;
-        uint32_t n_head_log2;
-        int32_t mask;
-    } pushConsts {
-        safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
-        ne00, ne01, ne02,
-        scale, max_bias, m0, m1,
-        n_head_log2,
-        bool(inB)
-    };
-
-    auto & inB_ = inB ? inB : inA;
-
-    std::shared_ptr s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(__func__)) {
-        // FIXME: The softmax kernel needs to be fixed to use the subgroupsize which can vary by device
-        const uint32_t local_x = 32;
-        s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {inA, inB_, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {local_x}, {pushConsts});
-    } else {
-        s_algo = komputeManager()->getAlgorithm(__func__);
-        s_algo->setTensors({inA, inB_, out});
-        s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
-        s_algo->setPushConstants({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record(s_algo);
-}
-
-static void ggml_vk_norm_(
-    const std::vector& spirv, const char * suffix, kp::Sequence& seq,
-    const std::shared_ptr& in,
-    const std::shared_ptr& out,
-    uint32_t inOff, uint32_t outOff,
-    int32_t ne00, int32_t nb01,
-    int32_t nrows, float epsilon
-) {
-    GGML_ASSERT(nb01%sizeof(float) == 0);
-    GGML_ASSERT(ne00%sizeof(float) == 0);
-
-    struct PushConstants {
-        uint32_t inOff, outOff;
-        uint32_t ne00, nb01;
-        float eps;
-    } pushConsts {
-        safe_divide(inOff, 4), safe_divide(outOff, 4),
-        (uint32_t)ne00, (uint32_t)nb01, epsilon
-    };
-
-    auto name = std::string(__func__) + "_" + suffix;
-    std::shared_ptr s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(name)) {
-        s_algo = komputeManager()->algorithm(name, s_kompute_context->pool.get(), {in, out}, spirv, {(uint32_t)nrows}, {}, {pushConsts});
-    } else {
-        s_algo = komputeManager()->getAlgorithm(name);
-        s_algo->setTensors({in, out});
-        s_algo->setWorkgroup({(uint32_t)nrows});
-        s_algo->setPushConstants({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record(s_algo);
-}
-
-template 
-static void ggml_vk_norm(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_norm_comp_spv,
-        kp::shader_data::op_norm_comp_spv_len);
-
-    ggml_vk_norm_(spirv, "norm", std::forward(args)...);
-}
-
-template 
-static void ggml_vk_rms_norm(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_rmsnorm_comp_spv,
-        kp::shader_data::op_rmsnorm_comp_spv_len);
-
-    ggml_vk_norm_(spirv, "rms", std::forward(args)...);
-}
-
-static void ggml_vk_diag_mask_inf(kp::Sequence& seq,
-                           const std::shared_ptr& in,
-                           const std::shared_ptr& out,
-                           uint32_t inOff, uint32_t outOff,
-                           uint32_t n_past,
-                           int32_t ne00, int32_t ne01, int32_t ne02) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_diagmask_comp_spv,
-        kp::shader_data::op_diagmask_comp_spv_len);
-
-    struct PushConstants {
-        uint32_t inOff, outOff;
-        uint32_t n_past;
-        int32_t ne00, ne01;
-    } pushConsts {
-        safe_divide(inOff, 4), safe_divide(outOff, 4),
-        n_past,
-        ne00, ne01
-    };
-
-    std::shared_ptr s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(__func__))
-        s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne00), unsigned(ne01), unsigned(ne02)}, {}, {pushConsts});
-    else {
-        s_algo = komputeManager()->getAlgorithm(__func__);
-        s_algo->setTensors({in, out});
-        s_algo->setWorkgroup({unsigned(ne00), unsigned(ne01), unsigned(ne02)});
-        s_algo->setPushConstants({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record(s_algo);
-}
-
-static void ggml_vk_mul_mat_f16(
-    kp::Sequence& seq,
-    const std::shared_ptr& inA,
-    const std::shared_ptr& inB,
-    const std::shared_ptr& out,
-    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
-    int32_t ne00, int32_t ne01, int32_t ne02,
-    uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
-    int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
-    uint32_t nb10, uint32_t nb11, uint32_t nb12, uint32_t nb13,
-    int32_t ne0, int32_t ne1,
-    uint32_t r2, uint32_t r3
-) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_f16_comp_spv,
-        kp::shader_data::op_mul_mat_f16_comp_spv_len);
-
-    struct PushConstants {
-        uint32_t inAOff, inBOff, outOff;
-        int32_t ne00, ne01, ne02;
-        uint32_t nb00, nb01, nb02, nb03;
-        int32_t ne10, ne11, ne12;
-        uint32_t nb10, nb11, nb12, nb13;
-        int32_t ne0, ne1;
-        uint32_t r2, r3;
-    } pushConsts {
-        safe_divide(inAOff, 2), safe_divide(inBOff, 4), safe_divide(outOff, 4),
-        ne00, ne01, ne02,
-        nb00, nb01, nb02, nb03,
-        ne10, ne11, ne12,
-        nb10, nb11, nb12, nb13,
-        ne0, ne1,
-        r2, r3
-    };
-
-    const unsigned ny = unsigned((ne11 + 4 - 1)/4);
-
-    std::shared_ptr s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(__func__)) {
-        const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
-        s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), ny, unsigned(ne12*ne13)}, {local_x}, {pushConsts});
-    } else {
-        s_algo = komputeManager()->getAlgorithm(__func__);
-        s_algo->setTensors({inA, inB, out});
-        s_algo->setWorkgroup({unsigned(ne01), ny, unsigned(ne12*ne13)});
-        s_algo->setPushConstants({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record(s_algo);
-}
-
-static void ggml_vk_mul_mat_mat_f32(kp::Sequence& seq,
-                         const std::shared_ptr& inA,
-                         const std::shared_ptr& inB,
-                         const std::shared_ptr& out,
-                         uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
-                         int32_t ne00, int32_t ne01, int32_t ne02,
-                         uint32_t nb01, uint32_t nb02,
-                         int32_t ne11, int32_t ne12,
-                         uint32_t nb11, uint32_t nb12,
-                         uint32_t nb1, uint32_t nb2) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_mat_f32_comp_spv,
-        kp::shader_data::op_mul_mat_mat_f32_comp_spv_len);
-
-    struct PushConstants {
-        uint32_t inAOff, inBOff, outOff;
-        int32_t ne00, ne01, ne02, ne11, ne12;
-        uint32_t nb01, nb02;
-        uint32_t nb11, nb12;
-        uint32_t nb1, nb2;
-    } pushConsts {
-        safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
-        ne00, ne01, ne02, ne11, ne12,
-        nb01, nb02, nb11, nb12,
-        nb1, nb2
-    };
-
-    const uint32_t local_x = ggml_vk_current_device().subgroupSize;
-    std::shared_ptr s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(__func__)) {
-        s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(),
-        {inA, inB, out}, spirv,
-        {unsigned(ne01),
-         unsigned(ne11),
-         unsigned(std::max(ne12, ne02))
-         },
-        {local_x},
-        {pushConsts});
-    } else {
-        s_algo = komputeManager()->getAlgorithm(__func__);
-        s_algo->setTensors({inA, inB, out});
-        s_algo->setWorkgroup({unsigned(ne01),
-                              unsigned(ne11),
-                              unsigned(std::max(ne12, ne02)),
-                              });
-        s_algo->setPushConstants({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record(s_algo);
-}
-
-static void ggml_vk_mul_mat_impl(
-    const std::vector& spirv, const char * suffix, uint32_t block_size, kp::Sequence& seq,
-    const std::shared_ptr& inA,
-    const std::shared_ptr& inB,
-    const std::shared_ptr& out,
-    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
-    int32_t ne00, int32_t ne01, int32_t ne02,
-    int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
-    int32_t ne0, int32_t ne1,
-    uint32_t nb01, uint32_t nb02, uint32_t nb03,
-    uint32_t nb11, uint32_t nb12, uint32_t nb13,
-    uint32_t r2, uint32_t r3
-) {
-    struct PushConstants {
-        uint32_t inAOff, inBOff, outOff;
-        int32_t ne00, ne01, ne02;
-        int32_t ne10, ne12;
-        int32_t ne0, ne1;
-        uint32_t nb01, nb02, nb03;
-        uint32_t nb11, nb12, nb13;
-        uint32_t r2, r3;
-    } pushConsts {
-        safe_divide(inAOff, block_size), safe_divide(inBOff, 4), safe_divide(outOff, 4),
-        ne00, ne01, ne02,
-        ne10, ne12,
-        ne0, ne1,
-        nb01, nb02, nb03,
-        nb11, nb12, nb13,
-        r2, r3
-    };
-
-    auto name = std::string(__func__) + "_" + suffix;
-    std::shared_ptr s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(name)) {
-        const uint32_t local_x = (ggml_vk_current_device().subgroupSize * 2) / 8;
-        s_algo = komputeManager()->algorithm(name, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)}, {local_x}, {pushConsts});
-    } else {
-        s_algo = komputeManager()->getAlgorithm(name);
-        s_algo->setTensors({inA, inB, out});
-        s_algo->setWorkgroup({unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)});
-        s_algo->setPushConstants({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record(s_algo);
-}
-
-template 
-static void ggml_vk_mul_mat_q4_0(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_0_comp_spv,
-        kp::shader_data::op_mul_mat_q4_0_comp_spv_len);
-
-    ggml_vk_mul_mat_impl(spirv, "q4_0", 1/*We access blocks unaligned*/, std::forward(args)...);
-}
-
-template 
-static void ggml_vk_mul_mat_q4_1(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_1_comp_spv,
-        kp::shader_data::op_mul_mat_q4_1_comp_spv_len);
-
-    ggml_vk_mul_mat_impl(spirv, "q4_1", 1/*We access blocks unaligned*/, std::forward(args)...);
-}
-
-template 
-static void ggml_vk_mul_mat_q8_0(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q8_0_comp_spv,
-        kp::shader_data::op_mul_mat_q8_0_comp_spv_len);
-
-    ggml_vk_mul_mat_impl(spirv, "q8_0", 1/*We access blocks unaligned*/, std::forward(args)...);
-}
-
-static void ggml_vk_mul_mat_q4_k(
-    kp::Sequence& seq,
-    const std::shared_ptr& inA,
-    const std::shared_ptr& inB,
-    const std::shared_ptr& out,
-    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
-    int32_t ne00, int32_t ne01, int32_t ne02,
-    int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
-    int32_t ne0, int32_t ne1,
-    uint32_t nb01, uint32_t nb02, uint32_t nb03,
-    uint32_t nb11, uint32_t nb12, uint32_t nb13,
-    uint32_t r2, uint32_t r3
-) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_k_comp_spv,
-        kp::shader_data::op_mul_mat_q4_k_comp_spv_len);
-
-    struct PushConstants {
-        uint32_t inAOff, inBOff, outOff;
-        int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12;
-        uint32_t nb01, nb02, nb03, nb11, nb12, nb13;
-        uint32_t r2, r3;
-    } pushConsts {
-        inAOff, safe_divide(inBOff, 4), safe_divide(outOff, 4),
-        ne00, ne10, ne0, ne1, ne01, ne02, ne12,
-        nb01, nb02, nb03, nb11, nb12, nb13,
-        r2, r3
-    };
-
-    std::shared_ptr s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(__func__)) {
-        s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)}, {}, {pushConsts});
-    } else {
-        s_algo = komputeManager()->getAlgorithm(__func__);
-        s_algo->setTensors({inA, inB, out});
-        s_algo->setWorkgroup({unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)});
-        s_algo->setPushConstants({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record(s_algo);
-}
-
-static void ggml_vk_mul_mat_q6_k(
-    kp::Sequence& seq,
-    const std::shared_ptr& inA,
-    const std::shared_ptr& inB,
-    const std::shared_ptr& out,
-    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
-    int32_t ne00, int32_t ne01, int32_t ne02,
-    int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
-    int32_t ne0, int32_t ne1,
-    uint32_t nb01, uint32_t nb02, uint32_t nb03,
-    uint32_t nb11, uint32_t nb12, uint32_t nb13,
-    uint32_t r2, uint32_t r3
-) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q6_k_comp_spv,
-        kp::shader_data::op_mul_mat_q6_k_comp_spv_len);
-
-    struct PushConstants {
-        uint32_t inAOff, inBOff, outOff;
-        int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12;
-        uint32_t nb01, nb02, nb03, nb11, nb12, nb13;
-        uint32_t r2, r3;
-    } pushConsts {
-        inAOff, safe_divide(inBOff, 4), safe_divide(outOff, 4),
-        ne00, ne10, ne0, ne1, ne01, ne02, ne12,
-        nb01, nb02, nb03, nb11, nb12, nb13,
-        r2, r3
-    };
-
-    std::shared_ptr s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(__func__)) {
-        const uint32_t local_x = 2;
-        const uint32_t local_y = ggml_vk_current_device().subgroupSize;
-        s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)*unsigned(ne13)}, {local_x, local_y}, {pushConsts});
-    } else {
-        s_algo = komputeManager()->getAlgorithm(__func__);
-        s_algo->setTensors({inA, inB, out});
-        s_algo->setWorkgroup({unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)*unsigned(ne13)});
-        s_algo->setPushConstants({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record(s_algo);
-}
-
-static void ggml_vk_get_rows(
-    const std::vector& spirv,
-    const char * suffix,
-    unsigned element_size, unsigned qk,
-    kp::Sequence& seq,
-    const std::shared_ptr& inA,
-    const std::shared_ptr& inB,
-    const std::shared_ptr& out,
-    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
-    int32_t ne00, int32_t nb01, int32_t nb1,
-    uint32_t size
-) {
-    GGML_ASSERT(nb01%element_size == 0);
-    GGML_ASSERT(nb1%sizeof(float) == 0);
-    if (qk) GGML_ASSERT(ne00%qk == 0);
-
-    struct PushConstants {
-        uint32_t inAOff, inBOff, outOff;
-        int32_t ne00, nb01, nb1;
-    } pushConsts {
-        safe_divide(inAOff, element_size), safe_divide(inBOff, 4), safe_divide(outOff, 4),
-        ne00, nb01, nb1
-    };
-
-    auto name = std::string(__func__) + "_" + suffix;
-    std::shared_ptr s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(name)) {
-        s_algo = komputeManager()->algorithm(name, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {size}, {}, {pushConsts});
-    } else {
-        s_algo = komputeManager()->getAlgorithm(name);
-        s_algo->setTensors({inA, inB, out});
-        s_algo->setWorkgroup({size});
-        s_algo->setPushConstants({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record(s_algo);
-}
-
-template 
-static void ggml_vk_get_rows_f32(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f32_comp_spv,
-        kp::shader_data::op_getrows_f32_comp_spv_len);
-
-    ggml_vk_get_rows(spirv, "f32", sizeof(float), 0, std::forward(args)...);
-}
-
-template 
-static void ggml_vk_get_rows_f16(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f16_comp_spv,
-        kp::shader_data::op_getrows_f16_comp_spv_len);
-
-    ggml_vk_get_rows(spirv, "f16", sizeof(half), 0, std::forward(args)...);
-}
-
-template 
-static void ggml_vk_get_rows_q4_0(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q4_0_comp_spv,
-        kp::shader_data::op_getrows_q4_0_comp_spv_len);
-
-    ggml_vk_get_rows(spirv, "q4_0", 1/*We access blocks unaligned*/, QK4_0, std::forward(args)...);
-}
-
-template 
-static void ggml_vk_get_rows_q4_1(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q4_1_comp_spv,
-        kp::shader_data::op_getrows_q4_1_comp_spv_len);
-
-    ggml_vk_get_rows(spirv, "q4_1", 1/*We access blocks unaligned*/, QK4_1, std::forward(args)...);
-}
-
-template 
-static void ggml_vk_get_rows_q6_k(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q6_k_comp_spv,
-        kp::shader_data::op_getrows_q6_k_comp_spv_len);
-    ggml_vk_get_rows(spirv, "q6_k", 1/*We access blocks unaligned*/, QK_NL, std::forward(args)...);
-}
-
-static void ggml_vk_rope(
-    kp::Sequence& seq,
-    const std::shared_ptr& inA,
-    const std::shared_ptr& inB,
-    const std::shared_ptr& inC,
-    const std::shared_ptr& out,
-    uint32_t inAOff, uint32_t inBOff, uint32_t inCOff, uint32_t outOff,
-    ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig,
-    float freq_base, float freq_scale, bool has_freq_factors, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
-    int32_t ne01, int32_t ne02, int32_t ne03,
-    uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
-    int32_t ne0,
-    uint32_t nb0, uint32_t nb1, uint32_t nb2, uint32_t nb3
-) {
-    GGML_ASSERT(src0t == GGML_TYPE_F16 || src0t == GGML_TYPE_F32);
-
-    static const auto spirv_norm_f16 = getSpirvShader(
-        kp::shader_data::op_rope_norm_f16_comp_spv, kp::shader_data::op_rope_norm_f16_comp_spv_len
-    );
-    static const auto spirv_norm_f32 = getSpirvShader(
-        kp::shader_data::op_rope_norm_f32_comp_spv, kp::shader_data::op_rope_norm_f32_comp_spv_len
-    );
-    static const auto spirv_neox_f16 = getSpirvShader(
-        kp::shader_data::op_rope_neox_f16_comp_spv, kp::shader_data::op_rope_neox_f16_comp_spv_len
-    );
-    static const auto spirv_neox_f32 = getSpirvShader(
-        kp::shader_data::op_rope_neox_f32_comp_spv, kp::shader_data::op_rope_neox_f32_comp_spv_len
-    );
-
-    int type_size = src0t == GGML_TYPE_F16 ? 2 : 4;
-
-    GGML_ASSERT(nb03 % type_size == 0);
-    GGML_ASSERT(nb02 % type_size == 0);
-    GGML_ASSERT(nb01 % type_size == 0);
-    GGML_ASSERT(nb00 % type_size == 0);
-    GGML_ASSERT(nb3  % type_size == 0);
-    GGML_ASSERT(nb2  % type_size == 0);
-    GGML_ASSERT(nb1  % type_size == 0);
-    GGML_ASSERT(nb0  % type_size == 0);
-
-    struct PushConstants {
-        uint32_t inAOff, inBOff, inCOff, outOff;
-        int32_t n_dims, mode, n_ctx_orig;
-        float freq_base, freq_scale;
-        bool has_freq_factors;
-        float ext_factor, attn_factor, beta_fast, beta_slow;
-        uint32_t nb00, nb01, nb02, nb03;
-        int32_t ne0;
-        uint32_t nb0, nb1, nb2, nb3;
-    } pushConsts {
-        safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(inCOff, type_size), safe_divide(outOff, type_size),
-        n_dims, mode, n_ctx_orig,
-        freq_base, freq_scale,
-        has_freq_factors,
-        ext_factor, attn_factor, beta_fast, beta_slow,
-        nb00, nb01, nb02, nb03,
-        ne0,
-        nb0, nb1, nb2, nb3
-    };
-
-    auto & inC_ = inC ? inC : inA;
-    const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
-    const bool is_f16 = src0t == GGML_TYPE_F16;
-
-    auto name = std::string(__func__) + (is_neox ? "_neox" : "_norm") + (src0t == GGML_TYPE_F16 ? "_f16" : "_f32");
-    std::shared_ptr s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(name)) {
-        auto & spirv = is_neox ? is_f16 ? spirv_neox_f16 : spirv_neox_f32 : is_f16 ? spirv_norm_f16 : spirv_norm_f32;
-        s_algo = komputeManager()->algorithm(
-            name, s_kompute_context->pool.get(), {inA, inB, inC_, out}, spirv,
-            {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts}
-        );
-    } else {
-        s_algo = komputeManager()->getAlgorithm(name);
-        s_algo->setTensors({inA, inB, inC_, out});
-        s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
-        s_algo->setPushConstants({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record(s_algo);
-}
-
-static void ggml_vk_cpy(
-    const std::vector& spirv,
-    uint32_t in_element_size, uint32_t out_element_size,
-    kp::Sequence& seq,
-    const std::shared_ptr& in,
-    const std::shared_ptr& out,
-    uint32_t inOff, uint32_t outOff,
-    int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03,
-    uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
-    int32_t ne0, int32_t ne1, int32_t ne2,
-    uint32_t nb0, uint32_t nb1, uint32_t nb2, uint32_t nb3
-) {
-    struct PushConstants {
-        uint32_t inOff, outOff;
-        int32_t ne00, ne01, ne02;
-        uint32_t nb00, nb01, nb02, nb03;
-        int32_t ne0, ne1, ne2;
-        uint32_t nb0, nb1, nb2, nb3;
-    } pushConsts {
-        safe_divide(inOff, in_element_size), safe_divide(outOff, out_element_size),
-        ne00, ne01, ne02,
-        nb00, nb01, nb02, nb03,
-        ne0, ne1, ne2,
-        nb0, nb1, nb2, nb3
-    };
-
-    std::string name = std::string(__func__)
-                       + "_i_" + std::to_string(in_element_size)
-                       + "_o_" + std::to_string(out_element_size);
-    std::shared_ptr s_algo = nullptr;
-    if (!komputeManager()->hasAlgorithm(name))
-        s_algo = komputeManager()->algorithm(name, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts});
-    else {
-        s_algo = komputeManager()->getAlgorithm(name);
-        s_algo->setTensors({in, out});
-        s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
-        s_algo->setPushConstants({pushConsts});
-        s_algo->updateDescriptors(s_kompute_context->pool.get());
-    }
-    seq.record(s_algo);
-}
-
-template 
-static void ggml_vk_cpy_f32_f16(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f32_f16_comp_spv,
-        kp::shader_data::op_cpy_f32_f16_comp_spv_len);
-    ggml_vk_cpy(spirv, 4, 2, std::forward(args)...);
-}
-
-template 
-static void ggml_vk_cpy_f32_f32(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f32_f32_comp_spv,
-        kp::shader_data::op_cpy_f32_f32_comp_spv_len);
-    ggml_vk_cpy(spirv, 4, 4, std::forward(args)...);
-}
-
-template 
-static void ggml_vk_cpy_f16_f16(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f16_f16_comp_spv,
-        kp::shader_data::op_cpy_f16_f16_comp_spv_len);
-    ggml_vk_cpy(spirv, 2, 2, std::forward(args)...);
-}
-
-template 
-static void ggml_vk_cpy_f16_f32(Args&&... args) {
-    const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f16_f32_comp_spv,
-        kp::shader_data::op_cpy_f16_f32_comp_spv_len);
-    ggml_vk_cpy(spirv, 2, 4, std::forward(args)...);
-}
-
-static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
-    int64_t n = ggml_nelements(op);
-    switch (op->op) {
-        case GGML_OP_UNARY:
-            if (n % 4 != 0) return false;
-            switch (ggml_get_unary_op(op)) {
-                case GGML_UNARY_OP_GELU:
-                    if (n % 8 != 0) return false;
-                    // fall through
-                case GGML_UNARY_OP_RELU:
-                case GGML_UNARY_OP_SILU:
-                    return ggml_is_contiguous(op->src[0]);
-                default:
-                    ;
-            }
-            break;
-        case GGML_OP_NONE:
-        case GGML_OP_RESHAPE:
-        case GGML_OP_VIEW:
-        case GGML_OP_TRANSPOSE:
-        case GGML_OP_PERMUTE:
-        case GGML_OP_ADD:
-        case GGML_OP_MUL:
-        case GGML_OP_SCALE:
-        case GGML_OP_SOFT_MAX:
-        case GGML_OP_RMS_NORM:
-        case GGML_OP_NORM:
-            return true;
-        case GGML_OP_ROPE:
-            {
-                const int mode = ((const int32_t *) op->op_params)[2];
-                if (mode & GGML_ROPE_TYPE_MROPE) {
-                    return false;
-                }
-                if (mode & GGML_ROPE_TYPE_VISION) {
-                    return false;
-                }
-                return true;
-            }
-        case GGML_OP_DUP:
-        case GGML_OP_CPY:
-        case GGML_OP_CONT:
-            switch (op->src[0]->type) {
-                case GGML_TYPE_F32:
-                case GGML_TYPE_F16:
-                    break;
-                default:
-                    return false;
-            }
-            switch (op->type) {
-                case GGML_TYPE_F32:
-                case GGML_TYPE_F16:
-                    break;
-                default:
-                    return false;
-            }
-            return true;
-        case GGML_OP_DIAG_MASK_INF:
-            return op->ne[3] == 1;
-        case GGML_OP_GET_ROWS:
-            switch (op->src[0]->type) {
-                case GGML_TYPE_F32:
-                case GGML_TYPE_F16:
-                case GGML_TYPE_Q4_0:
-                case GGML_TYPE_Q4_1:
-                case GGML_TYPE_Q6_K:
-                    return op->ne[2] == 1 && op->ne[3] == 1;
-                default:
-                    ;
-            }
-            return false;
-        case GGML_OP_MUL_MAT:
-            if (op->src[1]->type != GGML_TYPE_F32 || ggml_is_transposed(op->src[0]) || ggml_is_transposed(op->src[1]))
-                return false;
-
-            switch (op->src[0]->type) {
-                case GGML_TYPE_F32:
-                    return op->ne[3] == 1;
-                case GGML_TYPE_Q6_K:
-                case GGML_TYPE_F16:
-                case GGML_TYPE_Q8_0:
-                case GGML_TYPE_Q4_0:
-                case GGML_TYPE_Q4_1:
-                case GGML_TYPE_Q4_K:
-                    return true;
-                default:
-                    ;
-            }
-        default:
-            ;
-    }
-    return false;
-
-    GGML_UNUSED(dev);
-}
-
-static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph * gf) {
-    const int n_seq = 8;
-
-    // FIXME: Figure out if we can somehow optimize the size of the pool... right now we're setting
-    // it to the size of the graph, but I think it can be made smaller?
-    ggml_vk_allocate_descriptor_pool(ctx, gf->n_nodes);
-
-    std::vector> sequences(n_seq);
-
-    for (auto& sequence : sequences) {
-        sequence = komputeManager()->sequence();
-    }
-    for (int seq_idx = 0; seq_idx < n_seq; ++seq_idx) {
-        const int n_nodes_per_seq = (gf->n_nodes + n_seq - 1) / n_seq;
-
-        auto& seq = *sequences[seq_idx];
-
-        const int node_start = (seq_idx + 0) * n_nodes_per_seq;
-        const int node_end   = std::min((seq_idx == n_seq - 1) ? gf->n_nodes : (seq_idx + 1) * n_nodes_per_seq, gf->n_nodes);
-
-        bool any_commands_recorded = false;
-
-        for (int i = node_start; i < node_end; ++i) {
-            struct ggml_tensor * src0 = gf->nodes[i]->src[0];
-            struct ggml_tensor * src1 = gf->nodes[i]->src[1];
-            struct ggml_tensor * src2 = gf->nodes[i]->src[2]; GGML_UNUSED(src2);
-            struct ggml_tensor * dst = gf->nodes[i];
-            GGML_ASSERT(dst->data != nullptr);
-
-            if (ggml_is_empty(dst)) {
-                continue;
-            }
-
-            switch (dst->op) {
-                case GGML_OP_NONE:
-                case GGML_OP_RESHAPE:
-                case GGML_OP_VIEW:
-                case GGML_OP_TRANSPOSE:
-                case GGML_OP_PERMUTE:
-                    continue; // noop -> next node
-                default:
-                    break;
-            }
-
-            any_commands_recorded = true;
-
-            const int32_t ne00 = src0 ? src0->ne[0] : 0;
-            const int32_t ne01 = src0 ? src0->ne[1] : 0;
-            const int32_t ne02 = src0 ? src0->ne[2] : 0;
-            const int32_t ne03 = src0 ? src0->ne[3] : 0;
-
-            const uint32_t nb00 = src0 ? src0->nb[0] : 0;
-            const uint32_t nb01 = src0 ? src0->nb[1] : 0;
-            const uint32_t nb02 = src0 ? src0->nb[2] : 0;
-            const uint32_t nb03 = src0 ? src0->nb[3] : 0;
-
-            const int32_t ne10 = src1 ? src1->ne[0] : 0;
-            const int32_t ne11 = src1 ? src1->ne[1] : 0;
-            const int32_t ne12 = src1 ? src1->ne[2] : 0;
-            const int32_t ne13 = src1 ? src1->ne[3] : 0;
-
-            const uint32_t nb10 = src1 ? src1->nb[0] : 0;
-            const uint32_t nb11 = src1 ? src1->nb[1] : 0;
-            const uint32_t nb12 = src1 ? src1->nb[2] : 0;
-            const uint32_t nb13 = src1 ? src1->nb[3] : 0;
-
-            const int32_t ne0 = dst ? dst->ne[0] : 0;
-            const int32_t ne1 = dst ? dst->ne[1] : 0;
-            const int32_t ne2 = dst ? dst->ne[2] : 0;
-//            const int32_t ne3 = dst ? dst->ne[3] : 0;
-
-            const uint32_t nb0 = dst ? dst->nb[0] : 0;
-            const uint32_t nb1 = dst ? dst->nb[1] : 0;
-            const uint32_t nb2 = dst ? dst->nb[2] : 0;
-            const uint32_t nb3 = dst ? dst->nb[3] : 0;
-
-            const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
-            const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
-            const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
-
-            const static std::shared_ptr nullTensor = nullptr;
-            uint32_t off_src0 = 0;
-            uint32_t off_src1 = 0;
-            uint32_t off_src2 = 0;
-            uint32_t off_dst  = 0;
-            const std::shared_ptr& id_src0 = src0 ? ggml_vk_get_tensor(src0, &off_src0) : nullTensor;
-            const std::shared_ptr& id_src1 = src1 ? ggml_vk_get_tensor(src1, &off_src1) : nullTensor;
-            const std::shared_ptr& id_src2 = src2 ? ggml_vk_get_tensor(src2, &off_src2) : nullTensor;
-            const std::shared_ptr& id_dst  = dst  ? ggml_vk_get_tensor(dst,  &off_dst)  : nullTensor;
-
-            switch (dst->op) {
-                case GGML_OP_ADD:
-                    {
-                        if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
-                            // src1 is a row
-                            ggml_vk_addrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4, ne00);
-                        } else {
-                            ggml_vk_add(
-                                seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
-                                ne00, ne01, ne02, ne03,
-                                nb00, nb01, nb02, nb03,
-                                ne10, ne11, ne12, ne13,
-                                nb10, nb11, nb12, nb13,
-                                ne0,
-                                nb0, nb1, nb2, nb3
-                            );
-                        }
-                    } break;
-                case GGML_OP_MUL:
-                    {
-                        ggml_vk_mul(
-                            seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
-                            ne00, ne01, ne02, ne03,
-                            nb00, nb01, nb02, nb03,
-                            ne10, ne11, ne12, ne13,
-                            nb10, nb11, nb12, nb13,
-                            ne0,
-                            nb0, nb1, nb2, nb3
-                        );
-                    } break;
-                case GGML_OP_SCALE:
-                    {
-                        float scale; memcpy(&scale, dst->op_params, sizeof(float));
-
-                        ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst), scale);
-                    } break;
-                case GGML_OP_UNARY:
-                    {
-                        int64_t n = ggml_nelements(dst);
-                        GGML_ASSERT(n % 4 == 0);
-                        switch (ggml_get_unary_op(gf->nodes[i])) {
-                            case GGML_UNARY_OP_SILU:
-                                {
-                                    ggml_vk_silu(seq, id_src0, id_dst, off_src0, off_dst, n/4);
-                                } break;
-                            case GGML_UNARY_OP_RELU:
-                                {
-                                    ggml_vk_relu(seq, id_src0, id_dst, off_src0, off_dst, n/4);
-                                } break;
-                            case GGML_UNARY_OP_GELU:
-                                {
-                                    GGML_ASSERT(n % 8 == 0);
-                                    ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, n/8);
-                                } break;
-                            default:
-                                {
-                                    fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
-                                    GGML_ABORT("fatal error");
-                                }
-                        }
-                    } break;
-                case GGML_OP_SOFT_MAX:
-                    {
-                        float scale;
-                        float max_bias;
-
-                        memcpy(&scale,    (float *)dst->op_params + 0, sizeof(float));
-                        memcpy(&max_bias, (float *)dst->op_params + 1, sizeof(float));
-
-#pragma message("TODO: add ggml_vk_soft_max() F16 src1 support")
-#pragma message("ref:  https://github.com/ggerganov/llama.cpp/pull/5021")
-                        GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
-
-                        const int64_t nrows_x = ggml_nrows(src0);
-                        const int64_t nrows_y = src0->ne[1];
-
-                        const uint32_t n_head      = nrows_x/nrows_y;
-                        const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
-
-                        const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
-                        const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
-
-                        ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale, max_bias, m0, m1, n_head_log2);
-                    } break;
-                case GGML_OP_DIAG_MASK_INF:
-                    {
-                        const int n_past = ((int32_t *)(dst->op_params))[0];
-                        ggml_vk_diag_mask_inf(seq, id_src0, id_dst, off_src0, off_dst, n_past, ne00, ne01, ne02);
-                    } break;
-                case GGML_OP_NORM:
-                    {
-                        float eps;
-                        memcpy(&eps, dst->op_params, sizeof(float));
-                        ggml_vk_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps);
-                    } break;
-                case GGML_OP_RMS_NORM:
-                    {
-                        GGML_ASSERT(ne00 % 4 == 0);
-
-                        float eps;
-                        memcpy(&eps, dst->op_params, sizeof(float));
-                        ggml_vk_rms_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps);
-                    } break;
-                case GGML_OP_MUL_MAT:
-                    {
-                        GGML_ASSERT(ne00 == ne10);
-
-                        GGML_ASSERT(ne12 % ne02 == 0);
-                        GGML_ASSERT(ne13 % ne03 == 0);
-
-                        const uint32_t r2 = ne12/ne02;
-                        const uint32_t r3 = ne13/ne03;
-
-                        if (src1t != GGML_TYPE_F32) {
-                            fprintf(stderr, "%s: %s: Unsupported src1 type: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t);
-                            goto not_implemented;
-                        }
-
-                        if (ggml_is_transposed(src0) ||
-                            ggml_is_transposed(src1)) {
-                            fprintf(stderr, "%s: %s: matmul on tranposed tensor not supported: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t);
-                            goto not_implemented;
-                        }
-
-                        switch (src0t) {
-                            case GGML_TYPE_F32:
-                                ggml_vk_mul_mat_mat_f32(
-                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
-                                    ne00, ne01, ne02, nb01, nb02, ne11, ne12, nb11, nb12, nb1, nb2
-                                );
-                                break;
-                            case GGML_TYPE_F16:
-                                ggml_vk_mul_mat_f16(
-                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
-                                    ne00, ne01, ne02, nb00, nb01, nb02, nb03,
-                                    ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
-                                    ne0, ne1, r2, r3
-                                );
-                                break;
-                            case GGML_TYPE_Q8_0:
-                                ggml_vk_mul_mat_q8_0(
-                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
-                                    ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
-                                    nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
-                                );
-                                break;
-                            case GGML_TYPE_Q4_0:
-                                ggml_vk_mul_mat_q4_0(
-                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
-                                    ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
-                                    nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
-                                );
-                                break;
-                            case GGML_TYPE_Q4_1:
-                                ggml_vk_mul_mat_q4_1(
-                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
-                                    ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
-                                    nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
-                                );
-                                break;
-                            case GGML_TYPE_Q4_K:
-                                ggml_vk_mul_mat_q4_k(
-                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
-                                    ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
-                                    nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
-                                );
-                                break;
-                            case GGML_TYPE_Q6_K:
-                                ggml_vk_mul_mat_q6_k(
-                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
-                                    ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
-                                    nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
-                                );
-                                break;
-                            default: {
-                                fprintf(stderr, "%s: %s: Unsupported quantization: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t);
-                                goto not_implemented;
-                            }
-                        }
-
-                    } break;
-                case GGML_OP_GET_ROWS:
-                    {
-                        if (src0t == GGML_TYPE_F32) {
-                            ggml_vk_get_rows_f32(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
-                        } else if (src0t == GGML_TYPE_F16) {
-                            ggml_vk_get_rows_f16(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
-                        } else if (src0t == GGML_TYPE_Q4_0) {
-                            ggml_vk_get_rows_q4_0(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
-                        } else if (src0t == GGML_TYPE_Q4_1) {
-                            ggml_vk_get_rows_q4_1(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
-                        } else if (src0t == GGML_TYPE_Q6_K) {
-                            ggml_vk_get_rows_q6_k(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
-                        } else {
-                            fprintf(stderr, "%s: %s: Unsupported quantization: %u\n", __func__, ggml_op_name(dst->op), src0t);
-                            goto not_implemented;
-                        }
-                    } break;
-                case GGML_OP_ROPE:
-                    {
-                        GGML_ASSERT(ne10 == ne02);
-                        GGML_ASSERT(src0t == dstt);
-                        // const int n_past = ((int32_t *) dst->op_params)[0];
-                        const int n_dims     = ((int32_t *) dst->op_params)[1];
-                        const int mode       = ((int32_t *) dst->op_params)[2];
-                        // skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
-                        const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
-
-                        const bool has_freq_factors = dst->src[2] != nullptr;
-
-                        float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
-                        memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
-                        memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
-                        memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
-                        memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
-                        memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
-                        memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
-                        ggml_vk_rope(
-                            seq, id_src0, id_src1, id_src2, id_dst, off_src0, off_src1, off_src2, off_dst, src0t, n_dims, mode, n_ctx_orig,
-                            freq_base, freq_scale, has_freq_factors, ext_factor, attn_factor, beta_fast, beta_slow,
-                            ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
-                        );
-                    } break;
-                case GGML_OP_DUP:
-                case GGML_OP_CPY:
-                case GGML_OP_CONT:
-                    {
-                        switch (src0t) {
-                            case GGML_TYPE_F32:
-                                {
-                                    switch (dstt) {
-                                        case GGML_TYPE_F16: ggml_vk_cpy_f32_f16(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
-                                        case GGML_TYPE_F32: ggml_vk_cpy_f32_f32(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
-                                        default: goto not_implemented;
-                                    }
-                                } break;
-                            case GGML_TYPE_F16:
-                                {
-                                    switch (dstt) {
-                                        case GGML_TYPE_F16: ggml_vk_cpy_f16_f16(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
-                                        case GGML_TYPE_F32: ggml_vk_cpy_f16_f32(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break;
-                                    default: goto not_implemented;
-                                } break;
-                            default: goto not_implemented;
-                            }
-                        }
-                    } break;
-                default: goto not_implemented;
-            }
-            continue;
-            not_implemented: {}
-            fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
-            //GGML_ABORT("fatal error");
-        }
-
-        // Evaluate sequence
-        if (any_commands_recorded) {
-            seq.evalAsync();
-        }
-    }
-
-    // Wait for all sequences to finish
-    for (auto& sequence : sequences) {
-        if (sequence->isRunning())
-            sequence->evalAwait();
-    }
-
-    ggml_vk_free_descriptor_pool(ctx);
-}
-
-template<>
-kp::Tensor::TensorDataTypes
-kp::TensorT::dataType()
-{
-    return TensorDataTypes::eFloat;
-}
-
-template<>
-kp::Tensor::TensorDataTypes
-kp::TensorT::dataType()
-{
-    return TensorDataTypes::eUnsignedInt;
-}
-
-////////////////////////////////////////////////////////////////////////////////
-
-// backend interface
-
-struct ggml_backend_kompute_buffer_type_context {
-    int         device;
-    int         device_ref = 0;
-    uint64_t    buffer_alignment;
-    uint64_t    max_alloc;
-    std::string name;
-
-    ggml_backend_kompute_buffer_type_context(int device, uint64_t buffer_alignment, uint64_t max_alloc)
-        : device(device), buffer_alignment(buffer_alignment), max_alloc(max_alloc), name(ggml_kompute_format_name(device)) {}
-};
-
-static void ggml_backend_kompute_device_ref(ggml_backend_buffer_type_t buft) {
-    auto * ctx = static_cast(buft->context);
-
-    if (!ctx->device_ref) {
-        komputeManager()->initializeDevice(
-            ctx->device, {}, {
-                "VK_KHR_shader_float16_int8", "VK_KHR_8bit_storage",
-                "VK_KHR_16bit_storage", "VK_KHR_shader_non_semantic_info"
-            }
-        );
-    }
-
-    assert(ggml_vk_has_device());
-    ctx->device_ref++;
-}
-
-static void ggml_backend_kompute_device_unref(ggml_backend_buffer_type_t buft) {
-    auto * ctx = static_cast(buft->context);
-
-    assert(ctx->device_ref > 0);
-
-    ctx->device_ref--;
-
-    if (!ctx->device_ref) {
-        komputeManager.destroy();
-    }
-}
-
-static void ggml_backend_kompute_buffer_free_buffer(ggml_backend_buffer_t buffer) {
-    auto * memory = (ggml_vk_memory *)buffer->context;
-    if (ggml_vk_has_device()) {
-        ggml_vk_free_memory(*memory);
-    }
-    delete memory;
-}
-
-static void * ggml_backend_kompute_buffer_get_base(ggml_backend_buffer_t buffer) {
-    return ((ggml_vk_memory *)buffer->context)->data;
-}
-
-static void ggml_backend_kompute_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
-    GGML_UNUSED(buffer);
-
-    const auto res = ggml_vk_get_tensor(tensor);
-    GGML_ASSERT(res);
-
-    memcpy((char *)tensor->data + offset, data, size);
-
-    komputeManager()->sequence()->eval({res});
-}
-
-static void ggml_backend_kompute_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
-    GGML_UNUSED(buffer);
-
-    const auto res = ggml_vk_get_tensor(tensor);
-    GGML_ASSERT(res);
-
-    komputeManager()->sequence()->eval({res});
-
-    memcpy(data, (const char *)tensor->data + offset, size);
-}
-
-static void ggml_backend_kompute_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
-    auto * memory = (ggml_vk_memory *)buffer->context;
-    memset(memory->data, value, buffer->size);
-
-    if (memory->stagingBuffer)
-        komputeManager()->sequence()->eval(memory->primaryBuffer, memory->stagingBuffer, memory->size);
-}
-
-static ggml_backend_buffer_i ggml_backend_kompute_buffer_i = {
-    /* .free_buffer     = */ ggml_backend_kompute_buffer_free_buffer,
-    /* .get_base        = */ ggml_backend_kompute_buffer_get_base,
-    /* .init_tensor     = */ NULL,
-    /* .memset_tensor   = */ NULL,
-    /* .set_tensor      = */ ggml_backend_kompute_buffer_set_tensor,
-    /* .get_tensor      = */ ggml_backend_kompute_buffer_get_tensor,
-    /* .cpy_tensor      = */ NULL,
-    /* .clear           = */ ggml_backend_kompute_buffer_clear,
-    /* .reset           = */ NULL,
-};
-
-// default buffer type
-
-static const char * ggml_backend_kompute_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
-    auto * ctx = static_cast(buft->context);
-    return ctx->name.c_str();
-}
-
-static ggml_backend_buffer_t ggml_backend_kompute_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
-    ggml_backend_kompute_device_ref(buft);
-    auto * ctx = new ggml_vk_memory(ggml_vk_allocate(size));
-    return ggml_backend_buffer_init(buft, ggml_backend_kompute_buffer_i, ctx, size);
-}
-
-static size_t ggml_backend_kompute_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
-    auto * ctx = static_cast(buft->context);
-    return ctx->buffer_alignment;
-}
-
-static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
-    auto * ctx = static_cast(buft->context);
-    return ctx->max_alloc;
-}
-
-static ggml_backend_buffer_type_i ggml_backend_kompute_buffer_type_interface = {
-    /* .get_name         = */ ggml_backend_kompute_buffer_type_get_name,
-    /* .alloc_buffer     = */ ggml_backend_kompute_buffer_type_alloc_buffer,
-    /* .get_alignment    = */ ggml_backend_kompute_buffer_type_get_alignment,
-    /* .get_max_size     = */ ggml_backend_vk_buffer_type_get_max_size,
-    /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
-    /* .is_host          = */ NULL,
-};
-
-ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device) {
-    static std::mutex mutex;
-    std::lock_guard lock(mutex);
-
-    auto devices = ggml_vk_available_devices();
-    int32_t device_count = (int32_t) devices.size();
-    GGML_ASSERT(device < device_count);
-    GGML_ASSERT(devices.size() <= GGML_KOMPUTE_MAX_DEVICES);
-
-    static ggml_backend_buffer_type
-        ggml_backend_kompute_buffer_types[GGML_KOMPUTE_MAX_DEVICES];
-
-    static bool ggml_backend_kompute_buffer_type_initialized = false;
-
-    if (!ggml_backend_kompute_buffer_type_initialized) {
-        for (int32_t i = 0; i < device_count; i++) {
-            ggml_backend_kompute_buffer_types[i] = {
-                /* .iface    = */ ggml_backend_kompute_buffer_type_interface,
-                /* .device   = */ ggml_backend_reg_dev_get(ggml_backend_kompute_reg(), i),
-                /* .context  = */ new ggml_backend_kompute_buffer_type_context{ i, devices[i].bufferAlignment, devices[i].maxAlloc },
-            };
-        }
-        ggml_backend_kompute_buffer_type_initialized = true;
-    }
-
-    return &ggml_backend_kompute_buffer_types[device];
-}
-
-// backend
-
-static const char * ggml_backend_kompute_name(ggml_backend_t backend) {
-    auto * ctx = static_cast(backend->context);
-    return ctx->name.c_str();
-}
-
-static void ggml_backend_kompute_free(ggml_backend_t backend) {
-    auto * ctx = static_cast(backend->context);
-
-    assert(ctx == s_kompute_context);
-    s_kompute_context = nullptr;
-    if (ctx != nullptr) {
-        delete ctx;
-    }
-
-    delete backend;
-}
-
-static ggml_status ggml_backend_kompute_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
-    auto * ctx = static_cast(backend->context);
-    ggml_vk_graph_compute(ctx, cgraph);
-    return GGML_STATUS_SUCCESS;
-}
-
-static struct ggml_backend_i kompute_backend_i = {
-    /* .get_name                = */ ggml_backend_kompute_name,
-    /* .free                    = */ ggml_backend_kompute_free,
-    /* .set_tensor_async        = */ NULL,
-    /* .get_tensor_async        = */ NULL,
-    /* .cpy_tensor_async        = */ NULL,
-    /* .synchronize             = */ NULL,
-    /* .graph_plan_create       = */ NULL,
-    /* .graph_plan_free         = */ NULL,
-    /* .graph_plan_update       = */ NULL,
-    /* .graph_plan_compute      = */ NULL,
-    /* .graph_compute           = */ ggml_backend_kompute_graph_compute,
-    /* .event_record            = */ NULL,
-    /* .event_wait              = */ NULL,
-};
-
-static ggml_guid_t ggml_backend_kompute_guid() {
-    static ggml_guid guid = { 0x7b, 0x57, 0xdc, 0xaf, 0xde, 0x12, 0x1d, 0x49, 0xfb, 0x35, 0xfa, 0x9b, 0x18, 0x31, 0x1d, 0xca };
-    return &guid;
-}
-
-ggml_backend_t ggml_backend_kompute_init(int device) {
-    GGML_ASSERT(s_kompute_context == nullptr);
-    s_kompute_context = new ggml_kompute_context(device);
-
-    ggml_backend_t kompute_backend = new ggml_backend {
-        /* .guid      = */ ggml_backend_kompute_guid(),
-        /* .interface = */ kompute_backend_i,
-        /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_kompute_reg(), device),
-        /* .context   = */ s_kompute_context,
-    };
-
-    return kompute_backend;
-}
-
-bool ggml_backend_is_kompute(ggml_backend_t backend) {
-    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_kompute_guid());
-}
-
-static size_t ggml_backend_kompute_get_device_count() {
-    auto devices = ggml_vk_available_devices();
-    return devices.size();
-}
-
-static void ggml_backend_kompute_get_device_description(int device, char * description, size_t description_size) {
-    auto devices = ggml_vk_available_devices();
-    GGML_ASSERT((size_t) device < devices.size());
-    snprintf(description, description_size, "%s", devices[device].name);
-}
-
-static void ggml_backend_kompute_get_device_memory(int device, size_t * free, size_t * total) {
-    auto devices = ggml_vk_available_devices();
-    GGML_ASSERT((size_t) device < devices.size());
-    *total = devices[device].heapSize;
-    *free = devices[device].heapSize;
-}
-
-//////////////////////////
-
-struct ggml_backend_kompute_device_context {
-    int device;
-    std::string name;
-    std::string description;
-};
-
-static const char * ggml_backend_kompute_device_get_name(ggml_backend_dev_t dev) {
-    ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
-    return ctx->name.c_str();
-}
-
-static const char * ggml_backend_kompute_device_get_description(ggml_backend_dev_t dev) {
-    ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
-    return ctx->description.c_str();
-}
-
-static void ggml_backend_kompute_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
-    ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
-    ggml_backend_kompute_get_device_memory(ctx->device, free, total);
-}
-
-static ggml_backend_buffer_type_t ggml_backend_kompute_device_get_buffer_type(ggml_backend_dev_t dev) {
-    ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
-    return ggml_backend_kompute_buffer_type(ctx->device);
-}
-
-static bool ggml_backend_kompute_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
-    if (buft->iface.get_name != ggml_backend_kompute_buffer_type_get_name) {
-        return false;
-    }
-
-    ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
-    ggml_backend_kompute_buffer_type_context * buft_ctx = (ggml_backend_kompute_buffer_type_context *)buft->context;
-
-    return buft_ctx->device == ctx->device;
-}
-
-static enum ggml_backend_dev_type ggml_backend_kompute_device_get_type(ggml_backend_dev_t dev) {
-    GGML_UNUSED(dev);
-    return GGML_BACKEND_DEVICE_TYPE_GPU;
-}
-
-static void ggml_backend_kompute_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
-    props->name        = ggml_backend_kompute_device_get_name(dev);
-    props->description = ggml_backend_kompute_device_get_description(dev);
-    props->type        = ggml_backend_kompute_device_get_type(dev);
-    ggml_backend_kompute_device_get_memory(dev, &props->memory_free, &props->memory_total);
-    props->caps = {
-        /* async                  = */ false,
-        /* host_buffer            = */ false,
-        /* .buffer_from_host_ptr  = */ false,
-        /* events                 = */ false,
-    };
-}
-
-static ggml_backend_t ggml_backend_kompute_device_init(ggml_backend_dev_t dev, const char * params) {
-    GGML_UNUSED(params);
-    ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
-    return ggml_backend_kompute_init(ctx->device);
-}
-
-static bool ggml_backend_kompute_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
-    const int min_batch_size = 32;
-
-    return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
-           (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
-
-    GGML_UNUSED(dev);
-}
-
-static const struct ggml_backend_device_i ggml_backend_kompute_device_i = {
-    /* .get_name             = */ ggml_backend_kompute_device_get_name,
-    /* .get_description      = */ ggml_backend_kompute_device_get_description,
-    /* .get_memory           = */ ggml_backend_kompute_device_get_memory,
-    /* .get_type             = */ ggml_backend_kompute_device_get_type,
-    /* .get_props            = */ ggml_backend_kompute_device_get_props,
-    /* .init_backend         = */ ggml_backend_kompute_device_init,
-    /* .get_buffer_type      = */ ggml_backend_kompute_device_get_buffer_type,
-    /* .get_host_buffer_type = */ NULL,
-    /* .buffer_from_host_ptr = */ NULL,
-    /* .supports_op          = */ ggml_backend_kompute_device_supports_op,
-    /* .supports_buft        = */ ggml_backend_kompute_device_supports_buft,
-    /* .offload_op           = */ ggml_backend_kompute_device_offload_op,
-    /* .event_new            = */ NULL,
-    /* .event_free           = */ NULL,
-    /* .event_synchronize    = */ NULL,
-};
-
-static const char * ggml_backend_kompute_reg_get_name(ggml_backend_reg_t reg) {
-    GGML_UNUSED(reg);
-    return "Kompute";
-}
-
-static size_t ggml_backend_kompute_reg_get_device_count(ggml_backend_reg_t reg) {
-    GGML_UNUSED(reg);
-    return ggml_backend_kompute_get_device_count();
-}
-
-static ggml_backend_dev_t ggml_backend_kompute_reg_get_device(ggml_backend_reg_t reg, size_t device) {
-    static std::vector devices;
-
-    static bool initialized = false;
-
-    {
-        static std::mutex mutex;
-        std::lock_guard lock(mutex);
-        if (!initialized) {
-            for (size_t i = 0; i < ggml_backend_kompute_get_device_count(); i++) {
-                ggml_backend_kompute_device_context * ctx = new ggml_backend_kompute_device_context;
-                char desc[256];
-                ggml_backend_kompute_get_device_description(i, desc, sizeof(desc));
-                ctx->device = i;
-                ctx->name = "Kompute" + std::to_string(i);
-                ctx->description = desc;
-                devices.push_back(new ggml_backend_device {
-                    /* .iface   = */ ggml_backend_kompute_device_i,
-                    /* .reg     = */ reg,
-                    /* .context = */ ctx,
-                });
-            }
-            initialized = true;
-        }
-    }
-
-    GGML_ASSERT(device < devices.size());
-    return devices[device];
-}
-
-static const struct ggml_backend_reg_i ggml_backend_kompute_reg_i = {
-    /* .get_name         = */ ggml_backend_kompute_reg_get_name,
-    /* .get_device_count = */ ggml_backend_kompute_reg_get_device_count,
-    /* .get_device       = */ ggml_backend_kompute_reg_get_device,
-    /* .get_proc_address = */ NULL,
-};
-
-ggml_backend_reg_t ggml_backend_kompute_reg() {
-    static ggml_backend_reg reg = {
-        /* .api_version = */ GGML_BACKEND_API_VERSION,
-        /* .iface       = */ ggml_backend_kompute_reg_i,
-        /* .context     = */ nullptr,
-    };
-
-    return ®
-}
-
-GGML_BACKEND_DL_IMPL(ggml_backend_kompute_reg)
diff --git a/ggml/src/ggml-kompute/kompute-shaders/common.comp b/ggml/src/ggml-kompute/kompute-shaders/common.comp
deleted file mode 100644
index dbe4cf804e6..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/common.comp
+++ /dev/null
@@ -1,112 +0,0 @@
-#extension GL_EXT_shader_16bit_storage: require
-#extension GL_EXT_shader_8bit_storage: require
-#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
-#extension GL_EXT_shader_explicit_arithmetic_types_int8: require
-#extension GL_EXT_shader_explicit_arithmetic_types_int16: require
-#extension GL_EXT_shader_explicit_arithmetic_types_int64: require
-#extension GL_EXT_control_flow_attributes: enable
-#extension GL_KHR_shader_subgroup_arithmetic : require
-#extension GL_EXT_debug_printf : enable
-
-#define QK4_0 32
-#define QK4_1 32
-
-#define GELU_COEF_A 0.044715
-#define SQRT_2_OVER_PI 0.79788456080286535587989211986876
-#define TWOPI_F 6.283185307179586f
-
-#define QK_K 256
-#define K_SCALE_SIZE 12
-
-#define u8BufToU16(buf, idx) (((uint16_t(buf[idx + 1]) << 8)) | buf[idx])
-#define u8BufToFloat16(buf, idx) uint16BitsToHalf u8BufToU16(buf, idx)
-#define u8BufToU32(buf, idx) (((uint32_t u8BufToU16(buf, idx + 2) << 8 | buf[idx + 1]) << 8) | buf[idx])
-#define u8BufToFloat(buf, idx) uintBitsToFloat u8BufToU32(buf, idx)
-
-#define sizeof_block_q4_0 0x12
-struct block_q4_0 {
-    float16_t d;
-    uint8_t qs[QK4_0 / 2];
-};
-mat4 dequantize_q4_0(const block_q4_0 xb, uint il) {
-    const float d1 = il != 0 ? (xb.d / 16.f) : xb.d;
-    const float d2 = d1 / 256.f;
-    const float md = -8.f * xb.d;
-    const uint16_t mask0 = il != 0 ? uint16_t(0x00F0) : uint16_t(0x000F);
-    const uint16_t mask1 = mask0 << 8;
-
-    mat4 reg;
-    for (int i=0;i<8;i++) {
-        uint16_t b = (uint16_t(xb.qs[2 * i + 1]) << 8) | uint16_t(xb.qs[2 * i]);
-        reg[i/2][2*(i%2)+0] = d1 * (b & mask0) + md;
-        reg[i/2][2*(i%2)+1] = d2 * (b & mask1) + md;
-    }
-    return reg;
-}
-
-#define sizeof_block_q4_1 0x14
-struct block_q4_1 {
-    float16_t d;
-    float16_t m;
-    uint8_t qs[QK4_1 / 2];
-};
-mat4 dequantize_q4_1(const block_q4_1 xb, uint il) {
-    const float d1 = il != 0 ? (xb.d / 16.f) : xb.d;
-    const float d2 = d1 / 256.f;
-    const float  m = xb.m;
-    const uint16_t mask0 = il != 0 ? uint16_t(0x00F0) : uint16_t(0x000F);
-    const uint16_t mask1 = mask0 << 8;
-
-    mat4 reg;
-    for (int i=0;i<8;i++) {
-        uint16_t b = (uint16_t(xb.qs[2 * i + 1]) << 8) | uint16_t(xb.qs[2 * i]);
-        reg[i/2][2*(i%2)+0] = ((b & mask0) * d1) + m;
-        reg[i/2][2*(i%2)+1] = ((b & mask1) * d2) + m;
-    }
-    return reg;
-}
-
-#define sizeof_block_q4_k 144
-struct block_q4_k {
-    float16_t d;
-    float16_t dmin;
-    uint8_t scales[K_SCALE_SIZE];
-    uint8_t qs[QK_K/2];
-};
-
-#define sizeof_block_q6_k 210
-struct block_q6_k {
-    uint8_t ql[QK_K/2];      // quants, lower 4 bits
-    uint8_t qh[QK_K/4];      // quants, upper 2 bits
-    int8_t  scales[QK_K/16]; // scales, quantized with 8 bits
-    float16_t d;             // super-block scale
-};
-mat4 dequantize_q6_k(const block_q6_k xb, uint il) {
-    const float16_t d_all = xb.d;
-
-    const uint qlIndex = 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
-    const uint qhIndex = 32*(il/8) + 16*(il&1);
-    float16_t sc = xb.scales[(il%2) + 2 * ((il/2))];
-    il = (il/2) & 3;
-
-    const uint16_t  kmask1 = il>1 ? uint16_t(il>2 ? 192 : 48) : uint16_t(il>0 ? 12 : 3);
-    const uint16_t  kmask2 = il>1 ? uint8_t(0xF0)             : uint8_t(0x0F);
-    const float16_t coef   = il>1 ? float16_t(1.f/16.f)       : float16_t(1.f);
-    const float16_t ml = float16_t(d_all * sc * 32.f);
-    const float16_t dl = float16_t(d_all * sc * coef);
-    mat4 reg;
-    for (int i = 0; i < 16; ++i) {
-        const float16_t q = (il&1) != 0 ? ((xb.ql[qlIndex + i] & kmask2) | ((xb.qh[qhIndex + i] & kmask1) << 2))
-                                        : ((xb.ql[qlIndex + i] & kmask2) | ((xb.qh[qhIndex + i] & kmask1) << 4));
-        reg[i/4][i%4] = dl * q - ml;
-    }
-    return reg;
-}
-
-
-#define QK8_0 32
-// struct block_q8_0 {
-//     float16_t d;         // delta
-//     int8_t    qs[QK8_0]; // quants
-// };
-#define sizeof_block_q8_0 34
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_add.comp b/ggml/src/ggml-kompute/kompute-shaders/op_add.comp
deleted file mode 100644
index b7b76a79dbd..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_add.comp
+++ /dev/null
@@ -1,58 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-layout(local_size_x = 1024) in;
-
-layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
-layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
-layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
-
-layout(push_constant) uniform PushConstants {
-    uint inAOff;
-    uint inBOff;
-    uint outOff;
-    int ne00;
-    int nb00;
-    int nb01;
-    int nb02;
-    int nb03;
-    int ne10;
-    int ne11;
-    int ne12;
-    int ne13;
-    int nb10;
-    int nb11;
-    int nb12;
-    int nb13;
-    int ne0;
-    int nb0;
-    int nb1;
-    int nb2;
-    int nb3;
-  //int offs; // TODO: needed for GGML_OP_ACC, see metal code
-} pcs;
-
-// general-purpose kernel for addition of two tensors
-// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
-// cons: not very efficient
-void main() {
-    const uint i03 = gl_WorkGroupID.z;
-    const uint i02 = gl_WorkGroupID.y;
-    const uint i01 = gl_WorkGroupID.x;
-
-    const uint i13 = i03 % pcs.ne13;
-    const uint i12 = i02 % pcs.ne12;
-    const uint i11 = i01 % pcs.ne11;
-
-    int offs = 0; // TMP (see above)
-
-    uint src0_off = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + offs) / 4);
-    uint src1_off = uint((i13*pcs.nb13 + i12*pcs.nb12 + i11*pcs.nb11       ) / 4);
-    uint dst_off  = uint((i03*pcs.nb3  + i02*pcs.nb2  + i01*pcs.nb1  + offs) / 4);
-
-    for (uint i0 = gl_LocalInvocationID.x; i0 < pcs.ne0; i0 += gl_WorkGroupSize.x) {
-        const uint i10 = i0 % pcs.ne10;
-        out_[pcs.outOff + dst_off + i0] = inA[pcs.inAOff + src0_off + i0] + inB[pcs.inBOff + src1_off + i10];
-    }
-}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp b/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp
deleted file mode 100644
index 2376a6b8f03..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp
+++ /dev/null
@@ -1,25 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-layout(local_size_x = 1) in;
-
-layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
-layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
-layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
-
-layout(push_constant) uniform PushConstants {
-    uint inAOff;
-    uint inBOff;
-    uint outOff;
-    uint row;
-} pcs;
-
-void main() {
-    const uint baseIndex = gl_WorkGroupID.x * 4;
-
-    for (uint x = 0; x < 4; x++) {
-        const uint i = baseIndex + x;
-        out_[i + pcs.outOff] = inA[i + pcs.inAOff] + inB[(i % pcs.row) + pcs.inBOff];
-    }
-}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp b/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp
deleted file mode 100644
index d57247d2dcc..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp
+++ /dev/null
@@ -1,52 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-#define IN_TYPE float16_t
-#define IN_TYPE_SIZE 2
-#define OUT_TYPE float16_t
-#define OUT_TYPE_SIZE 2
-
-layout(local_size_x = 1024) in;
-
-layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
-layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
-
-layout (push_constant) uniform parameter {
-    uint inOff;
-    uint outOff;
-    int ne00;
-    int ne01;
-    int ne02;
-    uint nb00;
-    uint nb01;
-    uint nb02;
-    uint nb03;
-    int ne0;
-    int ne1;
-    int ne2;
-    uint nb0;
-    uint nb1;
-    uint nb2;
-    uint nb3;
-} pcs;
-
-void main() {
-    const uint i03 = gl_WorkGroupID.z;
-    const uint i02 = gl_WorkGroupID.y;
-    const uint i01 = gl_WorkGroupID.x;
-
-    const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
-
-    const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
-    const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
-    const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
-    const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
-
-    const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
-
-    for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
-        const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
-        out_[dst_data+i00] = OUT_TYPE(in_[src]);
-    }
-}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp b/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp
deleted file mode 100644
index b568bcd7b26..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp
+++ /dev/null
@@ -1,52 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-#define IN_TYPE float16_t
-#define IN_TYPE_SIZE 2
-#define OUT_TYPE float
-#define OUT_TYPE_SIZE 4
-
-layout(local_size_x = 1024) in;
-
-layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
-layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
-
-layout (push_constant) uniform parameter {
-    uint inOff;
-    uint outOff;
-    int ne00;
-    int ne01;
-    int ne02;
-    uint nb00;
-    uint nb01;
-    uint nb02;
-    uint nb03;
-    int ne0;
-    int ne1;
-    int ne2;
-    uint nb0;
-    uint nb1;
-    uint nb2;
-    uint nb3;
-} pcs;
-
-void main() {
-    const uint i03 = gl_WorkGroupID.z;
-    const uint i02 = gl_WorkGroupID.y;
-    const uint i01 = gl_WorkGroupID.x;
-
-    const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
-
-    const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
-    const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
-    const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
-    const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
-
-    const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
-
-    for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
-        const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
-        out_[dst_data+i00] = OUT_TYPE(in_[src]);
-    }
-}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp b/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp
deleted file mode 100644
index 99b22834308..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp
+++ /dev/null
@@ -1,52 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-#define IN_TYPE float
-#define IN_TYPE_SIZE 4
-#define OUT_TYPE float16_t
-#define OUT_TYPE_SIZE 2
-
-layout(local_size_x = 1024) in;
-
-layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
-layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
-
-layout (push_constant) uniform parameter {
-    uint inOff;
-    uint outOff;
-    int ne00;
-    int ne01;
-    int ne02;
-    uint nb00;
-    uint nb01;
-    uint nb02;
-    uint nb03;
-    int ne0;
-    int ne1;
-    int ne2;
-    uint nb0;
-    uint nb1;
-    uint nb2;
-    uint nb3;
-} pcs;
-
-void main() {
-    const uint i03 = gl_WorkGroupID.z;
-    const uint i02 = gl_WorkGroupID.y;
-    const uint i01 = gl_WorkGroupID.x;
-
-    const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
-
-    const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
-    const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
-    const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
-    const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
-
-    const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
-
-    for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
-        const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
-        out_[dst_data+i00] = OUT_TYPE(in_[src]);
-    }
-}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp b/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp
deleted file mode 100644
index 2fc998492b7..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp
+++ /dev/null
@@ -1,52 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-#define IN_TYPE float
-#define IN_TYPE_SIZE 4
-#define OUT_TYPE float
-#define OUT_TYPE_SIZE 4
-
-layout(local_size_x = 1024) in;
-
-layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
-layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
-
-layout (push_constant) uniform parameter {
-    uint inOff;
-    uint outOff;
-    int ne00;
-    int ne01;
-    int ne02;
-    uint nb00;
-    uint nb01;
-    uint nb02;
-    uint nb03;
-    int ne0;
-    int ne1;
-    int ne2;
-    uint nb0;
-    uint nb1;
-    uint nb2;
-    uint nb3;
-} pcs;
-
-void main() {
-    const uint i03 = gl_WorkGroupID.z;
-    const uint i02 = gl_WorkGroupID.y;
-    const uint i01 = gl_WorkGroupID.x;
-
-    const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
-
-    const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
-    const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
-    const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
-    const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
-
-    const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
-
-    for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
-        const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
-        out_[dst_data+i00] = OUT_TYPE(in_[src]);
-    }
-}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp b/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp
deleted file mode 100644
index 291c3fc1897..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp
+++ /dev/null
@@ -1,30 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-layout(local_size_x = 1) in;
-
-layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
-layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
-
-layout(push_constant) uniform PushConstants {
-    uint inOff;
-    uint outOff;
-    uint n_past;
-    int ne00;
-    int ne01;
-} pcs;
-
-void main() {
-    const uint i02 = gl_WorkGroupID.z;
-    const uint i01 = gl_WorkGroupID.y;
-    const uint i00 = gl_WorkGroupID.x;
-
-    const uint index = i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00 + i00;
-
-    if (i00 > pcs.n_past + i01) {
-        out_[index + pcs.outOff] = uintBitsToFloat(0xFF800000);
-    } else {
-        out_[index + pcs.outOff] = in_[index + pcs.inOff];
-    }
-}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp b/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp
deleted file mode 100644
index 9d8c53710af..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp
+++ /dev/null
@@ -1,22 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-layout(local_size_x = 1) in;
-
-layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
-layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
-layout(push_constant) uniform PushConstants {
-    uint inOff;
-    uint outOff;
-} pcs;
-
-void main() {
-    const uint baseIndex = gl_WorkGroupID.x * 8;
-
-    for (uint x = 0; x < 8; x++) {
-        const uint i = baseIndex + x;
-        const float y = in_[i + pcs.inOff];
-        out_[i + pcs.outOff] = 0.5*y*(1.0 + tanh(clamp(SQRT_2_OVER_PI*y*(1.0 + GELU_COEF_A*y*y), -15.0, 15.0)));
-    }
-}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp b/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp
deleted file mode 100644
index 1a5581b23a9..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp
+++ /dev/null
@@ -1,17 +0,0 @@
-void main() {
-    const uint i = gl_WorkGroupID.x;
-    const int r = inB[i + pcs.inBOff];
-
-    int z = 0;
-    for (uint ind = gl_LocalInvocationID.x; ind < pcs.ne00/16; ind += gl_WorkGroupSize.x) {
-        const uint inIndex = (r * pcs.nb01 + pcs.inAOff) + ind/NL * SIZE_OF_BLOCK;
-        const mat4 result = dequantize_block(inIndex, ind%NL);
-        for (uint j = 0; j < 4; ++j) {
-            for (uint k = 0; k < 4; ++k) {
-                const uint outIndex = i * pcs.nb1/BYTES_FOR_TYPE + pcs.outOff + z;
-                out_[outIndex] = result[j][k];
-                ++z;
-            }
-        }
-    }
-}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp b/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp
deleted file mode 100644
index 48c93610811..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp
+++ /dev/null
@@ -1,31 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-layout(local_size_x = 1) in;
-
-layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; };
-layout (binding = 1) readonly buffer tensorInB { int inB[]; };
-layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
-
-layout (push_constant) uniform parameter {
-    uint inAOff;
-    uint inBOff;
-    uint outOff;
-    int ne00;
-    int nb01;
-    int nb1;
-} pcs;
-
-void dequantize_row_f16(uint x /*Based from inA unaligned*/, uint y /*Based from out_*/, int k) {
-    for (int j = 0; j < k; j++) {
-        out_[y + j] = inA[x + j];
-    }
-}
-
-void main() {
-    const uint i = gl_WorkGroupID.x;
-    const int r = inB[i + pcs.inBOff];
-
-    dequantize_row_f16(r*pcs.nb01/2/*bytes for float16*/ + pcs.inAOff, i*pcs.nb1/4 + pcs.outOff, pcs.ne00);
-}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp b/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp
deleted file mode 100644
index 9d7acdaf8a8..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp
+++ /dev/null
@@ -1,31 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-layout(local_size_x = 1) in;
-
-layout (binding = 0) readonly buffer tensorInA { float inA[]; };
-layout (binding = 1) readonly buffer tensorInB { int inB[]; };
-layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
-
-layout (push_constant) uniform parameter {
-    uint inAOff;
-    uint inBOff;
-    uint outOff;
-    int ne00;
-    int nb01;
-    int nb1;
-} pcs;
-
-void dequantize_row_f32(uint x /*Based from inA unaligned*/, uint y /*Based from out_*/, int k) {
-    for (int j = 0; j < k; j++) {
-        out_[y + j] = inA[x + j];
-    }
-}
-
-void main() {
-    const uint i = gl_WorkGroupID.x;
-    const int r = inB[i + pcs.inBOff];
-
-    dequantize_row_f32(r*pcs.nb01/4 + pcs.inAOff, i*pcs.nb1/4 + pcs.outOff, pcs.ne00);
-}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp b/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp
deleted file mode 100644
index 32b2e891e8f..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp
+++ /dev/null
@@ -1,38 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-#define NL 2
-#define BYTES_FOR_TYPE 4 /*bytes for float*/
-#define SIZE_OF_BLOCK sizeof_block_q4_0
-
-layout(local_size_x = 1) in;
-
-layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
-layout (binding = 1) readonly buffer tensorInB { int inB[]; };
-layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
-
-layout (push_constant) uniform parameter {
-    uint inAOff;
-    uint inBOff;
-    uint outOff;
-    int ne00;
-    int nb01;
-    int nb1;
-} pcs;
-
-block_q4_0 get_unaligned_block_q4_0(uint index) {
-    block_q4_0 fres;
-    fres.d = u8BufToFloat16(inA, index);
-    [[unroll]] for (uint it = 0; it != QK4_0 / 2; it++) {
-        fres.qs[it] = inA[index+2+it];
-    }
-    return fres;
-}
-
-mat4 dequantize_block(uint index, uint il) {
-    const block_q4_0 block = get_unaligned_block_q4_0(index);
-    return dequantize_q4_0(block, il);
-}
-
-#include "op_getrows.comp"
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp b/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp
deleted file mode 100644
index 87f2fbe17bb..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp
+++ /dev/null
@@ -1,39 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-#define NL 2
-#define BYTES_FOR_TYPE 4 /*bytes for float*/
-#define SIZE_OF_BLOCK sizeof_block_q4_1
-
-layout(local_size_x = 1) in;
-
-layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
-layout (binding = 1) readonly buffer tensorInB { int inB[]; };
-layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
-
-layout (push_constant) uniform parameter {
-    uint inAOff;
-    uint inBOff;
-    uint outOff;
-    int ne00;
-    int nb01;
-    int nb1;
-} pcs;
-
-block_q4_1 get_unaligned_block_q4_1(uint index) {
-    block_q4_1 fres;
-    fres.d = u8BufToFloat16(inA, index);
-    fres.m = u8BufToFloat16(inA, index+2);
-    [[unroll]] for (uint it = 0; it != QK4_1 / 2; it++) {
-        fres.qs[it] = inA[index+4+it];
-    }
-    return fres;
-}
-
-mat4 dequantize_block(uint index, uint il) {
-    const block_q4_1 block = get_unaligned_block_q4_1(index);
-    return dequantize_q4_1(block, il);
-}
-
-#include "op_getrows.comp"
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp b/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp
deleted file mode 100644
index 9ce3545d1ec..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp
+++ /dev/null
@@ -1,44 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-#define NL 16
-#define BYTES_FOR_TYPE 4 /*bytes for float*/
-#define SIZE_OF_BLOCK sizeof_block_q6_k
-
-layout(local_size_x = 1) in;
-
-layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
-layout (binding = 1) readonly buffer tensorInB { int inB[]; };
-layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
-
-layout (push_constant) uniform parameter {
-    uint inAOff;
-    uint inBOff;
-    uint outOff;
-    int ne00;
-    int nb01;
-    int nb1;
-} pcs;
-
-block_q6_k get_unaligned_block_q6_k(uint index) {
-    block_q6_k fres;
-    [[unroll]] for (uint it = 0; it != QK_K / 2; it++) {
-        fres.ql[it] = inA[index + it];
-    }
-    [[unroll]] for (uint it = 0; it != QK_K / 4; it++) {
-        fres.qh[it] = inA[index + QK_K/2 + it];
-    }
-    [[unroll]] for (uint it = 0; it != QK_K / 16; it++) {
-        fres.scales[it] = int8_t(inA[index + QK_K/2 + QK_K/4 + it]);
-    }
-    fres.d = u8BufToFloat16(inA, index + QK_K/2 + QK_K/4 + QK_K/16);
-    return fres;
-}
-
-mat4 dequantize_block(uint index, uint il) {
-    const block_q6_k block = get_unaligned_block_q6_k(index);
-    return dequantize_q6_k(block, il);
-}
-
-#include "op_getrows.comp"
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp
deleted file mode 100644
index c92647c4db1..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp
+++ /dev/null
@@ -1,52 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-layout(local_size_x = 1024) in;
-
-layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
-layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
-layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
-
-layout(push_constant) uniform PushConstants {
-    uint inAOff;
-    uint inBOff;
-    uint outOff;
-    int ne00;
-    int nb00;
-    int nb01;
-    int nb02;
-    int nb03;
-    int ne10;
-    int ne11;
-    int ne12;
-    int ne13;
-    int nb10;
-    int nb11;
-    int nb12;
-    int nb13;
-    int ne0;
-    int nb0;
-    int nb1;
-    int nb2;
-    int nb3;
-} pcs;
-
-void main() {
-    const uint i03 = gl_WorkGroupID.z;
-    const uint i02 = gl_WorkGroupID.y;
-    const uint i01 = gl_WorkGroupID.x;
-
-    const uint i13 = i03 % pcs.ne13;
-    const uint i12 = i02 % pcs.ne12;
-    const uint i11 = i01 % pcs.ne11;
-
-    uint src0_off = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01) / 4);
-    uint src1_off = uint((i13*pcs.nb13 + i12*pcs.nb12 + i11*pcs.nb11) / 4);
-    uint dst_off  = uint((i03*pcs.nb3  + i02*pcs.nb2  + i01*pcs.nb1)  / 4);
-
-    for (uint i0 = gl_LocalInvocationID.x; i0 < pcs.ne0; i0 += gl_WorkGroupSize.x) {
-        const uint i10 = i0 % pcs.ne10;
-        out_[pcs.outOff + dst_off + i0] = inA[pcs.inAOff + src0_off + i0] * inB[pcs.inBOff + src1_off + i10];
-    }
-}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp
deleted file mode 100644
index 0ab1b2fc20e..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp
+++ /dev/null
@@ -1,69 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-#extension GL_KHR_shader_subgroup_arithmetic : require
-
-layout(local_size_x_id = 0) in;
-
-layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; };
-layout (binding = 1) readonly buffer tensorInB { float inB[]; };
-layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
-
-layout (push_constant) uniform parameter {
-    uint inAOff;
-    uint inBOff;
-    uint outOff;
-    int ne00;
-    int ne01;
-    int ne02;
-    uint nb00;
-    uint nb01;
-    uint nb02;
-    uint nb03;
-    int ne10;
-    int ne11;
-    int ne12;
-    uint nb10;
-    uint nb11;
-    uint nb12;
-    uint nb13;
-    int ne0;
-    int ne1;
-    uint r2;
-    uint r3;
-} pcs;
-
-#define N_F16_F32 4
-
-void main() {
-    const uint r0 = gl_WorkGroupID.x;
-    const uint rb = gl_WorkGroupID.y*N_F16_F32;
-    const uint im = gl_WorkGroupID.z;
-
-    const uint i12 = im%pcs.ne12;
-    const uint i13 = im/pcs.ne12;
-
-    const uint offset0 = r0*pcs.nb01 + (i12/pcs.r2)*pcs.nb02 + (i13/pcs.r3)*pcs.nb03;
-
-    const uint x = offset0 / 2 + pcs.inAOff; // Based from inA
-
-    for (uint row = 0; row < N_F16_F32; ++row) {
-        uint r1 = rb + row;
-        if (r1 >= pcs.ne11) {
-            break;
-        }
-
-        const uint y = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff;
-
-        float sumf = 0;
-        for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) {
-            sumf += float(inA[x+i]) * float(inB[y+i]);
-        }
-
-        const float all_sum = subgroupAdd(sumf);
-        if (subgroupElect()) {
-            out_[im*pcs.ne1*pcs.ne0 + r1*pcs.ne0 + r0 + pcs.outOff] = all_sum;
-        }
-    }
-}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp
deleted file mode 100644
index d1ca4ad6c25..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp
+++ /dev/null
@@ -1,51 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-#extension GL_KHR_shader_subgroup_arithmetic : require
-#extension GL_EXT_debug_printf : enable
-
-// device subgroup size
-layout (local_size_x_id = 0) in;
-
-layout(binding = 0) readonly buffer tensorInA { float inA[]; };
-layout(binding = 1) readonly buffer tensorInB { float inB[]; };
-layout(binding = 2) writeonly buffer tensorOut { float out_[]; };
-
-layout(push_constant) uniform parameter {
-  uint inAOff;
-  uint inBOff;
-  uint outOff;
-  int ne00;
-  int ne01;
-  int ne02;
-  int ne11;
-  int ne12;
-  uint nb01;
-  uint nb02;
-  uint nb11;
-  uint nb12;
-  uint nb1;
-  uint nb2;
-}
-pcs;
-
-
-void main() {
-  uvec3 gid = gl_WorkGroupID;
-
-  uint bc_ab = pcs.ne12 > pcs.ne02 ? gid.z / (pcs.ne12 / pcs.ne02) : gid.z;
-  uint bc_ba = pcs.ne02 > pcs.ne12 ? gid.z / (pcs.ne02 / pcs.ne12) : gid.z;
-
-  const uint x = (gid.x*pcs.nb01 + bc_ab*pcs.nb02) / 4 + pcs.inAOff; // Based from inA
-  const uint y = (gid.y*pcs.nb11 + bc_ba*pcs.nb12) / 4 + pcs.inBOff; // based from inB
-  float sum = 0.0f;
-  for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) {
-      sum += float(inA[x+i]) * float(inB[y+i]);
-  }
-
-  const float all_sum = subgroupAdd(sum);
-  if (subgroupElect()) {
-    out_[gid.z*(pcs.nb2/4) + gid.y*(pcs.nb1/4) + gid.x + pcs.outOff] = all_sum;
-  }
-}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp
deleted file mode 100644
index b0cea8bbe67..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp
+++ /dev/null
@@ -1,33 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-#define BLOCKS_IN_QUANT QK4_0
-#define SIZE_OF_BLOCK sizeof_block_q4_0
-#define N_ROWS 4
-
-#include "op_mul_mv_q_n_pre.comp"
-
-// The q4_0 version of this function
-float block_q_n_dot_y(uint block_index, uint yb, uint il) {
-    vec2 acc = vec2(0.0, 0.0);
-    const uint index = (block_index) * SIZE_OF_BLOCK + pcs.inAOff;
-    float d = float(u8BufToFloat16(inA, index));
-    float sumy = 0.0f;
-    for (int i = 0; i < BLOCKS_IN_QUANT/4; i+=2) {
-        const uint16_t b = u8BufToU16(inA, index + 2 + il + i);
-
-        const float yl0 = inB[yb + i];
-        const float yl1 = inB[yb + i + 1];
-        const float yl8 = inB[yb + i + BLOCKS_IN_QUANT/2];
-        const float yl9 = inB[yb + i + BLOCKS_IN_QUANT/2 + 1];
-
-        sumy += yl0 + yl1 + yl8 + yl9;
-
-        acc[0] += yl0 * (b & 0x000F) + yl1 / 256.f * (b & 0x0F00);
-        acc[1] += yl8 / 16.f * (b & 0x00F0) + yl9 / 4096.f * (b & 0xF000);
-    }
-    return d * (sumy * -8.f + acc[0] + acc[1]);
-}
-
-#include "op_mul_mv_q_n.comp"
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp
deleted file mode 100644
index 8582c61a3be..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp
+++ /dev/null
@@ -1,35 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-#define BLOCKS_IN_QUANT QK4_1
-#define SIZE_OF_BLOCK sizeof_block_q4_1
-#define N_ROWS 4
-
-#include "op_mul_mv_q_n_pre.comp"
-
-// The q4_1 version of this function
-float block_q_n_dot_y(uint block_index, uint yb, uint il) {
-    vec2 acc = vec2(0.0, 0.0);
-    const uint index = (block_index) * SIZE_OF_BLOCK + pcs.inAOff;
-    float d = float(u8BufToFloat16(inA, index));
-    float m = float(u8BufToFloat16(inA, index+2));
-
-    float sumy = 0.0f;
-    for (int i = 0; i < BLOCKS_IN_QUANT/4; i+=2) {
-        const uint16_t b = u8BufToU16(inA, index + 4 + il + i);
-
-        const float yl0 = inB[yb + i];
-        const float yl1 = inB[yb + i + 1];
-        const float yl8 = inB[yb + i + BLOCKS_IN_QUANT/2];
-        const float yl9 = inB[yb + i + BLOCKS_IN_QUANT/2 + 1];
-
-        sumy += yl0 + yl1 + yl8 + yl9;
-
-        acc[0] += yl0 * (b & 0x000F) + yl1 / 256.f * (b & 0x0F00);
-        acc[1] += yl8 / 16.f * (b & 0x00F0) + yl9 / 4096.f * (b & 0xF000);
-    }
-    return d * (acc[0] + acc[1]) + sumy * m;
-}
-
-#include "op_mul_mv_q_n.comp"
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp
deleted file mode 100644
index a5752a3a006..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp
+++ /dev/null
@@ -1,140 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-#define N_DST 4
-#define SIZE_OF_BLOCK sizeof_block_q4_k
-
-layout(local_size_x = 4) in;
-layout(local_size_y = 8) in;
-layout(local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer tensorInA { block_q4_k inA[]; };
-layout (binding = 1) readonly buffer tensorInB { float inB[]; };
-layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
-
-layout (push_constant) uniform parameter {
-    uint inAOff;
-    uint inBOff;
-    uint outOff;
-    int ne00;
-    int ne10;
-    int ne0;
-    int ne1;
-    int ne01;
-    int ne02;
-    int ne12;
-    uint nb01;
-    uint nb02;
-    uint nb03;
-    uint nb11;
-    uint nb12;
-    uint nb13;
-    uint r2;
-    uint r3;
-} pcs;
-
-void main() {
-    const uint16_t kmask1 = uint16_t(0x3f3f);
-    const uint16_t kmask2 = uint16_t(0x0f0f);
-    const uint16_t kmask3 = uint16_t(0xc0c0);
-
-    const uint ix = gl_SubgroupInvocationID/8;  // 0...3
-    const uint it = gl_SubgroupInvocationID%8;  // 0...7
-    const uint iq = it/4;     // 0 or 1
-    const uint ir = it%4;     // 0...3
-
-    const uint nb = pcs.ne00/QK_K;
-
-    const uint r0 = gl_WorkGroupID.x;
-    const uint r1 = gl_WorkGroupID.y;
-    const uint im = gl_WorkGroupID.z;
-
-    const uint first_row = r0 * N_DST;
-    const uint ib_row = first_row * nb;
-
-    const uint i12 = im%pcs.ne12;
-    const uint i13 = im/pcs.ne12;
-
-    const uint offset0 = first_row*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK);
-    const uint offset1 =        r1*pcs.nb11 + (i12       )*pcs.nb12 + (i13       )*pcs.nb13;
-
-    const uint xblk = offset0 + pcs.inAOff;
-    const uint y = (offset1 / 4) + pcs.inBOff;
-
-    float yl[16];
-    float yh[16];
-    float sumf[N_DST] = {0.f, 0.f, 0.f, 0.f};
-    float all_sum = 0.f;
-
-    uint y4 = y + ix * QK_K + 64 * iq + 8 * ir;
-
-    for (uint ib = ix; ib < nb; ib += 4) {
-        const uint blk_idx = ib + xblk;
-
-        float sumy[4] = {0.f, 0.f, 0.f, 0.f};
-        for (int i = 0; i < 8; ++i) {
-            yl[i+0] = inB[y4+i+  0]; sumy[0] += yl[i+0];
-            yl[i+8] = inB[y4+i+ 32]; sumy[1] += yl[i+8];
-            yh[i+0] = inB[y4+i+128]; sumy[2] += yh[i+0];
-            yh[i+8] = inB[y4+i+160]; sumy[3] += yh[i+8];
-        }
-
-        for (int row = 0; row < N_DST; row++) {
-            uint row_idx = row * (pcs.nb01 / SIZE_OF_BLOCK);
-
-            uint16_t sc_0 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 0);
-            uint16_t sc_1 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 2);
-            uint16_t sc_2 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 4);
-            uint16_t sc_3 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 6);
-            uint16_t sc_4 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 8);
-
-            uint16_t sc16[4];
-            sc16[0] = sc_0 & kmask1;
-            sc16[1] = sc_2 & kmask1;
-            sc16[2] = ((sc_4 >> 0) & kmask2) | ((sc_0 & kmask3) >> 2);
-            sc16[3] = ((sc_4 >> 4) & kmask2) | ((sc_2 & kmask3) >> 2);
-
-            float acc1[4] = {0.f, 0.f, 0.f, 0.f};
-            float acc2[4] = {0.f, 0.f, 0.f, 0.f};
-            for (int i = 0; i < 8; i += 2) {
-                uint16_t q1 = u8BufToU16(inA[blk_idx + row_idx].qs, 32 * iq + 8 * ir + i);
-                uint16_t q2 = u8BufToU16(inA[blk_idx + row_idx].qs, 64 + 32 * iq + 8 * ir + i);
-                acc1[0] += yl[i+0] * (q1 & 0x000F);
-                acc1[1] += yl[i+1] * (q1 & 0x0F00);
-                acc1[2] += yl[i+8] * (q1 & 0x00F0);
-                acc1[3] += yl[i+9] * (q1 & 0xF000);
-                acc2[0] += yh[i+0] * (q2 & 0x000F);
-                acc2[1] += yh[i+1] * (q2 & 0x0F00);
-                acc2[2] += yh[i+8] * (q2 & 0x00F0);
-                acc2[3] += yh[i+9] * (q2 & 0xF000);
-            }
-
-            uint8_t sc8_0 = uint8_t(sc16[0] & 0xFF);
-            uint8_t sc8_1 = uint8_t(sc16[0] >> 8 );
-            uint8_t sc8_2 = uint8_t(sc16[1] & 0xFF);
-            uint8_t sc8_3 = uint8_t(sc16[1] >> 8 );
-            uint8_t sc8_4 = uint8_t(sc16[2] & 0xFF);
-            uint8_t sc8_5 = uint8_t(sc16[2] >> 8 );
-            uint8_t sc8_6 = uint8_t(sc16[3] & 0xFF);
-            uint8_t sc8_7 = uint8_t(sc16[3] >> 8 );
-
-            float dall = float(inA[blk_idx + row_idx].d);
-            float dmin = float(inA[blk_idx + row_idx].dmin);
-            sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8_0 +
-                               (acc1[2] + 1.f/256.f * acc1[3]) * sc8_1 * 1.f/16.f +
-                               (acc2[0] + 1.f/256.f * acc2[1]) * sc8_4 +
-                               (acc2[2] + 1.f/256.f * acc2[3]) * sc8_5 * 1.f/16.f) -
-                dmin * (sumy[0] * sc8_2 + sumy[1] * sc8_3 + sumy[2] * sc8_6 + sumy[3] * sc8_7);
-        }
-
-        y4 += 4 * QK_K;
-    }
-
-    for (int row = 0; row < N_DST; ++row) {
-        all_sum = subgroupAdd(sumf[row]);
-        if (subgroupElect()) {
-            out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = all_sum;
-        }
-    }
-}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp
deleted file mode 100644
index d331d1a7057..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp
+++ /dev/null
@@ -1,106 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-#define SIZE_OF_BLOCK sizeof_block_q6_k
-
-layout(local_size_x_id = 0) in;
-layout(local_size_y_id = 1) in;
-layout(local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
-layout (binding = 1) readonly buffer tensorInB { float inB[]; };
-layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
-
-layout (push_constant) uniform parameter {
-    uint inAOff;
-    uint inBOff;
-    uint outOff;
-    int ne00;
-    int ne10;
-    int ne0;
-    int ne1;
-    int ne01;
-    int ne02;
-    int ne12;
-    uint nb01;
-    uint nb02;
-    uint nb03;
-    uint nb11;
-    uint nb12;
-    uint nb13;
-    uint r2;
-    uint r3;
-} pcs;
-
-void main() {
-    const uint8_t kmask1 = uint8_t(0x03);
-    const uint8_t kmask2 = uint8_t(0x0C);
-    const uint8_t kmask3 = uint8_t(0x30);
-    const uint8_t kmask4 = uint8_t(0xC0);
-
-    const uint nb = pcs.ne00/QK_K;
-
-    const uint r0 = gl_WorkGroupID.x;
-    const uint r1 = gl_WorkGroupID.y;
-    const uint im = gl_WorkGroupID.z;
-
-    const uint row = (r0 * gl_NumSubgroups + gl_SubgroupID);
-
-    const uint i12 = im%pcs.ne12;
-    const uint i13 = im/pcs.ne12;
-
-    const uint x = row*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK);
-    const uint yy = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff;
-
-    float sumf = 0;
-
-    // bits of invocation ID for gl_SubgroupSize=32:
-    //  x   x   x   x   x
-    //  4   3   2   1   0
-    // (     tid     ) ix
-    //  ip (   il    )
-
-    const uint block_stride = gl_SubgroupSize / 16;         // number of blocks each subgroup processes
-    const uint tid  = gl_SubgroupInvocationID/block_stride; // first block_stride groups have tid=0
-    const uint ix   = gl_SubgroupInvocationID%block_stride; // first block is 0..block_stride-1
-    const uint ip   = tid/8;        // first or second half of block (0 or 1)
-    const uint il   = tid%8;        // each half has 8 parts, one per scale
-    const uint n    = 4;            // 4 scales at a time (and 4 sums)
-    const uint l0   = n*il;         // offset into half-block, 0..28
-    const uint is   = 8*ip + l0/16; // 0, 1, 8, 9
-
-    const uint y_offset = 128*ip + l0;
-    const uint q_offset_l = 64*ip + l0;
-    const uint q_offset_h = 32*ip + l0;
-
-    for (uint i = ix; i < nb; i += block_stride) {
-
-        const uint baseIndex = (x + i) * SIZE_OF_BLOCK + pcs.inAOff;
-
-        const uint qlIndex = q_offset_l;
-        const uint q2Index = qlIndex + QK_K/8;
-        const uint qhIndex = q_offset_h;
-        const uint y = yy + i * QK_K + y_offset;
-
-        float sums[4] = {0.0f, 0.0f, 0.0f, 0.0f};
-        for (uint l = 0; l < n; ++l) {
-            const uint8_t currentQ1 = inA[baseIndex + qlIndex + l];
-            const uint8_t currentQ2 = inA[baseIndex + q2Index + l];
-            const uint8_t currentQh = inA[baseIndex + QK_K/2 + qhIndex + l];
-
-            sums[0] += inB[y+l+ 0] * (int8_t((currentQ1 & 0xF) | ((currentQh & kmask1) << 4)) - 32);
-            sums[1] += inB[y+l+32] * (int8_t((currentQ2 & 0xF) | ((currentQh & kmask2) << 2)) - 32);
-            sums[2] += inB[y+l+64] * (int8_t((currentQ1  >> 4) | ((currentQh & kmask3) << 0)) - 32);
-            sums[3] += inB[y+l+96] * (int8_t((currentQ2  >> 4) | ((currentQh & kmask4) >> 2)) - 32);
-        }
-
-        float d = u8BufToFloat16(inA, baseIndex + QK_K/2 + QK_K/4 + QK_K/16);
-        sumf += d * (sums[0] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + is]) + sums[1] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 2 + is]) + sums[2] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 4 + is]) + sums[3] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 6 + is]));
-    }
-
-    const float tot = subgroupAdd(sumf);
-    if (subgroupElect()) {
-        out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + row + pcs.outOff] = tot;
-    }
-}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp
deleted file mode 100644
index 34d015e90b8..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp
+++ /dev/null
@@ -1,73 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-#include "op_mul_mv_q_n_pre.comp"
-
-#define SIZE_OF_D 2
-
-#define N_DST 4 // each SIMD group works on 4 rows
-#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
-#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
-
-#define NB_Q8_0 8
-
-void main() {
-    // NB: hack to make compatible with AMD GPUs that have a subgroup size of 64
-    if (gl_SubgroupInvocationID > 31)
-        return;
-
-    const int nr  = N_DST;
-    const int nsg = N_SIMDGROUP;
-    const int nw  = N_SIMDWIDTH;
-
-    const int nb = pcs.ne00/QK8_0;
-    const uint r0 = gl_WorkGroupID.x;
-    const uint r1 = gl_WorkGroupID.y;
-    const uint im = gl_WorkGroupID.z;
-
-    const uint first_row = (r0 * nsg + gl_SubgroupID) * nr;
-
-    const uint i12 = im%pcs.ne12;
-    const uint i13 = im/pcs.ne12;
-
-    const uint offset0 = first_row * nb + (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
-
-    const uint x = offset0*sizeof_block_q8_0 + pcs.inAOff; // Based from inA
-    const uint y = r1*pcs.ne10 + im*pcs.ne00*pcs.ne1 + pcs.inBOff; // based from inB
-
-    float yl[NB_Q8_0];
-    float sumf[N_DST]={0.f, 0.f, 0.f, 0.f};
-
-    const uint ix = gl_SubgroupInvocationID.x/4;
-    const uint il = gl_SubgroupInvocationID.x%4;
-
-    uint yb = y + ix * QK8_0 + NB_Q8_0*il;
-
-    // each thread in a SIMD group deals with NB_Q8_0 quants at a time
-    for (uint ib = ix; ib < nb; ib += nw/4) {
-        for (int i = 0; i < NB_Q8_0; ++i) {
-            yl[i] = inB[yb + i];
-        }
-
-        for (int row = 0; row < nr; row++) {
-            const uint block_offset = (ib+row*nb) * sizeof_block_q8_0;
-            float sumq = 0.f;
-            for (int iq = 0; iq < NB_Q8_0; ++iq) {
-                const int8_t qs_iq = int8_t(inA[x + block_offset + SIZE_OF_D + NB_Q8_0*il + iq]);
-                sumq += qs_iq * yl[iq];
-            }
-            const float16_t d = u8BufToFloat16(inA, x + block_offset);
-            sumf[row] += sumq*d;
-        }
-
-        yb += NB_Q8_0 * nw;
-    }
-
-    for (int row = 0; row < nr; ++row) {
-        const float tot = subgroupAdd(sumf[row]);
-        if (subgroupElect() && first_row + row < pcs.ne01) {
-            out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row] = tot;
-        }
-    }
-}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp
deleted file mode 100644
index a6517cc1f19..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp
+++ /dev/null
@@ -1,52 +0,0 @@
-void main() {
-    // NB: hack to make compatible with AMD GPUs that have a subgroup size of 64
-    if (gl_SubgroupInvocationID > 31)
-        return;
-
-    const uint nb = uint(pcs.ne00/BLOCKS_IN_QUANT);
-
-    const uint r0 = gl_WorkGroupID.x;
-    const uint r1 = gl_WorkGroupID.y;
-    const uint im = gl_WorkGroupID.z;
-
-    const uint first_row = (r0 * gl_NumSubgroups + gl_SubgroupID) * N_ROWS;
-
-    const uint i12 = im%pcs.ne12;
-    const uint i13 = im/pcs.ne12;
-
-    // pointers to src0 rows
-    uint ax[N_ROWS];
-    for (int row = 0; row < N_ROWS; ++row) {
-        const uint offset0 = (first_row + row)*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK);
-
-        ax[row] = offset0 + pcs.inAOff;
-    }
-
-    const uint y = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff;
-
-    float sumf[N_ROWS] = {0.0f, 0.0f, 0.0f, 0.0f};
-
-    const uint ix = gl_SubgroupInvocationID/2;
-    const uint il = (BLOCKS_IN_QUANT/4)*(gl_SubgroupInvocationID%2);
-
-    uint yb = y + ix * BLOCKS_IN_QUANT + il;
-
-    //debugPrintfEXT("gl_NumSubgroups=%d, gl_SubgroupID=%d, gl_SubgroupInvocationID=%d, glSubgroupSize=%d, gl_WorkGroupSize.x=%d, gl_WorkGroupSize.y=%d, gl_WorkGroupSize.z=%d\n",
-    //    gl_NumSubgroups, gl_SubgroupID, gl_SubgroupInvocationID, gl_SubgroupSize,
-    //    gl_WorkGroupSize.x, gl_WorkGroupSize.y, gl_WorkGroupSize.z);
-
-    for (uint ib = ix; ib < nb; ib += 16) {
-        for (int row = 0; row < N_ROWS; row++) {
-            sumf[row] += block_q_n_dot_y(ax[row] + ib, yb, il);
-        }
-
-        yb += BLOCKS_IN_QUANT * 16;
-    }
-
-    for (int row = 0; row < N_ROWS; ++row) {
-        const float tot = subgroupAdd(sumf[row]);
-        if (first_row + row < pcs.ne01 && subgroupElect()) {
-            out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = tot;
-        }
-    }
-}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp
deleted file mode 100644
index a9a2f22180f..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp
+++ /dev/null
@@ -1,28 +0,0 @@
-layout(local_size_x_id = 0) in;
-layout(local_size_y = 8) in;
-layout(local_size_z = 1) in;
-
-layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
-layout (binding = 1) readonly buffer tensorInB { float inB[]; };
-layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
-
-layout (push_constant) uniform parameter {
-    uint inAOff;
-    uint inBOff;
-    uint outOff;
-    int  ne00;
-    int  ne01;
-    int  ne02;
-    int  ne10;
-    int  ne12;
-    int  ne0;
-    int  ne1;
-    uint nb01;
-    uint nb02;
-    uint nb03;
-    uint nb11;
-    uint nb12;
-    uint nb13;
-    uint r2;
-    uint r3;
-} pcs;
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp b/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp
deleted file mode 100644
index ad0c3c01b9d..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp
+++ /dev/null
@@ -1,84 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-layout(local_size_x = 256) in;
-
-layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
-layout(binding = 1) buffer restrict tensorOut { float out_[]; };
-
-layout(push_constant) uniform PushConstants {
-    uint inOff;
-    uint outOff;
-    uint ne00;
-    uint nb01;
-    float eps;
-} pcs;
-
-shared float sum[gl_WorkGroupSize.x];
-
-void main() {
-    const uint x = (gl_WorkGroupID.x*pcs.nb01/4) + pcs.inOff; // Based from in_
-    // MEAN
-    // parallel sum
-    sum[gl_LocalInvocationID.x] = 0.0;
-    for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
-        sum[gl_LocalInvocationID.x] += in_[x+i00];
-    }
-
-    // reduce
-    barrier();
-    memoryBarrierShared();
-    [[unroll]] for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) {
-        if (gl_LocalInvocationID.x < i) {
-            sum[gl_LocalInvocationID.x] += sum[gl_LocalInvocationID.x + i];
-        }
-        barrier();
-        memoryBarrierShared();
-    }
-
-    // broadcast
-    if (gl_LocalInvocationID.x == 0) {
-        sum[0] /= float(pcs.ne00);
-    }
-    barrier();
-    memoryBarrierShared();
-    const float mean = sum[0];
-
-    // recenter
-    const uint y = (gl_WorkGroupID.x*pcs.ne00) + pcs.outOff; // Based from out_
-    for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
-        out_[y+i00] = in_[x+i00] - mean;
-    }
-
-    // VARIANCE
-    // parallel sum
-    sum[gl_LocalInvocationID.x] = 0.0;
-    for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
-        sum[gl_LocalInvocationID.x] += out_[y+i00] * out_[y+i00];
-    }
-
-    // reduce
-    barrier();
-    memoryBarrierShared();
-    [[unroll]] for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) {
-        if (gl_LocalInvocationID.x < i) {
-            sum[gl_LocalInvocationID.x] += sum[gl_LocalInvocationID.x + i];
-        }
-        barrier();
-        memoryBarrierShared();
-    }
-
-    // broadcast
-    if (gl_LocalInvocationID.x == 0) {
-        sum[0] /= float(pcs.ne00);
-    }
-    barrier();
-    memoryBarrierShared();
-    const float variance = sum[0];
-
-    const float scale = 1.0f/sqrt(variance + pcs.eps);
-    for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
-        out_[y+i00] *= scale;
-    }
-}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp b/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp
deleted file mode 100644
index 52a601fe6da..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp
+++ /dev/null
@@ -1,21 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-layout(local_size_x = 1) in;
-
-layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
-layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
-layout(push_constant) uniform PushConstants {
-    uint inOff;
-    uint outOff;
-} pcs;
-
-void main() {
-    const uint baseIndex = gl_WorkGroupID.x * 4;
-
-    for (uint x = 0; x < 4; x++) {
-        const uint i = baseIndex + x;
-        out_[i + pcs.outOff] = max(0.0, in_[i + pcs.inOff]);
-    }
-}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp b/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp
deleted file mode 100644
index da658c1601e..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp
+++ /dev/null
@@ -1,53 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-layout(local_size_x = 512) in;
-
-layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
-layout(binding = 1) buffer restrict tensorOut { float out_[]; };
-
-layout(push_constant) uniform PushConstants {
-    uint inOff;
-    uint outOff;
-    uint ne00;
-    uint nb01;
-    float eps;
-} pcs;
-
-shared float sum[gl_WorkGroupSize.x];
-
-void main() {
-    const uint x = (gl_WorkGroupID.x*pcs.nb01/4) + pcs.inOff; // Based from in_
-
-    // parallel sum
-    sum[gl_LocalInvocationID.x] = 0.0;
-    for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
-        sum[gl_LocalInvocationID.x] += in_[x+i00] * in_[x+i00];
-    }
-
-    // reduce
-    barrier();
-    memoryBarrierShared();
-    [[unroll]] for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) {
-        if (gl_LocalInvocationID.x < i) {
-            sum[gl_LocalInvocationID.x] += sum[gl_LocalInvocationID.x + i];
-        }
-        barrier();
-        memoryBarrierShared();
-    }
-
-    // broadcast
-    if (gl_LocalInvocationID.x == 0) {
-        sum[0] /= float(pcs.ne00);
-    }
-    barrier();
-    memoryBarrierShared();
-
-    const float scale = 1.0f/sqrt(sum[0] + pcs.eps);
-
-    const uint y = (gl_WorkGroupID.x*pcs.ne00) + pcs.outOff; // Based from out_
-    for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
-        out_[y+i00] = in_[x+i00] * scale;
-    }
-}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp b/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp
deleted file mode 100644
index 63659cbfe55..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp
+++ /dev/null
@@ -1,52 +0,0 @@
-#version 450
-
-#include "rope_common.comp"
-
-layout(binding = 0) buffer restrict readonly  tensorInA { float16_t inA[]; };
-layout(binding = 1) buffer restrict readonly  tensorInB { int       inB[]; };
-layout(binding = 2) buffer restrict readonly  tensorInC { float     inC[]; };
-layout(binding = 3) buffer restrict writeonly tensorOut { float16_t out_[]; };
-
-void main() {
-    const uint i3 = gl_WorkGroupID.z;
-    const uint i2 = gl_WorkGroupID.y;
-    const uint i1 = gl_WorkGroupID.x;
-
-    float corr_dims[2];
-    rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
-
-    const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
-
-    float theta_base = float(inB[pcs.inBOff + i2]);
-    float inv_ndims = -1.f/pcs.n_dims;
-
-    float cos_theta;
-    float sin_theta;
-
-    for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
-        if (i0 < pcs.n_dims) {
-            uint ic = i0/2;
-
-            float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
-
-            const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
-
-            rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
-
-            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + ic*pcs.nb00) / 2) + pcs.inAOff; // Based from in
-            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + ic*pcs.nb0)  / 2) + pcs.outOff; // Based from out_
-
-            const float x0 = float(inA[src]);
-            const float x1 = float(inA[src+pcs.n_dims/2]);
-
-            out_[dst_data]              = float16_t(x0*cos_theta - x1*sin_theta);
-            out_[dst_data+pcs.n_dims/2] = float16_t(x0*sin_theta + x1*cos_theta);
-        } else {
-            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
-            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 2) + pcs.outOff; // Based from out_
-
-            out_[dst_data]   = inA[src];
-            out_[dst_data+1] = inA[src+1];
-        }
-    }
-}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp b/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp
deleted file mode 100644
index 4df56204d72..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp
+++ /dev/null
@@ -1,52 +0,0 @@
-#version 450
-
-#include "rope_common.comp"
-
-layout(binding = 0) buffer restrict readonly  tensorInA { float inA[]; };
-layout(binding = 1) buffer restrict readonly  tensorInB { int       inB[]; };
-layout(binding = 2) buffer restrict readonly  tensorInC { float inC[]; };
-layout(binding = 3) buffer restrict writeonly tensorOut { float out_[]; };
-
-void main() {
-    const uint i3 = gl_WorkGroupID.z;
-    const uint i2 = gl_WorkGroupID.y;
-    const uint i1 = gl_WorkGroupID.x;
-
-    float corr_dims[2];
-    rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
-
-    const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
-
-    float theta_base = float(inB[pcs.inBOff + i2]);
-    float inv_ndims = -1.f/pcs.n_dims;
-
-    float cos_theta;
-    float sin_theta;
-
-    for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
-        if (i0 < pcs.n_dims) {
-            uint ic = i0/2;
-
-            float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
-
-            const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
-
-            rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
-
-            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + ic*pcs.nb00) / 4) + pcs.inAOff; // Based from in
-            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + ic*pcs.nb0)  / 4) + pcs.outOff; // Based from out_
-
-            const float x0 = inA[src];
-            const float x1 = inA[src+pcs.n_dims/2];
-
-            out_[dst_data]              = x0*cos_theta - x1*sin_theta;
-            out_[dst_data+pcs.n_dims/2] = x0*sin_theta + x1*cos_theta;
-        } else {
-            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
-            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 4) + pcs.outOff; // Based from out_
-
-            out_[dst_data]   = inA[src];
-            out_[dst_data+1] = inA[src+1];
-        }
-    }
-}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp b/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp
deleted file mode 100644
index a3c0eda8bd3..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp
+++ /dev/null
@@ -1,52 +0,0 @@
-#version 450
-
-#include "rope_common.comp"
-
-layout(binding = 0) buffer restrict readonly  tensorInA { float16_t inA[]; };
-layout(binding = 1) buffer restrict readonly  tensorInB { int       inB[]; };
-layout(binding = 2) buffer restrict readonly  tensorInC { float     inC[]; };
-layout(binding = 3) buffer restrict writeonly tensorOut { float16_t out_[]; };
-
-void main() {
-    const uint i3 = gl_WorkGroupID.z;
-    const uint i2 = gl_WorkGroupID.y;
-    const uint i1 = gl_WorkGroupID.x;
-
-    float corr_dims[2];
-    rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
-
-    const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
-
-    float theta_base = float(inB[pcs.inBOff + i2]);
-    float inv_ndims = -1.f/pcs.n_dims;
-
-    float cos_theta;
-    float sin_theta;
-
-    for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
-        if (i0 < pcs.n_dims) {
-            uint ic = i0/2;
-
-            float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
-
-            const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
-
-            rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
-
-            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
-            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 2) + pcs.outOff; // Based from out_
-
-            const float x0 = float(inA[src]);
-            const float x1 = float(inA[src+1]);
-
-            out_[dst_data]   = float16_t(x0*cos_theta - x1*sin_theta);
-            out_[dst_data+1] = float16_t(x0*sin_theta + x1*cos_theta);
-        } else {
-            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
-            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 2) + pcs.outOff; // Based from out_
-
-            out_[dst_data]   = inA[src];
-            out_[dst_data+1] = inA[src+1];
-        }
-    }
-}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp b/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp
deleted file mode 100644
index b7963ae7253..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp
+++ /dev/null
@@ -1,52 +0,0 @@
-#version 450
-
-#include "rope_common.comp"
-
-layout(binding = 0) buffer restrict readonly  tensorInA { float inA[]; };
-layout(binding = 1) buffer restrict readonly  tensorInB { int   inB[]; };
-layout(binding = 2) buffer restrict readonly  tensorInC { float inC[]; };
-layout(binding = 3) buffer restrict writeonly tensorOut { float out_[]; };
-
-void main() {
-    const uint i3 = gl_WorkGroupID.z;
-    const uint i2 = gl_WorkGroupID.y;
-    const uint i1 = gl_WorkGroupID.x;
-
-    float corr_dims[2];
-    rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
-
-    const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
-
-    float theta_base = float(inB[pcs.inBOff + i2]);
-    float inv_ndims = -1.f/pcs.n_dims;
-
-    float cos_theta;
-    float sin_theta;
-
-    for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
-        if (i0 < pcs.n_dims) {
-            uint ic = i0/2;
-
-            float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
-
-            const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
-
-            rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
-
-            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
-            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 4) + pcs.outOff; // Based from out_
-
-            const float x0 = inA[src];
-            const float x1 = inA[src+1];
-
-            out_[dst_data]   = x0*cos_theta - x1*sin_theta;
-            out_[dst_data+1] = x0*sin_theta + x1*cos_theta;
-        } else {
-            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
-            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 4) + pcs.outOff; // Based from out_
-
-            out_[dst_data]   = inA[src];
-            out_[dst_data+1] = inA[src+1];
-        }
-    }
-}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp b/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp
deleted file mode 100644
index bdae2673820..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp
+++ /dev/null
@@ -1,19 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-layout(local_size_x = 1) in;
-
-layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
-layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
-
-layout(push_constant) uniform PushConstants {
-    uint inOff;
-    uint outOff;
-    float scale;
-} pcs;
-
-void main() {
-    const uint i = gl_WorkGroupID.x;
-    out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale;
-}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp b/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp
deleted file mode 100644
index ada69754b2c..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp
+++ /dev/null
@@ -1,23 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-layout(local_size_x = 1) in;
-
-layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
-layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
-
-layout(push_constant) uniform PushConstants {
-    uint inOff;
-    uint outOff;
-    float scale;
-} pcs;
-
-void main() {
-    const uint baseIndex = gl_WorkGroupID.x * 8;
-
-    for (uint x = 0; x < 8; x++) {
-        const uint i = baseIndex + x;
-        out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale;
-    }
-}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp b/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp
deleted file mode 100644
index 0fb8e4b7405..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp
+++ /dev/null
@@ -1,22 +0,0 @@
-#version 450
-
-#include "common.comp"
-
-layout(local_size_x = 1) in;
-
-layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
-layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
-layout(push_constant) uniform PushConstants {
-    uint inOff;
-    uint outOff;
-} pcs;
-
-void main() {
-    const uint baseIndex = gl_WorkGroupID.x * 4;
-
-    for (uint x = 0; x < 4; x++) {
-        const uint i = baseIndex + x;
-        const float y = in_[i + pcs.inOff];
-        out_[i + pcs.outOff] = y / (1.0 + exp(-y));
-    }
-}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp b/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp
deleted file mode 100644
index 4165295bf4b..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp
+++ /dev/null
@@ -1,72 +0,0 @@
-// TODO: implement multi-simd softmax (llama.cpp commit e16b9fa4)
-
-#version 450
-
-#include "common.comp"
-
-layout(local_size_x_id = 0) in;
-
-layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
-layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
-layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
-
-layout(push_constant) uniform PushConstants {
-    uint inAOff;
-    uint inBOff;
-    uint outOff;
-    int ne00;
-    int ne01;
-    int ne02;
-    float scale;
-    float max_bias;
-    float m0;
-    float m1;
-    uint n_head_log2;
-    int mask;
-} pcs;
-
-void main() {
-    if (gl_SubgroupInvocationID > 31)
-        return;
-
-    const uint i03 = gl_WorkGroupID.z;
-    const uint i02 = gl_WorkGroupID.y;
-    const uint i01 = gl_WorkGroupID.x;
-
-    const uint extra_off = i03*pcs.ne02*pcs.ne01*pcs.ne00 + i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00;
-    const uint psrc0 = extra_off + pcs.inAOff; // Based from inA
-    const uint pmask = i01*pcs.ne00 + pcs.inBOff; // Based from inB
-    const uint pdst = extra_off + pcs.outOff; // Based from out_
-
-    float slope = 1.0f;
-
-    // ALiBi
-    if (pcs.max_bias > 0.0f) {
-        int64_t h = i02;
-
-        float base = h < pcs.n_head_log2 ? pcs.m0 : pcs.m1;
-        int64_t exp = h < pcs.n_head_log2 ? h + 1 : 2*(h - pcs.n_head_log2) + 1;
-
-        slope = pow(base, float(exp));
-    }
-
-    // parallel max
-    float localMax = uintBitsToFloat(0xFF800000);
-    for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
-        localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f));
-    }
-    float max_ = subgroupMax(localMax);
-
-    // parallel sum
-    float localSum = 0.0f;
-    for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
-        const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f) - max_);
-        localSum += exp_psrc0;
-        out_[pdst + i00] = exp_psrc0;
-    }
-
-    const float sum = subgroupAdd(localSum);
-    for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
-        out_[pdst + i00] /= sum;
-    }
-}
diff --git a/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp b/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp
deleted file mode 100644
index 0fca640dcc2..00000000000
--- a/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp
+++ /dev/null
@@ -1,71 +0,0 @@
-#include "common.comp"
-
-#define GGML_ROPE_TYPE_NEOX 2
-
-// TODO: use a local size of 32 or more (Metal uses 1024)
-layout(local_size_x = 1) in;
-
-layout (push_constant) uniform parameter {
-    uint inAOff;
-    uint inBOff;
-    uint inCOff;
-    uint outOff;
-    int n_dims;
-    int mode;
-    int n_ctx_orig;
-    float freq_base;
-    float freq_scale;
-    bool has_freq_factors;
-    float ext_factor;
-    float attn_factor;
-    float beta_fast;
-    float beta_slow;
-    uint nb00;
-    uint nb01;
-    uint nb02;
-    uint nb03;
-    int ne0;
-    uint nb0;
-    uint nb1;
-    uint nb2;
-    uint nb3;
-} pcs;
-
-float rope_yarn_ramp(const float low, const float high, const float i0) {
-    const float y = (i0 / 2 - low) / max(0.001f, high - low);
-    return 1.0f - min(1.0f, max(0.0f, y));
-}
-
-// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
-// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
-void rope_yarn(
-    float theta_extrap, float freq_scale, float corr_dims[2], float i0, float ext_factor, float mscale,
-    out float cos_theta, out float sin_theta
-) {
-    // Get n-d rotational scaling corrected for extrapolation
-    float theta_interp = freq_scale * theta_extrap;
-    float theta = theta_interp;
-    if (ext_factor != 0.0f) {
-        float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
-        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
-
-        // Get n-d magnitude scaling corrected for interpolation
-        mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
-    }
-    cos_theta = cos(theta) * mscale;
-    sin_theta = sin(theta) * mscale;
-}
-
-// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
-// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
-float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
-    return n_dims * log(n_ctx_orig / (n_rot * TWOPI_F)) / (2 * log(base));
-}
-
-void rope_yarn_corr_dims(
-    int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, out float dims[2]
-) {
-    // start and end correction dims
-    dims[0] = max(0.0f,         floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
-    dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
-}
diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h
index 752d55c2166..fc6526d6d5d 100644
--- a/ggml/src/ggml-metal/ggml-metal-impl.h
+++ b/ggml/src/ggml-metal/ggml-metal-impl.h
@@ -23,6 +23,9 @@
 #define N_R0_Q8_0 4
 #define N_SG_Q8_0 2
 
+#define N_R0_MXFP4 2
+#define N_SG_MXFP4 2
+
 #define N_R0_Q2_K 4
 #define N_SG_Q2_K 2
 
@@ -126,8 +129,18 @@ typedef struct {
     uint64_t nb2;
     uint64_t nb3;
     uint64_t offs;
+    uint64_t o1[8];
 } ggml_metal_kargs_bin;
 
+typedef struct {
+    int64_t ne0;
+    int64_t ne1;
+    size_t nb01;
+    size_t nb02;
+    size_t nb11;
+    size_t nb21;
+} ggml_metal_kargs_add_id;
+
 typedef struct {
     int32_t  ne00;
     int32_t  ne01;
@@ -240,7 +253,7 @@ typedef struct {
     float    max_bias;
     float    m0;
     float    m1;
-    uint16_t n_head_log2;
+    int32_t  n_head_log2;
     float    logit_softcap;
 } ggml_metal_kargs_flash_attn_ext;
 
@@ -377,8 +390,16 @@ typedef struct {
 typedef struct {
     int32_t  ne00;
     int32_t  ne00_4;
-    uint64_t nb01;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
     float    eps;
+    int32_t  nef1[3];
+    int32_t  nef2[3];
+    int32_t  nef3[3];
+    uint64_t nbf1[3];
+    uint64_t nbf2[3];
+    uint64_t nbf3[3];
 } ggml_metal_kargs_rms_norm;
 
 typedef struct {
@@ -435,6 +456,8 @@ typedef struct{
     uint64_t nb1;
     int32_t  i00;
     int32_t  i10;
+    float    alpha;
+    float    limit;
 } ggml_metal_kargs_glu;
 
 typedef struct {
@@ -484,7 +507,7 @@ typedef struct {
     float    max_bias;
     float    m0;
     float    m1;
-    uint32_t n_head_log2;
+    int32_t  n_head_log2;
 } ggml_metal_kargs_soft_max;
 
 typedef struct {
@@ -519,6 +542,7 @@ typedef struct {
     int64_t  n_group;
     int64_t  n_seq_tokens;
     int64_t  n_seqs;
+    int64_t  s_off;
     uint64_t nb01;
     uint64_t nb02;
     uint64_t nb03;
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
index 83a0739809a..cb8eff4a772 100644
--- a/ggml/src/ggml-metal/ggml-metal.m
+++ b/ggml/src/ggml-metal/ggml-metal.m
@@ -55,6 +55,12 @@
     bool has_residency_sets;
     bool has_bfloat;
     bool use_bfloat;
+    bool use_fusion;
+
+    int debug_fusion;
+
+    // how many times a given op was fused
+    uint64_t fuse_cnt[GGML_OP_COUNT];
 
     size_t max_size;
 
@@ -69,6 +75,9 @@
     /*.has_residency_sets      =*/ false,
     /*.has_bfloat              =*/ false,
     /*.use_bfloat              =*/ false,
+    /*.use_fusion              =*/ true,
+    /*.debug_fusion            =*/ 0,
+    /*.fuse_cnt                =*/ { 0 },
     /*.max_size                =*/ 0,
     /*.name                    =*/ "",
 };
@@ -83,16 +92,14 @@
 
     if (ctx->mtl_device == nil) {
         ctx->mtl_device = MTLCreateSystemDefaultDevice();
-    }
 
-    if (ctx->mtl_device) {
         ctx->has_simdgroup_reduction  = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
         ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
 
         ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
 
 #if defined(GGML_METAL_HAS_RESIDENCY_SETS)
-        ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == NULL;
+        ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
 #endif
 
         ctx->has_bfloat  = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
@@ -103,6 +110,14 @@
 #else
         ctx->use_bfloat = false;
 #endif
+        ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
+
+        {
+            const char * val = getenv("GGML_METAL_FUSION_DEBUG");
+            ctx->debug_fusion = val ? atoi(val) : 0;
+        }
+
+        memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
 
         ctx->max_size = ctx->mtl_device.maxBufferLength;
 
@@ -122,6 +137,18 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
     ctx->mtl_device_ref_count--;
 
     if (ctx->mtl_device_ref_count == 0) {
+        if (ctx->debug_fusion > 0) {
+            fprintf(stderr, "%s: fusion stats:\n", __func__);
+            for (int i = 0; i < GGML_OP_COUNT; i++) {
+                if (ctx->fuse_cnt[i] == 0) {
+                    continue;
+                }
+
+                // note: cannot use ggml_log here
+                fprintf(stderr, "%s: - %s: %" PRIu64 "\n", __func__, ggml_op_name((enum ggml_op) i), ctx->fuse_cnt[i]);
+            }
+        }
+
         if (ctx->mtl_lock) {
             [ctx->mtl_lock release];
             ctx->mtl_lock = nil;
@@ -147,13 +174,28 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
 
 enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_ADD,
-    GGML_METAL_KERNEL_TYPE_ADD_ROW,
+    GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,
+    GGML_METAL_KERNEL_TYPE_ADD_FUSE_3,
+    GGML_METAL_KERNEL_TYPE_ADD_FUSE_4,
+    GGML_METAL_KERNEL_TYPE_ADD_FUSE_5,
+    GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,
+    GGML_METAL_KERNEL_TYPE_ADD_FUSE_7,
+    GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,
+    GGML_METAL_KERNEL_TYPE_ADD_ROW_C4,
+    GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2,
+    GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3,
+    GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4,
+    GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5,
+    GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6,
+    GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7,
+    GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8,
     GGML_METAL_KERNEL_TYPE_SUB,
-    GGML_METAL_KERNEL_TYPE_SUB_ROW,
+    GGML_METAL_KERNEL_TYPE_SUB_ROW_C4,
     GGML_METAL_KERNEL_TYPE_MUL,
-    GGML_METAL_KERNEL_TYPE_MUL_ROW,
+    GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,
     GGML_METAL_KERNEL_TYPE_DIV,
-    GGML_METAL_KERNEL_TYPE_DIV_ROW,
+    GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,
+    GGML_METAL_KERNEL_TYPE_ADD_ID,
     GGML_METAL_KERNEL_TYPE_REPEAT_F32,
     GGML_METAL_KERNEL_TYPE_REPEAT_F16,
     GGML_METAL_KERNEL_TYPE_REPEAT_I32,
@@ -173,6 +215,12 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
     GGML_METAL_KERNEL_TYPE_SILU,
     GGML_METAL_KERNEL_TYPE_SILU_4,
     GGML_METAL_KERNEL_TYPE_ELU,
+    GGML_METAL_KERNEL_TYPE_ABS,
+    GGML_METAL_KERNEL_TYPE_SGN,
+    GGML_METAL_KERNEL_TYPE_STEP,
+    GGML_METAL_KERNEL_TYPE_HARDSWISH,
+    GGML_METAL_KERNEL_TYPE_HARDSIGMOID,
+    GGML_METAL_KERNEL_TYPE_EXP,
     GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
     GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
     GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
@@ -187,6 +235,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
     GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
@@ -212,6 +261,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
     GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
     GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
     GGML_METAL_KERNEL_TYPE_RMS_NORM,
+    GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL,
+    GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD,
     GGML_METAL_KERNEL_TYPE_L2_NORM,
     GGML_METAL_KERNEL_TYPE_GROUP_NORM,
     GGML_METAL_KERNEL_TYPE_NORM,
@@ -237,6 +288,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
     GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
     GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
     GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
@@ -261,6 +313,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
     GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
     GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
     GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5,
     GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,
     GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,
     GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,
@@ -302,6 +358,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
@@ -324,6 +381,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
     GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
@@ -348,6 +406,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16,
@@ -530,6 +589,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
     GGML_METAL_KERNEL_TYPE_REGLU,
     GGML_METAL_KERNEL_TYPE_GEGLU,
     GGML_METAL_KERNEL_TYPE_SWIGLU,
+    GGML_METAL_KERNEL_TYPE_SWIGLU_OAI,
     GGML_METAL_KERNEL_TYPE_GEGLU_ERF,
     GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,
     GGML_METAL_KERNEL_TYPE_SUM_ROWS,
@@ -1129,13 +1189,28 @@ @implementation GGMLMetalClass
         // simd_sum and simd_max requires MTLGPUFamilyApple7
 
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD,                             add,                             true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW,                         add_row,                         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,                      add_fuse_2,                      true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_3,                      add_fuse_3,                      true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_4,                      add_fuse_4,                      true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_5,                      add_fuse_5,                      true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,                      add_fuse_6,                      true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_7,                      add_fuse_7,                      true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,                      add_fuse_8,                      true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4,                      add_row_c4,                      true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2,               add_row_c4_fuse_2,               true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3,               add_row_c4_fuse_3,               true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4,               add_row_c4_fuse_4,               true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5,               add_row_c4_fuse_5,               true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6,               add_row_c4_fuse_6,               true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7,               add_row_c4_fuse_7,               true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8,               add_row_c4_fuse_8,               true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB,                             sub,                             true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW,                         sub_row,                         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW_C4,                      sub_row_c4,                      true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL,                             mul,                             true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW,                         mul_row,                         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,                      mul_row_c4,                      true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV,                             div,                             true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW,                         div_row,                         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,                      div_row_c4,                      true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ID,                          add_id,                          true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32,                      repeat_f32,                      true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16,                      repeat_f16,                      true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32,                      repeat_i32,                      true);
@@ -1155,6 +1230,12 @@ @implementation GGMLMetalClass
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU,                            silu,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4,                          silu_4,                          true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU,                             elu,                             true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ABS,                             abs,                             true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SGN,                             sgn,                             true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_STEP,                            step,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSWISH,                       hardswish,                       true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSIGMOID,                     hardsigmoid,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_EXP,                             exp,                             true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,                    soft_max_f16,                    has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,                  soft_max_f16_4,                  has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,                    soft_max_f32,                    has_simdgroup_reduction);
@@ -1169,6 +1250,7 @@ @implementation GGMLMetalClass
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,                   get_rows_q5_0,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,                   get_rows_q5_1,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,                   get_rows_q8_0,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4,                  get_rows_mxfp4,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,                   get_rows_q2_K,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,                   get_rows_q3_K,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,                   get_rows_q4_K,                   true);
@@ -1194,6 +1276,8 @@ @implementation GGMLMetalClass
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,                   set_rows_q5_1,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,                 set_rows_iq4_nl,                 true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM,                        rms_norm,                        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL,                    rms_norm_mul,                    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD,                rms_norm_mul_add,                has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM,                         l2_norm,                         has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                      group_norm,                      has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM,                            norm,                            true);
@@ -1219,6 +1303,7 @@ @implementation GGMLMetalClass
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,                 mul_mv_q5_0_f32,                 has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,                 mul_mv_q5_1_f32,                 has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,                 mul_mv_q8_0_f32,                 has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,                mul_mv_mxfp4_f32,                has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,         mul_mv_ext_f16_f32_r1_2,         has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,         mul_mv_ext_f16_f32_r1_3,         has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,         mul_mv_ext_f16_f32_r1_4,         has_simdgroup_reduction);
@@ -1243,6 +1328,10 @@ @implementation GGMLMetalClass
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,        mul_mv_ext_q8_0_f32_r1_3,        has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,        mul_mv_ext_q8_0_f32_r1_4,        has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,        mul_mv_ext_q8_0_f32_r1_5,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2,       mul_mv_ext_mxfp4_f32_r1_2,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3,       mul_mv_ext_mxfp4_f32_r1_3,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4,       mul_mv_ext_mxfp4_f32_r1_4,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5,       mul_mv_ext_mxfp4_f32_r1_5,       has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,        mul_mv_ext_q4_K_f32_r1_2,        has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,        mul_mv_ext_q4_K_f32_r1_3,        has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,        mul_mv_ext_q4_K_f32_r1_4,        has_simdgroup_reduction);
@@ -1284,6 +1373,7 @@ @implementation GGMLMetalClass
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,              mul_mv_id_q5_0_f32,              has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,              mul_mv_id_q5_1_f32,              has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,              mul_mv_id_q8_0_f32,              has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32,             mul_mv_id_mxfp4_f32,             has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,              mul_mv_id_q2_K_f32,              has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,              mul_mv_id_q3_K_f32,              has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,              mul_mv_id_q4_K_f32,              has_simdgroup_reduction);
@@ -1306,6 +1396,8 @@ @implementation GGMLMetalClass
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,                 mul_mm_q5_0_f32,                 has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,                 mul_mm_q5_1_f32,                 has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,                 mul_mm_q8_0_f32,                 has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32,                mul_mm_mxfp4_f32,                has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32,                mul_mm_mxfp4_f32,                has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,                 mul_mm_q2_K_f32,                 has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,                 mul_mm_q3_K_f32,                 has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,                 mul_mm_q4_K_f32,                 has_simdgroup_mm);
@@ -1330,6 +1422,7 @@ @implementation GGMLMetalClass
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16,              mul_mm_id_q5_0_f16,              has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16,              mul_mm_id_q5_1_f16,              has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16,              mul_mm_id_q8_0_f16,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16,             mul_mm_id_mxfp4_f16,             has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16,              mul_mm_id_q2_K_f16,              has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16,              mul_mm_id_q3_K_f16,              has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16,              mul_mm_id_q4_K_f16,              has_simdgroup_mm);
@@ -1512,6 +1605,7 @@ @implementation GGMLMetalClass
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU,                           reglu,                           true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU,                           geglu,                           true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU,                          swiglu,                          true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU_OAI,                      swiglu_oai,                      true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF,                       geglu_erf,                       true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,                     geglu_quick,                     true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                        sum_rows,                        true);
@@ -1688,6 +1782,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
                 case GGML_UNARY_OP_SILU:
                 case GGML_UNARY_OP_ELU:
                 case GGML_UNARY_OP_NEG:
+                case GGML_UNARY_OP_ABS:
+                case GGML_UNARY_OP_SGN:
+                case GGML_UNARY_OP_STEP:
+                case GGML_UNARY_OP_HARDSWISH:
+                case GGML_UNARY_OP_HARDSIGMOID:
+                case GGML_UNARY_OP_EXP:
                     return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
                 default:
                     return false;
@@ -1697,6 +1797,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
                 case GGML_GLU_OP_REGLU:
                 case GGML_GLU_OP_GEGLU:
                 case GGML_GLU_OP_SWIGLU:
+                case GGML_GLU_OP_SWIGLU_OAI:
                 case GGML_GLU_OP_GEGLU_ERF:
                 case GGML_GLU_OP_GEGLU_QUICK:
                     return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
@@ -1714,6 +1815,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
         case GGML_OP_SUB:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
+        case GGML_OP_ADD_ID:
             return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_ACC:
         case GGML_OP_REPEAT:
@@ -1875,9 +1977,10 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
     }
 }
 
-static bool ggml_metal_encode_node(
+static int ggml_metal_encode_node(
                         ggml_backend_t   backend,
                                    int   idx,
+                                   int   idx_end,
           id   encoder,
             struct ggml_metal_mem_pool * mem_pool) {
     struct ggml_backend_metal_context        * ctx     = backend->context;
@@ -1885,7 +1988,10 @@ static bool ggml_metal_encode_node(
 
     struct ggml_cgraph * gf = ctx->gf;
 
-    struct ggml_tensor * node = ggml_graph_node(gf, idx);
+    enum ggml_op ops[8];
+
+    struct ggml_tensor ** nodes = ggml_graph_nodes(gf) + idx;
+    struct ggml_tensor *  node  = nodes[0];
 
     //GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
 
@@ -1895,7 +2001,7 @@ static bool ggml_metal_encode_node(
     struct ggml_tensor * dst  = node;
 
     if (ggml_is_empty(dst)) {
-        return true;
+        return 1;
     }
 
     switch (dst->op) {
@@ -1906,7 +2012,7 @@ static bool ggml_metal_encode_node(
         case GGML_OP_PERMUTE:
             {
                 // noop -> next node
-            } return true;
+            } return 1;
         default:
             {
             } break;
@@ -1961,6 +2067,7 @@ static bool ggml_metal_encode_node(
 
     const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
     const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
+    const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT;
     const enum ggml_type dstt  = dst  ? dst->type  : GGML_TYPE_COUNT;
 
     size_t offs_src0 = 0;
@@ -1973,6 +2080,8 @@ static bool ggml_metal_encode_node(
     id id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
     id id_dst  = dst  ? ggml_metal_get_buffer(dst,  &offs_dst)  : nil;
 
+    int n_fuse = 1;
+
 #if 0
     GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
     if (src0) {
@@ -2044,37 +2153,15 @@ static bool ggml_metal_encode_node(
                 GGML_ASSERT(src0t == GGML_TYPE_F32);
                 GGML_ASSERT(src1t == GGML_TYPE_F32);
 
+                GGML_ASSERT(ggml_is_contiguous_rows(src0));
+                GGML_ASSERT(ggml_is_contiguous_rows(src1));
+
                 const size_t offs = 0;
 
                 bool bcast_row = false;
 
                 id pipeline = nil;
 
-                if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
-                    GGML_ASSERT(ggml_is_contiguous(src0));
-
-                    // src1 is a row
-                    GGML_ASSERT(ne11 == 1);
-
-                    switch (dst->op) {
-                        case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
-                        case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
-                        case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
-                        case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
-                        default: GGML_ABORT("fatal error");
-                    }
-
-                    bcast_row = true;
-                } else {
-                    switch (dst->op) {
-                        case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
-                        case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
-                        case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
-                        case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
-                        default: GGML_ABORT("fatal error");
-                    }
-                }
-
                 ggml_metal_kargs_bin args = {
                     /*.ne00 =*/ ne00,
                     /*.ne01 =*/ ne01,
@@ -2101,12 +2188,119 @@ static bool ggml_metal_encode_node(
                     /*.nb2  =*/ nb2,
                     /*.nb3  =*/ nb3,
                     /*.offs =*/ offs,
+                    /*.o1   =*/ { offs_src1 },
                 };
 
+                // c[0] = add(a,    b[0])
+                // c[1] = add(c[0], b[1])
+                // c[2] = add(c[1], b[2])
+                // ...
+                if (ctx_dev->use_fusion) {
+                    ops[0] = GGML_OP_ADD;
+                    ops[1] = GGML_OP_ADD;
+                    ops[2] = GGML_OP_ADD;
+                    ops[3] = GGML_OP_ADD;
+                    ops[4] = GGML_OP_ADD;
+                    ops[5] = GGML_OP_ADD;
+                    ops[6] = GGML_OP_ADD;
+                    ops[7] = GGML_OP_ADD;
+
+                    size_t offs_fuse;
+                    id id_fuse;
+
+                    // note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing nodes
+                    //       across splits. idx_end indicates the last node in the current split
+                    for (n_fuse = 0; n_fuse <= 6 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
+                        if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
+                            break;
+                        }
+
+                        if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) {
+                            break;
+                        }
+
+                        // b[0] === b[1] === ...
+                        if (!ggml_are_same_layout(nodes[n_fuse]->src[1], nodes[n_fuse + 1]->src[1])) {
+                            break;
+                        }
+
+                        // only fuse nodes if src1 is in the same Metal buffer
+                        id_fuse = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse);
+                        if (id_fuse != id_src1) {
+                            break;
+                        }
+
+                        ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++;
+
+                        args.o1[n_fuse + 1] = offs_fuse;
+                    }
+
+                    ++n_fuse;
+
+                    if (ctx_dev->debug_fusion > 1 && n_fuse > 1) {
+                        GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse);
+                    }
+                }
+
+                if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
+                    GGML_ASSERT(ggml_is_contiguous(src0));
+
+                    // src1 is a row
+                    GGML_ASSERT(ne11 == 1);
+
+                    switch (dst->op) {
+                        case GGML_OP_ADD:
+                            {
+                                switch (n_fuse) {
+                                    case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4       ].pipeline; break;
+                                    case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2].pipeline; break;
+                                    case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3].pipeline; break;
+                                    case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4].pipeline; break;
+                                    case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5].pipeline; break;
+                                    case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6].pipeline; break;
+                                    case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7].pipeline; break;
+                                    case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8].pipeline; break;
+                                    default: GGML_ABORT("fatal error");
+                                }
+                            } break;
+                        case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW_C4].pipeline; break;
+                        case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW_C4].pipeline; break;
+                        case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW_C4].pipeline; break;
+                        default: GGML_ABORT("fatal error");
+                    }
+
+                    bcast_row = true;
+                } else {
+                    switch (dst->op) {
+                        case GGML_OP_ADD:
+                            {
+                                switch (n_fuse) {
+                                    case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD       ].pipeline; break;
+                                    case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline; break;
+                                    case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_3].pipeline; break;
+                                    case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline; break;
+                                    case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_5].pipeline; break;
+                                    case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline; break;
+                                    case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_7].pipeline; break;
+                                    case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline; break;
+                                    default: GGML_ABORT("fatal error");
+                                }
+                            } break;
+                        case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
+                        case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
+                        case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
+                        default: GGML_ABORT("fatal error");
+                    }
+                }
+
+                if (n_fuse > 1) {
+                    id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
+                }
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBytes:&args length:sizeof(args) atIndex:0];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
-                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+                [encoder setBuffer:id_src1 offset:0         atIndex:2];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:3];
 
                 if (bcast_row) {
@@ -2114,11 +2308,47 @@ static bool ggml_metal_encode_node(
 
                     [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                 } else {
-                    const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
+                    int nth = 32;
+
+                    while (16*nth < ne0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
+                        nth *= 2;
+                    }
 
                     [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                 }
             } break;
+        case GGML_OP_ADD_ID:
+            {
+                GGML_ASSERT(src0t == GGML_TYPE_F32);
+                GGML_ASSERT(src1t == GGML_TYPE_F32);
+                GGML_ASSERT(src2t == GGML_TYPE_I32);
+                GGML_ASSERT(dstt  == GGML_TYPE_F32);
+
+                GGML_ASSERT(ggml_is_contiguous_rows(src0));
+
+                id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ID].pipeline;
+
+                ggml_metal_kargs_add_id args = {
+                    /*.ne0  =*/ ne0,
+                    /*.ne1  =*/ ne1,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.nb11 =*/ nb11,
+                    /*.nb21 =*/ nb21,
+
+                };
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+                [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:4];
+
+                const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
+
+                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
         case GGML_OP_REPEAT:
             {
                 id pipeline;
@@ -2239,12 +2469,13 @@ static bool ggml_metal_encode_node(
                     /*.nb2  =*/ pnb2,
                     /*.nb3  =*/ pnb3,
                     /*.offs =*/ offs,
+                    /*.o1   =*/ { offs_src1},
                 };
 
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBytes:&args length:sizeof(args) atIndex:0];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
-                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+                [encoder setBuffer:id_src1 offset:0         atIndex:2];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:3];
 
                 const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
@@ -2439,6 +2670,78 @@ static bool ggml_metal_encode_node(
 
                     [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                 } break;
+                case GGML_UNARY_OP_ABS:
+                {
+                    id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ABS].pipeline;
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                    const int64_t n = ggml_nelements(dst);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
+                case GGML_UNARY_OP_SGN:
+                {
+                    id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SGN].pipeline;
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                    const int64_t n = ggml_nelements(dst);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
+                case GGML_UNARY_OP_STEP:
+                {
+                    id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_STEP].pipeline;
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                    const int64_t n = ggml_nelements(dst);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
+                case GGML_UNARY_OP_HARDSWISH:
+                {
+                    id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSWISH].pipeline;
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                    const int64_t n = ggml_nelements(dst);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
+                case GGML_UNARY_OP_HARDSIGMOID:
+                {
+                    id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSIGMOID].pipeline;
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                    const int64_t n = ggml_nelements(dst);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
+                case GGML_UNARY_OP_EXP:
+                {
+                    id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_EXP].pipeline;
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                    const int64_t n = ggml_nelements(dst);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
                 default:
                 {
                     GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
@@ -2465,6 +2768,9 @@ static bool ggml_metal_encode_node(
                     case GGML_GLU_OP_SWIGLU:
                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
                         break;
+                    case GGML_GLU_OP_SWIGLU_OAI:
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU_OAI].pipeline;
+                        break;
                     case GGML_GLU_OP_GEGLU_ERF:
                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline;
                         break;
@@ -2475,7 +2781,9 @@ static bool ggml_metal_encode_node(
                         GGML_ABORT("fatal error");
                 }
 
-                const int32_t swp = ((const int32_t *) dst->op_params)[1];
+                const int32_t swp = ggml_get_op_params_i32(dst, 1);
+                const float alpha = ggml_get_op_params_f32(dst, 2);
+                const float limit = ggml_get_op_params_f32(dst, 3);
 
                 const int32_t i00 = swp ? ne0 : 0;
                 const int32_t i10 = swp ? 0 : ne0;
@@ -2489,6 +2797,8 @@ static bool ggml_metal_encode_node(
                     /*.nb1  =*/ nb1,
                     /*.i00  =*/ src1 ? 0 : i00,
                     /*.i10  =*/ src1 ? 0 : i10,
+                    /*.alpha=*/ alpha,
+                    /*.limit=*/ limit
                 };
 
                 [encoder setComputePipelineState:pipeline];
@@ -2674,7 +2984,7 @@ static bool ggml_metal_encode_node(
                 id h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
                 if (!h_src0) {
                     GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
-                    return false;
+                    return 0;
                 }
 
                 offs_src0 = 0;
@@ -2747,8 +3057,13 @@ static bool ggml_metal_encode_node(
                 } else {
                     [encoder setBuffer:h_src0 offset:offs_src0  atIndex:1];
                 }
-                [encoder setBuffer:id_dst offset:offs_dst       atIndex:2];
-                [encoder setBytes:&args   length:sizeof(args)   atIndex:3];
+                if (id_src2) {
+                    [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
+                } else {
+                    [encoder setBuffer:h_src0 offset:offs_src0  atIndex:2];
+                }
+                [encoder setBuffer:id_dst offset:offs_dst       atIndex:3];
+                [encoder setBytes:&args   length:sizeof(args)   atIndex:4];
 
                 [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
 
@@ -2896,6 +3211,7 @@ static bool ggml_metal_encode_node(
                     /*.n_group      =*/ n_group,
                     /*.n_seq_tokens =*/ n_seq_tokens,
                     /*.n_seqs       =*/ n_seqs,
+                    /*.s_off        =*/ ggml_nelements(src1) * sizeof(float),
                     /*.nb01         =*/ nb01,
                     /*.nb02         =*/ nb02,
                     /*.nb03         =*/ nb03,
@@ -2924,12 +3240,22 @@ static bool ggml_metal_encode_node(
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:7];
                 [encoder setBytes:&args    length:sizeof(args) atIndex:8];
 
+                // One shared memory bucket for each simd group in the threadgroup
+                // NOTE: Metal kernels require the buffer size to be multiple of 16 bytes
+                //  https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
+                if (d_state >= 32) {
+                    GGML_ASSERT((int64_t)(d_state / 32) <= 32);
+                    const int64_t shmem_size = 32;
+                    GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup);
+                    [encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0];
+                }
+
                 if (ne30 == 1) {
                     // Mamba-2
-                    [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                    [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
                 } else {
                     GGML_ASSERT(d_inner == 1);
-                    [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                    [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
                 }
             } break;
         case GGML_OP_RWKV_WKV6:
@@ -3035,6 +3361,7 @@ static bool ggml_metal_encode_node(
                        src0t == GGML_TYPE_Q5_0 ||
                        src0t == GGML_TYPE_Q5_1 ||
                        src0t == GGML_TYPE_Q8_0 ||
+                       src0t == GGML_TYPE_MXFP4 ||
                        src0t == GGML_TYPE_IQ4_NL ||
                        false) && (ne11 >= 2 && ne11 <= 8)
                      ) ||
@@ -3127,6 +3454,14 @@ static bool ggml_metal_encode_node(
                                 case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break;
                                 default: GGML_ABORT("not implemented");
                             } break;
+                        case GGML_TYPE_MXFP4:
+                            switch (r1ptg) {
+                                case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2].pipeline; break;
+                                case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3].pipeline; break;
+                                case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4].pipeline; break;
+                                case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            } break;
                         case GGML_TYPE_Q4_K:
                             switch (r1ptg) {
                                 case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break;
@@ -3225,6 +3560,7 @@ static bool ggml_metal_encode_node(
                         case GGML_TYPE_Q5_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32   ].pipeline; break;
                         case GGML_TYPE_Q5_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32   ].pipeline; break;
                         case GGML_TYPE_Q8_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32   ].pipeline; break;
+                        case GGML_TYPE_MXFP4:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32   ].pipeline; break;
                         case GGML_TYPE_Q2_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32   ].pipeline; break;
                         case GGML_TYPE_Q3_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32   ].pipeline; break;
                         case GGML_TYPE_Q4_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32   ].pipeline; break;
@@ -3367,6 +3703,13 @@ static bool ggml_metal_encode_node(
                                 nr0 = N_R0_Q8_0;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
                             } break;
+                        case GGML_TYPE_MXFP4:
+                            {
+                                nsg = N_SG_MXFP4;
+                                nr0 = N_R0_MXFP4;
+                                smem = 32*sizeof(float);
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32].pipeline;
+                            } break;
                         case GGML_TYPE_Q2_K:
                             {
                                 nsg = N_SG_Q2_K;
@@ -3500,8 +3843,6 @@ static bool ggml_metal_encode_node(
         case GGML_OP_MUL_MAT_ID:
             {
                 // src2 = ids
-                const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
-
                 GGML_ASSERT(src2t == GGML_TYPE_I32);
 
                 GGML_ASSERT(!ggml_is_transposed(src0));
@@ -3550,7 +3891,7 @@ static bool ggml_metal_encode_node(
                     id h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
                     if (!h_src1) {
                         GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
-                        return false;
+                        return 0;
                     }
 
                     const int64_t neh0 = ne0;
@@ -3566,7 +3907,7 @@ static bool ggml_metal_encode_node(
                     id h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
                     if (!h_dst) {
                         GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
-                        return false;
+                        return 0;
                     }
 
                     // tokens per expert
@@ -3574,7 +3915,7 @@ static bool ggml_metal_encode_node(
                     id h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
                     if (!h_tpe) {
                         GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
-                        return false;
+                        return 0;
                     }
 
                     // id map
@@ -3583,7 +3924,7 @@ static bool ggml_metal_encode_node(
                     id h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
                     if (!h_ids) {
                         GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
-                        return false;
+                        return 0;
                     }
 
                     {
@@ -3627,6 +3968,7 @@ static bool ggml_metal_encode_node(
                             case GGML_TYPE_Q5_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16   ].pipeline; break;
                             case GGML_TYPE_Q5_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16   ].pipeline; break;
                             case GGML_TYPE_Q8_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16   ].pipeline; break;
+                            case GGML_TYPE_MXFP4:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16  ].pipeline; break;
                             case GGML_TYPE_Q2_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16   ].pipeline; break;
                             case GGML_TYPE_Q3_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16   ].pipeline; break;
                             case GGML_TYPE_Q4_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16   ].pipeline; break;
@@ -3762,6 +4104,13 @@ static bool ggml_metal_encode_node(
                                 nr0 = N_R0_Q8_0;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
                             } break;
+                        case GGML_TYPE_MXFP4:
+                            {
+                                nsg = N_SG_MXFP4;
+                                nr0 = N_R0_MXFP4;
+                                smem = 32*sizeof(float);
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32].pipeline;
+                            } break;
                         case GGML_TYPE_Q2_K:
                             {
                                 nsg = N_SG_Q2_K;
@@ -3914,6 +4263,7 @@ static bool ggml_metal_encode_node(
                     case GGML_TYPE_Q5_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0   ].pipeline; break;
                     case GGML_TYPE_Q5_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1   ].pipeline; break;
                     case GGML_TYPE_Q8_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0   ].pipeline; break;
+                    case GGML_TYPE_MXFP4:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4  ].pipeline; break;
                     case GGML_TYPE_Q2_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K   ].pipeline; break;
                     case GGML_TYPE_Q3_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K   ].pipeline; break;
                     case GGML_TYPE_Q4_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K   ].pipeline; break;
@@ -4015,12 +4365,95 @@ static bool ggml_metal_encode_node(
         case GGML_OP_RMS_NORM:
             {
                 GGML_ASSERT(ne00 % 4 == 0);
-                GGML_ASSERT(ggml_is_contiguous_1(src0));
+                GGML_ASSERT(ggml_is_contiguous_rows(src0));
 
                 float eps;
                 memcpy(&eps, dst->op_params, sizeof(float));
 
-                id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
+                ggml_metal_kargs_rms_norm args = {
+                    /*.ne00   =*/ ne00,
+                    /*.ne00_4 =*/ ne00/4,
+                    /*.nb1    =*/ nb1,
+                    /*.nb2    =*/ nb2,
+                    /*.nb3    =*/ nb3,
+                    /*.eps    =*/ eps,
+                    /*.nef1   =*/ { ne01 },
+                    /*.nef2   =*/ { ne02 },
+                    /*.nef3   =*/ { ne03 },
+                    /*.nbf1   =*/ { nb01 },
+                    /*.nbf2   =*/ { nb02 },
+                    /*.nbf3   =*/ { nb03 },
+                };
+
+                size_t offs_fuse[2] = { 0, 0 };
+                id id_fuse[2] = { id_src0, id_src0 };
+
+                // d[0] = rms_norm(a)
+                // d[1] = mul(d[0], b)
+                // d[2] = add(d[1], c)
+                if (ctx_dev->use_fusion) {
+                    ops[0] = GGML_OP_RMS_NORM;
+                    ops[1] = GGML_OP_MUL;
+                    ops[2] = GGML_OP_ADD;
+
+                    for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
+                        if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) {
+                            break;
+                        }
+
+                        if (nodes[n_fuse] != nodes[n_fuse + 1]->src[0]) {
+                            break;
+                        }
+
+                        if (nodes[n_fuse + 1]->src[1]->ne[0] != node->ne[0]) {
+                            break;
+                        }
+
+                        if (!ggml_is_contiguous_rows(nodes[n_fuse + 1]->src[1])) {
+                            break;
+                        }
+
+                        if (nodes[n_fuse + 1]->type != GGML_TYPE_F32) {
+                            break;
+                        }
+
+                        ctx_dev->fuse_cnt[nodes[n_fuse + 1]->op]++;
+
+                        id_fuse[n_fuse] = ggml_metal_get_buffer(nodes[n_fuse + 1]->src[1], &offs_fuse[n_fuse]);
+
+                        args.nef1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[1];
+                        args.nef2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[2];
+                        args.nef3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->ne[3];
+
+                        args.nbf1[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[1];
+                        args.nbf2[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[2];
+                        args.nbf3[n_fuse + 1] = nodes[n_fuse + 1]->src[1]->nb[3];
+                    }
+
+                    ++n_fuse;
+
+                    if (ctx_dev->debug_fusion > 1 && n_fuse > 1) {
+                        if (n_fuse == 2) {
+                            GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL\n", __func__);
+                        }
+                        if (n_fuse == 3) {
+                            GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL + ADD\n", __func__);
+                        }
+                    }
+                }
+
+                if (n_fuse > 1) {
+                    id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
+                }
+
+                id pipeline;
+
+                switch (n_fuse) {
+                    case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM        ].pipeline; break;
+                    case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL    ].pipeline; break;
+                    case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD].pipeline; break;
+                    default: GGML_ABORT("unsupported n_fuse = %d\n", n_fuse);
+                }
 
                 int nth = 32; // SIMD width
 
@@ -4031,23 +4464,16 @@ static bool ggml_metal_encode_node(
                 nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
                 nth = MIN(nth, ne00/4);
 
-                ggml_metal_kargs_rms_norm args = {
-                    /*.ne00   =*/ ne00,
-                    /*.ne00_4 =*/ ne00/4,
-                    /*.nb01   =*/ nb01,
-                    /*.eps    =*/ eps,
-                };
-
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBytes:&args length:sizeof(args) atIndex:0];
-                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
-                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                [encoder setBytes:&args length:sizeof(args)       atIndex:0];
+                [encoder setBuffer:id_src0    offset:offs_src0    atIndex:1];
+                [encoder setBuffer:id_fuse[0] offset:offs_fuse[0] atIndex:2];
+                [encoder setBuffer:id_fuse[1] offset:offs_fuse[1] atIndex:3];
+                [encoder setBuffer:id_dst     offset:offs_dst     atIndex:4];
 
                 [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
 
-                const int64_t nrows = ggml_nrows(src0);
-
-                [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
             } break;
         case GGML_OP_L2_NORM:
             {
@@ -4648,11 +5074,14 @@ static bool ggml_metal_encode_node(
                 GGML_ASSERT(ne11 == ne21);
                 GGML_ASSERT(ne12 == ne22);
 
-                struct ggml_tensor * src3 = node->src[3];
+                struct ggml_tensor * src3 = node->src[3]; // mask
+                struct ggml_tensor * src4 = node->src[4]; // sinks
 
                 size_t offs_src3 = 0;
+                size_t offs_src4 = 0;
 
                 id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
+                id id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
 
                 GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
                 GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
@@ -4668,8 +5097,6 @@ static bool ggml_metal_encode_node(
                 const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
                 const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
 
-                const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
-
                 float scale;
                 float max_bias;
                 float logit_softcap;
@@ -5057,7 +5484,12 @@ static bool ggml_metal_encode_node(
                 } else {
                     [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
                 }
-                [encoder setBuffer:id_dst offset:offs_dst       atIndex:5];
+                if (id_src4) {
+                    [encoder setBuffer:id_src4 offset:offs_src4 atIndex:5];
+                } else {
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:5];
+                }
+                [encoder setBuffer:id_dst offset:offs_dst       atIndex:6];
 
                 if (!use_vec_kernel) {
                     // half8x8 kernel
@@ -5442,7 +5874,7 @@ static bool ggml_metal_encode_node(
             }
     }
 
-    return true;
+    return n_fuse;
 }
 
 static enum ggml_status ggml_metal_graph_compute(
@@ -5948,20 +6380,26 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
         struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
         ggml_metal_mem_pool_reset(mem_pool);
 
-        for (int idx = node_start; idx < node_end; ++idx) {
+        for (int idx = node_start; idx < node_end;) {
             if (should_capture) {
                 [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
             }
 
-            const bool res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
+            const int res = ggml_metal_encode_node(backend, idx, node_end, encoder, mem_pool);
+            if (idx + res > node_end) {
+                GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s",
+                        "https://github.com/ggml-org/llama.cpp/pull/14849");
+            }
 
             if (should_capture) {
                 [encoder popDebugGroup];
             }
 
-            if (!res) {
+            if (res == 0) {
                 break;
             }
+
+            idx += res;
         }
 
         [encoder endEncoding];
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index 239ec31fbcb..b35a3bbdc31 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -35,6 +35,10 @@ constexpr constant static float kvalues_iq4nl_f[16] = {
     -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
 };
 
+constexpr constant static float kvalues_mxfp4_f[16] = {
+    0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f
+};
+
 static inline int best_index_int8(int n, constant float * val, float x) {
     if (x <= val[0]) return 0;
     if (x >= val[n-1]) return n-1;
@@ -46,6 +50,18 @@ static inline int best_index_int8(int n, constant float * val, float x) {
     return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
 }
 
+static inline float e8m0_to_fp32(uint8_t x) {
+    uint32_t bits;
+
+    if (x == 0) {
+        bits = 0x00400000;
+    } else {
+        bits = (uint32_t) x << 23;
+    }
+
+    return as_type(bits);
+}
+
 // NOTE: this is not dequantizing - we are simply fitting the template
 template 
 void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
@@ -242,6 +258,27 @@ void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
     }
 }
 
+void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
+#pragma METAL fp math_mode(safe)
+    float amax = 0.0f; // absolute max
+
+    for (int j = 0; j < QK8_0; j++) {
+        const float v = src[j];
+        amax = MAX(amax, fabs(v));
+    }
+
+    const float d = amax / ((1 << 7) - 1);
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dst.d = d;
+
+    for (int j = 0; j < QK8_0; ++j) {
+        const float x0 = src[j]*id;
+
+        dst.qs[j] = round(x0);
+    }
+}
+
 void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
 #pragma METAL fp math_mode(safe)
     float amax = 0.0f; // absolute max
@@ -462,25 +499,34 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re
     }
 }
 
-void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
-#pragma METAL fp math_mode(safe)
-    float amax = 0.0f; // absolute max
+template 
+void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) {
+    device const uint8_t * q2 = (device const uint8_t *)xb->qs;
 
-    for (int j = 0; j < QK8_0; j++) {
-        const float v = src[j];
-        amax = MAX(amax, fabs(v));
+    const float d = e8m0_to_fp32(xb->e);
+    const uint8_t shr = il >= 1 ? 4 : 0;
+
+    for (int i = 0; i < 4; ++i) {
+        reg[i][0] = d * kvalues_mxfp4_f[(q2[4*i + 0] >> shr) & 0x0F];
+        reg[i][1] = d * kvalues_mxfp4_f[(q2[4*i + 1] >> shr) & 0x0F];
+        reg[i][2] = d * kvalues_mxfp4_f[(q2[4*i + 2] >> shr) & 0x0F];
+        reg[i][3] = d * kvalues_mxfp4_f[(q2[4*i + 3] >> shr) & 0x0F];
     }
+}
 
-    const float d = amax / ((1 << 7) - 1);
-    const float id = d ? 1.0f/d : 0.0f;
+template 
+void dequantize_mxfp4_t4(device const block_mxfp4 * xb, short il, thread type4 & reg) {
+    device const uint8_t * q2 = (device const uint8_t *)xb->qs;
 
-    dst.d = d;
+    const float d = e8m0_to_fp32(xb->e);
+    const short il4 = il%4;
 
-    for (int j = 0; j < QK8_0; ++j) {
-        const float x0 = src[j]*id;
+    const uint8_t shr = il >= 4 ? 4 : 0;
 
-        dst.qs[j] = round(x0);
-    }
+    reg[0] = d * kvalues_mxfp4_f[(q2[4*il4 + 0] >> shr) & 0x0F];
+    reg[1] = d * kvalues_mxfp4_f[(q2[4*il4 + 1] >> shr) & 0x0F];
+    reg[2] = d * kvalues_mxfp4_f[(q2[4*il4 + 2] >> shr) & 0x0F];
+    reg[3] = d * kvalues_mxfp4_f[(q2[4*il4 + 3] >> shr) & 0x0F];
 }
 
 template 
@@ -832,7 +878,8 @@ enum ggml_sort_order {
 // general-purpose kernel for addition, subtraction, multiplication and division of two tensors
 // pros: works for non-contiguous tensors, supports broadcast across all dims
 // cons: not very efficient
-kernel void kernel_add(
+template 
+kernel void kernel_add_fuse_impl(
         constant ggml_metal_kargs_bin & args,
         device const char * src0,
         device const char * src1,
@@ -848,16 +895,39 @@ kernel void kernel_add(
     const int i12 = i02%args.ne12;
     const int i11 = i01%args.ne11;
 
-    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
-    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
-    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs;
+    device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
+    device       float * dst_ptr  = (device       float *) (dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs);
+
+    device const float * src1_ptr[F];
+    for (short j = 0; j < F; ++j) {
+        src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
+    }
 
     for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
         const int i10 = i0%args.ne10;
-        *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) + *((device float *)(src1_ptr + i10*args.nb10));
+
+        float res = src0_ptr[i0];
+
+#pragma unroll
+        for (short j = 0; j < F; ++j) {
+            res += src1_ptr[j][i10];
+        }
+
+        dst_ptr[i0] = res;
     }
 }
 
+typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t;
+
+template [[host_name("kernel_add")]]        kernel kernel_add_fuse_t kernel_add_fuse_impl<1>;
+template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>;
+template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>;
+template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>;
+template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>;
+template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>;
+template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>;
+template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>;
+
 kernel void kernel_sub(
         constant ggml_metal_kargs_bin & args,
         device const char * src0,
@@ -875,7 +945,7 @@ kernel void kernel_sub(
     const int i11 = i01%args.ne11;
 
     device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
-    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
+    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
     device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs;
 
     for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
@@ -900,9 +970,9 @@ kernel void kernel_mul(
     const int i12 = i02%args.ne12;
     const int i11 = i01%args.ne11;
 
-    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
-    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
-    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1;
+    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
+    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
+    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs;
 
     for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
         const int i10 = i0%args.ne10;
@@ -926,9 +996,9 @@ kernel void kernel_div(
     const int i12 = i02%args.ne12;
     const int i11 = i01%args.ne11;
 
-    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
-    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
-    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1;
+    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
+    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
+    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs;
 
     for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
         const int i10 = i0%args.ne10;
@@ -936,6 +1006,32 @@ kernel void kernel_div(
     }
 }
 
+kernel void kernel_add_id(
+        constant ggml_metal_kargs_add_id & args,
+        device const char * src0,
+        device const char * src1,
+        device const char * src2,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i1 = tgpig.x;
+    const int i2 = tgpig.y;
+
+    const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21));
+
+    const size_t nb1 = args.ne0 * sizeof(float);
+    const size_t nb2 = args.ne1 * nb1;
+
+    device       float * dst_row  = (device       float *)((device char *)dst + i1*nb1 + i2*nb2);
+    device const float * src0_row = (device const float *)((device char *)src0 +  i1*args.nb01 + i2*args.nb02);
+    device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
+
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        dst_row[i0] = src0_row[i0] + src1_row[i0];
+    }
+}
+
 template
 kernel void kernel_repeat(
         constant ggml_metal_kargs_repeat & args,
@@ -970,46 +1066,145 @@ template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat
 
 // assumption: src1 is a row
 // broadcast src1 into src0
-kernel void kernel_add_row(
+template 
+kernel void kernel_add_row_c4_fuse_impl(
         constant ggml_metal_kargs_bin & args,
-        device const float4 * src0,
-        device const float4 * src1,
-        device       float4 * dst,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
         uint tpig[[thread_position_in_grid]]) {
+
     const uint nb = args.ne00/4;
-    dst[tpig] = src0[tpig] + src1[tpig % nb];
+    const uint i  = tpig % nb;
+
+    device const float4 * src0_row = (device const float4 *) (src0);
+    device       float4 *  dst_row = (device       float4 *) (dst);
+
+    device const float4 * src1_row[F];
+    for (short j = 0; j < F; ++j) {
+        src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
+    }
+
+    float4 res = src0_row[tpig];
+
+#pragma unroll(F)
+    for (short j = 0; j < F; ++j) {
+        res += src1_row[j][i];
+    }
+
+    dst_row[tpig] = res;
 }
 
-kernel void kernel_sub_row(
+typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t;
+
+template [[host_name("kernel_add_row_c4")]]        kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>;
+template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>;
+template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>;
+template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>;
+template [[host_name("kernel_add_row_c4_fuse_5")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<5>;
+template [[host_name("kernel_add_row_c4_fuse_6")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<6>;
+template [[host_name("kernel_add_row_c4_fuse_7")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<7>;
+template [[host_name("kernel_add_row_c4_fuse_8")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<8>;
+
+template 
+kernel void kernel_sub_row_c4_fuse_impl(
         constant ggml_metal_kargs_bin & args,
-        device const float4 * src0,
-        device const float4 * src1,
-        device       float4 * dst,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
         uint tpig[[thread_position_in_grid]]) {
+
     const uint nb = args.ne00/4;
-    dst[tpig] = src0[tpig] - src1[tpig % nb];
+    const uint i  = tpig % nb;
+
+    device const float4 * src0_row = (device const float4 *) (src0);
+    device       float4 *  dst_row = (device       float4 *) (dst);
+
+    device const float4 * src1_row[F];
+    for (short j = 0; j < F; ++j) {
+        src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
+    }
+
+    float4 res = src0_row[tpig];
+
+#pragma unroll(F)
+    for (short j = 0; j < F; ++j) {
+        res -= src1_row[j][i];
+    }
+
+    dst_row[tpig] = res;
 }
 
-kernel void kernel_mul_row(
+typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t;
+
+template [[host_name("kernel_sub_row_c4")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>;
+
+template 
+kernel void kernel_mul_row_c4_fuse_impl(
         constant ggml_metal_kargs_bin & args,
-        device const float4 * src0,
-        device const float4 * src1,
-        device       float4 * dst,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
         uint tpig[[thread_position_in_grid]]) {
+
     const uint nb = args.ne00/4;
-    dst[tpig] = src0[tpig] * src1[tpig % nb];
+    const uint i  = tpig % nb;
+
+    device const float4 * src0_row = (device const float4 *) (src0);
+    device       float4 *  dst_row = (device       float4 *) (dst);
+
+    device const float4 * src1_row[F];
+    for (short j = 0; j < F; ++j) {
+        src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
+    }
+
+    float4 res = src0_row[tpig];
+
+#pragma unroll(F)
+    for (short j = 0; j < F; ++j) {
+        res *= src1_row[j][i];
+    }
+
+    dst_row[tpig] = res;
 }
 
-kernel void kernel_div_row(
+typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t;
+
+template [[host_name("kernel_mul_row_c4")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>;
+
+template 
+kernel void kernel_div_row_c4_fuse_impl(
         constant ggml_metal_kargs_bin & args,
-        device const float4 * src0,
-        device const float4 * src1,
-        device       float4 * dst,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
         uint tpig[[thread_position_in_grid]]) {
+
     const uint nb = args.ne00/4;
-    dst[tpig] = src0[tpig] / src1[tpig % nb];
+    const uint i  = tpig % nb;
+
+    device const float4 * src0_row = (device const float4 *) (src0);
+    device       float4 *  dst_row = (device       float4 *) (dst);
+
+    device const float4 * src1_row[F];
+    for (short j = 0; j < F; ++j) {
+        src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
+    }
+
+    float4 res = src0_row[tpig];
+
+#pragma unroll(F)
+    for (short j = 0; j < F; ++j) {
+        res /= src1_row[j][i];
+    }
+
+    dst_row[tpig] = res;
 }
 
+typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t;
+
+template [[host_name("kernel_div_row_c4")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>;
+
 kernel void kernel_scale(
         device const float * src0,
         device       float * dst,
@@ -1199,6 +1394,51 @@ kernel void kernel_neg(
     dst[tpig] = -src0[tpig];
 }
 
+kernel void kernel_abs(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = fabs(src0[tpig]);
+}
+
+kernel void kernel_sgn(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+    dst[tpig] = (x > 0.0f) ? 1.0f : ((x < 0.0f) ? -1.0f : 0.0f);
+}
+
+kernel void kernel_step(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] > 0.0f ? 1.0f : 0.0f;
+}
+
+kernel void kernel_hardswish(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+    dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
+}
+
+kernel void kernel_hardsigmoid(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+    dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
+}
+
+kernel void kernel_exp(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = exp(src0[tpig]);
+}
+
 kernel void kernel_reglu(
         device const char * src0,
         device const char * src1,
@@ -1263,6 +1503,32 @@ kernel void kernel_swiglu(
     }
 }
 
+kernel void kernel_swiglu_oai(
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        constant ggml_metal_kargs_glu & args,
+        uint tgpig[[threadgroup_position_in_grid]],
+        uint tpitg[[thread_position_in_threadgroup]],
+        uint   ntg[[threads_per_threadgroup]]) {
+    device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
+    device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
+    device       float * dst_row  = (device       float *) ((device       char *) dst  + tgpig*args.nb1);
+
+    for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
+        float x0 = src0_row[i0];
+        float x1 = src1_row[i0];
+
+        x0 = min(x0, args.limit);
+        x1 = max(min(x1, args.limit), -args.limit);
+
+        float out_glu = x0 / (1.0f + exp(-x0 * args.alpha));
+        out_glu = out_glu * (1.0f + x1);
+
+        dst_row[i0] = out_glu;
+    }
+}
+
 kernel void kernel_geglu_erf(
         device const char * src0,
         device const char * src1,
@@ -1366,6 +1632,7 @@ template
 kernel void kernel_soft_max(
         device const  char * src0,
         device const  char * src1,
+        device const  char * src2,
         device        char * dst,
         constant ggml_metal_kargs_soft_max & args,
         threadgroup  float * buf [[threadgroup(0)]],
@@ -1384,6 +1651,7 @@ kernel void kernel_soft_max(
 
     device const float * psrc0 =                (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
     device const     T * pmask = src1 != src0 ? (device const T *    ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
+    device const float * psrc2 = src2 != src0 ? (device const float *) (src2)                                                 : nullptr;
     device       float * pdst  =                (device       float *) (dst  + i01*args.nb1  + i02*args.nb2  + i03*args.nb3);
 
     float slope = 1.0f;
@@ -1399,7 +1667,7 @@ kernel void kernel_soft_max(
     }
 
     // parallel max
-    float lmax = -INFINITY;
+    float lmax = psrc2 ? psrc2[i02] : -INFINITY;
 
     for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
         lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
@@ -1455,6 +1723,10 @@ kernel void kernel_soft_max(
         sum = simd_sum(sum);
     }
 
+    if (psrc2) {
+        sum += exp(psrc2[i02] - max_val);
+    }
+
     const float inv_sum = 1.0f/sum;
 
     for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
@@ -1466,6 +1738,7 @@ template
 kernel void kernel_soft_max_4(
         device const  char * src0,
         device const  char * src1,
+        device const  char * src2,
         device        char * dst,
         constant ggml_metal_kargs_soft_max & args,
         threadgroup  float * buf [[threadgroup(0)]],
@@ -1484,6 +1757,7 @@ kernel void kernel_soft_max_4(
 
     device const float4 * psrc4 =                (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
     device const      T * pmask = src1 != src0 ? (device const T *     ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
+    device const float *  psrc2 = src2 != src0 ? (device const float * ) (src2)                                                 : nullptr;
     device       float4 * pdst4 =                (device       float4 *) (dst  + i01*args.nb1  + i02*args.nb2  + i03*args.nb3);
 
     float slope = 1.0f;
@@ -1498,7 +1772,7 @@ kernel void kernel_soft_max_4(
     }
 
     // parallel max
-    float4 lmax4 = -INFINITY;
+    float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
 
     for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
         lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
@@ -1557,6 +1831,10 @@ kernel void kernel_soft_max_4(
         sum = simd_sum(sum);
     }
 
+    if (psrc2) {
+        sum += exp(psrc2[i02] - max_val);
+    }
+
     const float inv_sum = 1.0f/sum;
 
     for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
@@ -1655,10 +1933,16 @@ kernel void kernel_ssm_scan_f32(
         device const void * src5,
         device const void * src6,
         device      float * dst,
+        threadgroup float * shared [[threadgroup(0)]],
         constant ggml_metal_kargs_ssm_scan & args,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        uint3  tpitg[[thread_position_in_threadgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgptg[[simdgroups_per_threadgroup]],
+        uint3   tgpg[[threadgroups_per_grid]]) {
+
+    const int64_t i0 = tpitg.x;
     const int64_t i1 = 0;
     const int64_t ir = tgpig.x; // current head
     const int64_t i3 = tgpig.y; // current seq
@@ -1673,41 +1957,88 @@ kernel void kernel_ssm_scan_f32(
     const int64_t ng  = args.n_group;
     const int64_t n_t = args.n_seq_tokens;
 
-    const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
+    const int64_t s_off = args.s_off;
 
     device const int32_t * ids = (device const int32_t *) src6;
 
-    device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
-    device       float * s  = (device       float *) ((device       char *) dst  + ir*args.nb02 +      i3*args.nb03 + s_off);
+    device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
+    device       float * s_buff  = (device       float *) ((device       char *) dst  + ir*args.nb02 +      i3*args.nb03 + s_off);
+    const int64_t i = i0 + i1*nc;
+    float s0 = s0_buff[i];
+    float s  = s_buff[i];
+
+        device const float * A        = (device const float *) ((device const char *) src3 + ir*args.nb31);
+        device const float * x_block  = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
+        device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
+        device const float * B_block  = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
+        device const float * C_block  = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
+        device       float * y_block  = (device       float *) ((device       char *) dst  + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
 
     for (int64_t i2 = 0; i2 < n_t; ++i2) {
-        device const float * x  = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
-        device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
-        device const float * A  = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh}
-        device const float * B  = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
-        device const float * C  = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
-        device       float * y  = (device       float *) ((device       char *) dst  + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
+        device const float * x  = (device const float *) ((device const char *) x_block + i2*args.nb12);    // {dim, nh, nt, ns}
+        device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21);   // {nh, nt, ns}
+        device const float * B  = (device const float *) ((device const char *) B_block + i2*args.nb42);    // {d_state, ng, nt, ns}
+        device const float * C  = (device const float *) ((device const char *) C_block + i2*args.nb52);    // {d_state, ng, nt, ns}
+        device       float * y  = (device       float *) ((device       char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
 
         const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
         const float x_dt = x[0] * dt_soft_plus;
-        float sumf = 0.0f;
 
-        for (int64_t i0 = 0; i0 < nc; ++i0) {
-            const int64_t i = i0 + i1*nc;
-            const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
-            sumf += state * C[i0];
-            s[i] = state;
-        }
+        const float state = (s0 * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
+        s = state;
 
-        y[0] = sumf;
+        // Parallel sum: This relies on the fact that this kernel will be
+        // dispatched with each threadgroup having (d_state, 1, 1) threads which
+        // are subdivided into SIMD groups of size `sgptg`. The goal is to
+        // compute y = sum({state * C[i] for i in range(d_state)}).
+        // To parallelize this effectively, we first use simd_sum over each SIMD
+        // group to compute the sum of each SIMD group, then place the result in
+        // the SIMD group's indexed bucket in the shared memory. We then sum
+        // over the individual group sums to compute the final sum.
+
+        // Computed for each thread
+        float sumf = state * C[i0];
+
+        // Sum the threads in the simd group => simd sum
+        sumf = simd_sum(sumf);
+
+        if (sgptg > 1) {
+
+            // Once per simd group, place the group sum into the shared buffer
+            if (tiisg == 0) {
+                shared[sgitg] = sumf;
+            }
+
+            // Wait for all threads in the threadgroup to reach this point. This
+            // ensures that all elements of the shared buffer are populated with the
+            // sum of the individual simd groups.
+            threadgroup_barrier(mem_flags::mem_threadgroup);
+
+            // For simd group 0 at indices < num simd groups, extract the shared
+            // simd sum
+            sumf = 0.0f;
+            if (sgitg == 0) {
+                if (tiisg < sgptg) {
+                    sumf = shared[tiisg];
+                }
+                sumf = simd_sum(sumf);
+                if (tiisg == 0) {
+                    y[0] = sumf;
+                }
+            }
+        } else if (tiisg == 0) {
+            y[0] = sumf;
+        }
 
         // recurse
         s0 = s;
     }
+
+    // Assign the final state to the output buffer
+    s_buff[i] = s;
 }
 
 // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
-// TODO: optimize (e.g. by parallelizing over d_state)
 kernel void kernel_ssm_scan_f32_group(
         device const void * src0,
         device const void * src1,
@@ -1717,10 +2048,16 @@ kernel void kernel_ssm_scan_f32_group(
         device const void * src5,
         device const void * src6,
         device      float * dst,
+        threadgroup float * shared [[threadgroup(0)]],
         constant ggml_metal_kargs_ssm_scan & args,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        uint3  tpitg[[thread_position_in_threadgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgptg[[simdgroups_per_threadgroup]],
+        uint3   tgpg[[threadgroups_per_grid]]) {
+
+    const int64_t i0 = tpitg.x;
     const int64_t i1 = tgpig.x;
     const int64_t ir = tgpig.y; // current head
     const int64_t i3 = tgpig.z; // current seq
@@ -1735,38 +2072,81 @@ kernel void kernel_ssm_scan_f32_group(
     const int64_t ng  = args.n_group;
     const int64_t n_t = args.n_seq_tokens;
 
-    const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
+    const int64_t s_off = args.s_off;
 
     device const int32_t * ids = (device const int32_t *) src6;
 
-    device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
-    device       float * s  = (device       float *) ((device       char *) dst  + ir*args.nb02 +      i3*args.nb03 + s_off);
+    device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
+    device       float * s_buff  = (device       float *) ((device       char *) dst  + ir*args.nb02 +      i3*args.nb03 + s_off);
+    const int64_t i = i0 + i1*nc;
+    float s0 = s0_buff[i];
+    float s  = s_buff[i];
+
+    device const float * A        = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
+    device const float * x_block  = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
+    device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
+    device const float * B_block  = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
+    device const float * C_block  = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
+    device       float * y_block  = (device       float *) ((device       char *) dst  + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
 
     for (int64_t i2 = 0; i2 < n_t; ++i2) {
-        device const float * x  = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
-        device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
-        device const float * A  = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
-        device const float * B  = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
-        device const float * C  = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
-        device       float * y  = (device       float *) ((device       char *) dst  + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
+        device const float * x  = (device const float *) ((device const char *) x_block  + i2*args.nb12);    // {dim, nh, nt, ns}
+        device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21);    // {nh, nt, ns}
+        device const float * B  = (device const float *) ((device const char *) B_block  + i2*args.nb42);    // {d_state, ng, nt, ns}
+        device const float * C  = (device const float *) ((device const char *) C_block  + i2*args.nb52);    // {d_state, ng, nt, ns}
+        device       float * y  = (device       float *) ((device       char *) y_block  + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
 
         const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
         const float x_dt = x[0] * dt_soft_plus;
         const float dA = exp(dt_soft_plus * A[0]);
-        float sumf = 0.0f;
 
-        for (int64_t i0 = 0; i0 < nc; ++i0) {
-            const int64_t i = i0 + i1*nc;
-            const float state = (s0[i] * dA) + (B[i0] * x_dt);
-            sumf += state * C[i0];
-            s[i] = state;
+        const float state = (s0 * dA) + (B[i0] * x_dt);
+        s = state;
+
+        // Parallel sum: This relies on the fact that this kernel will be
+        // dispatched with each threadgroup having (d_state, 1, 1) threads which
+        // are subdivided into SIMD groups of size `sgptg`. The goal is to
+        // compute y = sum({state * C[i] for i in range(d_state)}).
+        // To parallelize this effectively, we first use simd_sum over each SIMD
+        // group to compute the sum of each SIMD group, then place the result in
+        // the SIMD group's indexed bucket in the shared memory. We then sum
+        // over the individual group sums to compute the final sum.
+
+        // Computed for each thread
+        float sumf = state * C[i0];
+
+        // Sum the threads in the simd group => simd sum
+        sumf = simd_sum(sumf);
+
+        // Once per simd group, place the group sum into the shared buffer
+        if (tiisg == 0) {
+            shared[sgitg] = sumf;
         }
 
-        y[0] = sumf;
+        // Wait for all threads in the threadgroup to reach this point. This
+        // ensures that all elements of the shared buffer are populated with the
+        // sum of the individual simd groups.
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        // For simd group 0 at indices < num simd groups, extract the shared
+        // simd sum
+        sumf = 0.0f;
+        if (sgitg == 0) {
+            if (tiisg < sgptg) {
+                sumf = shared[tiisg];
+            }
+            sumf = simd_sum(sumf);
+            if (tiisg == 0) {
+                y[0] = sumf;
+            }
+        }
 
         // recurse
         s0 = s;
     }
+
+    // Assign the final state to the output buffer
+    s_buff[i] = s;
 }
 
 kernel void kernel_rwkv_wkv6_f32(
@@ -2071,26 +2451,39 @@ kernel void kernel_norm(
     }
 }
 
-kernel void kernel_rms_norm(
+// F == 1 : rms_norm (no fuse)
+// F == 2 : rms_norm + mul
+// F == 3 : rms_norm + mul + add
+template 
+kernel void kernel_rms_norm_fuse_impl(
         constant ggml_metal_kargs_rms_norm & args,
         device const char * src0,
+        device const char * src1_0,
+        device const char * src1_1,
         device       char * dst,
         threadgroup float * shmem_f32 [[threadgroup(0)]],
-        uint   tgpig[[threadgroup_position_in_grid]],
-        ushort tpitg[[thread_position_in_threadgroup]],
-        ushort sgitg[[simdgroup_index_in_threadgroup]],
-        ushort tiisg[[thread_index_in_simdgroup]],
-        ushort   ntg[[threads_per_threadgroup]]) {
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
     if (sgitg == 0) {
         shmem_f32[tiisg] = 0.0f;
     }
 
-    device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
+    const int i01 = tgpig.x;
+    const int i02 = tgpig.y;
+    const int i03 = tgpig.z;
+
+    device const float4 * x = (device const float4 *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
+
+    device const float4 * f0 = (device const float4 *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
+    device const float4 * f1 = (device const float4 *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
 
     float sumf = 0.0f;
 
     // parallel sum
-    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+    for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
         sumf += dot(x[i00], x[i00]);
     }
     sumf = simd_sum(sumf);
@@ -2109,12 +2502,26 @@ kernel void kernel_rms_norm(
     const float mean  = sumf/args.ne00;
     const float scale = 1.0f/sqrt(mean + args.eps);
 
-    device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
-    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
-        y[i00] = x[i00] * scale;
+    device float4 * y = (device float4 *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
+    for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
+        if (F == 1) {
+            y[i00] = (x[i00]*scale);
+        }
+        if (F == 2) {
+            y[i00] = (x[i00]*scale)*f0[i00];
+        }
+        if (F == 3) {
+            y[i00] = (x[i00]*scale)*f0[i00] + f1[i00];
+        }
     }
 }
 
+typedef decltype(kernel_rms_norm_fuse_impl<1>) kernel_rms_norm_fuse_t;
+
+template [[host_name("kernel_rms_norm")]]         kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<1>;
+template [[host_name("kernel_rms_norm_mul")]]     kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<2>;
+template [[host_name("kernel_rms_norm_mul_add")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<3>;
+
 kernel void kernel_l2_norm(
         constant ggml_metal_kargs_l2_norm & args,
         device const char * src0,
@@ -2809,6 +3216,11 @@ template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_3")]]   kernel mul_mv_ext_q4
 template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0,   32, dequantize_q8_0_t4>;
 template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0,   32, dequantize_q8_0_t4>;
 
+template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_2")]]  kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_mxfp4,  32, dequantize_mxfp4_t4>;
+template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_3")]]  kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_mxfp4,  32, dequantize_mxfp4_t4>;
+template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_4")]]  kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_mxfp4,  32, dequantize_mxfp4_t4>;
+template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_5")]]  kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_mxfp4,  32, dequantize_mxfp4_t4>;
+
 template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
 template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
 template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
@@ -3795,6 +4207,7 @@ kernel void kernel_flash_attn_ext(
         device const char * k,
         device const char * v,
         device const char * mask,
+        device const char * sinks,
         device       char * dst,
         threadgroup  half * shmem_f16 [[threadgroup(0)]],
         uint3   tgpig[[threadgroup_position_in_grid]],
@@ -4110,6 +4523,35 @@ kernel void kernel_flash_attn_ext(
             }
         }
 
+        if (sinks != q && sgitg == 0) {
+            for (ushort j = 0; j < Q; ++j) {
+                const float m = M[j];
+                const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
+
+                M[j] = simd_max(max(M[j], s));
+
+                const float ms = exp(m - M[j]);
+                const float vs = exp(s - M[j]);
+
+                S[j] = S[j]*ms + simd_sum(vs);
+
+                if (tiisg == j) {
+                    ss[j*TS + 2*C + j] = ms;
+                }
+            }
+
+            // O = diag(ms)*O
+            {
+                s8x8_t ms;
+                simdgroup_load(ms, ss + 2*C, TS, 0, false);
+
+                #pragma unroll(DV8)
+                for (short i = 0; i < DV8; ++i) {
+                    simdgroup_multiply(lo[i], ms, lo[i]);
+                }
+            }
+        }
+
         // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
         for (short j = tiisg; j < Q; j += NW) {
             ss[j*TS + 0] = S[j];
@@ -4321,6 +4763,7 @@ kernel void kernel_flash_attn_ext_vec(
         device const char * k,
         device const char * v,
         device const char * mask,
+        device const char * sinks,
         device       char * dst,
         threadgroup  half * shmem_f16 [[threadgroup(0)]],
         uint3   tgpig[[threadgroup_position_in_grid]],
@@ -4538,6 +4981,23 @@ kernel void kernel_flash_attn_ext_vec(
             }
         }
 
+        if (sinks != q && sgitg == 0) {
+            const float m = M;
+            const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
+
+            M = simd_max(max(M, s));
+
+            const float ms = exp(m - M);
+            const float vs = exp(s - M);
+
+            S = S*ms + simd_sum(vs);
+
+#pragma unroll(DV4/NL)
+            for (short ii = 0; ii < DV4; ii += NL) {
+                lo[ii/NL] *= ms;
+            }
+        }
+
         // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
         if (tiisg == 0) {
             ss[0] = (s_t) S;
@@ -6643,6 +7103,95 @@ kernel void kernel_mul_mv_iq4_xs_f32(
     kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
+template
+void kernel_mul_mv_mxfp4_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    threadgroup float * shmem_f32 = (threadgroup float *) shmem;
+    const int nb = args.ne00/QK_MXFP4;
+
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * nsg + sgitg) * nr0;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0);
+    device const float       * y = (device const float       *) (src1 + offset1);
+
+    const short ix = tiisg/2;  // 0...15
+    const short it = tiisg%2;  // 0 or 1
+
+    shmem_f32[tiisg] = kvalues_mxfp4_f[tiisg%16];
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    float4 yl[4];
+    float sumf[nr0]={0.f};
+
+    device const float * yb = y + ix * QK_MXFP4 + it * 8;
+
+    for (int ib = ix; ib < nb; ib += 16) {
+        device const float4 * y4 = (device const float4 *)yb;
+        yl[0] = y4[0];
+        yl[1] = y4[4];
+        yl[2] = y4[1];
+        yl[3] = y4[5];
+
+#pragma unroll(nr0)
+        for (short row = 0; row < nr0; row++) {
+            device const block_mxfp4 & xb = x[row*nb + ib];
+            device const uint8_t     * q2 = (device const uint8_t *)(xb.qs + 8*it);
+
+            float4 acc1 = yl[0]*float4(shmem_f32[q2[0] &  0x0F], shmem_f32[q2[1] &  0x0F], shmem_f32[q2[2] &  0x0F], shmem_f32[q2[3] &  0x0F]);
+            float4 acc2 = yl[1]*float4(shmem_f32[q2[0] >> 4   ], shmem_f32[q2[1] >> 4   ], shmem_f32[q2[2] >> 4   ], shmem_f32[q2[3] >> 4   ]);
+            float4 acc3 = yl[2]*float4(shmem_f32[q2[4] &  0x0F], shmem_f32[q2[5] &  0x0F], shmem_f32[q2[6] &  0x0F], shmem_f32[q2[7] &  0x0F]);
+            float4 acc4 = yl[3]*float4(shmem_f32[q2[4] >> 4   ], shmem_f32[q2[5] >> 4   ], shmem_f32[q2[6] >> 4   ], shmem_f32[q2[7] >> 4   ]);
+
+            acc1 = (acc1 + acc3) + (acc2 + acc4);
+
+            sumf[row] += e8m0_to_fp32(xb.e) * ((acc1[0] + acc1[1]) + (acc1[2] + acc1[3]));
+        }
+
+        yb += 16 * QK_MXFP4;
+    }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst_f32[first_row + row] = sum_all;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_mxfp4_f32")]]
+kernel void kernel_mul_mv_mxfp4_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_mxfp4_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
 template
 kernel void kernel_get_rows_q(
         constant ggml_metal_kargs_get_rows & args,
@@ -7178,6 +7727,7 @@ template [[host_name("kernel_get_rows_q4_1")]]    kernel get_rows_q_t kernel_get
 template [[host_name("kernel_get_rows_q5_0")]]    kernel get_rows_q_t kernel_get_rows_q;
 template [[host_name("kernel_get_rows_q5_1")]]    kernel get_rows_q_t kernel_get_rows_q;
 template [[host_name("kernel_get_rows_q8_0")]]    kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_mxfp4")]]   kernel get_rows_q_t kernel_get_rows_q;
 template [[host_name("kernel_get_rows_q2_K")]]    kernel get_rows_q_t kernel_get_rows_q;
 template [[host_name("kernel_get_rows_q3_K")]]    kernel get_rows_q_t kernel_get_rows_q;
 template [[host_name("kernel_get_rows_q4_K")]]    kernel get_rows_q_t kernel_get_rows_q;
@@ -7230,6 +7780,7 @@ template [[host_name("kernel_mul_mm_q4_1_f32")]]    kernel mul_mm_t kernel_mul_m
 template [[host_name("kernel_mul_mm_q5_0_f32")]]    kernel mul_mm_t kernel_mul_mm;
 template [[host_name("kernel_mul_mm_q5_1_f32")]]    kernel mul_mm_t kernel_mul_mm;
 template [[host_name("kernel_mul_mm_q8_0_f32")]]    kernel mul_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_mxfp4_f32")]]   kernel mul_mm_t kernel_mul_mm;
 template [[host_name("kernel_mul_mm_q2_K_f32")]]    kernel mul_mm_t kernel_mul_mm;
 template [[host_name("kernel_mul_mm_q3_K_f32")]]    kernel mul_mm_t kernel_mul_mm;
 template [[host_name("kernel_mul_mm_q4_K_f32")]]    kernel mul_mm_t kernel_mul_mm;
@@ -7261,6 +7812,7 @@ template [[host_name("kernel_mul_mm_id_q4_1_f16")]]    kernel mul_mm_id kernel_m
 template [[host_name("kernel_mul_mm_id_q5_0_f16")]]    kernel mul_mm_id kernel_mul_mm_id;
 template [[host_name("kernel_mul_mm_id_q5_1_f16")]]    kernel mul_mm_id kernel_mul_mm_id;
 template [[host_name("kernel_mul_mm_id_q8_0_f16")]]    kernel mul_mm_id kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_mxfp4_f16")]]   kernel mul_mm_id kernel_mul_mm_id;
 template [[host_name("kernel_mul_mm_id_q2_K_f16")]]    kernel mul_mm_id kernel_mul_mm_id;
 template [[host_name("kernel_mul_mm_id_q3_K_f16")]]    kernel mul_mm_id kernel_mul_mm_id;
 template [[host_name("kernel_mul_mm_id_q4_K_f16")]]    kernel mul_mm_id kernel_mul_mm_id;
@@ -7406,6 +7958,8 @@ template [[host_name("kernel_mul_mv_id_q4_1_f32")]]    kernel kernel_mul_mv_id_t
 template [[host_name("kernel_mul_mv_id_q5_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
 template [[host_name("kernel_mul_mv_id_q5_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
 
+template [[host_name("kernel_mul_mv_id_mxfp4_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+
 template [[host_name("kernel_mul_mv_id_q2_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
 template [[host_name("kernel_mul_mv_id_q3_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
 template [[host_name("kernel_mul_mv_id_q4_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt
index 971314debc7..02904526ade 100644
--- a/ggml/src/ggml-musa/CMakeLists.txt
+++ b/ggml/src/ggml-musa/CMakeLists.txt
@@ -34,8 +34,12 @@ if (MUSAToolkit_FOUND)
     list(APPEND GGML_SOURCES_MUSA ${SRCS})
     file(GLOB   SRCS "../ggml-cuda/template-instances/mmq*.cu")
     list(APPEND GGML_SOURCES_MUSA ${SRCS})
-    file(GLOB   SRCS "../ggml-musa/*.cu")
-    list(APPEND GGML_SOURCES_MUSA ${SRCS})
+
+    if (GGML_MUSA_MUDNN_COPY)
+        file(GLOB   SRCS "../ggml-musa/*.cu")
+        list(APPEND GGML_SOURCES_MUSA ${SRCS})
+        add_compile_definitions(GGML_MUSA_MUDNN_COPY)
+    endif()
 
     if (GGML_CUDA_FA_ALL_QUANTS)
         file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-vec*.cu")
@@ -72,6 +76,10 @@ if (MUSAToolkit_FOUND)
     add_compile_definitions(GGML_USE_MUSA)
     add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
 
+    if (GGML_MUSA_GRAPHS)
+        add_compile_definitions(GGML_MUSA_GRAPHS)
+    endif()
+
     if (GGML_CUDA_FORCE_MMQ)
         add_compile_definitions(GGML_CUDA_FORCE_MMQ)
     endif()
@@ -97,10 +105,16 @@ if (MUSAToolkit_FOUND)
     endif()
 
     if (GGML_STATIC)
-        # TODO: mudnn has not provided static libraries yet
         target_link_libraries(ggml-musa PRIVATE MUSA::musart_static MUSA::mublas_static)
+        # TODO: mudnn has not provided static libraries yet
+        # if (GGML_MUSA_MUDNN_COPY)
+        #     target_link_libraries(ggml-musa PRIVATE mudnn_static)
+        # endif()
     else()
-        target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas mudnn)
+        target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas)
+        if (GGML_MUSA_MUDNN_COPY)
+            target_link_libraries(ggml-musa PRIVATE mudnn)
+        endif()
     endif()
 
     if (GGML_CUDA_NO_VMM)
diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt
index ec5d8cf5955..9a7ccbcff06 100644
--- a/ggml/src/ggml-opencl/CMakeLists.txt
+++ b/ggml/src/ggml-opencl/CMakeLists.txt
@@ -55,6 +55,7 @@ endfunction()
 
 set(GGML_OPENCL_KERNELS
     add
+    add_id
     argsort
     clamp
     cpy
@@ -81,7 +82,11 @@ set(GGML_OPENCL_KERNELS
     mul_mv_q4_0_f32_1d_8x_flat
     mul_mv_q4_0_f32_1d_16x_flat
     mul_mv_q6_k
+    mul_mv_mxfp4_f32
     mul_mv_id_q4_0_f32_8x_flat
+    mul_mv_id_mxfp4_f32
+    mul_mm_f32_f32_l4_lm
+    mul_mm_f16_f32_l4_lm
     mul
     norm
     relu
@@ -105,6 +110,11 @@ set(GGML_OPENCL_KERNELS
     pad
     repeat
     mul_mat_f16_f32
+    conv2d
+    conv2d_f16_f32
+    flash_attn_f32_f16
+    flash_attn_f16
+    flash_attn_f32
 )
 
 foreach (K ${GGML_OPENCL_KERNELS})
diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp
index 58830b733a8..8a2ac7e377d 100644
--- a/ggml/src/ggml-opencl/ggml-opencl.cpp
+++ b/ggml/src/ggml-opencl/ggml-opencl.cpp
@@ -25,6 +25,7 @@
 #include 
 #include 
 #include 
+#include 
 #include 
 #include 
 #include 
@@ -33,6 +34,7 @@
 #undef MAX
 #define MIN(a, b) ((a) < (b) ? (a) : (b))
 #define MAX(a, b) ((a) > (b) ? (a) : (b))
+#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
 
 #define UNUSED(x) (void)(x)
 
@@ -333,6 +335,7 @@ struct ggml_backend_opencl_context {
     size_t max_alloc_size;
     bool fp16_support;
     bool has_vector_subgroup_broadcast;
+    bool disable_fusion;
     ggml_cl_compiler_version adreno_cl_compiler_version;
 
     int adreno_wave_size;
@@ -343,6 +346,7 @@ struct ggml_backend_opencl_context {
     cl_command_queue queue;
 
     cl_program program_add;
+    cl_program program_add_id;
     cl_program program_clamp;
     cl_program program_cpy;
     cl_program program_cvt;
@@ -362,6 +366,7 @@ struct ggml_backend_opencl_context {
     cl_program program_mul_mv_q4_0_f32_1d_8x_flat;
     cl_program program_mul_mv_q4_0_f32_1d_16x_flat;
     cl_program program_mul_mv_q6_K;
+    cl_program program_mul_mv_mxfp4_f32;
     cl_program program_mul_mv_f16_f16;
     cl_program program_mul_mv_f16_f32_1row;
     cl_program program_mul_mv_f16_f32_l4;
@@ -390,13 +395,20 @@ struct ggml_backend_opencl_context {
     cl_program program_tanh;
     cl_program program_upscale;
     cl_program program_concat;
+    cl_program program_conv_2d_f16;
+    cl_program program_conv_2d_f32;
+    cl_program program_conv_2d_f16_f32;
     cl_program program_tsembd;
     cl_program program_mul_mv_id_q4_0_f32_8x_flat;
-
-    cl_kernel kernel_add, kernel_add_row;
-    cl_kernel kernel_mul, kernel_mul_row;
-    cl_kernel kernel_div, kernel_div_row;
-    cl_kernel kernel_sub, kernel_sub_row;
+    cl_program program_mul_mv_id_mxfp4_f32;
+    cl_program program_mul_mm_f32_f32_l4_lm;
+    cl_program program_mul_mm_f16_f32_l4_lm;
+
+    cl_kernel kernel_add, kernel_add_row, kernel_add_f16, kernel_add_row_f16;
+    cl_kernel kernel_mul, kernel_mul_row, kernel_mul_f16, kernel_mul_row_f16;
+    cl_kernel kernel_div, kernel_div_row, kernel_div_f16, kernel_div_row_f16;
+    cl_kernel kernel_sub, kernel_sub_row, kernel_sub_f16, kernel_sub_row_f16;
+    cl_kernel kernel_add_id;
     cl_kernel kernel_scale;
     cl_kernel kernel_silu, kernel_silu_4;
     cl_kernel kernel_gelu, kernel_gelu_4;
@@ -405,14 +417,22 @@ struct ggml_backend_opencl_context {
     cl_kernel kernel_relu;
     cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
     cl_kernel kernel_clamp;
-    cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick,
+    cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick,
               kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
     cl_kernel kernel_norm;
-    cl_kernel kernel_rms_norm;
+    cl_kernel kernel_rms_norm, kernel_rms_norm_mul;
     cl_kernel kernel_group_norm;
     cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
     cl_kernel kernel_soft_max, kernel_soft_max_4;
     cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;
+    std::map, cl_kernel> kernels_flash_attn_f16;
+    std::map, cl_kernel> kernels_flash_attn_f16_q1;
+    std::map, cl_kernel> kernels_flash_attn_f32;
+    std::map, cl_kernel> kernels_flash_attn_f32_q1;
+    std::map, cl_kernel> kernels_flash_attn_f32_f16;
+    std::map, cl_kernel> kernels_flash_attn_f32_f16_q1;
+    std::map, int>       kernels_flash_attn_bm;
+    std::map, int>       kernels_flash_attn_bn;
     cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0;
     cl_kernel kernel_set_rows_f32, kernel_set_rows_f16;
     cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16;
@@ -430,6 +450,7 @@ struct ggml_backend_opencl_context {
     cl_kernel kernel_convert_block_q4_0_noshuffle;
     cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
     cl_kernel kernel_mul_mv_q6_K_f32;
+    cl_kernel kernel_mul_mv_mxfp4_f32;
     cl_kernel kernel_im2col_f32, kernel_im2col_f16;
     cl_kernel kernel_argsort_f32_i32;
     cl_kernel kernel_sum_rows_f32;
@@ -441,8 +462,14 @@ struct ggml_backend_opencl_context {
     cl_kernel kernel_upscale_bilinear;
     cl_kernel kernel_concat_f32_contiguous;
     cl_kernel kernel_concat_f32_non_contiguous;
+    cl_kernel kernel_conv_2d_f16;
+    cl_kernel kernel_conv_2d_f32;
+    cl_kernel kernel_conv_2d_f16_f32;
     cl_kernel kernel_timestep_embedding;
     cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
+    cl_kernel kernel_mul_mv_id_mxfp4_f32;
+    cl_kernel kernel_mul_mm_f32_f32_l4_lm;
+    cl_kernel kernel_mul_mm_f16_f32_l4_lm;
 
     std::vector profiling_info;
 
@@ -563,6 +590,7 @@ struct ggml_backend_opencl_context {
     cl_kernel kernel_transpose_32;
     cl_kernel kernel_transpose_32_16;
     cl_kernel kernel_transpose_16;
+    cl_kernel kernel_transpose_16_4x1;
 
     cl_mem A_s_d_max;            // max scale buffer size for transpose
     cl_mem A_q_d_max;            // max weight buffer size for transpose
@@ -588,6 +616,7 @@ struct ggml_backend_opencl_context {
         if (ref_count == 0) {
 #ifdef GGML_OPENCL_PROFILING
             write_profiling_info();
+            profiling_info.clear();
 #endif
         }
     }
@@ -662,8 +691,26 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         backend_ctx->program_add =
             build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
 
-        CL_CHECK((backend_ctx->kernel_add     = clCreateKernel(backend_ctx->program_add, "kernel_add", &err), err));
-        CL_CHECK((backend_ctx->kernel_add_row = clCreateKernel(backend_ctx->program_add, "kernel_add_row", &err), err));
+        CL_CHECK((backend_ctx->kernel_add         = clCreateKernel(backend_ctx->program_add, "kernel_add", &err), err));
+        CL_CHECK((backend_ctx->kernel_add_row     = clCreateKernel(backend_ctx->program_add, "kernel_add_row", &err), err));
+        CL_CHECK((backend_ctx->kernel_add_f16     = clCreateKernel(backend_ctx->program_add, "kernel_add_f16", &err), err));
+        CL_CHECK((backend_ctx->kernel_add_row_f16 = clCreateKernel(backend_ctx->program_add, "kernel_add_row_f16", &err), err));
+        GGML_LOG_CONT(".");
+    }
+
+    // add_id
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "add_id.cl.h"
+        };
+#else
+        const std::string kernel_src = read_file("add_id.cl");
+#endif
+        backend_ctx->program_add_id =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+
+        CL_CHECK((backend_ctx->kernel_add_id = clCreateKernel(backend_ctx->program_add_id, "kernel_add_id", &err), err));
         GGML_LOG_CONT(".");
     }
 
@@ -773,6 +820,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         CL_CHECK((backend_ctx->kernel_geglu           = clCreateKernel(backend_ctx->program_glu, "kernel_geglu", &err), err));
         CL_CHECK((backend_ctx->kernel_reglu           = clCreateKernel(backend_ctx->program_glu, "kernel_reglu", &err), err));
         CL_CHECK((backend_ctx->kernel_swiglu          = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu", &err), err));
+        CL_CHECK((backend_ctx->kernel_swiglu_oai      = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu_oai", &err), err));
         CL_CHECK((backend_ctx->kernel_geglu_erf       = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_erf", &err), err));
         CL_CHECK((backend_ctx->kernel_geglu_quick     = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_quick", &err), err));
         CL_CHECK((backend_ctx->kernel_geglu_f16       = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_f16", &err), err));
@@ -937,6 +985,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         GGML_LOG_CONT(".");
     }
 
+    // mul_mv_mxfp4_f32
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "mul_mv_mxfp4_f32.cl.h"
+        };
+#else
+        const std::string kernel_src = read_file("mul_mv_mxfp4_f32.cl");
+#endif
+        backend_ctx->program_mul_mv_mxfp4_f32 =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+
+        CL_CHECK((backend_ctx->kernel_mul_mv_mxfp4_f32 = clCreateKernel(backend_ctx->program_mul_mv_mxfp4_f32, "kernel_mul_mv_mxfp4_f32", &err), err));
+        GGML_LOG_CONT(".");
+    }
+
     // mul_mv_f16_f16
     {
 #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -1033,6 +1097,38 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         GGML_LOG_CONT(".");
     }
 
+    // mul_mm_f32_f32_l4_lm
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "mul_mm_f32_f32_l4_lm.cl.h"
+        };
+#else
+        const std::string kernel_src = read_file("mul_mm_f32_f32_l4_lm.cl");
+#endif
+        backend_ctx->program_mul_mm_f32_f32_l4_lm =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+
+        CL_CHECK((backend_ctx->kernel_mul_mm_f32_f32_l4_lm = clCreateKernel(backend_ctx->program_mul_mm_f32_f32_l4_lm, "kernel_mul_mm_f32_f32_l4_lm", &err), err));
+        GGML_LOG_CONT(".");
+    }
+
+    // mul_mm_f16_f32_l4_lm
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "mul_mm_f16_f32_l4_lm.cl.h"
+        };
+#else
+        const std::string kernel_src = read_file("mul_mm_f16_f32_l4_lm.cl");
+#endif
+        backend_ctx->program_mul_mm_f16_f32_l4_lm =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+
+        CL_CHECK((backend_ctx->kernel_mul_mm_f16_f32_l4_lm = clCreateKernel(backend_ctx->program_mul_mm_f16_f32_l4_lm, "kernel_mul_mm_f16_f32_l4_lm", &err), err));
+        GGML_LOG_CONT(".");
+    }
+
     // mul
     {
 #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -1045,8 +1141,10 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         backend_ctx->program_mul =
             build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
 
-        CL_CHECK((backend_ctx->kernel_mul     = clCreateKernel(backend_ctx->program_mul, "kernel_mul", &err), err));
-        CL_CHECK((backend_ctx->kernel_mul_row = clCreateKernel(backend_ctx->program_mul, "kernel_mul_row", &err), err));
+        CL_CHECK((backend_ctx->kernel_mul         = clCreateKernel(backend_ctx->program_mul, "kernel_mul", &err), err));
+        CL_CHECK((backend_ctx->kernel_mul_row     = clCreateKernel(backend_ctx->program_mul, "kernel_mul_row", &err), err));
+        CL_CHECK((backend_ctx->kernel_mul_f16     = clCreateKernel(backend_ctx->program_mul, "kernel_mul_f16", &err), err));
+        CL_CHECK((backend_ctx->kernel_mul_row_f16 = clCreateKernel(backend_ctx->program_mul, "kernel_mul_row_f16", &err), err));
         GGML_LOG_CONT(".");
     }
 
@@ -1094,7 +1192,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         backend_ctx->program_rms_norm =
             build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
 
-        CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm", &err), err));
+        CL_CHECK((backend_ctx->kernel_rms_norm     = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm", &err), err));
+        CL_CHECK((backend_ctx->kernel_rms_norm_mul = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm_mul", &err), err));
         GGML_LOG_CONT(".");
     }
 
@@ -1218,6 +1317,73 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         GGML_LOG_CONT(".");
     }
 
+    // flash_attn
+    {
+        #ifdef GGML_OPENCL_EMBED_KERNELS
+                const std::string kernel_src_f16 {
+                    #include "flash_attn_f16.cl.h"
+                };
+                const std::string kernel_src_f32 {
+                    #include "flash_attn_f32.cl.h"
+                };
+                const std::string kernel_src_f32_f16 {
+                    #include "flash_attn_f32_f16.cl.h"
+                };
+        #else
+                const std::string kernel_src_f16 = read_file("flash_attn_f16.cl");
+                const std::string kernel_src_f32 = read_file("flash_attn_f32.cl");
+                const std::string kernel_src_f32_f16 = read_file("flash_attn_f32_f16.cl");
+        #endif
+
+        if (!kernel_src_f16.empty() && !kernel_src_f32.empty() && !kernel_src_f32_f16.empty()) {
+            const struct { int dk; int dv; int bm; int bn; } fa_dims[] = {
+                { 64,  64, 64, 64}, { 80,  80, 64, 32}, { 96,  96, 64, 32},
+                {112, 112, 32, 32}, {128, 128, 32, 32}, {192, 128, 16, 16},
+                {192, 192, 16, 16}, {256, 256, 16, 16},
+            };
+
+            for (size_t i = 0; i < sizeof(fa_dims)/sizeof(fa_dims[0]); ++i) {
+                const int dk = fa_dims[i].dk;
+                const int dv = fa_dims[i].dv;
+                const int bm = fa_dims[i].bm;
+                const int bn = fa_dims[i].bn;
+                std::string OPTS = compile_opts +
+                    " -D DK=" + std::to_string(dk) +
+                    " -D DV=" + std::to_string(dv) +
+                    " -D BLOCK_M=" + std::to_string(bm) +
+                    " -D BLOCK_N=" + std::to_string(bn);
+
+                cl_program prog_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16.c_str(), OPTS);
+                cl_kernel k_f16, k_f16_q1;
+                CL_CHECK((k_f16 = clCreateKernel(prog_f16, "flash_attn_f16", &err), err));
+                CL_CHECK((k_f16_q1 = clCreateKernel(prog_f16, "flash_attn_f16_q1", &err), err));
+                backend_ctx->kernels_flash_attn_f16[{dk, dv}] = k_f16;
+                backend_ctx->kernels_flash_attn_f16_q1[{dk, dv}] = k_f16_q1;
+                CL_CHECK(clReleaseProgram(prog_f16));
+
+                cl_program prog_f32 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32.c_str(), OPTS);
+                cl_kernel k_f32, k_f32_q1;
+                CL_CHECK((k_f32 = clCreateKernel(prog_f32, "flash_attn_f32", &err), err));
+                CL_CHECK((k_f32_q1 = clCreateKernel(prog_f32, "flash_attn_f32_q1", &err), err));
+                backend_ctx->kernels_flash_attn_f32[{dk, dv}] = k_f32;
+                backend_ctx->kernels_flash_attn_f32_q1[{dk, dv}] = k_f32_q1;
+                CL_CHECK(clReleaseProgram(prog_f32));
+
+                cl_program prog_f32_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32_f16.c_str(), OPTS);
+                cl_kernel k_f32_f16, k_f32_f16_q1;
+                CL_CHECK((k_f32_f16 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16", &err), err));
+                CL_CHECK((k_f32_f16_q1 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16_q1", &err), err));
+                backend_ctx->kernels_flash_attn_f32_f16[{dk, dv}] = k_f32_f16;
+                backend_ctx->kernels_flash_attn_f32_f16_q1[{dk, dv}] = k_f32_f16_q1;
+                CL_CHECK(clReleaseProgram(prog_f32_f16));
+
+                backend_ctx->kernels_flash_attn_bm[{dk, dv}] = bm;
+                backend_ctx->kernels_flash_attn_bn[{dk, dv}] = bn;
+            }
+            GGML_LOG_CONT(".");
+        }
+    }
+
     // argsort
     {
 #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -1243,11 +1409,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
 #else
         const std::string kernel_src = read_file("div.cl");
 #endif
+        std::string compile_opts = std::string("-cl-std=") + opencl_c_std +
+                               " -cl-mad-enable -cl-finite-math-only ";
+
         backend_ctx->program_div =
             build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
 
-        CL_CHECK((backend_ctx->kernel_div     = clCreateKernel(backend_ctx->program_div, "kernel_div", &err), err));
-        CL_CHECK((backend_ctx->kernel_div_row = clCreateKernel(backend_ctx->program_div, "kernel_div_row", &err), err));
+        CL_CHECK((backend_ctx->kernel_div         = clCreateKernel(backend_ctx->program_div, "kernel_div", &err), err));
+        CL_CHECK((backend_ctx->kernel_div_row     = clCreateKernel(backend_ctx->program_div, "kernel_div_row", &err), err));
+        CL_CHECK((backend_ctx->kernel_div_f16     = clCreateKernel(backend_ctx->program_div, "kernel_div_f16", &err), err));
+        CL_CHECK((backend_ctx->kernel_div_row_f16 = clCreateKernel(backend_ctx->program_div, "kernel_div_row_f16", &err), err));
         GGML_LOG_CONT(".");
     }
 
@@ -1263,8 +1434,10 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         backend_ctx->program_sub =
             build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
 
-        CL_CHECK((backend_ctx->kernel_sub     = clCreateKernel(backend_ctx->program_sub, "kernel_sub", &err), err));
-        CL_CHECK((backend_ctx->kernel_sub_row = clCreateKernel(backend_ctx->program_sub, "kernel_sub_row", &err), err));
+        CL_CHECK((backend_ctx->kernel_sub         = clCreateKernel(backend_ctx->program_sub, "kernel_sub", &err), err));
+        CL_CHECK((backend_ctx->kernel_sub_row     = clCreateKernel(backend_ctx->program_sub, "kernel_sub_row", &err), err));
+        CL_CHECK((backend_ctx->kernel_sub_f16     = clCreateKernel(backend_ctx->program_sub, "kernel_sub_f16", &err), err));
+        CL_CHECK((backend_ctx->kernel_sub_row_f16 = clCreateKernel(backend_ctx->program_sub, "kernel_sub_row_f16", &err), err));
         GGML_LOG_CONT(".");
     }
 
@@ -1478,6 +1651,47 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         GGML_LOG_CONT(".");
     }
 
+     // conv2d
+     {
+        #ifdef GGML_OPENCL_EMBED_KERNELS
+                const std::string kernel_src {
+                    #include "conv2d.cl.h"
+                };
+                const std::string kernel_src_f16_f32 {
+                    #include "conv2d_f16_f32.cl.h"
+                };
+        #else
+                const std::string kernel_src = read_file("conv2d.cl");
+                const std::string kernel_src_f16_f32 = read_file("conv2d_f16_f32.cl");
+        #endif
+                if (!kernel_src.empty()) {
+                    backend_ctx->program_conv_2d_f16 =
+                        build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), (std::string(compile_opts) + " -DUSE_FP16=1").c_str());
+                    CL_CHECK((backend_ctx->kernel_conv_2d_f16 = clCreateKernel(backend_ctx->program_conv_2d_f16, "kernel_conv_2d", &err), err));
+                    GGML_LOG_CONT(".");
+                    backend_ctx->program_conv_2d_f32 =
+                        build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+                    CL_CHECK((backend_ctx->kernel_conv_2d_f32 = clCreateKernel(backend_ctx->program_conv_2d_f32, "kernel_conv_2d", &err), err));
+                    GGML_LOG_CONT(".");
+                } else {
+                    GGML_LOG_WARN("ggml_opencl: conv2d kernel source not found or empty. This op will not be available.\n");
+                    backend_ctx->program_conv_2d_f16 = nullptr;
+                    backend_ctx->kernel_conv_2d_f16 = nullptr;
+                    backend_ctx->program_conv_2d_f32 = nullptr;
+                    backend_ctx->kernel_conv_2d_f32 = nullptr;
+                }
+                if (!kernel_src_f16_f32.empty()) {
+                    backend_ctx->program_conv_2d_f16_f32 =
+                        build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16_f32.c_str(), compile_opts);
+                    CL_CHECK((backend_ctx->kernel_conv_2d_f16_f32 = clCreateKernel(backend_ctx->program_conv_2d_f16_f32, "kernel_conv_2d", &err), err));
+                    GGML_LOG_CONT(".");
+                } else {
+                    GGML_LOG_WARN("ggml_opencl: conv2d_f16_f32 kernel source not found or empty. This op will not be available.\n");
+                    backend_ctx->program_conv_2d_f16_f32 = nullptr;
+                    backend_ctx->kernel_conv_2d_f16_f32 = nullptr;
+                }
+    }
+
     // mul_mv_id_q4_0_f32_8x_flat
     {
 #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -1494,6 +1708,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         GGML_LOG_CONT(".");
     }
 
+    // mul_mv_id_mxfp4_f32
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "mul_mv_id_mxfp4_f32.cl.h"
+        };
+#else
+        const std::string kernel_src = read_file("mul_mv_id_mxfp4_f32.cl");
+#endif
+        backend_ctx->program_mul_mv_id_mxfp4_f32 =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+
+        CL_CHECK((backend_ctx->kernel_mul_mv_id_mxfp4_f32 = clCreateKernel(backend_ctx->program_mul_mv_id_mxfp4_f32, "kernel_mul_mv_id_mxfp4_f32", &err), err));
+        GGML_LOG_CONT(".");
+    }
+
     // Adreno kernels
 #ifdef GGML_OPENCL_USE_ADRENO_KERNELS
     // transpose
@@ -1511,6 +1741,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         CL_CHECK((backend_ctx->kernel_transpose_32_16 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32_16", &err), err));
         CL_CHECK((backend_ctx->kernel_transpose_32    = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32", &err), err));
         CL_CHECK((backend_ctx->kernel_transpose_16    = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16", &err), err));
+        CL_CHECK((backend_ctx->kernel_transpose_16_4x1    = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16_4x1", &err), err));
         GGML_LOG_CONT(".");
     }
 
@@ -1949,8 +2180,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
 
     backend_ctx->adreno_cl_compiler_version = get_adreno_cl_compiler_version(driver_version);
     backend_ctx->has_vector_subgroup_broadcast =
-        backend_ctx->adreno_cl_compiler_version.major >= 47 ||
-        backend_ctx->adreno_cl_compiler_version.major == 17;
+        (backend_ctx->adreno_cl_compiler_version.type == E031 && backend_ctx->adreno_cl_compiler_version.major >= 47) ||
+        (backend_ctx->adreno_cl_compiler_version.type == DX   && backend_ctx->adreno_cl_compiler_version.major >= 17);
     GGML_LOG_INFO("ggml_opencl: vector subgroup broadcast support: %s\n",
         backend_ctx->has_vector_subgroup_broadcast ? "true" : "false");
 
@@ -2063,6 +2294,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
     CL_CHECK((backend_ctx->B_d_max   = clCreateBuffer(context, 0, max_B_d_bytes,   NULL, &err), err));
 #endif // GGML_OPENCL_USE_ADRENO_KERNELS
 
+    backend_ctx->disable_fusion = getenv("GGML_OPENCL_DISABLE_FUSION") != nullptr;
+
     dev_ctx->backend_ctx = backend_ctx.release();
     return dev_ctx->backend_ctx;
 }
@@ -2232,7 +2465,45 @@ static void sync_with_other_backends(ggml_backend_t backend) {
     sync_with_other_backends(backend_ctx);
 }
 
+static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list ops) {
+    if (!ggml_can_fuse(cgraph, node_idx, ops)) {
+        return false;
+    }
+
+    if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
+        const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
+        const ggml_tensor *mul      = cgraph->nodes[node_idx+1];
+
+        GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
+        GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
+
+        // rms_norm only supports f32
+        if (mul->src[0]->type != GGML_TYPE_F32 ||
+            mul->src[1]->type != GGML_TYPE_F32 ||
+            mul->type != GGML_TYPE_F32) {
+            return false;
+        }
+
+        // if rms_norm is the B operand, then we don't handle broadcast
+        if (rms_norm == mul->src[1] &&
+            !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
+            return false;
+        }
+
+        // rms_norm assumes contiguous rows
+        if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
+            return false;
+        }
+    }
+
+    return true;
+}
+
+static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor);
+
 static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+
     for (int i = 0; i < cgraph->n_nodes; i++) {
         ggml_tensor * node = cgraph->nodes[i];
 
@@ -2245,6 +2516,12 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm
             continue;
         }
 
+        if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
+            ggml_opencl_op_rms_norm_fused(backend, node, cgraph->nodes[i+1]);
+            i++;
+            continue;
+        }
+
         bool ok = ggml_cl_compute_forward(backend, node);
         if (!ok) {
             GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
@@ -2280,6 +2557,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
             {
                 // TODO: add support
                 // ref: https://github.com/ggml-org/llama.cpp/pull/14274
+#pragma message("TODO: implement BF16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)")
                 if (op->src[0]->type != GGML_TYPE_F32) {
                     return false;
                 }
@@ -2314,11 +2592,23 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
                 default:
                     return false;
             }
-        case GGML_OP_ADD:
         case GGML_OP_SCALE:
+            return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
+        case GGML_OP_ADD:
+            if (op->type == GGML_TYPE_F16) {
+                const bool src0_ok = op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32;
+                const bool src1_ok = op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32;
+                if (src0_ok && src1_ok) {
+                    return true;
+                }
+            }
         case GGML_OP_MUL:
         case GGML_OP_DIV:
         case GGML_OP_SUB:
+            return (op->src[0]->type == op->src[1]->type) &&
+                   (op->src[0]->type == op->type) &&
+                   (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
+        case GGML_OP_ADD_ID:
             return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(op)) {
@@ -2341,6 +2631,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
                 case GGML_GLU_OP_GEGLU:
                 case GGML_GLU_OP_REGLU:
                 case GGML_GLU_OP_SWIGLU:
+                case GGML_GLU_OP_SWIGLU_OAI:
                 case GGML_GLU_OP_GEGLU_ERF:
                 case GGML_GLU_OP_GEGLU_QUICK:
                     return ggml_is_contiguous_1(op->src[0]) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
@@ -2360,6 +2651,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
                    op->src[0]->ne[3] == 1 && op->ne[3] == 1;
         case GGML_OP_UPSCALE:
             return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
+        case GGML_OP_CONV_2D:
+            return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) ||
+                   (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
+                   (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32);
         case GGML_OP_CONCAT:
             return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
         case GGML_OP_TIMESTEP_EMBEDDING:
@@ -2371,13 +2666,14 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
                 return true;
             } else if (op->src[0]->type == GGML_TYPE_F32) {
                 return op->src[1]->type == GGML_TYPE_F32;
-            } else if (op->src[0]->type == GGML_TYPE_Q4_0 ||
+            } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_MXFP4 ||
                        op->src[0]->type == GGML_TYPE_Q6_K) {
                 return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
             }
             return false;
         case GGML_OP_MUL_MAT_ID:
-            if (op->src[0]->type == GGML_TYPE_Q4_0) {
+            if (op->src[0]->type == GGML_TYPE_Q4_0 ||
+                op->src[0]->type == GGML_TYPE_MXFP4) {
                 if (op->src[1]->type == GGML_TYPE_F32) {
                     return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
                 }
@@ -2416,6 +2712,45 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
             return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_SUM_ROWS:
             return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
+        case GGML_OP_FLASH_ATTN_EXT:
+            {
+                if (op->src[4]) {
+                    return false;
+                }
+
+                const ggml_tensor * q = op->src[0];
+                const ggml_tensor * k = op->src[1];
+                const ggml_tensor * v = op->src[2];
+
+                const int dk = q->ne[0];
+                const int dv = v->ne[0];
+
+                const struct { int dk; int dv; } supported_dims[] = {
+                    { 64,  64}, { 80,  80}, { 96,  96},
+                    {112, 112}, {128, 128}, {192, 128},
+                    {192, 192}, {256, 256},
+                };
+
+                bool dims_supported = false;
+                for (size_t i = 0; i < sizeof(supported_dims)/sizeof(supported_dims[0]); ++i) {
+                    if (supported_dims[i].dk == dk && supported_dims[i].dv == dv) {
+                        dims_supported = true;
+                        break;
+                    }
+                }
+                if (!dims_supported) {
+                    return false;
+                }
+
+                const bool is_f32_f32 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F32 &&
+                                        v->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
+                const bool is_f16_f16 = q->type == GGML_TYPE_F16 && k->type == GGML_TYPE_F16 &&
+                                        v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16;
+                const bool is_f32_f16 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 &&
+                                        v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F32;
+
+                return is_f32_f32 || is_f16_f16 || is_f32_f16;
+            }
         default:
             return false;
     }
@@ -2450,10 +2785,10 @@ ggml_backend_t ggml_backend_opencl_init(void) {
     ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev);
 
     ggml_backend_t backend = new ggml_backend {
-        /* .guid      = */ ggml_backend_opencl_guid(),
-        /* .interface = */ ggml_backend_opencl_i,
-        /* .device    = */ dev,
-        /* .context   = */ backend_ctx
+        /* .guid    = */ ggml_backend_opencl_guid(),
+        /* .iface   = */ ggml_backend_opencl_i,
+        /* .device  = */ dev,
+        /* .context = */ backend_ctx
     };
 
     return backend;
@@ -2763,7 +3098,10 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
         // cl_mem qT_d = clCreateBuffer(context, CL_MEM_READ_WRITE, q_size_bytes, NULL, &err);
         CL_CHECK(err);
 
-        // size_t d_size_bytes = M * (K / 32) / 2 * sizeof(float);
+        bool K_tile_trans = true;
+        if ((K / 32) % 4 != 0){
+            K_tile_trans =false;
+        }
         size_t d_size_bytes = M * (K / 32) * 2;
         region.origin = 0;
         region.size = d_size_bytes;
@@ -2804,10 +3142,15 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
         qT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);
         CL_CHECK(err);
 
-        img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT };
         memset(&img_desc_1d, 0, sizeof(img_desc_1d));
+        if (K_tile_trans) {
+            img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT };
+            img_desc_1d.image_width = M * K / 32 / 4;
+        } else {
+            img_fmt_1d = { CL_R, CL_HALF_FLOAT };
+            img_desc_1d.image_width = M * K / 32;
+        }
         img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
-        img_desc_1d.image_width = M * K / 32 / 4;
         img_desc_1d.buffer = extra->d;
         d_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);
         CL_CHECK(err);
@@ -2843,6 +3186,10 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
         int width_s = K / 32 / 4;
 
         kernel = backend_ctx->kernel_transpose_16;
+        if (!K_tile_trans) {
+            kernel = backend_ctx->kernel_transpose_16_4x1;
+            width_s = K / 32;
+        }
         CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &d_d_image1D));
         CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &dT_d_image1D));
         CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_s));
@@ -3543,35 +3890,35 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
     GGML_ASSERT(dst);
     GGML_ASSERT(dst->extra);
 
-    const int  ne00 = src0 ? src0->ne[0] : 0;
-    const int  ne01 = src0 ? src0->ne[1] : 0;
-    const int  ne02 = src0 ? src0->ne[2] : 0;
-    const int  ne03 = src0 ? src0->ne[3] : 0;
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
 
-    const cl_ulong nb00 = src0 ? src0->nb[0] : 0;
-    const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
-    const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
-    const cl_ulong nb03 = src0 ? src0->nb[3] : 0;
+    const cl_ulong nb00 = src0->nb[0];
+    const cl_ulong nb01 = src0->nb[1];
+    const cl_ulong nb02 = src0->nb[2];
+    const cl_ulong nb03 = src0->nb[3];
 
-    const int  ne10 = src1 ? src1->ne[0] : 0;
-    const int  ne11 = src1 ? src1->ne[1] : 0;
-    const int  ne12 = src1 ? src1->ne[2] : 0;
-    const int  ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
+    const int ne10 = src1->ne[0];
+    const int ne11 = src1->ne[1];
+    const int ne12 = src1->ne[2];
+    const int ne13 = src1->ne[3];
 
-    const cl_ulong nb10 = src1 ? src1->nb[0] : 0;
-    const cl_ulong nb11 = src1 ? src1->nb[1] : 0;
-    const cl_ulong nb12 = src1 ? src1->nb[2] : 0;
-    const cl_ulong nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
+    const cl_ulong nb10 = src1->nb[0];
+    const cl_ulong nb11 = src1->nb[1];
+    const cl_ulong nb12 = src1->nb[2];
+    const cl_ulong nb13 = src1->nb[3];
 
-    const int  ne0  = dst ? dst->ne[0] : 0;
-    const int  ne1  = dst ? dst->ne[1] : 0;
-    const int  ne2  = dst ? dst->ne[2] : 0;
-    const int  ne3  = dst ? dst->ne[3] : 0;
+    const int ne0  = dst->ne[0];
+    const int ne1  = dst->ne[1];
+    const int ne2  = dst->ne[2];
+    const int ne3  = dst->ne[3];
 
-    const cl_ulong nb0  = dst ? dst->nb[0] : 0;
-    const cl_ulong nb1  = dst ? dst->nb[1] : 0;
-    const cl_ulong nb2  = dst ? dst->nb[2] : 0;
-    const cl_ulong nb3  = dst ? dst->nb[3] : 0;
+    const cl_ulong nb0  = dst->nb[0];
+    const cl_ulong nb1  = dst->nb[1];
+    const cl_ulong nb2  = dst->nb[2];
+    const cl_ulong nb3  = dst->nb[3];
 
     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
 
@@ -3583,59 +3930,114 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
     cl_ulong offset1 = extra1->offset + src1->view_offs;
     cl_ulong offsetd = extrad->offset + dst->view_offs;
 
-    bool bcast_row = false;
     cl_kernel kernel;
 
-    if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
-        GGML_ASSERT(ggml_is_contiguous(src0));
+    const bool bcast_row = ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0;
 
-        // src1 is a row
+    if (bcast_row) {
+        GGML_ASSERT(ggml_is_contiguous(src0));
         GGML_ASSERT(ne11 == 1);
+    }
 
-        bcast_row = true;
-        int ne = ne00 / 4;
-        kernel = backend_ctx->kernel_add_row;
-
-        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
-        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
-        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extra1->data_device));
-        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
-        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem),   &extrad->data_device));
-        CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
-        CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &ne));
+    if (dst->type == GGML_TYPE_F32) {
+        GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32);
+        if (bcast_row) {
+            kernel = backend_ctx->kernel_add_row;
+            const int ne = ne00 / 4;
+            CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
+            CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+            CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extra1->data_device));
+            CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
+            CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem),   &extrad->data_device));
+            CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
+            CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &ne));
+        } else {
+            kernel = backend_ctx->kernel_add;
+            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
+            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
+            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne03));
+            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
+            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
+            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
+            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
+            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne10));
+            CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne11));
+            CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &ne12));
+            CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &ne13));
+            CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
+            CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
+            CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
+            CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
+            CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int),      &ne0));
+            CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int),      &ne1));
+            CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int),      &ne2));
+            CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int),      &ne3));
+            CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
+            CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
+            CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
+            CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
+        }
+    } else if (dst->type == GGML_TYPE_F16) {
+        GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32);
+        GGML_ASSERT(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
+        const int type_src0 = (src0->type == GGML_TYPE_F32);
+        const int type_src1 = (src1->type == GGML_TYPE_F32);
+        if (bcast_row) {
+            kernel = backend_ctx->kernel_add_row_f16;
+            const int ne = ne00 / 4;
+            CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
+            CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+            CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extra1->data_device));
+            CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
+            CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem),   &extrad->data_device));
+            CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
+            CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &ne));
+            CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int),      &type_src0));
+            CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int),      &type_src1));
+        } else {
+            kernel = backend_ctx->kernel_add_f16;
+            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
+            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
+            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne03));
+            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
+            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
+            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
+            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
+            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne10));
+            CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne11));
+            CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &ne12));
+            CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &ne13));
+            CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
+            CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
+            CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
+            CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
+            CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int),      &ne0));
+            CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int),      &ne1));
+            CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int),      &ne2));
+            CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int),      &ne3));
+            CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
+            CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
+            CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
+            CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
+            CL_CHECK(clSetKernelArg(kernel, 30, sizeof(int),      &type_src0));
+            CL_CHECK(clSetKernelArg(kernel, 31, sizeof(int),      &type_src1));
+        }
     } else {
-        kernel = backend_ctx->kernel_add;
-
-        CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
-        CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
-        CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
-        CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
-        CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
-        CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
-        CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
-        CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
-        CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
-        CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne03));
-        CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
-        CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
-        CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
-        CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
-        CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne10));
-        CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne11));
-        CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &ne12));
-        CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &ne13));
-        CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
-        CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
-        CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
-        CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
-        CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int),      &ne0));
-        CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int),      &ne1));
-        CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int),      &ne2));
-        CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int),      &ne3));
-        CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
-        CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
-        CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
-        CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
+        GGML_ASSERT(false && "unsupported data types for add");
     }
 
     if (bcast_row) {
@@ -3645,20 +4047,20 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
 
         size_t * local_work_size_ptr = local_work_size;
         if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
-            local_work_size_ptr = nullptr;  // Let driver choose the work-group sizes.
+            local_work_size_ptr = nullptr;
         }
 
-        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
+        backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size_ptr, dst);
     } else {
         unsigned int nth = MIN(64, ne0);
-        size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03};
+        size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
         size_t local_work_size[] = {nth, 1, 1};
 
         backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
     }
 }
 
-static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_cl_add_id(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(src0);
     GGML_ASSERT(src0->extra);
     GGML_ASSERT(src1);
@@ -3666,35 +4068,108 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const
     GGML_ASSERT(dst);
     GGML_ASSERT(dst->extra);
 
-    const int ne00 = src0 ? src0->ne[0] : 0;
-    const int ne01 = src0 ? src0->ne[1] : 0;
-    const int ne02 = src0 ? src0->ne[2] : 0;
-    const int ne03 = src0 ? src0->ne[3] : 0;
+    const ggml_tensor * src2 = dst->src[2];
+    GGML_ASSERT(src2);
+    GGML_ASSERT(src2->extra);
 
-    const cl_ulong nb00 = src0 ? src0->nb[0] : 0;
-    const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
-    const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
-    const cl_ulong nb03 = src0 ? src0->nb[3] : 0;
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(src2->type == GGML_TYPE_I32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
 
-    const int ne10 = src1 ? src1->ne[0] : 0;
-    const int ne11 = src1 ? src1->ne[1] : 0;
-    const int ne12 = src1 ? src1->ne[2] : 0;
-    const int ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
+    GGML_ASSERT(ggml_is_contiguous_rows(src0));
 
-    const cl_ulong nb10 = src1 ? src1->nb[0] : 0;
-    const cl_ulong nb11 = src1 ? src1->nb[1] : 0;
-    const cl_ulong nb12 = src1 ? src1->nb[2] : 0;
-    const cl_ulong nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+
+    const cl_ulong nb01 = src0->nb[1];
+    const cl_ulong nb02 = src0->nb[2];
+
+    const cl_ulong nb11 = src1->nb[1];
+
+    const cl_ulong nb21 = src2->nb[1];
+
+    const int ne0 = dst->ne[0];
+    const int ne1 = dst->ne[1];
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
+    ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offset1 = extra1->offset + src1->view_offs;
+    cl_ulong offset2 = extra2->offset + src2->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    cl_kernel kernel = backend_ctx->kernel_add_id;
+
+    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extra2->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offset2));
+    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_mem),   &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &offsetd));
+    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb01));
+    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb02));
+    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb11));
+    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb21));
+    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne0));
+    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne1));
+
+    int nth = MIN(ne00, (int) backend_ctx->get_kernel_workgroup_size(kernel));
+    size_t global_work_size[] = { (size_t)ne01*nth, (size_t)ne02, 1 };
+    size_t local_work_size[] = { (size_t)nth, 1, 1 };
+
+    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+}
 
-    const int ne0  = dst ? dst->ne[0] : 0;
-    const int ne1  = dst ? dst->ne[1] : 0;
-    const int ne2  = dst ? dst->ne[2] : 0;
-    const int ne3  = dst ? dst->ne[3] : 0;
+static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(src1);
+    GGML_ASSERT(src1->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    GGML_ASSERT(src0->type == src1->type);
+    GGML_ASSERT(src0->type == dst->type);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
 
-    const cl_ulong nb0  = dst ? dst->nb[0] : 0;
-    const cl_ulong nb1  = dst ? dst->nb[1] : 0;
-    const cl_ulong nb2  = dst ? dst->nb[2] : 0;
-    const cl_ulong nb3  = dst ? dst->nb[3] : 0;
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
+
+    const cl_ulong nb00 = src0->nb[0];
+    const cl_ulong nb01 = src0->nb[1];
+    const cl_ulong nb02 = src0->nb[2];
+    const cl_ulong nb03 = src0->nb[3];
+
+    const int ne10 = src1->ne[0];
+    const int ne11 = src1->ne[1];
+    const int ne12 = src1->ne[2];
+    const int ne13 = src1->ne[3]; UNUSED(ne13);
+
+    const cl_ulong nb10 = src1->nb[0];
+    const cl_ulong nb11 = src1->nb[1];
+    const cl_ulong nb12 = src1->nb[2];
+    const cl_ulong nb13 = src1->nb[3]; UNUSED(nb13);
+
+    const int ne0  = dst->ne[0];
+    const int ne1  = dst->ne[1];
+    const int ne2  = dst->ne[2];
+    const int ne3  = dst->ne[3];
+
+    const cl_ulong nb0  = dst->nb[0];
+    const cl_ulong nb1  = dst->nb[1];
+    const cl_ulong nb2  = dst->nb[2];
+    const cl_ulong nb3  = dst->nb[3];
 
     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
 
@@ -3717,7 +4192,12 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const
 
         bcast_row = true;
         int ne = ne00 / 4;
-        kernel = backend_ctx->kernel_mul_row;
+
+        if (src0->type == GGML_TYPE_F32) {
+            kernel = backend_ctx->kernel_mul_row;
+        } else {
+            kernel = backend_ctx->kernel_mul_row_f16;
+        }
 
         CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
         CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
@@ -3727,7 +4207,11 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const
         CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
         CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &ne));
     } else {
-        kernel = backend_ctx->kernel_mul;
+        if (src0->type == GGML_TYPE_F32) {
+            kernel = backend_ctx->kernel_mul;
+        } else {
+            kernel = backend_ctx->kernel_mul_f16;
+        }
 
         CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
         CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
@@ -3789,6 +4273,10 @@ static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const
     GGML_ASSERT(dst);
     GGML_ASSERT(dst->extra);
 
+    GGML_ASSERT(src0->type == src1->type);
+    GGML_ASSERT(src0->type == dst->type);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+
     const int ne00 = src0->ne[0];
     const int ne01 = src0->ne[1];
     const int ne02 = src0->ne[2];
@@ -3837,7 +4325,12 @@ static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const
 
         bcast_row = true;
         int ne = ne00 / 4;
-        kernel = backend_ctx->kernel_div_row;
+
+        if (src0->type == GGML_TYPE_F32) {
+            kernel = backend_ctx->kernel_div_row;
+        } else {
+            kernel = backend_ctx->kernel_div_row_f16;
+        }
 
         CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
         CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
@@ -3847,7 +4340,11 @@ static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const
         CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
         CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &ne));
     } else {
-        kernel = backend_ctx->kernel_div;
+        if (src0->type == GGML_TYPE_F32) {
+            kernel = backend_ctx->kernel_div;
+        } else {
+            kernel = backend_ctx->kernel_div_f16;
+        }
 
         CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
         CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
@@ -3897,6 +4394,10 @@ static void ggml_cl_sub(ggml_backend_t backend, const ggml_tensor * src0, const
     GGML_ASSERT(dst);
     GGML_ASSERT(dst->extra);
 
+    GGML_ASSERT(src0->type == src1->type);
+    GGML_ASSERT(src0->type == dst->type);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+
     const int ne00 = src0->ne[0];
     const int ne01 = src0->ne[1];
     const int ne02 = src0->ne[2];
@@ -3945,7 +4446,12 @@ static void ggml_cl_sub(ggml_backend_t backend, const ggml_tensor * src0, const
 
         bcast_row = true;
         int ne = ne00 / 4;
-        kernel = backend_ctx->kernel_sub_row;
+
+        if (src0->type == GGML_TYPE_F32) {
+            kernel = backend_ctx->kernel_sub_row;
+        } else {
+            kernel = backend_ctx->kernel_sub_row_f16;
+        }
 
         CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
         CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
@@ -3955,7 +4461,11 @@ static void ggml_cl_sub(ggml_backend_t backend, const ggml_tensor * src0, const
         CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
         CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &ne));
     } else {
-        kernel = backend_ctx->kernel_sub;
+        if (src0->type == GGML_TYPE_F32) {
+            kernel = backend_ctx->kernel_sub;
+        } else {
+            kernel = backend_ctx->kernel_sub_f16;
+        }
 
         CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
         CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
@@ -4403,6 +4913,117 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c
     backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
 }
 
+static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor) {
+    GGML_ASSERT(mul_tensor);
+    GGML_ASSERT(rms_norm_tensor);
+
+    // src0 is the src of rms_norm, src1 is the other src of mul (one being rms_norm)
+    const ggml_tensor * src0 = rms_norm_tensor->src[0];
+    const ggml_tensor * src1;
+    if (mul_tensor->src[0] == rms_norm_tensor) {
+        src1 = mul_tensor->src[1];
+    } else if (mul_tensor->src[1] == rms_norm_tensor) {
+        src1 = mul_tensor->src[0];
+    } else {
+        GGML_ASSERT(false && "Invalid args for rms_norm and mul");
+    }
+    const ggml_tensor * dst = mul_tensor;
+
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(src1);
+    GGML_ASSERT(src1->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offset1 = extra1->offset + src0->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+
+    float eps;
+    memcpy(&eps, rms_norm_tensor->op_params, sizeof(float));
+
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
+
+    const cl_ulong nb01 = src0->nb[1];
+    const cl_ulong nb02 = src0->nb[2];
+    const cl_ulong nb03 = src0->nb[3];
+
+    const int ne10 = src1->ne[0];
+    const int ne11 = src1->ne[1];
+    const int ne12 = src1->ne[2];
+    const int ne13 = src1->ne[3];
+
+    const cl_ulong nb11 = src1->nb[1];
+    const cl_ulong nb12 = src1->nb[2];
+    const cl_ulong nb13 = src1->nb[3];
+
+    const cl_ulong nb1 = dst->nb[1];
+    const cl_ulong nb2 = dst->nb[2];
+    const cl_ulong nb3 = dst->nb[3];
+
+    GGML_ASSERT(ne00 % 4 == 0);
+
+    size_t sgs;
+    if (backend_ctx->gpu_family == ADRENO) {
+        sgs = 64;
+    } else if (backend_ctx->gpu_family == INTEL) {
+        sgs = 32;
+    } else {
+        GGML_ASSERT(false && "Unsupported GPU");
+    }
+
+    cl_kernel kernel = backend_ctx->kernel_rms_norm_mul;
+
+    int nth = sgs;
+    int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
+    while (nth < ne00 && nth < max_workgroup_size) {
+        nth *= 2;
+    }
+    nth = MIN(nth, max_workgroup_size);
+    nth = MIN(nth, ne00);
+
+    size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
+    size_t local_work_size[] = {(size_t)nth, 1, 1};
+
+    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),        &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong),      &offset0));
+    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),        &extra1->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong),      &offset1));
+    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),        &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong),      &offsetd));
+    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),           &ne00));
+    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),           &ne01));
+    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),           &ne02));
+    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),           &ne03));
+    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),      &nb01));
+    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),      &nb02));
+    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong),      &nb03));
+    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),           &ne10));
+    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),           &ne11));
+    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),           &ne12));
+    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),           &ne13));
+    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),      &nb11));
+    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),      &nb12));
+    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),      &nb13));
+    CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong),      &nb1));
+    CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong),      &nb2));
+    CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong),      &nb3));
+    CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float),         &eps));
+    CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*nth/sgs, NULL));
+
+    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+}
+
 static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(src0);
     GGML_ASSERT(src0->extra);
@@ -4945,6 +5566,133 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor
     backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst);
 }
 
+static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, const ggml_tensor * k, ggml_tensor * dst) {
+    const ggml_tensor * v = dst->src[2];
+    const ggml_tensor * mask = dst->src[3];
+    GGML_ASSERT(q->extra);
+    GGML_ASSERT(k->extra);
+    GGML_ASSERT(v->extra);
+    GGML_ASSERT(dst->extra);
+    if (mask) {
+        GGML_ASSERT(mask->extra);
+    }
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+
+    const int n_q = q->ne[1];
+    const int n_kv = k->ne[1];
+    const int d_head_q = q->ne[0];
+    const int d_head_v = v->ne[0];
+    const int n_head = q->ne[2];
+    const int n_head_kv = k->ne[2];
+    const int n_batch = q->ne[3];
+
+    cl_kernel kernel = NULL;
+
+    const bool is_f16 = q->type == GGML_TYPE_F16;
+    const bool is_mixed = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16;
+    const std::pair dk_dv = {d_head_q, d_head_v};
+
+    if (n_q == 1) {
+        if (is_mixed) {
+            kernel = backend_ctx->kernels_flash_attn_f32_f16_q1.at(dk_dv);
+        } else if (is_f16) {
+            kernel = backend_ctx->kernels_flash_attn_f16_q1.at(dk_dv);
+        } else {
+            kernel = backend_ctx->kernels_flash_attn_f32_q1.at(dk_dv);
+        }
+    } else {
+        if (is_mixed) {
+            kernel = backend_ctx->kernels_flash_attn_f32_f16.at(dk_dv);
+        } else if (is_f16) {
+            kernel = backend_ctx->kernels_flash_attn_f16.at(dk_dv);
+        } else {
+            kernel = backend_ctx->kernels_flash_attn_f32.at(dk_dv);
+        }
+    }
+    GGML_ASSERT(kernel != NULL);
+
+    ggml_tensor_extra_cl * extra_q = (ggml_tensor_extra_cl *)q->extra;
+    ggml_tensor_extra_cl * extra_k = (ggml_tensor_extra_cl *)k->extra;
+    ggml_tensor_extra_cl * extra_v = (ggml_tensor_extra_cl *)v->extra;
+    ggml_tensor_extra_cl * extra_o = (ggml_tensor_extra_cl *)dst->extra;
+    ggml_tensor_extra_cl * extra_mask = mask ? (ggml_tensor_extra_cl *)mask->extra : NULL;
+
+    cl_ulong offset_q = extra_q->offset + q->view_offs;
+    cl_ulong offset_k = extra_k->offset + k->view_offs;
+    cl_ulong offset_v = extra_v->offset + v->view_offs;
+    cl_ulong offset_o = extra_o->offset + dst->view_offs;
+    cl_mem   mask_buffer = extra_mask ? extra_mask->data_device : NULL;
+    cl_ulong offset_mask = extra_mask ? extra_mask->offset + mask->view_offs : 0;
+
+    const cl_ulong q_nb1 = q->nb[1], q_nb2 = q->nb[2], q_nb3 = q->nb[3];
+    const cl_ulong k_nb1 = k->nb[1], k_nb2 = k->nb[2], k_nb3 = k->nb[3];
+    const cl_ulong v_nb1 = v->nb[1], v_nb2 = v->nb[2], v_nb3 = v->nb[3];
+    const cl_ulong o_nb1 = dst->nb[1], o_nb2 = dst->nb[2], o_nb3 = dst->nb[3];
+    const cl_ulong mask_nb1 = mask ? mask->nb[1] : 0;
+    const cl_ulong mask_nb2 = mask ? mask->nb[2] : 0;
+    const cl_ulong mask_nb3 = mask ? mask->nb[3] : 0;
+    const int mask_ne2 = mask ? mask->ne[2] : 0;
+    const int mask_ne3 = mask ? mask->ne[3] : 0;
+
+    float scale, max_bias, logit_softcap;
+    const float * params = (const float *)dst->op_params;
+    scale         = params[0];
+    max_bias      = params[1];
+    logit_softcap = params[2];
+
+    const int is_causal = (mask == NULL && n_q > 1 && n_q == n_kv);
+
+    const int n_head_log2_val = n_head > 0 ? 1u << (int)floorf(log2f((float)n_head)) : 0;
+    const float n_head_log2_f = n_head_log2_val > 0 ? (float)n_head_log2_val : 1.0f;
+    const float m0 = powf(2.0f, -(max_bias) / n_head_log2_f);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2_f);
+
+    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra_q->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset_q));
+    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extra_k->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset_k));
+    CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem),   &extra_v->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset_v));
+    CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem),   &extra_o->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offset_o));
+    CL_CHECK(clSetKernelArg(kernel, 8, sizeof(float),    &scale));
+    CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int),      &n_q));
+    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),     &n_kv));
+    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),     &is_causal));
+    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),     &n_head));
+    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &q_nb1)); CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &q_nb2)); CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &q_nb3));
+    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &k_nb1)); CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &k_nb2)); CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &k_nb3));
+    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &v_nb1)); CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &v_nb2)); CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &v_nb3));
+    CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &o_nb1)); CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &o_nb2)); CL_CHECK(clSetKernelArg(kernel, 24, sizeof(cl_ulong), &o_nb3));
+    CL_CHECK(clSetKernelArg(kernel, 25, sizeof(float),    &max_bias));
+    CL_CHECK(clSetKernelArg(kernel, 26, sizeof(float),    &m0));
+    CL_CHECK(clSetKernelArg(kernel, 27, sizeof(float),    &m1));
+    CL_CHECK(clSetKernelArg(kernel, 28, sizeof(int),      &n_head_log2_val));
+    CL_CHECK(clSetKernelArg(kernel, 29, sizeof(float),    &logit_softcap));
+    CL_CHECK(clSetKernelArg(kernel, 30, sizeof(int),      &n_head_kv));
+    CL_CHECK(clSetKernelArg(kernel, 31, sizeof(cl_mem),   &mask_buffer));
+    CL_CHECK(clSetKernelArg(kernel, 32, sizeof(cl_ulong), &offset_mask));
+    CL_CHECK(clSetKernelArg(kernel, 33, sizeof(cl_ulong), &mask_nb1));
+    CL_CHECK(clSetKernelArg(kernel, 34, sizeof(cl_ulong), &mask_nb2));
+    CL_CHECK(clSetKernelArg(kernel, 35, sizeof(cl_ulong), &mask_nb3));
+    CL_CHECK(clSetKernelArg(kernel, 36, sizeof(int),      &mask_ne2));
+    CL_CHECK(clSetKernelArg(kernel, 37, sizeof(int),      &mask_ne3));
+
+    if (n_q == 1) {
+        const size_t wg_size = 64;
+        size_t local_work_size[] = { wg_size, 1 };
+        size_t global_work_size[] = { wg_size, (size_t)(n_head * n_batch) };
+        backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
+    } else {
+        const int block_m = backend_ctx->kernels_flash_attn_bm.at(dk_dv);
+        const size_t wg_size = block_m;
+        size_t local_work_size[] = { wg_size, 1 };
+        size_t global_work_size[] = { (size_t)((n_q + block_m - 1) / block_m) * wg_size, (size_t)(n_head * n_batch) };
+        backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
+    }
+}
+
 static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
 
@@ -4997,6 +5745,82 @@ static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_ten
     backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
 }
 
+static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_TENSOR_BINARY_OP_LOCALS;
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offset1 = extra1->offset + src1->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    const cl_uint Cout = ne03; const cl_uint Cin = ne02; const cl_uint N = ne13;
+    const cl_uint KW = ne00; const cl_uint KH = ne01; const cl_uint W = ne10; const cl_uint H = ne11; const cl_uint OW = ne0; const cl_uint OH = ne1;
+
+    const cl_uint s0 = dst->op_params[0]; const cl_uint s1 = dst->op_params[1];
+    const cl_uint p0 = dst->op_params[2]; const cl_uint p1 = dst->op_params[3];
+    const cl_uint d0 = dst->op_params[4]; const cl_uint d1 = dst->op_params[5];
+
+    const cl_uint cl_nb01 = nb01/ggml_type_size(src0->type); const cl_uint cl_nb02 = nb02/ggml_type_size(src0->type); const cl_uint cl_nb03 = nb03/ggml_type_size(src0->type);
+    const cl_uint cl_nb11 = nb11/ggml_type_size(src1->type); const cl_uint cl_nb12 = nb12/ggml_type_size(src1->type); const cl_uint cl_nb13 = nb13/ggml_type_size(src1->type);
+    const cl_uint cl_nb1 = nb1/ggml_type_size(dst->type); const cl_uint cl_nb2 = nb2/ggml_type_size(dst->type); const cl_uint cl_nb3 = nb3/ggml_type_size(dst->type);
+
+    const int64_t NPQ = (int64_t)N * OW * OH;
+
+    const uint32_t BS_K = 64;
+    const uint32_t BS_NPQ = 64;
+    const uint32_t BS_CRS = 16;
+    const uint32_t VEC_SIZE = 4;
+
+    const uint32_t TS_K = 4;
+    const uint32_t TS_NPQ = 8;
+
+    const uint32_t WG_K = BS_K / TS_K;
+    const uint32_t WG_NPQ = BS_NPQ / TS_NPQ;
+
+    auto splitWork = [](uint32_t work_size, uint32_t block_size) { return (block_size + work_size - 1) / block_size; };
+    const uint32_t NB_K = splitWork(Cout, BS_K);
+    const uint32_t NB_NPQ = splitWork(NPQ, BS_NPQ);
+
+    cl_kernel kernel;
+    size_t shmem_size;
+
+    if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
+        kernel = backend_ctx->kernel_conv_2d_f16;
+        shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_half4));
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
+        kernel = backend_ctx->kernel_conv_2d_f32;
+        shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_float) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4));
+    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
+        kernel = backend_ctx->kernel_conv_2d_f16_f32;
+        shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4));
+    } else {
+        GGML_ASSERT(false && "Unsupported data type combination for conv2d");
+    }
+
+    cl_uint idx = 0;
+    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset0));
+    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra1->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset1));
+    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offsetd));
+    CL_CHECK(clSetKernelArg(kernel, idx++, shmem_size, NULL));
+    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &Cout)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &Cin)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &N));
+    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &KW)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &KH)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &W)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &H));
+    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &OW)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &OH));
+    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &p0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &p1));
+    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &d0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &d1));
+    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb01)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb02)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb03));
+    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb11)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb12)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb13));
+    CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb2)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb3));
+
+    size_t global_work_size[] = { (size_t)NB_K * WG_K, (size_t)NB_NPQ * WG_NPQ, 1 };
+    size_t local_work_size[] = { (size_t)WG_K, (size_t)WG_NPQ, 1 };
+
+    backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
+}
+
 static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(src0);
     GGML_ASSERT(src0->extra);
@@ -5010,18 +5834,6 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
 
     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
 
-     if (src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32 &&
-        src0->ne[1] > 32 &&   // M > 32
-        src1->ne[1] > 32 &&   // N > 32
-        src0->ne[0] > 32 &&   // K > 32
-        src0->ne[2] == 1 && src0->ne[3] == 1 &&
-        src1->ne[2] == 1 && src1->ne[3] == 1 &&
-        ggml_is_contiguous(src0) && ggml_is_contiguous(src1) &&
-        backend_ctx->kernel_mul_mat_f16_f32_tiled != NULL) {
-        ggml_cl_mul_mat_f16_f32_tiled(backend, src0, src1, dst);
-        return;
-    }
-
     ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
     ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
     ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
@@ -5368,6 +6180,101 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
     } // if (ne01 && ne1)
 #endif // GGML_OPENCL_USE_ADRENO_KERNELS
 
+    // GEMM using local memory
+    // Current BK = 16, so ne00 % 16 == 0
+    if (ggml_is_contiguous(src0) &&
+        ggml_is_contiguous(src1) &&
+        src1t == GGML_TYPE_F32 &&
+        ne00 % 16 == 0 &&
+        ne11 > 1) {
+        switch(src0t) {
+            case GGML_TYPE_F32: {
+                kernel = backend_ctx->kernel_mul_mm_f32_f32_l4_lm;
+                nth0 = 128; // calculated as (BM*BN)/(TM*TN)
+
+                int batch_stride_a = ne00*ne01;
+                int batch_stride_b = ne10*ne11;
+                int batch_stride_d = ne0*ne1;
+
+                CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+                CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+                CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+                CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+                CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+                CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+                CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+                CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
+                CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
+                CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne11));
+                CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12));
+                CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne10)); // stride_a
+                CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne10)); // stride_b
+                CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne01)); // stride_d
+                CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &batch_stride_a));
+                CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &batch_stride_b));
+                CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &batch_stride_d));
+                CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &r2));
+                CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),      &r3));
+
+                // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
+                size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
+                size_t local_work_size[] = {(size_t)nth0, 1, 1};
+
+                backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+                return;
+            }
+            case GGML_TYPE_F16: {
+                kernel = backend_ctx->kernel_mul_mm_f16_f32_l4_lm;
+                nth0 = 128; // calculated as (BM*BN)/(TM*TN)
+
+                int batch_stride_a = ne00*ne01;
+                int batch_stride_b = ne10*ne11;
+                int batch_stride_d = ne0*ne1;
+
+                CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+                CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+                CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+                CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+                CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+                CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+                CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+                CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
+                CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
+                CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne11));
+                CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12));
+                CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne10)); // stride_a
+                CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne10)); // stride_b
+                CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne01)); // stride_d
+                CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &batch_stride_a));
+                CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &batch_stride_b));
+                CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &batch_stride_d));
+                CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &r2));
+                CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),      &r3));
+
+                // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
+                size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
+                size_t local_work_size[] = {(size_t)nth0, 1, 1};
+
+                backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+                return;
+            }
+            default:
+                break;
+        }
+    }
+
+    if (src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32 &&
+        src0->ne[1] > 32 &&   // M > 32
+        src1->ne[1] > 32 &&   // N > 32
+        src0->ne[0] > 32 &&   // K > 32
+        src0->ne[2] == 1 && src0->ne[3] == 1 &&
+        src1->ne[2] == 1 && src1->ne[3] == 1 &&
+        ggml_is_contiguous(src0) && ggml_is_contiguous(src1) &&
+        backend_ctx->kernel_mul_mat_f16_f32_tiled != NULL) {
+        ggml_cl_mul_mat_f16_f32_tiled(backend, src0, src1, dst);
+        return;
+    }
+
     if (!ggml_is_transposed(src0) &&
         !ggml_is_transposed(src1) &&
         src1t == GGML_TYPE_F32 &&
@@ -5640,11 +6547,47 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
             CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &r2));
             CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &r3));
             break;
+        case GGML_TYPE_MXFP4: {
+            kernel = backend_ctx->kernel_mul_mv_mxfp4_f32;
+
+            if (backend_ctx->gpu_family == INTEL) {
+                nth0 = 16;
+                nth1 = 2;
+                ndst = nth1*2;
+            } else if (backend_ctx->gpu_family == ADRENO) {
+                nth0 = 64;
+                nth1 = 2;
+                ndst = nth1*2;
+            } else {
+                GGML_ASSERT(false && "TODO: Unknown GPU");
+            }
+
+            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &nb01));
+            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb02));
+            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb03));
+            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12));
+            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb11));
+            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb12));
+            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb13));
+            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne0));
+            CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne1));
+            CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &r2));
+            CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &r3));
+            CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float)*nth0,nullptr));
+            break;
+        }
         default:
             GGML_ASSERT(false && "not implemented");
     }
 
-    if (src0t == GGML_TYPE_Q4_0 ||
+    if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_MXFP4 ||
         src0t == GGML_TYPE_Q4_1 ||
         src0t == GGML_TYPE_Q8_0 ||
         src0t == GGML_TYPE_Q2_K) {
@@ -5693,10 +6636,12 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
 
     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
 
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
     ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
     ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra;
     ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
 
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
     cl_ulong offset1 = extra1->offset + src1->view_offs;
     cl_ulong offset2 = extra2->offset + src2->view_offs;
     cl_ulong offsetd = extrad->offset + dst->view_offs;
@@ -5711,7 +6656,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
     const int ne03 = src0->ne[3];
 
     const cl_ulong nb00 = src0->nb[0];
+    const cl_ulong nb01 = src0->nb[1];
     const cl_ulong nb02 = src0->nb[2];
+    const cl_ulong nb03 = src0->nb[3];
 
     const int ne10 = src1->ne[0];
     const int ne11 = src1->ne[1];
@@ -5720,6 +6667,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
 
     const cl_ulong nb11 = src1->nb[1];
     const cl_ulong nb12 = src1->nb[2];
+    const cl_ulong nb13 = src1->nb[3];
 
     const int ne20 = src2->ne[0];
     const int ne21 = src2->ne[1];
@@ -5787,6 +6735,49 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
 
             break;
         }
+        case GGML_TYPE_MXFP4: {
+            kernel = backend_ctx->kernel_mul_mv_id_mxfp4_f32;
+
+            if (backend_ctx->gpu_family == INTEL) {
+                sgs  = 16;
+                nsg  = 2;
+                ndst = 2;
+            } else if (backend_ctx->gpu_family == ADRENO) {
+                sgs  = 64;
+                nsg  = 2;
+                ndst = 2;
+            } else {
+                GGML_ASSERT(false && "TODO: Unknown GPU");
+            }
+
+            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extra2->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offset2));
+            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_mem),   &extrad->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &offsetd));
+            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne00));
+            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb01));
+            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
+            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
+            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne11));
+            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne12));
+            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11));
+            CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12));
+            CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb13));
+            CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &ne20));
+            CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),      &ne21));
+            CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb21));
+            CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int),      &ne0));
+            CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int),      &ne1));
+            CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int),      &r2));
+            CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int),      &r3));
+            CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*sgs,nullptr));
+
+            break;
+        }
         default:
             GGML_ASSERT(false && "not implemented");;
     }
@@ -6027,17 +7018,24 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
         GGML_ASSERT(src1->extra);
     }
 
+    const ggml_tensor * src2 = dst->src[2];
+    if (src2) {
+        GGML_ASSERT(src2->extra);
+    }
+
     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
 
     ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
     ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
 
     ggml_tensor_extra_cl * extra1 = src1 ? (ggml_tensor_extra_cl *)src1->extra : nullptr;
+    ggml_tensor_extra_cl * extra2 = src2 ? (ggml_tensor_extra_cl *)src2->extra : nullptr;
 
     cl_ulong offset0 = extra0->offset + src0->view_offs;
     cl_ulong offsetd = extrad->offset + dst->view_offs;
 
     cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
+    cl_ulong offset2 = extra2 ? extra2->offset + src2->view_offs : offset0;
 
     const int ne00 = src0->ne[0];
     const int ne01 = src0->ne[1];
@@ -6105,25 +7103,27 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
     CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
     CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   extra1 ? &extra1->data_device : &extra0->data_device));
     CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
-    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
-    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
-    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
-    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &nb01));
-    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb02));
-    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb03));
-    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12));
-    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne13));
-    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
-    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
-    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));
-    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb1));
-    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb2));
-    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb3));
-    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float),    &scale));
-    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(float),    &max_bias));
-    CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float),    &m0));
-    CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float),    &m1));
-    CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int),      &n_head_log2));
+    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   extra2 ? &extra2->data_device : &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offset2));
+    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_mem),   &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &offsetd));
+    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne00));
+    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb01));
+    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
+    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
+    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne12));
+    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne13));
+    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11));
+    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12));
+    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb13));
+    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb1));
+    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb2));
+    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb3));
+    CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float),    &scale));
+    CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float),    &max_bias));
+    CL_CHECK(clSetKernelArg(kernel, 22, sizeof(float),    &m0));
+    CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float),    &m1));
+    CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int),      &n_head_log2));
 
     size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
     size_t local_work_size[] = {(size_t)nth, 1, 1};
@@ -6530,6 +7530,9 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
                 kernel = backend_ctx->kernel_swiglu_f16;
             }
             break;
+        case GGML_GLU_OP_SWIGLU_OAI:
+            kernel = backend_ctx->kernel_swiglu_oai;
+            break;
         case GGML_GLU_OP_GEGLU_ERF:
             if (dst->type == GGML_TYPE_F32) {
                 kernel = backend_ctx->kernel_geglu_erf;
@@ -6565,7 +7568,10 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
 
     const cl_ulong nb1  = dst->nb[1];
 
-    const int swp = ((const int32_t *) dst->op_params)[1];
+    const int   swp   = ggml_get_op_params_i32(dst, 1);
+    const float alpha = ggml_get_op_params_f32(dst, 2);
+    const float limit = ggml_get_op_params_f32(dst, 3);
+
     const int ne00_off = src1 ? 0 : (swp ? ne0 : 0);
     const int ne10_off = src1 ? 0 : (swp ? 0 : ne0);
 
@@ -6582,6 +7588,11 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
     CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne00_off));
     CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne10_off));
 
+    if (ggml_get_glu_op(dst) == GGML_GLU_OP_SWIGLU_OAI) {
+        CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &limit));
+        CL_CHECK(clSetKernelArg(kernel, 13, sizeof(float), &alpha));
+    }
+
     const size_t nrows = ggml_nrows(src0);
     size_t nth = 512;
     size_t global_work_size[] = {nrows*nth, 1, 1};
@@ -6638,6 +7649,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
             }
             func = ggml_cl_add;
             break;
+        case GGML_OP_ADD_ID:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_add_id;
+            break;
         case GGML_OP_MUL:
             if (!any_on_device) {
                 return false;
@@ -6751,6 +7768,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
             }
             ggml_cl_upscale(backend, tensor->src[0], tensor);
             return true;
+        case GGML_OP_CONV_2D:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_conv_2d;
+            break;
         case GGML_OP_CONCAT:
             if (!any_on_device) {
                 return false;
@@ -6826,6 +7849,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
             }
             func = ggml_cl_sum_rows;
             break;
+        case GGML_OP_FLASH_ATTN_EXT:
+            if (!any_on_device) {
+                return false;
+            }
+            ggml_cl_flash_attn(backend, tensor->src[0], tensor->src[1], tensor);
+            return true;
         default:
             return false;
     }
diff --git a/ggml/src/ggml-opencl/kernels/add.cl b/ggml/src/ggml-opencl/kernels/add.cl
index f73f3c01343..509bf17344e 100644
--- a/ggml/src/ggml-opencl/kernels/add.cl
+++ b/ggml/src/ggml-opencl/kernels/add.cl
@@ -81,3 +81,110 @@ kernel void kernel_add_row(
     uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
     dst[gid] = src0[gid] + src1[idx1];
 }
+
+kernel void kernel_add_f16(
+        global char * src0,
+        ulong  offset0,
+        global char * src1,
+        ulong  offset1,
+        global char * dst,
+        ulong  offsetd,
+        int   ne00,
+        int   ne01,
+        int   ne02,
+        int   ne03,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int   ne10,
+        int   ne11,
+        int   ne12,
+        int   ne13,
+        ulong nb10,
+        ulong nb11,
+        ulong nb12,
+        ulong nb13,
+        int   ne0,
+        int   ne1,
+        int   ne2,
+        int   ne3,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3,
+        int type_src0,
+        int type_src1
+) {
+    src0 = src0 + offset0;
+    src1 = src1 + offset1;
+    dst = dst + offsetd;
+
+    int i03 = get_group_id(2);
+    int i02 = get_group_id(1);
+    int i01 = get_group_id(0);
+
+    int i13 = i03 % ne13;
+    int i12 = i02 % ne12;
+    int i11 = i01 % ne11;
+
+    global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+    global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+    global char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
+
+    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
+        const int i10 = i0 % ne10;
+
+        half v0, v1;
+        if (type_src0 == 1) {
+            v0 = convert_half(*((global float *)(src0_ptr + i0*nb00)));
+        } else {
+            v0 = *((global half *)(src0_ptr + i0*nb00));
+        }
+
+        if (type_src1 == 1) {
+            v1 = convert_half(*((global float *)(src1_ptr + i10*nb10)));
+        } else {
+            v1 = *((global half *)(src1_ptr + i10*nb10));
+        }
+
+        *((global half *)(dst_ptr + i0*nb0)) = v0 + v1;
+    }
+}
+
+kernel void kernel_add_row_f16(
+        global char * src0,
+        ulong  offset0,
+        global char * src1,
+        ulong  offset1,
+        global half4 * dst,
+        ulong  offsetd,
+        int ne,
+        int type_src0,
+        int type_src1
+) {
+    dst = (global half4*)((global char*)dst + offsetd);
+
+    // This performs better than using %.
+    uint gid = get_global_id(0);
+    uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
+
+    half4 v0, v1;
+    if (type_src0 == 1) {
+        global float4* src0_f32 = (global float4*)((global char*)src0 + offset0);
+        v0 = convert_half4(src0_f32[gid]);
+    } else {
+        global half4* src0_f16 = (global half4*)((global char*)src0 + offset0);
+        v0 = src0_f16[gid];
+    }
+
+    if (type_src1 == 1) {
+        global float4* src1_f32 = (global float4*)((global char*)src1 + offset1);
+        v1 = convert_half4(src1_f32[idx1]);
+    } else {
+        global half4* src1_f16 = (global half4*)((global char*)src1 + offset1);
+        v1 = src1_f16[idx1];
+    }
+
+    dst[gid] = v0 + v1;
+}
diff --git a/ggml/src/ggml-opencl/kernels/add_id.cl b/ggml/src/ggml-opencl/kernels/add_id.cl
new file mode 100644
index 00000000000..e9c6d55e6a2
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/add_id.cl
@@ -0,0 +1,42 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+//------------------------------------------------------------------------------
+// add_id
+//------------------------------------------------------------------------------
+kernel void kernel_add_id(
+    global char * src0,
+    ulong         offset0,
+    global char * src1,
+    ulong         offset1,
+    global char * src2,
+    ulong         offset2,
+    global char * dst,
+    ulong         offsetd,
+    ulong         nb01,
+    ulong         nb02,
+    ulong         nb11,
+    ulong         nb21,
+    int           ne0,
+    int           ne1
+) {
+    src0 = (global char*)((global char*)src0 + offset0);
+    src1 = (global char*)((global char*)src1 + offset1);
+    src2 = (global char*)((global char*)src2 + offset2);
+    dst  = (global char*)((global char*)dst  + offsetd);
+
+    int i1 = get_group_id(0);
+    int i2 = get_group_id(1);
+
+    const int i11 = *((global const int *) (src2 + i1*sizeof(int) + i2*nb21));
+
+    const size_t nb1 = ne0 * sizeof(float);
+    const size_t nb2 = ne1 * nb1;
+
+    global float * dst_row  = (global float *)((global char *)dst  + i1*nb1 + i2*nb2);
+    global float * src0_row = (global float *)((global char *)src0 + i1*nb01 + i2*nb02);
+    global float * src1_row = (global float *)((global char *)src1 + i11*nb11);
+
+    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
+        dst_row[i0] = src0_row[i0] + src1_row[i0];
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/conv2d.cl b/ggml/src/ggml-opencl/kernels/conv2d.cl
new file mode 100644
index 00000000000..e339c90cff5
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/conv2d.cl
@@ -0,0 +1,185 @@
+#ifdef USE_FP16
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#define T_FLOAT half
+#define T_FLOAT4 half4
+#define VSTORE_T_FLOAT4(data, offset, p) vstore_half4_rte(data, offset, p)
+#else
+#define T_FLOAT float
+#define T_FLOAT4 float4
+#define VSTORE_T_FLOAT4(data, offset, p) vstore4(data, offset, p)
+#endif
+
+#if defined(cl_qcom_reqd_sub_group_size)
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
+#else
+#define REQD_SUBGROUP_SIZE_128
+#endif
+
+#define T_ACCUM float4
+#define VEC_SIZE 4
+
+#define BS_K 64
+#define BS_NPQ 64
+#define BS_CRS 16
+
+#define TS_K 4
+#define TS_NPQ 8
+
+#define WG_K (BS_K / TS_K)
+#define WG_NPQ (BS_NPQ / TS_NPQ)
+
+#define BS_NPQ_VEC (BS_NPQ / VEC_SIZE)
+#define TS_NPQ_VEC (TS_NPQ / VEC_SIZE)
+
+static inline uint splitWork(uint work_size, uint block_size){
+    return (work_size + block_size - 1) / block_size;
+}
+
+REQD_SUBGROUP_SIZE_128
+kernel void kernel_conv_2d(
+    global void* p_knl,
+    ulong off_knl,
+    global void* p_src,
+    ulong off_src,
+    global void* p_dst,
+    ulong off_dst,
+    local void* shared,
+    uint Cout, uint Cin, uint N,
+    uint KW, uint KH, uint W, uint H, uint OW, uint OH,
+    uint s0, uint s1, uint p0, uint p1, uint d0, uint d1,
+    uint nb01, uint nb02, uint nb03,
+    uint nb11, uint nb12, uint nb13,
+    uint nb1, uint nb2, uint nb3
+) {
+    global T_FLOAT* knl_data = (global T_FLOAT*) ((global char*)p_knl + off_knl);
+    global T_FLOAT* src_data = (global T_FLOAT*) ((global char*)p_src + off_src);
+    global T_FLOAT* dst_data = (global T_FLOAT*) ((global char*)p_dst + off_dst);
+
+    const uint K = Cout;
+    const uint CRS = Cin*KH*KW;
+    const uint NPQ = N*OH*OW;
+
+    const uint lid_k = get_local_id(0);
+    const uint lid_npq = get_local_id(1);
+    const uint tid = lid_npq * WG_K + lid_k;
+
+    const uint B_idx_K = get_group_id(0);
+    const uint B_idx_NPQ = get_group_id(1);
+
+    const uint offset_k = B_idx_K * BS_K;
+    const uint offset_npq = B_idx_NPQ * BS_NPQ;
+
+    local T_FLOAT* Ash = (local T_FLOAT*)shared;
+    local T_FLOAT4* Bsh = (local T_FLOAT4*) &Ash[BS_K * BS_CRS];
+
+    T_ACCUM regC[TS_K][TS_NPQ_VEC];
+    for (int i = 0; i < TS_K; ++i) {
+        for (int j = 0; j < TS_NPQ_VEC; ++j) {
+            regC[i][j] = (T_ACCUM)(0.0f);
+        }
+    }
+
+    const uint NB_CRS = splitWork(CRS, BS_CRS);
+
+    for (uint B_idx_CRS = 0; B_idx_CRS < NB_CRS; ++B_idx_CRS) {
+        const uint offset_crs = B_idx_CRS * BS_CRS;
+
+        for (int i = tid; i < BS_K * BS_CRS; i += (WG_K * WG_NPQ)) {
+            const uint k_l = i / BS_CRS;
+            const uint crs_l = i % BS_CRS;
+            const uint k_g = offset_k + k_l;
+            const uint crs_g = offset_crs + crs_l;
+
+            if (k_g < K && crs_g < CRS) {
+                const uint Cin_idx = crs_g / (KW*KH);
+                const uint KH_idx = (crs_g - Cin_idx*KW*KH) / KW;
+                const uint KW_idx = crs_g - Cin_idx*KW*KH - KH_idx*KW;
+                const uint knl_idx = KW_idx + KH_idx*nb01 + Cin_idx*nb02 + k_g*nb03;
+                Ash[k_l * BS_CRS + crs_l] = knl_data[knl_idx];
+            } else {
+                Ash[k_l * BS_CRS + crs_l] = (T_FLOAT)0.0f;
+            }
+        }
+
+        for (int i = tid; i < BS_CRS * BS_NPQ_VEC; i += (WG_K * WG_NPQ)) {
+            const uint crs_l = i / BS_NPQ_VEC;
+            const uint npq_l_vec = i % BS_NPQ_VEC;
+            const uint crs_g = offset_crs + crs_l;
+
+            T_FLOAT4 val = (T_FLOAT4)(0.0f);
+            if (crs_g < CRS) {
+                const uint Cin_idx = crs_g / (KW * KH);
+                const uint KH_idx = (crs_g - Cin_idx * KW * KH) / KW;
+                const uint KW_idx = crs_g - Cin_idx * KW * KH - KH_idx * KW;
+                for (int v = 0; v < VEC_SIZE; ++v) {
+                    const uint npq_g = offset_npq + npq_l_vec * VEC_SIZE + v;
+                    if (npq_g < NPQ) {
+                        const uint N_idx = npq_g / (OH * OW);
+                        const uint pq_idx = npq_g % (OH * OW);
+                        const uint OH_idx = pq_idx / OW;
+                        const uint OW_idx = pq_idx % OW;
+                        const int H_idx = (int)(OH_idx * s1 + KH_idx * d1 - p1);
+                        const int W_idx = (int)(OW_idx * s0 + KW_idx * d0 - p0);
+
+                        if (H_idx >= 0 && H_idx < H && W_idx >= 0 && W_idx < W) {
+                            const uint src_idx = W_idx + H_idx * nb11 + Cin_idx * nb12 + N_idx * nb13;
+                            ((T_FLOAT*)&val)[v] = src_data[src_idx];
+                        }
+                    }
+                }
+            }
+            Bsh[crs_l * BS_NPQ_VEC + npq_l_vec] = val;
+        }
+
+        barrier(CLK_LOCAL_MEM_FENCE);
+
+        #pragma unroll
+        for (uint crs_l = 0; crs_l < BS_CRS; ++crs_l) {
+            T_FLOAT regA[TS_K];
+            for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
+                regA[k_l_reg] = Ash[(lid_k * TS_K + k_l_reg) * BS_CRS + crs_l];
+            }
+
+            for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
+                T_FLOAT4 regB = Bsh[crs_l * BS_NPQ_VEC + lid_npq * TS_NPQ_VEC + npq_l_vec_reg];
+                for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
+                    regC[k_l_reg][npq_l_vec_reg] = mad(convert_float(regA[k_l_reg]), convert_float4(regB), regC[k_l_reg][npq_l_vec_reg]);
+                }
+            }
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+
+    for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
+        const uint k_g = offset_k + lid_k * TS_K + k_l_reg;
+        if (k_g >= K) continue;
+
+        for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
+            const uint npq_g_base = offset_npq + (lid_npq * TS_NPQ_VEC + npq_l_vec_reg) * VEC_SIZE;
+
+            const uint N_idx = npq_g_base / (OH * OW);
+            const uint pq_idx = npq_g_base % (OH * OW);
+            const uint OH_idx = pq_idx / OW;
+            const uint OW_idx = pq_idx % OW;
+
+            if (nb1 == OW && OW_idx + VEC_SIZE <= OW && npq_g_base + VEC_SIZE <= NPQ) {
+                const uint dst_idx = OW_idx + OH_idx*nb1 + k_g*nb2 + N_idx*nb3;
+                VSTORE_T_FLOAT4(regC[k_l_reg][npq_l_vec_reg], 0, &dst_data[dst_idx]);
+            } else {
+                T_ACCUM res = regC[k_l_reg][npq_l_vec_reg];
+                for (int v = 0; v < VEC_SIZE; ++v) {
+                    const uint npq_g = npq_g_base + v;
+                    if (npq_g < NPQ) {
+                        const uint N_idx_s = npq_g / (OH*OW);
+                        const uint pq_idx_s = npq_g % (OH*OW);
+                        const uint OH_idx_s = pq_idx_s / OW;
+                        const uint OW_idx_s = pq_idx_s % OW;
+                        const uint dst_idx_s = OW_idx_s + OH_idx_s*nb1 + k_g*nb2 + N_idx_s*nb3;
+                        dst_data[dst_idx_s] = (T_FLOAT)(((float*)&res)[v]);
+                    }
+                }
+            }
+        }
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl b/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl
new file mode 100644
index 00000000000..cb05637f33a
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl
@@ -0,0 +1,176 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+#if defined(cl_qcom_reqd_sub_group_size)
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
+#else
+#define REQD_SUBGROUP_SIZE_128
+#endif
+
+#define T_ACCUM float4
+#define VEC_SIZE 4
+
+#define BS_K 64
+#define BS_NPQ 64
+#define BS_CRS 16
+
+#define TS_K 4
+#define TS_NPQ 8
+
+#define WG_K (BS_K / TS_K)
+#define WG_NPQ (BS_NPQ / TS_NPQ)
+
+#define BS_NPQ_VEC (BS_NPQ / VEC_SIZE)
+#define TS_NPQ_VEC (TS_NPQ / VEC_SIZE)
+
+static inline uint splitWork(uint work_size, uint block_size){
+    return (work_size + block_size - 1) / block_size;
+}
+
+REQD_SUBGROUP_SIZE_128
+kernel void kernel_conv_2d(
+    global void* p_knl,
+    ulong off_knl,
+    global void* p_src,
+    ulong off_src,
+    global void* p_dst,
+    ulong off_dst,
+    local void* shared,
+    uint Cout, uint Cin, uint N,
+    uint KW, uint KH, uint W, uint H, uint OW, uint OH,
+    uint s0, uint s1, uint p0, uint p1, uint d0, uint d1,
+    uint nb01, uint nb02, uint nb03,
+    uint nb11, uint nb12, uint nb13,
+    uint nb1, uint nb2, uint nb3
+) {
+    global half* knl_data = (global half*) ((global char*)p_knl + off_knl);
+    global float* src_data = (global float*) ((global char*)p_src + off_src);
+    global float* dst_data = (global float*) ((global char*)p_dst + off_dst);
+
+    const uint K = Cout;
+    const uint CRS = Cin*KH*KW;
+    const uint NPQ = N*OH*OW;
+
+    const uint lid_k = get_local_id(0);
+    const uint lid_npq = get_local_id(1);
+    const uint tid = lid_npq * WG_K + lid_k;
+
+    const uint B_idx_K = get_group_id(0);
+    const uint B_idx_NPQ = get_group_id(1);
+
+    const uint offset_k = B_idx_K * BS_K;
+    const uint offset_npq = B_idx_NPQ * BS_NPQ;
+
+    local half* Ash = (local half*)shared;
+    local float4* Bsh = (local float4*) &Ash[BS_K * BS_CRS];
+
+    T_ACCUM regC[TS_K][TS_NPQ_VEC];
+    for (int i = 0; i < TS_K; ++i) {
+        for (int j = 0; j < TS_NPQ_VEC; ++j) {
+            regC[i][j] = (T_ACCUM)(0.0f);
+        }
+    }
+
+    const uint NB_CRS = splitWork(CRS, BS_CRS);
+
+    for (uint B_idx_CRS = 0; B_idx_CRS < NB_CRS; ++B_idx_CRS) {
+        const uint offset_crs = B_idx_CRS * BS_CRS;
+
+        for (int i = tid; i < BS_K * BS_CRS; i += (WG_K * WG_NPQ)) {
+            const uint k_l = i / BS_CRS;
+            const uint crs_l = i % BS_CRS;
+            const uint k_g = offset_k + k_l;
+            const uint crs_g = offset_crs + crs_l;
+
+            if (k_g < K && crs_g < CRS) {
+                const uint Cin_idx = crs_g / (KW*KH);
+                const uint KH_idx = (crs_g - Cin_idx*KW*KH) / KW;
+                const uint KW_idx = crs_g - Cin_idx*KW*KH - KH_idx*KW;
+                const uint knl_idx = KW_idx + KH_idx*nb01 + Cin_idx*nb02 + k_g*nb03;
+                Ash[k_l * BS_CRS + crs_l] = knl_data[knl_idx];
+            } else {
+                Ash[k_l * BS_CRS + crs_l] = (half)0.0f;
+            }
+        }
+
+        for (int i = tid; i < BS_CRS * BS_NPQ_VEC; i += (WG_K * WG_NPQ)) {
+            const uint crs_l = i / BS_NPQ_VEC;
+            const uint npq_l_vec = i % BS_NPQ_VEC;
+            const uint crs_g = offset_crs + crs_l;
+
+            float4 val = (float4)(0.0f);
+            if (crs_g < CRS) {
+                const uint Cin_idx = crs_g / (KW * KH);
+                const uint KH_idx = (crs_g - Cin_idx * KW * KH) / KW;
+                const uint KW_idx = crs_g - Cin_idx * KW * KH - KH_idx * KW;
+                for (int v = 0; v < VEC_SIZE; ++v) {
+                    const uint npq_g = offset_npq + npq_l_vec * VEC_SIZE + v;
+                    if (npq_g < NPQ) {
+                        const uint N_idx = npq_g / (OH * OW);
+                        const uint pq_idx = npq_g % (OH * OW);
+                        const uint OH_idx = pq_idx / OW;
+                        const uint OW_idx = pq_idx % OW;
+                        const int H_idx = (int)(OH_idx * s1 + KH_idx * d1 - p1);
+                        const int W_idx = (int)(OW_idx * s0 + KW_idx * d0 - p0);
+
+                        if (H_idx >= 0 && H_idx < H && W_idx >= 0 && W_idx < W) {
+                            const uint src_idx = W_idx + H_idx * nb11 + Cin_idx * nb12 + N_idx * nb13;
+                            ((float*)&val)[v] = src_data[src_idx];
+                        }
+                    }
+                }
+            }
+            Bsh[crs_l * BS_NPQ_VEC + npq_l_vec] = val;
+        }
+
+        barrier(CLK_LOCAL_MEM_FENCE);
+
+        #pragma unroll
+        for (uint crs_l = 0; crs_l < BS_CRS; ++crs_l) {
+            half regA[TS_K];
+            for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
+                regA[k_l_reg] = Ash[(lid_k * TS_K + k_l_reg) * BS_CRS + crs_l];
+            }
+
+            for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
+                float4 regB = Bsh[crs_l * BS_NPQ_VEC + lid_npq * TS_NPQ_VEC + npq_l_vec_reg];
+                for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
+                    regC[k_l_reg][npq_l_vec_reg] = mad(convert_float(regA[k_l_reg]), regB, regC[k_l_reg][npq_l_vec_reg]);
+                }
+            }
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+
+    for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
+        const uint k_g = offset_k + lid_k * TS_K + k_l_reg;
+        if (k_g >= K) continue;
+
+        for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
+            const uint npq_g_base = offset_npq + (lid_npq * TS_NPQ_VEC + npq_l_vec_reg) * VEC_SIZE;
+
+            const uint N_idx = npq_g_base / (OH * OW);
+            const uint pq_idx = npq_g_base % (OH * OW);
+            const uint OH_idx = pq_idx / OW;
+            const uint OW_idx = pq_idx % OW;
+
+            if (nb1 == OW && OW_idx + VEC_SIZE <= OW && npq_g_base + VEC_SIZE <= NPQ) {
+                const uint dst_idx = OW_idx + OH_idx*nb1 + k_g*nb2 + N_idx*nb3;
+                vstore4(regC[k_l_reg][npq_l_vec_reg], 0, &dst_data[dst_idx]);
+            } else {
+                T_ACCUM res = regC[k_l_reg][npq_l_vec_reg];
+                for (int v = 0; v < VEC_SIZE; ++v) {
+                    const uint npq_g = npq_g_base + v;
+                    if (npq_g < NPQ) {
+                        const uint N_idx_s = npq_g / (OH*OW);
+                        const uint pq_idx_s = npq_g % (OH*OW);
+                        const uint OH_idx_s = pq_idx_s / OW;
+                        const uint OW_idx_s = pq_idx_s % OW;
+                        const uint dst_idx_s = OW_idx_s + OH_idx_s*nb1 + k_g*nb2 + N_idx_s*nb3;
+                        dst_data[dst_idx_s] = ((float*)&res)[v];
+                    }
+                }
+            }
+        }
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/div.cl b/ggml/src/ggml-opencl/kernels/div.cl
index d453ad99be4..6d9b4ade9fe 100644
--- a/ggml/src/ggml-opencl/kernels/div.cl
+++ b/ggml/src/ggml-opencl/kernels/div.cl
@@ -70,3 +70,69 @@ kernel void kernel_div_row(
     uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
     dst[gid] = src0[gid] / src1[idx1];
 }
+
+kernel void kernel_div_f16(
+        global char * src0,
+        ulong offset0,
+        global char * src1,
+        ulong offset1,
+        global char * dst,
+        ulong offsetd,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne10,
+        int ne11,
+        int ne12,
+        int ne13,
+        ulong nb10,
+        ulong nb11,
+        ulong nb12,
+        ulong nb13,
+        int ne0,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3
+) {
+    src0 = src0 + offset0;
+    src1 = src1 + offset1;
+    dst  = dst + offsetd;
+
+    int i03 = get_group_id(2);
+    int i02 = get_group_id(1);
+    int i01 = get_group_id(0);
+
+    int i13 = i03 % ne13;
+    int i12 = i02 % ne12;
+    int i11 = i01 % ne11;
+
+    global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+    global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+    global char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
+
+    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
+        const int i10 = i0 % ne10;
+        *((global half *)(dst_ptr + i0*nb0)) = *((global half *)(src0_ptr + i0*nb00)) / *((global half *)(src1_ptr + i10*nb10));
+    }
+}
+
+kernel void kernel_div_row_f16(
+        global half4 * src0,
+        ulong offset0,
+        global half4 * src1,
+        ulong offset1,
+        global half4 * dst,
+        ulong offsetd,
+        int ne
+) {
+    src0 = (global half4*)((global char*)src0 + offset0);
+    src1 = (global half4*)((global char*)src1 + offset1);
+    dst = (global half4*)((global char*)dst + offsetd);
+
+    // This performs better than using %.
+    uint gid = get_global_id(0);
+    uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
+    dst[gid] = src0[gid] / src1[idx1];
+}
diff --git a/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl b/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl
new file mode 100644
index 00000000000..fea06867e10
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl
@@ -0,0 +1,343 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+#define ACC_TYPE float
+#define ACC_TYPE4 float4
+#define DATA_TYPE half
+#define DATA_TYPE4 half4
+#define CONVERT_ACC4(x) convert_float4(x)
+#define CONVERT_DATA4(x) convert_half4(x)
+
+#define DK_VEC (DK/4)
+#define DV_VEC (DV/4)
+#define WG_SIZE (BLOCK_M)
+#define Q1_WG_SIZE 64
+
+inline float get_alibi_slope(
+    const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
+) {
+    if (max_bias <= 0.0f) {
+        return 1.0f;
+    }
+    const float base = h < n_head_log2 ? m0 : m1;
+    const int   exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+    return pow(base, exph);
+}
+__kernel void flash_attn_f16(
+    const global void * q_void, ulong q_offset,
+    const global void * k_void, ulong k_offset,
+    const global void * v_void, ulong v_offset,
+    global void * o_void, ulong o_offset,
+    const float scale,
+    const int n_q,
+    const int n_kv,
+    const int is_causal,
+    const int n_head,
+    const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
+    const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
+    const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
+    const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
+    const float max_bias,
+    const float m0,
+    const float m1,
+    const int n_head_log2,
+    const float logit_softcap,
+    const int n_head_kv,
+    const global void* mask_void,
+    const ulong mask_offset,
+    const ulong mask_nb1,
+    const ulong mask_nb2,
+    const ulong mask_nb3,
+    const int mask_ne2,
+    const int mask_ne3
+) {
+    const int tid = get_local_id(0);
+    const int block_q_idx = get_group_id(0);
+    const int head_batch_idx = get_global_id(1);
+
+    const int my_query_row = block_q_idx * BLOCK_M + tid;
+
+    const int batch_idx = head_batch_idx / n_head;
+    const int head_idx = head_batch_idx % n_head;
+
+    const int gqa_ratio = n_head / n_head_kv;
+    const int head_kv_idx = head_idx / gqa_ratio;
+
+    const global char* q_base = (const global char*)q_void + q_offset;
+    const global char* k_base = (const global char*)k_void + k_offset;
+    const global char* v_base = (const global char*)v_void + v_offset;
+    global char* o_base = (global char*)o_void + o_offset;
+
+    const global char* mask_base = NULL;
+    if (mask_void != NULL) {
+        const int mask_head_idx = head_idx % mask_ne2;
+        const int mask_batch_idx = batch_idx % mask_ne3;
+        mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
+    }
+
+    ACC_TYPE4 q_priv[DK_VEC];
+    if (my_query_row < n_q) {
+        const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
+        const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
+        #pragma unroll
+        for (int i = 0; i < DK_VEC; ++i) {
+            q_priv[i] = CONVERT_ACC4(q_ptr[i]);
+        }
+    }
+
+    ACC_TYPE4 o_acc[DV_VEC];
+    #pragma unroll
+    for (int i = 0; i < DV_VEC; ++i) {
+        o_acc[i] = (ACC_TYPE4)(0.0f);
+    }
+    ACC_TYPE m_i = -INFINITY;
+    ACC_TYPE l_i = 0.0f;
+
+    float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
+
+    __local DATA_TYPE4 l_k[BLOCK_N][DK_VEC];
+    __local DATA_TYPE4 l_v[BLOCK_N][DV_VEC];
+
+    for (int k_start = 0; k_start < n_kv; k_start += BLOCK_N) {
+        for (int i = tid; i < BLOCK_N * DK_VEC; i += WG_SIZE) {
+            const int row = i / DK_VEC;
+            const int col = i % DK_VEC;
+            const int k_row_idx = k_start + row;
+            if (k_row_idx < n_kv) {
+                const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1;
+                l_k[row][col] = ((__global DATA_TYPE4*)(k_base + k_row_offset))[col];
+            }
+        }
+        for (int i = tid; i < BLOCK_N * DV_VEC; i += WG_SIZE) {
+            const int row = i / DV_VEC;
+            const int col = i % DV_VEC;
+            const int v_row_idx = k_start + row;
+            if (v_row_idx < n_kv) {
+                const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1;
+                l_v[row][col] = ((__global DATA_TYPE4*)(v_base + v_row_offset))[col];
+            }
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+
+        if (my_query_row >= n_q) {
+            continue;
+        }
+
+        for (int j = 0; j < BLOCK_N; j += 2) {
+            const int k_row0 = k_start + j;
+            const int k_row1 = k_start + j + 1;
+
+            ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
+            ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
+            #pragma unroll
+            for (int k = 0; k < DK_VEC; k++) {
+                dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0);
+                dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
+            }
+            ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
+            ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
+
+            if (is_causal) {
+                if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
+                if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
+            }
+
+            if (k_row0 >= n_kv) score0 = -INFINITY;
+            if (k_row1 >= n_kv) score1 = -INFINITY;
+
+            if (mask_base != NULL) {
+                const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
+                if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
+                if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
+            }
+
+            if (logit_softcap > 0.0f) {
+                score0 = logit_softcap * tanh(score0 / logit_softcap);
+                score1 = logit_softcap * tanh(score1 / logit_softcap);
+            }
+
+            const ACC_TYPE m_new = max(m_i, max(score0, score1));
+            const ACC_TYPE p0 = exp(score0 - m_new);
+            const ACC_TYPE p1 = exp(score1 - m_new);
+            const ACC_TYPE scale_prev = exp(m_i - m_new);
+
+            #pragma unroll
+            for (int i = 0; i < DV_VEC; ++i) {
+                o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]);
+            }
+            l_i = l_i * scale_prev + p0 + p1;
+            m_i = m_new;
+        }
+    }
+
+    if (my_query_row < n_q) {
+        const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
+        global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
+        if (l_i > 0.0f) {
+            const ACC_TYPE l_inv = 1.0f / l_i;
+            #pragma unroll
+            for (int i = 0; i < DV_VEC; ++i) {
+                o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv);
+            }
+        } else {
+            #pragma unroll
+            for (int i = 0; i < DV_VEC; ++i) {
+                o_row[i] = (DATA_TYPE4)(0.0f);
+            }
+        }
+    }
+}
+
+__kernel void flash_attn_f16_q1(
+    const global void * q_void, ulong q_offset,
+    const global void * k_void, ulong k_offset,
+    const global void * v_void, ulong v_offset,
+    global void * o_void, ulong o_offset,
+    const float scale,
+    const int n_q,
+    const int n_kv,
+    const int is_causal,
+    const int n_head,
+    const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
+    const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
+    const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
+    const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
+    const float max_bias,
+    const float m0,
+    const float m1,
+    const int n_head_log2,
+    const float logit_softcap,
+    const int n_head_kv,
+    const global void* mask_void,
+    const ulong mask_offset,
+    const ulong mask_nb1,
+    const ulong mask_nb2,
+    const ulong mask_nb3,
+    const int mask_ne2,
+    const int mask_ne3
+) {
+    const int tid = get_local_id(0);
+    const int head_batch_idx = get_global_id(1);
+
+    const int batch_idx = head_batch_idx / n_head;
+    const int head_idx = head_batch_idx % n_head;
+
+    const int gqa_ratio = n_head / n_head_kv;
+    const int head_kv_idx = head_idx / gqa_ratio;
+
+    const global char* q_base = (const global char*)q_void + q_offset;
+    const global char* k_base = (const global char*)k_void + k_offset;
+    const global char* v_base = (const global char*)v_void + v_offset;
+    global char* o_base = (global char*)o_void + o_offset;
+
+    const global char* mask_base = NULL;
+    if (mask_void != NULL) {
+        const int mask_head_idx = head_idx % mask_ne2;
+        const int mask_batch_idx = batch_idx % mask_ne3;
+        mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
+    }
+
+    ACC_TYPE4 q_priv[DK_VEC];
+    const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
+    const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
+    #pragma unroll
+    for (int i = 0; i < DK_VEC; ++i) {
+        q_priv[i] = CONVERT_ACC4(q_ptr[i]);
+    }
+
+    float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
+
+    ACC_TYPE m_i = -INFINITY;
+    for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
+        const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
+        const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
+        ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
+        #pragma unroll
+        for (int k = 0; k < DK_VEC; k++) {
+            dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
+        }
+        ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
+        if (mask_base != NULL) {
+            const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
+            score += slope * (ACC_TYPE)mask_ptr[k_idx];
+        }
+        if (logit_softcap > 0.0f) {
+            score = logit_softcap * tanh(score / logit_softcap);
+        }
+        m_i = max(m_i, score);
+    }
+
+    __local ACC_TYPE local_m[Q1_WG_SIZE];
+    local_m[tid] = m_i;
+    barrier(CLK_LOCAL_MEM_FENCE);
+    #pragma unroll
+    for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
+        if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    const ACC_TYPE m_final = local_m[0];
+
+    ACC_TYPE4 o_acc[DV_VEC];
+    #pragma unroll
+    for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
+    ACC_TYPE l_i = 0.0f;
+
+    for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
+        const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
+        const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1;
+        const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
+        const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset);
+        ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
+        #pragma unroll
+        for (int k = 0; k < DK_VEC; k++) {
+            dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
+        }
+        ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
+        if (mask_base != NULL) {
+            const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
+            score += slope * (ACC_TYPE)mask_ptr[k_idx];
+        }
+        if (logit_softcap > 0.0f) {
+            score = logit_softcap * tanh(score / logit_softcap);
+        }
+        const ACC_TYPE p = exp(score - m_final);
+        l_i += p;
+        #pragma unroll
+        for (int i = 0; i < DV_VEC; i++) {
+            o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]);
+        }
+    }
+
+    __local ACC_TYPE local_l[Q1_WG_SIZE];
+    __local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
+    local_l[tid] = l_i;
+    barrier(CLK_LOCAL_MEM_FENCE);
+    #pragma unroll
+    for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
+        if (tid < s) local_l[tid] += local_l[tid + s];
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+
+    const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
+    global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
+    const ACC_TYPE l_final = local_l[0];
+
+    if (l_final > 0.0f) {
+        const ACC_TYPE l_inv = 1.0f / l_final;
+        for (int i = 0; i < DV_VEC; i++) {
+            local_o_comp[tid] = o_acc[i];
+            barrier(CLK_LOCAL_MEM_FENCE);
+            #pragma unroll
+            for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
+                if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
+                barrier(CLK_LOCAL_MEM_FENCE);
+            }
+            if (tid == 0) {
+                o_row[i] = CONVERT_DATA4(local_o_comp[0] * l_inv);
+            }
+        }
+    } else if (tid == 0) {
+        #pragma unroll
+        for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f);
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl b/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl
new file mode 100644
index 00000000000..2d657327d64
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl
@@ -0,0 +1,343 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+#define ACC_TYPE float
+#define ACC_TYPE4 float4
+#define DATA_TYPE float
+#define DATA_TYPE4 float4
+#define CONVERT_ACC4(x) (x)
+#define CONVERT_DATA4(x) (x)
+
+#define DK_VEC (DK/4)
+#define DV_VEC (DV/4)
+#define WG_SIZE (BLOCK_M)
+#define Q1_WG_SIZE 64
+
+inline float get_alibi_slope(
+    const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
+) {
+    if (max_bias <= 0.0f) {
+        return 1.0f;
+    }
+    const float base = h < n_head_log2 ? m0 : m1;
+    const int   exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+    return pow(base, exph);
+}
+__kernel void flash_attn_f32(
+    const global void * q_void, ulong q_offset,
+    const global void * k_void, ulong k_offset,
+    const global void * v_void, ulong v_offset,
+    global void * o_void, ulong o_offset,
+    const float scale,
+    const int n_q,
+    const int n_kv,
+    const int is_causal,
+    const int n_head,
+    const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
+    const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
+    const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
+    const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
+    const float max_bias,
+    const float m0,
+    const float m1,
+    const int n_head_log2,
+    const float logit_softcap,
+    const int n_head_kv,
+    const global void* mask_void,
+    const ulong mask_offset,
+    const ulong mask_nb1,
+    const ulong mask_nb2,
+    const ulong mask_nb3,
+    const int mask_ne2,
+    const int mask_ne3
+) {
+    const int tid = get_local_id(0);
+    const int block_q_idx = get_group_id(0);
+    const int head_batch_idx = get_global_id(1);
+
+    const int my_query_row = block_q_idx * BLOCK_M + tid;
+
+    const int batch_idx = head_batch_idx / n_head;
+    const int head_idx = head_batch_idx % n_head;
+
+    const int gqa_ratio = n_head / n_head_kv;
+    const int head_kv_idx = head_idx / gqa_ratio;
+
+    const global char* q_base = (const global char*)q_void + q_offset;
+    const global char* k_base = (const global char*)k_void + k_offset;
+    const global char* v_base = (const global char*)v_void + v_offset;
+    global char* o_base = (global char*)o_void + o_offset;
+
+    const global char* mask_base = NULL;
+    if (mask_void != NULL) {
+        const int mask_head_idx = head_idx % mask_ne2;
+        const int mask_batch_idx = batch_idx % mask_ne3;
+        mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
+    }
+
+    ACC_TYPE4 q_priv[DK_VEC];
+    if (my_query_row < n_q) {
+        const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
+        const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
+        #pragma unroll
+        for (int i = 0; i < DK_VEC; ++i) {
+            q_priv[i] = CONVERT_ACC4(q_ptr[i]);
+        }
+    }
+
+    ACC_TYPE4 o_acc[DV_VEC];
+    #pragma unroll
+    for (int i = 0; i < DV_VEC; ++i) {
+        o_acc[i] = (ACC_TYPE4)(0.0f);
+    }
+    ACC_TYPE m_i = -INFINITY;
+    ACC_TYPE l_i = 0.0f;
+
+    float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
+
+    __local DATA_TYPE4 l_k[BLOCK_N][DK_VEC];
+    __local DATA_TYPE4 l_v[BLOCK_N][DV_VEC];
+
+    for (int k_start = 0; k_start < n_kv; k_start += BLOCK_N) {
+        for (int i = tid; i < BLOCK_N * DK_VEC; i += WG_SIZE) {
+            const int row = i / DK_VEC;
+            const int col = i % DK_VEC;
+            const int k_row_idx = k_start + row;
+            if (k_row_idx < n_kv) {
+                const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1;
+                l_k[row][col] = ((__global DATA_TYPE4*)(k_base + k_row_offset))[col];
+            }
+        }
+        for (int i = tid; i < BLOCK_N * DV_VEC; i += WG_SIZE) {
+            const int row = i / DV_VEC;
+            const int col = i % DV_VEC;
+            const int v_row_idx = k_start + row;
+            if (v_row_idx < n_kv) {
+                const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1;
+                l_v[row][col] = ((__global DATA_TYPE4*)(v_base + v_row_offset))[col];
+            }
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+
+        if (my_query_row >= n_q) {
+            continue;
+        }
+
+        for (int j = 0; j < BLOCK_N; j += 2) {
+            const int k_row0 = k_start + j;
+            const int k_row1 = k_start + j + 1;
+
+            ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
+            ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
+            #pragma unroll
+            for (int k = 0; k < DK_VEC; k++) {
+                dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0);
+                dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
+            }
+            ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
+            ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
+
+            if (is_causal) {
+                if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
+                if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
+            }
+
+            if (k_row0 >= n_kv) score0 = -INFINITY;
+            if (k_row1 >= n_kv) score1 = -INFINITY;
+
+            if (mask_base != NULL) {
+                const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
+                if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
+                if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
+            }
+
+            if (logit_softcap > 0.0f) {
+                score0 = logit_softcap * tanh(score0 / logit_softcap);
+                score1 = logit_softcap * tanh(score1 / logit_softcap);
+            }
+
+            const ACC_TYPE m_new = max(m_i, max(score0, score1));
+            const ACC_TYPE p0 = exp(score0 - m_new);
+            const ACC_TYPE p1 = exp(score1 - m_new);
+            const ACC_TYPE scale_prev = exp(m_i - m_new);
+
+            #pragma unroll
+            for (int i = 0; i < DV_VEC; ++i) {
+                o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]);
+            }
+            l_i = l_i * scale_prev + p0 + p1;
+            m_i = m_new;
+        }
+    }
+
+    if (my_query_row < n_q) {
+        const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
+        global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
+        if (l_i > 0.0f) {
+            const ACC_TYPE l_inv = 1.0f / l_i;
+            #pragma unroll
+            for (int i = 0; i < DV_VEC; ++i) {
+                o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv);
+            }
+        } else {
+            #pragma unroll
+            for (int i = 0; i < DV_VEC; ++i) {
+                o_row[i] = (DATA_TYPE4)(0.0f);
+            }
+        }
+    }
+}
+
+__kernel void flash_attn_f32_q1(
+    const global void * q_void, ulong q_offset,
+    const global void * k_void, ulong k_offset,
+    const global void * v_void, ulong v_offset,
+    global void * o_void, ulong o_offset,
+    const float scale,
+    const int n_q,
+    const int n_kv,
+    const int is_causal,
+    const int n_head,
+    const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
+    const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
+    const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
+    const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
+    const float max_bias,
+    const float m0,
+    const float m1,
+    const int n_head_log2,
+    const float logit_softcap,
+    const int n_head_kv,
+    const global void* mask_void,
+    const ulong mask_offset,
+    const ulong mask_nb1,
+    const ulong mask_nb2,
+    const ulong mask_nb3,
+    const int mask_ne2,
+    const int mask_ne3
+) {
+    const int tid = get_local_id(0);
+    const int head_batch_idx = get_global_id(1);
+
+    const int batch_idx = head_batch_idx / n_head;
+    const int head_idx = head_batch_idx % n_head;
+
+    const int gqa_ratio = n_head / n_head_kv;
+    const int head_kv_idx = head_idx / gqa_ratio;
+
+    const global char* q_base = (const global char*)q_void + q_offset;
+    const global char* k_base = (const global char*)k_void + k_offset;
+    const global char* v_base = (const global char*)v_void + v_offset;
+    global char* o_base = (global char*)o_void + o_offset;
+
+    const global char* mask_base = NULL;
+    if (mask_void != NULL) {
+        const int mask_head_idx = head_idx % mask_ne2;
+        const int mask_batch_idx = batch_idx % mask_ne3;
+        mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
+    }
+
+    ACC_TYPE4 q_priv[DK_VEC];
+    const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
+    const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
+    #pragma unroll
+    for (int i = 0; i < DK_VEC; ++i) {
+        q_priv[i] = CONVERT_ACC4(q_ptr[i]);
+    }
+
+    float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
+
+    ACC_TYPE m_i = -INFINITY;
+    for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
+        const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
+        const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
+        ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
+        #pragma unroll
+        for (int k = 0; k < DK_VEC; k++) {
+            dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
+        }
+        ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
+        if (mask_base != NULL) {
+            const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
+            score += slope * (ACC_TYPE)mask_ptr[k_idx];
+        }
+        if (logit_softcap > 0.0f) {
+            score = logit_softcap * tanh(score / logit_softcap);
+        }
+        m_i = max(m_i, score);
+    }
+
+    __local ACC_TYPE local_m[Q1_WG_SIZE];
+    local_m[tid] = m_i;
+    barrier(CLK_LOCAL_MEM_FENCE);
+    #pragma unroll
+    for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
+        if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    const ACC_TYPE m_final = local_m[0];
+
+    ACC_TYPE4 o_acc[DV_VEC];
+    #pragma unroll
+    for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
+    ACC_TYPE l_i = 0.0f;
+
+    for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
+        const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
+        const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1;
+        const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
+        const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset);
+        ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
+        #pragma unroll
+        for (int k = 0; k < DK_VEC; k++) {
+            dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
+        }
+        ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
+        if (mask_base != NULL) {
+            const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
+            score += slope * (ACC_TYPE)mask_ptr[k_idx];
+        }
+        if (logit_softcap > 0.0f) {
+            score = logit_softcap * tanh(score / logit_softcap);
+        }
+        const ACC_TYPE p = exp(score - m_final);
+        l_i += p;
+        #pragma unroll
+        for (int i = 0; i < DV_VEC; i++) {
+            o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]);
+        }
+    }
+
+    __local ACC_TYPE local_l[Q1_WG_SIZE];
+    __local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
+    local_l[tid] = l_i;
+    barrier(CLK_LOCAL_MEM_FENCE);
+    #pragma unroll
+    for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
+        if (tid < s) local_l[tid] += local_l[tid + s];
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+
+    const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
+    global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
+    const ACC_TYPE l_final = local_l[0];
+
+    if (l_final > 0.0f) {
+        const ACC_TYPE l_inv = 1.0f / l_final;
+        for (int i = 0; i < DV_VEC; i++) {
+            local_o_comp[tid] = o_acc[i];
+            barrier(CLK_LOCAL_MEM_FENCE);
+            #pragma unroll
+            for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
+                if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
+                barrier(CLK_LOCAL_MEM_FENCE);
+            }
+            if (tid == 0) {
+                o_row[i] = CONVERT_DATA4(local_o_comp[0] * l_inv);
+            }
+        }
+    } else if (tid == 0) {
+        #pragma unroll
+        for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f);
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl b/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl
new file mode 100644
index 00000000000..7067bd2591f
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl
@@ -0,0 +1,346 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+#define ACC_TYPE float
+#define ACC_TYPE4 float4
+#define Q_DATA_TYPE4 float4
+#define KV_DATA_TYPE4 half4
+#define O_DATA_TYPE4 float4
+#define MASK_DATA_TYPE half
+#define CONVERT_Q_ACC4(x) (x)
+#define CONVERT_KV_ACC4(x) convert_float4(x)
+#define CONVERT_O_DATA4(x) (x)
+
+#define DK_VEC (DK/4)
+#define DV_VEC (DV/4)
+#define WG_SIZE (BLOCK_M)
+#define Q1_WG_SIZE 64
+
+inline float get_alibi_slope(
+    const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
+) {
+    if (max_bias <= 0.0f) {
+        return 1.0f;
+    }
+    const float base = h < n_head_log2 ? m0 : m1;
+    const int   exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+    return pow(base, exph);
+}
+__kernel void flash_attn_f32_f16(
+    const global void * q_void, ulong q_offset,
+    const global void * k_void, ulong k_offset,
+    const global void * v_void, ulong v_offset,
+    global void * o_void, ulong o_offset,
+    const float scale,
+    const int n_q,
+    const int n_kv,
+    const int is_causal,
+    const int n_head,
+    const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
+    const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
+    const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
+    const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
+    const float max_bias,
+    const float m0,
+    const float m1,
+    const int n_head_log2,
+    const float logit_softcap,
+    const int n_head_kv,
+    const global void* mask_void,
+    const ulong mask_offset,
+    const ulong mask_nb1,
+    const ulong mask_nb2,
+    const ulong mask_nb3,
+    const int mask_ne2,
+    const int mask_ne3
+) {
+    const int tid = get_local_id(0);
+    const int block_q_idx = get_group_id(0);
+    const int head_batch_idx = get_global_id(1);
+
+    const int my_query_row = block_q_idx * BLOCK_M + tid;
+
+    const int batch_idx = head_batch_idx / n_head;
+    const int head_idx = head_batch_idx % n_head;
+
+    const int gqa_ratio = n_head / n_head_kv;
+    const int head_kv_idx = head_idx / gqa_ratio;
+
+    const global char* q_base = (const global char*)q_void + q_offset;
+    const global char* k_base = (const global char*)k_void + k_offset;
+    const global char* v_base = (const global char*)v_void + v_offset;
+    global char* o_base = (global char*)o_void + o_offset;
+
+    const global char* mask_base = NULL;
+    if (mask_void != NULL) {
+        const int mask_head_idx = head_idx % mask_ne2;
+        const int mask_batch_idx = batch_idx % mask_ne3;
+        mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
+    }
+
+    ACC_TYPE4 q_priv[DK_VEC];
+    if (my_query_row < n_q) {
+        const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
+        const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset);
+        #pragma unroll
+        for (int i = 0; i < DK_VEC; ++i) {
+            q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]);
+        }
+    }
+
+    ACC_TYPE4 o_acc[DV_VEC];
+    #pragma unroll
+    for (int i = 0; i < DV_VEC; ++i) {
+        o_acc[i] = (ACC_TYPE4)(0.0f);
+    }
+    ACC_TYPE m_i = -INFINITY;
+    ACC_TYPE l_i = 0.0f;
+
+    float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
+
+    __local KV_DATA_TYPE4 l_k[BLOCK_N][DK_VEC];
+    __local KV_DATA_TYPE4 l_v[BLOCK_N][DV_VEC];
+
+    for (int k_start = 0; k_start < n_kv; k_start += BLOCK_N) {
+        for (int i = tid; i < BLOCK_N * DK_VEC; i += WG_SIZE) {
+            const int row = i / DK_VEC;
+            const int col = i % DK_VEC;
+            const int k_row_idx = k_start + row;
+            if (k_row_idx < n_kv) {
+                const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1;
+                l_k[row][col] = ((__global KV_DATA_TYPE4*)(k_base + k_row_offset))[col];
+            }
+        }
+        for (int i = tid; i < BLOCK_N * DV_VEC; i += WG_SIZE) {
+            const int row = i / DV_VEC;
+            const int col = i % DV_VEC;
+            const int v_row_idx = k_start + row;
+            if (v_row_idx < n_kv) {
+                const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1;
+                l_v[row][col] = ((__global KV_DATA_TYPE4*)(v_base + v_row_offset))[col];
+            }
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+
+        if (my_query_row >= n_q) {
+            continue;
+        }
+
+        for (int j = 0; j < BLOCK_N; j += 2) {
+            const int k_row0 = k_start + j;
+            const int k_row1 = k_start + j + 1;
+
+            ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
+            ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
+            #pragma unroll
+            for (int k = 0; k < DK_VEC; k++) {
+                dot_acc0 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j][k]), dot_acc0);
+                dot_acc1 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j+1][k]), dot_acc1);
+            }
+            ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
+            ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
+
+            if (is_causal) {
+                if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
+                if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
+            }
+
+            if (k_row0 >= n_kv) score0 = -INFINITY;
+            if (k_row1 >= n_kv) score1 = -INFINITY;
+
+            if (mask_base != NULL) {
+                const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
+                if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
+                if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
+            }
+
+            if (logit_softcap > 0.0f) {
+                score0 = logit_softcap * tanh(score0 / logit_softcap);
+                score1 = logit_softcap * tanh(score1 / logit_softcap);
+            }
+
+            const ACC_TYPE m_new = max(m_i, max(score0, score1));
+            const ACC_TYPE p0 = exp(score0 - m_new);
+            const ACC_TYPE p1 = exp(score1 - m_new);
+            const ACC_TYPE scale_prev = exp(m_i - m_new);
+
+            #pragma unroll
+            for (int i = 0; i < DV_VEC; ++i) {
+                o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_KV_ACC4(l_v[j][i]) + p1 * CONVERT_KV_ACC4(l_v[j+1][i]);
+            }
+            l_i = l_i * scale_prev + p0 + p1;
+            m_i = m_new;
+        }
+    }
+
+    if (my_query_row < n_q) {
+        const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
+        global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
+        if (l_i > 0.0f) {
+            const ACC_TYPE l_inv = 1.0f / l_i;
+            #pragma unroll
+            for (int i = 0; i < DV_VEC; ++i) {
+                o_row[i] = CONVERT_O_DATA4(o_acc[i] * l_inv);
+            }
+        } else {
+            #pragma unroll
+            for (int i = 0; i < DV_VEC; ++i) {
+                o_row[i] = (O_DATA_TYPE4)(0.0f);
+            }
+        }
+    }
+}
+
+__kernel void flash_attn_f32_f16_q1(
+    const global void * q_void, ulong q_offset,
+    const global void * k_void, ulong k_offset,
+    const global void * v_void, ulong v_offset,
+    global void * o_void, ulong o_offset,
+    const float scale,
+    const int n_q,
+    const int n_kv,
+    const int is_causal,
+    const int n_head,
+    const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
+    const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
+    const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
+    const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
+    const float max_bias,
+    const float m0,
+    const float m1,
+    const int n_head_log2,
+    const float logit_softcap,
+    const int n_head_kv,
+    const global void* mask_void,
+    const ulong mask_offset,
+    const ulong mask_nb1,
+    const ulong mask_nb2,
+    const ulong mask_nb3,
+    const int mask_ne2,
+    const int mask_ne3
+) {
+    const int tid = get_local_id(0);
+    const int head_batch_idx = get_global_id(1);
+
+    const int batch_idx = head_batch_idx / n_head;
+    const int head_idx = head_batch_idx % n_head;
+
+    const int gqa_ratio = n_head / n_head_kv;
+    const int head_kv_idx = head_idx / gqa_ratio;
+
+    const global char* q_base = (const global char*)q_void + q_offset;
+    const global char* k_base = (const global char*)k_void + k_offset;
+    const global char* v_base = (const global char*)v_void + v_offset;
+    global char* o_base = (global char*)o_void + o_offset;
+
+    const global char* mask_base = NULL;
+    if (mask_void != NULL) {
+        const int mask_head_idx = head_idx % mask_ne2;
+        const int mask_batch_idx = batch_idx % mask_ne3;
+        mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
+    }
+
+    ACC_TYPE4 q_priv[DK_VEC];
+    const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
+    const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset);
+    #pragma unroll
+    for (int i = 0; i < DK_VEC; ++i) {
+        q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]);
+    }
+
+    float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
+
+    ACC_TYPE m_i = -INFINITY;
+    for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
+        const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
+        const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
+        ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
+        #pragma unroll
+        for (int k = 0; k < DK_VEC; k++) {
+            dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
+        }
+        ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
+        if (mask_base != NULL) {
+            const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base);
+            score += slope * (ACC_TYPE)mask_ptr[k_idx];
+        }
+        if (logit_softcap > 0.0f) {
+            score = logit_softcap * tanh(score / logit_softcap);
+        }
+        m_i = max(m_i, score);
+    }
+
+    __local ACC_TYPE local_m[Q1_WG_SIZE];
+    local_m[tid] = m_i;
+    barrier(CLK_LOCAL_MEM_FENCE);
+    #pragma unroll
+    for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
+        if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    const ACC_TYPE m_final = local_m[0];
+
+    ACC_TYPE4 o_acc[DV_VEC];
+    #pragma unroll
+    for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
+    ACC_TYPE l_i = 0.0f;
+
+    for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
+        const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
+        const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1;
+        const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
+        const global KV_DATA_TYPE4* v_ptr = (const global KV_DATA_TYPE4*)(v_base + v_row_offset);
+        ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
+        #pragma unroll
+        for (int k = 0; k < DK_VEC; k++) {
+            dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
+        }
+        ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
+        if (mask_base != NULL) {
+            const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base);
+            score += slope * (ACC_TYPE)mask_ptr[k_idx];
+        }
+        if (logit_softcap > 0.0f) {
+            score = logit_softcap * tanh(score / logit_softcap);
+        }
+        const ACC_TYPE p = exp(score - m_final);
+        l_i += p;
+        #pragma unroll
+        for (int i = 0; i < DV_VEC; i++) {
+            o_acc[i] = mad(p, CONVERT_KV_ACC4(v_ptr[i]), o_acc[i]);
+        }
+    }
+
+    __local ACC_TYPE local_l[Q1_WG_SIZE];
+    __local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
+    local_l[tid] = l_i;
+    barrier(CLK_LOCAL_MEM_FENCE);
+    #pragma unroll
+    for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
+        if (tid < s) local_l[tid] += local_l[tid + s];
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+
+    const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
+    global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
+    const ACC_TYPE l_final = local_l[0];
+
+    if (l_final > 0.0f) {
+        const ACC_TYPE l_inv = 1.0f / l_final;
+        for (int i = 0; i < DV_VEC; i++) {
+            local_o_comp[tid] = o_acc[i];
+            barrier(CLK_LOCAL_MEM_FENCE);
+            #pragma unroll
+            for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
+                if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
+                barrier(CLK_LOCAL_MEM_FENCE);
+            }
+            if (tid == 0) {
+                o_row[i] = CONVERT_O_DATA4(local_o_comp[0] * l_inv);
+            }
+        }
+    } else if (tid == 0) {
+        #pragma unroll
+        for (int i = 0; i < DV_VEC; ++i) o_row[i] = (O_DATA_TYPE4)(0.0f);
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/glu.cl b/ggml/src/ggml-opencl/kernels/glu.cl
index 7cca16e6a9e..059a4bbf1ba 100644
--- a/ggml/src/ggml-opencl/kernels/glu.cl
+++ b/ggml/src/ggml-opencl/kernels/glu.cl
@@ -202,6 +202,47 @@ kernel void kernel_swiglu_f16(
     }
 }
 
+//------------------------------------------------------------------------------
+// swiglu_oai
+//------------------------------------------------------------------------------
+kernel void kernel_swiglu_oai(
+    global char * src0,
+    ulong         offset0,
+    global char * src1,
+    ulong         offset1,
+    global char * dst,
+    ulong         offsetd,
+    ulong         nb01,
+    ulong         nb11,
+    int           ne0,
+    ulong         nb1,
+    int           ne00_off,
+    int           ne10_off,
+    float         limit,
+    float         alpha
+) {
+    src0 = (global char*)((global char*)src0 + offset0);
+    src1 = (global char*)((global char*)src1 + offset1);
+    dst  = (global char*)((global char*)dst  + offsetd);
+
+    global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
+    global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
+    global float * dst_row  = (global float *) ((global char *) dst  + get_group_id(0)*nb1);
+
+    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
+        float x0 = src0_row[i0];
+        float x1 = src1_row[i0];
+
+        x0 = min(x0, limit);
+        x1 = max(min(x1, limit), -limit);
+
+        float out_glu = x0 / (1.0f + exp(-x0 * alpha));
+        out_glu = out_glu * (1.0f + x1);
+
+        dst_row[i0] = out_glu;
+    }
+}
+
 //------------------------------------------------------------------------------
 // geglu_erf
 //------------------------------------------------------------------------------
diff --git a/ggml/src/ggml-opencl/kernels/im2col_f16.cl b/ggml/src/ggml-opencl/kernels/im2col_f16.cl
index b84c8984653..cf6cdaa4ce5 100644
--- a/ggml/src/ggml-opencl/kernels/im2col_f16.cl
+++ b/ggml/src/ggml-opencl/kernels/im2col_f16.cl
@@ -31,7 +31,7 @@ kernel void kernel_im2col_f16(
     src1 = (global float*)((global char*)src1 + offset1);
     dst = (global half*)((global char*)dst + offsetd);
 
-    long  ksize = OW * (KH > 1 ? KW : 1);
+    long  ksize = OW * KH;
     long  kx = i / ksize;
     long  kd = kx * ksize;
     long  ky = (i - kd) / OW;
diff --git a/ggml/src/ggml-opencl/kernels/im2col_f32.cl b/ggml/src/ggml-opencl/kernels/im2col_f32.cl
index 4bf65e4eaaf..1ecdb2344ad 100644
--- a/ggml/src/ggml-opencl/kernels/im2col_f32.cl
+++ b/ggml/src/ggml-opencl/kernels/im2col_f32.cl
@@ -31,7 +31,7 @@ kernel void kernel_im2col_f32(
     src1 = (global float*)((global char*)src1 + offset1);
     dst = (global float*)((global char*)dst + offsetd);
 
-    long  ksize = OW * (KH > 1 ? KW : 1);
+    long  ksize = OW * KH;
     long  kx = i / ksize;
     long  kd = kx * ksize;
     long  ky = (i - kd) / OW;
diff --git a/ggml/src/ggml-opencl/kernels/mul.cl b/ggml/src/ggml-opencl/kernels/mul.cl
index 2a2b4eb70a1..b12a592165f 100644
--- a/ggml/src/ggml-opencl/kernels/mul.cl
+++ b/ggml/src/ggml-opencl/kernels/mul.cl
@@ -77,3 +77,76 @@ kernel void kernel_mul_row(
     uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
     dst[gid] = src0[gid] * src1[idx1];
 }
+
+kernel void kernel_mul_f16(
+        global char * src0,
+        ulong offset0,
+        global char * src1,
+        ulong offset1,
+        global char * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne03,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne10,
+        int ne11,
+        int ne12,
+        int ne13,
+        ulong nb10,
+        ulong nb11,
+        ulong nb12,
+        ulong nb13,
+        int ne0,
+        int ne1,
+        int ne2,
+        int ne3,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3
+) {
+    src0 = src0 + offset0;
+    src1 = src1 + offset1;
+    dst  = dst + offsetd;
+
+    int i03 = get_group_id(2);
+    int i02 = get_group_id(1);
+    int i01 = get_group_id(0);
+
+    int i13 = i03 % ne13;
+    int i12 = i02 % ne12;
+    int i11 = i01 % ne11;
+
+    global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+    global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+    global char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
+
+    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
+        const int i10 = i0 % ne10;
+        *((global half *)(dst_ptr + i0*nb0)) = *((global half *)(src0_ptr + i0*nb00)) * *((global half *)(src1_ptr + i10*nb10));
+    }
+}
+
+kernel void kernel_mul_row_f16(
+        global half4 * src0,
+        ulong offset0,
+        global half4 * src1,
+        ulong offset1,
+        global half4 * dst,
+        ulong offsetd,
+        int ne
+) {
+    src0 = (global half4*)((global char*)src0 + offset0);
+    src1 = (global half4*)((global char*)src1 + offset1);
+    dst = (global half4*)((global char*)dst + offsetd);
+
+    // This performs better than using %.
+    uint gid = get_global_id(0);
+    uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
+    dst[gid] = src0[gid] * src1[idx1];
+}
diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl
new file mode 100644
index 00000000000..9599a0e1572
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl
@@ -0,0 +1,132 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+#define LOAD_VEC_A 4
+#define LOAD_VEC_B 4
+
+#define BM 64
+#define BN 64
+#define BK 16
+#define TM 4
+#define TN 8
+
+kernel void kernel_mul_mm_f16_f32_l4_lm(
+    global half4 * src0,
+    ulong offset0,
+    global float4 * src1,
+    ulong offset1,
+    global float * dst,
+    ulong offsetd,
+
+    int ne00,
+    int ne01,
+    int ne02,
+    int ne11,
+    int ne12,
+
+    int stride_a,
+    int stride_b,
+    int stride_d,
+
+    int batch_stride_a,
+    int batch_stride_b,
+    int batch_stride_d,
+
+    int r2,
+    int r3
+) {
+    src0 = (global half4*)((global char*)src0 + offset0);
+    src1 = (global float4*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    local half  buf_a[BM * BK];
+    local float buf_b[BN * BK];
+
+    const int batch_idx = get_global_id(2);
+
+    const int i13 = batch_idx / ne12;
+    const int i12 = batch_idx % ne12;
+
+    const int i03 = i13 / r3;
+    const int i02 = i12 / r2;
+
+    const int batch_idx_a = i03 * ne02 + i02;
+
+    const int ir = get_group_id(0);
+    const int ic = get_group_id(1);
+
+    const int tid = get_local_id(0);
+    const int th_r  = tid % (BM / TM);
+    const int th_c  = tid / (BM / TM);
+
+    const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
+    const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
+    const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
+    const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
+
+    const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
+    const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
+
+    int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
+    int pos_b = (batch_idx   * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
+
+    float sums[TM * TN];
+    half  cache_a[TM];
+    float cache_b[TN];
+
+    for (int i = 0; i < TM * TN; i++) {
+        sums[i] = 0.0f;
+    }
+
+    for (int block = 0; block < ne00; block += BK) {
+        for (int l = 0; l < BM; l += loadstride_a) {
+            const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
+            buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
+            buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
+            buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
+            buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3;
+        }
+
+        for (int l = 0; l < BN; l += loadstride_b) {
+            const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
+            buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
+            buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
+            buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
+            buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
+        }
+
+        barrier(CLK_LOCAL_MEM_FENCE);
+
+        pos_a += BK / LOAD_VEC_A;
+        pos_b += BK / LOAD_VEC_B;
+
+        for (int i = 0; i < BK; i++) {
+            for (int j = 0; j < TM; j++) {
+                cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
+            }
+            for (int j = 0; j < TN; j++) {
+                cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
+            }
+
+            for (int cc = 0; cc < TN; cc++) {
+                for (int cr = 0; cr < TM; cr++) {
+                    const int sums_idx = cc*TM + cr;
+                    sums[sums_idx] = mad(convert_float(cache_a[cr]), cache_b[cc], sums[sums_idx]);
+                }
+            }
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+
+    const int dr = ir * BM + th_r * TM;
+    const int dc = ic * BN + th_c * TN;
+
+    const int offsets = batch_idx * batch_stride_d;
+
+    for (int cc = 0; cc < TN; cc++) {
+        for (int cr = 0; cr < TM; cr++) {
+            if (dr + cr < ne01 && dc + cc < ne11) {
+                dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
+            }
+        }
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl
new file mode 100644
index 00000000000..58c5178e39c
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl
@@ -0,0 +1,133 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+#define LOAD_VEC_A 4
+#define LOAD_VEC_B 4
+
+#define BM 64
+#define BN 64
+#define BK 16
+#define TM 4
+#define TN 8
+
+kernel void kernel_mul_mm_f32_f32_l4_lm(
+    global float4 * src0,
+    ulong offset0,
+    global float4 * src1,
+    ulong offset1,
+    global float * dst,
+    ulong offsetd,
+
+    int ne00,
+    int ne01,
+    int ne02,
+    int ne11,
+    int ne12,
+
+    int stride_a,
+    int stride_b,
+    int stride_d,
+
+    int batch_stride_a,
+    int batch_stride_b,
+    int batch_stride_d,
+
+    int r2,
+    int r3
+) {
+    src0 = (global float4*)((global char*)src0 + offset0);
+    src1 = (global float4*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    local float buf_a[BM * BK];
+    local float buf_b[BN * BK];
+
+    const int batch_idx = get_global_id(2);
+
+    const int i13 = batch_idx / ne12;
+    const int i12 = batch_idx % ne12;
+
+    const int i03 = i13 / r3;
+    const int i02 = i12 / r2;
+
+    const int batch_idx_a = i03 * ne02 + i02;
+
+    const int ir = get_group_id(0);
+    const int ic = get_group_id(1);
+
+    const int tid = get_local_id(0);
+    const int th_r  = tid % (BM / TM);
+    const int th_c  = tid / (BM / TM);
+
+    const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
+    const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
+    const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
+    const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
+
+    const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
+    const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
+
+    int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
+    int pos_b = (batch_idx   * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
+
+    float sums[TM * TN];
+    float cache_a[TM];
+    float cache_b[TN];
+
+    for (int i = 0; i < TM * TN; i++) {
+        sums[i] = 0.0f;
+    }
+
+    for (int block = 0; block < ne00; block += BK) {
+        for (int l = 0; l < BM; l += loadstride_a) {
+            const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
+            buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
+            buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
+            buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
+            buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3;
+        }
+
+        for (int l = 0; l < BN; l += loadstride_b) {
+            const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
+            buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
+            buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
+            buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
+            buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
+        }
+
+        barrier(CLK_LOCAL_MEM_FENCE);
+
+        pos_a += BK / LOAD_VEC_A;
+        pos_b += BK / LOAD_VEC_B;
+
+        for (int i = 0; i < BK; i++) {
+            for (int j = 0; j < TM; j++) {
+                cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
+            }
+
+            for (int j = 0; j < TN; j++) {
+                cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
+            }
+
+            for (int cc = 0; cc < TN; cc++) {
+                for (int cr = 0; cr < TM; cr++) {
+                    const int sums_idx = cc*TM + cr;
+                    sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
+                }
+            }
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+
+    const int dr = ir * BM + th_r * TM;
+    const int dc = ic * BN + th_c * TN;
+
+    const int offsets = batch_idx * batch_stride_d;
+
+    for (int cc = 0; cc < TN; cc++) {
+        for (int cr = 0; cr < TM; cr++) {
+            if (dr + cr < ne01 && dc + cc < ne11) {
+                dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
+            }
+        }
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl
new file mode 100644
index 00000000000..d50bd1fc428
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl
@@ -0,0 +1,189 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+#ifdef cl_intel_subgroups
+#pragma OPENCL EXTENSION cl_intel_subgroups : enable
+#else
+#pragma OPENCL EXTENSION cl_khr_subgroups : enable
+#endif
+
+#ifdef cl_intel_required_subgroup_size
+#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
+#define INTEL_GPU 1
+#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
+#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
+#elif defined(cl_qcom_reqd_sub_group_size)
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+#define ADRENO_GPU 1
+#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half")))
+#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
+#endif
+
+#define QK_MXFP4 32
+typedef struct {
+    uchar e; // E8M0
+    uchar qs[QK_MXFP4/2];
+} block_mxfp4;
+
+constant static float kvalues_mxfp4_f[16] = {
+    0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f
+};
+
+static inline float e8m0_to_fp32(uchar x) {
+    int bits;
+
+    if (x == 0) {
+        bits = 0x00400000;
+    } else {
+        bits = (uint) x << 23;
+    }
+
+    return as_float(bits);
+}
+
+#ifdef INTEL_GPU
+#define N_R0_MXFP4 2 // number of rows each subgroup works on
+#define N_SG_MXFP4 2 // number of subgroups in a work group
+#define N_SIMDWIDTH 16 // subgroup size
+#elif defined (ADRENO_GPU)
+#define N_R0_MXFP4 2
+#define N_SG_MXFP4 2
+#define N_SIMDWIDTH 64
+#endif
+
+inline void mul_mv_mxfp4_f32(
+    global char * src0,
+    global char * src1,
+    global char * dst,
+    int ne00,
+    ulong nb01,
+    ulong nb02,
+    ulong nb03,
+    int ne12,
+    ulong nb11,
+    ulong nb12,
+    ulong nb13,
+    int ne0,
+    int ne1,
+    int r2,
+    int r3,
+    local  char * shmem
+) {
+    local float * shmem_f32 = (local float *) shmem;
+    int nb = ne00/QK_MXFP4;
+
+    int r0 = get_group_id(0);
+    int r1 = get_group_id(1);
+    int im = 0;
+
+    int first_row = (r0 * N_SG_MXFP4 + get_sub_group_id()) * N_R0_MXFP4;
+
+    uint i12 = im%ne12;
+    uint i13 = im/ne12;
+
+    ulong offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    ulong offset_src1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+    global block_mxfp4 * x = (global block_mxfp4 *) (src0 + offset_src0);
+    global float       * y = (global float       *) (src1 + offset_src1);
+
+    const short ix = get_sub_group_local_id()/2;  // 0...15
+    const short it = get_sub_group_local_id()%2;  // 0 or 1
+
+    shmem_f32[get_sub_group_local_id()] = kvalues_mxfp4_f[get_sub_group_local_id()%16];
+    barrier(CLK_LOCAL_MEM_FENCE);
+
+    float4 yl[4];
+    float sumf[N_R0_MXFP4] = {0.f};
+
+    global float * yb = y + ix * QK_MXFP4 + it * 8;
+
+    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
+        global float4 * y4 = (global float4 *)yb;
+        yl[0] = y4[0];
+        yl[1] = y4[4];
+        yl[2] = y4[1];
+        yl[3] = y4[5];
+
+        for (short row = 0; row < N_R0_MXFP4; row++) {
+            global block_mxfp4 * xb = x + row*nb + ib;
+            global uchar       * q2 = (global uchar *)(xb->qs + 8*it);
+
+            float4 acc1 = yl[0]*(float4)(shmem_f32[q2[0] &  0x0F], shmem_f32[q2[1] &  0x0F], shmem_f32[q2[2] &  0x0F], shmem_f32[q2[3] &  0x0F]);
+            float4 acc2 = yl[1]*(float4)(shmem_f32[q2[0] >> 4   ], shmem_f32[q2[1] >> 4   ], shmem_f32[q2[2] >> 4   ], shmem_f32[q2[3] >> 4   ]);
+            float4 acc3 = yl[2]*(float4)(shmem_f32[q2[4] &  0x0F], shmem_f32[q2[5] &  0x0F], shmem_f32[q2[6] &  0x0F], shmem_f32[q2[7] &  0x0F]);
+            float4 acc4 = yl[3]*(float4)(shmem_f32[q2[4] >> 4   ], shmem_f32[q2[5] >> 4   ], shmem_f32[q2[6] >> 4   ], shmem_f32[q2[7] >> 4   ]);
+
+            acc1 = (acc1 + acc3) + (acc2 + acc4);
+
+            sumf[row] += e8m0_to_fp32(xb->e) * ((acc1.s0 + acc1.s1) + (acc1.s2 + acc1.s3));
+        }
+
+        yb += (N_SIMDWIDTH/2) * QK_MXFP4;
+    }
+
+    global float * dst_f32 = (global float *) dst + (ulong)im*ne0*ne1 + (ulong)r1*ne0;
+
+    for (int row = 0; row < N_R0_MXFP4 && first_row + row < ne0; ++row) {
+        float sum_all = sub_group_reduce_add(sumf[row]);
+        if (get_sub_group_local_id() == 0) {
+            dst_f32[first_row + row] = sum_all;
+        }
+    }
+}
+
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_16
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_mul_mv_id_mxfp4_f32(
+    global char * src0,
+    ulong         offset0,
+    global char * src1,
+    ulong         offset1,
+    global char * src2,
+    ulong         offset2,
+    global char * dst,
+    ulong         offsetd,
+    int           ne00,
+    ulong         nb01,
+    ulong         nb02,
+    ulong         nb03,
+    int           ne11,
+    int           ne12,
+    ulong         nb11,
+    ulong         nb12,
+    ulong         nb13,
+    int           ne20,
+    int           ne21,
+    ulong         nb21,
+    int           ne0,
+    int           ne1,
+    int           r2,
+    int           r3,
+    local  char * shmem
+) {
+    src0 = (global char *)((global char *)src0 + offset0);
+    src1 = (global char *)((global char *)src1 + offset1);
+    src2 = (global char *)((global char *)src2 + offset2);
+    dst  = (global char *)((global char *)dst  + offsetd);
+
+    const int iid1 = get_group_id(2)/ne20;
+    const int idx  = get_group_id(2)%ne20;
+
+    int i02 = ((global int *) (src2 + iid1*nb21))[idx];
+
+    int i11 = idx % ne11;
+    int i12 = iid1;
+
+    int i1 = idx;
+    int i2 = i12;
+
+    global char * src0_cur = src0 + i02*nb02;
+    global char * src1_cur = src1 + i11*nb11 + i12*nb12;
+
+    global char * dst_cur = dst + (i1*ne0 + i2*ne1*ne0)*sizeof(float);
+
+    mul_mv_mxfp4_f32(src0_cur, src1_cur, dst_cur,
+        ne00, nb01, nb02, nb03, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shmem);
+}
diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl
new file mode 100644
index 00000000000..9a4d4b9bad1
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl
@@ -0,0 +1,144 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+#ifdef cl_intel_subgroups
+#pragma OPENCL EXTENSION cl_intel_subgroups : enable
+#else
+#pragma OPENCL EXTENSION cl_khr_subgroups : enable
+#endif
+
+#ifdef cl_intel_required_subgroup_size
+#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
+#define INTEL_GPU 1
+#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
+#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
+#elif defined(cl_qcom_reqd_sub_group_size)
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+#define ADRENO_GPU 1
+#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half")))
+#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
+#endif
+
+#define QK_MXFP4 32
+typedef struct {
+    uchar e; // E8M0
+    uchar qs[QK_MXFP4/2];
+} block_mxfp4;
+
+constant static float kvalues_mxfp4_f[16] = {
+    0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f
+};
+
+static inline float e8m0_to_fp32(uchar x) {
+    int bits;
+
+    if (x == 0) {
+        bits = 0x00400000;
+    } else {
+        bits = (uint) x << 23;
+    }
+
+    return as_float(bits);
+}
+
+#ifdef INTEL_GPU
+#define N_R0_MXFP4 2 // number of rows each subgroup works on
+#define N_SG_MXFP4 2 // number of subgroups in a work group
+#define N_SIMDWIDTH 16 // subgroup size
+#elif defined (ADRENO_GPU)
+#define N_R0_MXFP4 2
+#define N_SG_MXFP4 2
+#define N_SIMDWIDTH 64
+#endif
+
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_16
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_mul_mv_mxfp4_f32(
+    global char * src0,
+    ulong         offset0,
+    global char * src1,
+    ulong         offset1,
+    global char * dst,
+    ulong         offsetd,
+    int ne00,
+    ulong nb01,
+    ulong nb02,
+    ulong nb03,
+    int ne12,
+    ulong nb11,
+    ulong nb12,
+    ulong nb13,
+    int ne0,
+    int ne1,
+    int r2,
+    int r3,
+    local  char * shmem
+) {
+    src0 = (global char*)((global char*)src0 + offset0);
+    src1 = (global char*)((global char*)src1 + offset1);
+    dst  = (global char*)((global char*)dst  + offsetd);
+
+    local float * shmem_f32 = (local float *) shmem;
+    int nb = ne00/QK_MXFP4;
+
+    int r0 = get_group_id(0);
+    int r1 = get_group_id(1);
+    int im = get_group_id(2);
+
+    int first_row = (r0 * N_SG_MXFP4 + get_sub_group_id()) * N_R0_MXFP4;
+
+    uint i12 = im%ne12;
+    uint i13 = im/ne12;
+
+    ulong offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    ulong offset_src1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+    global block_mxfp4 * x = (global block_mxfp4 *) (src0 + offset_src0);
+    global float       * y = (global float       *) (src1 + offset_src1);
+
+    const short ix = get_sub_group_local_id()/2;  // 0...15
+    const short it = get_sub_group_local_id()%2;  // 0 or 1
+
+    shmem_f32[get_sub_group_local_id()] = kvalues_mxfp4_f[get_sub_group_local_id()%16];
+    barrier(CLK_LOCAL_MEM_FENCE);
+
+    float4 yl[4];
+    float sumf[N_R0_MXFP4] = {0.f};
+
+    global float * yb = y + ix * QK_MXFP4 + it * 8;
+
+    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
+        global float4 * y4 = (global float4 *)yb;
+        yl[0] = y4[0];
+        yl[1] = y4[4];
+        yl[2] = y4[1];
+        yl[3] = y4[5];
+
+        for (short row = 0; row < N_R0_MXFP4; row++) {
+            global block_mxfp4 * xb = x + row*nb + ib;
+            global uchar       * q2 = (global uchar *)(xb->qs + 8*it);
+
+            float4 acc1 = yl[0]*(float4)(shmem_f32[q2[0] &  0x0F], shmem_f32[q2[1] &  0x0F], shmem_f32[q2[2] &  0x0F], shmem_f32[q2[3] &  0x0F]);
+            float4 acc2 = yl[1]*(float4)(shmem_f32[q2[0] >> 4   ], shmem_f32[q2[1] >> 4   ], shmem_f32[q2[2] >> 4   ], shmem_f32[q2[3] >> 4   ]);
+            float4 acc3 = yl[2]*(float4)(shmem_f32[q2[4] &  0x0F], shmem_f32[q2[5] &  0x0F], shmem_f32[q2[6] &  0x0F], shmem_f32[q2[7] &  0x0F]);
+            float4 acc4 = yl[3]*(float4)(shmem_f32[q2[4] >> 4   ], shmem_f32[q2[5] >> 4   ], shmem_f32[q2[6] >> 4   ], shmem_f32[q2[7] >> 4   ]);
+
+            acc1 = (acc1 + acc3) + (acc2 + acc4);
+
+            sumf[row] += e8m0_to_fp32(xb->e) * ((acc1.s0 + acc1.s1) + (acc1.s2 + acc1.s3));
+        }
+
+        yb += (N_SIMDWIDTH/2) * QK_MXFP4;
+    }
+
+    global float * dst_f32 = (global float *) dst + (ulong)im*ne0*ne1 + (ulong)r1*ne0;
+
+    for (int row = 0; row < N_R0_MXFP4 && first_row + row < ne0; ++row) {
+        float sum_all = sub_group_reduce_add(sumf[row]);
+        if (get_sub_group_local_id() == 0) {
+            dst_f32[first_row + row] = sum_all;
+        }
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/rms_norm.cl b/ggml/src/ggml-opencl/kernels/rms_norm.cl
index 9d21f3398ec..ecd053cb4c1 100644
--- a/ggml/src/ggml-opencl/kernels/rms_norm.cl
+++ b/ggml/src/ggml-opencl/kernels/rms_norm.cl
@@ -94,3 +94,82 @@ kernel void kernel_rms_norm(
         }
     }
 }
+
+//------------------------------------------------------------------------------
+// rms_norm_mul
+//------------------------------------------------------------------------------
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_32
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_rms_norm_mul(
+        global char * src0,
+        ulong offset0,
+        global char * src1,
+        ulong offset1,
+        global char * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne03,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne10,
+        int ne11,
+        int ne12,
+        int ne13,
+        ulong nb11,
+        ulong nb12,
+        ulong nb13,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3,
+        float eps,
+        local float * sum
+) {
+    src0 = src0 + offset0;
+    src1 = src1 + offset1;
+    dst  = dst  + offsetd;
+
+    int i03 = get_group_id(2);
+    int i02 = get_group_id(1);
+    int i01 = get_group_id(0);
+
+    global float4 * x = (global float4 *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
+    global float4 * f = (global float4 *) (src1 + (i03%ne13)*nb13 + (i02%ne12)*nb12 + (i01%ne11)*nb11);
+
+    float sumf = 0;
+
+    // parallel sum
+    for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
+        sumf += dot(x[i00], x[i00]);
+    }
+    sumf = sub_group_reduce_add(sumf);
+    if (get_sub_group_local_id() == 0) {
+        sum[get_sub_group_id()] = sumf;
+    }
+
+    barrier(CLK_LOCAL_MEM_FENCE);
+
+    for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
+       if (get_local_id(0) < i) {
+           sum[get_local_id(0)] += sum[get_local_id(0) + i];
+       }
+    }
+    if (get_local_id(0) == 0) {
+        sum[0] /= ne00;
+    }
+
+    barrier(CLK_LOCAL_MEM_FENCE);
+
+    float mean  = sum[0];
+    float scale = 1.0f/sqrt(mean + eps);
+
+    global float4 * y = (global float4 *) (dst + i03*nb3 + i02*nb2 + i01*nb1);
+    for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
+        y[i00] = (x[i00] * scale) * f[i00%(ne10/4)];
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl b/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl
index a6d8ede6701..571d16507c6 100644
--- a/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl
+++ b/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl
@@ -26,6 +26,8 @@ kernel void kernel_soft_max_4_f16(
         ulong offset0,
         global char * src1,
         ulong offset1,
+        global char * src2,
+        ulong offset2,
         global char * dst,
         ulong offsetd,
         int ne00,
@@ -48,6 +50,7 @@ kernel void kernel_soft_max_4_f16(
 ) {
     src0 = src0 + offset0;
     src1 = src1 + offset1;
+    src2 = src2 + offset2;
     dst  = dst  + offsetd;
 
     int i03 = get_group_id(2);
@@ -60,6 +63,7 @@ kernel void kernel_soft_max_4_f16(
 
     global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
     global half4  * pmask = src1 != src0 ? (global half4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
+    global float  * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
     global float4 * pdst4 = (global float4 *)(dst  + i01*nb1 + i02*nb2 + i03*nb3);
 
     float slope = 1.0f;
@@ -75,7 +79,7 @@ kernel void kernel_soft_max_4_f16(
     }
 
     // parallel max
-    float4 lmax4 = -INFINITY;
+    float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
     for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
         lmax4 = fmax(lmax4, psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f));
     }
@@ -92,7 +96,11 @@ kernel void kernel_soft_max_4_f16(
     }
     float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
 
-    const float sum = sub_group_reduce_add(lsum);
+    float sum = sub_group_reduce_add(lsum);
+
+    if (psrc2) {
+        sum += exp(psrc2[i02] - max);
+    }
 
     for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
         pdst4[i00] /= sum;
diff --git a/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl b/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl
index 35b5573b46a..1f944b2201d 100644
--- a/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl
+++ b/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl
@@ -26,6 +26,8 @@ kernel void kernel_soft_max_4(
         ulong offset0,
         global char * src1,
         ulong offset1,
+        global char * src2,
+        ulong offset2,
         global char * dst,
         ulong offsetd,
         int ne00,
@@ -48,6 +50,7 @@ kernel void kernel_soft_max_4(
 ) {
     src0 = src0 + offset0;
     src1 = src1 + offset1;
+    src2 = src2 + offset2;
     dst  = dst  + offsetd;
 
     int i03 = get_group_id(2);
@@ -60,6 +63,7 @@ kernel void kernel_soft_max_4(
 
     global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
     global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
+    global float  * psrc2 = src2 != src0 ? (global float  *)(src2) : 0;
     global float4 * pdst4 = (global float4 *)(dst  + i01*nb1 + i02*nb2 + i03*nb3);
 
     float slope = 1.0f;
@@ -75,7 +79,7 @@ kernel void kernel_soft_max_4(
     }
 
     // parallel max
-    float4 lmax4 = -INFINITY;
+    float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
     for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
         lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
     }
@@ -92,7 +96,11 @@ kernel void kernel_soft_max_4(
     }
     float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
 
-    const float sum = sub_group_reduce_add(lsum);
+    float sum = sub_group_reduce_add(lsum);
+
+    if (psrc2) {
+        sum += exp(psrc2[i02] - max);
+    }
 
     for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
         pdst4[i00] /= sum;
diff --git a/ggml/src/ggml-opencl/kernels/softmax_f16.cl b/ggml/src/ggml-opencl/kernels/softmax_f16.cl
index 9d292b57465..4baa6c28e4f 100644
--- a/ggml/src/ggml-opencl/kernels/softmax_f16.cl
+++ b/ggml/src/ggml-opencl/kernels/softmax_f16.cl
@@ -26,6 +26,8 @@ kernel void kernel_soft_max_f16(
         ulong offset0,
         global char * src1,
         ulong offset1,
+        global char * src2,
+        ulong offset2,
         global char * dst,
         ulong offsetd,
         int ne00,
@@ -48,6 +50,7 @@ kernel void kernel_soft_max_f16(
 ) {
     src0 = src0 + offset0;
     src1 = src1 + offset1;
+    src2 = src2 + offset2;
     dst  = dst  + offsetd;
 
     int i03 = get_group_id(2);
@@ -60,6 +63,7 @@ kernel void kernel_soft_max_f16(
 
     global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
     global half  * pmask = src1 != src0 ? (global half *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
+    global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
     global float * pdst  = (global float *)(dst  + i01*nb1 + i02*nb2 + i03*nb3);
 
     float slope = 1.0f;
@@ -75,7 +79,7 @@ kernel void kernel_soft_max_f16(
     }
 
     // parallel max
-    float lmax = -INFINITY;
+    float lmax = psrc2 ? psrc2[i02] : -INFINITY;
     for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
         lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
     }
@@ -91,7 +95,11 @@ kernel void kernel_soft_max_f16(
         pdst[i00] = exp_psrc0;
     }
 
-    const float sum = sub_group_reduce_add(lsum);
+    float sum = sub_group_reduce_add(lsum);
+
+    if (psrc2) {
+        sum += exp(psrc2[i02] - max);
+    }
 
     for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
         pdst[i00] /= sum;
diff --git a/ggml/src/ggml-opencl/kernels/softmax_f32.cl b/ggml/src/ggml-opencl/kernels/softmax_f32.cl
index 7c53dfbe5a2..d503190b476 100644
--- a/ggml/src/ggml-opencl/kernels/softmax_f32.cl
+++ b/ggml/src/ggml-opencl/kernels/softmax_f32.cl
@@ -26,6 +26,8 @@ kernel void kernel_soft_max(
         ulong offset0,
         global char * src1,
         ulong offset1,
+        global char * src2,
+        ulong offset2,
         global char * dst,
         ulong offsetd,
         int ne00,
@@ -48,6 +50,7 @@ kernel void kernel_soft_max(
 ) {
     src0 = src0 + offset0;
     src1 = src1 + offset1;
+    src2 = src2 + offset2;
     dst  = dst  + offsetd;
 
     int i03 = get_group_id(2);
@@ -60,6 +63,7 @@ kernel void kernel_soft_max(
 
     global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
     global float * pmask = src1 != src0 ? (global float *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
+    global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
     global float * pdst  = (global float *)(dst  + i01*nb1 + i02*nb2 + i03*nb3);
 
     float slope = 1.0f;
@@ -75,7 +79,7 @@ kernel void kernel_soft_max(
     }
 
     // parallel max
-    float lmax = -INFINITY;
+    float lmax = psrc2 ? psrc2[i02] : -INFINITY;
     for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
         lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
     }
@@ -91,7 +95,11 @@ kernel void kernel_soft_max(
         pdst[i00] = exp_psrc0;
     }
 
-    const float sum = sub_group_reduce_add(lsum);
+    float sum = sub_group_reduce_add(lsum);
+
+    if (psrc2) {
+        sum += exp(psrc2[i02] - max);
+    }
 
     for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
         pdst[i00] /= sum;
diff --git a/ggml/src/ggml-opencl/kernels/sub.cl b/ggml/src/ggml-opencl/kernels/sub.cl
index 041e88ad3a0..423ed595ca8 100644
--- a/ggml/src/ggml-opencl/kernels/sub.cl
+++ b/ggml/src/ggml-opencl/kernels/sub.cl
@@ -70,3 +70,69 @@ kernel void kernel_sub_row(
     uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
     dst[gid] = src0[gid] - src1[idx1];
 }
+
+kernel void kernel_sub_f16(
+        global char * src0,
+        ulong offset0,
+        global char * src1,
+        ulong offset1,
+        global char * dst,
+        ulong offsetd,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne10,
+        int ne11,
+        int ne12,
+        int ne13,
+        ulong nb10,
+        ulong nb11,
+        ulong nb12,
+        ulong nb13,
+        int ne0,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3
+) {
+    src0 = src0 + offset0;
+    src1 = src1 + offset1;
+    dst  = dst + offsetd;
+
+    int i03 = get_group_id(2);
+    int i02 = get_group_id(1);
+    int i01 = get_group_id(0);
+
+    int i13 = i03 % ne13;
+    int i12 = i02 % ne12;
+    int i11 = i01 % ne11;
+
+    global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+    global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+    global char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
+
+    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
+        const int i10 = i0 % ne10;
+        *((global half *)(dst_ptr + i0*nb0)) = *((global half *)(src0_ptr + i0*nb00)) - *((global half *)(src1_ptr + i10*nb10));
+    }
+}
+
+kernel void kernel_sub_row_f16(
+        global half4 * src0,
+        ulong offset0,
+        global half4 * src1,
+        ulong offset1,
+        global half4 * dst,
+        ulong offsetd,
+        int ne
+) {
+    src0 = (global half4*)((global char*)src0 + offset0);
+    src1 = (global half4*)((global char*)src1 + offset1);
+    dst = (global half4*)((global char*)dst + offsetd);
+
+    // This performs better than using %.
+    uint gid = get_global_id(0);
+    uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
+    dst[gid] = src0[gid] - src1[idx1];
+}
diff --git a/ggml/src/ggml-opencl/kernels/transpose.cl b/ggml/src/ggml-opencl/kernels/transpose.cl
index a11490b304c..536dd560a91 100644
--- a/ggml/src/ggml-opencl/kernels/transpose.cl
+++ b/ggml/src/ggml-opencl/kernels/transpose.cl
@@ -24,6 +24,26 @@ kernel void kernel_transpose_16(
     write_imageh(output, (i_2+3)*rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3));
 }
 
+// Padded kernel for irregular shape
+kernel void kernel_transpose_16_4x1(
+    __read_only image1d_buffer_t input,
+    __write_only image1d_buffer_t output,
+    const uint rows,
+    const uint cols
+) {
+
+    const int i = get_global_id(0);
+    const int j = get_global_id(1);
+    const int j_2 = j << 2;
+
+    half temp0 = read_imageh(input, (j_2 + 0) * cols + i).x;
+    half temp1 = read_imageh(input, (j_2 + 1) * cols + i).x;
+    half temp2 = read_imageh(input, (j_2 + 2) * cols + i).x;
+    half temp3 = read_imageh(input, (j_2 + 3) * cols + i).x;
+
+    write_imageh(output, i * rows + j, (half4)(temp0, temp1, temp2, temp3));
+}
+
 // 32-bit transpose, loading/storing a 4x4 tile of elements
 kernel void kernel_transpose_32(
     __read_only image1d_buffer_t input,
diff --git a/ggml/src/ggml-opt.cpp b/ggml/src/ggml-opt.cpp
index a3c82d67577..e078ad14a39 100644
--- a/ggml/src/ggml-opt.cpp
+++ b/ggml/src/ggml-opt.cpp
@@ -64,9 +64,11 @@ struct ggml_opt_context {
     int32_t opt_i              = 0;
     bool    loss_per_datapoint = false;
 
-    ggml_opt_get_optimizer_params get_opt_pars = nullptr;
-    void * get_opt_pars_ud                     = nullptr;
-    struct ggml_tensor * adamw_params          = nullptr;
+    ggml_opt_get_optimizer_params get_opt_pars    = nullptr;
+    void *                        get_opt_pars_ud = nullptr;
+    struct ggml_tensor *          opt_step_params = nullptr; // Stores output of get_opt_pars.
+
+    enum ggml_opt_optimizer_type optimizer = GGML_OPT_OPTIMIZER_TYPE_ADAMW;
 };
 
 struct ggml_opt_result {
@@ -229,9 +231,13 @@ struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * us
     result.adamw.eps   = 1e-8f;
     result.adamw.wd    = 0.0f;
 
+    result.sgd.alpha   = 1e-3f;
+    result.sgd.wd      = 0.0f;
+
     return result;
 }
 
+
 struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) {
     return *((struct ggml_opt_optimizer_params *) userdata);
 }
@@ -249,6 +255,7 @@ struct ggml_opt_params ggml_opt_default_params(
         /*opt_period      =*/ 1,
         /*get_opt_pars    =*/ ggml_opt_get_default_optimizer_params,
         /*get_opt_pars_ud =*/ nullptr,
+        /*optimizer       =*/ GGML_OPT_OPTIMIZER_TYPE_ADAMW,
     };
 }
 
@@ -316,9 +323,14 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
     GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with ggml_opt_prepare_alloc");
     GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically");
 
+    const enum ggml_opt_optimizer_type optimizer = opt_ctx->optimizer;
+
     const bool accumulate = opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD &&
         !(opt_ctx->static_graphs && opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1);
 
+    const bool need_momenta = opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT &&
+        opt_ctx->optimizer == GGML_OPT_OPTIMIZER_TYPE_ADAMW;
+
     ggml_set_input(opt_ctx->inputs);
     ggml_set_output(opt_ctx->outputs);
 
@@ -340,8 +352,7 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
         //   - pred (if using static graphs)
         //   - ncorrect (if using static graphs, 2 tensors).
         constexpr size_t n_loss = 1;
-        const size_t tensors_per_param = (accumulate ? 1 : 0) +
-            (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
+        const size_t tensors_per_param = (accumulate ? 1 : 0) + (need_momenta ? 2 : 0);
         const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0;
         const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * ggml_tensor_overhead();
         struct ggml_init_params params = {
@@ -458,7 +469,7 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
             }
         }
 
-        if (opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) {
+        if (need_momenta && opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) {
             opt_ctx->grad_m.resize(n_nodes);
             opt_ctx->grad_v.resize(n_nodes);
             for (int i = 0; i < n_nodes; ++i) {
@@ -492,23 +503,36 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
     // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
     opt_ctx->gb_opt = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true);
 
-    opt_ctx->adamw_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, 7);
-    ggml_set_input(opt_ctx->adamw_params);
-    ggml_set_name(opt_ctx->adamw_params, "adamw_params");
-
+    opt_ctx->opt_step_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, need_momenta ? 7 : 2);
+    ggml_tensor * adamw_params = opt_ctx->opt_step_params;
+    ggml_set_input(adamw_params);
+    const char * optimizer_name = ggml_opt_optimizer_name(opt_ctx->optimizer);
+    ggml_format_name(adamw_params, "%s_params", optimizer_name);
     for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) {
         struct ggml_tensor * node = opt_ctx->gb_opt->nodes[i];
         struct ggml_tensor * grad = ggml_graph_get_grad(opt_ctx->gb_opt, node);
 
         if (grad && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
-            struct ggml_tensor * m        = opt_ctx->grad_m[i];
-            struct ggml_tensor * v        = opt_ctx->grad_v[i];
-            struct ggml_tensor * opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, opt_ctx->adamw_params);
-
-            ggml_set_name(m,        (std::string("AdamW m for ")    + std::string(node->name)).c_str());
-            ggml_set_name(v,        (std::string("AdamW v for ")    + std::string(node->name)).c_str());
-            ggml_set_name(opt_step, (std::string("AdamW step for ") + std::string(node->name)).c_str());
-
+            struct ggml_tensor * m = nullptr;
+            struct ggml_tensor * v = nullptr;
+            if (need_momenta) {
+                m = opt_ctx->grad_m[i];
+                v = opt_ctx->grad_v[i];
+                ggml_format_name(m, "AdamW m for %s", node->name);
+                ggml_format_name(v, "AdamW v for %s", node->name);
+            }
+            struct ggml_tensor * opt_step;
+            switch (optimizer) {
+                case GGML_OPT_OPTIMIZER_TYPE_ADAMW:
+                    opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, adamw_params);
+                    break;
+                case GGML_OPT_OPTIMIZER_TYPE_SGD:
+                    opt_step = ggml_opt_step_sgd(opt_ctx->ctx_compute, node, grad, adamw_params);
+                    break;
+                default:
+                    GGML_ABORT("fatal error");
+            }
+            ggml_format_name(opt_step, "%s step for %s", optimizer_name, node->name);
             ggml_build_forward_expand(opt_ctx->gb_opt, opt_step);
         }
     }
@@ -534,6 +558,7 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
     result->opt_period       = params.opt_period;
     result->get_opt_pars     = params.get_opt_pars;
     result->get_opt_pars_ud  = params.get_opt_pars_ud;
+    result->optimizer        = params.optimizer;
 
     GGML_ASSERT(result->opt_period >= 1);
 
@@ -756,29 +781,43 @@ void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward) {
 void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) {
     GGML_ASSERT(opt_ctx->eval_ready);
     if (opt_ctx->allocated_graph == opt_ctx->gb_opt) {
-        struct ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
-
-        GGML_ASSERT(opt_pars.adamw.alpha >  0.0f);
-        GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f);
-        GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f);
-        GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f);
-        GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f);
-        GGML_ASSERT(opt_pars.adamw.eps   >= 0.0f);
-        GGML_ASSERT(opt_pars.adamw.wd    >= 0.0f);
-        GGML_ASSERT(opt_pars.adamw.wd    <= 1.0f);
-
-        // beta1, beta2 after applying warmup
-        const float beta1h = 1.0f/(1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));
-        const float beta2h = 1.0f/(1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));
-
-        float * adamw_par_data = ggml_get_data_f32(opt_ctx->adamw_params);
-        adamw_par_data[0] = opt_pars.adamw.alpha;
-        adamw_par_data[1] = opt_pars.adamw.beta1;
-        adamw_par_data[2] = opt_pars.adamw.beta2;
-        adamw_par_data[3] = opt_pars.adamw.eps;
-        adamw_par_data[4] = opt_pars.adamw.wd;
-        adamw_par_data[5] = beta1h;
-        adamw_par_data[6] = beta2h;
+        const ggml_opt_optimizer_params & opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
+
+        switch (opt_ctx->optimizer) {
+            case GGML_OPT_OPTIMIZER_TYPE_ADAMW: {
+                GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
+                GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f);
+                GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f);
+                GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f);
+                GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f);
+                GGML_ASSERT(opt_pars.adamw.eps >= 0.0f);
+                GGML_ASSERT(opt_pars.adamw.wd >= 0.0f);
+                GGML_ASSERT(opt_pars.adamw.wd <= 1.0f);
+
+                // beta1, beta2 after applying warmup
+                const float beta1h = 1.0f / (1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));
+                const float beta2h = 1.0f / (1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));
+
+                float * adamw_par_data = ggml_get_data_f32(opt_ctx->opt_step_params);
+                adamw_par_data[0] = opt_pars.adamw.alpha;
+                adamw_par_data[1] = opt_pars.adamw.beta1;
+                adamw_par_data[2] = opt_pars.adamw.beta2;
+                adamw_par_data[3] = opt_pars.adamw.eps;
+                adamw_par_data[4] = opt_pars.adamw.wd;
+                adamw_par_data[5] = beta1h;
+                adamw_par_data[6] = beta2h;
+            } break;
+            case GGML_OPT_OPTIMIZER_TYPE_SGD: {
+                GGML_ASSERT(opt_pars.sgd.alpha > 0.0f);
+                GGML_ASSERT(opt_pars.sgd.wd >= 0.0f);
+                GGML_ASSERT(opt_pars.sgd.wd <= 1.0f);
+                float * sgd = ggml_get_data_f32(opt_ctx->opt_step_params);
+                sgd[0] = opt_pars.sgd.alpha;
+                sgd[1] = opt_pars.sgd.wd;
+            } break;
+            default:
+                GGML_ABORT("fatal error");
+        }
     }
 
     ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
@@ -963,6 +1002,7 @@ void ggml_opt_fit(
         ggml_tensor                   * outputs,
         ggml_opt_dataset_t              dataset,
         enum ggml_opt_loss_type         loss_type,
+        enum ggml_opt_optimizer_type    optimizer,
         ggml_opt_get_optimizer_params   get_opt_pars,
         int64_t                         nepoch,
         int64_t                         nbatch_logical,
@@ -993,6 +1033,7 @@ void ggml_opt_fit(
     params.opt_period      = opt_period;
     params.get_opt_pars    = get_opt_pars;
     params.get_opt_pars_ud = &epoch;
+    params.optimizer       = optimizer;
     ggml_opt_context_t opt_ctx = ggml_opt_init(params);
 
     // Shuffling the data is generally useful but there is only a point if not all data is used in a single batch.
@@ -1035,3 +1076,18 @@ void ggml_opt_fit(
     ggml_opt_result_free(result_train);
     ggml_opt_result_free(result_val);
 }
+
+enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t c) {
+    return c->optimizer;
+}
+
+GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type o) {
+    switch (o) {
+        case GGML_OPT_OPTIMIZER_TYPE_ADAMW:
+            return "adamw";
+        case GGML_OPT_OPTIMIZER_TYPE_SGD:
+            return "sgd";
+        default:
+            return "undefined";
+    };
+}
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index 9a7d1b22d79..727932123e4 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -21,6 +21,17 @@
 
 #define UNUSED GGML_UNUSED
 
+static inline int best_index_int8(int n, const int8_t * val, float x) {
+    if (x <= val[0]) return 0;
+    if (x >= val[n-1]) return n-1;
+    int ml = 0, mu = n-1;
+    while (mu-ml > 1) {
+        int mav = (ml+mu)/2;
+        if (x < val[mav]) mu = mav; else ml = mav;
+    }
+    return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
+}
+
 // reference implementation for deterministic creation of model files
 void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k) {
     static const int qk = QK4_0;
@@ -246,6 +257,53 @@ void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_REST
     }
 }
 
+static inline int best_index_mxfp4(float x, float e) {
+    int best_index = 0;
+    float best_err = fabsf(kvalues_mxfp4[0]*e - x);
+    for (int i = 1; i < 16; i++) {
+        float err = fabsf(kvalues_mxfp4[i]*e - x);
+        if (err < best_err) {
+            best_index = i;
+            best_err = err;
+        }
+    }
+    return best_index;
+}
+
+void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) {
+    static const int qk = QK_MXFP4;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+
+        for (int j = 0; j < qk; j++) {
+            const float v = x[i*qk + j];
+
+            if (amax < fabsf(v)) {
+                amax = fabsf(v);
+            }
+        }
+
+        const uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax)) - 2 + 127) : 0;
+
+        const float d = GGML_E8M0_TO_FP32_HALF(e);
+
+        y[i].e = e;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const uint8_t x0 = best_index_mxfp4(x[i*qk + 0    + j], d);
+            const uint8_t x1 = best_index_mxfp4(x[i*qk + qk/2 + j], d);
+
+            y[i].qs[j]  = x0;
+            y[i].qs[j] |= x1 << 4;
+        }
+    }
+}
+
 void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
     static const int qk = QK4_0;
 
@@ -356,6 +414,26 @@ void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRI
     }
 }
 
+void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
+    static const int qk = QK_MXFP4;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        const float d = GGML_E8M0_TO_FP32_HALF(x[i].e);
+
+        for (int j = 0; j < qk/2; ++j) {
+            const int8_t x0 = kvalues_mxfp4[x[i].qs[j] & 0x0F];
+            const int8_t x1 = kvalues_mxfp4[x[i].qs[j] >>   4];
+
+            y[i*qk + j + 0   ] = x0*d;
+            y[i*qk + j + qk/2] = x1*d;
+        }
+    }
+}
+
 //
 // 2-6 bit quantization in super-blocks
 //
@@ -488,7 +566,7 @@ static float make_q3_quants(int n, int nmax, const float * GGML_RESTRICT x, int8
         for (int i = 0; i < n; ++i) {
             L[i] += nmax;
         }
-        return sumlx / suml2;
+        return suml2 > 0.0f ? sumlx / suml2 : 0.0f;
     }
     for (int i = 0; i < n; ++i) {
         int l = nearest_int(iscale * x[i]);
@@ -823,7 +901,7 @@ static float make_qp_quants(int n, int nmax, const float * GGML_RESTRICT x, uint
     for (int i = 0; i < n; ++i) {
         max = MAX(max, x[i]);
     }
-    if (!max) { // all zero
+    if (max < GROUP_MAX_EPS) { // all zero
         for (int i = 0; i < n; ++i) { L[i] = 0; }
         return 0.f;
     }
@@ -888,7 +966,7 @@ static float make_qp_quants(int n, int nmax, const float * GGML_RESTRICT x, uint
             break;
         }
     }
-    return sumlx/suml2;
+    return suml2 > 0.0f ? sumlx / suml2 : 0.0f;
 }
 
 static void quantize_row_q2_K_impl(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int k, const float * GGML_RESTRICT quant_weights) {
@@ -2014,6 +2092,12 @@ size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
     return nrow * row_size;
 }
 
+size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+    GGML_UNUSED(quant_weights);
+    quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row);
+    return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row);
+}
+
 // ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
 
 void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) {
@@ -4182,7 +4266,7 @@ static void quantize_row_iq1_s_impl(const float * GGML_RESTRICT x, void * GGML_R
                     sumw[j+1] = sumw[j] + weight[i];
                 }
             }
-            float best_score = -FLT_MIN, scale = max;
+            float best_score = -FLT_MAX, scale = max;
             int besti1 = -1, besti2 = -1, best_shift = 0;
             for (int i1 = 0; i1 <= block_size; ++i1) {
                 for (int i2 = i1; i2 <= block_size; ++i2) {
@@ -4358,7 +4442,7 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
                 idx[2*j] = j;
             }
             qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper);
-            float best_score = -FLT_MIN, scale = max;
+            float best_score = -FLT_MAX, scale = max;
             int besti1 = -1, besti2 = -1, best_k = -1;
             // 0: +, +
             // 1: +, -
@@ -4551,17 +4635,6 @@ size_t quantize_iq1_m(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
 
 // ============================ 4-bit non-linear quants
 
-static inline int best_index_int8(int n, const int8_t * val, float x) {
-    if (x <= val[0]) return 0;
-    if (x >= val[n-1]) return n-1;
-    int ml = 0, mu = n-1;
-    while (mu-ml > 1) {
-        int mav = (ml+mu)/2;
-        if (x < val[mav]) mu = mav; else ml = mav;
-    }
-    return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
-}
-
 static void quantize_row_iq4_nl_impl(const int super_block_size, const int block_size, const float * GGML_RESTRICT x,
         ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l,
         float * scales, float * weight, uint8_t * L,
@@ -4961,6 +5034,15 @@ static bool validate_fp16(ggml_fp16_t f, size_t i) {
     return true;
 }
 
+static bool validate_e_e8m0(uint8_t e, size_t i) {
+    if (e == 0xff) {
+        fprintf(stderr, "ggml_validate_row_data: found invalid e value %d at block %zu\n", e, i);
+        return false;
+    }
+
+    return true;
+}
+
 #define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \
     const type * q = (const type *) (data); \
     for (size_t i = 0; i < (nb); ++i) { \
@@ -4977,6 +5059,14 @@ static bool validate_fp16(ggml_fp16_t f, size_t i) {
         } \
     }
 
+#define VALIDATE_ROW_DATA_E_E8M0_IMPL(type, data, nb) \
+    const type * q = (const type *) (data); \
+    for (size_t i = 0; i < (nb); ++i) { \
+        if (!validate_e_e8m0(q[i].e, i)) { \
+            return false; \
+        } \
+    }
+
 #define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \
     const type * q = (const type *) (data); \
     for (size_t i = 0; i < (nb); ++i) { \
@@ -5130,6 +5220,10 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
             {
                 VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb);
             } break;
+        case GGML_TYPE_MXFP4:
+            {
+                VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp4, data, nb);
+            } break;
         case GGML_TYPE_Q2_K:
             {
                 VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin);
diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h
index d09173e1116..3b688f31c21 100644
--- a/ggml/src/ggml-quants.h
+++ b/ggml/src/ggml-quants.h
@@ -21,6 +21,8 @@ GGML_API void quantize_row_q5_1_ref(const float * GGML_RESTRICT x, block_q5_1 *
 GGML_API void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k);
 GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k);
 
+GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k);
+
 GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
 GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
 GGML_API void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k);
@@ -45,6 +47,8 @@ GGML_API void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GG
 GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
 //GGML_API void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
 
+GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+
 GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
 GGML_API void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
 GGML_API void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
@@ -90,6 +94,8 @@ GGML_API size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTR
 GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
 GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
 
+GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+
 GGML_API void iq2xs_init_impl(enum ggml_type type);
 GGML_API void iq2xs_free_impl(enum ggml_type type);
 GGML_API void iq3xs_init_impl(int grid_size);
diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp
index f468f796d57..e84ff93efc3 100644
--- a/ggml/src/ggml-rpc/ggml-rpc.cpp
+++ b/ggml/src/ggml-rpc/ggml-rpc.cpp
@@ -29,9 +29,12 @@
 #include 
 #include 
 #include 
+#include 
 
 namespace fs = std::filesystem;
 
+static constexpr size_t MAX_CHUNK_SIZE = 1024ull * 1024ull * 1024ull; // 1 GiB
+
 #ifdef _WIN32
 typedef SOCKET sockfd_t;
 using ssize_t = __int64;
@@ -323,11 +326,14 @@ static std::shared_ptr create_server_socket(const char * host, int por
 static bool send_data(sockfd_t sockfd, const void * data, size_t size) {
     size_t bytes_sent = 0;
     while (bytes_sent < size) {
-        ssize_t n = send(sockfd, (const char *)data + bytes_sent, size - bytes_sent, 0);
+        size_t size_to_send = std::min(size - bytes_sent, MAX_CHUNK_SIZE);
+        ssize_t n = send(sockfd, (const char *)data + bytes_sent, size_to_send, 0);
         if (n < 0) {
+            GGML_LOG_ERROR("send failed (bytes_sent=%zu, size_to_send=%zu)\n",
+                           bytes_sent, size_to_send);
             return false;
         }
-        bytes_sent += n;
+        bytes_sent += (size_t)n;
     }
     return true;
 }
@@ -335,11 +341,18 @@ static bool send_data(sockfd_t sockfd, const void * data, size_t size) {
 static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
     size_t bytes_recv = 0;
     while (bytes_recv < size) {
-        ssize_t n = recv(sockfd, (char *)data + bytes_recv, size - bytes_recv, 0);
-        if (n <= 0) {
+        size_t size_to_recv = std::min(size - bytes_recv, MAX_CHUNK_SIZE);
+        ssize_t n = recv(sockfd, (char *)data + bytes_recv, size_to_recv, 0);
+        if (n < 0) {
+            GGML_LOG_ERROR("recv failed (bytes_recv=%zu, size_to_recv=%zu)\n",
+                           bytes_recv, size_to_recv);
+            return false;
+        }
+        if (n == 0) {
+            GGML_LOG_ERROR("recv returned 0 (peer closed?)\n");
             return false;
         }
-        bytes_recv += n;
+        bytes_recv += (size_t)n;
     }
     return true;
 }
@@ -823,10 +836,10 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
     };
 
     ggml_backend_t backend = new ggml_backend {
-        /* .guid      = */ ggml_backend_rpc_guid(),
-        /* .interface = */ ggml_backend_rpc_interface,
-        /* .device    = */ ggml_backend_rpc_add_device(endpoint),
-        /* .context   = */ ctx
+        /* .guid    = */ ggml_backend_rpc_guid(),
+        /* .iface   = */ ggml_backend_rpc_interface,
+        /* .device  = */ ggml_backend_rpc_add_device(endpoint),
+        /* .context = */ ctx
     };
     return backend;
 }
@@ -1055,7 +1068,7 @@ bool rpc_server::set_tensor(const std::vector & input) {
     GGML_ASSERT(ctx_ptr != nullptr);
     ggml_context * ctx = ctx_ptr.get();
     ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
-    if (tensor == nullptr) {
+    if (tensor == nullptr || tensor->buffer == nullptr) {
         GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
         return false;
     }
@@ -1124,7 +1137,7 @@ bool rpc_server::set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rp
     GGML_ASSERT(ctx_ptr != nullptr);
     ggml_context * ctx = ctx_ptr.get();
     ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
-    if (tensor == nullptr) {
+    if (tensor == nullptr || tensor->buffer == nullptr) {
         GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
         return false;
     }
@@ -1192,7 +1205,7 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<
     GGML_ASSERT(ctx_ptr != nullptr);
     ggml_context * ctx = ctx_ptr.get();
     ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
-    if (tensor == nullptr) {
+    if (tensor == nullptr || tensor->buffer == nullptr) {
         GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
         return false;
     }
@@ -1229,7 +1242,7 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co
 
     ggml_tensor * src = deserialize_tensor(ctx, &request.src);
     ggml_tensor * dst = deserialize_tensor(ctx, &request.dst);
-    if (src == nullptr || dst == nullptr) {
+    if (src == nullptr || dst == nullptr || src->buffer == nullptr || dst->buffer == nullptr) {
         GGML_LOG_ERROR("[%s] error deserializing tensors\n", __func__);
         return false;
     }
diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp
index f839a42bc90..410a67b0195 100644
--- a/ggml/src/ggml-sycl/backend.hpp
+++ b/ggml/src/ggml-sycl/backend.hpp
@@ -28,6 +28,7 @@
 #include "mmvq.hpp"
 #include "norm.hpp"
 #include "outprod.hpp"
+#include "quantize.hpp"
 #include "quants.hpp"
 #include "rope.hpp"
 #include "set_rows.hpp"
diff --git a/ggml/src/ggml-sycl/cpy.cpp b/ggml/src/ggml-sycl/cpy.cpp
index 1ffd7f12267..3d321b58ac6 100644
--- a/ggml/src/ggml-sycl/cpy.cpp
+++ b/ggml/src/ggml-sycl/cpy.cpp
@@ -1,31 +1,12 @@
 #include "cpy.hpp"
 
 #include 
-#include 
 
 #include "dequantize.hpp"
 #include "ggml-sycl/common.hpp"
 #include "ggml-sycl/presets.hpp"
 #include "ggml.h"
 
-static __dpct_inline__ int best_index_int8(int n, const int8_t * val, float x) {
-    if (x <= val[0]) {
-        return 0;
-    }
-    if (x >= val[n - 1]) {
-        return n - 1;
-    }
-    int ml = 0, mu = n - 1;
-    while (mu - ml > 1) {
-        int mav = (ml + mu) / 2;
-        if (x < val[mav]) {
-            mu = mav;
-        } else {
-            ml = mav;
-        }
-    }
-    return x - val[mu - 1] < val[mu] - x ? mu - 1 : mu;
-}
 
 static void cpy_1_f32_f32(const char * cxi, char * cdsti) {
     const float * xi   = (const float *) cxi;
@@ -97,28 +78,6 @@ static void cpy_f32_f16(const char * cx, char * cdst, const int ne, const int ne
     cpy_1(cx + x_offset, cdst + dst_offset);
 }
 
-static void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
-    const float * xi   = (const float *) cxi;
-    block_q8_0 *  dsti = (block_q8_0 *) cdsti;
-
-    float amax = 0.0f;  // absolute max
-
-    for (int j = 0; j < QK8_0; j++) {
-        const float v = xi[j];
-        amax          = sycl::fmax(amax, sycl::fabs((float) v));
-    }
-
-    const float d  = amax / ((1 << 7) - 1);
-    const float id = d ? 1.0f / d : 0.0f;
-
-    dsti->d = d;
-
-    for (int j = 0; j < QK8_0; ++j) {
-        const float x0 = xi[j] * id;
-
-        dsti->qs[j] = sycl::round((float) x0);
-    }
-}
 
 /* quantized type same copy */
 template
@@ -140,178 +99,7 @@ static void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
     }
 }
 
-static void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
-    const float * xi   = (const float *) cxi;
-    block_q4_0 *  dsti = (block_q4_0 *) cdsti;
-
-    float amax = 0.0f;
-    float vmax = 0.0f;
-
-    for (int j = 0; j < QK4_0; ++j) {
-        const float v = xi[j];
-        if (amax < sycl::fabs((float) v)) {
-            amax = sycl::fabs((float) v);
-            vmax = v;
-        }
-    }
-
-    const float d  = vmax / -8;
-    const float id = d ? 1.0f / d : 0.0f;
-
-    dsti->d = d;
-
-    for (int j = 0; j < QK4_0 / 2; ++j) {
-        const float x0 = xi[0 + j] * id;
-        const float x1 = xi[QK4_0 / 2 + j] * id;
-
-        const uint8_t xi0 = dpct::min(15, (int8_t) (x0 + 8.5f));
-        const uint8_t xi1 = dpct::min(15, (int8_t) (x1 + 8.5f));
-
-        dsti->qs[j] = xi0;
-        dsti->qs[j] |= xi1 << 4;
-    }
-}
-
-static void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
-    const float * xi   = (const float *) cxi;
-    block_q4_1 *  dsti = (block_q4_1 *) cdsti;
-
-    float vmin = FLT_MAX;
-    float vmax = -FLT_MAX;
-
-    for (int j = 0; j < QK4_1; ++j) {
-        const float v = xi[j];
-
-        if (v < vmin) {
-            vmin = v;
-        }
-        if (v > vmax) {
-            vmax = v;
-        }
-    }
-
-    const float d  = (vmax - vmin) / ((1 << 4) - 1);
-    const float id = d ? 1.0f / d : 0.0f;
-
-    dsti->dm.x() = d;
-    dsti->dm.y() = vmin;
-
-    for (int j = 0; j < QK4_1 / 2; ++j) {
-        const float x0 = (xi[0 + j] - vmin) * id;
-        const float x1 = (xi[QK4_1 / 2 + j] - vmin) * id;
-
-        const uint8_t xi0 = dpct::min(15, (int8_t) (x0 + 0.5f));
-        const uint8_t xi1 = dpct::min(15, (int8_t) (x1 + 0.5f));
 
-        dsti->qs[j] = xi0;
-        dsti->qs[j] |= xi1 << 4;
-    }
-}
-
-static void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
-    const float * xi   = (const float *) cxi;
-    block_q5_0 *  dsti = (block_q5_0 *) cdsti;
-
-    float amax = 0.0f;
-    float vmax = 0.0f;
-
-    for (int j = 0; j < QK5_0; ++j) {
-        const float v = xi[j];
-        if (amax < sycl::fabs((float) v)) {
-            amax = sycl::fabs((float) v);
-            vmax = v;
-        }
-    }
-
-    const float d  = vmax / -16;
-    const float id = d ? 1.0f / d : 0.0f;
-
-    dsti->d = d;
-
-    uint32_t qh = 0;
-    for (int j = 0; j < QK5_0 / 2; ++j) {
-        const float x0 = xi[0 + j] * id;
-        const float x1 = xi[QK5_0 / 2 + j] * id;
-
-        const uint8_t xi0 = dpct::min(31, (int8_t) (x0 + 16.5f));
-        const uint8_t xi1 = dpct::min(31, (int8_t) (x1 + 16.5f));
-
-        dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
-        qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
-        qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0 / 2);
-    }
-    memcpy(dsti->qh, &qh, sizeof(qh));
-}
-
-static void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
-    const float * xi   = (const float *) cxi;
-    block_q5_1 *  dsti = (block_q5_1 *) cdsti;
-
-    float min = xi[0];
-    float max = xi[0];
-
-    for (int j = 1; j < QK5_1; ++j) {
-        const float v = xi[j];
-        min           = v < min ? v : min;
-        max           = v > max ? v : max;
-    }
-
-    const float d  = (max - min) / 31;
-    const float id = d ? 1.0f / d : 0.0f;
-
-    dsti->dm.x() = d;
-    dsti->dm.y() = min;
-
-    uint32_t qh = 0;
-    for (int j = 0; j < QK5_1 / 2; ++j) {
-        const float x0 = (xi[0 + j] - min) * id;
-        const float x1 = (xi[QK5_1 / 2 + j] - min) * id;
-
-        const uint8_t xi0 = (uint8_t) (x0 + 0.5f);
-        const uint8_t xi1 = (uint8_t) (x1 + 0.5f);
-
-        dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
-        qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
-        qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1 / 2);
-    }
-    memcpy(dsti->qh, &qh, sizeof(qh));
-}
-
-static void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
-    const float *  xi   = (const float *) cxi;
-    block_iq4_nl * dsti = (block_iq4_nl *) cdsti;
-
-    float amax = 0.0f;
-    float vmax = 0.0f;
-
-    for (int j = 0; j < QK4_NL; ++j) {
-        const float v = xi[j];
-        if (amax < sycl::fabs((float) v)) {
-            amax = sycl::fabs((float) v);
-            vmax = v;
-        }
-    }
-
-    float       d  = vmax / kvalues_iq4nl[0];
-    const float id = d ? 1.0f / d : 0.0f;
-
-    float sumqx = 0, sumq2 = 0;
-    for (int j = 0; j < QK4_NL / 2; ++j) {
-        const float   x0  = xi[0 + j] * id;
-        const float   x1  = xi[QK4_NL / 2 + j] * id;
-        const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
-        const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
-        dsti->qs[j]       = xi0 | (xi1 << 4);
-        const float v0    = kvalues_iq4nl[xi0];
-        const float v1    = kvalues_iq4nl[xi1];
-        const float w0    = xi[0 + j] * xi[0 + j];
-        const float w1    = xi[QK4_NL / 2 + j] * xi[QK4_NL / 2 + j];
-        sumqx += w0 * v0 * xi[j] + w1 * v1 * xi[QK4_NL / 2 + j];
-        sumq2 += w0 * v0 * v0 + w1 * v1 * v1;
-    }
-
-    dsti->d = sumq2 > 0 ? sumqx / sumq2 : d;
-}
 
 template  static void cpy_blck_q_f32(const char * cxi, char * cdsti) {
     float * cdstf = (float *) (cdsti);
diff --git a/ggml/src/ggml-sycl/cpy.hpp b/ggml/src/ggml-sycl/cpy.hpp
index 0a0f561d230..3c331f1ef27 100644
--- a/ggml/src/ggml-sycl/cpy.hpp
+++ b/ggml/src/ggml-sycl/cpy.hpp
@@ -2,10 +2,222 @@
 #define GGML_SYCL_CPY_HPP
 
 #include "common.hpp"
+#include 
 
 typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
 
+__dpct_inline__ int best_index_int8(int n, const int8_t * val, float x) {
+    if (x <= val[0]) {
+        return 0;
+    }
+    if (x >= val[n - 1]) {
+        return n - 1;
+    }
+    int ml = 0, mu = n - 1;
+    while (mu - ml > 1) {
+        int mav = (ml + mu) / 2;
+        if (x < val[mav]) {
+            mu = mav;
+        } else {
+            ml = mav;
+        }
+    }
+    return x - val[mu - 1] < val[mu] - x ? mu - 1 : mu;
+}
+
+inline void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
+    const float * xi   = (const float *) cxi;
+    block_q8_0 *  dsti = (block_q8_0 *) cdsti;
+
+    float amax = 0.0f;  // absolute max
+
+    for (int j = 0; j < QK8_0; j++) {
+        const float v = xi[j];
+        amax          = sycl::fmax(amax, sycl::fabs((float) v));
+    }
+
+    const float d  = amax / ((1 << 7) - 1);
+    const float id = d ? 1.0f / d : 0.0f;
+
+    dsti->d = d;
+
+    for (int j = 0; j < QK8_0; ++j) {
+        const float x0 = xi[j] * id;
+
+        dsti->qs[j] = sycl::round((float) x0);
+    }
+}
+
+inline void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
+    const float * xi   = (const float *) cxi;
+    block_q4_0 *  dsti = (block_q4_0 *) cdsti;
+
+    float amax = 0.0f;
+    float vmax = 0.0f;
+
+    for (int j = 0; j < QK4_0; ++j) {
+        const float v = xi[j];
+        if (amax < sycl::fabs((float) v)) {
+            amax = sycl::fabs((float) v);
+            vmax = v;
+        }
+    }
+
+    const float d  = vmax / -8;
+    const float id = d ? 1.0f / d : 0.0f;
+
+    dsti->d = d;
+
+    for (int j = 0; j < QK4_0 / 2; ++j) {
+        const float x0 = xi[0 + j] * id;
+        const float x1 = xi[QK4_0 / 2 + j] * id;
+
+        const uint8_t xi0 = dpct::min(15, (int8_t) (x0 + 8.5f));
+        const uint8_t xi1 = dpct::min(15, (int8_t) (x1 + 8.5f));
+
+        dsti->qs[j] = xi0;
+        dsti->qs[j] |= xi1 << 4;
+    }
+}
+
+inline void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
+    const float * xi   = (const float *) cxi;
+    block_q4_1 *  dsti = (block_q4_1 *) cdsti;
+
+    float vmin = FLT_MAX;
+    float vmax = -FLT_MAX;
+
+    for (int j = 0; j < QK4_1; ++j) {
+        const float v = xi[j];
+
+        vmin = sycl::min(v, vmin);
+        vmax = sycl::max(v, vmax);
+    }
+
+    const float d  = (vmax - vmin) / ((1 << 4) - 1);
+    const float id = d ? 1.0f / d : 0.0f;
+
+    dsti->dm.x() = d;
+    dsti->dm.y() = vmin;
+
+    for (int j = 0; j < QK4_1 / 2; ++j) {
+        const float x0 = (xi[0 + j] - vmin) * id;
+        const float x1 = (xi[QK4_1 / 2 + j] - vmin) * id;
+
+        const uint8_t xi0 = dpct::min(15, (int8_t) (x0 + 0.5f));
+        const uint8_t xi1 = dpct::min(15, (int8_t) (x1 + 0.5f));
+
+        dsti->qs[j] = xi0;
+        dsti->qs[j] |= xi1 << 4;
+    }
+}
+
+inline void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
+    const float * xi   = (const float *) cxi;
+    block_q5_0 *  dsti = (block_q5_0 *) cdsti;
+
+    float amax = 0.0f;
+    float vmax = 0.0f;
+
+    for (int j = 0; j < QK5_0; ++j) {
+        const float v = xi[j];
+        if (amax < sycl::fabs((float) v)) {
+            amax = sycl::fabs((float) v);
+            vmax = v;
+        }
+    }
+
+    const float d  = vmax / -16;
+    const float id = d ? 1.0f / d : 0.0f;
+
+    dsti->d = d;
+
+    uint32_t qh = 0;
+    for (int j = 0; j < QK5_0 / 2; ++j) {
+        const float x0 = xi[0 + j] * id;
+        const float x1 = xi[QK5_0 / 2 + j] * id;
+
+        const uint8_t xi0 = dpct::min(31, (int8_t) (x0 + 16.5f));
+        const uint8_t xi1 = dpct::min(31, (int8_t) (x1 + 16.5f));
+
+        dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+        qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+        qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0 / 2);
+    }
+    memcpy(dsti->qh, &qh, sizeof(qh));
+}
+
+inline void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
+    const float * xi   = (const float *) cxi;
+    block_q5_1 *  dsti = (block_q5_1 *) cdsti;
+
+    float min = xi[0];
+    float max = xi[0];
+
+    for (int j = 1; j < QK5_1; ++j) {
+        const float v = xi[j];
+        min           = v < min ? v : min;
+        max           = v > max ? v : max;
+    }
+
+    const float d  = (max - min) / 31;
+    const float id = d ? 1.0f / d : 0.0f;
+
+    dsti->dm.x() = d;
+    dsti->dm.y() = min;
+
+    uint32_t qh = 0;
+    for (int j = 0; j < QK5_1 / 2; ++j) {
+        const float x0 = (xi[0 + j] - min) * id;
+        const float x1 = (xi[QK5_1 / 2 + j] - min) * id;
+
+        const uint8_t xi0 = (uint8_t) (x0 + 0.5f);
+        const uint8_t xi1 = (uint8_t) (x1 + 0.5f);
+
+        dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+        qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+        qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1 / 2);
+    }
+    memcpy(dsti->qh, &qh, sizeof(qh));
+}
+
+inline void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
+    const float *  xi   = (const float *) cxi;
+    block_iq4_nl * dsti = (block_iq4_nl *) cdsti;
+
+    float amax = 0.0f;
+    float vmax = 0.0f;
+
+    for (int j = 0; j < QK4_NL; ++j) {
+        const float v = xi[j];
+        if (amax < sycl::fabs((float) v)) {
+            amax = sycl::fabs((float) v);
+            vmax = v;
+        }
+    }
+
+    float       d  = vmax / kvalues_iq4nl[0];
+    const float id = d ? 1.0f / d : 0.0f;
+
+    float sumqx = 0, sumq2 = 0;
+    for (int j = 0; j < QK4_NL / 2; ++j) {
+        const float   x0  = xi[0 + j] * id;
+        const float   x1  = xi[QK4_NL / 2 + j] * id;
+        const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
+        const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
+        dsti->qs[j]       = xi0 | (xi1 << 4);
+        const float v0    = kvalues_iq4nl[xi0];
+        const float v1    = kvalues_iq4nl[xi1];
+        const float w0    = xi[0 + j] * xi[0 + j];
+        const float w1    = xi[QK4_NL / 2 + j] * xi[QK4_NL / 2 + j];
+        sumqx += w0 * v0 * xi[j] + w1 * v1 * xi[QK4_NL / 2 + j];
+        sumq2 += w0 * v0 * v0 + w1 * v1 * v1;
+    }
+
+    dsti->d = sumq2 > 0 ? sumqx / sumq2 : d;
+}
+
 void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1);
 void ggml_sycl_dup(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
-#endif // GGML_SYCL_CPY_HPP
+#endif  // GGML_SYCL_CPY_HPP
diff --git a/ggml/src/ggml-sycl/gemm.hpp b/ggml/src/ggml-sycl/gemm.hpp
index 5efe03d364b..dcf6c7aeeb4 100644
--- a/ggml/src/ggml-sycl/gemm.hpp
+++ b/ggml/src/ggml-sycl/gemm.hpp
@@ -32,39 +32,28 @@ class DnnlGemmWrapper {
         else static_assert(0);
     }
 
-    // matrix A has m rows, k columns
-    // matrix B has k rows, n columns
-    // nra - number of elements to skip when moving into next row in A
-    // nrb - number of elements to skip when moving into next row in B
-    // nca - number of elements to skip when moving into next column in A
-    // ncb - number of elements to skip when moving into next column in B
-    // stride_a - number of elements to skip when moving to next A matrix
-    // stride_b - number of elements to skip when moving to next B matrix
-    // batches_a - number of A matrices
-    // batches_b - number of B matrices
     static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
-        const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a,
-        const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b,
+        const void * a, dt at, dnnl_dim_t stra0, dnnl_dim_t stra1, dnnl_dim_t stra2,
+        const void * b, dt bt, dnnl_dim_t strb0, dnnl_dim_t strb1, dnnl_dim_t strb2,
         void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) {
 
         auto stream = ctx.stream_dnnl(q);
         auto eng = ctx.engine_dnnl(q);
 
-        // { # strides, # rows, # columns }
-        dnnl::memory::dims a_dims = { batches_a, m, k };
-        dnnl::memory::dims b_dims = { batches_b, k, n };
-        dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n };
-
-        // { # elements to skip to next stride, # elements to skip to next row, # elements to skip to next column }
-        dnnl::memory::dims a_strides = { stride_a, nra, nca };
-        dnnl::memory::dims b_strides = { stride_b, nrb, ncb };
-
+        dnnl::memory::dims a_dims = {batches_a, m, k };
+        dnnl::memory::dims a_strides = {stra2, stra1, stra0};
         const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides);
+
+        dnnl::memory::dims b_dims = {batches_b, k, n };
+        dnnl::memory::dims b_strides = {strb2, strb0, strb1};
         const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides);
-        const auto c_md    = dnnl::memory::desc(c_dims, ct, tag::abc);
 
+        dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n};
+        dnnl::memory::dims c_strides = {m*n, 1,  m };
+        const auto c_md    = dnnl::memory::desc(c_dims, ct, c_strides);
         dnnl::primitive_attr primitive_attr;
         primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
+
 #ifdef GGML_SYCL_F16
         primitive_attr.set_fpmath_mode(dnnl::fpmath_mode::f16);
 #endif
@@ -76,24 +65,23 @@ class DnnlGemmWrapper {
 
         auto scratchpad_md = matmul_pd.scratchpad_desc();
         auto scratchpad_mem = ctx.get_scratchpad_mem(scratchpad_md, eng, q);
+
         auto matmul_prim = dnnl::matmul(matmul_pd);
 
         std::unordered_map matmul_args;
         matmul_args.insert({ DNNL_ARG_SRC, a_mem });
         matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
+
         matmul_args.insert({ DNNL_ARG_DST, c_mem });
         matmul_args.insert({ DNNL_ARG_SCRATCHPAD, scratchpad_mem });
 
         matmul_prim.execute(stream, matmul_args);
     }
 
-    // matrices A and B are column major, both having k rows
-    // matrix A has m column, matrix B has n columns
-    // output: column major matrix C = A transposed * B
     static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
         const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
 
-        gemm(ctx, m, n, k, a, at, k, 1, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1);
+        gemm(ctx, m, n, k, a, at, 1, k, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1);
     }
 };
 
diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp
index 65b26fd0276..a0a650e92e4 100644
--- a/ggml/src/ggml-sycl/ggml-sycl.cpp
+++ b/ggml/src/ggml-sycl/ggml-sycl.cpp
@@ -44,6 +44,7 @@
 #include "ggml-sycl/set_rows.hpp"
 #include "ggml-sycl/sycl_hw.hpp"
 #include "ggml-sycl/getrows.hpp"
+#include "ggml-sycl/quantize.hpp"
 #include "ggml.h"
 
 static bool g_sycl_loaded = false;
@@ -1373,120 +1374,6 @@ typedef void (*ggml_sycl_op_mul_mat_t)(
 
 
 
-template
-static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded,
-                          const sycl::nd_item<3> &item_ct1) {
-    const int ix = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                    item_ct1.get_local_id(2)) * QUANT_BLOCK_TILE;
-
-    if (ix >= kx_padded) {
-        return;
-    }
-
-    const int iy = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                   item_ct1.get_local_id(1);
-
-    const int i_padded = iy*kx_padded + ix;
-
-    block_q8_1 * y = (block_q8_1 *) vy;
-
-    const int ib = i_padded / QK8_1; // block index
-    const int iqs = i_padded % QK8_1; // quant index
-    typedef  sycl::vec TC;
-    typedef  sycl::vec TQ;
-    TC zeros;
-    TQ qzeros;
-#pragma unroll
-    for (int i = 0; i < QUANT_BLOCK_TILE; i++)
-    {
-        zeros[i] = 0.f;
-        qzeros[i] = 0;
-    }
-    const TC xi = ix < kx ? *(const TC *)&x[iy * kx + ix] : zeros;
-    float sum = xi[0];
-    float amax = sycl::fabs(xi[0]);
-#pragma unroll
-    for (int i = 1; i < QUANT_BLOCK_TILE; i++)
-    {
-        sum += xi[i];
-        amax = sycl::fmax(sycl::fabs(xi[i]), amax);
-    }
-    sum = warp_reduce_sum(sum, item_ct1);
-    amax = warp_reduce_max(amax, item_ct1);
-
-    const float d = amax / 127;
-    TQ q = qzeros;
-    if (amax != 0.0f)
-    {
-#pragma unroll
-        for (int i = 0; i < QUANT_BLOCK_TILE; i++) {
-            q[i] = sycl::round(xi[i] / d);
-        }
-    }
-
-    *(TQ *)&y[ib].qs[iqs] = q;
-
-    if (iqs > 0) {
-        return;
-    }
-
-    reinterpret_cast(y[ib].ds.x()) = d;
-    reinterpret_cast(y[ib].ds.y()) = sum;
-}
-
-template 
-static __dpct_inline__ void quantize_and_reorder_q8_1(const float * __restrict__ x, void * reordered_q8_tensor,
-                                                      const int kx, const int kx_padded, const sycl::nd_item<1> & it) {
-    /*
-        Quantizes and reorders the resultant q8 tensor in a per row fashion
-        Each sub-group calculates one quant block. i.e. QK8_1 quant values and the d and sum values
-    */
-
-    auto subgroup_id = it.get_group(0);
-    auto wi_id       = it.get_local_id(0);
-
-    const int num_blocks_per_row = kx / QK8_1;
-    auto      row                = subgroup_id / num_blocks_per_row;
-    auto      col                = subgroup_id % num_blocks_per_row;
-
-    auto row_offset = row * (kx_padded / QK8_1) * sizeof(block_q8_1);
-    auto col_offset = QK8_1 * col + wi_id * ElementsPerWI;
-
-    auto quant_ptr = (int8_t *) ((char *) reordered_q8_tensor + row_offset + col_offset);
-    auto ds_ptr    = (sycl::half2 *) ((char *) reordered_q8_tensor + row_offset + kx + col * sizeof(sycl::half2));
-
-    sycl::vec  wi_f32_vals;
-    sycl::vec quantized_values;
-
-    auto float_ptr_offset = subgroup_id * QK8_1 + ElementsPerWI * wi_id;
-    wi_f32_vals           = *reinterpret_cast *>(x + float_ptr_offset);
-
-    float sum  = 0.0f;
-    float amax = 0.0f;
-
-#pragma unroll(ElementsPerWI)
-    for (int i = 0; i < ElementsPerWI; i++) {
-        sum += wi_f32_vals[i];
-        amax                = sycl::fmax(amax, sycl::fabs(wi_f32_vals[i]));
-        quantized_values[i] = 0;
-    }
-    sum     = sycl::reduce_over_group(it.get_group(), sum, sycl::plus());
-    amax    = sycl::reduce_over_group(it.get_group(), amax, sycl::maximum());
-    float d = amax == 0 ? 1 : amax / 127;
-
-#pragma unroll(ElementsPerWI)
-    for (int i = 0; i < ElementsPerWI; i++) {
-        quantized_values[i] = sycl::round(wi_f32_vals[i] / d);
-    }
-
-    d = amax == 0 ? 0 : d;
-
-    *reinterpret_cast *>(quant_ptr) = quantized_values;
-    if (wi_id == 0) {
-        *ds_ptr = sycl::half2(sycl::half(d), sycl::half(sum));
-    }
-}
-
 static void mul_mat_p021_f16_f32(
     const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
     const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,
@@ -1546,7 +1433,7 @@ static void mul_mat_p021_f16_f32(
 
 static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
     const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
-    const int row_stride_x, const int channel_stride_x, const int channel_x_divisor,
+    const int row_stride_x, const int channel_stride_x,const int channel_stride_y, const int channel_x_divisor,
     const sycl::nd_item<3> &item_ct1) {
 
     const sycl::half *x = (const sycl::half *)vx;
@@ -1557,7 +1444,6 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
                         item_ct1.get_local_id(0);
     const int channel_x = channel / channel_x_divisor;
 
-    const int nrows_y   = ncols_x;
     const int nrows_dst = nrows_x;
     const int row_dst   = row_x;
 
@@ -1576,7 +1462,7 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
         const int row_y = col_x;
 
         const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
-        const int iy = channel*nrows_y + row_y;
+        const int iy = channel * channel_stride_y + row_y;
 
         const float xi =
             sycl::vec(x[ix])
@@ -1771,32 +1657,6 @@ static  void pool2d_nchw_kernel(
         o_ptr[cur_oh * ow + cur_ow] = res;
 }
 
-static void quantize_row_q8_1_sycl(const float * x, void * vy, const int kx, const int ky, const int kx_padded,
-                                   bool reorder_q8_tensor, queue_ptr stream) {
-    if (reorder_q8_tensor) {
-        auto local_range      = std::size_t(WARP_SIZE);
-        auto num_quant_blocks = ky * (kx / QK8_1);
-        auto global_range     = num_quant_blocks * local_range;
-        stream->parallel_for(sycl::nd_range<1>({ global_range }, { local_range }),
-                             [=](sycl::nd_item<1> it) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
-                                 quantize_and_reorder_q8_1(x, vy, kx, kx_padded, it);
-                             });
-    } else {
-        const int            block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE;
-        const sycl::range<3> num_blocks(1, ky, block_num_x);
-        int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
-        static_assert(QK8_1 % WARP_SIZE == 0);
-        const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
-        {
-            dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
-
-            stream->parallel_for(sycl::nd_range<3>(num_blocks * block_size, block_size),
-                                 [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
-                                     quantize_q8_1(x, vy, kx, kx_padded, item_ct1);
-                                 });
-        }
-    }
-}
 
 static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
                                            float *dst, const int ncols_x,
@@ -1823,7 +1683,7 @@ static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
 static void ggml_mul_mat_vec_nc_f16_f32_sycl(
     const void *vx, const float *y, float *dst, const int ncols_x,
     const int nrows_x, const int row_stride_x, const int nchannels_x,
-    const int nchannels_y, const int channel_stride_x, queue_ptr stream) {
+    const int nchannels_y, const int channel_stride_x, const int channel_stride_y, queue_ptr stream) {
 
     const sycl::range<3> block_nums(nchannels_y, nrows_x, 1);
     const sycl::range<3> block_dims(1, 1, WARP_SIZE);
@@ -1835,7 +1695,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
             sycl::nd_range<3>(block_nums * block_dims, block_dims),
             [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
                 mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
-                                       row_stride_x, channel_stride_x,
+                                       row_stride_x, channel_stride_x, channel_stride_y,
                                        nchannels_y / nchannels_x, item_ct1);
             });
     }
@@ -2124,8 +1984,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
 
 #if GGML_SYCL_DNNL
         if (!g_ggml_sycl_disable_dnn) {
-            DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
-                                      DnnlGemmWrapper::to_dt(), src0_ptr, DnnlGemmWrapper::to_dt(),
+                DnnlGemmWrapper::row_gemm(ctx,row_diff, src1_ncols , ne10, src0_ptr,
+                                     DnnlGemmWrapper::to_dt(), src1_ptr, DnnlGemmWrapper::to_dt(),
                                       dst_dd_i, DnnlGemmWrapper::to_dt(), stream);
         }
         else
@@ -2171,8 +2031,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
 
 #if GGML_SYCL_DNNL
         if (!g_ggml_sycl_disable_dnn) {
-            DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i,
-                                      DnnlGemmWrapper::to_dt(), src0_ddf_i, DnnlGemmWrapper::to_dt(),
+            DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, src0_ddf_i,
+                                      DnnlGemmWrapper::to_dt(), src1_ddf1_i, DnnlGemmWrapper::to_dt(),
                                       dst_dd_i, DnnlGemmWrapper::to_dt(), stream);
         }
         else
@@ -2373,10 +2233,10 @@ static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
     peer_access_enabled = enable_peer_access;
 }
 
+template