-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[X86] Implement canceling out of XOR with equality #155106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-llvm-selectiondag Author: AZero13 (AZero13) ChangesThis can happen when comparing the two bit positions. Full diff: https://github.com/llvm/llvm-project/pull/155106.diff 3 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 402a012e8e555..88d52979bb1f0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -5686,6 +5686,46 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
return DAG.getSetCC(dl, VT, N0.getOperand(0), N1.getOperand(0), Cond);
}
+ // Optimize arithmetic comparisons: A op B == A op C -> B == C (for op = -, ^, +)
+ if ((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
+ N0.getValueType().isInteger() && N1.getValueType().isInteger()) {
+ unsigned Opc0 = N0.getOpcode();
+ unsigned Opc1 = N1.getOpcode();
+
+ // Check if both operands have the same arithmetic operation
+ if (Opc0 == Opc1 &&
+ (Opc0 == ISD::SUB || Opc0 == ISD::XOR || Opc0 == ISD::ADD)) {
+ SDValue A0 = N0.getOperand(0);
+ SDValue B0 = N0.getOperand(1);
+ SDValue A1 = N1.getOperand(0);
+ SDValue B1 = N1.getOperand(1);
+
+ // Pattern: (A - B) == (A - C) -> B == C
+ // Pattern: (A ^ B) == (A ^ C) -> B == C
+ // Pattern: (A + B) == (A + C) -> B == C
+ if (A0 == A1) {
+ return DAG.getSetCC(dl, VT, B0, B1, Cond);
+ }
+
+ // For commutative operations (XOR, ADD), also try the other combinations
+ if (Opc0 == ISD::XOR || Opc0 == ISD::ADD) {
+ // Pattern: (A ^ B) == (C ^ A) -> B == C
+ // Pattern: (A + B) == (C + A) -> B == C
+ if (A0 == B1) {
+ return DAG.getSetCC(dl, VT, B0, A1, Cond);
+ }
+ // Pattern: (B ^ A) == (A ^ C) -> B == C
+ if (B0 == A1) {
+ return DAG.getSetCC(dl, VT, A0, B1, Cond);
+ }
+ // Pattern: (B ^ A) == (C ^ A) -> B == C
+ if (B0 == B1) {
+ return DAG.getSetCC(dl, VT, A0, A1, Cond);
+ }
+ }
+ }
+ }
+
// Could not fold it.
return SDValue();
}
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 19131fbd4102b..6ecb6268d4637 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -49009,6 +49009,45 @@ static SDValue combineSetCCMOVMSK(SDValue EFLAGS, X86::CondCode &CC,
static SDValue combineSetCCEFLAGS(SDValue EFLAGS, X86::CondCode &CC,
SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
+ // For EQ/NE, rewrite SUB(xor A,B, xor C,B) -> SUB(A,C) and
+ // SUB(add A,K, add C,K) -> SUB(A,C)
+ // Only safe if the arithmetic result (value #0) is unused, since we only
+ // care about flags for EQ/NE. This avoids creating new nodes/VTLists.
+ if ((CC == X86::COND_E || CC == X86::COND_NE) &&
+ EFLAGS.getOpcode() == X86ISD::SUB &&
+ !EFLAGS.getNode()->hasAnyUseOfValue(0)) {
+ using namespace llvm::SDPatternMatch;
+ SDValue LHS = EFLAGS.getOperand(0);
+ SDValue RHS = EFLAGS.getOperand(1);
+ SDValue A, B, C, K;
+
+ bool DidRewrite = false;
+
+ // (xor A, B) - (xor C, B) -> A - C
+ if (!DidRewrite && sd_match(LHS, m_Xor(m_Value(A), m_Value(B))) &&
+ sd_match(RHS, m_Xor(m_Value(C), m_Specific(B)))) {
+ DAG.UpdateNodeOperands(EFLAGS.getNode(), A, C);
+ DidRewrite = true;
+ }
+
+ // (add A, K) - (add C, K) -> A - C
+ if (!DidRewrite && sd_match(LHS, m_Add(m_Value(A), m_Value(K))) &&
+ sd_match(RHS, m_Add(m_Value(C), m_Specific(K)))) {
+ DAG.UpdateNodeOperands(EFLAGS.getNode(), A, C);
+ DidRewrite = true;
+ }
+
+ // (sub A, K) - (sub C, K) -> A - C
+ if (!DidRewrite && sd_match(LHS, m_Sub(m_Value(A), m_Value(K))) &&
+ sd_match(RHS, m_Sub(m_Value(C), m_Specific(K)))) {
+ DAG.UpdateNodeOperands(EFLAGS.getNode(), A, C);
+ DidRewrite = true;
+ }
+
+ if (DidRewrite)
+ return SDValue(EFLAGS.getNode(), 1);
+ }
+
if (CC == X86::COND_B)
if (SDValue Flags = combineCarryThroughADD(EFLAGS, DAG))
return Flags;
@@ -56893,6 +56932,8 @@ static SDValue combineX86SetCC(SDNode *N, SelectionDAG &DAG,
X86::CondCode CC = X86::CondCode(N->getConstantOperandVal(0));
SDValue EFLAGS = N->getOperand(1);
+
+
// Try to simplify the EFLAGS and condition code operands.
if (SDValue Flags = combineSetCCEFLAGS(EFLAGS, CC, DAG, Subtarget))
return getSETCC(CC, Flags, DL, DAG);
diff --git a/llvm/test/CodeGen/X86/ctlz.ll b/llvm/test/CodeGen/X86/ctlz.ll
index 1267fe9033454..43b654caee9f9 100644
--- a/llvm/test/CodeGen/X86/ctlz.ll
+++ b/llvm/test/CodeGen/X86/ctlz.ll
@@ -1334,3 +1334,64 @@ define i64 @ctlz_i32_zext(i32 %x) {
%ext = zext i32 %xor to i64
ret i64 %ext
}
+
+define i32 @bsr_eq(i32 %a, i32 %b) {
+; X86-LABEL: bsr_eq:
+; X86: # %bb.0:
+; X86-NEXT: bsrl {{[0-9]+}}(%esp), %ecx
+; X86-NEXT: bsrl {{[0-9]+}}(%esp), %edx
+; X86-NEXT: xorl %eax, %eax
+; X86-NEXT: cmpl %edx, %ecx
+; X86-NEXT: sete %al
+; X86-NEXT: retl
+;
+; X64-LABEL: bsr_eq:
+; X64: # %bb.0:
+; X64-NEXT: bsrl %edi, %ecx
+; X64-NEXT: bsrl %esi, %edx
+; X64-NEXT: xorl %eax, %eax
+; X64-NEXT: cmpl %edx, %ecx
+; X64-NEXT: sete %al
+; X64-NEXT: retq
+;
+; X86-CLZ-LABEL: bsr_eq:
+; X86-CLZ: # %bb.0:
+; X86-CLZ-NEXT: lzcntl {{[0-9]+}}(%esp), %ecx
+; X86-CLZ-NEXT: lzcntl {{[0-9]+}}(%esp), %edx
+; X86-CLZ-NEXT: xorl %eax, %eax
+; X86-CLZ-NEXT: cmpl %edx, %ecx
+; X86-CLZ-NEXT: sete %al
+; X86-CLZ-NEXT: retl
+;
+; X64-CLZ-LABEL: bsr_eq:
+; X64-CLZ: # %bb.0:
+; X64-CLZ-NEXT: lzcntl %edi, %ecx
+; X64-CLZ-NEXT: lzcntl %esi, %edx
+; X64-CLZ-NEXT: xorl %eax, %eax
+; X64-CLZ-NEXT: cmpl %edx, %ecx
+; X64-CLZ-NEXT: sete %al
+; X64-CLZ-NEXT: retq
+;
+; X64-FASTLZCNT-LABEL: bsr_eq:
+; X64-FASTLZCNT: # %bb.0:
+; X64-FASTLZCNT-NEXT: lzcntl %edi, %ecx
+; X64-FASTLZCNT-NEXT: lzcntl %esi, %edx
+; X64-FASTLZCNT-NEXT: xorl %eax, %eax
+; X64-FASTLZCNT-NEXT: cmpl %edx, %ecx
+; X64-FASTLZCNT-NEXT: sete %al
+; X64-FASTLZCNT-NEXT: retq
+;
+; X86-FASTLZCNT-LABEL: bsr_eq:
+; X86-FASTLZCNT: # %bb.0:
+; X86-FASTLZCNT-NEXT: lzcntl {{[0-9]+}}(%esp), %ecx
+; X86-FASTLZCNT-NEXT: lzcntl {{[0-9]+}}(%esp), %edx
+; X86-FASTLZCNT-NEXT: xorl %eax, %eax
+; X86-FASTLZCNT-NEXT: cmpl %edx, %ecx
+; X86-FASTLZCNT-NEXT: sete %al
+; X86-FASTLZCNT-NEXT: retl
+ %1 = tail call range(i32 0, 33) i32 @llvm.ctlz.i32(i32 %a, i1 true)
+ %2 = tail call range(i32 0, 33) i32 @llvm.ctlz.i32(i32 %b, i1 true)
+ %cmp = icmp eq i32 %1, %2
+ %conv = zext i1 %cmp to i32
+ ret i32 %conv
+}
|
@llvm/pr-subscribers-backend-x86 Author: AZero13 (AZero13) ChangesThis can happen when comparing the two bit positions. Full diff: https://github.com/llvm/llvm-project/pull/155106.diff 3 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 402a012e8e555..88d52979bb1f0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -5686,6 +5686,46 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
return DAG.getSetCC(dl, VT, N0.getOperand(0), N1.getOperand(0), Cond);
}
+ // Optimize arithmetic comparisons: A op B == A op C -> B == C (for op = -, ^, +)
+ if ((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
+ N0.getValueType().isInteger() && N1.getValueType().isInteger()) {
+ unsigned Opc0 = N0.getOpcode();
+ unsigned Opc1 = N1.getOpcode();
+
+ // Check if both operands have the same arithmetic operation
+ if (Opc0 == Opc1 &&
+ (Opc0 == ISD::SUB || Opc0 == ISD::XOR || Opc0 == ISD::ADD)) {
+ SDValue A0 = N0.getOperand(0);
+ SDValue B0 = N0.getOperand(1);
+ SDValue A1 = N1.getOperand(0);
+ SDValue B1 = N1.getOperand(1);
+
+ // Pattern: (A - B) == (A - C) -> B == C
+ // Pattern: (A ^ B) == (A ^ C) -> B == C
+ // Pattern: (A + B) == (A + C) -> B == C
+ if (A0 == A1) {
+ return DAG.getSetCC(dl, VT, B0, B1, Cond);
+ }
+
+ // For commutative operations (XOR, ADD), also try the other combinations
+ if (Opc0 == ISD::XOR || Opc0 == ISD::ADD) {
+ // Pattern: (A ^ B) == (C ^ A) -> B == C
+ // Pattern: (A + B) == (C + A) -> B == C
+ if (A0 == B1) {
+ return DAG.getSetCC(dl, VT, B0, A1, Cond);
+ }
+ // Pattern: (B ^ A) == (A ^ C) -> B == C
+ if (B0 == A1) {
+ return DAG.getSetCC(dl, VT, A0, B1, Cond);
+ }
+ // Pattern: (B ^ A) == (C ^ A) -> B == C
+ if (B0 == B1) {
+ return DAG.getSetCC(dl, VT, A0, A1, Cond);
+ }
+ }
+ }
+ }
+
// Could not fold it.
return SDValue();
}
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 19131fbd4102b..6ecb6268d4637 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -49009,6 +49009,45 @@ static SDValue combineSetCCMOVMSK(SDValue EFLAGS, X86::CondCode &CC,
static SDValue combineSetCCEFLAGS(SDValue EFLAGS, X86::CondCode &CC,
SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
+ // For EQ/NE, rewrite SUB(xor A,B, xor C,B) -> SUB(A,C) and
+ // SUB(add A,K, add C,K) -> SUB(A,C)
+ // Only safe if the arithmetic result (value #0) is unused, since we only
+ // care about flags for EQ/NE. This avoids creating new nodes/VTLists.
+ if ((CC == X86::COND_E || CC == X86::COND_NE) &&
+ EFLAGS.getOpcode() == X86ISD::SUB &&
+ !EFLAGS.getNode()->hasAnyUseOfValue(0)) {
+ using namespace llvm::SDPatternMatch;
+ SDValue LHS = EFLAGS.getOperand(0);
+ SDValue RHS = EFLAGS.getOperand(1);
+ SDValue A, B, C, K;
+
+ bool DidRewrite = false;
+
+ // (xor A, B) - (xor C, B) -> A - C
+ if (!DidRewrite && sd_match(LHS, m_Xor(m_Value(A), m_Value(B))) &&
+ sd_match(RHS, m_Xor(m_Value(C), m_Specific(B)))) {
+ DAG.UpdateNodeOperands(EFLAGS.getNode(), A, C);
+ DidRewrite = true;
+ }
+
+ // (add A, K) - (add C, K) -> A - C
+ if (!DidRewrite && sd_match(LHS, m_Add(m_Value(A), m_Value(K))) &&
+ sd_match(RHS, m_Add(m_Value(C), m_Specific(K)))) {
+ DAG.UpdateNodeOperands(EFLAGS.getNode(), A, C);
+ DidRewrite = true;
+ }
+
+ // (sub A, K) - (sub C, K) -> A - C
+ if (!DidRewrite && sd_match(LHS, m_Sub(m_Value(A), m_Value(K))) &&
+ sd_match(RHS, m_Sub(m_Value(C), m_Specific(K)))) {
+ DAG.UpdateNodeOperands(EFLAGS.getNode(), A, C);
+ DidRewrite = true;
+ }
+
+ if (DidRewrite)
+ return SDValue(EFLAGS.getNode(), 1);
+ }
+
if (CC == X86::COND_B)
if (SDValue Flags = combineCarryThroughADD(EFLAGS, DAG))
return Flags;
@@ -56893,6 +56932,8 @@ static SDValue combineX86SetCC(SDNode *N, SelectionDAG &DAG,
X86::CondCode CC = X86::CondCode(N->getConstantOperandVal(0));
SDValue EFLAGS = N->getOperand(1);
+
+
// Try to simplify the EFLAGS and condition code operands.
if (SDValue Flags = combineSetCCEFLAGS(EFLAGS, CC, DAG, Subtarget))
return getSETCC(CC, Flags, DL, DAG);
diff --git a/llvm/test/CodeGen/X86/ctlz.ll b/llvm/test/CodeGen/X86/ctlz.ll
index 1267fe9033454..43b654caee9f9 100644
--- a/llvm/test/CodeGen/X86/ctlz.ll
+++ b/llvm/test/CodeGen/X86/ctlz.ll
@@ -1334,3 +1334,64 @@ define i64 @ctlz_i32_zext(i32 %x) {
%ext = zext i32 %xor to i64
ret i64 %ext
}
+
+define i32 @bsr_eq(i32 %a, i32 %b) {
+; X86-LABEL: bsr_eq:
+; X86: # %bb.0:
+; X86-NEXT: bsrl {{[0-9]+}}(%esp), %ecx
+; X86-NEXT: bsrl {{[0-9]+}}(%esp), %edx
+; X86-NEXT: xorl %eax, %eax
+; X86-NEXT: cmpl %edx, %ecx
+; X86-NEXT: sete %al
+; X86-NEXT: retl
+;
+; X64-LABEL: bsr_eq:
+; X64: # %bb.0:
+; X64-NEXT: bsrl %edi, %ecx
+; X64-NEXT: bsrl %esi, %edx
+; X64-NEXT: xorl %eax, %eax
+; X64-NEXT: cmpl %edx, %ecx
+; X64-NEXT: sete %al
+; X64-NEXT: retq
+;
+; X86-CLZ-LABEL: bsr_eq:
+; X86-CLZ: # %bb.0:
+; X86-CLZ-NEXT: lzcntl {{[0-9]+}}(%esp), %ecx
+; X86-CLZ-NEXT: lzcntl {{[0-9]+}}(%esp), %edx
+; X86-CLZ-NEXT: xorl %eax, %eax
+; X86-CLZ-NEXT: cmpl %edx, %ecx
+; X86-CLZ-NEXT: sete %al
+; X86-CLZ-NEXT: retl
+;
+; X64-CLZ-LABEL: bsr_eq:
+; X64-CLZ: # %bb.0:
+; X64-CLZ-NEXT: lzcntl %edi, %ecx
+; X64-CLZ-NEXT: lzcntl %esi, %edx
+; X64-CLZ-NEXT: xorl %eax, %eax
+; X64-CLZ-NEXT: cmpl %edx, %ecx
+; X64-CLZ-NEXT: sete %al
+; X64-CLZ-NEXT: retq
+;
+; X64-FASTLZCNT-LABEL: bsr_eq:
+; X64-FASTLZCNT: # %bb.0:
+; X64-FASTLZCNT-NEXT: lzcntl %edi, %ecx
+; X64-FASTLZCNT-NEXT: lzcntl %esi, %edx
+; X64-FASTLZCNT-NEXT: xorl %eax, %eax
+; X64-FASTLZCNT-NEXT: cmpl %edx, %ecx
+; X64-FASTLZCNT-NEXT: sete %al
+; X64-FASTLZCNT-NEXT: retq
+;
+; X86-FASTLZCNT-LABEL: bsr_eq:
+; X86-FASTLZCNT: # %bb.0:
+; X86-FASTLZCNT-NEXT: lzcntl {{[0-9]+}}(%esp), %ecx
+; X86-FASTLZCNT-NEXT: lzcntl {{[0-9]+}}(%esp), %edx
+; X86-FASTLZCNT-NEXT: xorl %eax, %eax
+; X86-FASTLZCNT-NEXT: cmpl %edx, %ecx
+; X86-FASTLZCNT-NEXT: sete %al
+; X86-FASTLZCNT-NEXT: retl
+ %1 = tail call range(i32 0, 33) i32 @llvm.ctlz.i32(i32 %a, i1 true)
+ %2 = tail call range(i32 0, 33) i32 @llvm.ctlz.i32(i32 %b, i1 true)
+ %cmp = icmp eq i32 %1, %2
+ %conv = zext i1 %cmp to i32
+ ret i32 %conv
+}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
This can happen when comparing the two bit positions.
This can happen when comparing the two bit positions.