Skip to content

Conversation

SamTebbs33
Copy link
Collaborator

This PR adds splitting in the AArch64 backend for the LOOP_DEPENDENCE_MASK nodes so that even large vector types can be turned into whilewr/rw.

@llvmbot
Copy link
Member

llvmbot commented Aug 12, 2025

@llvm/pr-subscribers-backend-aarch64

Author: Sam Tebbs (SamTebbs33)

Changes

This PR adds splitting in the AArch64 backend for the LOOP_DEPENDENCE_MASK nodes so that even large vector types can be turned into whilewr/rw.


Patch is 75.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/153187.diff

3 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+71-26)
  • (modified) llvm/test/CodeGen/AArch64/alias_mask.ll (+381-315)
  • (modified) llvm/test/CodeGen/AArch64/alias_mask_scalable.ll (+256-486)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 1a3095fc0895c..da587406d1844 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -5220,49 +5220,94 @@ AArch64TargetLowering::LowerLOOP_DEPENDENCE_MASK(SDValue Op,
                                                  SelectionDAG &DAG) const {
   SDLoc DL(Op);
   uint64_t EltSize = Op.getConstantOperandVal(2);
-  EVT VT = Op.getValueType();
+  EVT FullVT = Op.getValueType();
+  unsigned NumElements = FullVT.getVectorMinNumElements();
+  unsigned NumSplits = 0;
+  EVT EltVT;
   switch (EltSize) {
   case 1:
-    if (VT != MVT::v16i8 && VT != MVT::nxv16i1)
-      return SDValue();
+    EltVT = MVT::i8;
     break;
   case 2:
-    if (VT != MVT::v8i8 && VT != MVT::nxv8i1)
-      return SDValue();
+    if (NumElements >= 16)
+      NumSplits = NumElements / 16;
+    EltVT = MVT::i16;
     break;
   case 4:
-    if (VT != MVT::v4i16 && VT != MVT::nxv4i1)
-      return SDValue();
+    if (NumElements >= 8)
+      NumSplits = NumElements / 8;
+    EltVT = MVT::i32;
     break;
   case 8:
-    if (VT != MVT::v2i32 && VT != MVT::nxv2i1)
-      return SDValue();
+    if (NumElements >= 4)
+      NumSplits = NumElements / 4;
+    EltVT = MVT::i64;
     break;
   default:
     // Other element sizes are incompatible with whilewr/rw, so expand instead
     return SDValue();
   }
 
-  SDValue PtrA = Op.getOperand(0);
-  SDValue PtrB = Op.getOperand(1);
+  auto LowerToWhile = [&](EVT VT, unsigned AddrScale) {
+    SDValue PtrA = Op.getOperand(0);
+    SDValue PtrB = Op.getOperand(1);
 
-  if (VT.isScalableVT())
-    return DAG.getNode(Op.getOpcode(), DL, VT, PtrA, PtrB, Op.getOperand(2));
+    EVT StoreVT = EVT::getVectorVT(*DAG.getContext(), EltVT,
+                                   VT.getVectorMinNumElements(), false);
+    if (AddrScale > 0) {
+      unsigned Offset = StoreVT.getStoreSizeInBits() / 8 * AddrScale;
+      SDValue Addend;
 
-  // We can use the SVE whilewr/whilerw instruction to lower this
-  // intrinsic by creating the appropriate sequence of scalable vector
-  // operations and then extracting a fixed-width subvector from the scalable
-  // vector. Scalable vector variants are already legal.
-  EVT ContainerVT =
-      EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(),
-                       VT.getVectorNumElements(), true);
-  EVT WhileVT = ContainerVT.changeElementType(MVT::i1);
+      if (VT.isScalableVT())
+        Addend = DAG.getVScale(DL, MVT::i64, APInt(64, Offset));
+      else
+        Addend = DAG.getConstant(Offset, DL, MVT::i64);
 
-  SDValue Mask =
-      DAG.getNode(Op.getOpcode(), DL, WhileVT, PtrA, PtrB, Op.getOperand(2));
-  SDValue MaskAsInt = DAG.getNode(ISD::SIGN_EXTEND, DL, ContainerVT, Mask);
-  return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, MaskAsInt,
-                     DAG.getVectorIdxConstant(0, DL));
+      PtrA = DAG.getNode(ISD::ADD, DL, MVT::i64, PtrA, Addend);
+      PtrB = DAG.getNode(ISD::ADD, DL, MVT::i64, PtrB, Addend);
+    }
+
+    if (VT.isScalableVT())
+      return DAG.getNode(Op.getOpcode(), DL, VT, PtrA, PtrB, Op.getOperand(2));
+
+    // We can use the SVE whilewr/whilerw instruction to lower this
+    // intrinsic by creating the appropriate sequence of scalable vector
+    // operations and then extracting a fixed-width subvector from the scalable
+    // vector. Scalable vector variants are already legal.
+    EVT ContainerVT =
+        EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(),
+                         VT.getVectorNumElements(), true);
+    EVT WhileVT = ContainerVT.changeElementType(MVT::i1);
+
+    SDValue Mask =
+        DAG.getNode(Op.getOpcode(), DL, WhileVT, PtrA, PtrB, Op.getOperand(2));
+    SDValue MaskAsInt = DAG.getNode(ISD::SIGN_EXTEND, DL, ContainerVT, Mask);
+    return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, MaskAsInt,
+                       DAG.getVectorIdxConstant(0, DL));
+  };
+
+  if (NumSplits == 0)
+    return LowerToWhile(FullVT, 0);
+
+  SDValue FullVec = DAG.getUNDEF(FullVT);
+
+  unsigned NumElementsPerSplit = NumElements / (2 * NumSplits);
+  EVT PartVT =
+      EVT::getVectorVT(*DAG.getContext(), FullVT.getVectorElementType(),
+                       NumElementsPerSplit, FullVT.isScalableVT());
+  for (unsigned Split = 0, InsertIdx = 0; Split < NumSplits;
+       Split++, InsertIdx += 2) {
+    SDValue Low = LowerToWhile(PartVT, InsertIdx);
+    SDValue High = LowerToWhile(PartVT, InsertIdx + 1);
+    unsigned InsertIdxLow = InsertIdx * NumElementsPerSplit;
+    unsigned InsertIdxHigh = (InsertIdx + 1) * NumElementsPerSplit;
+    SDValue Insert =
+        DAG.getNode(ISD::INSERT_SUBVECTOR, DL, FullVT, FullVec, Low,
+                    DAG.getVectorIdxConstant(InsertIdxLow, DL));
+    FullVec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, FullVT, Insert, High,
+                          DAG.getVectorIdxConstant(InsertIdxHigh, DL));
+  }
+  return FullVec;
 }
 
 SDValue AArch64TargetLowering::LowerBITCAST(SDValue Op,
diff --git a/llvm/test/CodeGen/AArch64/alias_mask.ll b/llvm/test/CodeGen/AArch64/alias_mask.ll
index b1491c41135fa..ec00f9de39c22 100644
--- a/llvm/test/CodeGen/AArch64/alias_mask.ll
+++ b/llvm/test/CodeGen/AArch64/alias_mask.ll
@@ -511,44 +511,13 @@ entry:
 define <16 x i1> @whilewr_16_split(ptr %a, ptr %b) {
 ; CHECK-SVE-LABEL: whilewr_16_split:
 ; CHECK-SVE:       // %bb.0: // %entry
-; CHECK-SVE-NEXT:    index z0.d, #0, #1
-; CHECK-SVE-NEXT:    sub x8, x1, x0
-; CHECK-SVE-NEXT:    add x8, x8, x8, lsr #63
-; CHECK-SVE-NEXT:    asr x8, x8, #1
-; CHECK-SVE-NEXT:    mov z1.d, z0.d
-; CHECK-SVE-NEXT:    mov z2.d, z0.d
-; CHECK-SVE-NEXT:    mov z4.d, z0.d
-; CHECK-SVE-NEXT:    mov z5.d, z0.d
-; CHECK-SVE-NEXT:    mov z6.d, z0.d
-; CHECK-SVE-NEXT:    mov z7.d, z0.d
-; CHECK-SVE-NEXT:    mov z16.d, z0.d
-; CHECK-SVE-NEXT:    dup v3.2d, x8
-; CHECK-SVE-NEXT:    cmp x8, #1
-; CHECK-SVE-NEXT:    add z1.d, z1.d, #12 // =0xc
-; CHECK-SVE-NEXT:    add z2.d, z2.d, #10 // =0xa
-; CHECK-SVE-NEXT:    add z4.d, z4.d, #8 // =0x8
-; CHECK-SVE-NEXT:    add z5.d, z5.d, #6 // =0x6
-; CHECK-SVE-NEXT:    add z6.d, z6.d, #4 // =0x4
-; CHECK-SVE-NEXT:    add z7.d, z7.d, #2 // =0x2
-; CHECK-SVE-NEXT:    add z16.d, z16.d, #14 // =0xe
-; CHECK-SVE-NEXT:    cmhi v0.2d, v3.2d, v0.2d
-; CHECK-SVE-NEXT:    cset w8, lt
-; CHECK-SVE-NEXT:    cmhi v1.2d, v3.2d, v1.2d
-; CHECK-SVE-NEXT:    cmhi v2.2d, v3.2d, v2.2d
-; CHECK-SVE-NEXT:    cmhi v4.2d, v3.2d, v4.2d
-; CHECK-SVE-NEXT:    cmhi v5.2d, v3.2d, v5.2d
-; CHECK-SVE-NEXT:    cmhi v6.2d, v3.2d, v6.2d
-; CHECK-SVE-NEXT:    cmhi v16.2d, v3.2d, v16.2d
-; CHECK-SVE-NEXT:    cmhi v3.2d, v3.2d, v7.2d
-; CHECK-SVE-NEXT:    uzp1 v2.4s, v4.4s, v2.4s
-; CHECK-SVE-NEXT:    uzp1 v4.4s, v6.4s, v5.4s
-; CHECK-SVE-NEXT:    uzp1 v1.4s, v1.4s, v16.4s
-; CHECK-SVE-NEXT:    uzp1 v0.4s, v0.4s, v3.4s
-; CHECK-SVE-NEXT:    uzp1 v1.8h, v2.8h, v1.8h
-; CHECK-SVE-NEXT:    uzp1 v0.8h, v0.8h, v4.8h
+; CHECK-SVE-NEXT:    add x8, x1, #16
+; CHECK-SVE-NEXT:    add x9, x0, #16
+; CHECK-SVE-NEXT:    whilewr p0.h, x0, x1
+; CHECK-SVE-NEXT:    whilewr p1.h, x9, x8
+; CHECK-SVE-NEXT:    mov z0.h, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-SVE-NEXT:    mov z1.h, p1/z, #-1 // =0xffffffffffffffff
 ; CHECK-SVE-NEXT:    uzp1 v0.16b, v0.16b, v1.16b
-; CHECK-SVE-NEXT:    dup v1.16b, w8
-; CHECK-SVE-NEXT:    orr v0.16b, v0.16b, v1.16b
 ; CHECK-SVE-NEXT:    ret
 ;
 ; CHECK-NOSVE-LABEL: whilewr_16_split:
@@ -601,54 +570,38 @@ entry:
 define <32 x i1> @whilewr_16_split2(ptr %a, ptr %b) {
 ; CHECK-SVE-LABEL: whilewr_16_split2:
 ; CHECK-SVE:       // %bb.0: // %entry
-; CHECK-SVE-NEXT:    index z0.d, #0, #1
-; CHECK-SVE-NEXT:    sub x9, x1, x0
-; CHECK-SVE-NEXT:    add x9, x9, x9, lsr #63
-; CHECK-SVE-NEXT:    asr x9, x9, #1
-; CHECK-SVE-NEXT:    mov z1.d, z0.d
-; CHECK-SVE-NEXT:    mov z2.d, z0.d
-; CHECK-SVE-NEXT:    mov z3.d, z0.d
-; CHECK-SVE-NEXT:    mov z4.d, z0.d
-; CHECK-SVE-NEXT:    mov z5.d, z0.d
-; CHECK-SVE-NEXT:    mov z7.d, z0.d
-; CHECK-SVE-NEXT:    mov z16.d, z0.d
-; CHECK-SVE-NEXT:    dup v6.2d, x9
-; CHECK-SVE-NEXT:    cmp x9, #1
-; CHECK-SVE-NEXT:    add z1.d, z1.d, #12 // =0xc
-; CHECK-SVE-NEXT:    add z2.d, z2.d, #10 // =0xa
-; CHECK-SVE-NEXT:    add z3.d, z3.d, #8 // =0x8
-; CHECK-SVE-NEXT:    add z4.d, z4.d, #6 // =0x6
-; CHECK-SVE-NEXT:    add z5.d, z5.d, #4 // =0x4
-; CHECK-SVE-NEXT:    add z7.d, z7.d, #2 // =0x2
-; CHECK-SVE-NEXT:    add z16.d, z16.d, #14 // =0xe
-; CHECK-SVE-NEXT:    cmhi v0.2d, v6.2d, v0.2d
-; CHECK-SVE-NEXT:    cset w9, lt
-; CHECK-SVE-NEXT:    cmhi v1.2d, v6.2d, v1.2d
-; CHECK-SVE-NEXT:    cmhi v2.2d, v6.2d, v2.2d
-; CHECK-SVE-NEXT:    cmhi v3.2d, v6.2d, v3.2d
-; CHECK-SVE-NEXT:    cmhi v4.2d, v6.2d, v4.2d
-; CHECK-SVE-NEXT:    cmhi v5.2d, v6.2d, v5.2d
-; CHECK-SVE-NEXT:    cmhi v7.2d, v6.2d, v7.2d
-; CHECK-SVE-NEXT:    cmhi v6.2d, v6.2d, v16.2d
-; CHECK-SVE-NEXT:    uzp1 v2.4s, v3.4s, v2.4s
-; CHECK-SVE-NEXT:    uzp1 v3.4s, v5.4s, v4.4s
-; CHECK-SVE-NEXT:    uzp1 v0.4s, v0.4s, v7.4s
-; CHECK-SVE-NEXT:    uzp1 v1.4s, v1.4s, v6.4s
-; CHECK-SVE-NEXT:    uzp1 v0.8h, v0.8h, v3.8h
-; CHECK-SVE-NEXT:    uzp1 v1.8h, v2.8h, v1.8h
-; CHECK-SVE-NEXT:    uzp1 v0.16b, v0.16b, v1.16b
-; CHECK-SVE-NEXT:    dup v1.16b, w9
+; CHECK-SVE-NEXT:    add x9, x1, #48
+; CHECK-SVE-NEXT:    add x10, x0, #48
+; CHECK-SVE-NEXT:    add x11, x1, #16
+; CHECK-SVE-NEXT:    whilewr p1.h, x10, x9
+; CHECK-SVE-NEXT:    add x9, x1, #32
+; CHECK-SVE-NEXT:    add x10, x0, #32
+; CHECK-SVE-NEXT:    add x12, x0, #16
+; CHECK-SVE-NEXT:    whilewr p0.h, x0, x1
+; CHECK-SVE-NEXT:    whilewr p2.h, x10, x9
+; CHECK-SVE-NEXT:    mov z0.h, p1/z, #-1 // =0xffffffffffffffff
 ; CHECK-SVE-NEXT:    adrp x9, .LCPI11_0
-; CHECK-SVE-NEXT:    orr v0.16b, v0.16b, v1.16b
-; CHECK-SVE-NEXT:    ldr q1, [x9, :lo12:.LCPI11_0]
+; CHECK-SVE-NEXT:    whilewr p3.h, x12, x11
+; CHECK-SVE-NEXT:    mov z2.h, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-SVE-NEXT:    mov z1.h, p2/z, #-1 // =0xffffffffffffffff
+; CHECK-SVE-NEXT:    mov z3.h, p3/z, #-1 // =0xffffffffffffffff
+; CHECK-SVE-NEXT:    uzp1 v0.16b, v1.16b, v0.16b
+; CHECK-SVE-NEXT:    uzp1 v1.16b, v2.16b, v3.16b
+; CHECK-SVE-NEXT:    ldr q2, [x9, :lo12:.LCPI11_0]
 ; CHECK-SVE-NEXT:    shl v0.16b, v0.16b, #7
+; CHECK-SVE-NEXT:    shl v1.16b, v1.16b, #7
 ; CHECK-SVE-NEXT:    cmlt v0.16b, v0.16b, #0
-; CHECK-SVE-NEXT:    and v0.16b, v0.16b, v1.16b
-; CHECK-SVE-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
-; CHECK-SVE-NEXT:    zip1 v0.16b, v0.16b, v1.16b
+; CHECK-SVE-NEXT:    cmlt v1.16b, v1.16b, #0
+; CHECK-SVE-NEXT:    and v0.16b, v0.16b, v2.16b
+; CHECK-SVE-NEXT:    and v1.16b, v1.16b, v2.16b
+; CHECK-SVE-NEXT:    ext v2.16b, v0.16b, v0.16b, #8
+; CHECK-SVE-NEXT:    ext v3.16b, v1.16b, v1.16b, #8
+; CHECK-SVE-NEXT:    zip1 v0.16b, v0.16b, v2.16b
+; CHECK-SVE-NEXT:    zip1 v1.16b, v1.16b, v3.16b
 ; CHECK-SVE-NEXT:    addv h0, v0.8h
+; CHECK-SVE-NEXT:    addv h1, v1.8h
 ; CHECK-SVE-NEXT:    str h0, [x8, #2]
-; CHECK-SVE-NEXT:    str h0, [x8]
+; CHECK-SVE-NEXT:    str h1, [x8]
 ; CHECK-SVE-NEXT:    ret
 ;
 ; CHECK-NOSVE-LABEL: whilewr_16_split2:
@@ -711,31 +664,28 @@ entry:
 define <8 x i1> @whilewr_32_split(ptr %a, ptr %b) {
 ; CHECK-SVE-LABEL: whilewr_32_split:
 ; CHECK-SVE:       // %bb.0: // %entry
-; CHECK-SVE-NEXT:    index z0.d, #0, #1
-; CHECK-SVE-NEXT:    sub x8, x1, x0
-; CHECK-SVE-NEXT:    add x9, x8, #3
-; CHECK-SVE-NEXT:    cmp x8, #0
-; CHECK-SVE-NEXT:    csel x8, x9, x8, lt
-; CHECK-SVE-NEXT:    asr x8, x8, #2
-; CHECK-SVE-NEXT:    mov z2.d, z0.d
-; CHECK-SVE-NEXT:    mov z3.d, z0.d
-; CHECK-SVE-NEXT:    mov z4.d, z0.d
-; CHECK-SVE-NEXT:    dup v1.2d, x8
-; CHECK-SVE-NEXT:    cmp x8, #1
-; CHECK-SVE-NEXT:    cset w8, lt
-; CHECK-SVE-NEXT:    add z4.d, z4.d, #6 // =0x6
-; CHECK-SVE-NEXT:    add z2.d, z2.d, #4 // =0x4
-; CHECK-SVE-NEXT:    add z3.d, z3.d, #2 // =0x2
-; CHECK-SVE-NEXT:    cmhi v0.2d, v1.2d, v0.2d
-; CHECK-SVE-NEXT:    cmhi v4.2d, v1.2d, v4.2d
-; CHECK-SVE-NEXT:    cmhi v2.2d, v1.2d, v2.2d
-; CHECK-SVE-NEXT:    cmhi v1.2d, v1.2d, v3.2d
-; CHECK-SVE-NEXT:    uzp1 v2.4s, v2.4s, v4.4s
-; CHECK-SVE-NEXT:    uzp1 v0.4s, v0.4s, v1.4s
-; CHECK-SVE-NEXT:    dup v1.8b, w8
-; CHECK-SVE-NEXT:    uzp1 v0.8h, v0.8h, v2.8h
-; CHECK-SVE-NEXT:    xtn v0.8b, v0.8h
-; CHECK-SVE-NEXT:    orr v0.8b, v0.8b, v1.8b
+; CHECK-SVE-NEXT:    whilewr p0.s, x0, x1
+; CHECK-SVE-NEXT:    add x10, x0, #16
+; CHECK-SVE-NEXT:    mov z0.s, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-SVE-NEXT:    mov w8, v0.s[1]
+; CHECK-SVE-NEXT:    mov v1.16b, v0.16b
+; CHECK-SVE-NEXT:    mov w9, v0.s[2]
+; CHECK-SVE-NEXT:    mov v1.h[1], w8
+; CHECK-SVE-NEXT:    mov w8, v0.s[3]
+; CHECK-SVE-NEXT:    mov v1.h[2], w9
+; CHECK-SVE-NEXT:    add x9, x1, #16
+; CHECK-SVE-NEXT:    whilewr p0.s, x10, x9
+; CHECK-SVE-NEXT:    mov z0.s, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-SVE-NEXT:    mov v1.h[3], w8
+; CHECK-SVE-NEXT:    fmov w8, s0
+; CHECK-SVE-NEXT:    mov w9, v0.s[1]
+; CHECK-SVE-NEXT:    mov v1.h[4], w8
+; CHECK-SVE-NEXT:    mov w8, v0.s[2]
+; CHECK-SVE-NEXT:    mov v1.h[5], w9
+; CHECK-SVE-NEXT:    mov w9, v0.s[3]
+; CHECK-SVE-NEXT:    mov v1.h[6], w8
+; CHECK-SVE-NEXT:    mov v1.h[7], w9
+; CHECK-SVE-NEXT:    xtn v0.8b, v1.8h
 ; CHECK-SVE-NEXT:    ret
 ;
 ; CHECK-NOSVE-LABEL: whilewr_32_split:
@@ -775,46 +725,51 @@ entry:
 define <16 x i1> @whilewr_32_split2(ptr %a, ptr %b) {
 ; CHECK-SVE-LABEL: whilewr_32_split2:
 ; CHECK-SVE:       // %bb.0: // %entry
-; CHECK-SVE-NEXT:    index z0.d, #0, #1
-; CHECK-SVE-NEXT:    sub x8, x1, x0
-; CHECK-SVE-NEXT:    add x9, x8, #3
-; CHECK-SVE-NEXT:    cmp x8, #0
-; CHECK-SVE-NEXT:    csel x8, x9, x8, lt
-; CHECK-SVE-NEXT:    asr x8, x8, #2
-; CHECK-SVE-NEXT:    mov z1.d, z0.d
-; CHECK-SVE-NEXT:    mov z2.d, z0.d
-; CHECK-SVE-NEXT:    mov z4.d, z0.d
-; CHECK-SVE-NEXT:    mov z5.d, z0.d
-; CHECK-SVE-NEXT:    mov z6.d, z0.d
-; CHECK-SVE-NEXT:    mov z7.d, z0.d
-; CHECK-SVE-NEXT:    mov z16.d, z0.d
-; CHECK-SVE-NEXT:    dup v3.2d, x8
-; CHECK-SVE-NEXT:    cmp x8, #1
-; CHECK-SVE-NEXT:    add z1.d, z1.d, #12 // =0xc
-; CHECK-SVE-NEXT:    add z2.d, z2.d, #10 // =0xa
-; CHECK-SVE-NEXT:    add z4.d, z4.d, #8 // =0x8
-; CHECK-SVE-NEXT:    add z5.d, z5.d, #6 // =0x6
-; CHECK-SVE-NEXT:    add z6.d, z6.d, #4 // =0x4
-; CHECK-SVE-NEXT:    add z7.d, z7.d, #2 // =0x2
-; CHECK-SVE-NEXT:    add z16.d, z16.d, #14 // =0xe
-; CHECK-SVE-NEXT:    cmhi v0.2d, v3.2d, v0.2d
-; CHECK-SVE-NEXT:    cset w8, lt
-; CHECK-SVE-NEXT:    cmhi v1.2d, v3.2d, v1.2d
-; CHECK-SVE-NEXT:    cmhi v2.2d, v3.2d, v2.2d
-; CHECK-SVE-NEXT:    cmhi v4.2d, v3.2d, v4.2d
-; CHECK-SVE-NEXT:    cmhi v5.2d, v3.2d, v5.2d
-; CHECK-SVE-NEXT:    cmhi v6.2d, v3.2d, v6.2d
-; CHECK-SVE-NEXT:    cmhi v16.2d, v3.2d, v16.2d
-; CHECK-SVE-NEXT:    cmhi v3.2d, v3.2d, v7.2d
-; CHECK-SVE-NEXT:    uzp1 v2.4s, v4.4s, v2.4s
-; CHECK-SVE-NEXT:    uzp1 v4.4s, v6.4s, v5.4s
-; CHECK-SVE-NEXT:    uzp1 v1.4s, v1.4s, v16.4s
-; CHECK-SVE-NEXT:    uzp1 v0.4s, v0.4s, v3.4s
-; CHECK-SVE-NEXT:    uzp1 v1.8h, v2.8h, v1.8h
-; CHECK-SVE-NEXT:    uzp1 v0.8h, v0.8h, v4.8h
-; CHECK-SVE-NEXT:    uzp1 v0.16b, v0.16b, v1.16b
-; CHECK-SVE-NEXT:    dup v1.16b, w8
-; CHECK-SVE-NEXT:    orr v0.16b, v0.16b, v1.16b
+; CHECK-SVE-NEXT:    add x8, x1, #32
+; CHECK-SVE-NEXT:    add x9, x0, #32
+; CHECK-SVE-NEXT:    whilewr p0.s, x0, x1
+; CHECK-SVE-NEXT:    whilewr p1.s, x9, x8
+; CHECK-SVE-NEXT:    mov z0.s, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-SVE-NEXT:    mov z1.s, p1/z, #-1 // =0xffffffffffffffff
+; CHECK-SVE-NEXT:    mov w8, v0.s[1]
+; CHECK-SVE-NEXT:    mov v2.16b, v0.16b
+; CHECK-SVE-NEXT:    mov w10, v0.s[2]
+; CHECK-SVE-NEXT:    mov w9, v1.s[1]
+; CHECK-SVE-NEXT:    mov v3.16b, v1.16b
+; CHECK-SVE-NEXT:    mov w11, v1.s[3]
+; CHECK-SVE-NEXT:    mov v2.h[1], w8
+; CHECK-SVE-NEXT:    mov w8, v1.s[2]
+; CHECK-SVE-NEXT:    mov v3.h[1], w9
+; CHECK-SVE-NEXT:    mov w9, v0.s[3]
+; CHECK-SVE-NEXT:    mov v2.h[2], w10
+; CHECK-SVE-NEXT:    add x10, x1, #16
+; CHECK-SVE-NEXT:    mov v3.h[2], w8
+; CHECK-SVE-NEXT:    add x8, x0, #16
+; CHECK-SVE-NEXT:    whilewr p0.s, x8, x10
+; CHECK-SVE-NEXT:    add x8, x1, #48
+; CHECK-SVE-NEXT:    add x10, x0, #48
+; CHECK-SVE-NEXT:    whilewr p1.s, x10, x8
+; CHECK-SVE-NEXT:    mov z0.s, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-SVE-NEXT:    mov v2.h[3], w9
+; CHECK-SVE-NEXT:    mov z1.s, p1/z, #-1 // =0xffffffffffffffff
+; CHECK-SVE-NEXT:    mov v3.h[3], w11
+; CHECK-SVE-NEXT:    fmov w9, s0
+; CHECK-SVE-NEXT:    mov w8, v0.s[1]
+; CHECK-SVE-NEXT:    fmov w10, s1
+; CHECK-SVE-NEXT:    mov w11, v1.s[1]
+; CHECK-SVE-NEXT:    mov v2.h[4], w9
+; CHECK-SVE-NEXT:    mov w9, v0.s[2]
+; CHECK-SVE-NEXT:    mov v3.h[4], w10
+; CHECK-SVE-NEXT:    mov w10, v1.s[2]
+; CHECK-SVE-NEXT:    mov v2.h[5], w8
+; CHECK-SVE-NEXT:    mov w8, v0.s[3]
+; CHECK-SVE-NEXT:    mov v3.h[5], w11
+; CHECK-SVE-NEXT:    mov w11, v1.s[3]
+; CHECK-SVE-NEXT:    mov v2.h[6], w9
+; CHECK-SVE-NEXT:    mov v3.h[6], w10
+; CHECK-SVE-NEXT:    mov v2.h[7], w8
+; CHECK-SVE-NEXT:    mov v3.h[7], w11
+; CHECK-SVE-NEXT:    uzp1 v0.16b, v2.16b, v3.16b
 ; CHECK-SVE-NEXT:    ret
 ;
 ; CHECK-NOSVE-LABEL: whilewr_32_split2:
@@ -869,55 +824,113 @@ entry:
 define <32 x i1> @whilewr_32_split3(ptr %a, ptr %b) {
 ; CHECK-SVE-LABEL: whilewr_32_split3:
 ; CHECK-SVE:       // %bb.0: // %entry
-; CHECK-SVE-NEXT:    index z0.d, #0, #1
-; CHECK-SVE-NEXT:    sub x9, x1, x0
-; CHECK-SVE-NEXT:    add x10, x9, #3
-; CHECK-SVE-NEXT:    cmp x9, #0
-; CHECK-SVE-NEXT:    csel x9, x10, x9, lt
-; CHECK-SVE-NEXT:    asr x9, x9, #2
-; CHECK-SVE-NEXT:    mov z1.d, z0.d
-; CHECK-SVE-NEXT:    mov z2.d, z0.d
-; CHECK-SVE-NEXT:    mov z3.d, z0.d
-; CHECK-SVE-NEXT:    mov z4.d, z0.d
-; CHECK-SVE-NEXT:    mov z5.d, z0.d
-; CHECK-SVE-NEXT:    mov z7.d, z0.d
-; CHECK-SVE-NEXT:    mov z16.d, z0.d
-; CHECK-SVE-NEXT:    dup v6.2d, x9
-; CHECK-SVE-NEXT:    cmp x9, #1
-; CHECK-SVE-NEXT:    add z1.d, z1.d, #12 // =0xc
-; CHECK-SVE-NEXT:    add z2.d, z2.d, #10 // =0xa
-; CHECK-SVE-NEXT:    add z3.d, z3.d, #8 // =0x8
-; CHECK-SVE-NEXT:    add z4.d, z4.d, #6 // =0x6
-; CHECK-SVE-NEXT:    add z5.d, z5.d, #4 // =0x4
-; CHECK-SVE-NEXT:    add z7.d, z7.d, #2 // =0x2
-; CHECK-SVE-NEXT:    add z16.d, z16.d, #14 // =0xe
-; CHECK-SVE-NEXT:    cmhi v0.2d, v6.2d, v0.2d
-; CHECK-SVE-NEXT:    cset w9, lt
-; CHECK-SVE-NEXT:    cmhi v1.2d, v6.2d, v1.2d
-; CHECK-SVE-NEXT:    cmhi v2.2d, v6.2d, v2.2d
-; CHECK-SVE-NEXT:    cmhi v3.2d, v6.2d, v3.2d
-; CHECK-SVE-NEXT:    cmhi v4.2d, v6.2d, v4.2d
-; CHECK-SVE-NEXT:    cmhi v5.2d, v6.2d, v5.2d
-; CHECK-SVE-NEXT:    cmhi v7.2d, v6.2d, v7.2d
-; CHECK-SVE-NEXT:    cmhi v6.2d, v6.2d, v16.2d
-; CHECK-SVE-NEXT:    uzp1 v2.4s, v3.4s, v2.4s
-; CHECK-SVE-NEXT:    uzp1 v3.4s, v5.4s, v4.4s
-; CHECK-SVE-NEXT:    uzp1 v0.4s, v0.4s, v7.4s
-; CHECK-SVE-NEXT:    uzp1 v1.4s, v1.4s, v6.4s
-; CHECK-SVE-NEXT:    uzp1 v0.8h, v0.8h, v3.8h
-; CHECK-SVE-NEXT:    uzp1 v1.8h, v2.8h, v1.8h
-; CHECK-SVE-NEXT:    uzp1 v0.16b, v0.16b, v1.16b
-; CHECK-SVE-NEXT:    dup v1.16b, w9
+; CHECK-SVE-NEXT:    whilewr p0.s, x0, x1
+; CHECK-SVE-NEXT:    add x9, x1, #96
+; CHECK-SVE-NEXT:    add x10, x0, #96
+; CHECK-SVE-NEXT:    add x11, x1, #64
+; CHECK-SVE-NEXT:    add x12, x0, #64
+; CHECK-SVE-NEXT:    whilewr p1.s, x10, x9
+; CHECK-SVE-NEXT:    add x9, x1, #32
+; CHECK-SVE-NEXT:    add x10, x0, #32
+; CHECK-SVE-NEXT:    whilewr p2.s, x12, x11
+; CHECK-SVE-NEXT:    mov z0.s, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-SVE-NEXT:    whilewr p0.s, x10, x9
+; CHECK-SVE-NEXT:    mov z4.s, p1/z, #-1 // =0xffffffffffffffff
+; CHECK-SVE-NEXT:    mov z5.s, p2/z, #-1 // =0xffffffffffffffff
+; CHECK-SVE-NEXT:    mov z1.s, p0/z, #-1 // =0xffffffffffffffff
+; CHECK-SVE-NEXT:    mov w9, v0.s[1]
+; CHECK-SVE-NEXT:    mov w12, v4.s[1]
+; CHECK-SVE-NEXT:    mov w10, v0.s[2]
+; CHECK-SVE-NEXT:    mov w13, v5.s[1]
+; CHECK-SVE-NEXT:    mov w11, v0.s[3]
+; CHECK-SVE-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-SVE-NEXT:    mov v2.16b, v4.16b
+; CHECK-SVE-NEXT:    mov w14, v1.s[1]
+; CHECK-SVE-NEXT:    mov v3.16b, v5.16b
+; CHECK-SVE-NEXT:    mov w15, v1.s[2]
+; CHECK-SVE-NEXT:    mov w16, v1.s[3]
+; CHECK-SVE-NEXT:    // kill: def $q1 killed $q1 killed $z1
+; CHECK-SVE-NEXT:    mov w17, v4.s[2]
+; CHECK-SVE-NEXT:    mov w18, v5.s[2]
+; CHECK-SVE-NEXT:    mov v0.h[1], w9
+; CHECK-SVE-NEXT:    mov v2.h[1], w12
+; CHECK-SVE-NEXT:    add x9, x1, #16
+; CHECK-SVE-NEXT:    mov v3.h[1], w13
+; CHECK-SVE-NEXT:    add x12, x0, #16
+; CHECK-SVE-NE...
[truncated]

This PR adds splitting in the AArch64 backend for the LOOP_DEPENDENCE_MASK nodes
so that even large vector types can be turned into whilewr/rw.
@SamTebbs33 SamTebbs33 force-pushed the alias-intrinsic-backend-splitting branch from 1f5a509 to bbe60d6 Compare August 18, 2025 09:47
Comment on lines 5251 to 5272
EVT FullVT = Op.getValueType();
unsigned NumElements = FullVT.getVectorMinNumElements();
unsigned NumSplits = 0;
EVT EltVT;
switch (EltSize) {
case 1:
if (VT != MVT::v16i8 && VT != MVT::nxv16i1)
return SDValue();
EltVT = MVT::i8;
break;
case 2:
if (VT != MVT::v8i8 && VT != MVT::nxv8i1)
return SDValue();
if (NumElements >= 16)
NumSplits = NumElements / 16;
EltVT = MVT::i16;
break;
case 4:
if (VT != MVT::v4i16 && VT != MVT::nxv4i1)
return SDValue();
if (NumElements >= 8)
NumSplits = NumElements / 8;
EltVT = MVT::i32;
break;
case 8:
if (VT != MVT::v2i32 && VT != MVT::nxv2i1)
return SDValue();
if (NumElements >= 4)
NumSplits = NumElements / 4;
EltVT = MVT::i64;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found the underlying pattern here somewhat obfuscated, I think this can be simplified quite a bit to:

  uint64_t EltSizeInBytes = Op.getConstantOperandVal(2);

  // Other element sizes are incompatible with whilewr/rw, so expand instead
  if (!is_contained({1u, 2u, 4u, 8u}, EltSizeInBytes))
    return SDValue();

  EVT FullVT = Op.getValueType();
  EVT EltVT = MVT::getIntegerVT(EltSizeInBytes * 8);

  unsigned NumElements = FullVT.getVectorMinNumElements();
  unsigned PredElements = getPackedSVEVectorVT(EltVT).getVectorMinNumElements();
  unsigned NumWhiles = NumElements / PredElements;

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks.

@@ -5248,49 +5248,94 @@ AArch64TargetLowering::LowerLOOP_DEPENDENCE_MASK(SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
uint64_t EltSize = Op.getConstantOperandVal(2);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: EltSize -> EltSizeInBytes (it's not obvious what the units are otherwise)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 5317 to 5337
if (NumSplits == 0)
return LowerToWhile(FullVT, 0);

SDValue FullVec = DAG.getUNDEF(FullVT);

unsigned NumElementsPerSplit = NumElements / (2 * NumSplits);
EVT PartVT =
EVT::getVectorVT(*DAG.getContext(), FullVT.getVectorElementType(),
NumElementsPerSplit, FullVT.isScalableVT());
for (unsigned Split = 0, InsertIdx = 0; Split < NumSplits;
Split++, InsertIdx += 2) {
SDValue Low = LowerToWhile(PartVT, InsertIdx);
SDValue High = LowerToWhile(PartVT, InsertIdx + 1);
unsigned InsertIdxLow = InsertIdx * NumElementsPerSplit;
unsigned InsertIdxHigh = (InsertIdx + 1) * NumElementsPerSplit;
SDValue Insert =
DAG.getNode(ISD::INSERT_SUBVECTOR, DL, FullVT, FullVec, Low,
DAG.getVectorIdxConstant(InsertIdxLow, DL));
FullVec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, FullVT, Insert, High,
DAG.getVectorIdxConstant(InsertIdxHigh, DL));
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following from my first suggestion (that adds NumWhiles) this can be simplified to:

  if (NumWhiles <= 1)
    return LowerToWhile(FullVT, 0);

  unsigned NumElementsPerSplit = NumElements / NumWhiles;
  EVT PartVT =
      EVT::getVectorVT(*DAG.getContext(), FullVT.getVectorElementType(),
                       NumElementsPerSplit, FullVT.isScalableVT());
  SDValue FullVec = DAG.getUNDEF(FullVT);
  for (unsigned I = 0; I < NumWhiles; I++) {
    SDValue While = LowerToWhile(PartVT, I);
    unsigned InsertIdx = I * NumElementsPerSplit;
    FullVec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, FullVT, FullVec, While,
                          DAG.getVectorIdxConstant(InsertIdx, DL));
  }

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is nicer but I've removed the loop now.

Comment on lines 5301 to 5304
// We can use the SVE whilewr/whilerw instruction to lower this
// intrinsic by creating the appropriate sequence of scalable vector
// operations and then extracting a fixed-width subvector from the scalable
// vector. Scalable vector variants are already legal.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to assuming SVE is available to lower fixed-size masks. What happens if that's not the case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ISD::LOOP_DEPENDENCE_{WAR,RAW}_MASK is only marked as custom if sve2 or sme is present, otherwise it will be expanded. I have added an assertion.

if (NumSplits == 0)
return LowerToWhile(FullVT, 0);

SDValue FullVec = DAG.getUNDEF(FullVT);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
SDValue FullVec = DAG.getUNDEF(FullVT);
SDValue FullVec = DAG.getPOISON(FullVT);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 5283 to 5284
EVT StoreVT = EVT::getVectorVT(*DAG.getContext(), EltVT,
VT.getVectorMinNumElements(), false);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not necessary to create StoreVT, you can do unsigned Offset = VT.getVectorMinNumElements() * EltSize * AddrScale.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I wanted to be safe and use getStoreSizeInBits() in case that was ever different from EltSize * AddrScale * NumElements. Done.

@SamTebbs33 SamTebbs33 force-pushed the alias-intrinsic-backend-splitting branch from fbd622b to 1898d9c Compare August 27, 2025 09:32
@@ -765,3 +456,75 @@ entry:
%0 = call <vscale x 16 x i1> @llvm.loop.dependence.war.mask.nxv16i1(ptr %a, ptr %b, i64 3)
ret <vscale x 16 x i1> %0
}

; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

// 128-bit vector. Shorter vector lengths can be widened then extracted
FullVT = FullVT.getDoubleNumVectorElementsVT(*DAG.getContext());

auto LowerToWhile = [&](EVT VT, unsigned AddrScale) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To simplify the logic a bit, can you split the concerns of "lowering to a fixed-length operation" and "splitting the operation in half", such that this function only implements the latter? That would make this function into one that merely adds AddrScale (which is probably better renamed to Offset) and returns a LOOP_DEPENDENCE_[RAW|WAR]MASK node. These nodes will then be revisited again, and if lowering to a fixed-length vector is needed, that can be handled separately in LowerLOOP_DEPENDENCE_MASK.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


SDValue Result;
if (!Split) {
Result = LowerToWhile(FullVT, 0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If no splitting is required, do we need to call LowerToWhile, or can we just return the node?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was being called to re-use the containerisation, not needed now with the new approach.

Result = LowerToWhile(FullVT, 0);
} else {

if (Split) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The big difference with this revision is that you keep splitting masks, then only once you can't split anymore do you attempt to containerize fixed-with vectors into scalable vectors. Your previous approach would containerize fixed-width vectors as part of the first split, which leads to a better lowering.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I separated the splitting and containerisation logic as per the request here. I've experimented with adding another lambda that containerises without re-entering and that seems to have fixed the codegen.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants