-
Notifications
You must be signed in to change notification settings - Fork 24.9k
[ARM] Integrate INT4→BF16 via KleidiAI, with fallback #158250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Co-authored-by: Nikhil Gupta <nikhil.gupta2@arm.com>
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158250
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit dcb6d4a with merge base ffaed8c ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchbot label "module: arm" |
if (cpuinfo_has_arm_bf16()) { | ||
kleidiai::kai_quant_pack_lhs_int4_mm_bf16_channelwise( | ||
output, input, weight, m, n, k); | ||
} else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no need to fallback to fp32. If the platform does not support bf16 arm vector or platform is not arm then fallback to bf16 scalar reference implementation.
if bf16 scalar is no supported, error out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done - 48e16e1
@@ -217,6 +237,17 @@ static void kai_quant_pack_lhs_int4_mm_channelwise( | |||
const int64_t m, | |||
const int64_t n, | |||
const int64_t k) { | |||
|
|||
at::Tensor input_fp32; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code is not needed as we are not falling back to fp32 path. We can keep the existing fp32 path untouched
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done - Removed
@@ -13,6 +15,19 @@ | |||
|
|||
namespace at::native::kleidiai { | |||
|
|||
at::Tensor kleidi_bf16_to_fp32(const at::Tensor& src) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No required. No fallback to fp32
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done - Removed
@@ -5,6 +5,8 @@ | |||
|
|||
namespace at::native::kleidiai { | |||
|
|||
at::Tensor kleidi_bf16_to_fp32(const at::Tensor& src); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done - Removed
3 // Channelwise 4 bit GEMM | ||
3, // Channelwise 4 bit GEMM | ||
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod = | ||
4, // Channelwise 4 bit GEMV |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please Fix comments and explain this is bf16 int4 gemm and how its different than other ones
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
} | ||
} // namespace at::native::kleidiai | ||
#endif | ||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a new line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
uint8_t* dst_act_mtx_f32 = reinterpret_cast<uint8_t*>(output.data_ptr()); | ||
const uint8_t* lhs_native_mtx_f32 = | ||
reinterpret_cast<const uint8_t*>(input.data_ptr()); | ||
TORCH_CHECK(input_fp32.scalar_type() == at::kFloat, "Input tensor must be float."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove the checks and warns from the performant code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -238,13 +269,17 @@ static void kai_quant_pack_lhs_int4_mm_channelwise( | |||
|
|||
const size_t lhs_packed_size = | |||
kernel_packet.kai_get_lhs_packed_size(m, k, mr, kr, sr); | |||
auto lhs_packed = std::make_unique<uint8_t[]>(lhs_packed_size); | |||
const size_t padding = 128; // extra bytes | |||
auto lhs_packed_tensor = at::empty({(int64_t)(lhs_packed_size + padding)}, at::kByte); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we using at::empty instead of the original unique array?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
at::empty handles allocation, alignment, padding safety, and cleanup, and it can reuse memory internally, gives all memory safety and integration preventing memory corruption
Co-author: Nikhil Gupta <nikhil.gupta2@arm.com>
float scalar_min, | ||
float scalar_max) { | ||
|
||
std::unique_ptr<int8_t[]> lhs_quantized(new int8_t[m * (k + sizeof(float) + sizeof(int32_t))]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are you adding elements to bytes here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it’s just adding extra bytes per row, 4 bytes for the scale (float) and 4 for the zero-point (int32).
Since int8_t is 1 byte, we’re still just allocating the total number of bytes needed. Not mixing types, just accounting for all the data in raw byte space.
scalar_max); | ||
} | ||
|
||
static void ref_dyn_quant_matmul_4bit_groupwise_kernel_bf16( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not needed as we are adding only bf16 channelwise kernel?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool I ll remove it
return result; | ||
} | ||
|
||
static inline uint16_t kai_cast_bf16_f32(float val) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove kai_* prefix from functions in the reference code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -793,6 +793,215 @@ bool can_use_kleidiai( | |||
} | |||
#endif | |||
|
|||
static inline size_t roundup(size_t a, size_t b) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we remove these helper functions from here and perform these operations withing reference kernel function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done - integrated internally
} | ||
} | ||
|
||
static void ref_quant_qa8dx_bf16(size_t m, size_t k, const uint16_t* lhs_native_mtx_bf16, int8_t* lhs_ref_mtx_qa8dx) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please move this inside reference kernel function like this
auto input_quant_pack_8bit_channelwise = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done - added as a lambda
} | ||
} | ||
|
||
static void ref_dyn_quant_matmul_4bit_channelwise_kernel_bf16( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we limit the whole reference kernel into one function instead of dividing it into smaller ones?
The kernel can effectively become ref_quant_qa8dx_bf16 + ref_matmul_mxn_mxk_nxk_bf16_qa8dx_qs4cx
We will move the lhs quant logic out to avoid code duplication later when we add bf16-int4 groupwise kernel as well as for f32-int4 kernels
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done - Refactored
@pytorchbot label "ciflow/linux-aarch64" |
To add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
@pytorchbot label "ciflow/linux-aarch64" |
Can we please add test cases for the bf16 data type? |
auto cast_bf16_to_f32 = [](uint16_t bf16_val) { | ||
uint32_t tmp = static_cast<uint32_t>(bf16_val) << 16; | ||
float f; | ||
std::memcpy(&f, &tmp, sizeof(f)); | ||
return f; | ||
}; | ||
|
||
// Cast float32 to bfloat16 inline | ||
auto cast_f32_to_bf16 = [](float f) { | ||
uint32_t bits; | ||
std::memcpy(&bits, &f, sizeof(bits)); | ||
return static_cast<uint16_t>(bits >> 16); | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wonder if we can make use of the vectorized class convert for the cast here. It might turn out that autovec does a good enough job but it's a possible perf improvement
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is the scalar reference implementation and meant to run on CPUs without SVE/NEON, I avoided vectorization intentionally. The focus here is correctness and portability, not performance. This inline cast is simple and safe for platforms without vector ISA.
float rmax = std::max(0.0f, mx); | ||
const float qmin = static_cast<float>(INT8_MIN); | ||
const float qmax = static_cast<float>(INT8_MAX); | ||
float scale = (rmin == rmax) ? 1.f : (qmax - qmin) / (rmax - rmin); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens where rmax = rmin? (we should catch this)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
its setting scale to 1 ? if rmin and rmax are equal then there will be no requirement for scaling im quantization as the tensor will have same values throughout
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to double check : can we guarantee this check (rmin == rmax) is safe when dealing with float values really close to each other?
float err_min = qmin + des_min; | ||
float err_max = qmax + des_max; | ||
float zp_f = (err_min + err_max) > 0 | ||
? qmin - des_min |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a standard way of calculating zero points?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we are using this logic for kleidiai reference kernels
kai_pack_int4_rhs in kai_kernels.cpp is not specialised for bf16 case. Even if the rhs kernels is same, we need to use the pack size and packing from the correct struct. kai_pack_rhs_int4_size is also not specialized for bf16 case |
Co-authored-by: Nikhil Gupta nikhil.gupta2@arm.com
This PR enables the use of KleidiAI INT4 kernels that directly produce BF16 outputs within PyTorch.
✅ Key Results
• Integration of KleidiAI direct support for INT4→BF16 kernel execution.
• Kernels exposed in PyTorch: INT4 channelwise kernels now support BF16 output when available.
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @jerryzh168 @malfet @snadampal @milpuz01 @aditew01 @nikhil-arm @fadara01 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben