Skip to content

Commit e9c1982

Browse files
committed
[ARM] Basic gather scatter cost model
This is a very basic MVE gather/scatter cost model, based roughly on the code that we will currently produce. It does not handle truncating scatters or extending gathers correctly yet, as it is difficult to tell that they are going to be correctly extended/truncated from the limited information in the cost function. This can be improved as we extend support for these in the future. Based on code originally written by David Sherwood. Differential Revision: https://reviews.llvm.org/D73021
1 parent 0b83e14 commit e9c1982

File tree

3 files changed

+189
-112
lines changed

3 files changed

+189
-112
lines changed

llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,67 @@ int ARMTTIImpl::getInterleavedMemoryOpCost(
860860
UseMaskForCond, UseMaskForGaps);
861861
}
862862

863+
unsigned ARMTTIImpl::getGatherScatterOpCost(unsigned Opcode, Type *DataTy,
864+
Value *Ptr, bool VariableMask,
865+
unsigned Alignment) {
866+
if (!ST->hasMVEIntegerOps() || !EnableMaskedGatherScatters)
867+
return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
868+
Alignment);
869+
870+
assert(DataTy->isVectorTy() && "Can't do gather/scatters on scalar!");
871+
VectorType *VTy = cast<VectorType>(DataTy);
872+
873+
// TODO: Splitting, once we do that.
874+
// TODO: trunc/sext/zext the result/input
875+
876+
unsigned NumElems = VTy->getNumElements();
877+
unsigned EltSize = VTy->getScalarSizeInBits();
878+
std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, DataTy);
879+
880+
// For now, it is assumed that for the MVE gather instructions the loads are
881+
// all effectively serialised. This means the cost is the scalar cost
882+
// multiplied by the number of elements being loaded. This is possibly very
883+
// conservative, but even so we still end up vectorising loops because the
884+
// cost per iteration for many loops is lower than for scalar loops.
885+
unsigned VectorCost = NumElems * LT.first;
886+
// The scalarization cost should be a lot higher. We use the number of vector
887+
// elements plus the scalarization overhead.
888+
unsigned ScalarCost =
889+
NumElems * LT.first + BaseT::getScalarizationOverhead(DataTy, {});
890+
891+
// TODO: Cost extended gathers or trunc stores correctly.
892+
if (EltSize * NumElems != 128 || NumElems < 4)
893+
return ScalarCost;
894+
if (Alignment < EltSize / 8)
895+
return ScalarCost;
896+
897+
// Any (aligned) i32 gather will not need to be scalarised.
898+
if (EltSize == 32)
899+
return VectorCost;
900+
// For smaller types, we need to ensure that the gep's inputs are correctly
901+
// extended from a small enough value. Other size (including i64) are
902+
// scalarized for now.
903+
if (EltSize != 8 && EltSize != 16)
904+
return ScalarCost;
905+
906+
if (auto BC = dyn_cast<BitCastInst>(Ptr))
907+
Ptr = BC->getOperand(0);
908+
if (auto *GEP = dyn_cast<GetElementPtrInst>(Ptr)) {
909+
if (GEP->getNumOperands() != 2)
910+
return ScalarCost;
911+
unsigned Scale = DL.getTypeAllocSize(GEP->getResultElementType());
912+
// Scale needs to be correct (which is only relevant for i16s).
913+
if (Scale != 1 && Scale * 8 != EltSize)
914+
return ScalarCost;
915+
// And we need to zext (not sext) the indexes from a small enough type.
916+
if (auto ZExt = dyn_cast<ZExtInst>(GEP->getOperand(1)))
917+
if (ZExt->getOperand(0)->getType()->getScalarSizeInBits() <= EltSize)
918+
return VectorCost;
919+
return ScalarCost;
920+
}
921+
return ScalarCost;
922+
}
923+
863924
bool ARMTTIImpl::isLoweredToCall(const Function *F) {
864925
if (!F->isIntrinsic())
865926
BaseT::isLoweredToCall(F);

llvm/lib/Target/ARM/ARMTargetTransformInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,9 @@ class ARMTTIImpl : public BasicTTIImplBase<ARMTTIImpl> {
202202
bool UseMaskForCond = false,
203203
bool UseMaskForGaps = false);
204204

205+
unsigned getGatherScatterOpCost(unsigned Opcode, Type *DataTy, Value *Ptr,
206+
bool VariableMask, unsigned Alignment);
207+
205208
bool isLoweredToCall(const Function *F);
206209
bool isHardwareLoopProfitable(Loop *L, ScalarEvolution &SE,
207210
AssumptionCache &AC,

0 commit comments

Comments
 (0)