-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[InstCombine] optimize unnecessary sext instruction with add + cmp #152291
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-llvm-transforms Author: Gaurav Dhingra (gxyd) ChangesFixes #130088 Full diff: https://github.com/llvm/llvm-project/pull/152291.diff 4 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index cf94d28100488..ae329e9aaec4c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -6381,6 +6381,36 @@ Instruction *InstCombinerImpl::foldICmpWithZextOrSext(ICmpInst &ICmp) {
return new ICmpInst(CmpInst::ICMP_SLT, X, Constant::getNullValue(SrcTy));
}
+Instruction *InstCombinerImpl::foldICmpWithSextAndAdd(ICmpInst &ICmp) {
+ Value *X;
+ ConstantInt *Y, *Z;
+ // Match the pattern: icmp ult (add (sext X), Y), Z
+ // where X is a value, Y and Z are integer constants
+ // icmp ult (add(sext(X), Y)), Z -> icmp ult (add(X, Y)), Z
+ if (match(&ICmp, m_SpecificICmp(CmpInst::ICMP_ULT,
+ m_Add(m_SExt(m_Value(X)), m_ConstantInt(Y)),
+ m_ConstantInt(Z)))) {
+ Type *XType = X->getType();
+ if (!XType->isIntegerTy())
+ return nullptr;
+
+ unsigned XBitWidth = XType->getIntegerBitWidth();
+ auto YValue = Y->getSExtValue();
+ auto ZValue = Z->getSExtValue();
+
+ auto MinValue = -(1LL << (XBitWidth - 1));
+ auto MaxValue = (1LL << (XBitWidth - 1)) - 1;
+
+ // // Check if Y and Z fit within X's type without wrapping
+ if (YValue < MinValue || YValue > MaxValue || ZValue < MinValue || ZValue > MaxValue)
+ return nullptr; // Cannot optimize if Y or Z would wrap in X's type
+
+ Value *NewAdd = Builder.CreateAdd(X, ConstantInt::get(XType, YValue));
+ return new ICmpInst(CmpInst::ICMP_ULT, NewAdd, ConstantInt::get(XType, ZValue));
+ }
+ return nullptr;
+}
+
/// Handle icmp (cast x), (cast or constant).
Instruction *InstCombinerImpl::foldICmpWithCastOp(ICmpInst &ICmp) {
// If any operand of ICmp is a inttoptr roundtrip cast then remove it as
@@ -7727,6 +7757,10 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
}
}
+ if (Instruction *Res = foldICmpWithSextAndAdd(I)) {
+ return Res;
+ }
+
// In case of a comparison with two select instructions having the same
// condition, check whether one of the resulting branches can be simplified.
// If so, just compare the other branch and select the appropriate result.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index c67e27e5b3e7c..ed652720f0dac 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -744,6 +744,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
const APInt &C);
Instruction *foldICmpTruncWithTruncOrExt(ICmpInst &Cmp,
const SimplifyQuery &Q);
+ Instruction *foldICmpWithSextAndAdd(ICmpInst &ICmp);
Instruction *foldICmpAndConstant(ICmpInst &Cmp, BinaryOperator *And,
const APInt &C);
Instruction *foldICmpXorConstant(ICmpInst &Cmp, BinaryOperator *Xor,
diff --git a/llvm/test/Transforms/InstCombine/icmp-add.ll b/llvm/test/Transforms/InstCombine/icmp-add.ll
index 1a41c1f3e1045..df85f5a3caa65 100644
--- a/llvm/test/Transforms/InstCombine/icmp-add.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-add.ll
@@ -3259,9 +3259,8 @@ define i1 @zext_range_check_mergable(i8 %x) {
define i1 @sext_range_check_ult(i8 %x) {
; CHECK-LABEL: @sext_range_check_ult(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[CONV:%.*]] = sext i8 [[X:%.*]] to i32
-; CHECK-NEXT: [[ADD:%.*]] = add nsw i32 [[CONV]], -4
-; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[ADD]], 3
+; CHECK-NEXT: [[TMP0:%.*]] = add i8 [[X:%.*]], -4
+; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[TMP0]], 3
; CHECK-NEXT: ret i1 [[CMP]]
;
entry:
diff --git a/llvm/test/Transforms/InstCombine/sext_for_add_cmp.ll b/llvm/test/Transforms/InstCombine/sext_for_add_cmp.ll
new file mode 100644
index 0000000000000..b8617df750abb
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/sext_for_add_cmp.ll
@@ -0,0 +1,16 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+define i1 @sext_add_cmp_to_add_cmp(i16 noundef %x) {
+; CHECK-LABEL: define i1 @sext_add_cmp_to_add_cmp(
+; CHECK-SAME: i16 noundef [[X:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = add i16 [[X]], 128
+; CHECK-NEXT: [[OR_COND_I:%.*]] = icmp ult i16 [[TMP0]], 256
+; CHECK-NEXT: ret i1 [[OR_COND_I]]
+;
+entry:
+ %conv = sext i16 %x to i128
+ %add = add i128 %conv, 128
+ %or.cond.i = icmp ult i128 %add, 256
+ ret i1 %or.cond.i
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see llvm.org/docs/InstCombineContributorGuide.html. Please add an alive2 proof to the PR description.
m_Add(m_SExt(m_Value(X)), m_ConstantInt(Y)), | ||
m_ConstantInt(Z)))) { | ||
Type *XType = X->getType(); | ||
if (!XType->isIntegerTy()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
X must be an integer or vector of integers. We should handle both. Please remove this check and add a vector test.
// where X is a value, Y and Z are integer constants | ||
// icmp ult (add(sext(X), Y)), Z -> icmp ult (add(X, Y)), Z | ||
if (match(&ICmp, m_SpecificICmp(CmpInst::ICMP_ULT, | ||
m_Add(m_SExt(m_Value(X)), m_ConstantInt(Y)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
m_ConstantInt -> m_APInt and then work on APInt instead of int64_t returned by getSExtValue().
I've actually left a comment on the linked issue: #130088 (comment), where I need help with the issue understanding currently. |
5efe0e5
to
dde501e
Compare
dde501e
to
7d93fd6
Compare
You can test this locally with the following command:git-clang-format --diff HEAD~1 HEAD --extensions cpp,h -- llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp llvm/lib/Transforms/InstCombine/InstCombineInternal.h View the diff from clang-format here.diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 3bfa1a22c..55cea92eb 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -6395,7 +6395,8 @@ Instruction *InstCombinerImpl::foldICmpWithSextAndAdd(ICmpInst &ICmp) {
return nullptr;
unsigned XBitWidth = XType->getIntegerBitWidth();
- auto ExtractValue = [&](Value *V, Type *TargetType, int64_t &OutValue) -> Value* {
+ auto ExtractValue = [&](Value *V, Type *TargetType,
+ int64_t &OutValue) -> Value * {
if (auto *C = dyn_cast<ConstantInt>(V)) {
OutValue = C->getSExtValue();
return ConstantInt::get(TargetType, OutValue);
@@ -6414,7 +6415,8 @@ Instruction *InstCombinerImpl::foldICmpWithSextAndAdd(ICmpInst &ICmp) {
if (!NewY || !NewZ)
return nullptr;
- bool AreYZVariables = match(Y, m_SExt(m_Value())) && match(Z, m_SExt(m_Value()));
+ bool AreYZVariables =
+ match(Y, m_SExt(m_Value())) && match(Z, m_SExt(m_Value()));
if (AreYZVariables) {
bool FoundCondition = false;
BasicBlock *BB = ICmp.getParent();
@@ -6427,9 +6429,11 @@ Instruction *InstCombinerImpl::foldICmpWithSextAndAdd(ICmpInst &ICmp) {
MaxValue <<= (XBitWidth - 1);
APInt Limit = MaxValue;
Limit += 1;
- if (match(Cond, m_SpecificICmp(CmpInst::ICMP_ULT,
- m_Sub(m_SExt(m_Value(Y)), m_SExt(m_Value(Z))), m_ConstantInt(C))) &&
- C->getValue().ule(Limit)) {
+ if (match(Cond, m_SpecificICmp(
+ CmpInst::ICMP_ULT,
+ m_Sub(m_SExt(m_Value(Y)), m_SExt(m_Value(Z))),
+ m_ConstantInt(C))) &&
+ C->getValue().ule(Limit)) {
FoundCondition = true;
if (Assume->use_empty()) {
eraseInstFromFunction(*Assume);
|
Yes, now I've added a link to proof on alive2 platform as well. |
auto ExtractValue = [&](Value *V, Type *TargetType, | ||
int64_t &OutValue) -> Value * { | ||
if (auto *C = dyn_cast<ConstantInt>(V)) { | ||
OutValue = C->getSExtValue(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nikic , can you kindly help me how to avoid using int64_t
here? (You suggested in another comment that I should use APInt instead)
I've tried using:
OutValue = C->getValue().sextOrTrunc(TargetType->getIntegerBitWidth());
but I'm not sure if it's an approach in the right direction.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dtcxzyw maybe you can help me out with a possible solution to this problem? Thank you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use the m_APInt
matcher.
APInt &OutValue)
const APInt *C;
if (match(V, m_APInt(C))) {
OutValue = C->getValue().sextOrTrunc(TargetType->getScalarSizeInBits());
}
BTW, you can use m_APInt
in the matching code above (i.e., match(&ICmp, ...)
), as we don't want to handle non-constant cases in the original issue. Y and Z must be constants.
Then you can simply check if Y
and Z
have less than X->getType()->getScalarSizeInBits()
significant bits.
const APInt *Y;
const APInt *Z;
if (match(&ICmp, m_SpecificICmp(CmpInst::ICMP_ULT,
m_Add(m_SExt(m_Value(X)), m_APInt(Y)),
m_APInt(Z))) && Y->isSignedIntN(X->getType()->getScalarSizeInBits()) && Z->isSignedIntN(X->getType()->getScalarSizeInBits())) {
...
}
return nullptr; | ||
} else { | ||
auto MaxValue = (1LL << (XBitWidth - 1)); | ||
if (YValue - ZValue > MaxValue || YValue - ZValue < -MaxValue) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I should avoid using int64_t
here as well and instead use APInt
(should the bitwidth be Z->getType()->getIntegerBitWidth()
or should it be XBitWidth
?
Fixes #130088
Godbolt proof: https://godbolt.org/z/8rxGd1q73
Alive2 proof: https://alive2.llvm.org/ce/z/KA0Evi