Skip to content

[ARM] Integrate INT4→BF16 via KleidiAI, with fallback #158250

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

usamahz
Copy link

@usamahz usamahz commented Jul 14, 2025

Co-authored-by: Nikhil Gupta nikhil.gupta2@arm.com

This PR enables the use of KleidiAI INT4 kernels that directly produce BF16 outputs within PyTorch.

✅ Key Results
• Integration of KleidiAI direct support for INT4→BF16 kernel execution.
• Kernels exposed in PyTorch: INT4 channelwise kernels now support BF16 output when available.

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168 @malfet @snadampal @milpuz01 @aditew01 @nikhil-arm @fadara01 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

Co-authored-by: Nikhil Gupta <nikhil.gupta2@arm.com>
@pytorch-bot pytorch-bot bot added module: cpu CPU specific problem (e.g., perf, algorithm) release notes: linalg_frontend release notes category labels Jul 14, 2025
Copy link

pytorch-bot bot commented Jul 14, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158250

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 49979eb with merge base ffaed8c (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@usamahz
Copy link
Author

usamahz commented Jul 14, 2025

@pytorchbot label "module: arm"

@pytorch-bot pytorch-bot bot added the module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 label Jul 14, 2025
if (cpuinfo_has_arm_bf16()) {
kleidiai::kai_quant_pack_lhs_int4_mm_bf16_channelwise(
output, input, weight, m, n, k);
} else {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no need to fallback to fp32. If the platform does not support bf16 arm vector or platform is not arm then fallback to bf16 scalar reference implementation.
if bf16 scalar is no supported, error out.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done - 48e16e1

@@ -217,6 +237,17 @@ static void kai_quant_pack_lhs_int4_mm_channelwise(
const int64_t m,
const int64_t n,
const int64_t k) {

at::Tensor input_fp32;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is not needed as we are not falling back to fp32 path. We can keep the existing fp32 path untouched

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done - Removed

@@ -13,6 +15,19 @@

namespace at::native::kleidiai {

at::Tensor kleidi_bf16_to_fp32(const at::Tensor& src) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No required. No fallback to fp32

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done - Removed

@@ -5,6 +5,8 @@

namespace at::native::kleidiai {

at::Tensor kleidi_bf16_to_fp32(const at::Tensor& src);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done - Removed

3 // Channelwise 4 bit GEMM
3, // Channelwise 4 bit GEMM
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod =
4, // Channelwise 4 bit GEMV
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please Fix comments and explain this is bf16 int4 gemm and how its different than other ones

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}
} // namespace at::native::kleidiai
#endif
#endif
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a new line

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

uint8_t* dst_act_mtx_f32 = reinterpret_cast<uint8_t*>(output.data_ptr());
const uint8_t* lhs_native_mtx_f32 =
reinterpret_cast<const uint8_t*>(input.data_ptr());
TORCH_CHECK(input_fp32.scalar_type() == at::kFloat, "Input tensor must be float.");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the checks and warns from the performant code

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -238,13 +269,17 @@ static void kai_quant_pack_lhs_int4_mm_channelwise(

const size_t lhs_packed_size =
kernel_packet.kai_get_lhs_packed_size(m, k, mr, kr, sr);
auto lhs_packed = std::make_unique<uint8_t[]>(lhs_packed_size);
const size_t padding = 128; // extra bytes
auto lhs_packed_tensor = at::empty({(int64_t)(lhs_packed_size + padding)}, at::kByte);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we using at::empty instead of the original unique array?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at::empty handles allocation, alignment, padding safety, and cleanup, and it can reuse memory internally, gives all memory safety and integration preventing memory corruption

Co-author: Nikhil Gupta <nikhil.gupta2@arm.com>
float scalar_min,
float scalar_max) {

std::unique_ptr<int8_t[]> lhs_quantized(new int8_t[m * (k + sizeof(float) + sizeof(int32_t))]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you adding elements to bytes here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it’s just adding extra bytes per row, 4 bytes for the scale (float) and 4 for the zero-point (int32).

Since int8_t is 1 byte, we’re still just allocating the total number of bytes needed. Not mixing types, just accounting for all the data in raw byte space.

scalar_max);
}

static void ref_dyn_quant_matmul_4bit_groupwise_kernel_bf16(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not needed as we are adding only bf16 channelwise kernel?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool I ll remove it

return result;
}

static inline uint16_t kai_cast_bf16_f32(float val) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove kai_* prefix from functions in the reference code

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -793,6 +793,215 @@ bool can_use_kleidiai(
}
#endif

static inline size_t roundup(size_t a, size_t b) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove these helper functions from here and perform these operations withing reference kernel function?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done - integrated internally

}
}

static void ref_quant_qa8dx_bf16(size_t m, size_t k, const uint16_t* lhs_native_mtx_bf16, int8_t* lhs_ref_mtx_qa8dx) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please move this inside reference kernel function like this

auto input_quant_pack_8bit_channelwise =

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done - added as a lambda

}
}

static void ref_dyn_quant_matmul_4bit_channelwise_kernel_bf16(
Copy link
Collaborator

@nikhil-arm nikhil-arm Jul 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we limit the whole reference kernel into one function instead of dividing it into smaller ones?
The kernel can effectively become ref_quant_qa8dx_bf16 + ref_matmul_mxn_mxk_nxk_bf16_qa8dx_qs4cx

We will move the lhs quant logic out to avoid code duplication later when we add bf16-int4 groupwise kernel as well as for f32-int4 kernels

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done - Refactored

@robert-hardwick
Copy link
Collaborator

@pytorchbot label "ciflow/linux-aarch64"

@pytorch-bot pytorch-bot bot added the ciflow/linux-aarch64 linux aarch64 CI workflow label Jul 21, 2025
Copy link

pytorch-bot bot commented Jul 21, 2025

To add the ciflow label ciflow/linux-aarch64 please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot pytorch-bot bot removed the ciflow/linux-aarch64 linux aarch64 CI workflow label Jul 21, 2025
@usamahz
Copy link
Author

usamahz commented Jul 21, 2025

@pytorchbot label "ciflow/linux-aarch64"

@pytorch-bot pytorch-bot bot added the ciflow/linux-aarch64 linux aarch64 CI workflow label Jul 21, 2025
@usamahz usamahz changed the title Integrate INT4→BF16 via KleidiAI, with fallback to FP32 Integrate INT4→BF16 via KleidiAI, with fallback Jul 21, 2025
@pytorch-bot pytorch-bot bot removed the ciflow/linux-aarch64 linux aarch64 CI workflow label Jul 23, 2025
Comment on lines +811 to +824
auto cast_bf16_to_f32 = [](uint16_t bf16_val) {
uint32_t tmp = static_cast<uint32_t>(bf16_val) << 16;
float f;
std::memcpy(&f, &tmp, sizeof(f));
return f;
};

// Cast float32 to bfloat16 inline
auto cast_f32_to_bf16 = [](float f) {
uint32_t bits;
std::memcpy(&bits, &f, sizeof(bits));
return static_cast<uint16_t>(bits >> 16);
};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonder if we can make use of the vectorized class convert for the cast here. It might turn out that autovec does a good enough job but it's a possible perf improvement

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is the scalar reference implementation and meant to run on CPUs without SVE/NEON, I avoided vectorization intentionally. The focus here is correctness and portability, not performance. This inline cast is simple and safe for platforms without vector ISA.

float rmax = std::max(0.0f, mx);
const float qmin = static_cast<float>(INT8_MIN);
const float qmax = static_cast<float>(INT8_MAX);
float scale = (rmin == rmax) ? 1.f : (qmax - qmin) / (rmax - rmin);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens where rmax = rmin? (we should catch this)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its setting scale to 1 ? if rmin and rmax are equal then there will be no requirement for scaling im quantization as the tensor will have same values throughout

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to double check : can we guarantee this check (rmin == rmax) is safe when dealing with float values really close to each other?

float err_min = qmin + des_min;
float err_max = qmax + des_max;
float zp_f = (err_min + err_max) > 0
? qmin - des_min
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a standard way of calculating zero points?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we are using this logic for kleidiai reference kernels

@usamahz usamahz changed the title Integrate INT4→BF16 via KleidiAI, with fallback [ARM] Integrate INT4→BF16 via KleidiAI, with fallback Jul 25, 2025
@nikhil-arm
Copy link
Collaborator

nikhil-arm commented Aug 5, 2025

kai_pack_int4_rhs in kai_kernels.cpp is not specialised for bf16 case. Even if the rhs kernels is same, we need to use the pack size and packing from the correct struct.
Same for torch/_meta_registrations.py. Getting kai packed size in python is not specialized to bf16 case

kai_pack_rhs_int4_size is also not specialized for bf16 case

Add Bias in BF16 Refernce Kernels

Refactoring

Update Comments
@usamahz usamahz force-pushed the integrate/int4-bf16-kleidiai branch from dcb6d4a to 49979eb Compare August 12, 2025 13:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 module: cpu CPU specific problem (e.g., perf, algorithm) module: inductor open source release notes: linalg_frontend release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants