Skip to content

Commit 4edcca2

Browse files
committed
introduce and adapt vm.bool_equal to test is or else __eq__
1 parent b769675 commit 4edcca2

File tree

7 files changed

+29
-57
lines changed

7 files changed

+29
-57
lines changed

vm/src/dictdatatype.rs

+2-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use crate::obj::objbool;
21
use crate::obj::objstr::PyString;
32
use crate::pyhash;
43
use crate::pyobject::{IdProtocol, IntoPyObject, PyObjectRef, PyResult};
@@ -305,8 +304,7 @@ impl DictKey for &PyObjectRef {
305304
}
306305

307306
fn do_eq(self, vm: &VirtualMachine, other_key: &PyObjectRef) -> PyResult<bool> {
308-
let result = vm._eq(self.clone(), other_key.clone())?;
309-
objbool::boolval(vm, result)
307+
vm.bool_equal(self, other_key)
310308
}
311309
}
312310

@@ -399,7 +397,7 @@ mod tests {
399397
assert_eq!(true, dict.contains(&vm, "x").unwrap());
400398

401399
let val = dict.get(&vm, "x").unwrap().unwrap();
402-
vm._eq(val, value2)
400+
vm.bool_eq(val, value2)
403401
.expect("retrieved value must be equal to inserted value.");
404402
}
405403

vm/src/obj/objdict.rs

+1-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ use crate::pyobject::{
88
};
99
use crate::vm::{ReprGuard, VirtualMachine};
1010

11-
use super::objbool;
1211
use super::objiter;
1312
use super::objstr;
1413
use super::objtype;
@@ -124,8 +123,7 @@ impl PyDictRef {
124123
if v1.is(&v2) {
125124
continue;
126125
}
127-
let value = objbool::boolval(vm, vm._eq(v1, v2)?)?;
128-
if !value {
126+
if !vm.bool_eq(v1, v2)? {
129127
return Ok(false);
130128
}
131129
}

vm/src/obj/objlist.rs

+4-22
Original file line numberDiff line numberDiff line change
@@ -450,25 +450,16 @@ impl PyListRef {
450450
fn count(self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
451451
let mut count: usize = 0;
452452
for element in self.elements.borrow().iter() {
453-
if needle.is(element) {
453+
if vm.bool_equal(element, &needle)? {
454454
count += 1;
455-
} else {
456-
let py_equal = vm._eq(element.clone(), needle.clone())?;
457-
if objbool::boolval(vm, py_equal)? {
458-
count += 1;
459-
}
460455
}
461456
}
462457
Ok(count)
463458
}
464459

465460
fn contains(self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<bool> {
466461
for element in self.elements.borrow().iter() {
467-
if needle.is(element) {
468-
return Ok(true);
469-
}
470-
let py_equal = vm._eq(element.clone(), needle.clone())?;
471-
if objbool::boolval(vm, py_equal)? {
462+
if vm.bool_equal(element, &needle)? {
472463
return Ok(true);
473464
}
474465
}
@@ -478,11 +469,7 @@ impl PyListRef {
478469

479470
fn index(self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
480471
for (index, element) in self.elements.borrow().iter().enumerate() {
481-
if needle.is(element) {
482-
return Ok(index);
483-
}
484-
let py_equal = vm._eq(needle.clone(), element.clone())?;
485-
if objbool::boolval(vm, py_equal)? {
472+
if vm.bool_equal(element, &needle)? {
486473
return Ok(index);
487474
}
488475
}
@@ -508,12 +495,7 @@ impl PyListRef {
508495
fn remove(self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
509496
let mut ri: Option<usize> = None;
510497
for (index, element) in self.elements.borrow().iter().enumerate() {
511-
if needle.is(element) {
512-
ri = Some(index);
513-
break;
514-
}
515-
let py_equal = vm._eq(needle.clone(), element.clone())?;
516-
if objbool::get_value(&py_equal) {
498+
if vm.bool_equal(element, &needle)? {
517499
ri = Some(index);
518500
break;
519501
}

vm/src/obj/objslice.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,13 @@ impl PySlice {
117117
}
118118

119119
fn inner_eq(&self, other: &PySlice, vm: &VirtualMachine) -> PyResult<bool> {
120-
if !vm.bool_eq(self.start(vm), other.start(vm))? {
120+
if !vm.bool_equal(&self.start(vm), &other.start(vm))? {
121121
return Ok(false);
122122
}
123-
if !vm.bool_eq(self.stop(vm), other.stop(vm))? {
123+
if !vm.bool_equal(&self.stop(vm), &other.stop(vm))? {
124124
return Ok(false);
125125
}
126-
if !vm.bool_eq(self.step(vm), other.step(vm))? {
126+
if !vm.bool_equal(&self.step(vm), &other.step(vm))? {
127127
return Ok(false);
128128
}
129129
Ok(true)

vm/src/obj/objtuple.rs

+4-18
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@ use std::fmt;
44
use crate::function::OptionalArg;
55
use crate::pyhash;
66
use crate::pyobject::{
7-
IdProtocol, IntoPyObject, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue,
7+
IntoPyObject, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue,
88
};
99
use crate::vm::{ReprGuard, VirtualMachine};
1010

11-
use super::objbool;
1211
use super::objiter;
1312
use super::objsequence::{
1413
get_elements_tuple, get_item, seq_equal, seq_ge, seq_gt, seq_le, seq_lt, seq_mul,
@@ -143,13 +142,8 @@ impl PyTuple {
143142
fn count(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
144143
let mut count: usize = 0;
145144
for element in self.elements.iter() {
146-
if element.is(&needle) {
145+
if vm.bool_equal(element, &needle)? {
147146
count += 1;
148-
} else {
149-
let is_eq = vm._eq(element.clone(), needle.clone())?;
150-
if objbool::boolval(vm, is_eq)? {
151-
count += 1;
152-
}
153147
}
154148
}
155149
Ok(count)
@@ -237,11 +231,7 @@ impl PyTuple {
237231
#[pymethod(name = "index")]
238232
fn index(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
239233
for (index, element) in self.elements.iter().enumerate() {
240-
if element.is(&needle) {
241-
return Ok(index);
242-
}
243-
let is_eq = vm._eq(needle.clone(), element.clone())?;
244-
if objbool::boolval(vm, is_eq)? {
234+
if vm.bool_equal(element, &needle)? {
245235
return Ok(index);
246236
}
247237
}
@@ -251,11 +241,7 @@ impl PyTuple {
251241
#[pymethod(name = "__contains__")]
252242
fn contains(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<bool> {
253243
for element in self.elements.iter() {
254-
if element.is(&needle) {
255-
return Ok(true);
256-
}
257-
let is_eq = vm._eq(needle.clone(), element.clone())?;
258-
if objbool::boolval(vm, is_eq)? {
244+
if vm.bool_equal(element, &needle)? {
259245
return Ok(true);
260246
}
261247
}

vm/src/stdlib/collections.rs

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::function::OptionalArg;
2-
use crate::obj::{objbool, objsequence, objtype::PyClassRef};
2+
use crate::obj::objsequence;
3+
use crate::obj::objtype::PyClassRef;
34
use crate::pyobject::{IdProtocol, PyClassImpl, PyIterable, PyObjectRef, PyRef, PyResult, PyValue};
45
use crate::vm::ReprGuard;
56
use crate::VirtualMachine;
@@ -77,7 +78,7 @@ impl PyDeque {
7778
fn count(&self, obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
7879
let mut count = 0;
7980
for elem in self.deque.borrow().iter() {
80-
if objbool::boolval(vm, vm._eq(elem.clone(), obj.clone())?)? {
81+
if vm.bool_equal(elem, &obj)? {
8182
count += 1;
8283
}
8384
}
@@ -113,7 +114,7 @@ impl PyDeque {
113114
let start = start.unwrap_or(0);
114115
let stop = stop.unwrap_or_else(|| deque.len());
115116
for (i, elem) in deque.iter().skip(start).take(stop - start).enumerate() {
116-
if objbool::boolval(vm, vm._eq(elem.clone(), obj.clone())?)? {
117+
if vm.bool_equal(elem, &obj)? {
117118
return Ok(i);
118119
}
119120
}
@@ -170,7 +171,7 @@ impl PyDeque {
170171
let mut deque = self.deque.borrow_mut();
171172
let mut idx = None;
172173
for (i, elem) in deque.iter().enumerate() {
173-
if objbool::boolval(vm, vm._eq(elem.clone(), obj.clone())?)? {
174+
if vm.bool_equal(elem, &obj)? {
174175
idx = Some(i);
175176
break;
176177
}

vm/src/vm.rs

+10-3
Original file line numberDiff line numberDiff line change
@@ -1233,8 +1233,7 @@ impl VirtualMachine {
12331233
let iter = objiter::get_iter(self, &haystack)?;
12341234
loop {
12351235
if let Some(element) = objiter::get_next_object(self, &iter)? {
1236-
let equal = self._eq(needle.clone(), element.clone())?;
1237-
if objbool::get_value(&equal) {
1236+
if self.bool_eq(needle.clone(), element.clone())? {
12381237
return Ok(self.new_bool(true));
12391238
} else {
12401239
continue;
@@ -1267,11 +1266,19 @@ impl VirtualMachine {
12671266
}
12681267

12691268
pub fn bool_eq(&self, a: PyObjectRef, b: PyObjectRef) -> PyResult<bool> {
1270-
let eq = self._eq(a.clone(), b.clone())?;
1269+
let eq = self._eq(a, b)?;
12711270
let value = objbool::boolval(self, eq)?;
12721271
Ok(value)
12731272
}
12741273

1274+
pub fn bool_equal(&self, a: &PyObjectRef, b: &PyObjectRef) -> PyResult<bool> {
1275+
if a.is(b) {
1276+
Ok(true)
1277+
} else {
1278+
self.bool_eq(a.clone(), b.clone())
1279+
}
1280+
}
1281+
12751282
pub fn bool_seq_lt(&self, a: PyObjectRef, b: PyObjectRef) -> PyResult<Option<bool>> {
12761283
let value = if objbool::boolval(self, self._lt(a.clone(), b.clone())?)? {
12771284
Some(true)

0 commit comments

Comments
 (0)