-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[X86] Recognise VPMADD52L pattern with AVX512IFMA/AVXIFMA (#153787) #156714
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Match (X * Y) + Z in combineAdd. If target supports and we don't overflow, rewrite using VPMADD52L
@llvm/pr-subscribers-backend-x86 Author: Justin Riddell (Arghnews) ChangesMatch Have deliberately not put tests into existing Have just done the @RKSimon Full diff: https://github.com/llvm/llvm-project/pull/156714.diff 2 Files Affected:
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 47cea933d0836..bd0ab5fe96630 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -57966,6 +57966,51 @@ static SDValue pushAddIntoCmovOfConsts(SDNode *N, const SDLoc &DL,
Cmov.getOperand(3));
}
+static SDValue matchIntegerMultiplyAdd(SDNode *N, SelectionDAG &DAG,
+ SDValue Op0, SDValue Op1,
+ const SDLoc &DL, EVT VT,
+ const X86Subtarget &Subtarget) {
+ using namespace SDPatternMatch;
+ if (!VT.isVector() || VT.getScalarType() != MVT::i64 ||
+ !Subtarget.hasAVX512() ||
+ (!Subtarget.hasAVXIFMA() && !Subtarget.hasIFMA()) ||
+ !DAG.getTargetLoweringInfo().isOperationLegalOrCustom(X86ISD::VPMADD52L,
+ VT) ||
+ Op0.getValueType() != VT || Op1.getValueType() != VT)
+ return SDValue();
+
+ SDValue X, Y, Acc;
+ if (!sd_match(N, m_Add(m_Mul(m_Value(X), m_Value(Y)), m_Value(Acc))))
+ return SDValue();
+
+ auto CheckMulOperand = [&DAG, &VT](const SDValue &M, SDValue &Xval,
+ SDValue &Yval) -> bool {
+ if (M.getOpcode() != ISD::MUL)
+ return false;
+ const SDValue A = M.getOperand(0);
+ const SDValue B = M.getOperand(1);
+ const APInt Top12Set = APInt::getHighBitsSet(64, 12);
+ if (A.getValueType() != VT || B.getValueType() != VT ||
+ !DAG.MaskedValueIsZero(A, Top12Set) ||
+ !DAG.MaskedValueIsZero(B, Top12Set) ||
+ !DAG.MaskedValueIsZero(M, Top12Set))
+ return false;
+ Xval = A;
+ Yval = B;
+ return true;
+ };
+
+ if (CheckMulOperand(Op0, X, Y)) {
+ Acc = Op1;
+ } else if (CheckMulOperand(Op1, X, Y)) {
+ Acc = Op0;
+ } else {
+ return SDValue();
+ }
+
+ return DAG.getNode(X86ISD::VPMADD52L, DL, VT, Acc, X, Y);
+}
+
static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
@@ -58069,6 +58114,11 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
Op0.getOperand(0), Op0.getOperand(2));
}
+ if (SDValue node =
+ matchIntegerMultiplyAdd(N, DAG, Op0, Op1, DL, VT, Subtarget)) {
+ return node;
+ }
+
return combineAddOrSubToADCOrSBB(N, DL, DAG);
}
diff --git a/llvm/test/CodeGen/X86/ifma-combine-vpmadd52.ll b/llvm/test/CodeGen/X86/ifma-combine-vpmadd52.ll
new file mode 100644
index 0000000000000..6a37b1b814cdc
--- /dev/null
+++ b/llvm/test/CodeGen/X86/ifma-combine-vpmadd52.ll
@@ -0,0 +1,111 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -O1 -mtriple=x86_64-unknown-unknown -mattr=+avx512dq | FileCheck %s --check-prefixes=X64
+
+; 67108863 == (1 << 26) - 1
+
+define dso_local <8 x i64> @test_512_combine_evex(<8 x i64> noundef %0, <8 x i64> noundef %1, <8 x i64> noundef %2) local_unnamed_addr #0 {
+; X64-LABEL: test_512_combine_evex:
+; X64: # %bb.0:
+; X64-NEXT: vpbroadcastq {{.*#+}} zmm3 = [67108863,67108863,67108863,67108863,67108863,67108863,67108863,67108863]
+; X64-NEXT: vpandq %zmm3, %zmm0, %zmm0
+; X64-NEXT: vpandq %zmm3, %zmm1, %zmm1
+; X64-NEXT: vpandq %zmm3, %zmm2, %zmm2
+; X64-NOT: vpmul
+; X64-NOT: vpadd
+; X64-NEXT: vpmadd52luq %zmm1, %zmm2, %zmm0
+; X64-NEXT: retq
+ %4 = and <8 x i64> %0, splat (i64 67108863)
+ %5 = and <8 x i64> %1, splat (i64 67108863)
+ %6 = and <8 x i64> %2, splat (i64 67108863)
+ %7 = mul nuw nsw <8 x i64> %5, %4
+ %8 = add nuw nsw <8 x i64> %7, %6
+ ret <8 x i64> %8
+}
+
+define dso_local <8 x i64> @fff(<8 x i64> noundef %0, <8 x i64> noundef %1, <8 x i64> noundef %2) local_unnamed_addr #0 {
+ %4 = and <8 x i64> %0, splat (i64 67108863)
+ %5 = and <8 x i64> %1, splat (i64 67108863)
+ %6 = and <8 x i64> %2, splat (i64 67108863)
+ %7 = mul nuw nsw <8 x i64> %5, %4
+ %8 = mul nuw nsw <8 x i64> %7, %6
+ %9 = add nuw nsw <8 x i64> %8, %7
+ ret <8 x i64> %9
+}
+
+define dso_local noundef <8 x i64> @test_512_no_combine_evex(<8 x i64> noundef %0, <8 x i64> noundef %1, <8 x i64> noundef %2) local_unnamed_addr #0 {
+; X64-LABEL: test_512_no_combine_evex:
+; X64: # %bb.0:
+; X64-NOT: vpmadd52
+; X64-NEXT: vpmullq %zmm0, %zmm1, %zmm0
+; X64-NEXT: vpaddq %zmm2, %zmm0, %zmm0
+; X64-NEXT: retq
+ %4 = mul <8 x i64> %1, %0
+ %5 = add <8 x i64> %4, %2
+ ret <8 x i64> %5
+}
+
+define dso_local <4 x i64> @test_256_combine_evex(<4 x i64> noundef %0, <4 x i64> noundef %1, <4 x i64> noundef %2) local_unnamed_addr #1 {
+; X64-LABEL: test_256_combine_evex:
+; X64: # %bb.0:
+; X64-NEXT: vpbroadcastq {{.*#+}} ymm3 = [67108863,67108863,67108863,67108863]
+; X64-NEXT: vpand %ymm3, %ymm0, %ymm0
+; X64-NEXT: vpand %ymm3, %ymm1, %ymm1
+; X64-NEXT: vpand %ymm3, %ymm2, %ymm2
+; X64-NOT: vpmul
+; X64-NOT: vpadd
+; X64-NEXT: vpmadd52luq %ymm1, %ymm2, %ymm0
+; X64-NEXT: retq
+ %4 = and <4 x i64> %0, <i64 67108863, i64 67108863, i64 67108863, i64 67108863>
+ %5 = and <4 x i64> %1, <i64 67108863, i64 67108863, i64 67108863, i64 67108863>
+ %6 = and <4 x i64> %2, <i64 67108863, i64 67108863, i64 67108863, i64 67108863>
+ %7 = mul nuw nsw <4 x i64> %5, %4
+ %8 = add nuw nsw <4 x i64> %7, %6
+ ret <4 x i64> %8
+}
+
+define dso_local noundef <4 x i64> @test_256_no_combine_evex(<4 x i64> noundef %0, <4 x i64> noundef %1, <4 x i64> noundef %2) local_unnamed_addr #1 {
+; X64-LABEL: test_256_no_combine_evex:
+; X64: # %bb.0:
+; X64-NOT: vpmadd52
+; X64-NEXT: vpmullq %ymm0, %ymm1, %ymm0
+; X64-NEXT: vpaddq %ymm2, %ymm0, %ymm0
+; X64-NEXT: retq
+ %4 = mul <4 x i64> %1, %0
+ %5 = add <4 x i64> %4, %2
+ ret <4 x i64> %5
+}
+
+define dso_local <4 x i64> @test_256_combine_vex(<4 x i64> noundef %0, <4 x i64> noundef %1, <4 x i64> noundef %2) local_unnamed_addr #2 {
+; X64-LABEL: test_256_combine_vex:
+; X64: # %bb.0:
+; X64-NEXT: vpbroadcastq {{.*#+}} ymm3 = [67108863,67108863,67108863,67108863]
+; X64-NEXT: vpand %ymm3, %ymm0, %ymm0
+; X64-NEXT: vpand %ymm3, %ymm1, %ymm1
+; X64-NEXT: vpand %ymm3, %ymm2, %ymm2
+; X64-NOT: vpmul
+; X64-NOT: vpadd
+; X64-NEXT: {vex} vpmadd52luq %ymm1, %ymm2, %ymm0
+; X64-NEXT: retq
+ %4 = and <4 x i64> %0, <i64 67108863, i64 67108863, i64 67108863, i64 67108863>
+ %5 = and <4 x i64> %1, <i64 67108863, i64 67108863, i64 67108863, i64 67108863>
+ %6 = and <4 x i64> %2, <i64 67108863, i64 67108863, i64 67108863, i64 67108863>
+ %7 = mul nuw nsw <4 x i64> %5, %4
+ %8 = add nuw nsw <4 x i64> %7, %6
+ ret <4 x i64> %8
+}
+
+define dso_local noundef <4 x i64> @test_256_no_combine_vex(<4 x i64> noundef %0, <4 x i64> noundef %1, <4 x i64> noundef %2) local_unnamed_addr #2 {
+; X64-LABEL: test_256_no_combine_vex:
+; X64: # %bb.0:
+; X64-NOT: vpmadd52
+; X64-NEXT: vpmullq %ymm0, %ymm1, %ymm0
+; X64-NEXT: vpaddq %ymm2, %ymm0, %ymm0
+; X64-NEXT: retq
+ %4 = mul <4 x i64> %1, %0
+ %5 = add <4 x i64> %4, %2
+ ret <4 x i64> %5
+}
+
+attributes #0 = { mustprogress nofree norecurse nosync nounwind willreturn memory(none) uwtable "min-legal-vector-width"="512" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+avx,+avx2,+avx512dq,+avx512f,+avx512ifma,+cmov,+crc32,+cx8,+evex512,+f16c,+fma,+fxsr,+mmx,+popcnt,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+ssse3,+x87,+xsave" "tune-cpu"="generic" }
+attributes #1 = { mustprogress nofree norecurse nosync nounwind willreturn memory(none) uwtable "min-legal-vector-width"="256" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+avx,+avx2,+avx512dq,+avx512f,+avx512ifma,+avx512vl,+cmov,+crc32,+cx8,+evex512,+f16c,+fma,+fxsr,+mmx,+popcnt,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+ssse3,+x87,+xsave" "tune-cpu"="generic" }
+attributes #2 = { mustprogress nofree norecurse nosync nounwind willreturn memory(none) uwtable "min-legal-vector-width"="256" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+avx,+avx2,+avx512dq,+avx512f,+avx512vl,+avxifma,+cmov,+crc32,+cx8,+evex512,+f16c,+fma,+fxsr,+mmx,+popcnt,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+ssse3,+x87,+xsave" "tune-cpu"="generic" }
|
if (!VT.isVector() || VT.getScalarType() != MVT::i64 || | ||
!Subtarget.hasAVX512() || | ||
(!Subtarget.hasAVXIFMA() && !Subtarget.hasIFMA()) || | ||
!DAG.getTargetLoweringInfo().isOperationLegalOrCustom(X86ISD::VPMADD52L, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should be able to relax the type requirements and just use SplitOpsAndApply - see matchPMADDWD_2 for reference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still unclear on exactly how to use SplitOpsAndApply but will have a further look.
Can you help me understand, I believe I pass VT to SplitOpsAndApply and in the lambda, what VT will I pass to the DAG.getNode call?
Should I be able to generate a test case for this, that will fail before and pass after, by tinkering with the mattr options, and the case would be a vector that is too wide or too narrow?
✅ With the latest revision this PR passed the undef deprecator. |
Match
(X * Y) + Z
incombineAdd
. If target supports and we don't overflow (ie. we know the top 12 bits are unset), rewrite using VPMADD52LHave deliberately not put tests into existing
combine-vpmadd52.ll
as the flags make it a painHave just done the
L
version for now at least, wanted to get feedback before continuing@RKSimon