Skip to content

Conversation

david-arm
Copy link
Contributor

For IR like this:

%mask = ... @llvm.vector.interleave2(<vscale x 16 x i1> %a, <vscale x 16 x i1> %a)
%vec = ... @llvm.masked.load(..., <vscale x 32 x i1> %mask, ...)
%dvec = ... @llvm.vector.deinterleave2(<vscale x 32 x i8> %vec)

where we're deinterleaving a wide masked load of the supported type
and with an interleaved mask we can lower this directly to a ld2b
instruction. Similarly we can also support other variants of ld2
and ld4.

This PR adds a DAG combine to spot such patterns and lower to ld2X
or ld4X variants accordingly, whilst being careful to ensure the
masked load is only used by the deinterleave intrinsic.

@llvmbot
Copy link
Member

llvmbot commented Aug 19, 2025

@llvm/pr-subscribers-llvm-selectiondag

@llvm/pr-subscribers-backend-aarch64

Author: David Sherwood (david-arm)

Changes

For IR like this:

%mask = ... @llvm.vector.interleave2(<vscale x 16 x i1> %a, <vscale x 16 x i1> %a)
%vec = ... @llvm.masked.load(..., <vscale x 32 x i1> %mask, ...)
%dvec = ... @llvm.vector.deinterleave2(<vscale x 32 x i8> %vec)

where we're deinterleaving a wide masked load of the supported type
and with an interleaved mask we can lower this directly to a ld2b
instruction. Similarly we can also support other variants of ld2
and ld4.

This PR adds a DAG combine to spot such patterns and lower to ld2X
or ld4X variants accordingly, whilst being careful to ensure the
masked load is only used by the deinterleave intrinsic.


Patch is 39.15 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/154338.diff

2 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+118)
  • (added) llvm/test/CodeGen/AArch64/masked_deinterleaved_loads.ll (+542)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 63a85faf344c4..cff28e006b22e 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1178,6 +1178,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
   setTargetDAGCombine(ISD::SCALAR_TO_VECTOR);
 
   setTargetDAGCombine(ISD::SHL);
+  setTargetDAGCombine(ISD::VECTOR_DEINTERLEAVE);
 
   // In case of strict alignment, avoid an excessive number of byte wide stores.
   MaxStoresPerMemsetOptSize = 8;
@@ -27010,6 +27011,121 @@ performScalarToVectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
   return NVCAST;
 }
 
+static SDValue performVectorDeinterleaveCombine(
+    SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) {
+  unsigned NumParts = N->getNumOperands();
+  if (NumParts != 2 && NumParts != 4)
+    return SDValue();
+
+  EVT SubVecTy = N->getValueType(0);
+
+  // At the moment we're unlikely to see a fixed-width vector deinterleave as
+  // we usually generate shuffles instead.
+  unsigned MinNumElements = SubVecTy.getVectorMinNumElements();
+  if (!SubVecTy.isScalableVT() ||
+      SubVecTy.getSizeInBits().getKnownMinValue() != 128 || MinNumElements == 1)
+    return SDValue();
+
+  // Make sure each input operand is the correct extract_subvector of the same
+  // wider vector.
+  SDValue Op0 = N->getOperand(0);
+  for (unsigned I = 0; I < NumParts; I++) {
+    SDValue OpI = N->getOperand(I);
+    if (OpI->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+        OpI->getOperand(0) != Op0->getOperand(0))
+      return SDValue();
+    auto *Idx = cast<ConstantSDNode>(OpI->getOperand(1));
+    if (Idx->getZExtValue() != (I * MinNumElements))
+      return SDValue();
+  }
+
+  // Normal loads are currently already handled by the InterleavedAccessPass so
+  // we don't expect to see them here. Bail out if the masked load has an
+  // unexpected number of uses, since we want to avoid a situation where we have
+  // both deinterleaving loads and normal loads in the same block. Also, discard
+  // masked loads that are extending, indexed, have an unexpected offset or have
+  // an unsupported passthru value until we find a valid use case.
+  auto MaskedLoad = dyn_cast<MaskedLoadSDNode>(Op0->getOperand(0));
+  if (!MaskedLoad || !MaskedLoad->hasNUsesOfValue(NumParts, 0) ||
+      MaskedLoad->getExtensionType() != ISD::NON_EXTLOAD ||
+      MaskedLoad->getAddressingMode() != ISD::UNINDEXED ||
+      !MaskedLoad->getOffset().isUndef() ||
+      (!MaskedLoad->getPassThru()->isUndef() &&
+       !isZerosVector(MaskedLoad->getPassThru().getNode())))
+    return SDValue();
+
+  // Now prove that the mask is an interleave of identical masks.
+  SDValue Mask = MaskedLoad->getMask();
+  if (Mask->getOpcode() != ISD::SPLAT_VECTOR &&
+      Mask->getOpcode() != ISD::CONCAT_VECTORS)
+    return SDValue();
+
+  // TODO: Do we need to check the element type of the mask?
+  SDValue NarrowMask;
+  SDLoc DL(N);
+  if (Mask->getOpcode() == ISD::CONCAT_VECTORS) {
+    if (Mask->getNumOperands() != NumParts)
+      return SDValue();
+
+    // We should be concatenating each sequential result from a
+    // VECTOR_INTERLEAVE.
+    SDValue InterleaveOp = Mask->getOperand(0);
+    if (InterleaveOp->getOpcode() != ISD::VECTOR_INTERLEAVE ||
+        InterleaveOp->getNumOperands() != NumParts)
+      return SDValue();
+
+    for (unsigned I = 0; I < NumParts; I++) {
+      SDValue ConcatOp = Mask->getOperand(I);
+      if (ConcatOp.getResNo() != I ||
+          ConcatOp.getNode() != InterleaveOp.getNode())
+        return SDValue();
+    }
+
+    // Make sure the inputs to the vector interleave are identical.
+    for (unsigned I = 1; I < NumParts; I++) {
+      if (InterleaveOp->getOperand(I) != InterleaveOp->getOperand(0))
+        return SDValue();
+    }
+
+    NarrowMask = InterleaveOp->getOperand(0);
+  } else { // ISD::SPLAT_VECTOR
+    auto *SplatVal = dyn_cast<ConstantSDNode>(Mask->getOperand(0));
+    if (!SplatVal || SplatVal->getZExtValue() != 1)
+      return SDValue();
+    ElementCount EC = Mask.getValueType().getVectorElementCount();
+    assert((EC.getKnownMinValue() % NumParts) == 0 &&
+           "Expected element count divisible by number of parts");
+    EC = ElementCount::getScalable(EC.getKnownMinValue() / NumParts);
+    NarrowMask =
+        DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::getVectorVT(MVT::i1, EC),
+                    DAG.getConstant(1, DL, MVT::i1));
+  }
+
+  const Intrinsic::ID IID = NumParts == 2 ? Intrinsic::aarch64_sve_ld2_sret
+                                          : Intrinsic::aarch64_sve_ld4_sret;
+  SDValue NewLdOps[] = {MaskedLoad->getChain(),
+                        DAG.getConstant(IID, DL, MVT::i32), NarrowMask,
+                        MaskedLoad->getBasePtr()};
+  SDValue Res;
+  if (NumParts == 2)
+    Res = DAG.getNode(ISD::INTRINSIC_W_CHAIN, DL,
+                      {SubVecTy, SubVecTy, MVT::Other}, NewLdOps);
+  else
+    Res = DAG.getNode(ISD::INTRINSIC_W_CHAIN, DL,
+                      {SubVecTy, SubVecTy, SubVecTy, SubVecTy, MVT::Other},
+                      NewLdOps);
+
+  // We can now generate a structured load!
+  SmallVector<SDValue, 4> ResOps(NumParts);
+  for (unsigned Idx = 0; Idx < NumParts; Idx++)
+    ResOps[Idx] = SDValue(Res.getNode(), Idx);
+
+  // Replace uses of the original chain result with the new chain result.
+  DAG.ReplaceAllUsesOfValueWith(SDValue(MaskedLoad, 1),
+                                SDValue(Res.getNode(), NumParts));
+  return DCI.CombineTo(N, ResOps, false);
+}
+
 /// If the operand is a bitwise AND with a constant RHS, and the shift has a
 /// constant RHS and is the only use, we can pull it out of the shift, i.e.
 ///
@@ -27078,6 +27194,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
   default:
     LLVM_DEBUG(dbgs() << "Custom combining: skipping\n");
     break;
+  case ISD::VECTOR_DEINTERLEAVE:
+    return performVectorDeinterleaveCombine(N, DCI, DAG);
   case ISD::VECREDUCE_AND:
   case ISD::VECREDUCE_OR:
   case ISD::VECREDUCE_XOR:
diff --git a/llvm/test/CodeGen/AArch64/masked_deinterleaved_loads.ll b/llvm/test/CodeGen/AArch64/masked_deinterleaved_loads.ll
new file mode 100644
index 0000000000000..e7d9cb353941e
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/masked_deinterleaved_loads.ll
@@ -0,0 +1,542 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s
+
+define <vscale x 16 x i8> @foo_ld2_nxv16i8(<vscale x 16 x i1> %mask, ptr %p) {
+; CHECK-LABEL: foo_ld2_nxv16i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ld2b { z0.b, z1.b }, p0/z, [x0]
+; CHECK-NEXT:    add z0.b, z0.b, z1.b
+; CHECK-NEXT:    ret
+  %interleaved.mask = tail call <vscale x 32 x i1> @llvm.vector.interleave2.nxv32i1(<vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask)
+  %wide.masked.vec = tail call <vscale x 32 x i8> @llvm.masked.load.nxv32i8.p0(ptr %p, i32 1, <vscale x 32 x i1> %interleaved.mask, <vscale x 32 x i8> poison)
+  %strided.vec = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> %wide.masked.vec)
+  %part1 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %strided.vec, 0
+  %part2 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %strided.vec, 1
+  %add = add <vscale x 16 x i8> %part1, %part2
+  ret <vscale x 16 x i8> %add
+}
+
+define <vscale x 8 x i16> @foo_ld2_nxv8i16(<vscale x 8 x i1> %mask, ptr %p) {
+; CHECK-LABEL: foo_ld2_nxv8i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ld2h { z0.h, z1.h }, p0/z, [x0]
+; CHECK-NEXT:    add z0.h, z0.h, z1.h
+; CHECK-NEXT:    ret
+  %interleaved.mask = tail call <vscale x 16 x i1> @llvm.vector.interleave2.nxv16i1(<vscale x 8 x i1> %mask, <vscale x 8 x i1> %mask)
+  %wide.masked.vec = tail call <vscale x 16 x i16> @llvm.masked.load.nxv16i16.p0(ptr %p, i32 2, <vscale x 16 x i1> %interleaved.mask, <vscale x 16 x i16> poison)
+  %strided.vec = tail call { <vscale x 8 x i16>, <vscale x 8 x i16> } @llvm.vector.deinterleave2.nxv16i16(<vscale x 16 x i16> %wide.masked.vec)
+  %part1 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16> } %strided.vec, 0
+  %part2 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16> } %strided.vec, 1
+  %add = add <vscale x 8 x i16> %part1, %part2
+  ret <vscale x 8 x i16> %add
+}
+
+define <vscale x 4 x float> @foo_ld2_nxv4f32(<vscale x 4 x i1> %mask, ptr %p) {
+; CHECK-LABEL: foo_ld2_nxv4f32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ld2w { z0.s, z1.s }, p0/z, [x0]
+; CHECK-NEXT:    fadd z0.s, z0.s, z1.s
+; CHECK-NEXT:    ret
+  %interleaved.mask = tail call <vscale x 8 x i1> @llvm.vector.interleave2.nxv8i1(<vscale x 4 x i1> %mask, <vscale x 4 x i1> %mask)
+  %wide.masked.vec = tail call <vscale x 8 x float> @llvm.masked.load.nxv8f32.p0(ptr %p, i32 4, <vscale x 8 x i1> %interleaved.mask, <vscale x 8 x float> poison)
+  %strided.vec = tail call { <vscale x 4 x float>, <vscale x 4 x float> } @llvm.vector.deinterleave2.nxv16f32(<vscale x 8 x float> %wide.masked.vec)
+  %part1 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float> } %strided.vec, 0
+  %part2 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float> } %strided.vec, 1
+  %add = fadd <vscale x 4 x float> %part1, %part2
+  ret <vscale x 4 x float> %add
+}
+
+define <vscale x 2 x double> @foo_ld2_nxv2f64(<vscale x 2 x i1> %mask, ptr %p) {
+; CHECK-LABEL: foo_ld2_nxv2f64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ld2d { z0.d, z1.d }, p0/z, [x0]
+; CHECK-NEXT:    fadd z0.d, z0.d, z1.d
+; CHECK-NEXT:    ret
+  %interleaved.mask = tail call <vscale x 4 x i1> @llvm.vector.interleave2.nxv4i1(<vscale x 2 x i1> %mask, <vscale x 2 x i1> %mask)
+  %wide.masked.vec = tail call <vscale x 4 x double> @llvm.masked.load.nxv4f64.p0(ptr %p, i32 8, <vscale x 4 x i1> %interleaved.mask, <vscale x 4 x double> poison)
+  %strided.vec = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %wide.masked.vec)
+  %part1 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 0
+  %part2 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 1
+  %add = fadd <vscale x 2 x double> %part1, %part2
+  ret <vscale x 2 x double> %add
+}
+
+define <vscale x 16 x i8> @foo_ld4_nxv16i8(<vscale x 16 x i1> %mask, ptr %p) {
+; CHECK-LABEL: foo_ld4_nxv16i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ld4b { z0.b - z3.b }, p0/z, [x0]
+; CHECK-NEXT:    add z4.b, z0.b, z1.b
+; CHECK-NEXT:    add z0.b, z2.b, z3.b
+; CHECK-NEXT:    add z0.b, z4.b, z0.b
+; CHECK-NEXT:    ret
+  %interleaved.mask = tail call <vscale x 64 x i1> @llvm.vector.interleave4.nxv64i1(<vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask)
+  %wide.masked.vec = tail call <vscale x 64 x i8> @llvm.masked.load.nxv64i8.p0(ptr %p, i32 1, <vscale x 64 x i1> %interleaved.mask, <vscale x 64 x i8> poison)
+  %strided.vec = tail call { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave4.nxv64i8(<vscale x 64 x i8> %wide.masked.vec)
+  %part1 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } %strided.vec, 0
+  %part2 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } %strided.vec, 1
+  %part3 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } %strided.vec, 2
+  %part4 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } %strided.vec, 3
+  %add1 = add <vscale x 16 x i8> %part1, %part2
+  %add2 = add <vscale x 16 x i8> %part3, %part4
+  %add3 = add <vscale x 16 x i8> %add1, %add2
+  ret <vscale x 16 x i8> %add3
+}
+
+define <vscale x 8 x i16> @foo_ld4_nxv8i16(<vscale x 8 x i1> %mask, ptr %p) {
+; CHECK-LABEL: foo_ld4_nxv8i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ld4h { z0.h - z3.h }, p0/z, [x0]
+; CHECK-NEXT:    add z4.h, z0.h, z1.h
+; CHECK-NEXT:    add z0.h, z2.h, z3.h
+; CHECK-NEXT:    add z0.h, z4.h, z0.h
+; CHECK-NEXT:    ret
+  %interleaved.mask = tail call <vscale x 32 x i1> @llvm.vector.interleave4.nxv32i1(<vscale x 8 x i1> %mask, <vscale x 8 x i1> %mask, <vscale x 8 x i1> %mask, <vscale x 8 x i1> %mask)
+  %wide.masked.vec = tail call <vscale x 32 x i16> @llvm.masked.load.nxv32i16.p0(ptr %p, i32 2, <vscale x 32 x i1> %interleaved.mask, <vscale x 32 x i16> poison)
+  %strided.vec = tail call { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } @llvm.vector.deinterleave4.nxv32i16(<vscale x 32 x i16> %wide.masked.vec)
+  %part1 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %strided.vec, 0
+  %part2 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %strided.vec, 1
+  %part3 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %strided.vec, 2
+  %part4 = extractvalue { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } %strided.vec, 3
+  %add1 = add <vscale x 8 x i16> %part1, %part2
+  %add2 = add <vscale x 8 x i16> %part3, %part4
+  %add3 = add <vscale x 8 x i16> %add1, %add2
+  ret <vscale x 8 x i16> %add3
+}
+
+define <vscale x 4 x float> @foo_ld4_nxv4f32(<vscale x 4 x i1> %mask, ptr %p) {
+; CHECK-LABEL: foo_ld4_nxv4f32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ld4w { z0.s - z3.s }, p0/z, [x0]
+; CHECK-NEXT:    fadd z4.s, z0.s, z1.s
+; CHECK-NEXT:    fadd z0.s, z2.s, z3.s
+; CHECK-NEXT:    fadd z0.s, z4.s, z0.s
+; CHECK-NEXT:    ret
+  %interleaved.mask = tail call <vscale x 16 x i1> @llvm.vector.interleave4.nxv16i1(<vscale x 4 x i1> %mask, <vscale x 4 x i1> %mask, <vscale x 4 x i1> %mask, <vscale x 4 x i1> %mask)
+  %wide.masked.vec = tail call <vscale x 16 x float> @llvm.masked.load.nxv16f32.p0(ptr %p, i32 4, <vscale x 16 x i1> %interleaved.mask, <vscale x 16 x float> poison)
+  %strided.vec = tail call { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } @llvm.vector.deinterleave4.nxv16f32(<vscale x 16 x float> %wide.masked.vec)
+  %part1 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } %strided.vec, 0
+  %part2 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } %strided.vec, 1
+  %part3 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } %strided.vec, 2
+  %part4 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float> } %strided.vec, 3
+  %add1 = fadd <vscale x 4 x float> %part1, %part2
+  %add2 = fadd <vscale x 4 x float> %part3, %part4
+  %add3 = fadd <vscale x 4 x float> %add1, %add2
+  ret <vscale x 4 x float> %add3
+}
+
+define <vscale x 2 x double> @foo_ld4_nxv2f64(<vscale x 2 x i1> %mask, ptr %p) {
+; CHECK-LABEL: foo_ld4_nxv2f64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ld4d { z0.d - z3.d }, p0/z, [x0]
+; CHECK-NEXT:    fadd z4.d, z0.d, z1.d
+; CHECK-NEXT:    fadd z0.d, z2.d, z3.d
+; CHECK-NEXT:    fadd z0.d, z4.d, z0.d
+; CHECK-NEXT:    ret
+  %interleaved.mask = tail call <vscale x 8 x i1> @llvm.vector.interleave4.nxv8i1(<vscale x 2 x i1> %mask, <vscale x 2 x i1> %mask, <vscale x 2 x i1> %mask, <vscale x 2 x i1> %mask)
+  %wide.masked.vec = tail call <vscale x 8 x double> @llvm.masked.load.nxv8f64.p0(ptr %p, i32 8, <vscale x 8 x i1> %interleaved.mask, <vscale x 8 x double> poison)
+  %strided.vec = tail call { <vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave4.nxv8f64(<vscale x 8 x double> %wide.masked.vec)
+  %part1 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 0
+  %part2 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 1
+  %part3 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 2
+  %part4 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 3
+  %add1 = fadd <vscale x 2 x double> %part1, %part2
+  %add2 = fadd <vscale x 2 x double> %part3, %part4
+  %add3 = fadd <vscale x 2 x double> %add1, %add2
+  ret <vscale x 2 x double> %add3
+}
+
+
+define <vscale x 16 x i8> @foo_ld4_nxv16i8_mul_use_of_mask(<vscale x 16 x i1> %mask, ptr %p, ptr %p2) {
+; CHECK-LABEL: foo_ld4_nxv16i8_mul_use_of_mask:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ld4b { z0.b - z3.b }, p0/z, [x0]
+; CHECK-NEXT:    zip2 p1.b, p0.b, p0.b
+; CHECK-NEXT:    zip1 p0.b, p0.b, p0.b
+; CHECK-NEXT:    zip2 p2.b, p1.b, p1.b
+; CHECK-NEXT:    zip1 p1.b, p1.b, p1.b
+; CHECK-NEXT:    zip2 p3.b, p0.b, p0.b
+; CHECK-NEXT:    add z4.b, z0.b, z1.b
+; CHECK-NEXT:    add z0.b, z2.b, z3.b
+; CHECK-NEXT:    movi v1.2d, #0000000000000000
+; CHECK-NEXT:    zip1 p0.b, p0.b, p0.b
+; CHECK-NEXT:    add z0.b, z4.b, z0.b
+; CHECK-NEXT:    st1b { z1.b }, p2, [x1, #3, mul vl]
+; CHECK-NEXT:    st1b { z1.b }, p1, [x1, #2, mul vl]
+; CHECK-NEXT:    st1b { z1.b }, p3, [x1, #1, mul vl]
+; CHECK-NEXT:    st1b { z1.b }, p0, [x1]
+; CHECK-NEXT:    ret
+  %interleaved.mask = tail call <vscale x 64 x i1> @llvm.vector.interleave4.nxv64i1(<vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask)
+  %wide.masked.vec = tail call <vscale x 64 x i8> @llvm.masked.load.nxv64i8.p0(ptr %p, i32 4, <vscale x 64 x i1> %interleaved.mask, <vscale x 64 x i8> poison)
+  %strided.vec = tail call { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave4.nxv64i8(<vscale x 64 x i8> %wide.masked.vec)
+  %part1 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } %strided.vec, 0
+  %part2 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } %strided.vec, 1
+  %part3 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } %strided.vec, 2
+  %part4 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } %strided.vec, 3
+  %add1 = add <vscale x 16 x i8> %part1, %part2
+  %add2 = add <vscale x 16 x i8> %part3, %part4
+  %add3 = add <vscale x 16 x i8> %add1, %add2
+  tail call void @llvm.masked.store.nxv64i8.p0(<vscale x 64 x i8> zeroinitializer, ptr %p2, i32 1, <vscale x 64 x i1> %interleaved.mask)
+  ret <vscale x 16 x i8> %add3
+}
+
+define <vscale x 16 x i8> @foo_ld4_nxv16i8_mask_of_interleaved_ones(ptr %p) {
+; CHECK-LABEL: foo_ld4_nxv16i8_mask_of_interleaved_ones:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.b
+; CHECK-NEXT:    ld4b { z0.b - z3.b }, p0/z, [x0]
+; CHECK-NEXT:    add z4.b, z0.b, z1.b
+; CHECK-NEXT:    add z0.b, z2.b, z3.b
+; CHECK-NEXT:    add z0.b, z4.b, z0.b
+; CHECK-NEXT:    ret
+  %interleaved.mask = tail call <vscale x 64 x i1> @llvm.vector.interleave4.nxv64i1(<vscale x 16 x i1> splat(i1 1), <vscale x 16 x i1> splat(i1 1), <vscale x 16 x i1> splat(i1 1), <vscale x 16 x i1> splat(i1 1))
+  %wide.masked.vec = tail call <vscale x 64 x i8> @llvm.masked.load.nxv64i8.p0(ptr %p, i32 4, <vscale x 64 x i1> %interleaved.mask, <vscale x 64 x i8> poison)
+  %strided.vec = tail call { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave4.nxv64i8(<vscale x 64 x i8> %wide.masked.vec)
+  %part1 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } %strided.vec, 0
+  %part2 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } %strided.vec, 1
+  %part3 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } %strided.vec, 2
+  %part4 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } %strided.vec, 3
+  %add1 = add <vscale x 16 x i8> %part1, %part2
+  %add2 = add <vscale x 16 x i8> %part3, %pa...
[truncated]

%interleaved.mask = tail call <vscale x 64 x i1> @llvm.vector.interleave4.nxv64i1(<vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask)
%wide.masked.vec = tail call <vscale x 64 x i8> @llvm.masked.load.nxv64i8.p0(ptr %p, i32 1, <vscale x 64 x i1> %interleaved.mask, <vscale x 64 x i8> poison)
%wide.masked.vec.ext = zext <vscale x 64 x i8> %wide.masked.vec to <vscale x 64 x i16>
%strided.vec = tail call { <vscale x 16 x i16>, <vscale x 16 x i16>, <vscale x 16 x i16>, <vscale x 16 x i16> } @llvm.vector.deinterleave4.nxv64i8(<vscale x 64 x i16> %wide.masked.vec.ext)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realised the signatures for the deinterleave4 intrinsics here and in the test below are incorrect. I'll fix them!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this test not just checking nothing happens for illegal type like foo_ld4_nxv8i8?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, this is a negative test because the input to the deinterleave does not come from a masked load, although I see it will get rejected early on anyway because the vector subparts are illegal. I'll create another test for this.

It's also different to foo_ld2_nxv8i8_exti16 because in that case the zext is folded into the masked load to create an extending masked load. Even though we currently reject deinterleaves with illegal types, I thought it's useful to have in case we add support for illegal types in future.

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for patch Dave, mostly LGTM, but there is a fair few codepaths that are missing tests I've pointed out

%interleaved.mask = tail call <vscale x 64 x i1> @llvm.vector.interleave4.nxv64i1(<vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask)
%wide.masked.vec = tail call <vscale x 64 x i8> @llvm.masked.load.nxv64i8.p0(ptr %p, i32 1, <vscale x 64 x i1> %interleaved.mask, <vscale x 64 x i8> poison)
%wide.masked.vec.ext = zext <vscale x 64 x i8> %wide.masked.vec to <vscale x 64 x i16>
%strided.vec = tail call { <vscale x 16 x i16>, <vscale x 16 x i16>, <vscale x 16 x i16>, <vscale x 16 x i16> } @llvm.vector.deinterleave4.nxv64i8(<vscale x 64 x i16> %wide.masked.vec.ext)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this test not just checking nothing happens for illegal type like foo_ld4_nxv8i8?

@david-arm david-arm force-pushed the masked_deinterleave branch from 62780d8 to b4bb182 Compare August 27, 2025 15:36
Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for updates, sorry for the noise pointing out all the untested codepaths in the first round of review, it's just there's a ton of constraints in the pattern that's pretty overwhelming so I went through and added unreachables to figure out what was what, and noted ones that didn't fire.

I've got some more suggestions to improve the tests, mostly minor, but the actual code changes LGTM, cheers!

Comment on lines 13 to 16
%part1 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %deinterleaved.vec, 0
%part2 = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %deinterleaved.vec, 1
%add = add <vscale x 16 x i8> %part1, %part2
ret <vscale x 16 x i8> %add
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the extracts + add aren't necessary for these tests, I think you should just return the tuple like we do in llvm/test/CodeGen/AArch64/sve-vector-deinterleave.ll

%add1 = add <vscale x 16 x i8> %part1, %part2
%add2 = add <vscale x 16 x i8> %part3, %part4
%add3 = add <vscale x 16 x i8> %add1, %add2
call void @llvm.masked.store.nxv64i8.p0(<vscale x 64 x i8> zeroinitializer, ptr %p2, i32 1, <vscale x 64 x i1> %interleaved.mask)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI there's an intrinsic llvm.fake.use for this purpose: https://godbolt.org/z/cnbMdKP6s

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Didn't know about that.

ret <vscale x 16 x i8> %add1
}

define <vscale x 16 x i8> @foo_ld2_nxv16i8_bad_mask(<vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask2, ptr %p, ptr %p2) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: a comment would be helpful. Same for tests below, I added comments as I went thru and figured out what they're doing which others may find useful.

Suggested change
define <vscale x 16 x i8> @foo_ld2_nxv16i8_bad_mask(<vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask2, ptr %p, ptr %p2) {
; Mask must be an interleave of identical masks.
define <vscale x 16 x i8> @foo_ld2_nxv16i8_bad_mask(<vscale x 16 x i1> %mask, <vscale x 16 x i1> %mask2, ptr %p, ptr %p2) {

ret <vscale x 16 x i8> %add1
}

define <vscale x 16 x i8> @foo_ld4_nxv16i8_bad_mask2(<vscale x 32 x i1> %mask, ptr %p, ptr %p2) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
define <vscale x 16 x i8> @foo_ld4_nxv16i8_bad_mask2(<vscale x 32 x i1> %mask, ptr %p, ptr %p2) {
; Number of parts in mask interleave must match deinterleave.
define <vscale x 16 x i8> @foo_ld4_nxv16i8_bad_mask2(<vscale x 32 x i1> %mask, ptr %p, ptr %p2) {

ret <vscale x 16 x i8> %add3
}

define <vscale x 8 x i8> @foo_ld2_nxv8i8(<vscale x 8 x i1> %mask, ptr %p) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
define <vscale x 8 x i8> @foo_ld2_nxv8i8(<vscale x 8 x i1> %mask, ptr %p) {
; Vectors must be at least 128-bit.
define <vscale x 8 x i8> @foo_ld2_nxv8i8(<vscale x 8 x i1> %mask, ptr %p) {

ret <vscale x 8 x i8> %add1
}

define <vscale x 16 x i8> @foo_ld2_nxv16i8_bad_passthru(<vscale x 16 x i1> %mask, ptr %p) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
define <vscale x 16 x i8> @foo_ld2_nxv16i8_bad_passthru(<vscale x 16 x i1> %mask, ptr %p) {
; Passthru must be uninitialized.
define <vscale x 16 x i8> @foo_ld2_nxv16i8_bad_passthru(<vscale x 16 x i1> %mask, ptr %p) {

}


define <vscale x 8 x i16> @foo_deinterleave2_not_load(<vscale x 16 x i16> %vec) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's already a test for this in llvm/test/CodeGen/AArch64/sve-vector-deinterleave.ll, this can be removed. Apologies if this is one you added based on my last set of comments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

For IR like this:

  %mask = ... @llvm.vector.interleave2(<vscale x 16 x i1> %a, <vscale x 16 x i1> %a)
  %vec = ... @llvm.masked.load(..., <vscale x 32 x i1> %mask, ...)
  %dvec = ... @llvm.vector.deinterleave2(<vscale x 32 x i8> %vec)

where we're deinterleaving a wide masked load of the supported type
and with an interleaved mask we can lower this directly to a ld2b
instruction. Similarly we can also support other variants of ld2
and ld4.

This PR adds a DAG combine to spot such patterns and lower to ld2X
or ld4X variants accordingly, whilst being careful to ensure the
masked load is only used by the deinterleave intrinsic.
@david-arm david-arm force-pushed the masked_deinterleave branch from b4bb182 to 19987ea Compare August 29, 2025 09:00
@c-rhodes
Copy link
Collaborator

updates LGTM, cheers!

Comment on lines 27096 to 27098
auto *SplatVal = dyn_cast<ConstantSDNode>(Mask->getOperand(0));
if (!SplatVal || SplatVal->getZExtValue() != 1)
return SDValue();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the value being splatted matter? I think you can drop this logic and just allow all splat values?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure if we should perform the DAG combine if the mask is all-false because we will then generate an intrinsic that we know is a nop and also the DAGCombiner will not understand and not optimise away. I thought that leaving the original code left a greater chance of the whole lot being deleted.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expect the all-false case to be caught by instcombine so there would be no need for the code generator to worry about it.

ElementCount EC = Mask.getValueType().getVectorElementCount();
assert((EC.getKnownMinValue() % NumParts) == 0 &&
"Expected element count divisible by number of parts");
EC = ElementCount::getScalable(EC.getKnownMinValue() / NumParts);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
EC = ElementCount::getScalable(EC.getKnownMinValue() / NumParts);
EC = EC.divideCoefficientBy(NumParts);

Comment on lines 27100 to 27101
assert((EC.getKnownMinValue() % NumParts) == 0 &&
"Expected element count divisible by number of parts");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert((EC.getKnownMinValue() % NumParts) == 0 &&
"Expected element count divisible by number of parts");
assert((EC.isKnownMultipleOf(NumParts) &&
"Expected element count divisible by number of parts");

Although based on my "do we need to strengthen the type checking" comment, this assert might be unnecessary.

@@ -27015,6 +27016,120 @@ performScalarToVectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
return NVCAST;
}

static SDValue performVectorDeinterleaveCombine(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the combine is only useful prior to type legalisation it is worth having an early bail out for when it's run after that point.

// we usually generate shuffles instead.
unsigned MinNumElements = SubVecTy.getVectorMinNumElements();
if (!SubVecTy.isScalableVT() ||
SubVecTy.getSizeInBits().getKnownMinValue() != 128 || MinNumElements == 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to strengthen the type checking here because being before type legalisation there must be many vscale*128-bit sized types that we'll not support. Perhaps replace the MinNumElements check with !isTypeLegal(SubVecTy)?

// At the moment we're unlikely to see a fixed-width vector deinterleave as
// we usually generate shuffles instead.
unsigned MinNumElements = SubVecTy.getVectorMinNumElements();
if (!SubVecTy.isScalableVT() ||
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (!SubVecTy.isScalableVT() ||
if (!SubVecTy.isScalableVector() ||

Comment on lines 27055 to 27056
MaskedLoad->getExtensionType() != ISD::NON_EXTLOAD ||
MaskedLoad->getAddressingMode() != ISD::UNINDEXED ||
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
MaskedLoad->getExtensionType() != ISD::NON_EXTLOAD ||
MaskedLoad->getAddressingMode() != ISD::UNINDEXED ||
!MaskedLoad->isSimple() || !MaskedLoad->isNormalLoad() ||

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isNormalLoad doesn't work for masked loads, so I've added a new isNormalMaskedLoad

Comment on lines 27042 to 27043
auto *Idx = cast<ConstantSDNode>(OpI->getOperand(1));
if (Idx->getZExtValue() != (I * MinNumElements))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto *Idx = cast<ConstantSDNode>(OpI->getOperand(1));
if (Idx->getZExtValue() != (I * MinNumElements))
if (OpI->getConstantOperandVal(1) != (I * MinNumElements))

Comment on lines 27089 to 27092
for (unsigned I = 1; I < NumParts; I++) {
if (InterleaveOp->getOperand(I) != InterleaveOp->getOperand(0))
return SDValue();
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (unsigned I = 1; I < NumParts; I++) {
if (InterleaveOp->getOperand(I) != InterleaveOp->getOperand(0))
return SDValue();
}
if (!llvm::all_equal(InterleaveOp->op_values()))
return SDValue();

Comment on lines 27083 to 27084
if (ConcatOp.getResNo() != I ||
ConcatOp.getNode() != InterleaveOp.getNode())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you make InterleaveOp an SDNode (which is how you're using it anyway for several of the cases by using ->) you can simplify this to if (Mask.getOperand(I) == SDValue(InterleaveOp, I))

@llvmbot llvmbot added the llvm:SelectionDAG SelectionDAGISel as well label Aug 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AArch64 llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants