Skip to content

Commit 4cf16ef

Browse files
committed
[AArch64][SVE] Add patterns for unpredicated load/store to frame-indices.
This patch also fixes up a number of cases in DAGCombine and SelectionDAGBuilder where the size of a scalable vector is used in a fixed-width context (thus triggering an assertion failure). Reviewers: efriedma, c-rhodes, rovka, cameron.mcinally Reviewed By: efriedma Tags: #llvm Differential Revision: https://reviews.llvm.org/D71215
1 parent e0f0d0e commit 4cf16ef

File tree

12 files changed

+341
-7
lines changed

12 files changed

+341
-7
lines changed

llvm/include/llvm/Analysis/MemoryLocation.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "llvm/ADT/Optional.h"
2020
#include "llvm/IR/Instructions.h"
2121
#include "llvm/IR/Metadata.h"
22+
#include "llvm/Support/TypeSize.h"
2223

2324
namespace llvm {
2425

@@ -240,6 +241,12 @@ class MemoryLocation {
240241
return getForArgument(Call, ArgIdx, &TLI);
241242
}
242243

244+
// Return the exact size if the exact size is known at compiletime,
245+
// otherwise return MemoryLocation::UnknownSize.
246+
static uint64_t getSizeOrUnknown(const TypeSize &T) {
247+
return T.isScalable() ? UnknownSize : T.getFixedSize();
248+
}
249+
243250
explicit MemoryLocation(const Value *Ptr = nullptr,
244251
LocationSize Size = LocationSize::unknown(),
245252
const AAMDNodes &AATags = AAMDNodes())

llvm/lib/Analysis/Loads.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,9 @@ bool llvm::isDereferenceableAndAlignedPointer(const Value *V, Type *Ty,
140140
const DataLayout &DL,
141141
const Instruction *CtxI,
142142
const DominatorTree *DT) {
143-
if (!Ty->isSized())
143+
// For unsized types or scalable vectors we don't know exactly how many bytes
144+
// are dereferenced, so bail out.
145+
if (!Ty->isSized() || (Ty->isVectorTy() && Ty->getVectorIsScalable()))
144146
return false;
145147

146148
// When dereferenceability information is provided by a dereferenceable

llvm/lib/CodeGen/CodeGenPrepare.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6809,6 +6809,14 @@ static bool splitMergedValStore(StoreInst &SI, const DataLayout &DL,
68096809
const TargetLowering &TLI) {
68106810
// Handle simple but common cases only.
68116811
Type *StoreType = SI.getValueOperand()->getType();
6812+
6813+
// The code below assumes shifting a value by <number of bits>,
6814+
// whereas scalable vectors would have to be shifted by
6815+
// <2log(vscale) + number of bits> in order to store the
6816+
// low/high parts. Bailing out for now.
6817+
if (StoreType->isVectorTy() && StoreType->getVectorIsScalable())
6818+
return false;
6819+
68126820
if (!DL.typeSizeEqualsStoreSize(StoreType) ||
68136821
DL.getTypeSizeInBits(StoreType) == 0)
68146822
return false;

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15738,7 +15738,14 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode *St) {
1573815738
if (OptLevel == CodeGenOpt::None || !EnableStoreMerging)
1573915739
return false;
1574015740

15741+
// TODO: Extend this function to merge stores of scalable vectors.
15742+
// (i.e. two <vscale x 8 x i8> stores can be merged to one <vscale x 16 x i8>
15743+
// store since we know <vscale x 16 x i8> is exactly twice as large as
15744+
// <vscale x 8 x i8>). Until then, bail out for scalable vectors.
1574115745
EVT MemVT = St->getMemoryVT();
15746+
if (MemVT.isScalableVector())
15747+
return false;
15748+
1574215749
int64_t ElementSizeBytes = MemVT.getStoreSize();
1574315750
unsigned NumMemElts = MemVT.isVector() ? MemVT.getVectorNumElements() : 1;
1574415751

@@ -20842,9 +20849,11 @@ bool DAGCombiner::isAlias(SDNode *Op0, SDNode *Op1) const {
2084220849
: (LSN->getAddressingMode() == ISD::PRE_DEC)
2084320850
? -1 * C->getSExtValue()
2084420851
: 0;
20852+
uint64_t Size =
20853+
MemoryLocation::getSizeOrUnknown(LSN->getMemoryVT().getStoreSize());
2084520854
return {LSN->isVolatile(), LSN->isAtomic(), LSN->getBasePtr(),
2084620855
Offset /*base offset*/,
20847-
Optional<int64_t>(LSN->getMemoryVT().getStoreSize()),
20856+
Optional<int64_t>(Size),
2084820857
LSN->getMemOperand()};
2084920858
}
2085020859
if (const auto *LN = cast<LifetimeSDNode>(N))
@@ -21124,6 +21133,12 @@ bool DAGCombiner::parallelizeChainedStores(StoreSDNode *St) {
2112421133
if (BasePtr.getBase().isUndef())
2112521134
return false;
2112621135

21136+
// BaseIndexOffset assumes that offsets are fixed-size, which
21137+
// is not valid for scalable vectors where the offsets are
21138+
// scaled by `vscale`, so bail out early.
21139+
if (St->getMemoryVT().isScalableVector())
21140+
return false;
21141+
2112721142
// Add ST's interval.
2112821143
Intervals.insert(0, (St->getMemoryVT().getSizeInBits() + 7) / 8, Unit);
2112921144

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6824,9 +6824,10 @@ SDValue SelectionDAG::getLoad(ISD::MemIndexedMode AM, ISD::LoadExtType ExtType,
68246824
if (PtrInfo.V.isNull())
68256825
PtrInfo = InferPointerInfo(PtrInfo, *this, Ptr, Offset);
68266826

6827+
uint64_t Size = MemoryLocation::getSizeOrUnknown(MemVT.getStoreSize());
68276828
MachineFunction &MF = getMachineFunction();
68286829
MachineMemOperand *MMO = MF.getMachineMemOperand(
6829-
PtrInfo, MMOFlags, MemVT.getStoreSize(), Alignment, AAInfo, Ranges);
6830+
PtrInfo, MMOFlags, Size, Alignment, AAInfo, Ranges);
68306831
return getLoad(AM, ExtType, VT, dl, Chain, Ptr, Offset, MemVT, MMO);
68316832
}
68326833

@@ -6946,8 +6947,10 @@ SDValue SelectionDAG::getStore(SDValue Chain, const SDLoc &dl, SDValue Val,
69466947
PtrInfo = InferPointerInfo(PtrInfo, *this, Ptr);
69476948

69486949
MachineFunction &MF = getMachineFunction();
6949-
MachineMemOperand *MMO = MF.getMachineMemOperand(
6950-
PtrInfo, MMOFlags, Val.getValueType().getStoreSize(), Alignment, AAInfo);
6950+
uint64_t Size =
6951+
MemoryLocation::getSizeOrUnknown(Val.getValueType().getStoreSize());
6952+
MachineMemOperand *MMO =
6953+
MF.getMachineMemOperand(PtrInfo, MMOFlags, Size, Alignment, AAInfo);
69516954
return getStore(Chain, dl, Val, Ptr, MMO);
69526955
}
69536956

llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,8 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
220220
void SelectLoadLane(SDNode *N, unsigned NumVecs, unsigned Opc);
221221
void SelectPostLoadLane(SDNode *N, unsigned NumVecs, unsigned Opc);
222222

223+
bool SelectAddrModeFrameIndexSVE(SDValue N, SDValue &Base, SDValue &OffImm);
224+
223225
void SelectStore(SDNode *N, unsigned NumVecs, unsigned Opc);
224226
void SelectPostStore(SDNode *N, unsigned NumVecs, unsigned Opc);
225227
void SelectStoreLane(SDNode *N, unsigned NumVecs, unsigned Opc);
@@ -1374,6 +1376,23 @@ void AArch64DAGToDAGISel::SelectStore(SDNode *N, unsigned NumVecs,
13741376
ReplaceNode(N, St);
13751377
}
13761378

1379+
bool AArch64DAGToDAGISel::SelectAddrModeFrameIndexSVE(SDValue N, SDValue &Base,
1380+
SDValue &OffImm) {
1381+
SDLoc dl(N);
1382+
const DataLayout &DL = CurDAG->getDataLayout();
1383+
const TargetLowering *TLI = getTargetLowering();
1384+
1385+
// Try to match it for the frame address
1386+
if (auto FINode = dyn_cast<FrameIndexSDNode>(N)) {
1387+
int FI = FINode->getIndex();
1388+
Base = CurDAG->getTargetFrameIndex(FI, TLI->getPointerTy(DL));
1389+
OffImm = CurDAG->getTargetConstant(0, dl, MVT::i64);
1390+
return true;
1391+
}
1392+
1393+
return false;
1394+
}
1395+
13771396
void AArch64DAGToDAGISel::SelectPostStore(SDNode *N, unsigned NumVecs,
13781397
unsigned Opc) {
13791398
SDLoc dl(N);

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9458,6 +9458,10 @@ bool AArch64TargetLowering::isLegalAddressingMode(const DataLayout &DL,
94589458
if (AM.HasBaseReg && AM.BaseOffs && AM.Scale)
94599459
return false;
94609460

9461+
// FIXME: Update this method to support scalable addressing modes.
9462+
if (Ty->isVectorTy() && Ty->getVectorIsScalable())
9463+
return AM.HasBaseReg && !AM.BaseOffs && !AM.Scale;
9464+
94619465
// check reg + imm case:
94629466
// i.e., reg + 0, reg + imm9, reg + SIZE_IN_BYTES * uimm12
94639467
uint64_t NumBytes = 0;

llvm/lib/Target/AArch64/AArch64InstrFormats.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,8 @@ def simm7s16 : Operand<i32> {
349349
let PrintMethod = "printImmScale<16>";
350350
}
351351

352+
def am_sve_fi : ComplexPattern<i64, 2, "SelectAddrModeFrameIndexSVE", []>;
353+
352354
def am_indexed7s8 : ComplexPattern<i64, 2, "SelectAddrModeIndexed7S8", []>;
353355
def am_indexed7s16 : ComplexPattern<i64, 2, "SelectAddrModeIndexed7S16", []>;
354356
def am_indexed7s32 : ComplexPattern<i64, 2, "SelectAddrModeIndexed7S32", []>;

llvm/lib/Target/AArch64/AArch64InstrInfo.cpp

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1691,6 +1691,8 @@ unsigned AArch64InstrInfo::isStoreToStackSlot(const MachineInstr &MI,
16911691
case AArch64::STRSui:
16921692
case AArch64::STRDui:
16931693
case AArch64::STRQui:
1694+
case AArch64::LDR_PXI:
1695+
case AArch64::STR_PXI:
16941696
if (MI.getOperand(0).getSubReg() == 0 && MI.getOperand(1).isFI() &&
16951697
MI.getOperand(2).isImm() && MI.getOperand(2).getImm() == 0) {
16961698
FrameIndex = MI.getOperand(1).getIndex();
@@ -1803,9 +1805,19 @@ unsigned AArch64InstrInfo::getLoadStoreImmIdx(unsigned Opc) {
18031805
case AArch64::STNPSi:
18041806
case AArch64::LDG:
18051807
case AArch64::STGPi:
1808+
case AArch64::LD1B_IMM:
1809+
case AArch64::LD1H_IMM:
1810+
case AArch64::LD1W_IMM:
1811+
case AArch64::LD1D_IMM:
1812+
case AArch64::ST1B_IMM:
1813+
case AArch64::ST1H_IMM:
1814+
case AArch64::ST1W_IMM:
1815+
case AArch64::ST1D_IMM:
18061816
return 3;
18071817
case AArch64::ADDG:
18081818
case AArch64::STGOffset:
1819+
case AArch64::LDR_PXI:
1820+
case AArch64::STR_PXI:
18091821
return 2;
18101822
}
18111823
}
@@ -2056,6 +2068,7 @@ AArch64InstrInfo::getMemOpBaseRegImmOfsOffsetOperand(MachineInstr &LdSt) const {
20562068
bool AArch64InstrInfo::getMemOpInfo(unsigned Opcode, unsigned &Scale,
20572069
unsigned &Width, int64_t &MinOffset,
20582070
int64_t &MaxOffset) {
2071+
const unsigned SVEMaxBytesPerVector = AArch64::SVEMaxBitsPerVector / 8;
20592072
switch (Opcode) {
20602073
// Not a memory operation or something we want to handle.
20612074
default:
@@ -2220,16 +2233,33 @@ bool AArch64InstrInfo::getMemOpInfo(unsigned Opcode, unsigned &Scale,
22202233
break;
22212234
case AArch64::LDR_PXI:
22222235
case AArch64::STR_PXI:
2223-
Scale = Width = 2;
2236+
Scale = 2;
2237+
Width = SVEMaxBytesPerVector / 8;
22242238
MinOffset = -256;
22252239
MaxOffset = 255;
22262240
break;
22272241
case AArch64::LDR_ZXI:
22282242
case AArch64::STR_ZXI:
2229-
Scale = Width = 16;
2243+
Scale = 16;
2244+
Width = SVEMaxBytesPerVector;
22302245
MinOffset = -256;
22312246
MaxOffset = 255;
22322247
break;
2248+
case AArch64::LD1B_IMM:
2249+
case AArch64::LD1H_IMM:
2250+
case AArch64::LD1W_IMM:
2251+
case AArch64::LD1D_IMM:
2252+
case AArch64::ST1B_IMM:
2253+
case AArch64::ST1H_IMM:
2254+
case AArch64::ST1W_IMM:
2255+
case AArch64::ST1D_IMM:
2256+
// A full vectors worth of data
2257+
// Width = mbytes * elements
2258+
Scale = 16;
2259+
Width = SVEMaxBytesPerVector;
2260+
MinOffset = -8;
2261+
MaxOffset = 7;
2262+
break;
22332263
case AArch64::ST2GOffset:
22342264
case AArch64::STZ2GOffset:
22352265
Scale = 16;
@@ -3433,6 +3463,14 @@ static bool isSVEScaledImmInstruction(unsigned Opcode) {
34333463
case AArch64::STR_ZXI:
34343464
case AArch64::LDR_PXI:
34353465
case AArch64::STR_PXI:
3466+
case AArch64::LD1B_IMM:
3467+
case AArch64::LD1H_IMM:
3468+
case AArch64::LD1W_IMM:
3469+
case AArch64::LD1D_IMM:
3470+
case AArch64::ST1B_IMM:
3471+
case AArch64::ST1H_IMM:
3472+
case AArch64::ST1W_IMM:
3473+
case AArch64::ST1D_IMM:
34363474
return true;
34373475
default:
34383476
return false;

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1261,6 +1261,52 @@ let Predicates = [HasSVE] in {
12611261
defm : pred_store<nxv4i32, nxv4i1, non_temporal_store, STNT1W_ZRI>;
12621262
defm : pred_store<nxv2i64, nxv2i1, non_temporal_store, STNT1D_ZRI>;
12631263

1264+
multiclass unpred_store<ValueType Ty, Instruction RegImmInst, Instruction PTrue> {
1265+
def _fi : Pat<(store (Ty ZPR:$val), (am_sve_fi GPR64sp:$base, simm4s1:$offset)),
1266+
(RegImmInst ZPR:$val, (PTrue 31), GPR64sp:$base, simm4s1:$offset)>;
1267+
}
1268+
1269+
defm Pat_ST1B : unpred_store<nxv16i8, ST1B_IMM, PTRUE_B>;
1270+
defm Pat_ST1H : unpred_store<nxv8i16, ST1H_IMM, PTRUE_H>;
1271+
defm Pat_ST1W : unpred_store<nxv4i32, ST1W_IMM, PTRUE_S>;
1272+
defm Pat_ST1D : unpred_store<nxv2i64, ST1D_IMM, PTRUE_D>;
1273+
defm Pat_ST1H_float16: unpred_store<nxv8f16, ST1H_IMM, PTRUE_H>;
1274+
defm Pat_ST1W_float : unpred_store<nxv4f32, ST1W_IMM, PTRUE_S>;
1275+
defm Pat_ST1D_double : unpred_store<nxv2f64, ST1D_IMM, PTRUE_D>;
1276+
1277+
multiclass unpred_load<ValueType Ty, Instruction RegImmInst, Instruction PTrue> {
1278+
def _fi : Pat<(Ty (load (am_sve_fi GPR64sp:$base, simm4s1:$offset))),
1279+
(RegImmInst (PTrue 31), GPR64sp:$base, simm4s1:$offset)>;
1280+
}
1281+
1282+
defm Pat_LD1B : unpred_load<nxv16i8, LD1B_IMM, PTRUE_B>;
1283+
defm Pat_LD1H : unpred_load<nxv8i16, LD1H_IMM, PTRUE_H>;
1284+
defm Pat_LD1W : unpred_load<nxv4i32, LD1W_IMM, PTRUE_S>;
1285+
defm Pat_LD1D : unpred_load<nxv2i64, LD1D_IMM, PTRUE_D>;
1286+
defm Pat_LD1H_float16: unpred_load<nxv8f16, LD1H_IMM, PTRUE_H>;
1287+
defm Pat_LD1W_float : unpred_load<nxv4f32, LD1W_IMM, PTRUE_S>;
1288+
defm Pat_LD1D_double : unpred_load<nxv2f64, LD1D_IMM, PTRUE_D>;
1289+
1290+
multiclass unpred_store_predicate<ValueType Ty, Instruction Store> {
1291+
def _fi : Pat<(store (Ty PPR:$val), (am_sve_fi GPR64sp:$base, simm9:$offset)),
1292+
(Store PPR:$val, GPR64sp:$base, simm9:$offset)>;
1293+
}
1294+
1295+
defm Pat_Store_P16 : unpred_store_predicate<nxv16i1, STR_PXI>;
1296+
defm Pat_Store_P8 : unpred_store_predicate<nxv8i1, STR_PXI>;
1297+
defm Pat_Store_P4 : unpred_store_predicate<nxv4i1, STR_PXI>;
1298+
defm Pat_Store_P2 : unpred_store_predicate<nxv2i1, STR_PXI>;
1299+
1300+
multiclass unpred_load_predicate<ValueType Ty, Instruction Load> {
1301+
def _fi : Pat<(Ty (load (am_sve_fi GPR64sp:$base, simm9:$offset))),
1302+
(Load GPR64sp:$base, simm9:$offset)>;
1303+
}
1304+
1305+
defm Pat_Load_P16 : unpred_load_predicate<nxv16i1, LDR_PXI>;
1306+
defm Pat_Load_P8 : unpred_load_predicate<nxv8i1, LDR_PXI>;
1307+
defm Pat_Load_P4 : unpred_load_predicate<nxv4i1, LDR_PXI>;
1308+
defm Pat_Load_P2 : unpred_load_predicate<nxv2i1, LDR_PXI>;
1309+
12641310
multiclass ldnf1<Instruction I, ValueType Ty, SDPatternOperator Load, ValueType PredTy, ValueType MemVT> {
12651311
// base
12661312
def : Pat<(Ty (Load (PredTy PPR:$gp), GPR64:$base, MemVT)),

llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,7 @@ namespace AArch64 {
658658
// in index i*P of a <n x (M*P) x t> vector. The other elements of the
659659
// <n x (M*P) x t> vector (such as index 1) are undefined.
660660
static constexpr unsigned SVEBitsPerBlock = 128;
661+
static constexpr unsigned SVEMaxBitsPerVector = 2048;
661662
const unsigned NeonBitsPerVector = 128;
662663
} // end namespace AArch64
663664
} // end namespace llvm

0 commit comments

Comments
 (0)