Skip to content

Commit 16c3c78

Browse files
committed
introduce and adapt vm.identical_or_equal to test is or else __eq__
1 parent e9ebd95 commit 16c3c78

File tree

7 files changed

+28
-57
lines changed

7 files changed

+28
-57
lines changed

vm/src/dictdatatype.rs

+2-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use crate::obj::objbool;
21
use crate::obj::objstr::PyString;
32
use crate::pyhash;
43
use crate::pyobject::{IdProtocol, IntoPyObject, PyObjectRef, PyResult};
@@ -305,8 +304,7 @@ impl DictKey for &PyObjectRef {
305304
}
306305

307306
fn do_eq(self, vm: &VirtualMachine, other_key: &PyObjectRef) -> PyResult<bool> {
308-
let result = vm._eq(self.clone(), other_key.clone())?;
309-
objbool::boolval(vm, result)
307+
vm.identical_or_equal(self, other_key)
310308
}
311309
}
312310

@@ -399,7 +397,7 @@ mod tests {
399397
assert_eq!(true, dict.contains(&vm, "x").unwrap());
400398

401399
let val = dict.get(&vm, "x").unwrap().unwrap();
402-
vm._eq(val, value2)
400+
vm.bool_eq(val, value2)
403401
.expect("retrieved value must be equal to inserted value.");
404402
}
405403

vm/src/obj/objdict.rs

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
use std::cell::{Cell, RefCell};
22
use std::fmt;
33

4-
use super::objbool;
54
use super::objiter;
65
use super::objstr;
76
use super::objtype::{self, PyClassRef};
@@ -121,8 +120,7 @@ impl PyDictRef {
121120
if v1.is(&v2) {
122121
continue;
123122
}
124-
let value = objbool::boolval(vm, vm._eq(v1, v2)?)?;
125-
if !value {
123+
if !vm.bool_eq(v1, v2)? {
126124
return Ok(false);
127125
}
128126
}

vm/src/obj/objlist.rs

+4-22
Original file line numberDiff line numberDiff line change
@@ -447,25 +447,16 @@ impl PyListRef {
447447
fn count(self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
448448
let mut count: usize = 0;
449449
for element in self.elements.borrow().iter() {
450-
if needle.is(element) {
450+
if vm.identical_or_equal(element, &needle)? {
451451
count += 1;
452-
} else {
453-
let py_equal = vm._eq(element.clone(), needle.clone())?;
454-
if objbool::boolval(vm, py_equal)? {
455-
count += 1;
456-
}
457452
}
458453
}
459454
Ok(count)
460455
}
461456

462457
fn contains(self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<bool> {
463458
for element in self.elements.borrow().iter() {
464-
if needle.is(element) {
465-
return Ok(true);
466-
}
467-
let py_equal = vm._eq(element.clone(), needle.clone())?;
468-
if objbool::boolval(vm, py_equal)? {
459+
if vm.identical_or_equal(element, &needle)? {
469460
return Ok(true);
470461
}
471462
}
@@ -475,11 +466,7 @@ impl PyListRef {
475466

476467
fn index(self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
477468
for (index, element) in self.elements.borrow().iter().enumerate() {
478-
if needle.is(element) {
479-
return Ok(index);
480-
}
481-
let py_equal = vm._eq(needle.clone(), element.clone())?;
482-
if objbool::boolval(vm, py_equal)? {
469+
if vm.identical_or_equal(element, &needle)? {
483470
return Ok(index);
484471
}
485472
}
@@ -505,12 +492,7 @@ impl PyListRef {
505492
fn remove(self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
506493
let mut ri: Option<usize> = None;
507494
for (index, element) in self.elements.borrow().iter().enumerate() {
508-
if needle.is(element) {
509-
ri = Some(index);
510-
break;
511-
}
512-
let py_equal = vm._eq(needle.clone(), element.clone())?;
513-
if objbool::get_value(&py_equal) {
495+
if vm.identical_or_equal(element, &needle)? {
514496
ri = Some(index);
515497
break;
516498
}

vm/src/obj/objslice.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,13 @@ impl PySlice {
117117
}
118118

119119
fn inner_eq(&self, other: &PySlice, vm: &VirtualMachine) -> PyResult<bool> {
120-
if !vm.bool_eq(self.start(vm), other.start(vm))? {
120+
if !vm.identical_or_equal(&self.start(vm), &other.start(vm))? {
121121
return Ok(false);
122122
}
123-
if !vm.bool_eq(self.stop(vm), other.stop(vm))? {
123+
if !vm.identical_or_equal(&self.stop(vm), &other.stop(vm))? {
124124
return Ok(false);
125125
}
126-
if !vm.bool_eq(self.step(vm), other.step(vm))? {
126+
if !vm.identical_or_equal(&self.step(vm), &other.step(vm))? {
127127
return Ok(false);
128128
}
129129
Ok(true)

vm/src/obj/objtuple.rs

+4-18
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
use std::cell::Cell;
22
use std::fmt;
33

4-
use super::objbool;
54
use super::objiter;
65
use super::objsequence::{
76
get_elements_tuple, get_item, seq_equal, seq_ge, seq_gt, seq_le, seq_lt, seq_mul,
@@ -10,7 +9,7 @@ use super::objtype::{self, PyClassRef};
109
use crate::function::OptionalArg;
1110
use crate::pyhash;
1211
use crate::pyobject::{
13-
IdProtocol, IntoPyObject, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue,
12+
IntoPyObject, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue,
1413
};
1514
use crate::vm::{ReprGuard, VirtualMachine};
1615

@@ -142,13 +141,8 @@ impl PyTuple {
142141
fn count(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
143142
let mut count: usize = 0;
144143
for element in self.elements.iter() {
145-
if element.is(&needle) {
144+
if vm.identical_or_equal(element, &needle)? {
146145
count += 1;
147-
} else {
148-
let is_eq = vm._eq(element.clone(), needle.clone())?;
149-
if objbool::boolval(vm, is_eq)? {
150-
count += 1;
151-
}
152146
}
153147
}
154148
Ok(count)
@@ -236,11 +230,7 @@ impl PyTuple {
236230
#[pymethod(name = "index")]
237231
fn index(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
238232
for (index, element) in self.elements.iter().enumerate() {
239-
if element.is(&needle) {
240-
return Ok(index);
241-
}
242-
let is_eq = vm._eq(needle.clone(), element.clone())?;
243-
if objbool::boolval(vm, is_eq)? {
233+
if vm.identical_or_equal(element, &needle)? {
244234
return Ok(index);
245235
}
246236
}
@@ -250,11 +240,7 @@ impl PyTuple {
250240
#[pymethod(name = "__contains__")]
251241
fn contains(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<bool> {
252242
for element in self.elements.iter() {
253-
if element.is(&needle) {
254-
return Ok(true);
255-
}
256-
let is_eq = vm._eq(needle.clone(), element.clone())?;
257-
if objbool::boolval(vm, is_eq)? {
243+
if vm.identical_or_equal(element, &needle)? {
258244
return Ok(true);
259245
}
260246
}

vm/src/stdlib/collections.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::function::OptionalArg;
2-
use crate::obj::{objbool, objiter, objsequence, objtype::PyClassRef};
2+
use crate::obj::{objiter, objsequence, objtype::PyClassRef};
33
use crate::pyobject::{IdProtocol, PyClassImpl, PyIterable, PyObjectRef, PyRef, PyResult, PyValue};
44
use crate::vm::ReprGuard;
55
use crate::VirtualMachine;
@@ -78,7 +78,7 @@ impl PyDeque {
7878
fn count(&self, obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
7979
let mut count = 0;
8080
for elem in self.deque.borrow().iter() {
81-
if objbool::boolval(vm, vm._eq(elem.clone(), obj.clone())?)? {
81+
if vm.identical_or_equal(elem, &obj)? {
8282
count += 1;
8383
}
8484
}
@@ -114,7 +114,7 @@ impl PyDeque {
114114
let start = start.unwrap_or(0);
115115
let stop = stop.unwrap_or_else(|| deque.len());
116116
for (i, elem) in deque.iter().skip(start).take(stop - start).enumerate() {
117-
if objbool::boolval(vm, vm._eq(elem.clone(), obj.clone())?)? {
117+
if vm.identical_or_equal(elem, &obj)? {
118118
return Ok(i);
119119
}
120120
}
@@ -171,7 +171,7 @@ impl PyDeque {
171171
let mut deque = self.deque.borrow_mut();
172172
let mut idx = None;
173173
for (i, elem) in deque.iter().enumerate() {
174-
if objbool::boolval(vm, vm._eq(elem.clone(), obj.clone())?)? {
174+
if vm.identical_or_equal(elem, &obj)? {
175175
idx = Some(i);
176176
break;
177177
}

vm/src/vm.rs

+10-3
Original file line numberDiff line numberDiff line change
@@ -1236,8 +1236,7 @@ impl VirtualMachine {
12361236
let iter = objiter::get_iter(self, &haystack)?;
12371237
loop {
12381238
if let Some(element) = objiter::get_next_object(self, &iter)? {
1239-
let equal = self._eq(needle.clone(), element.clone())?;
1240-
if objbool::get_value(&equal) {
1239+
if self.bool_eq(needle.clone(), element.clone())? {
12411240
return Ok(self.new_bool(true));
12421241
} else {
12431242
continue;
@@ -1270,11 +1269,19 @@ impl VirtualMachine {
12701269
}
12711270

12721271
pub fn bool_eq(&self, a: PyObjectRef, b: PyObjectRef) -> PyResult<bool> {
1273-
let eq = self._eq(a.clone(), b.clone())?;
1272+
let eq = self._eq(a, b)?;
12741273
let value = objbool::boolval(self, eq)?;
12751274
Ok(value)
12761275
}
12771276

1277+
pub fn identical_or_equal(&self, a: &PyObjectRef, b: &PyObjectRef) -> PyResult<bool> {
1278+
if a.is(b) {
1279+
Ok(true)
1280+
} else {
1281+
self.bool_eq(a.clone(), b.clone())
1282+
}
1283+
}
1284+
12781285
pub fn bool_seq_lt(&self, a: PyObjectRef, b: PyObjectRef) -> PyResult<Option<bool>> {
12791286
let value = if objbool::boolval(self, self._lt(a.clone(), b.clone())?)? {
12801287
Some(true)

0 commit comments

Comments
 (0)