Skip to content

Commit d1f0bdf

Browse files
committed
[SDAG] remove use restriction in isNegatibleForFree() when called from getNegatedExpression()
This is an alternate fix for the bug discussed in D70595. This also includes minimal tests for other in-tree targets to show the problem more generally. We check the number of uses as a predicate for whether some value is free to negate, but that use count can change as we rewrite the expression in getNegatedExpression(). So something that was marked free to negate during the cost evaluation phase becomes not free to negate during the rewrite phase (or the inverse - something that was not free becomes free). This can lead to a crash/assert because we expect that everything in an expression that is negatible to be handled in the corresponding code within getNegatedExpression(). This patch skips the use check during the rewrite phase. So we determine that some expression isNegatibleForFree (identically to without this patch), but during the rewrite, don't rely on use counts to decide how to create the optimal expression. Differential Revision: https://reviews.llvm.org/D70975
1 parent d8c31d4 commit d1f0bdf

File tree

6 files changed

+72
-20
lines changed

6 files changed

+72
-20
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3442,8 +3442,16 @@ class TargetLowering : public TargetLoweringBase {
34423442
/// Return 1 if we can compute the negated form of the specified expression
34433443
/// for the same cost as the expression itself, or 2 if we can compute the
34443444
/// negated form more cheaply than the expression itself. Else return 0.
3445+
///
3446+
/// EnableUseCheck specifies whether the number of uses of a value affects
3447+
/// if negation is considered free. This is needed because the number of uses
3448+
/// of any value may change as we rewrite the expression. Therefore, when
3449+
/// called from getNegatedExpression(), we must explicitly set EnableUseCheck
3450+
/// to false to avoid getting a different answer than when called from other
3451+
/// contexts.
34453452
virtual char isNegatibleForFree(SDValue Op, SelectionDAG &DAG,
34463453
bool LegalOperations, bool ForCodeSize,
3454+
bool EnableUseCheck = true,
34473455
unsigned Depth = 0) const;
34483456

34493457
/// If isNegatibleForFree returns true, return the newly negated expression.

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5413,18 +5413,21 @@ verifyReturnAddressArgumentIsConstant(SDValue Op, SelectionDAG &DAG) const {
54135413

54145414
char TargetLowering::isNegatibleForFree(SDValue Op, SelectionDAG &DAG,
54155415
bool LegalOperations, bool ForCodeSize,
5416+
bool EnableUseCheck,
54165417
unsigned Depth) const {
54175418
// fneg is removable even if it has multiple uses.
54185419
if (Op.getOpcode() == ISD::FNEG)
54195420
return 2;
54205421

5421-
// Don't allow anything with multiple uses unless we know it is free.
5422+
// If the caller requires checking uses, don't allow anything with multiple
5423+
// uses unless we know it is free.
54225424
EVT VT = Op.getValueType();
54235425
const SDNodeFlags Flags = Op->getFlags();
54245426
const TargetOptions &Options = DAG.getTarget().Options;
5425-
if (!Op.hasOneUse() && !(Op.getOpcode() == ISD::FP_EXTEND &&
5426-
isFPExtFree(VT, Op.getOperand(0).getValueType())))
5427-
return 0;
5427+
if (EnableUseCheck)
5428+
if (!Op.hasOneUse() && !(Op.getOpcode() == ISD::FP_EXTEND &&
5429+
isFPExtFree(VT, Op.getOperand(0).getValueType())))
5430+
return 0;
54285431

54295432
// Don't recurse exponentially.
54305433
if (Depth > SelectionDAG::MaxRecursionDepth)
@@ -5468,11 +5471,11 @@ char TargetLowering::isNegatibleForFree(SDValue Op, SelectionDAG &DAG,
54685471

54695472
// fold (fneg (fadd A, B)) -> (fsub (fneg A), B)
54705473
if (char V = isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations,
5471-
ForCodeSize, Depth + 1))
5474+
ForCodeSize, EnableUseCheck, Depth + 1))
54725475
return V;
54735476
// fold (fneg (fadd A, B)) -> (fsub (fneg B), A)
54745477
return isNegatibleForFree(Op.getOperand(1), DAG, LegalOperations,
5475-
ForCodeSize, Depth + 1);
5478+
ForCodeSize, EnableUseCheck, Depth + 1);
54765479
case ISD::FSUB:
54775480
// We can't turn -(A-B) into B-A when we honor signed zeros.
54785481
if (!Options.NoSignedZerosFPMath && !Flags.hasNoSignedZeros())
@@ -5485,7 +5488,7 @@ char TargetLowering::isNegatibleForFree(SDValue Op, SelectionDAG &DAG,
54855488
case ISD::FDIV:
54865489
// fold (fneg (fmul X, Y)) -> (fmul (fneg X), Y) or (fmul X, (fneg Y))
54875490
if (char V = isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations,
5488-
ForCodeSize, Depth + 1))
5491+
ForCodeSize, EnableUseCheck, Depth + 1))
54895492
return V;
54905493

54915494
// Ignore X * 2.0 because that is expected to be canonicalized to X + X.
@@ -5494,7 +5497,7 @@ char TargetLowering::isNegatibleForFree(SDValue Op, SelectionDAG &DAG,
54945497
return 0;
54955498

54965499
return isNegatibleForFree(Op.getOperand(1), DAG, LegalOperations,
5497-
ForCodeSize, Depth + 1);
5500+
ForCodeSize, EnableUseCheck, Depth + 1);
54985501

54995502
case ISD::FMA:
55005503
case ISD::FMAD: {
@@ -5504,15 +5507,15 @@ char TargetLowering::isNegatibleForFree(SDValue Op, SelectionDAG &DAG,
55045507
// fold (fneg (fma X, Y, Z)) -> (fma (fneg X), Y, (fneg Z))
55055508
// fold (fneg (fma X, Y, Z)) -> (fma X, (fneg Y), (fneg Z))
55065509
char V2 = isNegatibleForFree(Op.getOperand(2), DAG, LegalOperations,
5507-
ForCodeSize, Depth + 1);
5510+
ForCodeSize, EnableUseCheck, Depth + 1);
55085511
if (!V2)
55095512
return 0;
55105513

55115514
// One of Op0/Op1 must be cheaply negatible, then select the cheapest.
55125515
char V0 = isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations,
5513-
ForCodeSize, Depth + 1);
5516+
ForCodeSize, EnableUseCheck, Depth + 1);
55145517
char V1 = isNegatibleForFree(Op.getOperand(1), DAG, LegalOperations,
5515-
ForCodeSize, Depth + 1);
5518+
ForCodeSize, EnableUseCheck, Depth + 1);
55165519
char V01 = std::max(V0, V1);
55175520
return V01 ? std::max(V01, V2) : 0;
55185521
}
@@ -5521,7 +5524,7 @@ char TargetLowering::isNegatibleForFree(SDValue Op, SelectionDAG &DAG,
55215524
case ISD::FP_ROUND:
55225525
case ISD::FSIN:
55235526
return isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations,
5524-
ForCodeSize, Depth + 1);
5527+
ForCodeSize, EnableUseCheck, Depth + 1);
55255528
}
55265529

55275530
return 0;
@@ -5565,7 +5568,7 @@ SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
55655568

55665569
// fold (fneg (fadd A, B)) -> (fsub (fneg A), B)
55675570
if (isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations, ForCodeSize,
5568-
Depth + 1))
5571+
false, Depth + 1))
55695572
return DAG.getNode(ISD::FSUB, SDLoc(Op), Op.getValueType(),
55705573
getNegatedExpression(Op.getOperand(0), DAG,
55715574
LegalOperations, ForCodeSize,
@@ -5592,7 +5595,7 @@ SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
55925595
case ISD::FDIV:
55935596
// fold (fneg (fmul X, Y)) -> (fmul (fneg X), Y)
55945597
if (isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations, ForCodeSize,
5595-
Depth + 1))
5598+
false, Depth + 1))
55965599
return DAG.getNode(Op.getOpcode(), SDLoc(Op), Op.getValueType(),
55975600
getNegatedExpression(Op.getOperand(0), DAG,
55985601
LegalOperations, ForCodeSize,
@@ -5616,9 +5619,9 @@ SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
56165619
ForCodeSize, Depth + 1);
56175620

56185621
char V0 = isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations,
5619-
ForCodeSize, Depth + 1);
5622+
ForCodeSize, false, Depth + 1);
56205623
char V1 = isNegatibleForFree(Op.getOperand(1), DAG, LegalOperations,
5621-
ForCodeSize, Depth + 1);
5624+
ForCodeSize, false, Depth + 1);
56225625
if (V0 >= V1) {
56235626
// fold (fneg (fma X, Y, Z)) -> (fma (fneg X), Y, (fneg Z))
56245627
SDValue Neg0 = getNegatedExpression(

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41898,6 +41898,7 @@ static SDValue combineFneg(SDNode *N, SelectionDAG &DAG,
4189841898
char X86TargetLowering::isNegatibleForFree(SDValue Op, SelectionDAG &DAG,
4189941899
bool LegalOperations,
4190041900
bool ForCodeSize,
41901+
bool EnableUseCheck,
4190141902
unsigned Depth) const {
4190241903
// fneg patterns are removable even if they have multiple uses.
4190341904
if (isFNEG(DAG, Op.getNode(), Depth))
@@ -41926,7 +41927,7 @@ char X86TargetLowering::isNegatibleForFree(SDValue Op, SelectionDAG &DAG,
4192641927
// extra operand negations as well.
4192741928
for (int i = 0; i != 3; ++i) {
4192841929
char V = isNegatibleForFree(Op.getOperand(i), DAG, LegalOperations,
41929-
ForCodeSize, Depth + 1);
41930+
ForCodeSize, EnableUseCheck, Depth + 1);
4193041931
if (V == 2)
4193141932
return V;
4193241933
}
@@ -41935,7 +41936,8 @@ char X86TargetLowering::isNegatibleForFree(SDValue Op, SelectionDAG &DAG,
4193541936
}
4193641937

4193741938
return TargetLowering::isNegatibleForFree(Op, DAG, LegalOperations,
41938-
ForCodeSize, Depth);
41939+
ForCodeSize, EnableUseCheck,
41940+
Depth);
4193941941
}
4194041942

4194141943
SDValue X86TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
@@ -41967,7 +41969,7 @@ SDValue X86TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
4196741969
SmallVector<SDValue, 4> NewOps(Op.getNumOperands(), SDValue());
4196841970
for (int i = 0; i != 3; ++i) {
4196941971
char V = isNegatibleForFree(Op.getOperand(i), DAG, LegalOperations,
41970-
ForCodeSize, Depth + 1);
41972+
ForCodeSize, false, Depth + 1);
4197141973
if (V == 2)
4197241974
NewOps[i] = getNegatedExpression(Op.getOperand(i), DAG, LegalOperations,
4197341975
ForCodeSize, Depth + 1);

llvm/lib/Target/X86/X86ISelLowering.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -809,7 +809,8 @@ namespace llvm {
809809
/// for the same cost as the expression itself, or 2 if we can compute the
810810
/// negated form more cheaply than the expression itself. Else return 0.
811811
char isNegatibleForFree(SDValue Op, SelectionDAG &DAG, bool LegalOperations,
812-
bool ForCodeSize, unsigned Depth) const override;
812+
bool ForCodeSize, bool EnableUseCheck,
813+
unsigned Depth) const override;
813814

814815
/// If isNegatibleForFree returns true, return the newly negated expression.
815816
SDValue getNegatedExpression(SDValue Op, SelectionDAG &DAG,

llvm/test/CodeGen/AArch64/arm64-fmadd.ll

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,5 +88,23 @@ entry:
8888
ret double %0
8989
}
9090

91+
; This would crash while trying getNegatedExpression().
92+
93+
define float @negated_constant(float %x) {
94+
; CHECK-LABEL: negated_constant:
95+
; CHECK: // %bb.0:
96+
; CHECK-NEXT: mov w8, #-1037565952
97+
; CHECK-NEXT: mov w9, #1109917696
98+
; CHECK-NEXT: fmov s1, w8
99+
; CHECK-NEXT: fmul s1, s0, s1
100+
; CHECK-NEXT: fmov s2, w9
101+
; CHECK-NEXT: fmadd s0, s0, s2, s1
102+
; CHECK-NEXT: ret
103+
%m = fmul float %x, 42.0
104+
%fma = call nsz float @llvm.fma.f32(float %x, float -42.0, float %m)
105+
%nfma = fneg float %fma
106+
ret float %nfma
107+
}
108+
91109
declare float @llvm.fma.f32(float, float, float) nounwind readnone
92110
declare double @llvm.fma.f64(double, double, double) nounwind readnone

llvm/test/CodeGen/X86/fma-fneg-combine-2.ll

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,24 @@ entry:
8686
ret float %1
8787
}
8888

89+
; This would crash while trying getNegatedExpression().
90+
91+
define float @negated_constant(float %x) {
92+
; FMA3-LABEL: negated_constant:
93+
; FMA3: # %bb.0:
94+
; FMA3-NEXT: vmulss {{.*}}(%rip), %xmm0, %xmm1
95+
; FMA3-NEXT: vfmadd132ss {{.*#+}} xmm0 = (xmm0 * mem) + xmm1
96+
; FMA3-NEXT: retq
97+
;
98+
; FMA4-LABEL: negated_constant:
99+
; FMA4: # %bb.0:
100+
; FMA4-NEXT: vmulss {{.*}}(%rip), %xmm0, %xmm1
101+
; FMA4-NEXT: vfmaddss %xmm1, {{.*}}(%rip), %xmm0, %xmm0
102+
; FMA4-NEXT: retq
103+
%m = fmul float %x, 42.0
104+
%fma = call nsz float @llvm.fma.f32(float %x, float -42.0, float %m)
105+
%nfma = fneg float %fma
106+
ret float %nfma
107+
}
108+
89109
declare float @llvm.fma.f32(float, float, float)

0 commit comments

Comments
 (0)