-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Enable explicitly vectorized _weight_int8pack_mm
op for FP16 dtype on x86_64 CPU
#146777
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Merge AVX512 & AVX2 implementations of vectorized int8 WoQ GEMM for CPU, and make it common for Float/BFloat16/Half dtypes. TODO - [ ] Run CI for all devices - [ ] Check for regressions
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/146777
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 7792c8d with merge base cc444e7 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
_weight_int8pack_mm
on CPU for Half dtype as well_weight_int8pack_mm
op on CPU for FP16
_weight_int8pack_mm
op on CPU for FP16_weight_int8pack_mm
op on CPU for FP16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
_weight_int8pack_mm
op on CPU for FP16_weight_int8pack_mm
op for FP16 dtype on x86_64 CPU
@pytorchbot merge |
Merge failedReason: Approvers from one of the following sets are needed:
|
Hi @malfet, can you please help review & land this PR? Thank you! |
@pytorchbot rebase -b main |
@pytorchbot started a rebase job onto refs/remotes/origin/main. Check the current status here |
Rebase failed due to Command
Raised by https://github.com/pytorch/pytorch/actions/runs/15487849970 |
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Summary
Currently,
_weight_int8pack_mm
is only explicitly vectorized for BF16 activations for x86_64 CPU, and has different AVX2 & AVX512 implementations.This PR unifies its separate AVX512 & AVX2 implementations, and also makes it common for Float/BFloat16/Half activation dtypes, which is feasible since compute & accumulation happen in FP32 even in case of FP16/BF16 activations.
Most of the code added in this PR has been copy-pasted from Inductor-CPP FP32 GEMM micro-kernel template (so, credits to the original authors).
There's no performance regression. The input shapes (M, N, K) benchmarked are:
[1, 4096, 4096], [1, 4096, 11008], [1, 11008, 4096], [4, 4096, 4096], [4, 4096, 11008], [4, 11008, 4096], [1, 4096, 14336], [1, 14336, 4096], [4, 4096, 14336], [4, 14336, 4096]
Intel OpenMP & tcmalloc were preloaded for benchmarking.
Now the non-vectorized (not explicitly vectorized) micro-kernel would only be used when:
1
ATEN_CPU_CAPABILITY
is default.2. x86_64 CPUs MSVC builds.
3. aarch64 builds with
C10_MOBILE
true? Not sure if such builds exist on PyTorch CIcc @jgong5 @mingfeima @XiaobingSuper @ashokei @jingxu10 @jerryzh168