Skip to content

Commit c24cf97

Browse files
author
Anna Welker
committed
[ARM][MVE] Enable extending gathers
Enables the masked gather pass to create extending masked gathers. Differential Revision: https://reviews.llvm.org/D72451
1 parent 2b81d3c commit c24cf97

File tree

5 files changed

+182
-622
lines changed

5 files changed

+182
-622
lines changed

llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp

Lines changed: 63 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class MVEGatherScatterLowering : public FunctionPass {
8484
bool lowerGather(IntrinsicInst *I);
8585
// Create a gather from a base + vector of offsets
8686
Value *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr,
87-
IRBuilder<> Builder);
87+
Instruction *&Root, IRBuilder<> Builder);
8888
// Create a gather from a vector of pointers
8989
Value *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr,
9090
IRBuilder<> Builder);
@@ -104,9 +104,9 @@ Pass *llvm::createMVEGatherScatterLoweringPass() {
104104
bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements,
105105
unsigned ElemSize,
106106
unsigned Alignment) {
107-
// Do only allow non-extending gathers for now
108-
if (((NumElements == 4 && ElemSize == 32) ||
109-
(NumElements == 8 && ElemSize == 16) ||
107+
if (((NumElements == 4 &&
108+
(ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) ||
109+
(NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) ||
110110
(NumElements == 16 && ElemSize == 8)) &&
111111
ElemSize / 8 <= Alignment)
112112
return true;
@@ -126,9 +126,6 @@ Value *MVEGatherScatterLowering::checkGEP(Value *&Offsets, Type *Ty, Value *Ptr,
126126
<< " from base + vector of offsets\n");
127127
Value *GEPPtr = GEP->getPointerOperand();
128128
if (GEPPtr->getType()->isVectorTy()) {
129-
LLVM_DEBUG(dbgs() << "masked gathers: gather from a vector of pointers"
130-
<< " hidden behind a getelementptr currently not"
131-
<< " supported. Expanding.\n");
132129
return nullptr;
133130
}
134131
if (GEP->getNumOperands() != 2) {
@@ -194,7 +191,10 @@ bool MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) {
194191
IRBuilder<> Builder(I->getContext());
195192
Builder.SetInsertPoint(I);
196193
Builder.SetCurrentDebugLocation(I->getDebugLoc());
197-
Value *Load = tryCreateMaskedGatherOffset(I, Ptr, Builder);
194+
195+
Instruction *Root = I;
196+
197+
Value *Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder);
198198
if (!Load)
199199
Load = tryCreateMaskedGatherBase(I, Ptr, Builder);
200200
if (!Load)
@@ -206,18 +206,24 @@ bool MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) {
206206
Load = Builder.CreateSelect(Mask, Load, PassThru);
207207
}
208208

209+
Root->replaceAllUsesWith(Load);
210+
Root->eraseFromParent();
211+
if (Root != I)
212+
// If this was an extending gather, we need to get rid of the sext/zext
213+
// sext/zext as well as of the gather itself
214+
I->eraseFromParent();
209215
LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n");
210-
I->replaceAllUsesWith(Load);
211-
I->eraseFromParent();
212216
return true;
213217
}
214218

215219
Value *MVEGatherScatterLowering::tryCreateMaskedGatherBase(
216220
IntrinsicInst *I, Value *Ptr, IRBuilder<> Builder) {
217221
using namespace PatternMatch;
218-
LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n");
222+
219223
Type *Ty = I->getType();
220-
if (Ty->getVectorNumElements() != 4)
224+
225+
LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n");
226+
if (Ty->getVectorNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
221227
// Can't build an intrinsic for this
222228
return nullptr;
223229
Value *Mask = I->getArgOperand(2);
@@ -233,23 +239,55 @@ Value *MVEGatherScatterLowering::tryCreateMaskedGatherBase(
233239
}
234240

235241
Value *MVEGatherScatterLowering::tryCreateMaskedGatherOffset(
236-
IntrinsicInst *I, Value *Ptr, IRBuilder<> Builder) {
242+
IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> Builder) {
237243
using namespace PatternMatch;
238-
Type *Ty = I->getType();
244+
245+
Type *OriginalTy = I->getType();
246+
Type *ResultTy = OriginalTy;
247+
248+
unsigned Unsigned = 1;
249+
// The size of the gather was already checked in isLegalTypeAndAlignment;
250+
// if it was not a full vector width an appropriate extend should follow.
251+
auto *Extend = Root;
252+
if (OriginalTy->getPrimitiveSizeInBits() < 128) {
253+
// Only transform gathers with exactly one use
254+
if (!I->hasOneUse())
255+
return nullptr;
256+
257+
// The correct root to replace is the not the CallInst itself, but the
258+
// instruction which extends it
259+
Extend = cast<Instruction>(*I->users().begin());
260+
if (isa<SExtInst>(Extend)) {
261+
Unsigned = 0;
262+
} else if (!isa<ZExtInst>(Extend)) {
263+
LLVM_DEBUG(dbgs() << "masked gathers: extend needed but not provided. "
264+
<< "Expanding\n");
265+
return nullptr;
266+
}
267+
LLVM_DEBUG(dbgs() << "masked gathers: found an extending gather\n");
268+
ResultTy = Extend->getType();
269+
// The final size of the gather must be a full vector width
270+
if (ResultTy->getPrimitiveSizeInBits() != 128) {
271+
LLVM_DEBUG(dbgs() << "masked gathers: extending from the wrong type. "
272+
<< "Expanding\n");
273+
return nullptr;
274+
}
275+
}
276+
239277
Value *Offsets;
240-
Value *BasePtr = checkGEP(Offsets, Ty, Ptr, Builder);
278+
Value *BasePtr = checkGEP(Offsets, ResultTy, Ptr, Builder);
241279
if (!BasePtr)
242280
return nullptr;
243281

244282
unsigned Scale;
245283
int GEPElemSize =
246284
BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits();
247-
int ResultElemSize = Ty->getScalarSizeInBits();
285+
int MemoryElemSize = OriginalTy->getScalarSizeInBits();
248286
// This can be a 32bit load scaled by 4, a 16bit load scaled by 2, or a
249287
// 8bit, 16bit or 32bit load scaled by 1
250-
if (GEPElemSize == 32 && ResultElemSize == 32) {
288+
if (GEPElemSize == 32 && MemoryElemSize == 32) {
251289
Scale = 2;
252-
} else if (GEPElemSize == 16 && ResultElemSize == 16) {
290+
} else if (GEPElemSize == 16 && MemoryElemSize == 16) {
253291
Scale = 1;
254292
} else if (GEPElemSize == 8) {
255293
Scale = 0;
@@ -258,20 +296,21 @@ Value *MVEGatherScatterLowering::tryCreateMaskedGatherOffset(
258296
<< " create masked gather\n");
259297
return nullptr;
260298
}
299+
Root = Extend;
261300

262301
Value *Mask = I->getArgOperand(2);
263302
if (!match(Mask, m_One()))
264303
return Builder.CreateIntrinsic(
265304
Intrinsic::arm_mve_vldr_gather_offset_predicated,
266-
{Ty, BasePtr->getType(), Offsets->getType(), Mask->getType()},
267-
{BasePtr, Offsets, Builder.getInt32(Ty->getScalarSizeInBits()),
268-
Builder.getInt32(Scale), Builder.getInt32(1), Mask});
305+
{ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()},
306+
{BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()),
307+
Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask});
269308
else
270309
return Builder.CreateIntrinsic(
271310
Intrinsic::arm_mve_vldr_gather_offset,
272-
{Ty, BasePtr->getType(), Offsets->getType()},
273-
{BasePtr, Offsets, Builder.getInt32(Ty->getScalarSizeInBits()),
274-
Builder.getInt32(Scale), Builder.getInt32(1)});
311+
{ResultTy, BasePtr->getType(), Offsets->getType()},
312+
{BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()),
313+
Builder.getInt32(Scale), Builder.getInt32(Unsigned)});
275314
}
276315

277316
bool MVEGatherScatterLowering::runOnFunction(Function &F) {

llvm/test/CodeGen/Thumb2/mve-gather-ind16-unscaled.ll

Lines changed: 12 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -4,38 +4,9 @@
44
define arm_aapcs_vfpcc <8 x i16> @zext_unscaled_i8_i16(i8* %base, <8 x i16>* %offptr) {
55
; CHECK-LABEL: zext_unscaled_i8_i16:
66
; CHECK: @ %bb.0: @ %entry
7-
; CHECK-NEXT: .save {r4, r5, r7, lr}
8-
; CHECK-NEXT: push {r4, r5, r7, lr}
9-
; CHECK-NEXT: vldrh.u32 q0, [r1]
10-
; CHECK-NEXT: vldrh.u32 q1, [r1, #8]
11-
; CHECK-NEXT: vadd.i32 q0, q0, r0
12-
; CHECK-NEXT: vadd.i32 q1, q1, r0
13-
; CHECK-NEXT: vmov r2, s3
14-
; CHECK-NEXT: vmov r3, s1
15-
; CHECK-NEXT: vmov r5, s0
16-
; CHECK-NEXT: vmov r0, s4
17-
; CHECK-NEXT: vmov r1, s5
18-
; CHECK-NEXT: vmov r4, s7
19-
; CHECK-NEXT: ldrb.w r12, [r2]
20-
; CHECK-NEXT: vmov r2, s2
21-
; CHECK-NEXT: ldrb.w lr, [r3]
22-
; CHECK-NEXT: vmov r3, s6
23-
; CHECK-NEXT: ldrb r5, [r5]
24-
; CHECK-NEXT: ldrb r0, [r0]
25-
; CHECK-NEXT: vmov.16 q0[0], r5
26-
; CHECK-NEXT: ldrb r1, [r1]
27-
; CHECK-NEXT: vmov.16 q0[1], lr
28-
; CHECK-NEXT: ldrb r4, [r4]
29-
; CHECK-NEXT: ldrb r2, [r2]
30-
; CHECK-NEXT: ldrb r3, [r3]
31-
; CHECK-NEXT: vmov.16 q0[2], r2
32-
; CHECK-NEXT: vmov.16 q0[3], r12
33-
; CHECK-NEXT: vmov.16 q0[4], r0
34-
; CHECK-NEXT: vmov.16 q0[5], r1
35-
; CHECK-NEXT: vmov.16 q0[6], r3
36-
; CHECK-NEXT: vmov.16 q0[7], r4
37-
; CHECK-NEXT: vmovlb.u8 q0, q0
38-
; CHECK-NEXT: pop {r4, r5, r7, pc}
7+
; CHECK-NEXT: vldrh.u16 q1, [r1]
8+
; CHECK-NEXT: vldrb.u16 q0, [r0, q1]
9+
; CHECK-NEXT: bx lr
3910
entry:
4011
%offs = load <8 x i16>, <8 x i16>* %offptr, align 2
4112
%offs.zext = zext <8 x i16> %offs to <8 x i32>
@@ -48,38 +19,9 @@ entry:
4819
define arm_aapcs_vfpcc <8 x i16> @sext_unscaled_i8_i16(i8* %base, <8 x i16>* %offptr) {
4920
; CHECK-LABEL: sext_unscaled_i8_i16:
5021
; CHECK: @ %bb.0: @ %entry
51-
; CHECK-NEXT: .save {r4, r5, r7, lr}
52-
; CHECK-NEXT: push {r4, r5, r7, lr}
53-
; CHECK-NEXT: vldrh.u32 q0, [r1]
54-
; CHECK-NEXT: vldrh.u32 q1, [r1, #8]
55-
; CHECK-NEXT: vadd.i32 q0, q0, r0
56-
; CHECK-NEXT: vadd.i32 q1, q1, r0
57-
; CHECK-NEXT: vmov r2, s3
58-
; CHECK-NEXT: vmov r3, s1
59-
; CHECK-NEXT: vmov r5, s0
60-
; CHECK-NEXT: vmov r0, s4
61-
; CHECK-NEXT: vmov r1, s5
62-
; CHECK-NEXT: vmov r4, s7
63-
; CHECK-NEXT: ldrb.w r12, [r2]
64-
; CHECK-NEXT: vmov r2, s2
65-
; CHECK-NEXT: ldrb.w lr, [r3]
66-
; CHECK-NEXT: vmov r3, s6
67-
; CHECK-NEXT: ldrb r5, [r5]
68-
; CHECK-NEXT: ldrb r0, [r0]
69-
; CHECK-NEXT: vmov.16 q0[0], r5
70-
; CHECK-NEXT: ldrb r1, [r1]
71-
; CHECK-NEXT: vmov.16 q0[1], lr
72-
; CHECK-NEXT: ldrb r4, [r4]
73-
; CHECK-NEXT: ldrb r2, [r2]
74-
; CHECK-NEXT: ldrb r3, [r3]
75-
; CHECK-NEXT: vmov.16 q0[2], r2
76-
; CHECK-NEXT: vmov.16 q0[3], r12
77-
; CHECK-NEXT: vmov.16 q0[4], r0
78-
; CHECK-NEXT: vmov.16 q0[5], r1
79-
; CHECK-NEXT: vmov.16 q0[6], r3
80-
; CHECK-NEXT: vmov.16 q0[7], r4
81-
; CHECK-NEXT: vmovlb.s8 q0, q0
82-
; CHECK-NEXT: pop {r4, r5, r7, pc}
22+
; CHECK-NEXT: vldrh.u16 q1, [r1]
23+
; CHECK-NEXT: vldrb.s16 q0, [r0, q1]
24+
; CHECK-NEXT: bx lr
8325
entry:
8426
%offs = load <8 x i16>, <8 x i16>* %offptr, align 2
8527
%offs.zext = zext <8 x i16> %offs to <8 x i32>
@@ -122,38 +64,9 @@ entry:
12264
define arm_aapcs_vfpcc <8 x i16> @zext_unsigned_unscaled_i8_i8(i8* %base, <8 x i8>* %offptr) {
12365
; CHECK-LABEL: zext_unsigned_unscaled_i8_i8:
12466
; CHECK: @ %bb.0: @ %entry
125-
; CHECK-NEXT: .save {r4, r5, r7, lr}
126-
; CHECK-NEXT: push {r4, r5, r7, lr}
127-
; CHECK-NEXT: vldrb.u32 q0, [r1]
128-
; CHECK-NEXT: vldrb.u32 q1, [r1, #4]
129-
; CHECK-NEXT: vadd.i32 q0, q0, r0
130-
; CHECK-NEXT: vadd.i32 q1, q1, r0
131-
; CHECK-NEXT: vmov r2, s3
132-
; CHECK-NEXT: vmov r3, s1
133-
; CHECK-NEXT: vmov r5, s0
134-
; CHECK-NEXT: vmov r0, s4
135-
; CHECK-NEXT: vmov r1, s5
136-
; CHECK-NEXT: vmov r4, s7
137-
; CHECK-NEXT: ldrb.w r12, [r2]
138-
; CHECK-NEXT: vmov r2, s2
139-
; CHECK-NEXT: ldrb.w lr, [r3]
140-
; CHECK-NEXT: vmov r3, s6
141-
; CHECK-NEXT: ldrb r5, [r5]
142-
; CHECK-NEXT: ldrb r0, [r0]
143-
; CHECK-NEXT: vmov.16 q0[0], r5
144-
; CHECK-NEXT: ldrb r1, [r1]
145-
; CHECK-NEXT: vmov.16 q0[1], lr
146-
; CHECK-NEXT: ldrb r4, [r4]
147-
; CHECK-NEXT: ldrb r2, [r2]
148-
; CHECK-NEXT: ldrb r3, [r3]
149-
; CHECK-NEXT: vmov.16 q0[2], r2
150-
; CHECK-NEXT: vmov.16 q0[3], r12
151-
; CHECK-NEXT: vmov.16 q0[4], r0
152-
; CHECK-NEXT: vmov.16 q0[5], r1
153-
; CHECK-NEXT: vmov.16 q0[6], r3
154-
; CHECK-NEXT: vmov.16 q0[7], r4
155-
; CHECK-NEXT: vmovlb.u8 q0, q0
156-
; CHECK-NEXT: pop {r4, r5, r7, pc}
67+
; CHECK-NEXT: vldrb.u16 q1, [r1]
68+
; CHECK-NEXT: vldrb.u16 q0, [r0, q1]
69+
; CHECK-NEXT: bx lr
15770
entry:
15871
%offs = load <8 x i8>, <8 x i8>* %offptr, align 1
15972
%offs.zext = zext <8 x i8> %offs to <8 x i32>
@@ -166,38 +79,9 @@ entry:
16679
define arm_aapcs_vfpcc <8 x i16> @sext_unsigned_unscaled_i8_i8(i8* %base, <8 x i8>* %offptr) {
16780
; CHECK-LABEL: sext_unsigned_unscaled_i8_i8:
16881
; CHECK: @ %bb.0: @ %entry
169-
; CHECK-NEXT: .save {r4, r5, r7, lr}
170-
; CHECK-NEXT: push {r4, r5, r7, lr}
171-
; CHECK-NEXT: vldrb.u32 q0, [r1]
172-
; CHECK-NEXT: vldrb.u32 q1, [r1, #4]
173-
; CHECK-NEXT: vadd.i32 q0, q0, r0
174-
; CHECK-NEXT: vadd.i32 q1, q1, r0
175-
; CHECK-NEXT: vmov r2, s3
176-
; CHECK-NEXT: vmov r3, s1
177-
; CHECK-NEXT: vmov r5, s0
178-
; CHECK-NEXT: vmov r0, s4
179-
; CHECK-NEXT: vmov r1, s5
180-
; CHECK-NEXT: vmov r4, s7
181-
; CHECK-NEXT: ldrb.w r12, [r2]
182-
; CHECK-NEXT: vmov r2, s2
183-
; CHECK-NEXT: ldrb.w lr, [r3]
184-
; CHECK-NEXT: vmov r3, s6
185-
; CHECK-NEXT: ldrb r5, [r5]
186-
; CHECK-NEXT: ldrb r0, [r0]
187-
; CHECK-NEXT: vmov.16 q0[0], r5
188-
; CHECK-NEXT: ldrb r1, [r1]
189-
; CHECK-NEXT: vmov.16 q0[1], lr
190-
; CHECK-NEXT: ldrb r4, [r4]
191-
; CHECK-NEXT: ldrb r2, [r2]
192-
; CHECK-NEXT: ldrb r3, [r3]
193-
; CHECK-NEXT: vmov.16 q0[2], r2
194-
; CHECK-NEXT: vmov.16 q0[3], r12
195-
; CHECK-NEXT: vmov.16 q0[4], r0
196-
; CHECK-NEXT: vmov.16 q0[5], r1
197-
; CHECK-NEXT: vmov.16 q0[6], r3
198-
; CHECK-NEXT: vmov.16 q0[7], r4
199-
; CHECK-NEXT: vmovlb.s8 q0, q0
200-
; CHECK-NEXT: pop {r4, r5, r7, pc}
82+
; CHECK-NEXT: vldrb.u16 q1, [r1]
83+
; CHECK-NEXT: vldrb.s16 q0, [r0, q1]
84+
; CHECK-NEXT: bx lr
20185
entry:
20286
%offs = load <8 x i8>, <8 x i8>* %offptr, align 1
20387
%offs.zext = zext <8 x i8> %offs to <8 x i32>

0 commit comments

Comments
 (0)