-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[AArch64] Improve lowering for scalable masked deinterleaving loads #154338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-llvm-selectiondag @llvm/pr-subscribers-backend-aarch64 Author: David Sherwood (david-arm) ChangesFor IR like this: %mask = ... @llvm.vector.interleave2(<vscale x 16 x i1> %a, <vscale x 16 x i1> %a) where we're deinterleaving a wide masked load of the supported type This PR adds a DAG combine to spot such patterns and lower to ld2X Patch is 39.15 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/154338.diff 2 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 63a85faf344c4..cff28e006b22e 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1178,6 +1178,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setTargetDAGCombine(ISD::SCALAR_TO_VECTOR);
setTargetDAGCombine(ISD::SHL);
+ setTargetDAGCombine(ISD::VECTOR_DEINTERLEAVE);
// In case of strict alignment, avoid an excessive number of byte wide stores.
MaxStoresPerMemsetOptSize = 8;
@@ -27010,6 +27011,121 @@ performScalarToVectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
return NVCAST;
}
+static SDValue performVectorDeinterleaveCombine(
+ SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) {
+ unsigned NumParts = N->getNumOperands();
+ if (NumParts != 2 && NumParts != 4)
+ return SDValue();
+
+ EVT SubVecTy = N->getValueType(0);
+
+ // At the moment we're unlikely to see a fixed-width vector deinterleave as
+ // we usually generate shuffles instead.
+ unsigned MinNumElements = SubVecTy.getVectorMinNumElements();
+ if (!SubVecTy.isScalableVT() ||
+ SubVecTy.getSizeInBits().getKnownMinValue() != 128 || MinNumElements == 1)
+ return SDValue();
+
+ // Make sure each input operand is the correct extract_subvector of the same
+ // wider vector.
+ SDValue Op0 = N->getOperand(0);
+ for (unsigned I = 0; I < NumParts; I++) {
+ SDValue OpI = N->getOperand(I);
+ if (OpI->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+ OpI->getOperand(0) != Op0->getOperand(0))
+ return SDValue();
+ auto *Idx = cast<ConstantSDNode>(OpI->getOperand(1));
+ if (Idx->getZExtValue() != (I * MinNumElements))
+ return SDValue();
+ }
+
+ // Normal loads are currently already handled by the InterleavedAccessPass so
+ // we don't expect to see them here. Bail out if the masked load has an
+ // unexpected number of uses, since we want to avoid a situation where we have
+ // both deinterleaving loads and normal loads in the same block. Also, discard
+ // masked loads that are extending, indexed, have an unexpected offset or have
+ // an unsupported passthru value until we find a valid use case.
+ auto MaskedLoad = dyn_cast<MaskedLoadSDNode>(Op0->getOperand(0));
+ if (!MaskedLoad || !MaskedLoad->hasNUsesOfValue(NumParts, 0) ||
+ MaskedLoad->getExtensionType() != ISD::NON_EXTLOAD ||
+ MaskedLoad->getAddressingMode() != ISD::UNINDEXED ||
+ !MaskedLoad->getOffset().isUndef() ||
+ (!MaskedLoad->getPassThru()->isUndef() &&
+ !isZerosVector(MaskedLoad->getPassThru().getNode())))
+ return SDValue();
+
+ // Now prove that the mask is an interleave of identical masks.
+ SDValue Mask = MaskedLoad->getMask();
+ if (Mask->getOpcode() != ISD::SPLAT_VECTOR &&
+ Mask->getOpcode() != ISD::CONCAT_VECTORS)
+ return SDValue();
+
+ // TODO: Do we need to check the element type of the mask?
+ SDValue NarrowMask;
+ SDLoc DL(N);
+ if (Mask->getOpcode() == ISD::CONCAT_VECTORS) {
+ if (Mask->getNumOperands() != NumParts)
+ return SDValue();
+
+ // We should be concatenating each sequential result from a
+ // VECTOR_INTERLEAVE.
+ SDValue InterleaveOp = Mask->getOperand(0);
+ if (InterleaveOp->getOpcode() != ISD::VECTOR_INTERLEAVE ||
+ InterleaveOp->getNumOperands() != NumParts)
+ return SDValue();
+
+ for (unsigned I = 0; I < NumParts; I++) {
+ SDValue ConcatOp = Mask->getOperand(I);
+ if (ConcatOp.getResNo() != I ||
+ ConcatOp.getNode() != InterleaveOp.getNode())
+ return SDValue();
+ }
+
+ // Make sure the inputs to the vector interleave are identical.
+ for (unsigned I = 1; I < NumParts; I++) {
+ if (InterleaveOp->getOperand(I) != InterleaveOp->getOperand(0))
+ return SDValue();
+ }
+
+ NarrowMask = InterleaveOp->getOperand(0);
+ } else { // ISD::SPLAT_VECTOR
+ auto *SplatVal = dyn_cast<ConstantSDNode>(Mask->getOperand(0));
+ if (!SplatVal || SplatVal->getZExtValue() != 1)
+ return SDValue();
+ ElementCount EC = Mask.getValueType().getVectorElementCount();
+ assert((EC.getKnownMinValue() % NumParts) == 0 &&
+ "Expected element count divisible by number of parts");
+ EC = ElementCount::getScalable(EC.getKnownMinValue() / NumParts);
+ NarrowMask =
+ DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::getVectorVT(MVT::i1, EC),
+ DAG.getConstant(1, DL, MVT::i1));
+ }
+
+ const Intrinsic::ID IID = NumParts == 2 ? Intrinsic::aarch64_sve_ld2_sret
+ : Intrinsic::aarch64_sve_ld4_sret;
+ SDValue NewLdOps[] = {MaskedLoad->getChain(),
+ DAG.getConstant(IID, DL, MVT::i32), NarrowMask,
+ MaskedLoad->getBasePtr()};
+ SDValue Res;
+ if (NumParts == 2)
+ Res = DAG.getNode(ISD::INTRINSIC_W_CHAIN, DL,
+ {SubVecTy, SubVecTy, MVT::Other}, NewLdOps);
+ else
+ Res = DAG.getNode(ISD::INTRINSIC_W_CHAIN, DL,
+ {SubVecTy, SubVecTy, SubVecTy, SubVecTy, MVT::Other},
+ NewLdOps);
+
+ // We can now generate a structured load!
+ SmallVector<SDValue, 4> ResOps(NumParts);
+ for (unsigned Idx = 0; Idx < NumParts; Idx++)
+ ResOps[Idx] = SDValue(Res.getNode(), Idx);
+
+ // Replace uses of the original chain result with the new chain result.
+ DAG.ReplaceAllUsesOfValueWith(SDValue(MaskedLoad, 1),
+ SDValue(Res.getNode(), NumParts));
+ return DCI.CombineTo(N, ResOps, false);
+}
+
/// If the operand is a bitwise AND with a constant RHS, and the shift has a
/// constant RHS and is the only use, we can pull it out of the shift, i.e.
///
@@ -27078,6 +27194,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
default:
LLVM_DEBUG(dbgs() << "Custom combining: skipping\n");
break;
+ case ISD::VECTOR_DEINTERLEAVE:
+ return performVectorDeinterleaveCombine(N, DCI, DAG);
case ISD::VECREDUCE_AND:
case ISD::VECREDUCE_OR:
case ISD::VECREDUCE_XOR:
diff --git a/llvm/test/CodeGen/AArch64/masked_deinterleaved_loads.ll b/llvm/test/CodeGen/AArch64/masked_deinterleaved_loads.ll
new file mode 100644
index 0000000000000..e7d9cb353941e
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/masked_deinterleaved_loads.ll
@@ -0,0 +1,542 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s
+
+define <vscale x 16 x i8> @foo_ld2_nxv16i8(<vscale x 16 x i1> %mask, ptr %p) {
+; CHECK-LABEL: foo_ld2_nxv16i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ld2b { z0.b, z1.b }, p0/z, [x0]
+; CHECK-NEXT: add z0.b, z0.b, z1.b
+; CHECK-NEXT: ret
+ %interleaved.mask = tail call <vscale x 32 x i1> @llvm.vector.interleave2.nxv32i1(<vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask)
+ %wide.masked.vec = tail call <vscale x 32 x i8> @llvm.masked.load.nxv32i8.p0(ptr %p, i32 1, <vscale x 32 x i1> %interleaved.mask, <vscale x 32 x i8> poison)
+ %strided.vec = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> %wide.masked.vec)
+ %part1 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %strided.vec, 0
+ %part2 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %strided.vec, 1
+ %add = add <vscale x 16 x i8> %part1, %part2
+ ret <vscale x 16 x i8> %add
+}
+
+define <vscale x 8 x i16> @foo_ld2_nxv8i16(<vscale x 8 x i1> %mask, ptr %p) {
+; CHECK-LABEL: foo_ld2_nxv8i16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ld2h { z0.h, z1.h }, p0/z, [x0]
+; CHECK-NEXT: add z0.h, z0.h, z1.h
+; CHECK-NEXT: ret
+ %interleaved.mask = tail call <vscale x 16 x i1> @llvm.vector.interleave2.nxv16i1(<vscale x 8 x i1> %mask, <vscale x 8 x i1> %mask)
+ %wide.masked.vec = tail call <vscale x 16 x i16> @llvm.masked.load.nxv16i16.p0(ptr %p, i32 2, <vscale x 16 x i1> %interleaved.mask, <vscale x 16 x i16> poison)
+ %strided.vec = tail call { <vscale x 8 x i16>, <vscale x 8 x i16> } @llvm.vector.deinterleave2.nxv16i16(<vscale x 16 x i16> %wide.masked.vec)
+ %part1 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16> } %strided.vec, 0
+ %part2 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16> } %strided.vec, 1
+ %add = add <vscale x 8 x i16> %part1, %part2
+ ret <vscale x 8 x i16> %add
+}
+
+define <vscale x 4 x float> @foo_ld2_nxv4f32(<vscale x 4 x i1> %mask, ptr %p) {
+; CHECK-LABEL: foo_ld2_nxv4f32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ld2w { z0.s, z1.s }, p0/z, [x0]
+; CHECK-NEXT: fadd z0.s, z0.s, z1.s
+; CHECK-NEXT: ret
+ %interleaved.mask = tail call <vscale x 8 x i1> @llvm.vector.interleave2.nxv8i1(<vscale x 4 x i1> %mask, <vscale x 4 x i1> %mask)
+ %wide.masked.vec = tail call <vscale x 8 x float> @llvm.masked.load.nxv8f32.p0(ptr %p, i32 4, <vscale x 8 x i1> %interleaved.mask, <vscale x 8 x float> poison)
+ %strided.vec = tail call { <vscale x 4 x float>, <vscale x 4 x float> } @llvm.vector.deinterleave2.nxv16f32(<vscale x 8 x float> %wide.masked.vec)
+ %part1 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float> } %strided.vec, 0
+ %part2 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float> } %strided.vec, 1
+ %add = fadd <vscale x 4 x float> %part1, %part2
+ ret <vscale x 4 x float> %add
+}
+
+define <vscale x 2 x double> @foo_ld2_nxv2f64(<vscale x 2 x i1> %mask, ptr %p) {
+; CHECK-LABEL: foo_ld2_nxv2f64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ld2d { z0.d, z1.d }, p0/z, [x0]
+; CHECK-NEXT: fadd z0.d, z0.d, z1.d
+; CHECK-NEXT: ret
+ %interleaved.mask = tail call <vscale x 4 x i1> @llvm.vector.interleave2.nxv4i1(<vscale x 2 x i1> %mask, <vscale x 2 x i1> %mask)
+ %wide.masked.vec = tail call <vscale x 4 x double> @llvm.masked.load.nxv4f64.p0(ptr %p, i32 8, <vscale x 4 x i1> %interleaved.mask, <vscale x 4 x double> poison)
+ %strided.vec = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %wide.masked.vec)
+ %part1 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 0
+ %part2 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 1
+ %add = fadd <vscale x 2 x double> %part1, %part2
+ ret <vscale x 2 x double> %add
+}
+
+define <vscale x 16 x i8> @foo_ld4_nxv16i8(<vscale x 16 x i1> %mask, ptr %p) {
+; CHECK-LABEL: foo_ld4_nxv16i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ld4b { z0.b - z3.b }, p0/z, [x0]
+; CHECK-NEXT: add z4.b, z0.b, z1.b
+; CHECK-NEXT: add z0.b, z2.b, z3.b
+; CHECK-NEXT: add z0.b, z4.b, z0.b
+; CHECK-NEXT: ret
+ %interleaved.mask = tail call <vscale x 64 x i1> @llvm.vector.interleave4.nxv64i1(<vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask)
+ %wide.masked.vec = tail call <vscale x 64 x i8> @llvm.masked.load.nxv64i8.p0(ptr %p, i32 1, <vscale x 64 x i1> %interleaved.mask, <vscale x 64 x i8> poison)
+ %strided.vec = tail call { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave4.nxv64i8(<vscale x 64 x i8> %wide.masked.vec)
+ %part1 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } %strided.vec, 0
+ %part2 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } %strided.vec, 1
+ %part3 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } %strided.vec, 2
+ %part4 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } %strided.vec, 3
+ %add1 = add <vscale x 16 x i8> %part1, %part2
+ %add2 = add <vscale x 16 x i8> %part3, %part4
+ %add3 = add <vscale x 16 x i8> %add1, %add2
+ ret <vscale x 16 x i8> %add3
+}
+
+define <vscale x 8 x i16> @foo_ld4_nxv8i16(<vscale x 8 x i1> %mask, ptr %p) {
+; CHECK-LABEL: foo_ld4_nxv8i16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ld4h { z0.h - z3.h }, p0/z, [x0]
+; CHECK-NEXT: add z4.h, z0.h, z1.h
+; CHECK-NEXT: add z0.h, z2.h, z3.h
+; CHECK-NEXT: add z0.h, z4.h, z0.h
+; CHECK-NEXT: ret
+ %interleaved.mask = tail call <vscale x 32 x i1> @llvm.vector.interleave4.nxv32i1(<vscale x 8 x i1> %mask, <vscale x 8 x i1> %mask, <vscale x 8 x i1> %mask, <vscale x 8 x i1> %mask)
+ %wide.masked.vec = tail call <vscale x 32 x i16> @llvm.masked.load.nxv32i16.p0(ptr %p, i32 2, <vscale x 32 x i1> %interleaved.mask, <vscale x 32 x i16> poison)
+ %strided.vec = tail call { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } @llvm.vector.deinterleave4.nxv32i16(<vscale x 32 x i16> %wide.masked.vec)
+ %part1 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %strided.vec, 0
+ %part2 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %strided.vec, 1
+ %part3 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %strided.vec, 2
+ %part4 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %strided.vec, 3
+ %add1 = add <vscale x 8 x i16> %part1, %part2
+ %add2 = add <vscale x 8 x i16> %part3, %part4
+ %add3 = add <vscale x 8 x i16> %add1, %add2
+ ret <vscale x 8 x i16> %add3
+}
+
+define <vscale x 4 x float> @foo_ld4_nxv4f32(<vscale x 4 x i1> %mask, ptr %p) {
+; CHECK-LABEL: foo_ld4_nxv4f32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ld4w { z0.s - z3.s }, p0/z, [x0]
+; CHECK-NEXT: fadd z4.s, z0.s, z1.s
+; CHECK-NEXT: fadd z0.s, z2.s, z3.s
+; CHECK-NEXT: fadd z0.s, z4.s, z0.s
+; CHECK-NEXT: ret
+ %interleaved.mask = tail call <vscale x 16 x i1> @llvm.vector.interleave4.nxv16i1(<vscale x 4 x i1> %mask, <vscale x 4 x i1> %mask, <vscale x 4 x i1> %mask, <vscale x 4 x i1> %mask)
+ %wide.masked.vec = tail call <vscale x 16 x float> @llvm.masked.load.nxv16f32.p0(ptr %p, i32 4, <vscale x 16 x i1> %interleaved.mask, <vscale x 16 x float> poison)
+ %strided.vec = tail call { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } @llvm.vector.deinterleave4.nxv16f32(<vscale x 16 x float> %wide.masked.vec)
+ %part1 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } %strided.vec, 0
+ %part2 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } %strided.vec, 1
+ %part3 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } %strided.vec, 2
+ %part4 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } %strided.vec, 3
+ %add1 = fadd <vscale x 4 x float> %part1, %part2
+ %add2 = fadd <vscale x 4 x float> %part3, %part4
+ %add3 = fadd <vscale x 4 x float> %add1, %add2
+ ret <vscale x 4 x float> %add3
+}
+
+define <vscale x 2 x double> @foo_ld4_nxv2f64(<vscale x 2 x i1> %mask, ptr %p) {
+; CHECK-LABEL: foo_ld4_nxv2f64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ld4d { z0.d - z3.d }, p0/z, [x0]
+; CHECK-NEXT: fadd z4.d, z0.d, z1.d
+; CHECK-NEXT: fadd z0.d, z2.d, z3.d
+; CHECK-NEXT: fadd z0.d, z4.d, z0.d
+; CHECK-NEXT: ret
+ %interleaved.mask = tail call <vscale x 8 x i1> @llvm.vector.interleave4.nxv8i1(<vscale x 2 x i1> %mask, <vscale x 2 x i1> %mask, <vscale x 2 x i1> %mask, <vscale x 2 x i1> %mask)
+ %wide.masked.vec = tail call <vscale x 8 x double> @llvm.masked.load.nxv8f64.p0(ptr %p, i32 8, <vscale x 8 x i1> %interleaved.mask, <vscale x 8 x double> poison)
+ %strided.vec = tail call { <vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave4.nxv8f64(<vscale x 8 x double> %wide.masked.vec)
+ %part1 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 0
+ %part2 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 1
+ %part3 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 2
+ %part4 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 3
+ %add1 = fadd <vscale x 2 x double> %part1, %part2
+ %add2 = fadd <vscale x 2 x double> %part3, %part4
+ %add3 = fadd <vscale x 2 x double> %add1, %add2
+ ret <vscale x 2 x double> %add3
+}
+
+
+define <vscale x 16 x i8> @foo_ld4_nxv16i8_mul_use_of_mask(<vscale x 16 x i1> %mask, ptr %p, ptr %p2) {
+; CHECK-LABEL: foo_ld4_nxv16i8_mul_use_of_mask:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ld4b { z0.b - z3.b }, p0/z, [x0]
+; CHECK-NEXT: zip2 p1.b, p0.b, p0.b
+; CHECK-NEXT: zip1 p0.b, p0.b, p0.b
+; CHECK-NEXT: zip2 p2.b, p1.b, p1.b
+; CHECK-NEXT: zip1 p1.b, p1.b, p1.b
+; CHECK-NEXT: zip2 p3.b, p0.b, p0.b
+; CHECK-NEXT: add z4.b, z0.b, z1.b
+; CHECK-NEXT: add z0.b, z2.b, z3.b
+; CHECK-NEXT: movi v1.2d, #0000000000000000
+; CHECK-NEXT: zip1 p0.b, p0.b, p0.b
+; CHECK-NEXT: add z0.b, z4.b, z0.b
+; CHECK-NEXT: st1b { z1.b }, p2, [x1, #3, mul vl]
+; CHECK-NEXT: st1b { z1.b }, p1, [x1, #2, mul vl]
+; CHECK-NEXT: st1b { z1.b }, p3, [x1, #1, mul vl]
+; CHECK-NEXT: st1b { z1.b }, p0, [x1]
+; CHECK-NEXT: ret
+ %interleaved.mask = tail call <vscale x 64 x i1> @llvm.vector.interleave4.nxv64i1(<vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask)
+ %wide.masked.vec = tail call <vscale x 64 x i8> @llvm.masked.load.nxv64i8.p0(ptr %p, i32 4, <vscale x 64 x i1> %interleaved.mask, <vscale x 64 x i8> poison)
+ %strided.vec = tail call { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave4.nxv64i8(<vscale x 64 x i8> %wide.masked.vec)
+ %part1 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } %strided.vec, 0
+ %part2 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } %strided.vec, 1
+ %part3 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } %strided.vec, 2
+ %part4 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } %strided.vec, 3
+ %add1 = add <vscale x 16 x i8> %part1, %part2
+ %add2 = add <vscale x 16 x i8> %part3, %part4
+ %add3 = add <vscale x 16 x i8> %add1, %add2
+ tail call void @llvm.masked.store.nxv64i8.p0(<vscale x 64 x i8> zeroinitializer, ptr %p2, i32 1, <vscale x 64 x i1> %interleaved.mask)
+ ret <vscale x 16 x i8> %add3
+}
+
+define <vscale x 16 x i8> @foo_ld4_nxv16i8_mask_of_interleaved_ones(ptr %p) {
+; CHECK-LABEL: foo_ld4_nxv16i8_mask_of_interleaved_ones:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ptrue p0.b
+; CHECK-NEXT: ld4b { z0.b - z3.b }, p0/z, [x0]
+; CHECK-NEXT: add z4.b, z0.b, z1.b
+; CHECK-NEXT: add z0.b, z2.b, z3.b
+; CHECK-NEXT: add z0.b, z4.b, z0.b
+; CHECK-NEXT: ret
+ %interleaved.mask = tail call <vscale x 64 x i1> @llvm.vector.interleave4.nxv64i1(<vscale x 16 x i1> splat(i1 1), <vscale x 16 x i1> splat(i1 1), <vscale x 16 x i1> splat(i1 1), <vscale x 16 x i1> splat(i1 1))
+ %wide.masked.vec = tail call <vscale x 64 x i8> @llvm.masked.load.nxv64i8.p0(ptr %p, i32 4, <vscale x 64 x i1> %interleaved.mask, <vscale x 64 x i8> poison)
+ %strided.vec = tail call { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave4.nxv64i8(<vscale x 64 x i8> %wide.masked.vec)
+ %part1 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } %strided.vec, 0
+ %part2 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } %strided.vec, 1
+ %part3 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } %strided.vec, 2
+ %part4 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } %strided.vec, 3
+ %add1 = add <vscale x 16 x i8> %part1, %part2
+ %add2 = add <vscale x 16 x i8> %part3, %pa...
[truncated]
|
%interleaved.mask = tail call <vscale x 64 x i1> @llvm.vector.interleave4.nxv64i1(<vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask) | ||
%wide.masked.vec = tail call <vscale x 64 x i8> @llvm.masked.load.nxv64i8.p0(ptr %p, i32 1, <vscale x 64 x i1> %interleaved.mask, <vscale x 64 x i8> poison) | ||
%wide.masked.vec.ext = zext <vscale x 64 x i8> %wide.masked.vec to <vscale x 64 x i16> | ||
%strided.vec = tail call { <vscale x 16 x i16>, <vscale x 16 x i16>, <vscale x 16 x i16>, <vscale x 16 x i16> } @llvm.vector.deinterleave4.nxv64i8(<vscale x 64 x i16> %wide.masked.vec.ext) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just realised the signatures for the deinterleave4 intrinsics here and in the test below are incorrect. I'll fix them!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this test not just checking nothing happens for illegal type like foo_ld4_nxv8i8
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, this is a negative test because the input to the deinterleave does not come from a masked load, although I see it will get rejected early on anyway because the vector subparts are illegal. I'll create another test for this.
It's also different to foo_ld2_nxv8i8_exti16 because in that case the zext is folded into the masked load to create an extending masked load. Even though we currently reject deinterleaves with illegal types, I thought it's useful to have in case we add support for illegal types in future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for patch Dave, mostly LGTM, but there is a fair few codepaths that are missing tests I've pointed out
%interleaved.mask = tail call <vscale x 64 x i1> @llvm.vector.interleave4.nxv64i1(<vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask) | ||
%wide.masked.vec = tail call <vscale x 64 x i8> @llvm.masked.load.nxv64i8.p0(ptr %p, i32 1, <vscale x 64 x i1> %interleaved.mask, <vscale x 64 x i8> poison) | ||
%wide.masked.vec.ext = zext <vscale x 64 x i8> %wide.masked.vec to <vscale x 64 x i16> | ||
%strided.vec = tail call { <vscale x 16 x i16>, <vscale x 16 x i16>, <vscale x 16 x i16>, <vscale x 16 x i16> } @llvm.vector.deinterleave4.nxv64i8(<vscale x 64 x i16> %wide.masked.vec.ext) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this test not just checking nothing happens for illegal type like foo_ld4_nxv8i8
?
62780d8
to
b4bb182
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for updates, sorry for the noise pointing out all the untested codepaths in the first round of review, it's just there's a ton of constraints in the pattern that's pretty overwhelming so I went through and added unreachables to figure out what was what, and noted ones that didn't fire.
I've got some more suggestions to improve the tests, mostly minor, but the actual code changes LGTM, cheers!
%part1 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %deinterleaved.vec, 0 | ||
%part2 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %deinterleaved.vec, 1 | ||
%add = add <vscale x 16 x i8> %part1, %part2 | ||
ret <vscale x 16 x i8> %add |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the extracts + add aren't necessary for these tests, I think you should just return the tuple like we do in llvm/test/CodeGen/AArch64/sve-vector-deinterleave.ll
%add1 = add <vscale x 16 x i8> %part1, %part2 | ||
%add2 = add <vscale x 16 x i8> %part3, %part4 | ||
%add3 = add <vscale x 16 x i8> %add1, %add2 | ||
call void @llvm.masked.store.nxv64i8.p0(<vscale x 64 x i8> zeroinitializer, ptr %p2, i32 1, <vscale x 64 x i1> %interleaved.mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI there's an intrinsic llvm.fake.use for this purpose: https://godbolt.org/z/cnbMdKP6s
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Didn't know about that.
ret <vscale x 16 x i8> %add1 | ||
} | ||
|
||
define <vscale x 16 x i8> @foo_ld2_nxv16i8_bad_mask(<vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask2, ptr %p, ptr %p2) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: a comment would be helpful. Same for tests below, I added comments as I went thru and figured out what they're doing which others may find useful.
define <vscale x 16 x i8> @foo_ld2_nxv16i8_bad_mask(<vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask2, ptr %p, ptr %p2) { | |
; Mask must be an interleave of identical masks. | |
define <vscale x 16 x i8> @foo_ld2_nxv16i8_bad_mask(<vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask2, ptr %p, ptr %p2) { |
ret <vscale x 16 x i8> %add1 | ||
} | ||
|
||
define <vscale x 16 x i8> @foo_ld4_nxv16i8_bad_mask2(<vscale x 32 x i1> %mask, ptr %p, ptr %p2) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
define <vscale x 16 x i8> @foo_ld4_nxv16i8_bad_mask2(<vscale x 32 x i1> %mask, ptr %p, ptr %p2) { | |
; Number of parts in mask interleave must match deinterleave. | |
define <vscale x 16 x i8> @foo_ld4_nxv16i8_bad_mask2(<vscale x 32 x i1> %mask, ptr %p, ptr %p2) { |
ret <vscale x 16 x i8> %add3 | ||
} | ||
|
||
define <vscale x 8 x i8> @foo_ld2_nxv8i8(<vscale x 8 x i1> %mask, ptr %p) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
define <vscale x 8 x i8> @foo_ld2_nxv8i8(<vscale x 8 x i1> %mask, ptr %p) { | |
; Vectors must be at least 128-bit. | |
define <vscale x 8 x i8> @foo_ld2_nxv8i8(<vscale x 8 x i1> %mask, ptr %p) { |
ret <vscale x 8 x i8> %add1 | ||
} | ||
|
||
define <vscale x 16 x i8> @foo_ld2_nxv16i8_bad_passthru(<vscale x 16 x i1> %mask, ptr %p) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
define <vscale x 16 x i8> @foo_ld2_nxv16i8_bad_passthru(<vscale x 16 x i1> %mask, ptr %p) { | |
; Passthru must be uninitialized. | |
define <vscale x 16 x i8> @foo_ld2_nxv16i8_bad_passthru(<vscale x 16 x i1> %mask, ptr %p) { |
} | ||
|
||
|
||
define <vscale x 8 x i16> @foo_deinterleave2_not_load(<vscale x 16 x i16> %vec) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there's already a test for this in llvm/test/CodeGen/AArch64/sve-vector-deinterleave.ll
, this can be removed. Apologies if this is one you added based on my last set of comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
llvm/test/CodeGen/AArch64/scalable_masked_deinterleaved_loads.ll
Outdated
Show resolved
Hide resolved
llvm/test/CodeGen/AArch64/scalable_masked_deinterleaved_loads.ll
Outdated
Show resolved
Hide resolved
For IR like this: %mask = ... @llvm.vector.interleave2(<vscale x 16 x i1> %a, <vscale x 16 x i1> %a) %vec = ... @llvm.masked.load(..., <vscale x 32 x i1> %mask, ...) %dvec = ... @llvm.vector.deinterleave2(<vscale x 32 x i8> %vec) where we're deinterleaving a wide masked load of the supported type and with an interleaved mask we can lower this directly to a ld2b instruction. Similarly we can also support other variants of ld2 and ld4. This PR adds a DAG combine to spot such patterns and lower to ld2X or ld4X variants accordingly, whilst being careful to ensure the masked load is only used by the deinterleave intrinsic.
b4bb182
to
19987ea
Compare
updates LGTM, cheers! |
auto *SplatVal = dyn_cast<ConstantSDNode>(Mask->getOperand(0)); | ||
if (!SplatVal || SplatVal->getZExtValue() != 1) | ||
return SDValue(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does the value being splatted matter? I think you can drop this logic and just allow all splat values?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't sure if we should perform the DAG combine if the mask is all-false because we will then generate an intrinsic that we know is a nop and also the DAGCombiner will not understand and not optimise away. I thought that leaving the original code left a greater chance of the whole lot being deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would expect the all-false case to be caught by instcombine so there would be no need for the code generator to worry about it.
ElementCount EC = Mask.getValueType().getVectorElementCount(); | ||
assert((EC.getKnownMinValue() % NumParts) == 0 && | ||
"Expected element count divisible by number of parts"); | ||
EC = ElementCount::getScalable(EC.getKnownMinValue() / NumParts); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
EC = ElementCount::getScalable(EC.getKnownMinValue() / NumParts); | |
EC = EC.divideCoefficientBy(NumParts); |
assert((EC.getKnownMinValue() % NumParts) == 0 && | ||
"Expected element count divisible by number of parts"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert((EC.getKnownMinValue() % NumParts) == 0 && | |
"Expected element count divisible by number of parts"); | |
assert((EC.isKnownMultipleOf(NumParts) && | |
"Expected element count divisible by number of parts"); |
Although based on my "do we need to strengthen the type checking" comment, this assert might be unnecessary.
@@ -27015,6 +27016,120 @@ performScalarToVectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, | |||
return NVCAST; | |||
} | |||
|
|||
static SDValue performVectorDeinterleaveCombine( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the combine is only useful prior to type legalisation it is worth having an early bail out for when it's run after that point.
// we usually generate shuffles instead. | ||
unsigned MinNumElements = SubVecTy.getVectorMinNumElements(); | ||
if (!SubVecTy.isScalableVT() || | ||
SubVecTy.getSizeInBits().getKnownMinValue() != 128 || MinNumElements == 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to strengthen the type checking here because being before type legalisation there must be many vscale*128-bit sized types that we'll not support. Perhaps replace the MinNumElements
check with !isTypeLegal(SubVecTy)
?
// At the moment we're unlikely to see a fixed-width vector deinterleave as | ||
// we usually generate shuffles instead. | ||
unsigned MinNumElements = SubVecTy.getVectorMinNumElements(); | ||
if (!SubVecTy.isScalableVT() || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (!SubVecTy.isScalableVT() || | |
if (!SubVecTy.isScalableVector() || |
MaskedLoad->getExtensionType() != ISD::NON_EXTLOAD || | ||
MaskedLoad->getAddressingMode() != ISD::UNINDEXED || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MaskedLoad->getExtensionType() != ISD::NON_EXTLOAD || | |
MaskedLoad->getAddressingMode() != ISD::UNINDEXED || | |
!MaskedLoad->isSimple() || !MaskedLoad->isNormalLoad() || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isNormalLoad
doesn't work for masked loads, so I've added a new isNormalMaskedLoad
auto *Idx = cast<ConstantSDNode>(OpI->getOperand(1)); | ||
if (Idx->getZExtValue() != (I * MinNumElements)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto *Idx = cast<ConstantSDNode>(OpI->getOperand(1)); | |
if (Idx->getZExtValue() != (I * MinNumElements)) | |
if (OpI->getConstantOperandVal(1) != (I * MinNumElements)) |
for (unsigned I = 1; I < NumParts; I++) { | ||
if (InterleaveOp->getOperand(I) != InterleaveOp->getOperand(0)) | ||
return SDValue(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for (unsigned I = 1; I < NumParts; I++) { | |
if (InterleaveOp->getOperand(I) != InterleaveOp->getOperand(0)) | |
return SDValue(); | |
} | |
if (!llvm::all_equal(InterleaveOp->op_values())) | |
return SDValue(); |
if (ConcatOp.getResNo() != I || | ||
ConcatOp.getNode() != InterleaveOp.getNode()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you make InterleaveOp
an SDNode (which is how you're using it anyway for several of the cases by using ->
) you can simplify this to if (Mask.getOperand(I) == SDValue(InterleaveOp, I))
For IR like this:
%mask = ... @llvm.vector.interleave2(<vscale x 16 x i1> %a, <vscale x 16 x i1> %a)
%vec = ... @llvm.masked.load(..., <vscale x 32 x i1> %mask, ...)
%dvec = ... @llvm.vector.deinterleave2(<vscale x 32 x i8> %vec)
where we're deinterleaving a wide masked load of the supported type
and with an interleaved mask we can lower this directly to a ld2b
instruction. Similarly we can also support other variants of ld2
and ld4.
This PR adds a DAG combine to spot such patterns and lower to ld2X
or ld4X variants accordingly, whilst being careful to ensure the
masked load is only used by the deinterleave intrinsic.