diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index cf94d28100488..f170d8ba2512a 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -6381,6 +6381,107 @@ Instruction *InstCombinerImpl::foldICmpWithZextOrSext(ICmpInst &ICmp) { return new ICmpInst(CmpInst::ICMP_SLT, X, Constant::getNullValue(SrcTy)); } +Instruction *InstCombinerImpl::foldICmpWithSextAndAdd(ICmpInst &ICmp) { + Value *X; + Value *Y, *Z; + // Match the pattern: icmp ult (add (sext X), Y), Z + // where X is a value, Y and Z are integer constants + // icmp ult (add(sext(X), Y)), Z -> icmp ult (add(X, Y)), Z + if (match(&ICmp, m_SpecificICmp(CmpInst::ICMP_ULT, + m_Add(m_SExt(m_Value(X)), m_Value(Y)), + m_Value(Z)))) { + Type *XType = X->getType(); + if (!XType->isIntegerTy()) + return nullptr; + + unsigned XBitWidth = XType->getIntegerBitWidth(); + auto ExtractValue = [&](Value *V, Type *TargetType, + int64_t &OutValue) -> Value * { + if (auto *C = dyn_cast(V)) { + OutValue = C->getSExtValue(); + return ConstantInt::get(TargetType, OutValue); + } + if (match(V, m_SExt(m_Value()))) { + Value *Src = cast(V)->getOperand(0); + if (Src->getType() == TargetType) + return Src; + } + return nullptr; + }; + + int64_t YValue, ZValue; + Value *NewY = ExtractValue(Y, XType, YValue); + Value *NewZ = ExtractValue(Z, XType, ZValue); + if (!NewY || !NewZ) + return nullptr; + + bool IsYVariable = match(Y, m_SExt(m_Value())); + bool IsYOrZVariable = + match(Y, m_SExt(m_Value())) || match(Z, m_SExt(m_Value())); + if (IsYOrZVariable) { + bool FoundCondition = false; + BasicBlock *BB = ICmp.getParent(); + for (Instruction &I : *BB) { + if (auto *Assume = dyn_cast(&I)) { + if (Assume->getIntrinsicID() == Intrinsic::assume) { + Value *Cond = Assume->getOperand(0); + ConstantInt *C, *C2; + unsigned CommonWidth = Z->getType()->getIntegerBitWidth(); + APInt MaxValue(CommonWidth, 1); + MaxValue <<= (XBitWidth - 1); + APInt Limit = MaxValue; + Limit += 1; + // if both "Y" and "Z" are variables, Op would be "Sub" + if (match(Cond, m_SpecificICmp( + CmpInst::ICMP_ULT, + m_Sub(m_SExt(m_Value(Y)), m_SExt(m_Value(Z))), + m_ConstantInt(C))) || + match(Cond, + m_SpecificICmp(CmpInst::ICMP_ULT, + m_Add(m_SExt(m_Value(Y)), m_Value(Z)), + m_ConstantInt(C)))) { + if (C->getValue().ugt(Limit)) + return nullptr; + + FoundCondition = true; + if (Assume->use_empty()) { + eraseInstFromFunction(*Assume); + } + break; + } + if (!IsYVariable && + match(Cond, m_SpecificICmp(CmpInst::ICMP_ULT, + m_Add(m_Value(Z), m_ConstantInt(C2)), + m_ConstantInt(C)))) { + auto *K = llvm::dyn_cast(Y); + if ((C2->getValue().sextOrTrunc(CommonWidth) - + C->getValue().sextOrTrunc(CommonWidth)) + .slt(K->getValue() - Limit)) + return nullptr; + + FoundCondition = true; + if (Assume->use_empty()) { + eraseInstFromFunction(*Assume); + } + break; + } + } + } + } + if (!FoundCondition) + return nullptr; + } else { + auto MaxValue = (1LL << (XBitWidth - 1)); + if (YValue - ZValue > MaxValue || YValue - ZValue < -MaxValue) + return nullptr; + } + + Value *NewAdd = Builder.CreateAdd(X, NewY); + return new ICmpInst(CmpInst::ICMP_ULT, NewAdd, NewZ); + } + return nullptr; +} + /// Handle icmp (cast x), (cast or constant). Instruction *InstCombinerImpl::foldICmpWithCastOp(ICmpInst &ICmp) { // If any operand of ICmp is a inttoptr roundtrip cast then remove it as @@ -7727,6 +7828,10 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { } } + if (Instruction *Res = foldICmpWithSextAndAdd(I)) { + return Res; + } + // In case of a comparison with two select instructions having the same // condition, check whether one of the resulting branches can be simplified. // If so, just compare the other branch and select the appropriate result. diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index c67e27e5b3e7c..ed652720f0dac 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -744,6 +744,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final const APInt &C); Instruction *foldICmpTruncWithTruncOrExt(ICmpInst &Cmp, const SimplifyQuery &Q); + Instruction *foldICmpWithSextAndAdd(ICmpInst &ICmp); Instruction *foldICmpAndConstant(ICmpInst &Cmp, BinaryOperator *And, const APInt &C); Instruction *foldICmpXorConstant(ICmpInst &Cmp, BinaryOperator *Xor, diff --git a/llvm/test/Transforms/InstCombine/icmp-add.ll b/llvm/test/Transforms/InstCombine/icmp-add.ll index 1a41c1f3e1045..df85f5a3caa65 100644 --- a/llvm/test/Transforms/InstCombine/icmp-add.ll +++ b/llvm/test/Transforms/InstCombine/icmp-add.ll @@ -3259,9 +3259,8 @@ define i1 @zext_range_check_mergable(i8 %x) { define i1 @sext_range_check_ult(i8 %x) { ; CHECK-LABEL: @sext_range_check_ult( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[CONV:%.*]] = sext i8 [[X:%.*]] to i32 -; CHECK-NEXT: [[ADD:%.*]] = add nsw i32 [[CONV]], -4 -; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[ADD]], 3 +; CHECK-NEXT: [[TMP0:%.*]] = add i8 [[X:%.*]], -4 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[TMP0]], 3 ; CHECK-NEXT: ret i1 [[CMP]] ; entry: diff --git a/llvm/test/Transforms/InstCombine/sext_for_add_cmp.ll b/llvm/test/Transforms/InstCombine/sext_for_add_cmp.ll new file mode 100644 index 0000000000000..b8617df750abb --- /dev/null +++ b/llvm/test/Transforms/InstCombine/sext_for_add_cmp.ll @@ -0,0 +1,16 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt < %s -passes=instcombine -S | FileCheck %s +define i1 @sext_add_cmp_to_add_cmp(i16 noundef %x) { +; CHECK-LABEL: define i1 @sext_add_cmp_to_add_cmp( +; CHECK-SAME: i16 noundef [[X:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[TMP0:%.*]] = add i16 [[X]], 128 +; CHECK-NEXT: [[OR_COND_I:%.*]] = icmp ult i16 [[TMP0]], 256 +; CHECK-NEXT: ret i1 [[OR_COND_I]] +; +entry: + %conv = sext i16 %x to i128 + %add = add i128 %conv, 128 + %or.cond.i = icmp ult i128 %add, 256 + ret i1 %or.cond.i +}