Skip to content

Conversation

alexey-bataev
Copy link
Member

In the initial patch for FMAD, potential FMAD nodes were completely
excluded from the reduction analysis for the smaller patch. But it may
cause regressions.

This patch adds better detection of scalar FMAD reduction operations and
tries to correctly calculate the costs of the FMAD reduction operations
(also, excluding the costs of the scalar fmuls) and split reduction
operations, combined with regular FMADs.

Created using spr 1.3.5
@llvmbot
Copy link
Member

llvmbot commented Aug 8, 2025

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-backend-risc-v

Author: Alexey Bataev (alexey-bataev)

Changes

In the initial patch for FMAD, potential FMAD nodes were completely
excluded from the reduction analysis for the smaller patch. But it may
cause regressions.

This patch adds better detection of scalar FMAD reduction operations and
tries to correctly calculate the costs of the FMAD reduction operations
(also, excluding the costs of the scalar fmuls) and split reduction
operations, combined with regular FMADs.


Patch is 72.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152787.diff

9 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+63-20)
  • (modified) llvm/test/Transforms/SLPVectorizer/AArch64/reused-scalar-repeated-in-node.ll (+92-95)
  • (modified) llvm/test/Transforms/SLPVectorizer/AArch64/scalarization-overhead.ll (+18-39)
  • (modified) llvm/test/Transforms/SLPVectorizer/RISCV/vec3-base.ll (+56-38)
  • (modified) llvm/test/Transforms/SLPVectorizer/X86/dot-product.ll (+12-82)
  • (modified) llvm/test/Transforms/SLPVectorizer/X86/horizontal-list.ll (+28-126)
  • (modified) llvm/test/Transforms/SLPVectorizer/X86/redux-feed-buildvector.ll (+13-56)
  • (modified) llvm/test/Transforms/SLPVectorizer/X86/redux-feed-insertelement.ll (+3-19)
  • (modified) llvm/test/Transforms/SLPVectorizer/X86/slp-fma-loss.ll (+23-1)
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index ec06a217667c2..afdf73ab58184 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -22730,21 +22730,11 @@ class HorizontalReduction {
   /// Try to find a reduction tree.
   bool matchAssociativeReduction(BoUpSLP &R, Instruction *Root,
                                  ScalarEvolution &SE, const DataLayout &DL,
-                                 const TargetLibraryInfo &TLI,
-                                 DominatorTree &DT, TargetTransformInfo &TTI) {
+                                 const TargetLibraryInfo &TLI) {
     RdxKind = HorizontalReduction::getRdxKind(Root);
     if (!isVectorizable(RdxKind, Root))
       return false;
 
-    // FMA reduction root - skip.
-    auto CheckForFMA = [&](Instruction *I) {
-      return RdxKind == RecurKind::FAdd &&
-             canConvertToFMA(I, getSameOpcode(I, TLI), DT, DL, TTI, TLI)
-                 .isValid();
-    };
-    if (CheckForFMA(Root))
-      return false;
-
     // Analyze "regular" integer/FP types for reductions - no target-specific
     // types or pointers.
     Type *Ty = Root->getType();
@@ -22782,7 +22772,7 @@ class HorizontalReduction {
         // Also, do not try to reduce const values, if the operation is not
         // foldable.
         if (!EdgeInst || Level > RecursionMaxDepth ||
-            getRdxKind(EdgeInst) != RdxKind || CheckForFMA(EdgeInst) ||
+            getRdxKind(EdgeInst) != RdxKind ||
             IsCmpSelMinMax != isCmpSelMinMax(EdgeInst) ||
             !hasRequiredNumberOfUses(IsCmpSelMinMax, EdgeInst) ||
             !isVectorizable(RdxKind, EdgeInst) ||
@@ -22901,7 +22891,8 @@ class HorizontalReduction {
 
   /// Attempt to vectorize the tree found by matchAssociativeReduction.
   Value *tryToReduce(BoUpSLP &V, const DataLayout &DL, TargetTransformInfo *TTI,
-                     const TargetLibraryInfo &TLI, AssumptionCache *AC) {
+                     const TargetLibraryInfo &TLI, AssumptionCache *AC,
+                     DominatorTree &DT) {
     constexpr unsigned RegMaxNumber = 4;
     constexpr unsigned RedValsMaxNumber = 128;
     // If there are a sufficient number of reduction values, reduce
@@ -23302,7 +23293,7 @@ class HorizontalReduction {
 
         // Estimate cost.
         InstructionCost ReductionCost =
-            getReductionCost(TTI, VL, IsCmpSelMinMax, RdxFMF, V);
+            getReductionCost(TTI, VL, IsCmpSelMinMax, RdxFMF, V, DT, DL, TLI);
         InstructionCost Cost = V.getTreeCost(VL, ReductionCost);
         LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost
                           << " for reduction\n");
@@ -23607,7 +23598,9 @@ class HorizontalReduction {
   InstructionCost getReductionCost(TargetTransformInfo *TTI,
                                    ArrayRef<Value *> ReducedVals,
                                    bool IsCmpSelMinMax, FastMathFlags FMF,
-                                   const BoUpSLP &R) {
+                                   const BoUpSLP &R, DominatorTree &DT,
+                                   const DataLayout &DL,
+                                   const TargetLibraryInfo &TLI) {
     TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
     Type *ScalarTy = ReducedVals.front()->getType();
     unsigned ReduxWidth = ReducedVals.size();
@@ -23632,6 +23625,22 @@ class HorizontalReduction {
         for (User *U : RdxVal->users()) {
           auto *RdxOp = cast<Instruction>(U);
           if (hasRequiredNumberOfUses(IsCmpSelMinMax, RdxOp)) {
+            if (RdxKind == RecurKind::FAdd) {
+              InstructionCost FMACost = canConvertToFMA(
+                  RdxOp, getSameOpcode(RdxOp, TLI), DT, DL, *TTI, TLI);
+              if (FMACost.isValid()) {
+                LLVM_DEBUG(dbgs() << "FMA cost: " << FMACost << "\n");
+                if (auto *I = dyn_cast<Instruction>(RdxVal)) {
+                  // Also, exclude scalar fmul cost.
+                  InstructionCost FMulCost =
+                      TTI->getInstructionCost(I, CostKind);
+                  LLVM_DEBUG(dbgs() << "Minus FMul cost: " << FMulCost << "\n");
+                  FMACost -= FMulCost;
+                }
+                ScalarCost += FMACost;
+                continue;
+              }
+            }
             ScalarCost += TTI->getInstructionCost(RdxOp, CostKind);
             continue;
           }
@@ -23696,8 +23705,42 @@ class HorizontalReduction {
           auto [RType, IsSigned] = R.getRootNodeTypeWithNoCast().value_or(
               std::make_pair(RedTy, true));
           VectorType *RVecTy = getWidenedType(RType, ReduxWidth);
-          VectorCost +=
-              TTI->getArithmeticInstrCost(RdxOpcode, RVecTy, CostKind);
+          InstructionCost FMACost = InstructionCost::getInvalid();
+          if (RdxKind == RecurKind::FAdd) {
+            // Check if the reduction operands can be converted to FMA.
+            SmallVector<Value *> Ops;
+            FastMathFlags FMF;
+            FMF.set();
+            for (Value *RdxVal : ReducedVals) {
+              if (!RdxVal->hasOneUse()) {
+                Ops.clear();
+                break;
+              }
+              if (auto *FPCI = dyn_cast<FPMathOperator>(RdxVal))
+                FMF &= FPCI->getFastMathFlags();
+              Ops.push_back(RdxVal->user_back());
+            }
+            FMACost = canConvertToFMA(
+                Ops, getSameOpcode(Ops, TLI), DT, DL, *TTI, TLI);
+            if (FMACost.isValid()) {
+              // Calculate actual FMAD cost.
+              IntrinsicCostAttributes ICA(Intrinsic::fmuladd, RVecTy,
+                                          {RVecTy, RVecTy, RVecTy}, FMF);
+              FMACost = TTI->getIntrinsicInstrCost(ICA, CostKind);
+
+              LLVM_DEBUG(dbgs() << "Vector FMA cost: " << FMACost << "\n");
+              // Also, exclude vector fmul cost.
+              InstructionCost FMulCost = TTI->getArithmeticInstrCost(
+                  Instruction::FMul, RVecTy, CostKind);
+              LLVM_DEBUG(dbgs() << "Minus vector FMul cost: " << FMulCost << "\n");
+              FMACost -= FMulCost;
+            }
+          }
+          if (FMACost.isValid())
+            VectorCost += FMACost;
+          else
+            VectorCost +=
+                TTI->getArithmeticInstrCost(RdxOpcode, RVecTy, CostKind);
           if (RType != RedTy) {
             unsigned Opcode = Instruction::Trunc;
             if (RedTy->getScalarSizeInBits() > RType->getScalarSizeInBits())
@@ -24357,9 +24400,9 @@ bool SLPVectorizerPass::vectorizeHorReduction(
     if (!isReductionCandidate(Inst))
       return nullptr;
     HorizontalReduction HorRdx;
-    if (!HorRdx.matchAssociativeReduction(R, Inst, *SE, *DL, *TLI, *DT, *TTI))
+    if (!HorRdx.matchAssociativeReduction(R, Inst, *SE, *DL, *TLI))
       return nullptr;
-    return HorRdx.tryToReduce(R, *DL, TTI, *TLI, AC);
+    return HorRdx.tryToReduce(R, *DL, TTI, *TLI, AC, *DT);
   };
   auto TryAppendToPostponedInsts = [&](Instruction *FutureSeed) {
     if (TryOperandsAsNewSeeds && FutureSeed == Root) {
@@ -24504,7 +24547,7 @@ bool SLPVectorizerPass::tryToVectorize(Instruction *I, BoUpSLP &R) {
     if (RedCost >= ScalarCost)
       return false;
 
-    return HorRdx.tryToReduce(R, *DL, &TTI, *TLI, AC) != nullptr;
+    return HorRdx.tryToReduce(R, *DL, &TTI, *TLI, AC, *DT) != nullptr;
   };
   if (Candidates.size() == 1)
     return TryToReduce(I, {Op0, Op1}) || tryToVectorizeList({Op0, Op1}, R);
diff --git a/llvm/test/Transforms/SLPVectorizer/AArch64/reused-scalar-repeated-in-node.ll b/llvm/test/Transforms/SLPVectorizer/AArch64/reused-scalar-repeated-in-node.ll
index 2e684320ba10e..cca58d8d66f04 100644
--- a/llvm/test/Transforms/SLPVectorizer/AArch64/reused-scalar-repeated-in-node.ll
+++ b/llvm/test/Transforms/SLPVectorizer/AArch64/reused-scalar-repeated-in-node.ll
@@ -27,114 +27,111 @@ define void @test() {
 ; CHECK-NEXT:    [[I76:%.*]] = load float, ptr poison, align 4
 ; CHECK-NEXT:    br i1 poison, label %[[BB167:.*]], label %[[BB77:.*]]
 ; CHECK:       [[BB77]]:
+; CHECK-NEXT:    [[TMP0:%.*]] = insertelement <2 x float> poison, float [[I70]], i32 0
+; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <2 x float> poison, float [[I68]], i32 0
+; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <2 x float> poison, float [[I67]], i32 0
+; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <2 x float> [[TMP2]], float [[I69]], i32 1
+; CHECK-NEXT:    [[TMP4:%.*]] = insertelement <2 x float> poison, float [[I66]], i32 0
 ; CHECK-NEXT:    br label %[[BB78:.*]]
 ; CHECK:       [[BB78]]:
-; CHECK-NEXT:    [[I85:%.*]] = phi nsz float [ [[I66]], %[[BB77]] ], [ [[I103:%.*]], %[[BB78]] ]
-; CHECK-NEXT:    [[I80:%.*]] = phi nsz float [ [[I67]], %[[BB77]] ], [ [[I104:%.*]], %[[BB78]] ]
-; CHECK-NEXT:    [[I81:%.*]] = phi nsz float [ [[I68]], %[[BB77]] ], [ [[I105:%.*]], %[[BB78]] ]
-; CHECK-NEXT:    [[I82:%.*]] = phi nsz float [ poison, %[[BB77]] ], [ [[I106:%.*]], %[[BB78]] ]
-; CHECK-NEXT:    [[I84:%.*]] = phi nsz float [ poison, %[[BB77]] ], [ [[I123:%.*]], %[[BB78]] ]
-; CHECK-NEXT:    [[I127:%.*]] = phi nsz float [ [[I69]], %[[BB77]] ], [ [[I124:%.*]], %[[BB78]] ]
-; CHECK-NEXT:    [[I131:%.*]] = phi nsz float [ poison, %[[BB77]] ], [ [[I125:%.*]], %[[BB78]] ]
-; CHECK-NEXT:    [[I86:%.*]] = phi nsz float [ [[I70]], %[[BB77]] ], [ [[I126:%.*]], %[[BB78]] ]
-; CHECK-NEXT:    [[I87:%.*]] = fmul fast float [[I85]], poison
-; CHECK-NEXT:    [[I88:%.*]] = fmul fast float [[I80]], poison
-; CHECK-NEXT:    [[I89:%.*]] = fmul fast float [[I81]], poison
-; CHECK-NEXT:    [[I90:%.*]] = fmul fast float [[I82]], poison
-; CHECK-NEXT:    [[I91:%.*]] = fmul fast float [[I84]], poison
-; CHECK-NEXT:    [[I92:%.*]] = fadd fast float [[I91]], [[I87]]
-; CHECK-NEXT:    [[I93:%.*]] = fmul fast float [[I127]], poison
-; CHECK-NEXT:    [[I94:%.*]] = fadd fast float [[I93]], [[I88]]
-; CHECK-NEXT:    [[I95:%.*]] = fmul fast float [[I131]], poison
-; CHECK-NEXT:    [[I96:%.*]] = fadd fast float [[I95]], [[I89]]
-; CHECK-NEXT:    [[I97:%.*]] = fmul fast float [[I86]], poison
-; CHECK-NEXT:    [[I98:%.*]] = fadd fast float [[I97]], [[I90]]
-; CHECK-NEXT:    [[I99:%.*]] = fadd fast float [[I92]], poison
-; CHECK-NEXT:    [[I100:%.*]] = fadd fast float [[I94]], poison
-; CHECK-NEXT:    [[I101:%.*]] = fadd fast float [[I96]], poison
-; CHECK-NEXT:    [[I102:%.*]] = fadd fast float [[I98]], poison
-; CHECK-NEXT:    [[I103]] = fadd fast float [[I99]], poison
-; CHECK-NEXT:    [[I104]] = fadd fast float [[I100]], poison
-; CHECK-NEXT:    [[I105]] = fadd fast float [[I101]], poison
-; CHECK-NEXT:    [[I106]] = fadd fast float [[I102]], poison
-; CHECK-NEXT:    [[I107:%.*]] = fmul fast float [[I85]], poison
-; CHECK-NEXT:    [[I108:%.*]] = fmul fast float [[I80]], poison
-; CHECK-NEXT:    [[I109:%.*]] = fmul fast float [[I81]], poison
-; CHECK-NEXT:    [[I110:%.*]] = fmul fast float [[I82]], poison
-; CHECK-NEXT:    [[I111:%.*]] = fmul fast float [[I84]], poison
-; CHECK-NEXT:    [[I112:%.*]] = fadd fast float [[I111]], [[I107]]
-; CHECK-NEXT:    [[I113:%.*]] = fmul fast float [[I127]], poison
-; CHECK-NEXT:    [[I114:%.*]] = fadd fast float [[I113]], [[I108]]
-; CHECK-NEXT:    [[I115:%.*]] = fmul fast float [[I131]], poison
-; CHECK-NEXT:    [[I116:%.*]] = fadd fast float [[I115]], [[I109]]
-; CHECK-NEXT:    [[I117:%.*]] = fmul fast float [[I86]], poison
-; CHECK-NEXT:    [[I118:%.*]] = fadd fast float [[I117]], [[I110]]
-; CHECK-NEXT:    [[I119:%.*]] = fadd fast float [[I112]], poison
-; CHECK-NEXT:    [[I120:%.*]] = fadd fast float [[I114]], poison
-; CHECK-NEXT:    [[I121:%.*]] = fadd fast float [[I116]], poison
-; CHECK-NEXT:    [[I122:%.*]] = fadd fast float [[I118]], poison
-; CHECK-NEXT:    [[I123]] = fadd fast float [[I119]], poison
-; CHECK-NEXT:    [[I124]] = fadd fast float [[I120]], poison
-; CHECK-NEXT:    [[I125]] = fadd fast float [[I121]], poison
-; CHECK-NEXT:    [[I126]] = fadd fast float [[I122]], poison
+; CHECK-NEXT:    [[I85:%.*]] = phi nsz float [ [[I66]], %[[BB77]] ], [ [[TMP31:%.*]], %[[BB78]] ]
+; CHECK-NEXT:    [[I80:%.*]] = phi nsz float [ [[I67]], %[[BB77]] ], [ [[TMP46:%.*]], %[[BB78]] ]
+; CHECK-NEXT:    [[I81:%.*]] = phi nsz float [ [[I68]], %[[BB77]] ], [ [[TMP37:%.*]], %[[BB78]] ]
+; CHECK-NEXT:    [[I82:%.*]] = phi nsz float [ poison, %[[BB77]] ], [ [[TMP39:%.*]], %[[BB78]] ]
+; CHECK-NEXT:    [[I84:%.*]] = phi nsz float [ poison, %[[BB77]] ], [ [[TMP30:%.*]], %[[BB78]] ]
+; CHECK-NEXT:    [[I127:%.*]] = phi nsz float [ [[I69]], %[[BB77]] ], [ [[TMP53:%.*]], %[[BB78]] ]
+; CHECK-NEXT:    [[I131:%.*]] = phi nsz float [ poison, %[[BB77]] ], [ [[TMP36:%.*]], %[[BB78]] ]
+; CHECK-NEXT:    [[I86:%.*]] = phi nsz float [ [[I70]], %[[BB77]] ], [ [[TMP40:%.*]], %[[BB78]] ]
+; CHECK-NEXT:    [[TMP5:%.*]] = phi <2 x float> [ [[TMP0]], %[[BB77]] ], [ [[TMP38:%.*]], %[[BB78]] ]
+; CHECK-NEXT:    [[TMP6:%.*]] = phi <2 x float> [ [[TMP1]], %[[BB77]] ], [ [[TMP35:%.*]], %[[BB78]] ]
+; CHECK-NEXT:    [[TMP7:%.*]] = phi <2 x float> [ [[TMP3]], %[[BB77]] ], [ [[TMP32:%.*]], %[[BB78]] ]
+; CHECK-NEXT:    [[TMP8:%.*]] = phi <2 x float> [ [[TMP4]], %[[BB77]] ], [ [[TMP29:%.*]], %[[BB78]] ]
+; CHECK-NEXT:    [[TMP9:%.*]] = fmul fast <2 x float> [[TMP8]], poison
+; CHECK-NEXT:    [[TMP10:%.*]] = fmul fast <2 x float> [[TMP7]], poison
+; CHECK-NEXT:    [[TMP11:%.*]] = fmul fast <2 x float> [[TMP6]], poison
+; CHECK-NEXT:    [[TMP12:%.*]] = fmul fast <2 x float> [[TMP5]], poison
+; CHECK-NEXT:    [[TMP13:%.*]] = fmul fast <2 x float> [[TMP8]], poison
+; CHECK-NEXT:    [[TMP14:%.*]] = shufflevector <2 x float> [[TMP13]], <2 x float> poison, <2 x i32> <i32 1, i32 0>
+; CHECK-NEXT:    [[TMP15:%.*]] = fmul fast <2 x float> [[TMP7]], poison
+; CHECK-NEXT:    [[TMP16:%.*]] = shufflevector <2 x float> [[TMP15]], <2 x float> poison, <2 x i32> <i32 1, i32 0>
+; CHECK-NEXT:    [[TMP17:%.*]] = fmul fast <2 x float> [[TMP6]], poison
+; CHECK-NEXT:    [[TMP18:%.*]] = shufflevector <2 x float> [[TMP17]], <2 x float> poison, <2 x i32> <i32 1, i32 0>
+; CHECK-NEXT:    [[TMP19:%.*]] = fmul fast <2 x float> [[TMP5]], poison
+; CHECK-NEXT:    [[TMP20:%.*]] = shufflevector <2 x float> [[TMP19]], <2 x float> poison, <2 x i32> <i32 1, i32 0>
+; CHECK-NEXT:    [[TMP21:%.*]] = fadd fast <2 x float> [[TMP14]], [[TMP9]]
+; CHECK-NEXT:    [[TMP22:%.*]] = fadd fast <2 x float> [[TMP21]], poison
+; CHECK-NEXT:    [[TMP23:%.*]] = fadd fast <2 x float> [[TMP16]], [[TMP10]]
+; CHECK-NEXT:    [[TMP24:%.*]] = fadd fast <2 x float> [[TMP23]], poison
+; CHECK-NEXT:    [[TMP25:%.*]] = fadd fast <2 x float> [[TMP18]], [[TMP11]]
+; CHECK-NEXT:    [[TMP26:%.*]] = fadd fast <2 x float> [[TMP25]], poison
+; CHECK-NEXT:    [[TMP27:%.*]] = fadd fast <2 x float> [[TMP20]], [[TMP12]]
+; CHECK-NEXT:    [[TMP28:%.*]] = fadd fast <2 x float> [[TMP27]], poison
+; CHECK-NEXT:    [[TMP29]] = fadd fast <2 x float> [[TMP22]], poison
+; CHECK-NEXT:    [[TMP30]] = extractelement <2 x float> [[TMP29]], i32 1
+; CHECK-NEXT:    [[TMP31]] = extractelement <2 x float> [[TMP29]], i32 0
+; CHECK-NEXT:    [[TMP32]] = fadd fast <2 x float> [[TMP24]], poison
+; CHECK-NEXT:    [[TMP53]] = extractelement <2 x float> [[TMP32]], i32 1
+; CHECK-NEXT:    [[TMP46]] = extractelement <2 x float> [[TMP32]], i32 0
+; CHECK-NEXT:    [[TMP35]] = fadd fast <2 x float> [[TMP26]], poison
+; CHECK-NEXT:    [[TMP36]] = extractelement <2 x float> [[TMP35]], i32 1
+; CHECK-NEXT:    [[TMP37]] = extractelement <2 x float> [[TMP35]], i32 0
+; CHECK-NEXT:    [[TMP38]] = fadd fast <2 x float> [[TMP28]], poison
+; CHECK-NEXT:    [[TMP39]] = extractelement <2 x float> [[TMP38]], i32 1
+; CHECK-NEXT:    [[TMP40]] = extractelement <2 x float> [[TMP38]], i32 0
 ; CHECK-NEXT:    [[I135:%.*]] = fmul fast float [[I85]], [[I65]]
 ; CHECK-NEXT:    [[I128:%.*]] = fmul fast float [[I80]], [[I65]]
 ; CHECK-NEXT:    [[I129:%.*]] = fmul fast float [[I81]], [[I65]]
 ; CHECK-NEXT:    [[I130:%.*]] = fmul fast float [[I82]], [[I65]]
 ; CHECK-NEXT:    [[I133:%.*]] = fmul fast float [[I84]], [[I77]]
-; CHECK-NEXT:    [[I134:%.*]] = fadd fast float [[I133]], [[I135]]
 ; CHECK-NEXT:    [[I136:%.*]] = fmul fast float [[I127]], [[I77]]
-; CHECK-NEXT:    [[TMP51:%.*]] = fadd fast float [[I136]], [[I128]]
 ; CHECK-NEXT:    [[I138:%.*]] = fmul fast float [[I131]], [[I77]]
-; CHECK-NEXT:    [[TMP52:%.*]] = fadd fast float [[I138]], [[I129]]
 ; CHECK-NEXT:    [[I137:%.*]] = fmul fast float [[I86]], [[I77]]
-; CHECK-NEXT:    [[I139:%.*]] = fadd fast float [[I137]], [[I130]]
-; CHECK-NEXT:    [[I140:%.*]] = fadd fast float [[I134]], poison
-; CHECK-NEXT:    [[I141:%.*]] = fadd fast float [[TMP51]], poison
-; CHECK-NEXT:    [[I142:%.*]] = fadd fast float [[TMP52]], poison
-; CHECK-NEXT:    [[I143:%.*]] = fadd fast float [[I139]], poison
-; CHECK-NEXT:    [[I144:%.*]] = fadd fast float [[I140]], poison
-; CHECK-NEXT:    [[I145:%.*]] = fadd fast float [[I141]], poison
-; CHECK-NEXT:    [[I146:%.*]] = fadd fast float [[I142]], poison
-; CHECK-NEXT:    [[I152:%.*]] = fadd fast float [[I143]], poison
-; CHECK-NEXT:    [[I147:%.*]] = fmul fast float [[I85]], poison
-; CHECK-NEXT:    [[I148:%.*]] = fmul fast float [[I80]], poison
-; CHECK-NEXT:    [[I149:%.*]] = fmul fast float [[I81]], poison
-; CHECK-NEXT:    [[I150:%.*]] = fmul fast float [[I82]], poison
-; CHECK-NEXT:    [[I151:%.*]] = fmul fast float [[I84]], poison
-; CHECK-NEXT:    [[TMP57:%.*]] = fadd fast float [[I151]], [[I147]]
-; CHECK-NEXT:    [[I153:%.*]] = fmul fast float [[I127]], poison
-; CHECK-NEXT:    [[TMP58:%.*]] = fadd fast float [[I153]], [[I148]]
-; CHECK-NEXT:    [[I155:%.*]] = fmul fast float [[I131]], poison
-; CHECK-NEXT:    [[TMP59:%.*]] = fadd fast float [[I155]], [[I149]]
-; CHECK-NEXT:    [[I157:%.*]] = fmul fast float [[I86]], poison
+; CHECK-NEXT:    [[OP_RDX14:%.*]] = fadd fast float poison, [[I133]]
+; CHECK-NEXT:    [[OP_RDX15:%.*]] = fadd fast float [[OP_RDX14]], [[I135]]
+; CHECK-NEXT:    [[OP_RDX12:%.*]] = fadd fast float poison, [[I136]]
+; CHECK-NEXT:    [[OP_RDX13:%.*]] = fadd fast float [[OP_RDX12]], [[I128]]
+; CHECK-NEXT:    [[OP_RDX10:%.*]] = fadd fast float poison, [[I138]]
+; CHECK-NEXT:    [[OP_RDX11:%.*]] = fadd fast float [[OP_RDX10]], [[I129]]
+; CHECK-NEXT:    [[OP_RDX8:%.*]] = fadd fast float poison, [[I137]]
+; CHECK-NEXT:    [[OP_RDX9:%.*]] = fadd fast float [[OP_RDX8]], [[I130]]
+; CHECK-NEXT:    [[TMP41:%.*]] = fmul fast <2 x float> [[TMP8]], poison
+; CHECK-NEXT:    [[TMP42:%.*]] = fmul fast <2 x float> [[TMP7]], poison
+; CHECK-NEXT:    [[TMP43:%.*]] = fmul fast <2 x float> [[TMP6]], poison
+; CHECK-NEXT:    [[TMP44:%.*]] = fmul fast <2 x float> [[TMP5]], poison
+; CHECK-NEXT:    [[TMP45:%.*]] = extractelement <2 x float> [[TMP41]], i32 1
+; CHECK-NEXT:    [[I157:%.*]] = fadd fast float poison, [[TMP45]]
+; CHECK-NEXT:    [[I150:%.*]] = extractelement <2 x float> [[TMP41]], i32 0
 ; CHECK-NEXT:    [[TMP60:%.*]] = fadd fast float [[I157]], [[I150]]
-; CHECK-NEXT:    [[I159:%.*]] = fadd fast float [[TMP57]], poison
-; CHECK-NEXT:    [[I160:%.*]] = fadd fast float [[TMP58]], poison
-; CHECK-NEXT:    [[I161:%.*]] = fadd fast float [[TMP59]], poison
-; CHECK-NEXT:    [[I162:%.*]] = fadd fast float [[TMP60]], poison
-; CHECK-NEXT:    [[I163:%.*]] = fadd fast float [[I159]], poison
-; CHECK-NEXT:    [[I164:%.*]] = fadd fast float [[I160]], poison
-; CHECK-NEXT:    [[I165:%.*]] = fadd fast float [[I161]], poison
-; CHECK-NEXT:    [[I166:%.*]] = fadd fast float [[I162]], poison
+; CHECK-NEXT:    [[TMP47:%.*]] = extractelement <2 x float> [[TMP42]], i32 1
+; CHECK-NEXT:    [[OP_RDX4:%.*]] = fadd fast float poison, [[TMP47]]
+; CHECK-NEXT:    [[TMP48:%.*]] = extractelement <2 x float> [[TMP42]], i32 0
+; CHECK-NEXT:    [[OP_RDX5:%.*]] = fadd fast float [[OP_RDX4]], [[TMP48]]
+; CHECK-NEXT:    [[TMP49:%.*]] = extractelement <2 x float> [[TMP43]], i32 1
+; CHECK-NEXT:    [[OP_RDX2:%.*]] = fadd fast float poison, [[TMP49]]
+; CHECK-NEXT:    [[TMP50:%.*]] = extractelement <2 x float> [[TMP43]], i32 0
+; CHECK-NEXT:    [[OP_RDX3:%.*]] = fadd fast float [[OP_RDX2]], [[TMP50]]
+; CHECK-NEXT:    [[TMP51:%.*]] = extractelement <2 x float> [[TMP44]], i32 0
+; CHECK-NEXT:    [[OP_RDX:%.*]] = fadd fast float poison, [[TMP51]]
+; CHECK-NEXT:    [[TMP52:%.*]] = extractelement <2 x float> [[TMP44]], i32 1
+; CHECK-NEXT:    [[OP_RDX1:%.*]] = fadd fast float [[OP_RDX]], [[TMP52]]
 ; CHECK-NEXT:    br i1 po...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Aug 8, 2025

@llvm/pr-subscribers-vectorizers

Author: Alexey Bataev (alexey-bataev)

Changes

In the initial patch for FMAD, potential FMAD nodes were completely
excluded from the reduction analysis for the smaller patch. But it may
cause regressions.

This patch adds better detection of scalar FMAD reduction operations and
tries to correctly calculate the costs of the FMAD reduction operations
(also, excluding the costs of the scalar fmuls) and split reduction
operations, combined with regular FMADs.


Patch is 72.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152787.diff

9 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+63-20)
  • (modified) llvm/test/Transforms/SLPVectorizer/AArch64/reused-scalar-repeated-in-node.ll (+92-95)
  • (modified) llvm/test/Transforms/SLPVectorizer/AArch64/scalarization-overhead.ll (+18-39)
  • (modified) llvm/test/Transforms/SLPVectorizer/RISCV/vec3-base.ll (+56-38)
  • (modified) llvm/test/Transforms/SLPVectorizer/X86/dot-product.ll (+12-82)
  • (modified) llvm/test/Transforms/SLPVectorizer/X86/horizontal-list.ll (+28-126)
  • (modified) llvm/test/Transforms/SLPVectorizer/X86/redux-feed-buildvector.ll (+13-56)
  • (modified) llvm/test/Transforms/SLPVectorizer/X86/redux-feed-insertelement.ll (+3-19)
  • (modified) llvm/test/Transforms/SLPVectorizer/X86/slp-fma-loss.ll (+23-1)
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index ec06a217667c2..afdf73ab58184 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -22730,21 +22730,11 @@ class HorizontalReduction {
   /// Try to find a reduction tree.
   bool matchAssociativeReduction(BoUpSLP &R, Instruction *Root,
                                  ScalarEvolution &SE, const DataLayout &DL,
-                                 const TargetLibraryInfo &TLI,
-                                 DominatorTree &DT, TargetTransformInfo &TTI) {
+                                 const TargetLibraryInfo &TLI) {
     RdxKind = HorizontalReduction::getRdxKind(Root);
     if (!isVectorizable(RdxKind, Root))
       return false;
 
-    // FMA reduction root - skip.
-    auto CheckForFMA = [&](Instruction *I) {
-      return RdxKind == RecurKind::FAdd &&
-             canConvertToFMA(I, getSameOpcode(I, TLI), DT, DL, TTI, TLI)
-                 .isValid();
-    };
-    if (CheckForFMA(Root))
-      return false;
-
     // Analyze "regular" integer/FP types for reductions - no target-specific
     // types or pointers.
     Type *Ty = Root->getType();
@@ -22782,7 +22772,7 @@ class HorizontalReduction {
         // Also, do not try to reduce const values, if the operation is not
         // foldable.
         if (!EdgeInst || Level > RecursionMaxDepth ||
-            getRdxKind(EdgeInst) != RdxKind || CheckForFMA(EdgeInst) ||
+            getRdxKind(EdgeInst) != RdxKind ||
             IsCmpSelMinMax != isCmpSelMinMax(EdgeInst) ||
             !hasRequiredNumberOfUses(IsCmpSelMinMax, EdgeInst) ||
             !isVectorizable(RdxKind, EdgeInst) ||
@@ -22901,7 +22891,8 @@ class HorizontalReduction {
 
   /// Attempt to vectorize the tree found by matchAssociativeReduction.
   Value *tryToReduce(BoUpSLP &V, const DataLayout &DL, TargetTransformInfo *TTI,
-                     const TargetLibraryInfo &TLI, AssumptionCache *AC) {
+                     const TargetLibraryInfo &TLI, AssumptionCache *AC,
+                     DominatorTree &DT) {
     constexpr unsigned RegMaxNumber = 4;
     constexpr unsigned RedValsMaxNumber = 128;
     // If there are a sufficient number of reduction values, reduce
@@ -23302,7 +23293,7 @@ class HorizontalReduction {
 
         // Estimate cost.
         InstructionCost ReductionCost =
-            getReductionCost(TTI, VL, IsCmpSelMinMax, RdxFMF, V);
+            getReductionCost(TTI, VL, IsCmpSelMinMax, RdxFMF, V, DT, DL, TLI);
         InstructionCost Cost = V.getTreeCost(VL, ReductionCost);
         LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost
                           << " for reduction\n");
@@ -23607,7 +23598,9 @@ class HorizontalReduction {
   InstructionCost getReductionCost(TargetTransformInfo *TTI,
                                    ArrayRef<Value *> ReducedVals,
                                    bool IsCmpSelMinMax, FastMathFlags FMF,
-                                   const BoUpSLP &R) {
+                                   const BoUpSLP &R, DominatorTree &DT,
+                                   const DataLayout &DL,
+                                   const TargetLibraryInfo &TLI) {
     TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
     Type *ScalarTy = ReducedVals.front()->getType();
     unsigned ReduxWidth = ReducedVals.size();
@@ -23632,6 +23625,22 @@ class HorizontalReduction {
         for (User *U : RdxVal->users()) {
           auto *RdxOp = cast<Instruction>(U);
           if (hasRequiredNumberOfUses(IsCmpSelMinMax, RdxOp)) {
+            if (RdxKind == RecurKind::FAdd) {
+              InstructionCost FMACost = canConvertToFMA(
+                  RdxOp, getSameOpcode(RdxOp, TLI), DT, DL, *TTI, TLI);
+              if (FMACost.isValid()) {
+                LLVM_DEBUG(dbgs() << "FMA cost: " << FMACost << "\n");
+                if (auto *I = dyn_cast<Instruction>(RdxVal)) {
+                  // Also, exclude scalar fmul cost.
+                  InstructionCost FMulCost =
+                      TTI->getInstructionCost(I, CostKind);
+                  LLVM_DEBUG(dbgs() << "Minus FMul cost: " << FMulCost << "\n");
+                  FMACost -= FMulCost;
+                }
+                ScalarCost += FMACost;
+                continue;
+              }
+            }
             ScalarCost += TTI->getInstructionCost(RdxOp, CostKind);
             continue;
           }
@@ -23696,8 +23705,42 @@ class HorizontalReduction {
           auto [RType, IsSigned] = R.getRootNodeTypeWithNoCast().value_or(
               std::make_pair(RedTy, true));
           VectorType *RVecTy = getWidenedType(RType, ReduxWidth);
-          VectorCost +=
-              TTI->getArithmeticInstrCost(RdxOpcode, RVecTy, CostKind);
+          InstructionCost FMACost = InstructionCost::getInvalid();
+          if (RdxKind == RecurKind::FAdd) {
+            // Check if the reduction operands can be converted to FMA.
+            SmallVector<Value *> Ops;
+            FastMathFlags FMF;
+            FMF.set();
+            for (Value *RdxVal : ReducedVals) {
+              if (!RdxVal->hasOneUse()) {
+                Ops.clear();
+                break;
+              }
+              if (auto *FPCI = dyn_cast<FPMathOperator>(RdxVal))
+                FMF &= FPCI->getFastMathFlags();
+              Ops.push_back(RdxVal->user_back());
+            }
+            FMACost = canConvertToFMA(
+                Ops, getSameOpcode(Ops, TLI), DT, DL, *TTI, TLI);
+            if (FMACost.isValid()) {
+              // Calculate actual FMAD cost.
+              IntrinsicCostAttributes ICA(Intrinsic::fmuladd, RVecTy,
+                                          {RVecTy, RVecTy, RVecTy}, FMF);
+              FMACost = TTI->getIntrinsicInstrCost(ICA, CostKind);
+
+              LLVM_DEBUG(dbgs() << "Vector FMA cost: " << FMACost << "\n");
+              // Also, exclude vector fmul cost.
+              InstructionCost FMulCost = TTI->getArithmeticInstrCost(
+                  Instruction::FMul, RVecTy, CostKind);
+              LLVM_DEBUG(dbgs() << "Minus vector FMul cost: " << FMulCost << "\n");
+              FMACost -= FMulCost;
+            }
+          }
+          if (FMACost.isValid())
+            VectorCost += FMACost;
+          else
+            VectorCost +=
+                TTI->getArithmeticInstrCost(RdxOpcode, RVecTy, CostKind);
           if (RType != RedTy) {
             unsigned Opcode = Instruction::Trunc;
             if (RedTy->getScalarSizeInBits() > RType->getScalarSizeInBits())
@@ -24357,9 +24400,9 @@ bool SLPVectorizerPass::vectorizeHorReduction(
     if (!isReductionCandidate(Inst))
       return nullptr;
     HorizontalReduction HorRdx;
-    if (!HorRdx.matchAssociativeReduction(R, Inst, *SE, *DL, *TLI, *DT, *TTI))
+    if (!HorRdx.matchAssociativeReduction(R, Inst, *SE, *DL, *TLI))
       return nullptr;
-    return HorRdx.tryToReduce(R, *DL, TTI, *TLI, AC);
+    return HorRdx.tryToReduce(R, *DL, TTI, *TLI, AC, *DT);
   };
   auto TryAppendToPostponedInsts = [&](Instruction *FutureSeed) {
     if (TryOperandsAsNewSeeds && FutureSeed == Root) {
@@ -24504,7 +24547,7 @@ bool SLPVectorizerPass::tryToVectorize(Instruction *I, BoUpSLP &R) {
     if (RedCost >= ScalarCost)
       return false;
 
-    return HorRdx.tryToReduce(R, *DL, &TTI, *TLI, AC) != nullptr;
+    return HorRdx.tryToReduce(R, *DL, &TTI, *TLI, AC, *DT) != nullptr;
   };
   if (Candidates.size() == 1)
     return TryToReduce(I, {Op0, Op1}) || tryToVectorizeList({Op0, Op1}, R);
diff --git a/llvm/test/Transforms/SLPVectorizer/AArch64/reused-scalar-repeated-in-node.ll b/llvm/test/Transforms/SLPVectorizer/AArch64/reused-scalar-repeated-in-node.ll
index 2e684320ba10e..cca58d8d66f04 100644
--- a/llvm/test/Transforms/SLPVectorizer/AArch64/reused-scalar-repeated-in-node.ll
+++ b/llvm/test/Transforms/SLPVectorizer/AArch64/reused-scalar-repeated-in-node.ll
@@ -27,114 +27,111 @@ define void @test() {
 ; CHECK-NEXT:    [[I76:%.*]] = load float, ptr poison, align 4
 ; CHECK-NEXT:    br i1 poison, label %[[BB167:.*]], label %[[BB77:.*]]
 ; CHECK:       [[BB77]]:
+; CHECK-NEXT:    [[TMP0:%.*]] = insertelement <2 x float> poison, float [[I70]], i32 0
+; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <2 x float> poison, float [[I68]], i32 0
+; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <2 x float> poison, float [[I67]], i32 0
+; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <2 x float> [[TMP2]], float [[I69]], i32 1
+; CHECK-NEXT:    [[TMP4:%.*]] = insertelement <2 x float> poison, float [[I66]], i32 0
 ; CHECK-NEXT:    br label %[[BB78:.*]]
 ; CHECK:       [[BB78]]:
-; CHECK-NEXT:    [[I85:%.*]] = phi nsz float [ [[I66]], %[[BB77]] ], [ [[I103:%.*]], %[[BB78]] ]
-; CHECK-NEXT:    [[I80:%.*]] = phi nsz float [ [[I67]], %[[BB77]] ], [ [[I104:%.*]], %[[BB78]] ]
-; CHECK-NEXT:    [[I81:%.*]] = phi nsz float [ [[I68]], %[[BB77]] ], [ [[I105:%.*]], %[[BB78]] ]
-; CHECK-NEXT:    [[I82:%.*]] = phi nsz float [ poison, %[[BB77]] ], [ [[I106:%.*]], %[[BB78]] ]
-; CHECK-NEXT:    [[I84:%.*]] = phi nsz float [ poison, %[[BB77]] ], [ [[I123:%.*]], %[[BB78]] ]
-; CHECK-NEXT:    [[I127:%.*]] = phi nsz float [ [[I69]], %[[BB77]] ], [ [[I124:%.*]], %[[BB78]] ]
-; CHECK-NEXT:    [[I131:%.*]] = phi nsz float [ poison, %[[BB77]] ], [ [[I125:%.*]], %[[BB78]] ]
-; CHECK-NEXT:    [[I86:%.*]] = phi nsz float [ [[I70]], %[[BB77]] ], [ [[I126:%.*]], %[[BB78]] ]
-; CHECK-NEXT:    [[I87:%.*]] = fmul fast float [[I85]], poison
-; CHECK-NEXT:    [[I88:%.*]] = fmul fast float [[I80]], poison
-; CHECK-NEXT:    [[I89:%.*]] = fmul fast float [[I81]], poison
-; CHECK-NEXT:    [[I90:%.*]] = fmul fast float [[I82]], poison
-; CHECK-NEXT:    [[I91:%.*]] = fmul fast float [[I84]], poison
-; CHECK-NEXT:    [[I92:%.*]] = fadd fast float [[I91]], [[I87]]
-; CHECK-NEXT:    [[I93:%.*]] = fmul fast float [[I127]], poison
-; CHECK-NEXT:    [[I94:%.*]] = fadd fast float [[I93]], [[I88]]
-; CHECK-NEXT:    [[I95:%.*]] = fmul fast float [[I131]], poison
-; CHECK-NEXT:    [[I96:%.*]] = fadd fast float [[I95]], [[I89]]
-; CHECK-NEXT:    [[I97:%.*]] = fmul fast float [[I86]], poison
-; CHECK-NEXT:    [[I98:%.*]] = fadd fast float [[I97]], [[I90]]
-; CHECK-NEXT:    [[I99:%.*]] = fadd fast float [[I92]], poison
-; CHECK-NEXT:    [[I100:%.*]] = fadd fast float [[I94]], poison
-; CHECK-NEXT:    [[I101:%.*]] = fadd fast float [[I96]], poison
-; CHECK-NEXT:    [[I102:%.*]] = fadd fast float [[I98]], poison
-; CHECK-NEXT:    [[I103]] = fadd fast float [[I99]], poison
-; CHECK-NEXT:    [[I104]] = fadd fast float [[I100]], poison
-; CHECK-NEXT:    [[I105]] = fadd fast float [[I101]], poison
-; CHECK-NEXT:    [[I106]] = fadd fast float [[I102]], poison
-; CHECK-NEXT:    [[I107:%.*]] = fmul fast float [[I85]], poison
-; CHECK-NEXT:    [[I108:%.*]] = fmul fast float [[I80]], poison
-; CHECK-NEXT:    [[I109:%.*]] = fmul fast float [[I81]], poison
-; CHECK-NEXT:    [[I110:%.*]] = fmul fast float [[I82]], poison
-; CHECK-NEXT:    [[I111:%.*]] = fmul fast float [[I84]], poison
-; CHECK-NEXT:    [[I112:%.*]] = fadd fast float [[I111]], [[I107]]
-; CHECK-NEXT:    [[I113:%.*]] = fmul fast float [[I127]], poison
-; CHECK-NEXT:    [[I114:%.*]] = fadd fast float [[I113]], [[I108]]
-; CHECK-NEXT:    [[I115:%.*]] = fmul fast float [[I131]], poison
-; CHECK-NEXT:    [[I116:%.*]] = fadd fast float [[I115]], [[I109]]
-; CHECK-NEXT:    [[I117:%.*]] = fmul fast float [[I86]], poison
-; CHECK-NEXT:    [[I118:%.*]] = fadd fast float [[I117]], [[I110]]
-; CHECK-NEXT:    [[I119:%.*]] = fadd fast float [[I112]], poison
-; CHECK-NEXT:    [[I120:%.*]] = fadd fast float [[I114]], poison
-; CHECK-NEXT:    [[I121:%.*]] = fadd fast float [[I116]], poison
-; CHECK-NEXT:    [[I122:%.*]] = fadd fast float [[I118]], poison
-; CHECK-NEXT:    [[I123]] = fadd fast float [[I119]], poison
-; CHECK-NEXT:    [[I124]] = fadd fast float [[I120]], poison
-; CHECK-NEXT:    [[I125]] = fadd fast float [[I121]], poison
-; CHECK-NEXT:    [[I126]] = fadd fast float [[I122]], poison
+; CHECK-NEXT:    [[I85:%.*]] = phi nsz float [ [[I66]], %[[BB77]] ], [ [[TMP31:%.*]], %[[BB78]] ]
+; CHECK-NEXT:    [[I80:%.*]] = phi nsz float [ [[I67]], %[[BB77]] ], [ [[TMP46:%.*]], %[[BB78]] ]
+; CHECK-NEXT:    [[I81:%.*]] = phi nsz float [ [[I68]], %[[BB77]] ], [ [[TMP37:%.*]], %[[BB78]] ]
+; CHECK-NEXT:    [[I82:%.*]] = phi nsz float [ poison, %[[BB77]] ], [ [[TMP39:%.*]], %[[BB78]] ]
+; CHECK-NEXT:    [[I84:%.*]] = phi nsz float [ poison, %[[BB77]] ], [ [[TMP30:%.*]], %[[BB78]] ]
+; CHECK-NEXT:    [[I127:%.*]] = phi nsz float [ [[I69]], %[[BB77]] ], [ [[TMP53:%.*]], %[[BB78]] ]
+; CHECK-NEXT:    [[I131:%.*]] = phi nsz float [ poison, %[[BB77]] ], [ [[TMP36:%.*]], %[[BB78]] ]
+; CHECK-NEXT:    [[I86:%.*]] = phi nsz float [ [[I70]], %[[BB77]] ], [ [[TMP40:%.*]], %[[BB78]] ]
+; CHECK-NEXT:    [[TMP5:%.*]] = phi <2 x float> [ [[TMP0]], %[[BB77]] ], [ [[TMP38:%.*]], %[[BB78]] ]
+; CHECK-NEXT:    [[TMP6:%.*]] = phi <2 x float> [ [[TMP1]], %[[BB77]] ], [ [[TMP35:%.*]], %[[BB78]] ]
+; CHECK-NEXT:    [[TMP7:%.*]] = phi <2 x float> [ [[TMP3]], %[[BB77]] ], [ [[TMP32:%.*]], %[[BB78]] ]
+; CHECK-NEXT:    [[TMP8:%.*]] = phi <2 x float> [ [[TMP4]], %[[BB77]] ], [ [[TMP29:%.*]], %[[BB78]] ]
+; CHECK-NEXT:    [[TMP9:%.*]] = fmul fast <2 x float> [[TMP8]], poison
+; CHECK-NEXT:    [[TMP10:%.*]] = fmul fast <2 x float> [[TMP7]], poison
+; CHECK-NEXT:    [[TMP11:%.*]] = fmul fast <2 x float> [[TMP6]], poison
+; CHECK-NEXT:    [[TMP12:%.*]] = fmul fast <2 x float> [[TMP5]], poison
+; CHECK-NEXT:    [[TMP13:%.*]] = fmul fast <2 x float> [[TMP8]], poison
+; CHECK-NEXT:    [[TMP14:%.*]] = shufflevector <2 x float> [[TMP13]], <2 x float> poison, <2 x i32> <i32 1, i32 0>
+; CHECK-NEXT:    [[TMP15:%.*]] = fmul fast <2 x float> [[TMP7]], poison
+; CHECK-NEXT:    [[TMP16:%.*]] = shufflevector <2 x float> [[TMP15]], <2 x float> poison, <2 x i32> <i32 1, i32 0>
+; CHECK-NEXT:    [[TMP17:%.*]] = fmul fast <2 x float> [[TMP6]], poison
+; CHECK-NEXT:    [[TMP18:%.*]] = shufflevector <2 x float> [[TMP17]], <2 x float> poison, <2 x i32> <i32 1, i32 0>
+; CHECK-NEXT:    [[TMP19:%.*]] = fmul fast <2 x float> [[TMP5]], poison
+; CHECK-NEXT:    [[TMP20:%.*]] = shufflevector <2 x float> [[TMP19]], <2 x float> poison, <2 x i32> <i32 1, i32 0>
+; CHECK-NEXT:    [[TMP21:%.*]] = fadd fast <2 x float> [[TMP14]], [[TMP9]]
+; CHECK-NEXT:    [[TMP22:%.*]] = fadd fast <2 x float> [[TMP21]], poison
+; CHECK-NEXT:    [[TMP23:%.*]] = fadd fast <2 x float> [[TMP16]], [[TMP10]]
+; CHECK-NEXT:    [[TMP24:%.*]] = fadd fast <2 x float> [[TMP23]], poison
+; CHECK-NEXT:    [[TMP25:%.*]] = fadd fast <2 x float> [[TMP18]], [[TMP11]]
+; CHECK-NEXT:    [[TMP26:%.*]] = fadd fast <2 x float> [[TMP25]], poison
+; CHECK-NEXT:    [[TMP27:%.*]] = fadd fast <2 x float> [[TMP20]], [[TMP12]]
+; CHECK-NEXT:    [[TMP28:%.*]] = fadd fast <2 x float> [[TMP27]], poison
+; CHECK-NEXT:    [[TMP29]] = fadd fast <2 x float> [[TMP22]], poison
+; CHECK-NEXT:    [[TMP30]] = extractelement <2 x float> [[TMP29]], i32 1
+; CHECK-NEXT:    [[TMP31]] = extractelement <2 x float> [[TMP29]], i32 0
+; CHECK-NEXT:    [[TMP32]] = fadd fast <2 x float> [[TMP24]], poison
+; CHECK-NEXT:    [[TMP53]] = extractelement <2 x float> [[TMP32]], i32 1
+; CHECK-NEXT:    [[TMP46]] = extractelement <2 x float> [[TMP32]], i32 0
+; CHECK-NEXT:    [[TMP35]] = fadd fast <2 x float> [[TMP26]], poison
+; CHECK-NEXT:    [[TMP36]] = extractelement <2 x float> [[TMP35]], i32 1
+; CHECK-NEXT:    [[TMP37]] = extractelement <2 x float> [[TMP35]], i32 0
+; CHECK-NEXT:    [[TMP38]] = fadd fast <2 x float> [[TMP28]], poison
+; CHECK-NEXT:    [[TMP39]] = extractelement <2 x float> [[TMP38]], i32 1
+; CHECK-NEXT:    [[TMP40]] = extractelement <2 x float> [[TMP38]], i32 0
 ; CHECK-NEXT:    [[I135:%.*]] = fmul fast float [[I85]], [[I65]]
 ; CHECK-NEXT:    [[I128:%.*]] = fmul fast float [[I80]], [[I65]]
 ; CHECK-NEXT:    [[I129:%.*]] = fmul fast float [[I81]], [[I65]]
 ; CHECK-NEXT:    [[I130:%.*]] = fmul fast float [[I82]], [[I65]]
 ; CHECK-NEXT:    [[I133:%.*]] = fmul fast float [[I84]], [[I77]]
-; CHECK-NEXT:    [[I134:%.*]] = fadd fast float [[I133]], [[I135]]
 ; CHECK-NEXT:    [[I136:%.*]] = fmul fast float [[I127]], [[I77]]
-; CHECK-NEXT:    [[TMP51:%.*]] = fadd fast float [[I136]], [[I128]]
 ; CHECK-NEXT:    [[I138:%.*]] = fmul fast float [[I131]], [[I77]]
-; CHECK-NEXT:    [[TMP52:%.*]] = fadd fast float [[I138]], [[I129]]
 ; CHECK-NEXT:    [[I137:%.*]] = fmul fast float [[I86]], [[I77]]
-; CHECK-NEXT:    [[I139:%.*]] = fadd fast float [[I137]], [[I130]]
-; CHECK-NEXT:    [[I140:%.*]] = fadd fast float [[I134]], poison
-; CHECK-NEXT:    [[I141:%.*]] = fadd fast float [[TMP51]], poison
-; CHECK-NEXT:    [[I142:%.*]] = fadd fast float [[TMP52]], poison
-; CHECK-NEXT:    [[I143:%.*]] = fadd fast float [[I139]], poison
-; CHECK-NEXT:    [[I144:%.*]] = fadd fast float [[I140]], poison
-; CHECK-NEXT:    [[I145:%.*]] = fadd fast float [[I141]], poison
-; CHECK-NEXT:    [[I146:%.*]] = fadd fast float [[I142]], poison
-; CHECK-NEXT:    [[I152:%.*]] = fadd fast float [[I143]], poison
-; CHECK-NEXT:    [[I147:%.*]] = fmul fast float [[I85]], poison
-; CHECK-NEXT:    [[I148:%.*]] = fmul fast float [[I80]], poison
-; CHECK-NEXT:    [[I149:%.*]] = fmul fast float [[I81]], poison
-; CHECK-NEXT:    [[I150:%.*]] = fmul fast float [[I82]], poison
-; CHECK-NEXT:    [[I151:%.*]] = fmul fast float [[I84]], poison
-; CHECK-NEXT:    [[TMP57:%.*]] = fadd fast float [[I151]], [[I147]]
-; CHECK-NEXT:    [[I153:%.*]] = fmul fast float [[I127]], poison
-; CHECK-NEXT:    [[TMP58:%.*]] = fadd fast float [[I153]], [[I148]]
-; CHECK-NEXT:    [[I155:%.*]] = fmul fast float [[I131]], poison
-; CHECK-NEXT:    [[TMP59:%.*]] = fadd fast float [[I155]], [[I149]]
-; CHECK-NEXT:    [[I157:%.*]] = fmul fast float [[I86]], poison
+; CHECK-NEXT:    [[OP_RDX14:%.*]] = fadd fast float poison, [[I133]]
+; CHECK-NEXT:    [[OP_RDX15:%.*]] = fadd fast float [[OP_RDX14]], [[I135]]
+; CHECK-NEXT:    [[OP_RDX12:%.*]] = fadd fast float poison, [[I136]]
+; CHECK-NEXT:    [[OP_RDX13:%.*]] = fadd fast float [[OP_RDX12]], [[I128]]
+; CHECK-NEXT:    [[OP_RDX10:%.*]] = fadd fast float poison, [[I138]]
+; CHECK-NEXT:    [[OP_RDX11:%.*]] = fadd fast float [[OP_RDX10]], [[I129]]
+; CHECK-NEXT:    [[OP_RDX8:%.*]] = fadd fast float poison, [[I137]]
+; CHECK-NEXT:    [[OP_RDX9:%.*]] = fadd fast float [[OP_RDX8]], [[I130]]
+; CHECK-NEXT:    [[TMP41:%.*]] = fmul fast <2 x float> [[TMP8]], poison
+; CHECK-NEXT:    [[TMP42:%.*]] = fmul fast <2 x float> [[TMP7]], poison
+; CHECK-NEXT:    [[TMP43:%.*]] = fmul fast <2 x float> [[TMP6]], poison
+; CHECK-NEXT:    [[TMP44:%.*]] = fmul fast <2 x float> [[TMP5]], poison
+; CHECK-NEXT:    [[TMP45:%.*]] = extractelement <2 x float> [[TMP41]], i32 1
+; CHECK-NEXT:    [[I157:%.*]] = fadd fast float poison, [[TMP45]]
+; CHECK-NEXT:    [[I150:%.*]] = extractelement <2 x float> [[TMP41]], i32 0
 ; CHECK-NEXT:    [[TMP60:%.*]] = fadd fast float [[I157]], [[I150]]
-; CHECK-NEXT:    [[I159:%.*]] = fadd fast float [[TMP57]], poison
-; CHECK-NEXT:    [[I160:%.*]] = fadd fast float [[TMP58]], poison
-; CHECK-NEXT:    [[I161:%.*]] = fadd fast float [[TMP59]], poison
-; CHECK-NEXT:    [[I162:%.*]] = fadd fast float [[TMP60]], poison
-; CHECK-NEXT:    [[I163:%.*]] = fadd fast float [[I159]], poison
-; CHECK-NEXT:    [[I164:%.*]] = fadd fast float [[I160]], poison
-; CHECK-NEXT:    [[I165:%.*]] = fadd fast float [[I161]], poison
-; CHECK-NEXT:    [[I166:%.*]] = fadd fast float [[I162]], poison
+; CHECK-NEXT:    [[TMP47:%.*]] = extractelement <2 x float> [[TMP42]], i32 1
+; CHECK-NEXT:    [[OP_RDX4:%.*]] = fadd fast float poison, [[TMP47]]
+; CHECK-NEXT:    [[TMP48:%.*]] = extractelement <2 x float> [[TMP42]], i32 0
+; CHECK-NEXT:    [[OP_RDX5:%.*]] = fadd fast float [[OP_RDX4]], [[TMP48]]
+; CHECK-NEXT:    [[TMP49:%.*]] = extractelement <2 x float> [[TMP43]], i32 1
+; CHECK-NEXT:    [[OP_RDX2:%.*]] = fadd fast float poison, [[TMP49]]
+; CHECK-NEXT:    [[TMP50:%.*]] = extractelement <2 x float> [[TMP43]], i32 0
+; CHECK-NEXT:    [[OP_RDX3:%.*]] = fadd fast float [[OP_RDX2]], [[TMP50]]
+; CHECK-NEXT:    [[TMP51:%.*]] = extractelement <2 x float> [[TMP44]], i32 0
+; CHECK-NEXT:    [[OP_RDX:%.*]] = fadd fast float poison, [[TMP51]]
+; CHECK-NEXT:    [[TMP52:%.*]] = extractelement <2 x float> [[TMP44]], i32 1
+; CHECK-NEXT:    [[OP_RDX1:%.*]] = fadd fast float [[OP_RDX]], [[TMP52]]
 ; CHECK-NEXT:    br i1 po...
[truncated]

Copy link

github-actions bot commented Aug 8, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Created using spr 1.3.5
@RKSimon RKSimon requested a review from gregbedwell August 9, 2025 11:02
Created using spr 1.3.5
@alexey-bataev
Copy link
Member Author

Ping!

// Calculate actual FMAD cost.
IntrinsicCostAttributes ICA(Intrinsic::fmuladd, RVecTy,
{RVecTy, RVecTy, RVecTy}, FMF);
FMACost = TTI->getIntrinsicInstrCost(ICA, CostKind);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this different to the valid cost that canConvertToFMA returned?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

canConvertToFMA return the cost of scalars in Ops, but here we need the cost of the vector FMAD. So, canConvertToFMA checks, if scalars can be represented as FMAD and then we update the cost with the actual vector cost.

Created using spr 1.3.5
Copy link
Collaborator

@davemgreen davemgreen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran a number of test and whilst is it difficult to be sure, this didn't seem to cause problems in what I tried. It seems OK performance-wise now.

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Created using spr 1.3.7
@alexey-bataev alexey-bataev merged commit 74230ff into main Sep 1, 2025
5 of 9 checks passed
@alexey-bataev alexey-bataev deleted the users/alexey-bataev/spr/slpimprovedfixed-fmad-support-in-reductions branch September 1, 2025 21:01
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Sep 1, 2025
In the initial patch for FMAD, potential FMAD nodes were completely
excluded from the reduction analysis for the smaller patch. But it may
cause regressions.

This patch adds better detection of scalar FMAD reduction operations and
tries to correctly calculate the costs of the FMAD reduction operations
(also, excluding the costs of the scalar fmuls) and split reduction
operations, combined with regular FMADs.

Reviewers: RKSimon, gregbedwell, hiraditya

Reviewed By: RKSimon

Pull Request: llvm/llvm-project#152787
alexey-bataev added a commit that referenced this pull request Sep 2, 2025
In the initial patch for FMAD, potential FMAD nodes were completely
excluded from the reduction analysis for the smaller patch. But it may
cause regressions.

This patch adds better detection of scalar FMAD reduction operations and
tries to correctly calculate the costs of the FMAD reduction operations
(also, excluding the costs of the scalar fmuls) and split reduction
operations, combined with regular FMADs.

Fixed the handling for reduced values with many uses.

Reviewers: RKSimon, gregbedwell, hiraditya

Reviewed By: RKSimon

Pull Request: #152787
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Sep 2, 2025
In the initial patch for FMAD, potential FMAD nodes were completely
excluded from the reduction analysis for the smaller patch. But it may
cause regressions.

This patch adds better detection of scalar FMAD reduction operations and
tries to correctly calculate the costs of the FMAD reduction operations
(also, excluding the costs of the scalar fmuls) and split reduction
operations, combined with regular FMADs.

Fixed the handling for reduced values with many uses.

Reviewers: RKSimon, gregbedwell, hiraditya

Reviewed By: RKSimon

Pull Request: llvm/llvm-project#152787
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants