@@ -3802,15 +3802,21 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
3802
3802
}
3803
3803
3804
3804
// Fold remainder of division by a constant.
3805
- if (N0.getOpcode () == ISD::UREM && N0.hasOneUse ( ) &&
3806
- (Cond == ISD::SETEQ || Cond == ISD::SETNE)) {
3805
+ if (( N0.getOpcode () == ISD::UREM || N0.getOpcode () == ISD::SREM ) &&
3806
+ N0. hasOneUse () && (Cond == ISD::SETEQ || Cond == ISD::SETNE)) {
3807
3807
AttributeList Attr = DAG.getMachineFunction ().getFunction ().getAttributes ();
3808
3808
3809
3809
// When division is cheap or optimizing for minimum size,
3810
3810
// fall through to DIVREM creation by skipping this fold.
3811
- if (!isIntDivCheap (VT, Attr) && !Attr.hasFnAttribute (Attribute::MinSize))
3812
- if (SDValue Folded = buildUREMEqFold (VT, N0, N1, Cond, DCI, dl))
3813
- return Folded;
3811
+ if (!isIntDivCheap (VT, Attr) && !Attr.hasFnAttribute (Attribute::MinSize)) {
3812
+ if (N0.getOpcode () == ISD::UREM) {
3813
+ if (SDValue Folded = buildUREMEqFold (VT, N0, N1, Cond, DCI, dl))
3814
+ return Folded;
3815
+ } else if (N0.getOpcode () == ISD::SREM) {
3816
+ if (SDValue Folded = buildSREMEqFold (VT, N0, N1, Cond, DCI, dl))
3817
+ return Folded;
3818
+ }
3819
+ }
3814
3820
}
3815
3821
3816
3822
// Fold away ALL boolean setcc's.
@@ -5004,6 +5010,216 @@ TargetLowering::prepareUREMEqFold(EVT SETCCVT, SDValue REMNode,
5004
5010
((Cond == ISD::SETEQ) ? ISD::SETULE : ISD::SETUGT));
5005
5011
}
5006
5012
5013
+ // / Given an ISD::SREM used only by an ISD::SETEQ or ISD::SETNE
5014
+ // / where the divisor is constant and the comparison target is zero,
5015
+ // / return a DAG expression that will generate the same comparison result
5016
+ // / using only multiplications, additions and shifts/rotations.
5017
+ // / Ref: "Hacker's Delight" 10-17.
5018
+ SDValue TargetLowering::buildSREMEqFold (EVT SETCCVT, SDValue REMNode,
5019
+ SDValue CompTargetNode,
5020
+ ISD::CondCode Cond,
5021
+ DAGCombinerInfo &DCI,
5022
+ const SDLoc &DL) const {
5023
+ SmallVector<SDNode *, 3 > Built;
5024
+ if (SDValue Folded = prepareSREMEqFold (SETCCVT, REMNode, CompTargetNode, Cond,
5025
+ DCI, DL, Built)) {
5026
+ for (SDNode *N : Built)
5027
+ DCI.AddToWorklist (N);
5028
+ return Folded;
5029
+ }
5030
+
5031
+ return SDValue ();
5032
+ }
5033
+
5034
+ SDValue
5035
+ TargetLowering::prepareSREMEqFold (EVT SETCCVT, SDValue REMNode,
5036
+ SDValue CompTargetNode, ISD::CondCode Cond,
5037
+ DAGCombinerInfo &DCI, const SDLoc &DL,
5038
+ SmallVectorImpl<SDNode *> &Created) const {
5039
+ // Fold:
5040
+ // (seteq/ne (srem N, D), 0)
5041
+ // To:
5042
+ // (setule/ugt (rotr (add (mul N, P), A), K), Q)
5043
+ //
5044
+ // - D must be constant, with D = D0 * 2^K where D0 is odd
5045
+ // - P is the multiplicative inverse of D0 modulo 2^W
5046
+ // - A = bitwiseand(floor((2^(W - 1) - 1) / D0), (-(2^k)))
5047
+ // - Q = floor((2 * A) / (2^K))
5048
+ // where W is the width of the common type of N and D.
5049
+ assert ((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
5050
+ " Only applicable for (in)equality comparisons." );
5051
+
5052
+ SelectionDAG &DAG = DCI.DAG ;
5053
+
5054
+ EVT VT = REMNode.getValueType ();
5055
+ EVT SVT = VT.getScalarType ();
5056
+ EVT ShVT = getShiftAmountTy (VT, DAG.getDataLayout ());
5057
+ EVT ShSVT = ShVT.getScalarType ();
5058
+
5059
+ // If MUL is unavailable, we cannot proceed in any case.
5060
+ if (!isOperationLegalOrCustom (ISD::MUL, VT))
5061
+ return SDValue ();
5062
+
5063
+ // TODO: Could support comparing with non-zero too.
5064
+ ConstantSDNode *CompTarget = isConstOrConstSplat (CompTargetNode);
5065
+ if (!CompTarget || !CompTarget->isNullValue ())
5066
+ return SDValue ();
5067
+
5068
+ bool HadOneDivisor = false ;
5069
+ bool AllDivisorsAreOnes = true ;
5070
+ bool HadEvenDivisor = false ;
5071
+ bool NeedToApplyOffset = false ;
5072
+ bool AllDivisorsArePowerOfTwo = true ;
5073
+ SmallVector<SDValue, 16 > PAmts, AAmts, KAmts, QAmts;
5074
+
5075
+ auto BuildSREMPattern = [&](ConstantSDNode *C) {
5076
+ // Division by 0 is UB. Leave it to be constant-folded elsewhere.
5077
+ if (C->isNullValue ())
5078
+ return false ;
5079
+
5080
+ // FIXME: we don't fold `rem %X, -C` to `rem %X, C` in DAGCombine.
5081
+
5082
+ // WARNING: this fold is only valid for positive divisors!
5083
+ APInt D = C->getAPIntValue ();
5084
+ if (D.isMinSignedValue ())
5085
+ return false ; // We can't negate INT_MIN.
5086
+ if (D.isNegative ())
5087
+ D.negate (); // `rem %X, -C` is equivalent to `rem %X, C`
5088
+
5089
+ assert (!D.isNegative () && " The fold is only valid for positive divisors!" );
5090
+
5091
+ // If all divisors are ones, we will prefer to avoid the fold.
5092
+ HadOneDivisor |= D.isOneValue ();
5093
+ AllDivisorsAreOnes &= D.isOneValue ();
5094
+
5095
+ // Decompose D into D0 * 2^K
5096
+ unsigned K = D.countTrailingZeros ();
5097
+ assert ((!D.isOneValue () || (K == 0 )) && " For divisor '1' we won't rotate." );
5098
+ APInt D0 = D.lshr (K);
5099
+
5100
+ // D is even if it has trailing zeros.
5101
+ HadEvenDivisor |= (K != 0 );
5102
+ // D is a power-of-two if D0 is one.
5103
+ // If all divisors are power-of-two, we will prefer to avoid the fold.
5104
+ AllDivisorsArePowerOfTwo &= D0.isOneValue ();
5105
+
5106
+ // P = inv(D0, 2^W)
5107
+ // 2^W requires W + 1 bits, so we have to extend and then truncate.
5108
+ unsigned W = D.getBitWidth ();
5109
+ APInt P = D0.zext (W + 1 )
5110
+ .multiplicativeInverse (APInt::getSignedMinValue (W + 1 ))
5111
+ .trunc (W);
5112
+ assert (!P.isNullValue () && " No multiplicative inverse!" ); // unreachable
5113
+ assert ((D0 * P).isOneValue () && " Multiplicative inverse sanity check." );
5114
+
5115
+ // A = floor((2^(W - 1) - 1) / D0) & -2^K
5116
+ APInt A = APInt::getSignedMaxValue (W).udiv (D0);
5117
+ A.clearLowBits (K);
5118
+
5119
+ NeedToApplyOffset |= A != 0 ;
5120
+
5121
+ // Q = floor((2 * A) / (2^K))
5122
+ APInt Q = (2 * A).udiv (APInt::getOneBitSet (W, K));
5123
+
5124
+ assert (APInt::getAllOnesValue (SVT.getSizeInBits ()).ugt (A) &&
5125
+ " We are expecting that A is always less than all-ones for SVT" );
5126
+ assert (APInt::getAllOnesValue (ShSVT.getSizeInBits ()).ugt (K) &&
5127
+ " We are expecting that K is always less than all-ones for ShSVT" );
5128
+
5129
+ // If the divisor is 1 the result can be constant-folded.
5130
+ if (D.isOneValue ()) {
5131
+ // Set P, A and K to a bogus values so we can try to splat them.
5132
+ P = 0 ;
5133
+ A = -1 ;
5134
+ K = -1 ;
5135
+
5136
+ // x ?% 1 == 0 <--> true <--> x u<= -1
5137
+ Q = -1 ;
5138
+ }
5139
+
5140
+ PAmts.push_back (DAG.getConstant (P, DL, SVT));
5141
+ AAmts.push_back (DAG.getConstant (A, DL, SVT));
5142
+ KAmts.push_back (
5143
+ DAG.getConstant (APInt (ShSVT.getSizeInBits (), K), DL, ShSVT));
5144
+ QAmts.push_back (DAG.getConstant (Q, DL, SVT));
5145
+ return true ;
5146
+ };
5147
+
5148
+ SDValue N = REMNode.getOperand (0 );
5149
+ SDValue D = REMNode.getOperand (1 );
5150
+
5151
+ // Collect the values from each element.
5152
+ if (!ISD::matchUnaryPredicate (D, BuildSREMPattern))
5153
+ return SDValue ();
5154
+
5155
+ // If this is a srem by a one, avoid the fold since it can be constant-folded.
5156
+ if (AllDivisorsAreOnes)
5157
+ return SDValue ();
5158
+
5159
+ // If this is a srem by a powers-of-two, avoid the fold since it can be
5160
+ // best implemented as a bit test.
5161
+ if (AllDivisorsArePowerOfTwo)
5162
+ return SDValue ();
5163
+
5164
+ SDValue PVal, AVal, KVal, QVal;
5165
+ if (VT.isVector ()) {
5166
+ if (HadOneDivisor) {
5167
+ // Try to turn PAmts into a splat, since we don't care about the values
5168
+ // that are currently '0'. If we can't, just keep '0'`s.
5169
+ turnVectorIntoSplatVector (PAmts, isNullConstant);
5170
+ // Try to turn AAmts into a splat, since we don't care about the
5171
+ // values that are currently '-1'. If we can't, change them to '0'`s.
5172
+ turnVectorIntoSplatVector (AAmts, isAllOnesConstant,
5173
+ DAG.getConstant (0 , DL, SVT));
5174
+ // Try to turn KAmts into a splat, since we don't care about the values
5175
+ // that are currently '-1'. If we can't, change them to '0'`s.
5176
+ turnVectorIntoSplatVector (KAmts, isAllOnesConstant,
5177
+ DAG.getConstant (0 , DL, ShSVT));
5178
+ }
5179
+
5180
+ PVal = DAG.getBuildVector (VT, DL, PAmts);
5181
+ AVal = DAG.getBuildVector (VT, DL, AAmts);
5182
+ KVal = DAG.getBuildVector (ShVT, DL, KAmts);
5183
+ QVal = DAG.getBuildVector (VT, DL, QAmts);
5184
+ } else {
5185
+ PVal = PAmts[0 ];
5186
+ AVal = AAmts[0 ];
5187
+ KVal = KAmts[0 ];
5188
+ QVal = QAmts[0 ];
5189
+ }
5190
+
5191
+ // (mul N, P)
5192
+ SDValue Op0 = DAG.getNode (ISD::MUL, DL, VT, N, PVal);
5193
+ Created.push_back (Op0.getNode ());
5194
+
5195
+ if (NeedToApplyOffset) {
5196
+ // We need ADD to do this.
5197
+ if (!isOperationLegalOrCustom (ISD::ADD, VT))
5198
+ return SDValue ();
5199
+
5200
+ // (add (mul N, P), A)
5201
+ Op0 = DAG.getNode (ISD::ADD, DL, VT, Op0, AVal);
5202
+ Created.push_back (Op0.getNode ());
5203
+ }
5204
+
5205
+ // Rotate right only if any divisor was even. We avoid rotates for all-odd
5206
+ // divisors as a performance improvement, since rotating by 0 is a no-op.
5207
+ if (HadEvenDivisor) {
5208
+ // We need ROTR to do this.
5209
+ if (!isOperationLegalOrCustom (ISD::ROTR, VT))
5210
+ return SDValue ();
5211
+ SDNodeFlags Flags;
5212
+ Flags.setExact (true );
5213
+ // SREM: (rotr (add (mul N, P), A), K)
5214
+ Op0 = DAG.getNode (ISD::ROTR, DL, VT, Op0, KVal, Flags);
5215
+ Created.push_back (Op0.getNode ());
5216
+ }
5217
+
5218
+ // SREM: (setule/setugt (rotr (add (mul N, P), A), K), Q)
5219
+ return DAG.getSetCC (DL, SETCCVT, Op0, QVal,
5220
+ ((Cond == ISD::SETEQ) ? ISD::SETULE : ISD::SETUGT));
5221
+ }
5222
+
5007
5223
bool TargetLowering::
5008
5224
verifyReturnAddressArgumentIsConstant (SDValue Op, SelectionDAG &DAG) const {
5009
5225
if (!isa<ConstantSDNode>(Op.getOperand (0 ))) {
0 commit comments