Skip to content

Conversation

AZero13
Copy link
Contributor

@AZero13 AZero13 commented Aug 25, 2025

Remove redundant fold regarding CSEL thanks to this too.

@llvmbot
Copy link
Member

llvmbot commented Aug 25, 2025

@llvm/pr-subscribers-backend-aarch64

@llvm/pr-subscribers-backend-arm

Author: AZero13 (AZero13)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/155311.diff

3 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+38-3)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp (+31)
  • (modified) llvm/lib/Target/ARM/ARMISelLowering.cpp (+13)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index fbd8f7a979d66..3826ad61a5823 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -3586,7 +3586,8 @@ static SDValue emitStrictFPComparison(SDValue LHS, SDValue RHS, const SDLoc &DL,
 }
 
 static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC,
-                              const SDLoc &DL, SelectionDAG &DAG) {
+                              const SDLoc &DL, SelectionDAG &DAG,
+                              bool optimizeMIOrPL = false) {
   EVT VT = LHS.getValueType();
   const bool FullFP16 = DAG.getSubtarget<AArch64Subtarget>().hasFullFP16();
 
@@ -3630,6 +3631,40 @@ static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC,
       // Use result of ANDS
       return LHS.getValue(1);
     }
+
+    if (LHS.getOpcode() == ISD::SUB) {
+      if (LHS->getFlags().hasNoSignedWrap() ||
+          ((CC == ISD::SETLT || CC == ISD::SETGE) && optimizeMIOrPL)) {
+        const SDValue SUBSNode =
+            DAG.getNode(AArch64ISD::SUBS, DL, DAG.getVTList(VT, FlagsVT),
+                        LHS.getOperand(0), LHS.getOperand(1));
+        // Replace all users of (and X, Y) with newly generated (ands X, Y)
+        DAG.ReplaceAllUsesWith(LHS, SUBSNode);
+        return SUBSNode.getValue(1);
+      }
+    } else if (LHS.getOpcode() == AArch64ISD::SUBS) {
+      if (LHS->getFlags().hasNoSignedWrap() ||
+          ((CC == ISD::SETLT || CC == ISD::SETGE) && optimizeMIOrPL)) {
+        return LHS.getValue(1);
+      }
+    }
+
+    if (LHS.getOpcode() == ISD::ADD) {
+      if (LHS->getFlags().hasNoSignedWrap() ||
+          ((CC == ISD::SETLT || CC == ISD::SETGE) && optimizeMIOrPL)) {
+        const SDValue ADDSNode =
+            DAG.getNode(AArch64ISD::ADDS, DL, DAG.getVTList(VT, FlagsVT),
+                        LHS.getOperand(0), LHS.getOperand(1));
+        // Replace all users of (and X, Y) with newly generated (ands X, Y)
+        DAG.ReplaceAllUsesWith(LHS, ADDSNode);
+        return ADDSNode.getValue(1);
+      }
+    } else if (LHS.getOpcode() == AArch64ISD::ADDS) {
+      if (LHS->getFlags().hasNoSignedWrap() ||
+          ((CC == ISD::SETLT || CC == ISD::SETGE) && optimizeMIOrPL)) {
+        return LHS.getValue(1);
+      }
+    }
   }
 
   return DAG.getNode(Opcode, DL, DAG.getVTList(VT, FlagsVT), LHS, RHS)
@@ -3843,7 +3878,7 @@ static SDValue emitConjunctionRec(SelectionDAG &DAG, SDValue Val,
 
     // Produce a normal comparison if we are first in the chain
     if (!CCOp)
-      return emitComparison(LHS, RHS, CC, DL, DAG);
+      return emitComparison(LHS, RHS, CC, DL, DAG, isInteger);
     // Otherwise produce a ccmp.
     return emitConditionalComparison(LHS, RHS, CC, CCOp, Predicate, OutCC, DL,
                                      DAG);
@@ -4125,7 +4160,7 @@ static SDValue getAArch64Cmp(SDValue LHS, SDValue RHS, ISD::CondCode CC,
   }
 
   if (!Cmp) {
-    Cmp = emitComparison(LHS, RHS, CC, DL, DAG);
+    Cmp = emitComparison(LHS, RHS, CC, DL, DAG, true);
     AArch64CC = changeIntCCToAArch64CC(CC, RHS);
   }
   AArch64cc = getCondCode(DAG, AArch64CC);
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
index 0bceb322726d1..a96323e6534eb 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
@@ -4466,6 +4466,7 @@ MachineInstr *AArch64InstructionSelector::emitIntegerCompare(
   // Fold the compare into a cmn or tst if possible.
   if (auto FoldCmp = tryFoldIntegerCompare(LHS, RHS, Predicate, MIRBuilder))
     return FoldCmp;
+
   auto Dst = MRI.cloneVirtualRegister(LHS.getReg());
   return emitSUBS(Dst, LHS, RHS, MIRBuilder);
 }
@@ -5102,6 +5103,36 @@ MachineInstr *AArch64InstructionSelector::tryFoldIntegerCompare(
     return emitCMN(LHSDef->getOperand(2), RHS, MIRBuilder);
   }
 
+  // Given this:
+  //
+  // z = G_SUB/G_ADD x, y
+  // G_ICMP z, 0
+  //
+  // Produce this if the compare is signed:
+  //
+  // cmp/cmn x, y
+  if ((LHSDef->getFlag(MachineInstr::NoSWrap) && !CmpInst::isUnsigned(P)) ||
+      (P == CmpInst::ICMP_EQ || P == CmpInst::ICMP_NE ||
+       P == CmpInst::ICMP_SLT || P == CmpInst::ICMP_SGE)) {
+
+    if (LHSDef->getOpcode() == TargetOpcode::G_SUB ||
+        LHSDef->getOpcode() == TargetOpcode::G_ADD) {
+      // Make sure that the RHS is 0.
+      auto ValAndVReg = getIConstantVRegValWithLookThrough(RHS.getReg(), MRI);
+      if (!ValAndVReg || ValAndVReg->Value != 0)
+        return nullptr;
+
+      if (LHSDef->getOpcode() == TargetOpcode::G_SUB) {
+        auto Dst = MRI.cloneVirtualRegister(LHS.getReg());
+        return emitSUBS(Dst, LHSDef->getOperand(1), LHSDef->getOperand(2),
+                        MIRBuilder);
+      } else {
+        return emitCMN(LHSDef->getOperand(1), LHSDef->getOperand(2),
+                       MIRBuilder);
+      }
+    }
+  }
+
   // Given this:
   //
   // z = G_AND x, y
diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index 12d2d678ff63a..072392bc5315b 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -4808,6 +4808,19 @@ SDValue ARMTargetLowering::getARMCmp(SDValue LHS, SDValue RHS, ISD::CondCode CC,
     CompareType = ARMISD::CMPZ;
     break;
   }
+
+  // If we have MI or PL and a sub, we can just do that instead of a CMP.
+  if (CondCode == ARMCC::MI || CondCode == ARMCC::PL || CondCode == ARMCC::EQ ||
+      CondCode == ARMCC::NE ||
+      (LHS.getFlags().hasNoSignedWrap() &&
+       (CondCode == ARMCC::LT || CondCode == ARMCC::GE ||
+        CondCode == ARMCC::LE || CondCode == ARMCC::GT))) {
+    if (LHS.getOpcode() == ISD::SUB) {
+      ARMcc = DAG.getConstant(CondCode, dl, MVT::i32);
+      return DAG.getNode(CompareType, dl, FlagsVT, LHS.getOperand(0),
+                         LHS.getOperand(1));
+    }
+  }
   ARMcc = DAG.getConstant(CondCode, dl, MVT::i32);
   return DAG.getNode(CompareType, dl, FlagsVT, LHS, RHS);
 }

@AZero13 AZero13 marked this pull request as draft August 25, 2025 21:56
Remove redundant fold regarding CSEL thanks to this too.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants