Skip to content

Commit 6765943

Browse files
committed
[CodeGen][SelectionDAG] More efficient code for X % C == 0 (SREM case)
Summary: This implements an optimization described in Hacker's Delight 10-17: when `C` is constant, the result of `X % C == 0` can be computed more cheaply without actually calculating the remainder. The motivation is discussed here: https://bugs.llvm.org/show_bug.cgi?id=35479. One huge caveat: this signed case is only valid for positive divisors. While we can freely negate negative divisors, we can't negate `INT_MIN`, so for now if `INT_MIN` is encountered, we bailout. As a follow-up, it should be possible to handle that more gracefully via extra `and`+`setcc`+`select`. This passes llvm's test-suite, and from cursory(!) cross-examination the folds (the assembly) match those of GCC, and manual checking via alive did not reveal any issues (other than the `INT_MIN` case) Reviewers: RKSimon, spatel, hermord, craig.topper, xbolva00 Reviewed By: RKSimon, xbolva00 Subscribers: xbolva00, thakis, javed.absar, hiraditya, dexonsmith, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D65366 llvm-svn: 368702
1 parent f4de7ed commit 6765943

File tree

11 files changed

+774
-1444
lines changed

11 files changed

+774
-1444
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4164,6 +4164,14 @@ class TargetLowering : public TargetLoweringBase {
41644164
SDValue buildUREMEqFold(EVT SETCCVT, SDValue REMNode, SDValue CompTargetNode,
41654165
ISD::CondCode Cond, DAGCombinerInfo &DCI,
41664166
const SDLoc &DL) const;
4167+
4168+
SDValue prepareSREMEqFold(EVT SETCCVT, SDValue REMNode,
4169+
SDValue CompTargetNode, ISD::CondCode Cond,
4170+
DAGCombinerInfo &DCI, const SDLoc &DL,
4171+
SmallVectorImpl<SDNode *> &Created) const;
4172+
SDValue buildSREMEqFold(EVT SETCCVT, SDValue REMNode, SDValue CompTargetNode,
4173+
ISD::CondCode Cond, DAGCombinerInfo &DCI,
4174+
const SDLoc &DL) const;
41674175
};
41684176

41694177
/// Given an LLVM IR type and return type attributes, compute the return value

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 221 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3802,15 +3802,21 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
38023802
}
38033803

38043804
// Fold remainder of division by a constant.
3805-
if (N0.getOpcode() == ISD::UREM && N0.hasOneUse() &&
3806-
(Cond == ISD::SETEQ || Cond == ISD::SETNE)) {
3805+
if ((N0.getOpcode() == ISD::UREM || N0.getOpcode() == ISD::SREM) &&
3806+
N0.hasOneUse() && (Cond == ISD::SETEQ || Cond == ISD::SETNE)) {
38073807
AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
38083808

38093809
// When division is cheap or optimizing for minimum size,
38103810
// fall through to DIVREM creation by skipping this fold.
3811-
if (!isIntDivCheap(VT, Attr) && !Attr.hasFnAttribute(Attribute::MinSize))
3812-
if (SDValue Folded = buildUREMEqFold(VT, N0, N1, Cond, DCI, dl))
3813-
return Folded;
3811+
if (!isIntDivCheap(VT, Attr) && !Attr.hasFnAttribute(Attribute::MinSize)) {
3812+
if (N0.getOpcode() == ISD::UREM) {
3813+
if (SDValue Folded = buildUREMEqFold(VT, N0, N1, Cond, DCI, dl))
3814+
return Folded;
3815+
} else if (N0.getOpcode() == ISD::SREM) {
3816+
if (SDValue Folded = buildSREMEqFold(VT, N0, N1, Cond, DCI, dl))
3817+
return Folded;
3818+
}
3819+
}
38143820
}
38153821

38163822
// Fold away ALL boolean setcc's.
@@ -5004,6 +5010,216 @@ TargetLowering::prepareUREMEqFold(EVT SETCCVT, SDValue REMNode,
50045010
((Cond == ISD::SETEQ) ? ISD::SETULE : ISD::SETUGT));
50055011
}
50065012

5013+
/// Given an ISD::SREM used only by an ISD::SETEQ or ISD::SETNE
5014+
/// where the divisor is constant and the comparison target is zero,
5015+
/// return a DAG expression that will generate the same comparison result
5016+
/// using only multiplications, additions and shifts/rotations.
5017+
/// Ref: "Hacker's Delight" 10-17.
5018+
SDValue TargetLowering::buildSREMEqFold(EVT SETCCVT, SDValue REMNode,
5019+
SDValue CompTargetNode,
5020+
ISD::CondCode Cond,
5021+
DAGCombinerInfo &DCI,
5022+
const SDLoc &DL) const {
5023+
SmallVector<SDNode *, 3> Built;
5024+
if (SDValue Folded = prepareSREMEqFold(SETCCVT, REMNode, CompTargetNode, Cond,
5025+
DCI, DL, Built)) {
5026+
for (SDNode *N : Built)
5027+
DCI.AddToWorklist(N);
5028+
return Folded;
5029+
}
5030+
5031+
return SDValue();
5032+
}
5033+
5034+
SDValue
5035+
TargetLowering::prepareSREMEqFold(EVT SETCCVT, SDValue REMNode,
5036+
SDValue CompTargetNode, ISD::CondCode Cond,
5037+
DAGCombinerInfo &DCI, const SDLoc &DL,
5038+
SmallVectorImpl<SDNode *> &Created) const {
5039+
// Fold:
5040+
// (seteq/ne (srem N, D), 0)
5041+
// To:
5042+
// (setule/ugt (rotr (add (mul N, P), A), K), Q)
5043+
//
5044+
// - D must be constant, with D = D0 * 2^K where D0 is odd
5045+
// - P is the multiplicative inverse of D0 modulo 2^W
5046+
// - A = bitwiseand(floor((2^(W - 1) - 1) / D0), (-(2^k)))
5047+
// - Q = floor((2 * A) / (2^K))
5048+
// where W is the width of the common type of N and D.
5049+
assert((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
5050+
"Only applicable for (in)equality comparisons.");
5051+
5052+
SelectionDAG &DAG = DCI.DAG;
5053+
5054+
EVT VT = REMNode.getValueType();
5055+
EVT SVT = VT.getScalarType();
5056+
EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout());
5057+
EVT ShSVT = ShVT.getScalarType();
5058+
5059+
// If MUL is unavailable, we cannot proceed in any case.
5060+
if (!isOperationLegalOrCustom(ISD::MUL, VT))
5061+
return SDValue();
5062+
5063+
// TODO: Could support comparing with non-zero too.
5064+
ConstantSDNode *CompTarget = isConstOrConstSplat(CompTargetNode);
5065+
if (!CompTarget || !CompTarget->isNullValue())
5066+
return SDValue();
5067+
5068+
bool HadOneDivisor = false;
5069+
bool AllDivisorsAreOnes = true;
5070+
bool HadEvenDivisor = false;
5071+
bool NeedToApplyOffset = false;
5072+
bool AllDivisorsArePowerOfTwo = true;
5073+
SmallVector<SDValue, 16> PAmts, AAmts, KAmts, QAmts;
5074+
5075+
auto BuildSREMPattern = [&](ConstantSDNode *C) {
5076+
// Division by 0 is UB. Leave it to be constant-folded elsewhere.
5077+
if (C->isNullValue())
5078+
return false;
5079+
5080+
// FIXME: we don't fold `rem %X, -C` to `rem %X, C` in DAGCombine.
5081+
5082+
// WARNING: this fold is only valid for positive divisors!
5083+
APInt D = C->getAPIntValue();
5084+
if (D.isMinSignedValue())
5085+
return false; // We can't negate INT_MIN.
5086+
if (D.isNegative())
5087+
D.negate(); // `rem %X, -C` is equivalent to `rem %X, C`
5088+
5089+
assert(!D.isNegative() && "The fold is only valid for positive divisors!");
5090+
5091+
// If all divisors are ones, we will prefer to avoid the fold.
5092+
HadOneDivisor |= D.isOneValue();
5093+
AllDivisorsAreOnes &= D.isOneValue();
5094+
5095+
// Decompose D into D0 * 2^K
5096+
unsigned K = D.countTrailingZeros();
5097+
assert((!D.isOneValue() || (K == 0)) && "For divisor '1' we won't rotate.");
5098+
APInt D0 = D.lshr(K);
5099+
5100+
// D is even if it has trailing zeros.
5101+
HadEvenDivisor |= (K != 0);
5102+
// D is a power-of-two if D0 is one.
5103+
// If all divisors are power-of-two, we will prefer to avoid the fold.
5104+
AllDivisorsArePowerOfTwo &= D0.isOneValue();
5105+
5106+
// P = inv(D0, 2^W)
5107+
// 2^W requires W + 1 bits, so we have to extend and then truncate.
5108+
unsigned W = D.getBitWidth();
5109+
APInt P = D0.zext(W + 1)
5110+
.multiplicativeInverse(APInt::getSignedMinValue(W + 1))
5111+
.trunc(W);
5112+
assert(!P.isNullValue() && "No multiplicative inverse!"); // unreachable
5113+
assert((D0 * P).isOneValue() && "Multiplicative inverse sanity check.");
5114+
5115+
// A = floor((2^(W - 1) - 1) / D0) & -2^K
5116+
APInt A = APInt::getSignedMaxValue(W).udiv(D0);
5117+
A.clearLowBits(K);
5118+
5119+
NeedToApplyOffset |= A != 0;
5120+
5121+
// Q = floor((2 * A) / (2^K))
5122+
APInt Q = (2 * A).udiv(APInt::getOneBitSet(W, K));
5123+
5124+
assert(APInt::getAllOnesValue(SVT.getSizeInBits()).ugt(A) &&
5125+
"We are expecting that A is always less than all-ones for SVT");
5126+
assert(APInt::getAllOnesValue(ShSVT.getSizeInBits()).ugt(K) &&
5127+
"We are expecting that K is always less than all-ones for ShSVT");
5128+
5129+
// If the divisor is 1 the result can be constant-folded.
5130+
if (D.isOneValue()) {
5131+
// Set P, A and K to a bogus values so we can try to splat them.
5132+
P = 0;
5133+
A = -1;
5134+
K = -1;
5135+
5136+
// x ?% 1 == 0 <--> true <--> x u<= -1
5137+
Q = -1;
5138+
}
5139+
5140+
PAmts.push_back(DAG.getConstant(P, DL, SVT));
5141+
AAmts.push_back(DAG.getConstant(A, DL, SVT));
5142+
KAmts.push_back(
5143+
DAG.getConstant(APInt(ShSVT.getSizeInBits(), K), DL, ShSVT));
5144+
QAmts.push_back(DAG.getConstant(Q, DL, SVT));
5145+
return true;
5146+
};
5147+
5148+
SDValue N = REMNode.getOperand(0);
5149+
SDValue D = REMNode.getOperand(1);
5150+
5151+
// Collect the values from each element.
5152+
if (!ISD::matchUnaryPredicate(D, BuildSREMPattern))
5153+
return SDValue();
5154+
5155+
// If this is a srem by a one, avoid the fold since it can be constant-folded.
5156+
if (AllDivisorsAreOnes)
5157+
return SDValue();
5158+
5159+
// If this is a srem by a powers-of-two, avoid the fold since it can be
5160+
// best implemented as a bit test.
5161+
if (AllDivisorsArePowerOfTwo)
5162+
return SDValue();
5163+
5164+
SDValue PVal, AVal, KVal, QVal;
5165+
if (VT.isVector()) {
5166+
if (HadOneDivisor) {
5167+
// Try to turn PAmts into a splat, since we don't care about the values
5168+
// that are currently '0'. If we can't, just keep '0'`s.
5169+
turnVectorIntoSplatVector(PAmts, isNullConstant);
5170+
// Try to turn AAmts into a splat, since we don't care about the
5171+
// values that are currently '-1'. If we can't, change them to '0'`s.
5172+
turnVectorIntoSplatVector(AAmts, isAllOnesConstant,
5173+
DAG.getConstant(0, DL, SVT));
5174+
// Try to turn KAmts into a splat, since we don't care about the values
5175+
// that are currently '-1'. If we can't, change them to '0'`s.
5176+
turnVectorIntoSplatVector(KAmts, isAllOnesConstant,
5177+
DAG.getConstant(0, DL, ShSVT));
5178+
}
5179+
5180+
PVal = DAG.getBuildVector(VT, DL, PAmts);
5181+
AVal = DAG.getBuildVector(VT, DL, AAmts);
5182+
KVal = DAG.getBuildVector(ShVT, DL, KAmts);
5183+
QVal = DAG.getBuildVector(VT, DL, QAmts);
5184+
} else {
5185+
PVal = PAmts[0];
5186+
AVal = AAmts[0];
5187+
KVal = KAmts[0];
5188+
QVal = QAmts[0];
5189+
}
5190+
5191+
// (mul N, P)
5192+
SDValue Op0 = DAG.getNode(ISD::MUL, DL, VT, N, PVal);
5193+
Created.push_back(Op0.getNode());
5194+
5195+
if (NeedToApplyOffset) {
5196+
// We need ADD to do this.
5197+
if (!isOperationLegalOrCustom(ISD::ADD, VT))
5198+
return SDValue();
5199+
5200+
// (add (mul N, P), A)
5201+
Op0 = DAG.getNode(ISD::ADD, DL, VT, Op0, AVal);
5202+
Created.push_back(Op0.getNode());
5203+
}
5204+
5205+
// Rotate right only if any divisor was even. We avoid rotates for all-odd
5206+
// divisors as a performance improvement, since rotating by 0 is a no-op.
5207+
if (HadEvenDivisor) {
5208+
// We need ROTR to do this.
5209+
if (!isOperationLegalOrCustom(ISD::ROTR, VT))
5210+
return SDValue();
5211+
SDNodeFlags Flags;
5212+
Flags.setExact(true);
5213+
// SREM: (rotr (add (mul N, P), A), K)
5214+
Op0 = DAG.getNode(ISD::ROTR, DL, VT, Op0, KVal, Flags);
5215+
Created.push_back(Op0.getNode());
5216+
}
5217+
5218+
// SREM: (setule/setugt (rotr (add (mul N, P), A), K), Q)
5219+
return DAG.getSetCC(DL, SETCCVT, Op0, QVal,
5220+
((Cond == ISD::SETEQ) ? ISD::SETULE : ISD::SETUGT));
5221+
}
5222+
50075223
bool TargetLowering::
50085224
verifyReturnAddressArgumentIsConstant(SDValue Op, SelectionDAG &DAG) const {
50095225
if (!isa<ConstantSDNode>(Op.getOperand(0))) {

llvm/test/CodeGen/AArch64/srem-seteq-optsize.ll

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,16 @@ define i32 @test_minsize(i32 %X) optsize minsize nounwind readnone {
2121
define i32 @test_optsize(i32 %X) optsize nounwind readnone {
2222
; CHECK-LABEL: test_optsize:
2323
; CHECK: // %bb.0:
24-
; CHECK-NEXT: mov w8, #26215
25-
; CHECK-NEXT: movk w8, #26214, lsl #16
26-
; CHECK-NEXT: smull x8, w0, w8
27-
; CHECK-NEXT: lsr x10, x8, #63
28-
; CHECK-NEXT: asr x8, x8, #33
29-
; CHECK-NEXT: add w8, w8, w10
30-
; CHECK-NEXT: add w8, w8, w8, lsl #2
31-
; CHECK-NEXT: mov w9, #-10
32-
; CHECK-NEXT: cmp w0, w8
24+
; CHECK-NEXT: mov w8, #52429
25+
; CHECK-NEXT: mov w9, #39321
26+
; CHECK-NEXT: movk w8, #52428, lsl #16
27+
; CHECK-NEXT: movk w9, #6553, lsl #16
28+
; CHECK-NEXT: mov w10, #858993459
29+
; CHECK-NEXT: madd w8, w0, w8, w9
30+
; CHECK-NEXT: mov w11, #-10
31+
; CHECK-NEXT: cmp w8, w10
3332
; CHECK-NEXT: mov w8, #42
34-
; CHECK-NEXT: csel w0, w8, w9, eq
33+
; CHECK-NEXT: csel w0, w8, w11, lo
3534
; CHECK-NEXT: ret
3635
%rem = srem i32 %X, 5
3736
%cmp = icmp eq i32 %rem, 0

0 commit comments

Comments
 (0)