Skip to content

Commit fa30b51

Browse files
committed
Move IsNegated parameter and cost sub
1 parent a439b53 commit fa30b51

13 files changed

+36
-33
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1648,7 +1648,7 @@ class TargetTransformInfo {
16481648
/// The multiply can optionally be negated, which signifies that it is a sub
16491649
/// reduction.
16501650
LLVM_ABI InstructionCost getMulAccReductionCost(
1651-
bool IsUnsigned, Type *ResTy, VectorType *Ty, bool Negated,
1651+
bool IsUnsigned, bool IsNegated, Type *ResTy, VectorType *Ty,
16521652
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const;
16531653

16541654
/// Calculate the cost of an extended reduction pattern, similar to

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -959,8 +959,8 @@ class TargetTransformInfoImplBase {
959959
}
960960

961961
virtual InstructionCost
962-
getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *Ty,
963-
bool Negated, TTI::TargetCostKind CostKind) const {
962+
getMulAccReductionCost(bool IsUnsigned, bool IsNegated, Type *ResTy,
963+
VectorType *Ty, TTI::TargetCostKind CostKind) const {
964964
return 1;
965965
}
966966

llvm/include/llvm/CodeGen/BasicTTIImpl.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3187,14 +3187,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
31873187
}
31883188

31893189
InstructionCost
3190-
getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *Ty,
3191-
bool Negated,
3190+
getMulAccReductionCost(bool IsUnsigned, bool IsNegated, Type *ResTy,
3191+
VectorType *Ty,
31923192
TTI::TargetCostKind CostKind) const override {
3193-
if (Negated)
3194-
return InstructionCost::getInvalid(CostKind);
31953193
// Without any native support, this is equivalent to the cost of
31963194
// vecreduce.add(mul(ext(Ty A), ext(Ty B))) or
3197-
// vecreduce.add(mul(A, B)).
3195+
// vecreduce.add(mul(A, B)) with an optional negation of the mul.
31983196
VectorType *ExtTy = VectorType::get(ResTy, Ty);
31993197
InstructionCost RedCost = thisT()->getArithmeticReductionCost(
32003198
Instruction::Add, ExtTy, std::nullopt, CostKind);
@@ -3204,8 +3202,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
32043202

32053203
InstructionCost MulCost =
32063204
thisT()->getArithmeticInstrCost(Instruction::Mul, ExtTy, CostKind);
3205+
InstructionCost SubCost =
3206+
IsNegated
3207+
? thisT()->getArithmeticInstrCost(Instruction::Sub, ExtTy, CostKind)
3208+
: 0;
32073209

3208-
return RedCost + MulCost + 2 * ExtCost;
3210+
return RedCost + SubCost + MulCost + 2 * ExtCost;
32093211
}
32103212

32113213
InstructionCost getVectorSplitCost() const { return 1; }

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,9 +1274,9 @@ InstructionCost TargetTransformInfo::getExtendedReductionCost(
12741274
}
12751275

12761276
InstructionCost TargetTransformInfo::getMulAccReductionCost(
1277-
bool IsUnsigned, Type *ResTy, VectorType *Ty, bool Negated,
1277+
bool IsUnsigned, bool IsNegated, Type *ResTy, VectorType *Ty,
12781278
TTI::TargetCostKind CostKind) const {
1279-
return TTIImpl->getMulAccReductionCost(IsUnsigned, ResTy, Ty, Negated,
1279+
return TTIImpl->getMulAccReductionCost(IsUnsigned, IsNegated, ResTy, Ty,
12801280
CostKind);
12811281
}
12821282

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5323,10 +5323,10 @@ InstructionCost AArch64TTIImpl::getExtendedReductionCost(
53235323
}
53245324

53255325
InstructionCost
5326-
AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
5327-
VectorType *VecTy, bool Negated,
5326+
AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, bool IsNegated,
5327+
Type *ResTy, VectorType *VecTy,
53285328
TTI::TargetCostKind CostKind) const {
5329-
if (Negated)
5329+
if (IsNegated)
53305330
return InstructionCost::getInvalid(CostKind);
53315331
EVT VecVT = TLI->getValueType(DL, VecTy);
53325332
EVT ResVT = TLI->getValueType(DL, ResTy);
@@ -5342,7 +5342,7 @@ AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
53425342
return LT.first + 2;
53435343
}
53445344

5345-
return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, VecTy, Negated,
5345+
return BaseT::getMulAccReductionCost(IsUnsigned, IsNegated, ResTy, VecTy,
53465346
CostKind);
53475347
}
53485348

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
446446
TTI::TargetCostKind CostKind) const override;
447447

448448
InstructionCost getMulAccReductionCost(
449-
bool IsUnsigned, Type *ResTy, VectorType *Ty, bool Negated,
449+
bool IsUnsigned, bool IsNegated, Type *ResTy, VectorType *Ty,
450450
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const override;
451451

452452
InstructionCost

llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1882,10 +1882,10 @@ InstructionCost ARMTTIImpl::getExtendedReductionCost(
18821882
}
18831883

18841884
InstructionCost
1885-
ARMTTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
1886-
VectorType *ValTy, bool Negated,
1885+
ARMTTIImpl::getMulAccReductionCost(bool IsUnsigned, bool IsNegated, Type *ResTy,
1886+
VectorType *ValTy,
18871887
TTI::TargetCostKind CostKind) const {
1888-
if (Negated)
1888+
if (IsNegated)
18891889
return InstructionCost::getInvalid(CostKind);
18901890
EVT ValVT = TLI->getValueType(DL, ValTy);
18911891
EVT ResVT = TLI->getValueType(DL, ResTy);
@@ -1907,7 +1907,7 @@ ARMTTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
19071907
return ST->getMVEVectorCostFactor(CostKind) * LT.first;
19081908
}
19091909

1910-
return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, ValTy, Negated,
1910+
return BaseT::getMulAccReductionCost(IsUnsigned, IsNegated, ResTy, ValTy,
19111911
CostKind);
19121912
}
19131913

llvm/lib/Target/ARM/ARMTargetTransformInfo.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,8 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
298298
VectorType *ValTy, std::optional<FastMathFlags> FMF,
299299
TTI::TargetCostKind CostKind) const override;
300300
InstructionCost
301-
getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *ValTy,
302-
bool Negated,
301+
getMulAccReductionCost(bool IsUnsigned, bool IsNegated, Type *ResTy,
302+
VectorType *ValTy,
303303
TTI::TargetCostKind CostKind) const override;
304304

305305
InstructionCost

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5451,7 +5451,8 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
54515451
TTI::CastContextHint::None, CostKind, RedOp);
54525452

54535453
InstructionCost RedCost = TTI.getMulAccReductionCost(
5454-
IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, false, CostKind);
5454+
IsUnsigned, /*IsNegated=*/false, RdxDesc.getRecurrenceType(), ExtType,
5455+
CostKind);
54555456

54565457
if (RedCost.isValid() &&
54575458
RedCost < ExtCost * 2 + MulCost + Ext2Cost + BaseCost)
@@ -5496,7 +5497,8 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
54965497
TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);
54975498

54985499
InstructionCost RedCost = TTI.getMulAccReductionCost(
5499-
IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, false, CostKind);
5500+
IsUnsigned, /*IsNegated=*/false, RdxDesc.getRecurrenceType(), ExtType,
5501+
CostKind);
55005502
InstructionCost ExtraExtCost = 0;
55015503
if (Op0Ty != LargestOpTy || Op1Ty != LargestOpTy) {
55025504
Instruction *ExtraExtOp = (Op0Ty != LargestOpTy) ? Op0 : Op1;
@@ -5515,7 +5517,8 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
55155517
TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);
55165518

55175519
InstructionCost RedCost = TTI.getMulAccReductionCost(
5518-
true, RdxDesc.getRecurrenceType(), VectorTy, false, CostKind);
5520+
true, /*IsNegated=*/false, RdxDesc.getRecurrenceType(), VectorTy,
5521+
CostKind);
55195522

55205523
if (RedCost.isValid() && RedCost < MulCost + BaseCost)
55215524
return I == RetI ? RedCost : 0;

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2799,9 +2799,7 @@ class VPExpressionRecipe : public VPSingleDefRecipe {
27992799
MulAccReduction,
28002800
/// Represent an inloop multiply-accumulate reduction, multiplying the
28012801
/// extended vector operands, negating the multiplication, performing a
2802-
/// reduction.add
2803-
/// on the result, and adding
2804-
/// the scalar result to a chain.
2802+
/// reduction.add on the result, and adding the scalar result to a chain.
28052803
ExtNegatedMulAccReduction,
28062804
};
28072805

0 commit comments

Comments
 (0)