Skip to content

Conversation

SamTebbs33
Copy link
Collaborator

@SamTebbs33 SamTebbs33 commented Jul 7, 2025

This PR bundles sub reductions into the VPExpressionRecipe class and adjusts the cost functions to take the negation into account.

Stacked PRs:

  1. [LV] Create in-loop sub reductions #147026
  2. -> [LV] Bundle sub reductions into VPExpressionRecipe #147255
  3. [LV] Bundle partial reductions inside VPExpressionRecipe #147302
  4. [LV] Use VPReductionRecipe for partial reductions #147513

@llvmbot
Copy link
Member

llvmbot commented Jul 7, 2025

@llvm/pr-subscribers-vectorizers
@llvm/pr-subscribers-backend-aarch64
@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-backend-arm

Author: Sam Tebbs (SamTebbs33)

Changes

This PR bundles sub reductions into the VPExpressionRecipe class and adjusts the cost functions to take the negation into account.


Patch is 23.85 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147255.diff

14 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+3-1)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+1-1)
  • (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+3)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+3-2)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+5-2)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+1-1)
  • (modified) llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp (+5-2)
  • (modified) llvm/lib/Target/ARM/ARMTargetTransformInfo.h (+1)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+3-3)
  • (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+11)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+32-3)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp (+23-10)
  • (modified) llvm/lib/Transforms/Vectorize/VectorCombine.cpp (+2-2)
  • (modified) llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll (+143)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index c43870392361d..3cc0ea01953c3 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1645,8 +1645,10 @@ class TargetTransformInfo {
   /// extensions. This is the cost of as:
   /// ResTy vecreduce.add(mul (A, B)).
   /// ResTy vecreduce.add(mul(ext(Ty A), ext(Ty B)).
+  /// The multiply can optionally be negated, which signifies that it is a sub
+  /// reduction.
   LLVM_ABI InstructionCost getMulAccReductionCost(
-      bool IsUnsigned, Type *ResTy, VectorType *Ty,
+      bool IsUnsigned, Type *ResTy, VectorType *Ty, bool Negated,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const;
 
   /// Calculate the cost of an extended reduction pattern, similar to
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 12f87226c5f57..fd22981a5dbf3 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -960,7 +960,7 @@ class TargetTransformInfoImplBase {
 
   virtual InstructionCost
   getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *Ty,
-                         TTI::TargetCostKind CostKind) const {
+                         bool Negated, TTI::TargetCostKind CostKind) const {
     return 1;
   }
 
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index bf958e100f2ac..a9c9fa6d1db0d 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -3116,7 +3116,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
 
   InstructionCost
   getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *Ty,
+                         bool Negated,
                          TTI::TargetCostKind CostKind) const override {
+    if (Negated)
+      return InstructionCost::getInvalid(CostKind);
     // Without any native support, this is equivalent to the cost of
     // vecreduce.add(mul(ext(Ty A), ext(Ty B))) or
     // vecreduce.add(mul(A, B)).
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 3ebd9d487ba04..ba0d070bffe6d 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1274,9 +1274,10 @@ InstructionCost TargetTransformInfo::getExtendedReductionCost(
 }
 
 InstructionCost TargetTransformInfo::getMulAccReductionCost(
-    bool IsUnsigned, Type *ResTy, VectorType *Ty,
+    bool IsUnsigned, Type *ResTy, VectorType *Ty, bool Negated,
     TTI::TargetCostKind CostKind) const {
-  return TTIImpl->getMulAccReductionCost(IsUnsigned, ResTy, Ty, CostKind);
+  return TTIImpl->getMulAccReductionCost(IsUnsigned, ResTy, Ty, Negated,
+                                         CostKind);
 }
 
 InstructionCost
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 380faa6cf6939..d9a367535baf4 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5316,8 +5316,10 @@ InstructionCost AArch64TTIImpl::getExtendedReductionCost(
 
 InstructionCost
 AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
-                                       VectorType *VecTy,
+                                       VectorType *VecTy, bool Negated,
                                        TTI::TargetCostKind CostKind) const {
+  if (Negated)
+    return InstructionCost::getInvalid(CostKind);
   EVT VecVT = TLI->getValueType(DL, VecTy);
   EVT ResVT = TLI->getValueType(DL, ResTy);
 
@@ -5332,7 +5334,8 @@ AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
       return LT.first + 2;
   }
 
-  return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, VecTy, CostKind);
+  return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, VecTy, Negated,
+                                       CostKind);
 }
 
 InstructionCost
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 9ada70bd7086a..8bb31d2a3dac5 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -447,7 +447,7 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
                            TTI::TargetCostKind CostKind) const override;
 
   InstructionCost getMulAccReductionCost(
-      bool IsUnsigned, Type *ResTy, VectorType *Ty,
+      bool IsUnsigned, Type *ResTy, VectorType *Ty, bool Negated,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const override;
 
   InstructionCost
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index 203fb76d7be86..27eb22b5f9986 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -1884,8 +1884,10 @@ InstructionCost ARMTTIImpl::getExtendedReductionCost(
 
 InstructionCost
 ARMTTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
-                                   VectorType *ValTy,
+                                   VectorType *ValTy, bool Negated,
                                    TTI::TargetCostKind CostKind) const {
+  if (Negated)
+    return InstructionCost::getInvalid(CostKind);
   EVT ValVT = TLI->getValueType(DL, ValTy);
   EVT ResVT = TLI->getValueType(DL, ResTy);
 
@@ -1906,7 +1908,8 @@ ARMTTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
       return ST->getMVEVectorCostFactor(CostKind) * LT.first;
   }
 
-  return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, ValTy, CostKind);
+  return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, ValTy, Negated,
+                                       CostKind);
 }
 
 InstructionCost
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index ca06b9e3cb661..43f47f3e7aa6f 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -299,6 +299,7 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
                            TTI::TargetCostKind CostKind) const override;
   InstructionCost
   getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *ValTy,
+                         bool Negated,
                          TTI::TargetCostKind CostKind) const override;
 
   InstructionCost
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 1cfbcf1336620..0adff8d957e98 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -5538,7 +5538,7 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
                              TTI::CastContextHint::None, CostKind, RedOp);
 
     InstructionCost RedCost = TTI.getMulAccReductionCost(
-        IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, CostKind);
+        IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, false, CostKind);
 
     if (RedCost.isValid() &&
         RedCost < ExtCost * 2 + MulCost + Ext2Cost + BaseCost)
@@ -5583,7 +5583,7 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
           TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);
 
       InstructionCost RedCost = TTI.getMulAccReductionCost(
-          IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, CostKind);
+          IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, false, CostKind);
       InstructionCost ExtraExtCost = 0;
       if (Op0Ty != LargestOpTy || Op1Ty != LargestOpTy) {
         Instruction *ExtraExtOp = (Op0Ty != LargestOpTy) ? Op0 : Op1;
@@ -5602,7 +5602,7 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
           TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);
 
       InstructionCost RedCost = TTI.getMulAccReductionCost(
-          true, RdxDesc.getRecurrenceType(), VectorTy, CostKind);
+          true, RdxDesc.getRecurrenceType(), VectorTy, false, CostKind);
 
       if (RedCost.isValid() && RedCost < MulCost + BaseCost)
         return I == RetI ? RedCost : 0;
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index d460573f5bec6..1bc926db301d8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2757,6 +2757,12 @@ class VPExpressionRecipe : public VPSingleDefRecipe {
     /// vector operands, performing a reduction.add on the result, and adding
     /// the scalar result to a chain.
     MulAccReduction,
+    /// Represent an inloop multiply-accumulate reduction, multiplying the
+    /// extended vector operands, negating the multiplication, performing a
+    /// reduction.add
+    /// on the result, and adding
+    /// the scalar result to a chain.
+    ExtNegatedMulAccReduction,
   };
 
   /// Type of the expression.
@@ -2780,6 +2786,11 @@ class VPExpressionRecipe : public VPSingleDefRecipe {
                      VPWidenRecipe *Mul, VPReductionRecipe *Red)
       : VPExpressionRecipe(ExpressionTypes::ExtMulAccReduction,
                            {Ext0, Ext1, Mul, Red}) {}
+  VPExpressionRecipe(VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1,
+                     VPWidenRecipe *Mul, VPWidenRecipe *Sub,
+                     VPReductionRecipe *Red)
+      : VPExpressionRecipe(ExpressionTypes::ExtNegatedMulAccReduction,
+                           {Ext0, Ext1, Mul, Sub, Red}) {}
 
   ~VPExpressionRecipe() override {
     for (auto *R : reverse(ExpressionRecipes))
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 318e8171e098d..c20b1920c3791 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2672,13 +2672,17 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF,
         RedTy, SrcVecTy, std::nullopt, Ctx.CostKind);
   }
   case ExpressionTypes::MulAccReduction:
-    return Ctx.TTI.getMulAccReductionCost(false, RedTy, SrcVecTy, Ctx.CostKind);
+    return Ctx.TTI.getMulAccReductionCost(false, RedTy, SrcVecTy, false,
+                                          Ctx.CostKind);
 
-  case ExpressionTypes::ExtMulAccReduction:
+  case ExpressionTypes::ExtNegatedMulAccReduction:
+  case ExpressionTypes::ExtMulAccReduction: {
+    bool Negated = ExpressionType == ExpressionTypes::ExtNegatedMulAccReduction;
     return Ctx.TTI.getMulAccReductionCost(
         cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() ==
             Instruction::ZExt,
-        RedTy, SrcVecTy, Ctx.CostKind);
+        RedTy, SrcVecTy, Negated, Ctx.CostKind);
+  }
   }
   llvm_unreachable("Unknown VPExpressionRecipe::ExpressionTypes enum");
 }
@@ -2725,6 +2729,31 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent,
     O << ")";
     break;
   }
+  case ExpressionTypes::ExtNegatedMulAccReduction: {
+    getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker);
+    O << " + ";
+    O << "reduce."
+      << Instruction::getOpcodeName(
+             RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()))
+      << " (sub (0, mul";
+    auto *Mul = cast<VPWidenRecipe>(ExpressionRecipes[2]);
+    Mul->printFlags(O);
+    O << "(";
+    getOperand(0)->printAsOperand(O, SlotTracker);
+    auto *Ext0 = cast<VPWidenCastRecipe>(ExpressionRecipes[0]);
+    O << " " << Instruction::getOpcodeName(Ext0->getOpcode()) << " to "
+      << *Ext0->getResultType() << "), (";
+    getOperand(1)->printAsOperand(O, SlotTracker);
+    auto *Ext1 = cast<VPWidenCastRecipe>(ExpressionRecipes[1]);
+    O << " " << Instruction::getOpcodeName(Ext1->getOpcode()) << " to "
+      << *Ext1->getResultType() << ")";
+    if (Red->isConditional()) {
+      O << ", ";
+      Red->getCondOp()->printAsOperand(O, SlotTracker);
+    }
+    O << "))";
+    break;
+  }
   case ExpressionTypes::MulAccReduction:
   case ExpressionTypes::ExtMulAccReduction: {
     getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker);
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 931d4d42f56e4..a09d2037e97b4 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -2908,16 +2908,17 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
 
   // Clamp the range if using multiply-accumulate-reduction is profitable.
   auto IsMulAccValidAndClampRange =
-      [&](bool isZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
-          VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt) -> bool {
+      [&](bool IsZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
+          VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt,
+          bool Negated = false) -> bool {
     return LoopVectorizationPlanner::getDecisionAndClampRange(
         [&](ElementCount VF) {
           TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
           Type *SrcTy =
               Ext0 ? Ctx.Types.inferScalarType(Ext0->getOperand(0)) : RedTy;
           auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
-          InstructionCost MulAccCost =
-              Ctx.TTI.getMulAccReductionCost(isZExt, RedTy, SrcVecTy, CostKind);
+          InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
+              IsZExt, RedTy, SrcVecTy, Negated, CostKind);
           InstructionCost MulCost = Mul->computeCost(VF, Ctx);
           InstructionCost RedCost = Red->computeCost(VF, Ctx);
           InstructionCost ExtCost = 0;
@@ -2935,14 +2936,22 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
   };
 
   VPValue *VecOp = Red->getVecOp();
+  VPValue *Mul = nullptr;
+  VPValue *Sub = nullptr;
   VPValue *A, *B;
+  // Sub reductions will have a sub between the add reduction and vec op.
+  if (match(VecOp,
+            m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(Mul))))
+    Sub = VecOp;
+  else
+    Mul = VecOp;
   // Try to match reduce.add(mul(...)).
-  if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
+  if (match(Mul, m_Mul(m_VPValue(A), m_VPValue(B)))) {
     auto *RecipeA =
         dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
     auto *RecipeB =
         dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
-    auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe());
+    auto *MulR = cast<VPWidenRecipe>(Mul->getDefiningRecipe());
 
     // Match reduce.add(mul(ext, ext)).
     if (RecipeA && RecipeB &&
@@ -2951,12 +2960,16 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
         match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
         IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
                                        Instruction::CastOps::ZExt,
-                                   Mul, RecipeA, RecipeB, nullptr)) {
-      return new VPExpressionRecipe(RecipeA, RecipeB, Mul, Red);
+                                   MulR, RecipeA, RecipeB, nullptr, Sub)) {
+      if (Sub)
+        return new VPExpressionRecipe(
+            RecipeA, RecipeB, MulR,
+            cast<VPWidenRecipe>(Sub->getDefiningRecipe()), Red);
+      return new VPExpressionRecipe(RecipeA, RecipeB, MulR, Red);
     }
     // Match reduce.add(mul).
-    if (IsMulAccValidAndClampRange(true, Mul, nullptr, nullptr, nullptr))
-      return new VPExpressionRecipe(Mul, Red);
+    if (IsMulAccValidAndClampRange(true, MulR, nullptr, nullptr, nullptr, Sub))
+      return new VPExpressionRecipe(MulR, Red);
   }
   // Match reduce.add(ext(mul(ext(A), ext(B)))).
   // All extend recipes must have same opcode or A == B
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index b2fced47b9527..7953aec48c8b0 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1401,8 +1401,8 @@ static void analyzeCostOfVecReduction(const IntrinsicInst &II,
                              TTI::CastContextHint::None, CostKind, RedOp);
 
     CostBeforeReduction = ExtCost * 2 + MulCost + Ext2Cost;
-    CostAfterReduction =
-        TTI.getMulAccReductionCost(IsUnsigned, II.getType(), ExtType, CostKind);
+    CostAfterReduction = TTI.getMulAccReductionCost(IsUnsigned, II.getType(),
+                                                    ExtType, false, CostKind);
     return;
   }
   CostAfterReduction = TTI.getArithmeticReductionCost(ReductionOpc, VecRedTy,
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll
index 4af3fa9202c77..8059ac12ecd2e 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll
@@ -416,3 +416,146 @@ exit:
   %r.0.lcssa = phi i64 [ %rdx.next, %loop ]
   ret i64 %r.0.lcssa
 }
+
+define i32 @print_mulacc_sub(ptr %a, ptr %b) {
+; CHECK:      VPlan 'Initial VPlan for VF={4},UF>=1' {
+; CHECK-NEXT: Live-in vp<%0> = VF
+; CHECK-NEXT: Live-in vp<%1> = VF * UF
+; CHECK-NEXT: Live-in vp<%2> = vector-trip-count
+; CHECK-NEXT: Live-in ir<1024> = original trip-count
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<entry>:
+; CHECK-NEXT: Successor(s): scalar.ph, vector.ph
+; CHECK-EMPTY:
+; CHECK-NEXT: vector.ph:
+; CHECK-NEXT:   EMIT vp<%3> = reduction-start-vector ir<0>, ir<0>, ir<1>
+; CHECK-NEXT: Successor(s): vector loop
+; CHECK-EMPTY:
+; CHECK-NEXT: <x1> vector loop: {
+; CHECK-NEXT:   vector.body:
+; CHECK-NEXT:     EMIT vp<%4> = CANONICAL-INDUCTION ir<0>, vp<%index.next>
+; CHECK-NEXT:     WIDEN-REDUCTION-PHI ir<%accum> = phi vp<%3>, vp<%8>
+; CHECK-NEXT:     vp<%5> = SCALAR-STEPS vp<%4>, ir<1>, vp<%0>
+; CHECK-NEXT:     CLONE ir<%gep.a> = getelementptr ir<%a>, vp<%5>
+; CHECK-NEXT:     vp<%6> = vector-pointer ir<%gep.a>
+; CHECK-NEXT:     WIDEN ir<%load.a> = load vp<%6>
+; CHECK-NEXT:     CLONE ir<%gep.b> = getelementptr ir<%b>, vp<%5>
+; CHECK-NEXT:     vp<%7> = vector-pointer ir<%gep.b>
+; CHECK-NEXT:     WIDEN ir<%load.b> = load vp<%7>
+; CHECK-NEXT:     EXPRESSION vp<%8> = ir<%accum> + reduce.add (sub (0, mul (ir<%load.b> zext to i32), (ir<%load.a> zext to i32)))
+; CHECK-NEXT:     EMIT vp<%index.next> = add nuw vp<%4>, vp<%1>
+; CHECK-NEXT:     EMIT branch-on-count vp<%index.next>, vp<%2>
+; CHECK-NEXT:   No successors
+; CHECK-NEXT: }
+; CHECK-NEXT: Successor(s): middle.block
+; CHECK-EMPTY:
+; CHECK-NEXT: middle.block:
+; CHECK-NEXT:   EMIT vp<%10> = compute-reduction-result ir<%accum>, vp<%8>
+; CHECK-NEXT:   EMIT vp<%cmp.n> = icmp eq ir<1024>, vp<%2>
+; CHECK-NEXT:   EMIT branch-on-cond vp<%cmp.n>
+; CHECK-NEXT: Successor(s): ir-bb<for.exit>, scalar.ph
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.exit>:
+; CHECK-NEXT:   IR   %add.lcssa = phi i32 [ %add, %for.body ] (extra operand: vp<%10> from middle.block)
+; CHECK-NEXT: No successors
+; CHECK-EMPTY:
+; CHECK-NEXT: scalar.ph:
+; CHECK-NEXT:   EMIT-SCALAR vp<%bc.resume.val> = phi [ vp<%2>, middle.block ], [ ir<0>, ir-bb<entry> ]
+; CHECK-NEXT:   EMIT-SCALAR vp<%bc.merge.rdx> = phi [ vp<%10>, middle.block ], [ ir<0>, ir-bb<entry> ]
+; CHECK-NEXT: Successor(s): ir-bb<for.body>
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.body>:
+; CHECK-NEXT:   IR   %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ] (extra operand: vp<%bc.resume.val> from scalar.ph)
+; CHECK-NEXT:   IR   %accum = phi i32 [ 0, %entry ], [ %add, %for.body ] (extra operand: vp<%bc.merge.rdx> from scalar.ph)
+; CHECK-NEXT:   IR   %gep.a = getelementptr i8, ptr %a, i64 %iv
+; CHECK-NEXT:   IR   %load.a = load i8, ptr %gep.a, align 1
+; CHECK-NEXT:   IR   %ext.a = zext i8 %load.a to i32
+; CHECK-NEXT:   IR   %gep.b = getelementptr i8, ptr %b, i64 %iv
+; CHECK-NEXT:   IR   %load.b = load i8, ptr %gep.b, align 1
+; CHECK-NEXT:   IR   %ext.b = zext i8 %load.b to i32
+; CHECK-NEXT:   IR   %mul = mul i32 %ext.b, %ext.a
+; CHECK-NEXT:   IR   %add = sub i32 %accum, %mul
+; CHECK-NEXT:   IR   %iv.next = add i64 %iv, 1
+; CHECK-NEXT:   IR   %exitcond.not = icmp eq i64 %iv.next, 1024
+; CHECK-NEXT: No successors
+; CH...
[truncated]

@@ -2725,6 +2729,31 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent,
O << ")";
break;
}
case ExpressionTypes::ExtNegatedMulAccReduction: {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to commonise this with the ExtMulAccReduction case if the only difference is a negate?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was my initial approach but it required checking the number of operands to know if there was a sub or not, and I was asked to create an expression type to not rely on operand ordering being stable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you still could re-use the code for printing, by just checking the expression type to decide whether to print the sub or not.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if you've seen the comment above?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did miss this, thanks. It should be irrelevant now that that expression type is gone.

CostAfterReduction =
TTI.getMulAccReductionCost(IsUnsigned, II.getType(), ExtType, CostKind);
CostAfterReduction = TTI.getMulAccReductionCost(IsUnsigned, II.getType(),
ExtType, false, CostKind);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Probably better written as /*Negated=*/false

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

LLVM_ABI InstructionCost getMulAccReductionCost(
bool IsUnsigned, Type *ResTy, VectorType *Ty,
bool IsUnsigned, Type *ResTy, VectorType *Ty, bool Negated,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth keeping the booleans together, i.e. next to IsUnsigned?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, done.

TTI::TargetCostKind CostKind) const override {
if (Negated)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we add a cost for this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I've added a cost for the sub.

@@ -5538,7 +5538,7 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
TTI::CastContextHint::None, CostKind, RedOp);

InstructionCost RedCost = TTI.getMulAccReductionCost(
IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, CostKind);
IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, false, CostKind);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: /*Negated=*/false and same for other below.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

/// Represent an inloop multiply-accumulate reduction, multiplying the
/// extended vector operands, negating the multiplication, performing a
/// reduction.add
/// on the result, and adding
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Formatting of the comment looks a bit odd - can you fix it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

SamTebbs33 added a commit to SamTebbs33/llvm-project that referenced this pull request Jul 7, 2025
This PR bundles partial reductions inside the VPExpressionRecipe class.

Depends on llvm#147255 .
@SamTebbs33 SamTebbs33 force-pushed the users/SamTebbs33/sub-reductions branch from 39f3dab to 0b93c24 Compare July 22, 2025 15:43
@SamTebbs33 SamTebbs33 force-pushed the expression-recipe-sub branch from b7c9820 to fa30b51 Compare July 22, 2025 15:45
@SamTebbs33
Copy link
Collaborator Author

Ping

SamTebbs33 added a commit that referenced this pull request Aug 12, 2025
This PR allows the loop vectorizer to handle in-loop sub reductions by
forming a normal in-loop add reduction with a negated input.

Stacked PRs:
1. -> #147026
2. #147255
3. #147302
4. #147513
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Aug 12, 2025
This PR allows the loop vectorizer to handle in-loop sub reductions by
forming a normal in-loop add reduction with a negated input.

Stacked PRs:
1. -> llvm/llvm-project#147026
2. llvm/llvm-project#147255
3. llvm/llvm-project#147302
4. llvm/llvm-project#147513
TTI::TargetCostKind CostKind) const override {
// Without any native support, this is equivalent to the cost of
// vecreduce.add(mul(ext(Ty A), ext(Ty B))) or
// vecreduce.add(mul(A, B)).
// vecreduce.add(mul(A, B)) with an optional negation of the mul.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: The wording makes it seem like the optional negation only applies to the second form.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@SamTebbs33
Copy link
Collaborator Author

Apologies for the noise.

@SamTebbs33 SamTebbs33 changed the base branch from users/SamTebbs33/sub-reductions to main August 14, 2025 09:16
@SamTebbs33
Copy link
Collaborator Author

Ping

entry:
br label %for.body

for.body: ; preds = %for.body, %entry
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for.body: ; preds = %for.body, %entry
loop:

nit: consistency with other functions in file

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

%exitcond.not = icmp eq i64 %iv.next, 1024
br i1 %exitcond.not, label %for.exit, label %for.body

for.exit: ; preds = %for.body
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for.exit: ; preds = %for.body
exit:

nit: consistency with other functions in file

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -416,3 +416,146 @@ exit:
%r.0.lcssa = phi i64 [ %rdx.next, %loop ]
ret i64 %r.0.lcssa
}

define i32 @print_mulacc_sub(ptr %a, ptr %b) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add a test that checks the generated code? IIUC there should also be changes in costing/vectorization factors we chose, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, the codegen changes will be covered by the existing tests, just curious if it would be possible to add a test that benefits from the cost changes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tried but haven't been able to come up with a test that is different in the VF chosen without these changes. I reckon that it will be easier once perhaps the AArch64 or ARM getMulAccReductionCost functions accept the sub version.

@@ -2725,6 +2729,31 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent,
O << ")";
break;
}
case ExpressionTypes::ExtNegatedMulAccReduction: {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if you've seen the comment above?

@fhahn
Copy link
Contributor

fhahn commented Aug 20, 2025

(not sure why, but it looks like the precommit tests on Linux/Windows have not been triggered for some reason, but the libx++ ones have; may be solve by updating the latest main again?)

@SamTebbs33 SamTebbs33 force-pushed the expression-recipe-sub branch from 1d7cb25 to 229331e Compare August 20, 2025 19:20
@SamTebbs33
Copy link
Collaborator Author

Thanks for the review @fhahn , I'll have a look at it tomorrow. I've just rebased the patch on top of main after the sub reduction patch was merged. This involved removing the negated expression type since (non-chained) a sub reduction is now represented with a subtraction, rather than an addition with a negated input.

/// ResTy vecreduce.add(mul(ext(Ty A), ext(Ty B)).
/// getArithmeticReductionCost of an Add/Sub reduction with multiply and
/// optional extensions. This is the cost of as:
/// ResTy vecreduce.add/sub(mul (A, B)).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
/// ResTy vecreduce.add/sub(mul (A, B)).
/// ResTy vecreduce.add/sub(mul(A, B)).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

TTI::TargetCostKind CostKind) const override {
// Without any native support, this is equivalent to the cost of
// vecreduce.add(mul(ext(Ty A), ext(Ty B))) or
// vecreduce.add(mul(A, B)).
// vecreduce.add(mul(A, B)). IsNegated determines if the mul is negated.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IsNegated no longer exists.

please also add an assert that RedOpcode is either an add or a sub.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

TTI::TargetCostKind CostKind) const {
if (RedOpcode != Instruction::Add)
return InstructionCost::getInvalid(CostKind);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this returning an invalid cost, rather than adding the cost of a negation of the operand?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going with the most conservative approach at first, but I've now allowed subs here (but made sure that it's an add in the UDOT case below). We don't need to consider a negation of the operand since this function isn't used for the chained add+sub case at the moment.

@@ -3126,14 +3127,15 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
};

VPValue *VecOp = Red->getVecOp();
VPValue *Mul = VecOp;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a partially NFC change, and I'd prefer the use of VecOp as it was before this change. The reason for this is that in the case that VecOp is not a multiply, Mul is still defined (to the value of VecOp), which I don't think is right.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I think this is left over from when we were checking for a negation. Done.

@fhahn
Copy link
Contributor

fhahn commented Aug 28, 2025

It looks like one of the vectorizer tests is failing in precommit and possibly needs updating?

@SamTebbs33 SamTebbs33 force-pushed the expression-recipe-sub branch from 7d40358 to 33236a3 Compare August 29, 2025 12:30
@SamTebbs33
Copy link
Collaborator Author

It looks like one of the vectorizer tests is failing in precommit and possibly needs updating?

Yeah it needed a rebase, done.

Copy link
Contributor

@fhahn fhahn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with a few inline comments remaining, thanks!

Comment on lines +1652 to +1653
/// ResTy vecreduce.add/sub(mul (A, B)).
/// ResTy vecreduce.add/sub(mul(ext(Ty A), ext(Ty B)).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// ResTy vecreduce.add/sub(mul (A, B)).
/// ResTy vecreduce.add/sub(mul(ext(Ty A), ext(Ty B)).
/// * ResTy vecreduce.add/sub(mul (A, B)) or,
/// * ResTy vecreduce.add/sub(mul(ext(Ty A), ext(Ty B)).

Comment on lines +3162 to +3163
VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt,
unsigned Opcode) -> bool {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt,
unsigned Opcode) -> bool {
VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt
) -> bool {

Can we just use the captured Opcode?

Comment on lines +1471 to +1472
CostAfterReduction = TTI.getMulAccReductionCost(
IsUnsigned, ReductionOpc, II.getType(), ExtType, CostKind);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be nice to have a test for this, but not sure if that's possible.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants