Skip to content

Commit c114374

Browse files
authored
Merge pull request #1701 from youknowone/comparison
Fix comparison operator
2 parents 32d544d + 3698d0e commit c114374

File tree

2 files changed

+50
-14
lines changed

2 files changed

+50
-14
lines changed

Lib/test/test_compare.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,6 @@ def __ne__(*args):
7070
Left() != Right()
7171
self.assertSequenceEqual(calls, ['Left.__eq__', 'Right.__ne__'])
7272

73-
# TODO: RUSTPYTHON
74-
@unittest.expectedFailure
7573
def test_ne_low_priority(self):
7674
"""object.__ne__() should not invoke reflected __eq__()"""
7775
calls = []

vm/src/vm.rs

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,40 +1145,78 @@ impl VirtualMachine {
11451145
})
11461146
}
11471147

1148+
// Perform a comparison, raising TypeError when the requested comparison
1149+
// operator is not supported.
1150+
// see: CPython PyObject_RichCompare
1151+
fn _cmp<F>(
1152+
&self,
1153+
v: PyObjectRef,
1154+
w: PyObjectRef,
1155+
op: &str,
1156+
swap_op: &str,
1157+
default: F,
1158+
) -> PyResult
1159+
where
1160+
F: Fn(&VirtualMachine, PyObjectRef, PyObjectRef) -> PyResult,
1161+
{
1162+
// TODO: _Py_EnterRecursiveCall(tstate, " in comparison")
1163+
1164+
let mut checked_reverse_op = false;
1165+
if !v.typ.is(&w.typ) && objtype::issubclass(&w.class(), &v.class()) {
1166+
if let Some(method_or_err) = self.get_method(w.clone(), swap_op) {
1167+
let method = method_or_err?;
1168+
checked_reverse_op = true;
1169+
1170+
let result = self.invoke(&method, vec![v.clone()])?;
1171+
if !result.is(&self.ctx.not_implemented()) {
1172+
return Ok(result);
1173+
}
1174+
}
1175+
}
1176+
1177+
self.call_or_unsupported(v, w, op, |vm, v, w| {
1178+
if !checked_reverse_op {
1179+
self.call_or_unsupported(w, v, swap_op, |vm, v, w| default(vm, v, w))
1180+
} else {
1181+
default(vm, v, w)
1182+
}
1183+
})
1184+
1185+
// TODO: _Py_LeaveRecursiveCall(tstate);
1186+
}
1187+
11481188
pub fn _eq(&self, a: PyObjectRef, b: PyObjectRef) -> PyResult {
1149-
self.call_or_reflection(
1150-
a,
1151-
b,
1152-
"__eq__",
1153-
"__eq__",
1154-
|vm, _a, _b| Ok(vm.new_bool(false)),
1155-
)
1189+
self._cmp(a, b, "__eq__", "__eq__", |vm, a, b| {
1190+
Ok(vm.new_bool(a.is(&b)))
1191+
})
11561192
}
11571193

11581194
pub fn _ne(&self, a: PyObjectRef, b: PyObjectRef) -> PyResult {
1159-
self.call_or_reflection(a, b, "__ne__", "__ne__", |vm, _a, _b| Ok(vm.new_bool(true)))
1195+
self._cmp(a, b, "__ne__", "__ne__", |vm, a, b| {
1196+
Ok(vm.new_bool(!a.is(&b)))
1197+
})
11601198
}
11611199

11621200
pub fn _lt(&self, a: PyObjectRef, b: PyObjectRef) -> PyResult {
1163-
self.call_or_reflection(a, b, "__lt__", "__gt__", |vm, a, b| {
1201+
self._cmp(a, b, "__lt__", "__gt__", |vm, a, b| {
11641202
Err(vm.new_unsupported_operand_error(a, b, "<"))
11651203
})
11661204
}
11671205

11681206
pub fn _le(&self, a: PyObjectRef, b: PyObjectRef) -> PyResult {
1169-
self.call_or_reflection(a, b, "__le__", "__ge__", |vm, a, b| {
1207+
self._cmp(a, b, "__le__", "__ge__", |vm, a, b| {
11701208
Err(vm.new_unsupported_operand_error(a, b, "<="))
11711209
})
11721210
}
11731211

11741212
pub fn _gt(&self, a: PyObjectRef, b: PyObjectRef) -> PyResult {
1175-
self.call_or_reflection(a, b, "__gt__", "__lt__", |vm, a, b| {
1213+
self._cmp(a, b, "__gt__", "__lt__", |vm, a, b| {
11761214
Err(vm.new_unsupported_operand_error(a, b, ">"))
11771215
})
11781216
}
11791217

11801218
pub fn _ge(&self, a: PyObjectRef, b: PyObjectRef) -> PyResult {
1181-
self.call_or_reflection(a, b, "__ge__", "__le__", |vm, a, b| {
1219+
self._cmp(a, b, "__ge__", "__le__", |vm, a, b| {
11821220
Err(vm.new_unsupported_operand_error(a, b, ">="))
11831221
})
11841222
}

0 commit comments

Comments
 (0)