-
Notifications
You must be signed in to change notification settings - Fork 13.4k
clastb representation in existing IR, and AArch64 codegen #112738
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-selectiondag @llvm/pr-subscribers-backend-aarch64 Author: Graham Hunter (huntergr-arm) ChangesThese commits show a possible representation of SVE's At 9 instructions to represent Perhaps something like This is something we would use to support the CSA vectorization in #106560 for SVE, though we would prefer to use clastb inside the vector loop instead of after it. That patch uses an int max reduction to determine the index instead of the cttz.elts based approach in this PR, so we have another existing IR option to use if we want. Full diff: https://github.com/llvm/llvm-project/pull/112738.diff 2 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b5657584016ea6..ccd89af6eb8770 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -19512,6 +19512,57 @@ performLastTrueTestVectorCombine(SDNode *N,
return getPTest(DAG, N->getValueType(0), Pg, N0, AArch64CC::LAST_ACTIVE);
}
+static SDValue
+performLastActiveExtractEltCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ const AArch64Subtarget *Subtarget) {
+ SDValue Index = N->getOperand(1);
+ // FIXME: Make this more generic. Should be a utility func somewhere?
+ if (Index.getOpcode() == ISD::ZERO_EXTEND)
+ Index = Index.getOperand(0);
+
+ // Looking for an add of an inverted value.
+ if (Index.getOpcode() != ISD::ADD)
+ return SDValue();
+
+ SDValue Size = Index.getOperand(1);
+
+ if (Size.getOpcode() == ISD::TRUNCATE)
+ Size = Size.getOperand(0);
+
+ // Check that we're looking at the size of the overall vector...
+ // FIXME: What about VSL codegen?
+ if (Size.getOpcode() != ISD::VSCALE)
+ return SDValue();
+
+ unsigned NElts = N->getOperand(0)->getValueType(0).getVectorElementCount().getKnownMinValue();
+ if (Size.getConstantOperandVal(0) != NElts)
+ return SDValue();
+
+ SDValue Invert = Index.getOperand(0);
+ if (Invert.getOpcode() != ISD::XOR)
+ return SDValue();
+
+ if (!Invert.getConstantOperandAPInt(1).isAllOnes())
+ return SDValue();
+
+ SDValue LZeroes = Invert.getOperand(0);
+ if (LZeroes.getOpcode() == ISD::TRUNCATE)
+ LZeroes = LZeroes.getOperand(0);
+
+ // Check that we're looking at a cttz.elts from a reversed predicate...
+ if (LZeroes.getOpcode() != AArch64ISD::CTTZ_ELTS)
+ return SDValue();
+
+ SDValue Pred = LZeroes.getOperand(0);
+ if (Pred.getOpcode() != ISD::VECTOR_REVERSE)
+ return SDValue();
+
+ // Matched a LASTB pattern.
+ return DCI.DAG.getNode(AArch64ISD::LASTB, SDLoc(N), N->getValueType(0),
+ Pred.getOperand(0), N->getOperand(0));
+}
+
static SDValue
performExtractVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
@@ -19520,6 +19571,8 @@ performExtractVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
return Res;
if (SDValue Res = performLastTrueTestVectorCombine(N, DCI, Subtarget))
return Res;
+ if (SDValue Res = performLastActiveExtractEltCombine(N, DCI, Subtarget))
+ return Res;
SelectionDAG &DAG = DCI.DAG;
SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
@@ -24363,6 +24416,50 @@ static SDValue foldCSELOfCSEL(SDNode *Op, SelectionDAG &DAG) {
return DAG.getNode(AArch64ISD::CSEL, DL, VT, L, R, CCValue, Cond);
}
+static SDValue foldCSELOfLASTB(SDNode *N, SelectionDAG &DAG) {
+ SDValue Op0 = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+ AArch64CC::CondCode CC =
+ static_cast<AArch64CC::CondCode>(N->getConstantOperandVal(2));
+ SDValue PTAny = N->getOperand(3);
+
+ // FIXME: Handle the inverse?
+ if (Op0.getOpcode() != AArch64ISD::LASTB)
+ return SDValue();
+
+ if (PTAny.getOpcode() != AArch64ISD::PTEST_ANY)
+ return SDValue();
+
+ // Get the predicate...
+ SDValue LBPred = Op0.getOperand(0);
+
+ // Look through reinterprets...
+ SDValue PTestPG = PTAny.getOperand(0);
+ if (PTestPG.getOpcode() == AArch64ISD::REINTERPRET_CAST)
+ PTestPG = PTestPG.getOperand(0);
+
+ SDValue PTestOp = PTAny.getOperand(1);
+ if (PTestOp.getOpcode() == AArch64ISD::REINTERPRET_CAST)
+ PTestOp = PTestOp.getOperand(0);
+
+ // And compare against the csel cmp.
+ // Make sure the same predicate is used.
+ if (PTestOp != LBPred)
+ return SDValue();
+
+ // Make sure that PG for the test is either the same as the input or
+ // an explicit ptrue.
+ // FIXME:... look for ptrue_all instead of just ptrue...
+ if (PTestPG != LBPred && PTestPG.getOpcode() != AArch64ISD::PTRUE)
+ return SDValue();
+
+ if (CC != AArch64CC::NE)
+ return SDValue();
+
+ return DAG.getNode(AArch64ISD::CLASTB_N, SDLoc(N), N->getValueType(0),
+ LBPred, Op1, Op0.getOperand(1));
+}
+
// Optimize CSEL instructions
static SDValue performCSELCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
@@ -24379,6 +24476,9 @@ static SDValue performCSELCombine(SDNode *N,
if (SDValue Folded = foldCSELofCTTZ(N, DAG))
return Folded;
+ if (SDValue CLastB = foldCSELOfLASTB(N, DAG))
+ return CLastB;
+
return performCONDCombine(N, DCI, DAG, 2, 3);
}
diff --git a/llvm/test/CodeGen/AArch64/sve-clastb.ll b/llvm/test/CodeGen/AArch64/sve-clastb.ll
new file mode 100644
index 00000000000000..e2ff1e478f6b87
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-clastb.ll
@@ -0,0 +1,104 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 3
+; RUN: llc -mtriple=aarch64-linux-unknown -mattr=+sve -o - < %s | FileCheck %s
+
+define i8 @clastb_i8(<vscale x 16 x i8> %data, <vscale x 16 x i1> %pg, i8 %existing) {
+; CHECK-LABEL: clastb_i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: clastb w0, p0, w0, z0.b
+; CHECK-NEXT: ret
+ %rev.pg = call <vscale x 16 x i1> @llvm.vector.reverse.nxv16i1(<vscale x 16 x i1> %pg)
+ %tz.cnt = call i32 @llvm.experimental.cttz.elts.i32.nxv16i1(<vscale x 16 x i1> %rev.pg, i1 false)
+ %any.set = call i1 @llvm.vector.reduce.or.nxv16i1(<vscale x 16 x i1> %pg)
+ %vscale = call i32 @llvm.vscale.i32()
+ %size = shl i32 %vscale, 4
+ %sub = sub i32 %size, %tz.cnt
+ %idx = sub i32 %sub, 1
+ %extr = extractelement <vscale x 16 x i8> %data, i32 %idx
+ %res = select i1 %any.set, i8 %extr, i8 %existing
+ ret i8 %res
+}
+
+define i16 @clastb_i16(<vscale x 8 x i16> %data, <vscale x 8 x i1> %pg, i16 %existing) {
+; CHECK-LABEL: clastb_i16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: clastb w0, p0, w0, z0.h
+; CHECK-NEXT: ret
+ %rev.pg = call <vscale x 8 x i1> @llvm.vector.reverse.nxv8i1(<vscale x 8 x i1> %pg)
+ %tz.cnt = call i32 @llvm.experimental.cttz.elts.i32.nxv8i1(<vscale x 8 x i1> %rev.pg, i1 false)
+ %any.set = call i1 @llvm.vector.reduce.or.nxv8i1(<vscale x 8 x i1> %pg)
+ %vscale = call i32 @llvm.vscale.i32()
+ %size = shl i32 %vscale, 3
+ %sub = sub i32 %size, %tz.cnt
+ %idx = sub i32 %sub, 1
+ %extr = extractelement <vscale x 8 x i16> %data, i32 %idx
+ %res = select i1 %any.set, i16 %extr, i16 %existing
+ ret i16 %res
+}
+
+define i32 @clastb_i32(<vscale x 4 x i32> %data, <vscale x 4 x i1> %pg, i32 %existing) {
+; CHECK-LABEL: clastb_i32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: clastb w0, p0, w0, z0.s
+; CHECK-NEXT: ret
+ %rev.pg = call <vscale x 4 x i1> @llvm.vector.reverse.nxv4i1(<vscale x 4 x i1> %pg)
+ %tz.cnt = call i32 @llvm.experimental.cttz.elts.i32.nxv4i1(<vscale x 4 x i1> %rev.pg, i1 false)
+ %any.set = call i1 @llvm.vector.reduce.or.nxv4i1(<vscale x 4 x i1> %pg)
+ %vscale = call i32 @llvm.vscale.i32()
+ %size = shl i32 %vscale, 2
+ %sub = sub i32 %size, %tz.cnt
+ %idx = sub i32 %sub, 1
+ %extr = extractelement <vscale x 4 x i32> %data, i32 %idx
+ %res = select i1 %any.set, i32 %extr, i32 %existing
+ ret i32 %res
+}
+
+define i64 @clastb_i64(<vscale x 2 x i64> %data, <vscale x 2 x i1> %pg, i64 %existing) {
+; CHECK-LABEL: clastb_i64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: clastb x0, p0, x0, z0.d
+; CHECK-NEXT: ret
+ %rev.pg = call <vscale x 2 x i1> @llvm.vector.reverse.nxv2i1(<vscale x 2 x i1> %pg)
+ %tz.cnt = call i32 @llvm.experimental.cttz.elts.i32.nxv2i1(<vscale x 2 x i1> %rev.pg, i1 false)
+ %any.set = call i1 @llvm.vector.reduce.or.nxv2i1(<vscale x 2 x i1> %pg)
+ %vscale = call i32 @llvm.vscale.i32()
+ %size = shl i32 %vscale, 1
+ %sub = sub i32 %size, %tz.cnt
+ %idx = sub i32 %sub, 1
+ %extr = extractelement <vscale x 2 x i64> %data, i32 %idx
+ %res = select i1 %any.set, i64 %extr, i64 %existing
+ ret i64 %res
+}
+
+define float @clastb_float(float %existing, <vscale x 4 x float> %data, <vscale x 4 x i1> %pg) {
+; CHECK-LABEL: clastb_float:
+; CHECK: // %bb.0:
+; CHECK-NEXT: clastb s0, p0, s0, z1.s
+; CHECK-NEXT: ret
+ %rev.pg = call <vscale x 4 x i1> @llvm.vector.reverse.nxv4i1(<vscale x 4 x i1> %pg)
+ %tz.cnt = call i32 @llvm.experimental.cttz.elts.float.nxv4i1(<vscale x 4 x i1> %rev.pg, i1 false)
+ %any.set = call i1 @llvm.vector.reduce.or.nxv4i1(<vscale x 4 x i1> %pg)
+ %vscale = call i32 @llvm.vscale.float()
+ %size = shl i32 %vscale, 2
+ %sub = sub i32 %size, %tz.cnt
+ %idx = sub i32 %sub, 1
+ %extr = extractelement <vscale x 4 x float> %data, i32 %idx
+ %res = select i1 %any.set, float %extr, float %existing
+ ret float %res
+}
+
+define double @clastb_double(double %existing, <vscale x 2 x double> %data, <vscale x 2 x i1> %pg) {
+; CHECK-LABEL: clastb_double:
+; CHECK: // %bb.0:
+; CHECK-NEXT: clastb d0, p0, d0, z1.d
+; CHECK-NEXT: ret
+ %rev.pg = call <vscale x 2 x i1> @llvm.vector.reverse.nxv2i1(<vscale x 2 x i1> %pg)
+ %tz.cnt = call i32 @llvm.experimental.cttz.elts.i32.nxv2i1(<vscale x 2 x i1> %rev.pg, i1 false)
+ %any.set = call i1 @llvm.vector.reduce.or.nxv2i1(<vscale x 2 x i1> %pg)
+ %vscale = call i32 @llvm.vscale.i32()
+ %size = shl i32 %vscale, 1
+ %sub = sub i32 %size, %tz.cnt
+ %idx = sub i32 %sub, 1
+ %extr = extractelement <vscale x 2 x double> %data, i32 %idx
+ %res = select i1 %any.set, double %extr, double %existing
+ ret double %res
+}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
Reliably matching the integer arithmetic seems problematic; I think I'd lean towards an intrinsic. Maybe |
As discussed in #112738, it may be better to have an intrinsic to represent vector element extracts based on mask bits. This intrinsic is for the case of extracting the last active element, if any, or a default value if the mask is all-false. The target-agnostic SelectionDAG lowering is similar to the IR in #106560.
b80e292
to
5df59e4
Compare
Rebased on top of the intrinsic. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me apart from a doc comment needed.
@@ -490,6 +490,11 @@ class TargetLoweringBase { | |||
return true; | |||
} | |||
|
|||
virtual bool | |||
shouldExpandVectorExtractLastActive(const IntrinsicInst *I) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we'll want a doc comment here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bad TargetLowering API with no use context. SelectionDAGBuilder should directly translate into dag nodes and not be a point of configuration
I was following a pattern of existing TLI functions (starting with I'll investigate leaving it as an intrinsic and moving the current code in SelectionDAGBuilder to an expand legalization function instead. |
It's developing an ad-hoc, parallel legalization mechanism to the regular one that happens in a different place. This expands the surface area of developing a target, and I have endless trouble fighting with non-trivial logic in the builder. All the legalization decisions should be made during legalization, not before |
5df59e4
to
e7bfbf6
Compare
Rebased on the vecops-expansion-based lowering from #118810. I implemented basic SVE isel pattern matching for the new ISD node, then added a couple new dag combines to fold to lastb, then clastb. |
@arsenm has the changed design addressed your concerns? The SDAGBuilder lowering is the same for all targets, and now I'm just marking the new ISD node as Legal for AArch64 to avoid the default vecops expansion here to improve SVE codegen. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, I think @arsenm's main concern as been addressed as there's no new target lowering hook anymore. I think the PR description needs updating though as the llvm.vector.extract.last.active
intrinsic now exists and this patch is improves the SVE lowering.
: SDNode<"ISD::VECTOR_FIND_LAST_ACTIVE", SDTypeProfile<1, 1, []>, []>; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it worth adding any constrains for find_last_active
?
// Find index of last active lane. This is a fallback in case we miss the | ||
// opportunity to fold into a lastb or clastb directly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these fallback patterns tested in the final patch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree it would be good to have some tests for these.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sadly it's pretty difficult to do this once the combines have been added. I don't see a global switch to disable combining, just the target-indepedent combines. We do check the optimization level in a few places in AArch64ISelLowering, but mostly for TLI methods for IR-level decisions. Deliberately turning off (c)lastb pattern matching at O0 feels odd. Adding a new switch just for this feature also feels excessive.
I could potentially add a globalisel-based test, though I'm not sure how much code that requires. We've added a few new ISD nodes recently, and none have added support in globalisel.
I guess this is mostly due to it being hard to just create a selectiondag without IR and run selection over it.
// Find index of last active lane. This is a fallback in case we miss the | ||
// opportunity to fold into a lastb or clastb directly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree it would be good to have some tests for these.
SDValue Vec = N->getOperand(0); | ||
SDValue Idx = N->getOperand(1); | ||
|
||
if (!Subtarget->hasSVE() || DCI.isBeforeLegalize() || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: the test for Subtarget->hasSVE()
is redundant and tested by the getOperationAction(..) == legal
below.
if (LastB.getOpcode() != AArch64ISD::LASTB || LastB.getOperand(0) != AnyPred) | ||
return SDValue(); | ||
|
||
SDValue Vec = LastB.getOperand(1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: only has one use, can just as well propagate into expression below.
SDValue Idx = N->getOperand(1); | ||
|
||
if (!Subtarget->hasSVE() || DCI.isBeforeLegalize() || | ||
Idx.getOpcode() != ISD::VECTOR_FIND_LAST_ACTIVE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LASTB is only supported for element types i8, i16, i32 and i64. Could you add a check for this and add some tests for doing a llvm.experimental.vector.extract.last.active
on a predicate vector?
if (TLI.getOperationAction(ISD::VECTOR_FIND_LAST_ACTIVE, | ||
Mask.getValueType()) != TargetLowering::Legal) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isOperationLegal?
e7bfbf6
to
cf5b283
Compare
// Only legal for 8, 16, 32, and 64 bit element types. | ||
EVT EltVT = Vec.getValueType().getVectorElementType(); | ||
if (!is_contained( | ||
ArrayRef({MVT::i8, MVT::i16, MVT::i32, MVT::i64, MVT::f32, MVT::f64}), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about f16 and bf16?
…#112738) This patch improves SVE codegen for the vector extract last active intrinsic, using either the lastb instruction (if the passthru value was poison or undef), or the clastb instruction.
These commits show a possible representation of SVE's
clastb
instruction using existing IR instructions and intrinsics, along with DAGCombines to emit the actual instruction.At 9 instructions to represent
clastb
, it feels a little fragile and may be changed by passes between LoopVectorize and codegen. While we can sink the loop-invariant terms back into the right block in CGP, I do wonder if we want a more direct intrinsic to represent this kind of operation.Perhaps something like
llvm.vector.extract.last.active(data, mask)
?This is something we would use to support the CSA vectorization in #106560 for SVE, though we would prefer to use clastb inside the vector loop instead of after it. That patch uses an int max reduction to determine the index instead of the cttz.elts based approach in this PR, so we have another existing IR option to use if we want.