Skip to content

Commit ec10cb2

Browse files
committed
[InstCombine] Allow values with multiple users in SimplifyDemandedVectorElts
Summary: Allow for ignoring the check for a single use in SimplifyDemandedVectorElts to be able to simplify operands if DemandedElts is known to contain the union of elements used by all users. It is a responsibility of a caller of SimplifyDemandedVectorElts to supply correct DemandedElts. Simplify a series of extractelement instructions if only a subset of elements is used. Reviewers: reames, arsenm, majnemer, nhaehnle Reviewed By: nhaehnle Subscribers: wdng, jvesely, nhaehnle, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D67345 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@375395 91177308-0d34-0410-b5e6-96231b3b80d8
1 parent 919bed3 commit ec10cb2

File tree

4 files changed

+186
-31
lines changed

4 files changed

+186
-31
lines changed

lib/Transforms/InstCombine/InstCombineInternal.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -858,7 +858,8 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner
858858
int DmaskIdx = -1);
859859

860860
Value *SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
861-
APInt &UndefElts, unsigned Depth = 0);
861+
APInt &UndefElts, unsigned Depth = 0,
862+
bool AllowMultipleUsers = false);
862863

863864
/// Canonicalize the position of binops relative to shufflevector.
864865
Instruction *foldVectorBinop(BinaryOperator &Inst);

lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,16 +1074,22 @@ Value *InstCombiner::simplifyAMDGCNMemoryIntrinsicDemanded(IntrinsicInst *II,
10741074
}
10751075

10761076
/// The specified value produces a vector with any number of elements.
1077+
/// This method analyzes which elements of the operand are undef and returns
1078+
/// that information in UndefElts.
1079+
///
10771080
/// DemandedElts contains the set of elements that are actually used by the
1078-
/// caller. This method analyzes which elements of the operand are undef and
1079-
/// returns that information in UndefElts.
1081+
/// caller, and by default (AllowMultipleUsers equals false) the value is
1082+
/// simplified only if it has a single caller. If AllowMultipleUsers is set
1083+
/// to true, DemandedElts refers to the union of sets of elements that are
1084+
/// used by all callers.
10801085
///
10811086
/// If the information about demanded elements can be used to simplify the
10821087
/// operation, the operation is simplified, then the resultant value is
10831088
/// returned. This returns null if no change was made.
10841089
Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
10851090
APInt &UndefElts,
1086-
unsigned Depth) {
1091+
unsigned Depth,
1092+
bool AllowMultipleUsers) {
10871093
unsigned VWidth = V->getType()->getVectorNumElements();
10881094
APInt EltMask(APInt::getAllOnesValue(VWidth));
10891095
assert((DemandedElts & ~EltMask) == 0 && "Invalid DemandedElts!");
@@ -1137,19 +1143,21 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
11371143
if (Depth == 10)
11381144
return nullptr;
11391145

1140-
// If multiple users are using the root value, proceed with
1141-
// simplification conservatively assuming that all elements
1142-
// are needed.
1143-
if (!V->hasOneUse()) {
1144-
// Quit if we find multiple users of a non-root value though.
1145-
// They'll be handled when it's their turn to be visited by
1146-
// the main instcombine process.
1147-
if (Depth != 0)
1148-
// TODO: Just compute the UndefElts information recursively.
1149-
return nullptr;
1146+
if (!AllowMultipleUsers) {
1147+
// If multiple users are using the root value, proceed with
1148+
// simplification conservatively assuming that all elements
1149+
// are needed.
1150+
if (!V->hasOneUse()) {
1151+
// Quit if we find multiple users of a non-root value though.
1152+
// They'll be handled when it's their turn to be visited by
1153+
// the main instcombine process.
1154+
if (Depth != 0)
1155+
// TODO: Just compute the UndefElts information recursively.
1156+
return nullptr;
11501157

1151-
// Conservatively assume that all elements are needed.
1152-
DemandedElts = EltMask;
1158+
// Conservatively assume that all elements are needed.
1159+
DemandedElts = EltMask;
1160+
}
11531161
}
11541162

11551163
Instruction *I = dyn_cast<Instruction>(V);

lib/Transforms/InstCombine/InstCombineVectorOps.cpp

Lines changed: 90 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,69 @@ static Instruction *foldBitcastExtElt(ExtractElementInst &Ext,
253253
return nullptr;
254254
}
255255

256+
/// Find elements of V demanded by UserInstr.
257+
static APInt findDemandedEltsBySingleUser(Value *V, Instruction *UserInstr) {
258+
unsigned VWidth = V->getType()->getVectorNumElements();
259+
260+
// Conservatively assume that all elements are needed.
261+
APInt UsedElts(APInt::getAllOnesValue(VWidth));
262+
263+
switch (UserInstr->getOpcode()) {
264+
case Instruction::ExtractElement: {
265+
ExtractElementInst *EEI = cast<ExtractElementInst>(UserInstr);
266+
assert(EEI->getVectorOperand() == V);
267+
ConstantInt *EEIIndexC = dyn_cast<ConstantInt>(EEI->getIndexOperand());
268+
if (EEIIndexC && EEIIndexC->getValue().ult(VWidth)) {
269+
UsedElts = APInt::getOneBitSet(VWidth, EEIIndexC->getZExtValue());
270+
}
271+
break;
272+
}
273+
case Instruction::ShuffleVector: {
274+
ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(UserInstr);
275+
unsigned MaskNumElts = UserInstr->getType()->getVectorNumElements();
276+
277+
UsedElts = APInt(VWidth, 0);
278+
for (unsigned i = 0; i < MaskNumElts; i++) {
279+
unsigned MaskVal = Shuffle->getMaskValue(i);
280+
if (MaskVal == -1u || MaskVal >= 2 * VWidth)
281+
continue;
282+
if (Shuffle->getOperand(0) == V && (MaskVal < VWidth))
283+
UsedElts.setBit(MaskVal);
284+
if (Shuffle->getOperand(1) == V &&
285+
((MaskVal >= VWidth) && (MaskVal < 2 * VWidth)))
286+
UsedElts.setBit(MaskVal - VWidth);
287+
}
288+
break;
289+
}
290+
default:
291+
break;
292+
}
293+
return UsedElts;
294+
}
295+
296+
/// Find union of elements of V demanded by all its users.
297+
/// If it is known by querying findDemandedEltsBySingleUser that
298+
/// no user demands an element of V, then the corresponding bit
299+
/// remains unset in the returned value.
300+
static APInt findDemandedEltsByAllUsers(Value *V) {
301+
unsigned VWidth = V->getType()->getVectorNumElements();
302+
303+
APInt UnionUsedElts(VWidth, 0);
304+
for (const Use &U : V->uses()) {
305+
if (Instruction *I = dyn_cast<Instruction>(U.getUser())) {
306+
UnionUsedElts |= findDemandedEltsBySingleUser(V, I);
307+
} else {
308+
UnionUsedElts = APInt::getAllOnesValue(VWidth);
309+
break;
310+
}
311+
312+
if (UnionUsedElts.isAllOnesValue())
313+
break;
314+
}
315+
316+
return UnionUsedElts;
317+
}
318+
256319
Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) {
257320
Value *SrcVec = EI.getVectorOperand();
258321
Value *Index = EI.getIndexOperand();
@@ -271,19 +334,35 @@ Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) {
271334
return nullptr;
272335

273336
// This instruction only demands the single element from the input vector.
274-
// If the input vector has a single use, simplify it based on this use
275-
// property.
276-
if (SrcVec->hasOneUse() && NumElts != 1) {
277-
APInt UndefElts(NumElts, 0);
278-
APInt DemandedElts(NumElts, 0);
279-
DemandedElts.setBit(IndexC->getZExtValue());
280-
if (Value *V = SimplifyDemandedVectorElts(SrcVec, DemandedElts,
281-
UndefElts)) {
282-
EI.setOperand(0, V);
283-
return &EI;
337+
if (NumElts != 1) {
338+
// If the input vector has a single use, simplify it based on this use
339+
// property.
340+
if (SrcVec->hasOneUse()) {
341+
APInt UndefElts(NumElts, 0);
342+
APInt DemandedElts(NumElts, 0);
343+
DemandedElts.setBit(IndexC->getZExtValue());
344+
if (Value *V =
345+
SimplifyDemandedVectorElts(SrcVec, DemandedElts, UndefElts)) {
346+
EI.setOperand(0, V);
347+
return &EI;
348+
}
349+
} else {
350+
// If the input vector has multiple uses, simplify it based on a union
351+
// of all elements used.
352+
APInt DemandedElts = findDemandedEltsByAllUsers(SrcVec);
353+
if (!DemandedElts.isAllOnesValue()) {
354+
APInt UndefElts(NumElts, 0);
355+
if (Value *V = SimplifyDemandedVectorElts(
356+
SrcVec, DemandedElts, UndefElts, 0 /* Depth */,
357+
true /* AllowMultipleUsers */)) {
358+
if (V != SrcVec) {
359+
SrcVec->replaceAllUsesWith(V);
360+
return &EI;
361+
}
362+
}
363+
}
284364
}
285365
}
286-
287366
if (Instruction *I = foldBitcastExtElt(EI, Builder, DL.isBigEndian()))
288367
return I;
289368

test/Transforms/InstCombine/AMDGPU/amdgcn-demanded-vector-elts.ll

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,10 @@ define amdgpu_ps <3 x float> @extract_elt0_elt2_elt3_buffer_load_v4f32(<4 x i32>
152152
ret <3 x float> %shuf
153153
}
154154

155-
; FIXME: Not handled even though only 2 elts used
156155
; CHECK-LABEL: @extract_elt0_elt1_buffer_load_v4f32_2(
157-
; CHECK-NEXT: %data = call <4 x float> @llvm.amdgcn.buffer.load.v4f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false)
158-
; CHECK-NEXT: %elt0 = extractelement <4 x float> %data, i32 0
159-
; CHECK-NEXT: %elt1 = extractelement <4 x float> %data, i32 1
156+
; CHECK-NEXT: %data = call <2 x float> @llvm.amdgcn.buffer.load.v2f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false)
157+
; CHECK-NEXT: %elt0 = extractelement <2 x float> %data, i32 0
158+
; CHECK-NEXT: %elt1 = extractelement <2 x float> %data, i32 1
160159
; CHECK-NEXT: %ins0 = insertvalue { float, float } undef, float %elt0, 0
161160
; CHECK-NEXT: %ins1 = insertvalue { float, float } %ins0, float %elt1, 1
162161
; CHECK-NEXT: ret { float, float } %ins1
@@ -169,6 +168,74 @@ define amdgpu_ps { float, float } @extract_elt0_elt1_buffer_load_v4f32_2(<4 x i3
169168
ret { float, float } %ins1
170169
}
171170

171+
; CHECK-LABEL: @extract_elt0_elt1_elt2_buffer_load_v4f32_2(
172+
; CHECK-NEXT: %data = call <3 x float> @llvm.amdgcn.buffer.load.v3f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false)
173+
; CHECK-NEXT: %elt0 = extractelement <3 x float> %data, i32 0
174+
; CHECK-NEXT: %elt1 = extractelement <3 x float> %data, i32 1
175+
; CHECK-NEXT: %elt2 = extractelement <3 x float> %data, i32 2
176+
; CHECK-NEXT: %ins0 = insertvalue { float, float, float } undef, float %elt0, 0
177+
; CHECK-NEXT: %ins1 = insertvalue { float, float, float } %ins0, float %elt1, 1
178+
; CHECK-NEXT: %ins2 = insertvalue { float, float, float } %ins1, float %elt2, 2
179+
; CHECK-NEXT: ret { float, float, float } %ins2
180+
define amdgpu_ps { float, float, float } @extract_elt0_elt1_elt2_buffer_load_v4f32_2(<4 x i32> inreg %rsrc, i32 %idx, i32 %ofs) #0 {
181+
%data = call <4 x float> @llvm.amdgcn.buffer.load.v4f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false)
182+
%elt0 = extractelement <4 x float> %data, i32 0
183+
%elt1 = extractelement <4 x float> %data, i32 1
184+
%elt2 = extractelement <4 x float> %data, i32 2
185+
%ins0 = insertvalue { float, float, float } undef, float %elt0, 0
186+
%ins1 = insertvalue { float, float, float } %ins0, float %elt1, 1
187+
%ins2 = insertvalue { float, float, float } %ins1, float %elt2, 2
188+
ret { float, float, float } %ins2
189+
}
190+
191+
; CHECK-LABEL: @extract_elt0_elt1_elt2_buffer_load_v4f32_3(
192+
; CHECK-NEXT: %data = call <3 x float> @llvm.amdgcn.buffer.load.v3f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false)
193+
; CHECK-NEXT: %ins1 = shufflevector <3 x float> %data, <3 x float> undef, <2 x i32> <i32 0, i32 2>
194+
; CHECK-NEXT: %shuf = shufflevector <3 x float> %data, <3 x float> undef, <2 x i32> <i32 undef, i32 1>
195+
; CHECK-NEXT: %ret = fadd <2 x float> %ins1, %shuf
196+
define amdgpu_ps <2 x float> @extract_elt0_elt1_elt2_buffer_load_v4f32_3(<4 x i32> inreg %rsrc, i32 %idx, i32 %ofs) #0 {
197+
%data = call <4 x float> @llvm.amdgcn.buffer.load.v4f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false)
198+
%elt0 = extractelement <4 x float> %data, i32 0
199+
%elt2 = extractelement <4 x float> %data, i32 2
200+
%ins0 = insertelement <2 x float> undef, float %elt0, i32 0
201+
%ins1 = insertelement <2 x float> %ins0, float %elt2, i32 1
202+
%shuf = shufflevector <4 x float> %data, <4 x float> undef, <2 x i32> <i32 4, i32 1>
203+
%ret = fadd <2 x float> %ins1, %shuf
204+
ret <2 x float> %ret
205+
}
206+
207+
; CHECK-LABEL: @extract_elt0_elt1_elt2_buffer_load_v4f32_4(
208+
; CHECK-NEXT: %data = call <3 x float> @llvm.amdgcn.buffer.load.v3f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false)
209+
; CHECK-NEXT: %ins1 = shufflevector <3 x float> %data, <3 x float> undef, <2 x i32> <i32 0, i32 2>
210+
; CHECK-NEXT: %shuf = shufflevector <3 x float> %data, <3 x float> undef, <2 x i32> <i32 1, i32 undef>
211+
; CHECK-NEXT: %ret = fadd <2 x float> %ins1, %shuf
212+
; CHECK-NEXT: ret <2 x float> %ret
213+
define amdgpu_ps <2 x float> @extract_elt0_elt1_elt2_buffer_load_v4f32_4(<4 x i32> inreg %rsrc, i32 %idx, i32 %ofs) #0 {
214+
%data = call <4 x float> @llvm.amdgcn.buffer.load.v4f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false)
215+
%elt0 = extractelement <4 x float> %data, i32 0
216+
%elt2 = extractelement <4 x float> %data, i32 2
217+
%ins0 = insertelement <2 x float> undef, float %elt0, i32 0
218+
%ins1 = insertelement <2 x float> %ins0, float %elt2, i32 1
219+
%shuf = shufflevector <4 x float> undef, <4 x float> %data, <2 x i32> <i32 5, i32 1>
220+
%ret = fadd <2 x float> %ins1, %shuf
221+
ret <2 x float> %ret
222+
}
223+
224+
; CHECK-LABEL: @extract_elt0_elt1_elt2_buffer_load_v4f32_5(
225+
; CHECK-NEXT: %data = call <3 x float> @llvm.amdgcn.buffer.load.v3f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false)
226+
; CHECK-NEXT: %ins1 = shufflevector <3 x float> %data, <3 x float> undef, <2 x i32> <i32 2, i32 2>
227+
; CHECK-NEXT: %shuf = shufflevector <3 x float> %data, <3 x float> undef, <2 x i32> <i32 0, i32 1>
228+
; CHECK-NEXT: %ret = fadd <2 x float> %ins1, %shuf
229+
define amdgpu_ps <2 x float> @extract_elt0_elt1_elt2_buffer_load_v4f32_5(<4 x i32> inreg %rsrc, i32 %idx, i32 %ofs) #0 {
230+
%data = call <4 x float> @llvm.amdgcn.buffer.load.v4f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false)
231+
%elt2 = extractelement <4 x float> %data, i32 2
232+
%ins0 = insertelement <2 x float> undef, float %elt2, i32 0
233+
%ins1 = insertelement <2 x float> %ins0, float %elt2, i32 1
234+
%shuf = shufflevector <4 x float> %data, <4 x float> %data, <2 x i32> <i32 0, i32 5>
235+
%ret = fadd <2 x float> %ins1, %shuf
236+
ret <2 x float> %ret
237+
}
238+
172239
; CHECK-LABEL: @extract_elt0_buffer_load_v3f32(
173240
; CHECK-NEXT: %data = call float @llvm.amdgcn.buffer.load.f32(<4 x i32> %rsrc, i32 %idx, i32 %ofs, i1 false, i1 false)
174241
; CHECK-NEXT: ret float %data

0 commit comments

Comments
 (0)