-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[InstCombine] Fold switch (cond ? A + C1 : C2)
to switch (cond ? A : C2 - C1)
#156816
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Remove redundant binary operation in a switch's select condition while adjusting the switch's arms accordingly.
@llvm/pr-subscribers-llvm-transforms Author: Veera (veera-sivarajan) ChangesFixes #152626 Remove redundant binary operation in a switch's select condition while adjusting the switch's arms accordingly. Full diff: https://github.com/llvm/llvm-project/pull/156816.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 5ee3bb1abe86e..420fc25c18093 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -4226,6 +4226,73 @@ static Value *simplifySwitchOnSelectUsingRanges(SwitchInst &SI,
return X;
}
+/// Remove redundant binary operation in a switch's select condition.
+///
+/// For example:
+/// ```
+/// switch (cond ? A + C1 : C2) {
+/// C3 => {}
+/// C4 => {}
+/// C5 => {}
+/// }
+/// ```
+///
+/// is turned into:
+/// ```
+/// switch (cond ? A : C2 - C1) {
+/// C3 - C1 => {}
+/// C4 - C1 => {}
+/// C5 - C1 => {}
+/// }
+/// ```
+static Value *removeRedundantBinOp(SwitchInst &SI, SelectInst *Select,
+ InstCombiner::BuilderTy &Builder) {
+ Value *Cond, *TrueVal;
+ Constant *FalseC;
+ if (!match(Select, m_OneUse(m_Select(m_Value(Cond), m_Value(TrueVal),
+ m_Constant(FalseC)))))
+ return nullptr;
+
+ auto *BinOp = dyn_cast<BinaryOperator>(TrueVal);
+ if (!BinOp || !BinOp->hasOneUse())
+ return nullptr;
+
+ Value *BinOpLHS;
+ Constant *BinOpRHS;
+ if (!match(BinOp, m_BinOp(m_Value(BinOpLHS), m_Constant(BinOpRHS))))
+ return nullptr;
+
+ auto getInverseBinaryOp = [](Instruction::BinaryOps BinOp)
+ -> std::optional<Instruction::BinaryOps> {
+ switch (BinOp) {
+ default:
+ return std::nullopt;
+ case Instruction::Add:
+ return Instruction::Sub;
+ case Instruction::Sub:
+ return Instruction::Add;
+ }
+ };
+
+ auto InverseBinOp = getInverseBinaryOp(BinOp->getOpcode());
+ if (!InverseBinOp)
+ return nullptr;
+
+ const DataLayout &DL = SI.getDataLayout();
+ Constant *NewFalseVal =
+ ConstantFoldBinaryOpOperands(*InverseBinOp, FalseC, BinOpRHS, DL);
+ if (!NewFalseVal)
+ return nullptr;
+
+ auto *NewSelect = Builder.CreateSelect(Cond, BinOpLHS, NewFalseVal);
+ for (auto &OrigCase : SI.cases()) {
+ auto *NewCase = ConstantFoldBinaryOpOperands(
+ *InverseBinOp, OrigCase.getCaseValue(), BinOpRHS, DL);
+ OrigCase.setValue(cast<ConstantInt>(NewCase));
+ }
+ return NewSelect;
+}
+
Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
Value *Cond = SI.getCondition();
Value *Op0;
@@ -4299,14 +4366,17 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
}
}
- // Fold switch(select cond, X, Y) into switch(X/Y) if possible
if (auto *Select = dyn_cast<SelectInst>(Cond)) {
+ // Fold switch(select cond, X, Y) into switch(X/Y) if possible
if (Value *V =
simplifySwitchOnSelectUsingRanges(SI, Select, /*IsTrueArm=*/true))
return replaceOperand(SI, 0, V);
if (Value *V =
simplifySwitchOnSelectUsingRanges(SI, Select, /*IsTrueArm=*/false))
return replaceOperand(SI, 0, V);
+
+ if (Value *V = removeRedundantBinOp(SI, Select, Builder))
+ return replaceOperand(SI, 0, V);
}
KnownBits Known = computeKnownBits(Cond, &SI);
diff --git a/llvm/test/Transforms/InstCombine/switch-select.ll b/llvm/test/Transforms/InstCombine/switch-select.ll
index 60757c5d22527..2ad130fff8342 100644
--- a/llvm/test/Transforms/InstCombine/switch-select.ll
+++ b/llvm/test/Transforms/InstCombine/switch-select.ll
@@ -75,13 +75,12 @@ bb3:
define void @test_ult_rhsc_invalid_cond(i8 %x, i8 %y) {
; CHECK-LABEL: define void @test_ult_rhsc_invalid_cond(
; CHECK-SAME: i8 [[X:%.*]], i8 [[Y:%.*]]) {
-; CHECK-NEXT: [[VAL:%.*]] = add nsw i8 [[X]], -2
; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[Y]], 11
-; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[VAL]], i8 6
+; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[X]], i8 8
; CHECK-NEXT: switch i8 [[COND]], label [[BB1:%.*]] [
-; CHECK-NEXT: i8 0, label [[BB2:%.*]]
-; CHECK-NEXT: i8 10, label [[BB3:%.*]]
-; CHECK-NEXT: i8 13, label [[BB3]]
+; CHECK-NEXT: i8 2, label [[BB2:%.*]]
+; CHECK-NEXT: i8 12, label [[BB3:%.*]]
+; CHECK-NEXT: i8 15, label [[BB3]]
; CHECK-NEXT: ]
; CHECK: bb1:
; CHECK-NEXT: call void @func1()
@@ -154,6 +153,335 @@ bb3:
unreachable
}
+define void @remove_redundant_binop_add(i8 %lhs, i8 %value) {
+; CHECK-LABEL: define void @remove_redundant_binop_add(
+; CHECK-SAME: i8 [[LHS:%.*]], i8 [[VALUE:%.*]]) {
+; CHECK-NEXT: [[COND_INV:%.*]] = icmp ult i8 [[LHS]], 2
+; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[COND_INV]], i8 2, i8 [[VALUE]]
+; CHECK-NEXT: switch i8 [[TMP1]], label [[BB1:%.*]] [
+; CHECK-NEXT: i8 2, label [[BB1]]
+; CHECK-NEXT: i8 3, label [[BB2:%.*]]
+; CHECK-NEXT: i8 4, label [[BB3:%.*]]
+; CHECK-NEXT: ]
+; CHECK: bb1:
+; CHECK-NEXT: call void @func1()
+; CHECK-NEXT: unreachable
+; CHECK: bb2:
+; CHECK-NEXT: call void @func2()
+; CHECK-NEXT: unreachable
+; CHECK: bb3:
+; CHECK-NEXT: call void @func3()
+; CHECK-NEXT: unreachable
+;
+ %cond = icmp samesign ugt i8 %lhs, 1
+ %binop = add nuw nsw i8 %value, 25
+ %select = select i1 %cond, i8 %binop, i8 27
+ switch i8 %select, label %bb1 [
+ i8 27, label %bb1
+ i8 28, label %bb2
+ i8 29, label %bb3
+ ]
+
+bb1:
+ call void @func1()
+ unreachable
+bb2:
+ call void @func2()
+ unreachable
+bb3:
+ call void @func3()
+ unreachable
+}
+
+define void @remove_redundant_binop_sub(i8 %lhs, i8 %value) {
+; CHECK-LABEL: define void @remove_redundant_binop_sub(
+; CHECK-SAME: i8 [[LHS:%.*]], i8 [[VALUE:%.*]]) {
+; CHECK-NEXT: [[COND:%.*]] = icmp samesign ugt i8 [[LHS]], 1
+; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[COND]], i8 [[VALUE]], i8 0
+; CHECK-NEXT: switch i8 [[TMP1]], label [[BB1:%.*]] [
+; CHECK-NEXT: i8 0, label [[BB1]]
+; CHECK-NEXT: i8 1, label [[BB2:%.*]]
+; CHECK-NEXT: i8 2, label [[BB3:%.*]]
+; CHECK-NEXT: ]
+; CHECK: bb1:
+; CHECK-NEXT: call void @func1()
+; CHECK-NEXT: unreachable
+; CHECK: bb2:
+; CHECK-NEXT: call void @func2()
+; CHECK-NEXT: unreachable
+; CHECK: bb3:
+; CHECK-NEXT: call void @func3()
+; CHECK-NEXT: unreachable
+;
+ %cond = icmp samesign ugt i8 %lhs, 1
+ %binop = sub nuw nsw i8 %value, 10
+ %select = select i1 %cond, i8 %binop, i8 -10
+ switch i8 %select, label %bb1 [
+ i8 -10, label %bb1
+ i8 -9, label %bb2
+ i8 -8, label %bb3
+ ]
+
+bb1:
+ call void @func1()
+ unreachable
+bb2:
+ call void @func2()
+ unreachable
+bb3:
+ call void @func3()
+ unreachable
+}
+
+define void @remove_redundant_binop_multiuse_binop_negative(i8 %lhs, i8 %value) {
+; CHECK-LABEL: define void @remove_redundant_binop_multiuse_binop_negative(
+; CHECK-SAME: i8 [[LHS:%.*]], i8 [[VALUE:%.*]]) {
+; CHECK-NEXT: [[COND:%.*]] = icmp samesign ugt i8 [[LHS]], 1
+; CHECK-NEXT: [[BINOP:%.*]] = add nsw i8 [[VALUE]], -10
+; CHECK-NEXT: call void @use(i8 [[BINOP]])
+; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[COND]], i8 [[BINOP]], i8 -10
+; CHECK-NEXT: switch i8 [[TMP1]], label [[BB1:%.*]] [
+; CHECK-NEXT: i8 -10, label [[BB1]]
+; CHECK-NEXT: i8 -9, label [[BB2:%.*]]
+; CHECK-NEXT: i8 -8, label [[BB3:%.*]]
+; CHECK-NEXT: ]
+; CHECK: bb1:
+; CHECK-NEXT: call void @func1()
+; CHECK-NEXT: unreachable
+; CHECK: bb2:
+; CHECK-NEXT: call void @func2()
+; CHECK-NEXT: unreachable
+; CHECK: bb3:
+; CHECK-NEXT: call void @func3()
+; CHECK-NEXT: unreachable
+;
+ %cond = icmp samesign ugt i8 %lhs, 1
+ %binop = sub nuw nsw i8 %value, 10
+ call void @use(i8 %binop)
+ %select = select i1 %cond, i8 %binop, i8 -10
+ switch i8 %select, label %bb1 [
+ i8 -10, label %bb1
+ i8 -9, label %bb2
+ i8 -8, label %bb3
+ ]
+
+bb1:
+ call void @func1()
+ unreachable
+bb2:
+ call void @func2()
+ unreachable
+bb3:
+ call void @func3()
+ unreachable
+}
+
+define void @remove_redundant_binop_unknown_binop_negative(i8 %lhs, i8 %value) {
+; CHECK-LABEL: define void @remove_redundant_binop_unknown_binop_negative(
+; CHECK-SAME: i8 [[LHS:%.*]], i8 [[VALUE:%.*]]) {
+; CHECK-NEXT: [[COND:%.*]] = icmp samesign ugt i8 [[LHS]], 1
+; CHECK-NEXT: [[BINOP:%.*]] = mul nuw nsw i8 [[VALUE]], 10
+; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[COND]], i8 [[BINOP]], i8 -10
+; CHECK-NEXT: switch i8 [[SELECT]], label [[BB1:%.*]] [
+; CHECK-NEXT: i8 -10, label [[BB1]]
+; CHECK-NEXT: i8 -9, label [[BB2:%.*]]
+; CHECK-NEXT: i8 -8, label [[BB3:%.*]]
+; CHECK-NEXT: ]
+; CHECK: bb1:
+; CHECK-NEXT: call void @func1()
+; CHECK-NEXT: unreachable
+; CHECK: bb2:
+; CHECK-NEXT: call void @func2()
+; CHECK-NEXT: unreachable
+; CHECK: bb3:
+; CHECK-NEXT: call void @func3()
+; CHECK-NEXT: unreachable
+;
+ %cond = icmp samesign ugt i8 %lhs, 1
+ %binop = mul nuw nsw i8 %value, 10
+ %select = select i1 %cond, i8 %binop, i8 -10
+ switch i8 %select, label %bb1 [
+ i8 -10, label %bb1
+ i8 -9, label %bb2
+ i8 -8, label %bb3
+ ]
+
+bb1:
+ call void @func1()
+ unreachable
+bb2:
+ call void @func2()
+ unreachable
+bb3:
+ call void @func3()
+ unreachable
+}
+
+define void @remove_redundant_binop_no_binop_negative(i8 %lhs, i8 %value) {
+; CHECK-LABEL: define void @remove_redundant_binop_no_binop_negative(
+; CHECK-SAME: i8 [[LHS:%.*]], i8 [[VALUE:%.*]]) {
+; CHECK-NEXT: [[COND:%.*]] = icmp samesign ugt i8 [[LHS]], 1
+; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[COND]], i8 [[VALUE]], i8 -10
+; CHECK-NEXT: switch i8 [[SELECT]], label [[BB1:%.*]] [
+; CHECK-NEXT: i8 -10, label [[BB1]]
+; CHECK-NEXT: i8 -9, label [[BB2:%.*]]
+; CHECK-NEXT: i8 -8, label [[BB3:%.*]]
+; CHECK-NEXT: ]
+; CHECK: bb1:
+; CHECK-NEXT: call void @func1()
+; CHECK-NEXT: unreachable
+; CHECK: bb2:
+; CHECK-NEXT: call void @func2()
+; CHECK-NEXT: unreachable
+; CHECK: bb3:
+; CHECK-NEXT: call void @func3()
+; CHECK-NEXT: unreachable
+;
+ %cond = icmp samesign ugt i8 %lhs, 1
+ %select = select i1 %cond, i8 %value, i8 -10
+ switch i8 %select, label %bb1 [
+ i8 -10, label %bb1
+ i8 -9, label %bb2
+ i8 -8, label %bb3
+ ]
+
+bb1:
+ call void @func1()
+ unreachable
+bb2:
+ call void @func2()
+ unreachable
+bb3:
+ call void @func3()
+ unreachable
+}
+
+define void @remove_redundant_binop_multiuse_select_negative(i8 %lhs, i8 %value) {
+; CHECK-LABEL: define void @remove_redundant_binop_multiuse_select_negative(
+; CHECK-SAME: i8 [[LHS:%.*]], i8 [[VALUE:%.*]]) {
+; CHECK-NEXT: [[COND:%.*]] = icmp samesign ugt i8 [[LHS]], 1
+; CHECK-NEXT: [[BINOP:%.*]] = add nsw i8 [[VALUE]], -10
+; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[COND]], i8 [[BINOP]], i8 -10
+; CHECK-NEXT: call void @use(i8 [[SELECT]])
+; CHECK-NEXT: switch i8 [[SELECT]], label [[BB1:%.*]] [
+; CHECK-NEXT: i8 -10, label [[BB1]]
+; CHECK-NEXT: i8 -9, label [[BB2:%.*]]
+; CHECK-NEXT: i8 -8, label [[BB3:%.*]]
+; CHECK-NEXT: ]
+; CHECK: bb1:
+; CHECK-NEXT: call void @func1()
+; CHECK-NEXT: unreachable
+; CHECK: bb2:
+; CHECK-NEXT: call void @func2()
+; CHECK-NEXT: unreachable
+; CHECK: bb3:
+; CHECK-NEXT: call void @func3()
+; CHECK-NEXT: unreachable
+;
+ %cond = icmp samesign ugt i8 %lhs, 1
+ %binop = sub nuw nsw i8 %value, 10
+ %select = select i1 %cond, i8 %binop, i8 -10
+ call void @use(i8 %select)
+ switch i8 %select, label %bb1 [
+ i8 -10, label %bb1
+ i8 -9, label %bb2
+ i8 -8, label %bb3
+ ]
+
+bb1:
+ call void @func1()
+ unreachable
+bb2:
+ call void @func2()
+ unreachable
+bb3:
+ call void @func3()
+ unreachable
+}
+
+define void @redundant_binop_variable_binop_negative(i8 %lhs, i8 %value, i8 %rhs) {
+; CHECK-LABEL: define void @redundant_binop_variable_binop_negative(
+; CHECK-SAME: i8 [[LHS:%.*]], i8 [[VALUE:%.*]], i8 [[RHS:%.*]]) {
+; CHECK-NEXT: [[COND:%.*]] = icmp samesign ugt i8 [[LHS]], 1
+; CHECK-NEXT: [[BINOP:%.*]] = sub nuw nsw i8 [[VALUE]], [[RHS]]
+; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[COND]], i8 [[BINOP]], i8 -10
+; CHECK-NEXT: switch i8 [[SELECT]], label [[BB1:%.*]] [
+; CHECK-NEXT: i8 -10, label [[BB1]]
+; CHECK-NEXT: i8 -9, label [[BB2:%.*]]
+; CHECK-NEXT: i8 -8, label [[BB3:%.*]]
+; CHECK-NEXT: ]
+; CHECK: bb1:
+; CHECK-NEXT: call void @func1()
+; CHECK-NEXT: unreachable
+; CHECK: bb2:
+; CHECK-NEXT: call void @func2()
+; CHECK-NEXT: unreachable
+; CHECK: bb3:
+; CHECK-NEXT: call void @func3()
+; CHECK-NEXT: unreachable
+;
+ %cond = icmp samesign ugt i8 %lhs, 1
+ %binop = sub nuw nsw i8 %value, %rhs
+ %select = select i1 %cond, i8 %binop, i8 -10
+ switch i8 %select, label %bb1 [
+ i8 -10, label %bb1
+ i8 -9, label %bb2
+ i8 -8, label %bb3
+ ]
+
+bb1:
+ call void @func1()
+ unreachable
+bb2:
+ call void @func2()
+ unreachable
+bb3:
+ call void @func3()
+ unreachable
+}
+
+define void @remove_redundant_op_variable_select_negative(i8 %lhs, i8 %value, i8 %falseVal) {
+; CHECK-LABEL: define void @remove_redundant_op_variable_select_negative(
+; CHECK-SAME: i8 [[LHS:%.*]], i8 [[VALUE:%.*]], i8 [[FALSEVAL:%.*]]) {
+; CHECK-NEXT: [[COND:%.*]] = icmp samesign ugt i8 [[LHS]], 1
+; CHECK-NEXT: [[BINOP:%.*]] = add nsw i8 [[VALUE]], -10
+; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[COND]], i8 [[BINOP]], i8 [[FALSEVAL]]
+; CHECK-NEXT: switch i8 [[SELECT]], label [[BB1:%.*]] [
+; CHECK-NEXT: i8 -10, label [[BB1]]
+; CHECK-NEXT: i8 -9, label [[BB2:%.*]]
+; CHECK-NEXT: i8 -8, label [[BB3:%.*]]
+; CHECK-NEXT: ]
+; CHECK: bb1:
+; CHECK-NEXT: call void @func1()
+; CHECK-NEXT: unreachable
+; CHECK: bb2:
+; CHECK-NEXT: call void @func2()
+; CHECK-NEXT: unreachable
+; CHECK: bb3:
+; CHECK-NEXT: call void @func3()
+; CHECK-NEXT: unreachable
+;
+ %cond = icmp samesign ugt i8 %lhs, 1
+ %binop = sub nuw nsw i8 %value, 10
+ %select = select i1 %cond, i8 %binop, i8 %falseVal
+ switch i8 %select, label %bb1 [
+ i8 -10, label %bb1
+ i8 -9, label %bb2
+ i8 -8, label %bb3
+ ]
+
+bb1:
+ call void @func1()
+ unreachable
+bb2:
+ call void @func2()
+ unreachable
+bb3:
+ call void @func3()
+ unreachable
+}
+
declare void @func1()
declare void @func2()
declare void @func3()
+declare void @use(i8)
|
static Value *removeRedundantBinOp(SwitchInst &SI, SelectInst *Select, | ||
InstCombiner::BuilderTy &Builder) { | ||
Value *Cond, *TrueVal; | ||
Constant *FalseC; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As the condition is a scalar, we can use m_APInt
to avoid constant folding calls (they may return constant expressions).
; | ||
%cond = icmp samesign ugt i8 %lhs, 1 | ||
%binop = add nuw nsw i8 %value, 25 | ||
%select = select i1 %cond, i8 %binop, i8 27 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a commuted test with select cond, 27, binop
? Not sure if it is canonicalized.
for (auto &OrigCase : SI.cases()) { | ||
auto *NewCase = ConstantFoldBinaryOpOperands( | ||
*InverseBinOp, OrigCase.getCaseValue(), BinOpRHS, DL); | ||
OrigCase.setValue(cast<ConstantInt>(NewCase)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See my previous comment. NewCase
may be a constant expr.
https://godbolt.org/z/oqMEzYhrP |
Fixes #152626
Proof: https://alive2.llvm.org/ce/z/LXE5ep
Remove redundant binary operation in a switch's select condition while adjusting the switch's arms accordingly.