Skip to content

Conversation

AZero13
Copy link
Contributor

@AZero13 AZero13 commented Aug 23, 2025

This can happen when comparing the two bit positions.

@llvmbot llvmbot added backend:X86 llvm:SelectionDAG SelectionDAGISel as well labels Aug 23, 2025
@llvmbot
Copy link
Member

llvmbot commented Aug 23, 2025

@llvm/pr-subscribers-llvm-selectiondag

Author: AZero13 (AZero13)

Changes

This can happen when comparing the two bit positions.


Full diff: https://github.com/llvm/llvm-project/pull/155106.diff

3 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+40)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+41)
  • (modified) llvm/test/CodeGen/X86/ctlz.ll (+61)
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 402a012e8e555..88d52979bb1f0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -5686,6 +5686,46 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
     return DAG.getSetCC(dl, VT, N0.getOperand(0), N1.getOperand(0), Cond);
   }
 
+  // Optimize arithmetic comparisons: A op B == A op C -> B == C (for op = -, ^, +)
+  if ((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
+      N0.getValueType().isInteger() && N1.getValueType().isInteger()) {
+    unsigned Opc0 = N0.getOpcode();
+    unsigned Opc1 = N1.getOpcode();
+    
+    // Check if both operands have the same arithmetic operation
+    if (Opc0 == Opc1 && 
+        (Opc0 == ISD::SUB || Opc0 == ISD::XOR || Opc0 == ISD::ADD)) {
+      SDValue A0 = N0.getOperand(0);
+      SDValue B0 = N0.getOperand(1);
+      SDValue A1 = N1.getOperand(0);
+      SDValue B1 = N1.getOperand(1);
+      
+      // Pattern: (A - B) == (A - C) -> B == C
+      // Pattern: (A ^ B) == (A ^ C) -> B == C  
+      // Pattern: (A + B) == (A + C) -> B == C
+      if (A0 == A1) {
+        return DAG.getSetCC(dl, VT, B0, B1, Cond);
+      }
+      
+      // For commutative operations (XOR, ADD), also try the other combinations
+      if (Opc0 == ISD::XOR || Opc0 == ISD::ADD) {
+        // Pattern: (A ^ B) == (C ^ A) -> B == C
+        // Pattern: (A + B) == (C + A) -> B == C
+        if (A0 == B1) {
+          return DAG.getSetCC(dl, VT, B0, A1, Cond);
+        }
+        // Pattern: (B ^ A) == (A ^ C) -> B == C
+        if (B0 == A1) {
+          return DAG.getSetCC(dl, VT, A0, B1, Cond);
+        }
+        // Pattern: (B ^ A) == (C ^ A) -> B == C
+        if (B0 == B1) {
+          return DAG.getSetCC(dl, VT, A0, A1, Cond);
+        }
+      }
+    }
+  }
+
   // Could not fold it.
   return SDValue();
 }
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 19131fbd4102b..6ecb6268d4637 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -49009,6 +49009,45 @@ static SDValue combineSetCCMOVMSK(SDValue EFLAGS, X86::CondCode &CC,
 static SDValue combineSetCCEFLAGS(SDValue EFLAGS, X86::CondCode &CC,
                                   SelectionDAG &DAG,
                                   const X86Subtarget &Subtarget) {
+  // For EQ/NE, rewrite SUB(xor A,B, xor C,B) -> SUB(A,C) and
+  //            SUB(add A,K, add C,K)       -> SUB(A,C)
+  // Only safe if the arithmetic result (value #0) is unused, since we only
+  // care about flags for EQ/NE. This avoids creating new nodes/VTLists.
+  if ((CC == X86::COND_E || CC == X86::COND_NE) &&
+      EFLAGS.getOpcode() == X86ISD::SUB &&
+      !EFLAGS.getNode()->hasAnyUseOfValue(0)) {
+    using namespace llvm::SDPatternMatch;
+    SDValue LHS = EFLAGS.getOperand(0);
+    SDValue RHS = EFLAGS.getOperand(1);
+    SDValue A, B, C, K;
+
+    bool DidRewrite = false;
+
+    // (xor A, B) - (xor C, B) -> A - C
+    if (!DidRewrite && sd_match(LHS, m_Xor(m_Value(A), m_Value(B))) &&
+        sd_match(RHS, m_Xor(m_Value(C), m_Specific(B)))) {
+      DAG.UpdateNodeOperands(EFLAGS.getNode(), A, C);
+      DidRewrite = true;
+    }
+
+    // (add A, K) - (add C, K) -> A - C
+    if (!DidRewrite && sd_match(LHS, m_Add(m_Value(A), m_Value(K))) &&
+        sd_match(RHS, m_Add(m_Value(C), m_Specific(K)))) {
+      DAG.UpdateNodeOperands(EFLAGS.getNode(), A, C);
+      DidRewrite = true;
+    }
+
+    // (sub A, K) - (sub C, K) -> A - C
+    if (!DidRewrite && sd_match(LHS, m_Sub(m_Value(A), m_Value(K))) &&
+        sd_match(RHS, m_Sub(m_Value(C), m_Specific(K)))) {
+      DAG.UpdateNodeOperands(EFLAGS.getNode(), A, C);
+      DidRewrite = true;
+    }
+
+    if (DidRewrite)
+      return SDValue(EFLAGS.getNode(), 1);
+  }
+
   if (CC == X86::COND_B)
     if (SDValue Flags = combineCarryThroughADD(EFLAGS, DAG))
       return Flags;
@@ -56893,6 +56932,8 @@ static SDValue combineX86SetCC(SDNode *N, SelectionDAG &DAG,
   X86::CondCode CC = X86::CondCode(N->getConstantOperandVal(0));
   SDValue EFLAGS = N->getOperand(1);
 
+
+
   // Try to simplify the EFLAGS and condition code operands.
   if (SDValue Flags = combineSetCCEFLAGS(EFLAGS, CC, DAG, Subtarget))
     return getSETCC(CC, Flags, DL, DAG);
diff --git a/llvm/test/CodeGen/X86/ctlz.ll b/llvm/test/CodeGen/X86/ctlz.ll
index 1267fe9033454..43b654caee9f9 100644
--- a/llvm/test/CodeGen/X86/ctlz.ll
+++ b/llvm/test/CodeGen/X86/ctlz.ll
@@ -1334,3 +1334,64 @@ define i64 @ctlz_i32_zext(i32 %x) {
   %ext = zext i32 %xor to i64
   ret i64 %ext
 }
+
+define i32 @bsr_eq(i32 %a, i32 %b) {
+; X86-LABEL: bsr_eq:
+; X86:       # %bb.0:
+; X86-NEXT:    bsrl {{[0-9]+}}(%esp), %ecx
+; X86-NEXT:    bsrl {{[0-9]+}}(%esp), %edx
+; X86-NEXT:    xorl %eax, %eax
+; X86-NEXT:    cmpl %edx, %ecx
+; X86-NEXT:    sete %al
+; X86-NEXT:    retl
+;
+; X64-LABEL: bsr_eq:
+; X64:       # %bb.0:
+; X64-NEXT:    bsrl %edi, %ecx
+; X64-NEXT:    bsrl %esi, %edx
+; X64-NEXT:    xorl %eax, %eax
+; X64-NEXT:    cmpl %edx, %ecx
+; X64-NEXT:    sete %al
+; X64-NEXT:    retq
+;
+; X86-CLZ-LABEL: bsr_eq:
+; X86-CLZ:       # %bb.0:
+; X86-CLZ-NEXT:    lzcntl {{[0-9]+}}(%esp), %ecx
+; X86-CLZ-NEXT:    lzcntl {{[0-9]+}}(%esp), %edx
+; X86-CLZ-NEXT:    xorl %eax, %eax
+; X86-CLZ-NEXT:    cmpl %edx, %ecx
+; X86-CLZ-NEXT:    sete %al
+; X86-CLZ-NEXT:    retl
+;
+; X64-CLZ-LABEL: bsr_eq:
+; X64-CLZ:       # %bb.0:
+; X64-CLZ-NEXT:    lzcntl %edi, %ecx
+; X64-CLZ-NEXT:    lzcntl %esi, %edx
+; X64-CLZ-NEXT:    xorl %eax, %eax
+; X64-CLZ-NEXT:    cmpl %edx, %ecx
+; X64-CLZ-NEXT:    sete %al
+; X64-CLZ-NEXT:    retq
+;
+; X64-FASTLZCNT-LABEL: bsr_eq:
+; X64-FASTLZCNT:       # %bb.0:
+; X64-FASTLZCNT-NEXT:    lzcntl %edi, %ecx
+; X64-FASTLZCNT-NEXT:    lzcntl %esi, %edx
+; X64-FASTLZCNT-NEXT:    xorl %eax, %eax
+; X64-FASTLZCNT-NEXT:    cmpl %edx, %ecx
+; X64-FASTLZCNT-NEXT:    sete %al
+; X64-FASTLZCNT-NEXT:    retq
+;
+; X86-FASTLZCNT-LABEL: bsr_eq:
+; X86-FASTLZCNT:       # %bb.0:
+; X86-FASTLZCNT-NEXT:    lzcntl {{[0-9]+}}(%esp), %ecx
+; X86-FASTLZCNT-NEXT:    lzcntl {{[0-9]+}}(%esp), %edx
+; X86-FASTLZCNT-NEXT:    xorl %eax, %eax
+; X86-FASTLZCNT-NEXT:    cmpl %edx, %ecx
+; X86-FASTLZCNT-NEXT:    sete %al
+; X86-FASTLZCNT-NEXT:    retl
+  %1 = tail call range(i32 0, 33) i32 @llvm.ctlz.i32(i32 %a, i1 true)
+  %2 = tail call range(i32 0, 33) i32 @llvm.ctlz.i32(i32 %b, i1 true)
+  %cmp = icmp eq i32 %1, %2
+  %conv = zext i1 %cmp to i32
+  ret i32 %conv
+}

@llvmbot
Copy link
Member

llvmbot commented Aug 23, 2025

@llvm/pr-subscribers-backend-x86

Author: AZero13 (AZero13)

Changes

This can happen when comparing the two bit positions.


Full diff: https://github.com/llvm/llvm-project/pull/155106.diff

3 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+40)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+41)
  • (modified) llvm/test/CodeGen/X86/ctlz.ll (+61)
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 402a012e8e555..88d52979bb1f0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -5686,6 +5686,46 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
     return DAG.getSetCC(dl, VT, N0.getOperand(0), N1.getOperand(0), Cond);
   }
 
+  // Optimize arithmetic comparisons: A op B == A op C -> B == C (for op = -, ^, +)
+  if ((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
+      N0.getValueType().isInteger() && N1.getValueType().isInteger()) {
+    unsigned Opc0 = N0.getOpcode();
+    unsigned Opc1 = N1.getOpcode();
+    
+    // Check if both operands have the same arithmetic operation
+    if (Opc0 == Opc1 && 
+        (Opc0 == ISD::SUB || Opc0 == ISD::XOR || Opc0 == ISD::ADD)) {
+      SDValue A0 = N0.getOperand(0);
+      SDValue B0 = N0.getOperand(1);
+      SDValue A1 = N1.getOperand(0);
+      SDValue B1 = N1.getOperand(1);
+      
+      // Pattern: (A - B) == (A - C) -> B == C
+      // Pattern: (A ^ B) == (A ^ C) -> B == C  
+      // Pattern: (A + B) == (A + C) -> B == C
+      if (A0 == A1) {
+        return DAG.getSetCC(dl, VT, B0, B1, Cond);
+      }
+      
+      // For commutative operations (XOR, ADD), also try the other combinations
+      if (Opc0 == ISD::XOR || Opc0 == ISD::ADD) {
+        // Pattern: (A ^ B) == (C ^ A) -> B == C
+        // Pattern: (A + B) == (C + A) -> B == C
+        if (A0 == B1) {
+          return DAG.getSetCC(dl, VT, B0, A1, Cond);
+        }
+        // Pattern: (B ^ A) == (A ^ C) -> B == C
+        if (B0 == A1) {
+          return DAG.getSetCC(dl, VT, A0, B1, Cond);
+        }
+        // Pattern: (B ^ A) == (C ^ A) -> B == C
+        if (B0 == B1) {
+          return DAG.getSetCC(dl, VT, A0, A1, Cond);
+        }
+      }
+    }
+  }
+
   // Could not fold it.
   return SDValue();
 }
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 19131fbd4102b..6ecb6268d4637 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -49009,6 +49009,45 @@ static SDValue combineSetCCMOVMSK(SDValue EFLAGS, X86::CondCode &CC,
 static SDValue combineSetCCEFLAGS(SDValue EFLAGS, X86::CondCode &CC,
                                   SelectionDAG &DAG,
                                   const X86Subtarget &Subtarget) {
+  // For EQ/NE, rewrite SUB(xor A,B, xor C,B) -> SUB(A,C) and
+  //            SUB(add A,K, add C,K)       -> SUB(A,C)
+  // Only safe if the arithmetic result (value #0) is unused, since we only
+  // care about flags for EQ/NE. This avoids creating new nodes/VTLists.
+  if ((CC == X86::COND_E || CC == X86::COND_NE) &&
+      EFLAGS.getOpcode() == X86ISD::SUB &&
+      !EFLAGS.getNode()->hasAnyUseOfValue(0)) {
+    using namespace llvm::SDPatternMatch;
+    SDValue LHS = EFLAGS.getOperand(0);
+    SDValue RHS = EFLAGS.getOperand(1);
+    SDValue A, B, C, K;
+
+    bool DidRewrite = false;
+
+    // (xor A, B) - (xor C, B) -> A - C
+    if (!DidRewrite && sd_match(LHS, m_Xor(m_Value(A), m_Value(B))) &&
+        sd_match(RHS, m_Xor(m_Value(C), m_Specific(B)))) {
+      DAG.UpdateNodeOperands(EFLAGS.getNode(), A, C);
+      DidRewrite = true;
+    }
+
+    // (add A, K) - (add C, K) -> A - C
+    if (!DidRewrite && sd_match(LHS, m_Add(m_Value(A), m_Value(K))) &&
+        sd_match(RHS, m_Add(m_Value(C), m_Specific(K)))) {
+      DAG.UpdateNodeOperands(EFLAGS.getNode(), A, C);
+      DidRewrite = true;
+    }
+
+    // (sub A, K) - (sub C, K) -> A - C
+    if (!DidRewrite && sd_match(LHS, m_Sub(m_Value(A), m_Value(K))) &&
+        sd_match(RHS, m_Sub(m_Value(C), m_Specific(K)))) {
+      DAG.UpdateNodeOperands(EFLAGS.getNode(), A, C);
+      DidRewrite = true;
+    }
+
+    if (DidRewrite)
+      return SDValue(EFLAGS.getNode(), 1);
+  }
+
   if (CC == X86::COND_B)
     if (SDValue Flags = combineCarryThroughADD(EFLAGS, DAG))
       return Flags;
@@ -56893,6 +56932,8 @@ static SDValue combineX86SetCC(SDNode *N, SelectionDAG &DAG,
   X86::CondCode CC = X86::CondCode(N->getConstantOperandVal(0));
   SDValue EFLAGS = N->getOperand(1);
 
+
+
   // Try to simplify the EFLAGS and condition code operands.
   if (SDValue Flags = combineSetCCEFLAGS(EFLAGS, CC, DAG, Subtarget))
     return getSETCC(CC, Flags, DL, DAG);
diff --git a/llvm/test/CodeGen/X86/ctlz.ll b/llvm/test/CodeGen/X86/ctlz.ll
index 1267fe9033454..43b654caee9f9 100644
--- a/llvm/test/CodeGen/X86/ctlz.ll
+++ b/llvm/test/CodeGen/X86/ctlz.ll
@@ -1334,3 +1334,64 @@ define i64 @ctlz_i32_zext(i32 %x) {
   %ext = zext i32 %xor to i64
   ret i64 %ext
 }
+
+define i32 @bsr_eq(i32 %a, i32 %b) {
+; X86-LABEL: bsr_eq:
+; X86:       # %bb.0:
+; X86-NEXT:    bsrl {{[0-9]+}}(%esp), %ecx
+; X86-NEXT:    bsrl {{[0-9]+}}(%esp), %edx
+; X86-NEXT:    xorl %eax, %eax
+; X86-NEXT:    cmpl %edx, %ecx
+; X86-NEXT:    sete %al
+; X86-NEXT:    retl
+;
+; X64-LABEL: bsr_eq:
+; X64:       # %bb.0:
+; X64-NEXT:    bsrl %edi, %ecx
+; X64-NEXT:    bsrl %esi, %edx
+; X64-NEXT:    xorl %eax, %eax
+; X64-NEXT:    cmpl %edx, %ecx
+; X64-NEXT:    sete %al
+; X64-NEXT:    retq
+;
+; X86-CLZ-LABEL: bsr_eq:
+; X86-CLZ:       # %bb.0:
+; X86-CLZ-NEXT:    lzcntl {{[0-9]+}}(%esp), %ecx
+; X86-CLZ-NEXT:    lzcntl {{[0-9]+}}(%esp), %edx
+; X86-CLZ-NEXT:    xorl %eax, %eax
+; X86-CLZ-NEXT:    cmpl %edx, %ecx
+; X86-CLZ-NEXT:    sete %al
+; X86-CLZ-NEXT:    retl
+;
+; X64-CLZ-LABEL: bsr_eq:
+; X64-CLZ:       # %bb.0:
+; X64-CLZ-NEXT:    lzcntl %edi, %ecx
+; X64-CLZ-NEXT:    lzcntl %esi, %edx
+; X64-CLZ-NEXT:    xorl %eax, %eax
+; X64-CLZ-NEXT:    cmpl %edx, %ecx
+; X64-CLZ-NEXT:    sete %al
+; X64-CLZ-NEXT:    retq
+;
+; X64-FASTLZCNT-LABEL: bsr_eq:
+; X64-FASTLZCNT:       # %bb.0:
+; X64-FASTLZCNT-NEXT:    lzcntl %edi, %ecx
+; X64-FASTLZCNT-NEXT:    lzcntl %esi, %edx
+; X64-FASTLZCNT-NEXT:    xorl %eax, %eax
+; X64-FASTLZCNT-NEXT:    cmpl %edx, %ecx
+; X64-FASTLZCNT-NEXT:    sete %al
+; X64-FASTLZCNT-NEXT:    retq
+;
+; X86-FASTLZCNT-LABEL: bsr_eq:
+; X86-FASTLZCNT:       # %bb.0:
+; X86-FASTLZCNT-NEXT:    lzcntl {{[0-9]+}}(%esp), %ecx
+; X86-FASTLZCNT-NEXT:    lzcntl {{[0-9]+}}(%esp), %edx
+; X86-FASTLZCNT-NEXT:    xorl %eax, %eax
+; X86-FASTLZCNT-NEXT:    cmpl %edx, %ecx
+; X86-FASTLZCNT-NEXT:    sete %al
+; X86-FASTLZCNT-NEXT:    retl
+  %1 = tail call range(i32 0, 33) i32 @llvm.ctlz.i32(i32 %a, i1 true)
+  %2 = tail call range(i32 0, 33) i32 @llvm.ctlz.i32(i32 %b, i1 true)
+  %cmp = icmp eq i32 %1, %2
+  %conv = zext i1 %cmp to i32
+  ret i32 %conv
+}

Copy link

github-actions bot commented Aug 23, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@AZero13 AZero13 changed the title [X86] Implement canceling out of XOR with sub nodes [X86] Implement canceling out of XOR with equality Aug 23, 2025
This can happen when comparing the two bit positions.
@AZero13 AZero13 requested a review from arsenm August 24, 2025 00:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:X86 llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants