-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[LV] Bundle sub reductions into VPExpressionRecipe #147255
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-vectorizers @llvm/pr-subscribers-backend-arm Author: Sam Tebbs (SamTebbs33) ChangesThis PR bundles sub reductions into the VPExpressionRecipe class and adjusts the cost functions to take the negation into account. Patch is 23.85 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147255.diff 14 Files Affected:
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index c43870392361d..3cc0ea01953c3 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1645,8 +1645,10 @@ class TargetTransformInfo {
/// extensions. This is the cost of as:
/// ResTy vecreduce.add(mul (A, B)).
/// ResTy vecreduce.add(mul(ext(Ty A), ext(Ty B)).
+ /// The multiply can optionally be negated, which signifies that it is a sub
+ /// reduction.
LLVM_ABI InstructionCost getMulAccReductionCost(
- bool IsUnsigned, Type *ResTy, VectorType *Ty,
+ bool IsUnsigned, Type *ResTy, VectorType *Ty, bool Negated,
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const;
/// Calculate the cost of an extended reduction pattern, similar to
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 12f87226c5f57..fd22981a5dbf3 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -960,7 +960,7 @@ class TargetTransformInfoImplBase {
virtual InstructionCost
getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *Ty,
- TTI::TargetCostKind CostKind) const {
+ bool Negated, TTI::TargetCostKind CostKind) const {
return 1;
}
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index bf958e100f2ac..a9c9fa6d1db0d 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -3116,7 +3116,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
InstructionCost
getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *Ty,
+ bool Negated,
TTI::TargetCostKind CostKind) const override {
+ if (Negated)
+ return InstructionCost::getInvalid(CostKind);
// Without any native support, this is equivalent to the cost of
// vecreduce.add(mul(ext(Ty A), ext(Ty B))) or
// vecreduce.add(mul(A, B)).
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 3ebd9d487ba04..ba0d070bffe6d 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1274,9 +1274,10 @@ InstructionCost TargetTransformInfo::getExtendedReductionCost(
}
InstructionCost TargetTransformInfo::getMulAccReductionCost(
- bool IsUnsigned, Type *ResTy, VectorType *Ty,
+ bool IsUnsigned, Type *ResTy, VectorType *Ty, bool Negated,
TTI::TargetCostKind CostKind) const {
- return TTIImpl->getMulAccReductionCost(IsUnsigned, ResTy, Ty, CostKind);
+ return TTIImpl->getMulAccReductionCost(IsUnsigned, ResTy, Ty, Negated,
+ CostKind);
}
InstructionCost
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 380faa6cf6939..d9a367535baf4 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5316,8 +5316,10 @@ InstructionCost AArch64TTIImpl::getExtendedReductionCost(
InstructionCost
AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
- VectorType *VecTy,
+ VectorType *VecTy, bool Negated,
TTI::TargetCostKind CostKind) const {
+ if (Negated)
+ return InstructionCost::getInvalid(CostKind);
EVT VecVT = TLI->getValueType(DL, VecTy);
EVT ResVT = TLI->getValueType(DL, ResTy);
@@ -5332,7 +5334,8 @@ AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
return LT.first + 2;
}
- return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, VecTy, CostKind);
+ return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, VecTy, Negated,
+ CostKind);
}
InstructionCost
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 9ada70bd7086a..8bb31d2a3dac5 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -447,7 +447,7 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
TTI::TargetCostKind CostKind) const override;
InstructionCost getMulAccReductionCost(
- bool IsUnsigned, Type *ResTy, VectorType *Ty,
+ bool IsUnsigned, Type *ResTy, VectorType *Ty, bool Negated,
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const override;
InstructionCost
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index 203fb76d7be86..27eb22b5f9986 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -1884,8 +1884,10 @@ InstructionCost ARMTTIImpl::getExtendedReductionCost(
InstructionCost
ARMTTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
- VectorType *ValTy,
+ VectorType *ValTy, bool Negated,
TTI::TargetCostKind CostKind) const {
+ if (Negated)
+ return InstructionCost::getInvalid(CostKind);
EVT ValVT = TLI->getValueType(DL, ValTy);
EVT ResVT = TLI->getValueType(DL, ResTy);
@@ -1906,7 +1908,8 @@ ARMTTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
return ST->getMVEVectorCostFactor(CostKind) * LT.first;
}
- return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, ValTy, CostKind);
+ return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, ValTy, Negated,
+ CostKind);
}
InstructionCost
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index ca06b9e3cb661..43f47f3e7aa6f 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -299,6 +299,7 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
TTI::TargetCostKind CostKind) const override;
InstructionCost
getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *ValTy,
+ bool Negated,
TTI::TargetCostKind CostKind) const override;
InstructionCost
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 1cfbcf1336620..0adff8d957e98 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -5538,7 +5538,7 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
TTI::CastContextHint::None, CostKind, RedOp);
InstructionCost RedCost = TTI.getMulAccReductionCost(
- IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, CostKind);
+ IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, false, CostKind);
if (RedCost.isValid() &&
RedCost < ExtCost * 2 + MulCost + Ext2Cost + BaseCost)
@@ -5583,7 +5583,7 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);
InstructionCost RedCost = TTI.getMulAccReductionCost(
- IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, CostKind);
+ IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, false, CostKind);
InstructionCost ExtraExtCost = 0;
if (Op0Ty != LargestOpTy || Op1Ty != LargestOpTy) {
Instruction *ExtraExtOp = (Op0Ty != LargestOpTy) ? Op0 : Op1;
@@ -5602,7 +5602,7 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);
InstructionCost RedCost = TTI.getMulAccReductionCost(
- true, RdxDesc.getRecurrenceType(), VectorTy, CostKind);
+ true, RdxDesc.getRecurrenceType(), VectorTy, false, CostKind);
if (RedCost.isValid() && RedCost < MulCost + BaseCost)
return I == RetI ? RedCost : 0;
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index d460573f5bec6..1bc926db301d8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2757,6 +2757,12 @@ class VPExpressionRecipe : public VPSingleDefRecipe {
/// vector operands, performing a reduction.add on the result, and adding
/// the scalar result to a chain.
MulAccReduction,
+ /// Represent an inloop multiply-accumulate reduction, multiplying the
+ /// extended vector operands, negating the multiplication, performing a
+ /// reduction.add
+ /// on the result, and adding
+ /// the scalar result to a chain.
+ ExtNegatedMulAccReduction,
};
/// Type of the expression.
@@ -2780,6 +2786,11 @@ class VPExpressionRecipe : public VPSingleDefRecipe {
VPWidenRecipe *Mul, VPReductionRecipe *Red)
: VPExpressionRecipe(ExpressionTypes::ExtMulAccReduction,
{Ext0, Ext1, Mul, Red}) {}
+ VPExpressionRecipe(VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1,
+ VPWidenRecipe *Mul, VPWidenRecipe *Sub,
+ VPReductionRecipe *Red)
+ : VPExpressionRecipe(ExpressionTypes::ExtNegatedMulAccReduction,
+ {Ext0, Ext1, Mul, Sub, Red}) {}
~VPExpressionRecipe() override {
for (auto *R : reverse(ExpressionRecipes))
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 318e8171e098d..c20b1920c3791 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -2672,13 +2672,17 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF,
RedTy, SrcVecTy, std::nullopt, Ctx.CostKind);
}
case ExpressionTypes::MulAccReduction:
- return Ctx.TTI.getMulAccReductionCost(false, RedTy, SrcVecTy, Ctx.CostKind);
+ return Ctx.TTI.getMulAccReductionCost(false, RedTy, SrcVecTy, false,
+ Ctx.CostKind);
- case ExpressionTypes::ExtMulAccReduction:
+ case ExpressionTypes::ExtNegatedMulAccReduction:
+ case ExpressionTypes::ExtMulAccReduction: {
+ bool Negated = ExpressionType == ExpressionTypes::ExtNegatedMulAccReduction;
return Ctx.TTI.getMulAccReductionCost(
cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() ==
Instruction::ZExt,
- RedTy, SrcVecTy, Ctx.CostKind);
+ RedTy, SrcVecTy, Negated, Ctx.CostKind);
+ }
}
llvm_unreachable("Unknown VPExpressionRecipe::ExpressionTypes enum");
}
@@ -2725,6 +2729,31 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent,
O << ")";
break;
}
+ case ExpressionTypes::ExtNegatedMulAccReduction: {
+ getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker);
+ O << " + ";
+ O << "reduce."
+ << Instruction::getOpcodeName(
+ RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()))
+ << " (sub (0, mul";
+ auto *Mul = cast<VPWidenRecipe>(ExpressionRecipes[2]);
+ Mul->printFlags(O);
+ O << "(";
+ getOperand(0)->printAsOperand(O, SlotTracker);
+ auto *Ext0 = cast<VPWidenCastRecipe>(ExpressionRecipes[0]);
+ O << " " << Instruction::getOpcodeName(Ext0->getOpcode()) << " to "
+ << *Ext0->getResultType() << "), (";
+ getOperand(1)->printAsOperand(O, SlotTracker);
+ auto *Ext1 = cast<VPWidenCastRecipe>(ExpressionRecipes[1]);
+ O << " " << Instruction::getOpcodeName(Ext1->getOpcode()) << " to "
+ << *Ext1->getResultType() << ")";
+ if (Red->isConditional()) {
+ O << ", ";
+ Red->getCondOp()->printAsOperand(O, SlotTracker);
+ }
+ O << "))";
+ break;
+ }
case ExpressionTypes::MulAccReduction:
case ExpressionTypes::ExtMulAccReduction: {
getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker);
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 931d4d42f56e4..a09d2037e97b4 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -2908,16 +2908,17 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
// Clamp the range if using multiply-accumulate-reduction is profitable.
auto IsMulAccValidAndClampRange =
- [&](bool isZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
- VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt) -> bool {
+ [&](bool IsZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
+ VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt,
+ bool Negated = false) -> bool {
return LoopVectorizationPlanner::getDecisionAndClampRange(
[&](ElementCount VF) {
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
Type *SrcTy =
Ext0 ? Ctx.Types.inferScalarType(Ext0->getOperand(0)) : RedTy;
auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
- InstructionCost MulAccCost =
- Ctx.TTI.getMulAccReductionCost(isZExt, RedTy, SrcVecTy, CostKind);
+ InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
+ IsZExt, RedTy, SrcVecTy, Negated, CostKind);
InstructionCost MulCost = Mul->computeCost(VF, Ctx);
InstructionCost RedCost = Red->computeCost(VF, Ctx);
InstructionCost ExtCost = 0;
@@ -2935,14 +2936,22 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
};
VPValue *VecOp = Red->getVecOp();
+ VPValue *Mul = nullptr;
+ VPValue *Sub = nullptr;
VPValue *A, *B;
+ // Sub reductions will have a sub between the add reduction and vec op.
+ if (match(VecOp,
+ m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(Mul))))
+ Sub = VecOp;
+ else
+ Mul = VecOp;
// Try to match reduce.add(mul(...)).
- if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
+ if (match(Mul, m_Mul(m_VPValue(A), m_VPValue(B)))) {
auto *RecipeA =
dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
auto *RecipeB =
dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
- auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe());
+ auto *MulR = cast<VPWidenRecipe>(Mul->getDefiningRecipe());
// Match reduce.add(mul(ext, ext)).
if (RecipeA && RecipeB &&
@@ -2951,12 +2960,16 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
Instruction::CastOps::ZExt,
- Mul, RecipeA, RecipeB, nullptr)) {
- return new VPExpressionRecipe(RecipeA, RecipeB, Mul, Red);
+ MulR, RecipeA, RecipeB, nullptr, Sub)) {
+ if (Sub)
+ return new VPExpressionRecipe(
+ RecipeA, RecipeB, MulR,
+ cast<VPWidenRecipe>(Sub->getDefiningRecipe()), Red);
+ return new VPExpressionRecipe(RecipeA, RecipeB, MulR, Red);
}
// Match reduce.add(mul).
- if (IsMulAccValidAndClampRange(true, Mul, nullptr, nullptr, nullptr))
- return new VPExpressionRecipe(Mul, Red);
+ if (IsMulAccValidAndClampRange(true, MulR, nullptr, nullptr, nullptr, Sub))
+ return new VPExpressionRecipe(MulR, Red);
}
// Match reduce.add(ext(mul(ext(A), ext(B)))).
// All extend recipes must have same opcode or A == B
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index b2fced47b9527..7953aec48c8b0 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1401,8 +1401,8 @@ static void analyzeCostOfVecReduction(const IntrinsicInst &II,
TTI::CastContextHint::None, CostKind, RedOp);
CostBeforeReduction = ExtCost * 2 + MulCost + Ext2Cost;
- CostAfterReduction =
- TTI.getMulAccReductionCost(IsUnsigned, II.getType(), ExtType, CostKind);
+ CostAfterReduction = TTI.getMulAccReductionCost(IsUnsigned, II.getType(),
+ ExtType, false, CostKind);
return;
}
CostAfterReduction = TTI.getArithmeticReductionCost(ReductionOpc, VecRedTy,
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll
index 4af3fa9202c77..8059ac12ecd2e 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll
@@ -416,3 +416,146 @@ exit:
%r.0.lcssa = phi i64 [ %rdx.next, %loop ]
ret i64 %r.0.lcssa
}
+
+define i32 @print_mulacc_sub(ptr %a, ptr %b) {
+; CHECK: VPlan 'Initial VPlan for VF={4},UF>=1' {
+; CHECK-NEXT: Live-in vp<%0> = VF
+; CHECK-NEXT: Live-in vp<%1> = VF * UF
+; CHECK-NEXT: Live-in vp<%2> = vector-trip-count
+; CHECK-NEXT: Live-in ir<1024> = original trip-count
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<entry>:
+; CHECK-NEXT: Successor(s): scalar.ph, vector.ph
+; CHECK-EMPTY:
+; CHECK-NEXT: vector.ph:
+; CHECK-NEXT: EMIT vp<%3> = reduction-start-vector ir<0>, ir<0>, ir<1>
+; CHECK-NEXT: Successor(s): vector loop
+; CHECK-EMPTY:
+; CHECK-NEXT: <x1> vector loop: {
+; CHECK-NEXT: vector.body:
+; CHECK-NEXT: EMIT vp<%4> = CANONICAL-INDUCTION ir<0>, vp<%index.next>
+; CHECK-NEXT: WIDEN-REDUCTION-PHI ir<%accum> = phi vp<%3>, vp<%8>
+; CHECK-NEXT: vp<%5> = SCALAR-STEPS vp<%4>, ir<1>, vp<%0>
+; CHECK-NEXT: CLONE ir<%gep.a> = getelementptr ir<%a>, vp<%5>
+; CHECK-NEXT: vp<%6> = vector-pointer ir<%gep.a>
+; CHECK-NEXT: WIDEN ir<%load.a> = load vp<%6>
+; CHECK-NEXT: CLONE ir<%gep.b> = getelementptr ir<%b>, vp<%5>
+; CHECK-NEXT: vp<%7> = vector-pointer ir<%gep.b>
+; CHECK-NEXT: WIDEN ir<%load.b> = load vp<%7>
+; CHECK-NEXT: EXPRESSION vp<%8> = ir<%accum> + reduce.add (sub (0, mul (ir<%load.b> zext to i32), (ir<%load.a> zext to i32)))
+; CHECK-NEXT: EMIT vp<%index.next> = add nuw vp<%4>, vp<%1>
+; CHECK-NEXT: EMIT branch-on-count vp<%index.next>, vp<%2>
+; CHECK-NEXT: No successors
+; CHECK-NEXT: }
+; CHECK-NEXT: Successor(s): middle.block
+; CHECK-EMPTY:
+; CHECK-NEXT: middle.block:
+; CHECK-NEXT: EMIT vp<%10> = compute-reduction-result ir<%accum>, vp<%8>
+; CHECK-NEXT: EMIT vp<%cmp.n> = icmp eq ir<1024>, vp<%2>
+; CHECK-NEXT: EMIT branch-on-cond vp<%cmp.n>
+; CHECK-NEXT: Successor(s): ir-bb<for.exit>, scalar.ph
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.exit>:
+; CHECK-NEXT: IR %add.lcssa = phi i32 [ %add, %for.body ] (extra operand: vp<%10> from middle.block)
+; CHECK-NEXT: No successors
+; CHECK-EMPTY:
+; CHECK-NEXT: scalar.ph:
+; CHECK-NEXT: EMIT-SCALAR vp<%bc.resume.val> = phi [ vp<%2>, middle.block ], [ ir<0>, ir-bb<entry> ]
+; CHECK-NEXT: EMIT-SCALAR vp<%bc.merge.rdx> = phi [ vp<%10>, middle.block ], [ ir<0>, ir-bb<entry> ]
+; CHECK-NEXT: Successor(s): ir-bb<for.body>
+; CHECK-EMPTY:
+; CHECK-NEXT: ir-bb<for.body>:
+; CHECK-NEXT: IR %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ] (extra operand: vp<%bc.resume.val> from scalar.ph)
+; CHECK-NEXT: IR %accum = phi i32 [ 0, %entry ], [ %add, %for.body ] (extra operand: vp<%bc.merge.rdx> from scalar.ph)
+; CHECK-NEXT: IR %gep.a = getelementptr i8, ptr %a, i64 %iv
+; CHECK-NEXT: IR %load.a = load i8, ptr %gep.a, align 1
+; CHECK-NEXT: IR %ext.a = zext i8 %load.a to i32
+; CHECK-NEXT: IR %gep.b = getelementptr i8, ptr %b, i64 %iv
+; CHECK-NEXT: IR %load.b = load i8, ptr %gep.b, align 1
+; CHECK-NEXT: IR %ext.b = zext i8 %load.b to i32
+; CHECK-NEXT: IR %mul = mul i32 %ext.b, %ext.a
+; CHECK-NEXT: IR %add = sub i32 %accum, %mul
+; CHECK-NEXT: IR %iv.next = add i64 %iv, 1
+; CHECK-NEXT: IR %exitcond.not = icmp eq i64 %iv.next, 1024
+; CHECK-NEXT: No successors
+; CH...
[truncated]
|
@@ -2725,6 +2729,31 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent, | |||
O << ")"; | |||
break; | |||
} | |||
case ExpressionTypes::ExtNegatedMulAccReduction: { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way to commonise this with the ExtMulAccReduction case if the only difference is a negate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That was my initial approach but it required checking the number of operands to know if there was a sub or not, and I was asked to create an expression type to not rely on operand ordering being stable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you still could re-use the code for printing, by just checking the expression type to decide whether to print the sub or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if you've seen the comment above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did miss this, thanks. It should be irrelevant now that that expression type is gone.
CostAfterReduction = | ||
TTI.getMulAccReductionCost(IsUnsigned, II.getType(), ExtType, CostKind); | ||
CostAfterReduction = TTI.getMulAccReductionCost(IsUnsigned, II.getType(), | ||
ExtType, false, CostKind); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Probably better written as /*Negated=*/false
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
LLVM_ABI InstructionCost getMulAccReductionCost( | ||
bool IsUnsigned, Type *ResTy, VectorType *Ty, | ||
bool IsUnsigned, Type *ResTy, VectorType *Ty, bool Negated, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it worth keeping the booleans together, i.e. next to IsUnsigned
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, done.
TTI::TargetCostKind CostKind) const override { | ||
if (Negated) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't we add a cost for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I've added a cost for the sub.
@@ -5538,7 +5538,7 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I, | |||
TTI::CastContextHint::None, CostKind, RedOp); | |||
|
|||
InstructionCost RedCost = TTI.getMulAccReductionCost( | |||
IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, CostKind); | |||
IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, false, CostKind); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: /*Negated=*/false
and same for other below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
/// Represent an inloop multiply-accumulate reduction, multiplying the | ||
/// extended vector operands, negating the multiplication, performing a | ||
/// reduction.add | ||
/// on the result, and adding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Formatting of the comment looks a bit odd - can you fix it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
This PR bundles partial reductions inside the VPExpressionRecipe class. Depends on llvm#147255 .
39f3dab
to
0b93c24
Compare
b7c9820
to
fa30b51
Compare
Ping |
This PR allows the loop vectorizer to handle in-loop sub reductions by forming a normal in-loop add reduction with a negated input. Stacked PRs: 1. -> llvm/llvm-project#147026 2. llvm/llvm-project#147255 3. llvm/llvm-project#147302 4. llvm/llvm-project#147513
TTI::TargetCostKind CostKind) const override { | ||
// Without any native support, this is equivalent to the cost of | ||
// vecreduce.add(mul(ext(Ty A), ext(Ty B))) or | ||
// vecreduce.add(mul(A, B)). | ||
// vecreduce.add(mul(A, B)) with an optional negation of the mul. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: The wording makes it seem like the optional negation only applies to the second form.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
fa30b51
to
1d7cb25
Compare
Apologies for the noise. |
Ping |
entry: | ||
br label %for.body | ||
|
||
for.body: ; preds = %for.body, %entry |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for.body: ; preds = %for.body, %entry | |
loop: |
nit: consistency with other functions in file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
%exitcond.not = icmp eq i64 %iv.next, 1024 | ||
br i1 %exitcond.not, label %for.exit, label %for.body | ||
|
||
for.exit: ; preds = %for.body |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for.exit: ; preds = %for.body | |
exit: |
nit: consistency with other functions in file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -416,3 +416,146 @@ exit: | |||
%r.0.lcssa = phi i64 [ %rdx.next, %loop ] | |||
ret i64 %r.0.lcssa | |||
} | |||
|
|||
define i32 @print_mulacc_sub(ptr %a, ptr %b) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also add a test that checks the generated code? IIUC there should also be changes in costing/vectorization factors we chose, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, the codegen changes will be covered by the existing tests, just curious if it would be possible to add a test that benefits from the cost changes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've tried but haven't been able to come up with a test that is different in the VF chosen without these changes. I reckon that it will be easier once perhaps the AArch64 or ARM getMulAccReductionCost
functions accept the sub version.
@@ -2725,6 +2729,31 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent, | |||
O << ")"; | |||
break; | |||
} | |||
case ExpressionTypes::ExtNegatedMulAccReduction: { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if you've seen the comment above?
(not sure why, but it looks like the precommit tests on Linux/Windows have not been triggered for some reason, but the libx++ ones have; may be solve by updating the latest main again?) |
1d7cb25
to
229331e
Compare
Thanks for the review @fhahn , I'll have a look at it tomorrow. I've just rebased the patch on top of main after the sub reduction patch was merged. This involved removing the negated expression type since (non-chained) a sub reduction is now represented with a subtraction, rather than an addition with a negated input. |
/// ResTy vecreduce.add(mul(ext(Ty A), ext(Ty B)). | ||
/// getArithmeticReductionCost of an Add/Sub reduction with multiply and | ||
/// optional extensions. This is the cost of as: | ||
/// ResTy vecreduce.add/sub(mul (A, B)). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
/// ResTy vecreduce.add/sub(mul (A, B)). | |
/// ResTy vecreduce.add/sub(mul(A, B)). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
TTI::TargetCostKind CostKind) const override { | ||
// Without any native support, this is equivalent to the cost of | ||
// vecreduce.add(mul(ext(Ty A), ext(Ty B))) or | ||
// vecreduce.add(mul(A, B)). | ||
// vecreduce.add(mul(A, B)). IsNegated determines if the mul is negated. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IsNegated
no longer exists.
please also add an assert that RedOpcode is either an add or a sub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
TTI::TargetCostKind CostKind) const { | ||
if (RedOpcode != Instruction::Add) | ||
return InstructionCost::getInvalid(CostKind); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this returning an invalid cost, rather than adding the cost of a negation of the operand?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was going with the most conservative approach at first, but I've now allowed subs here (but made sure that it's an add in the UDOT case below). We don't need to consider a negation of the operand since this function isn't used for the chained add+sub case at the moment.
@@ -3126,14 +3127,15 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, | |||
}; | |||
|
|||
VPValue *VecOp = Red->getVecOp(); | |||
VPValue *Mul = VecOp; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like a partially NFC change, and I'd prefer the use of VecOp
as it was before this change. The reason for this is that in the case that VecOp
is not a multiply, Mul
is still defined (to the value of VecOp
), which I don't think is right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, I think this is left over from when we were checking for a negation. Done.
It looks like one of the vectorizer tests is failing in precommit and possibly needs updating? |
This PR bundles sub reductions into the VPExpressionRecipe class and adjusts the cost functions to take the negation into account.
7d40358
to
33236a3
Compare
Yeah it needed a rebase, done. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with a few inline comments remaining, thanks!
/// ResTy vecreduce.add/sub(mul (A, B)). | ||
/// ResTy vecreduce.add/sub(mul(ext(Ty A), ext(Ty B)). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// ResTy vecreduce.add/sub(mul (A, B)). | |
/// ResTy vecreduce.add/sub(mul(ext(Ty A), ext(Ty B)). | |
/// * ResTy vecreduce.add/sub(mul (A, B)) or, | |
/// * ResTy vecreduce.add/sub(mul(ext(Ty A), ext(Ty B)). |
VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt, | ||
unsigned Opcode) -> bool { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt, | |
unsigned Opcode) -> bool { | |
VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt | |
) -> bool { |
Can we just use the captured Opcode?
CostAfterReduction = TTI.getMulAccReductionCost( | ||
IsUnsigned, ReductionOpc, II.getType(), ExtType, CostKind); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be nice to have a test for this, but not sure if that's possible.
This PR bundles sub reductions into the VPExpressionRecipe class and adjusts the cost functions to take the negation into account.
Stacked PRs: