From 16c3c78b69e0a274961f41a408683060b376761b Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Thu, 10 Oct 2019 01:19:32 +0900 Subject: [PATCH] introduce and adapt vm.identical_or_equal to test `is` or else `__eq__` --- vm/src/dictdatatype.rs | 6 ++---- vm/src/obj/objdict.rs | 4 +--- vm/src/obj/objlist.rs | 26 ++++---------------------- vm/src/obj/objslice.rs | 6 +++--- vm/src/obj/objtuple.rs | 22 ++++------------------ vm/src/stdlib/collections.rs | 8 ++++---- vm/src/vm.rs | 13 ++++++++++--- 7 files changed, 28 insertions(+), 57 deletions(-) diff --git a/vm/src/dictdatatype.rs b/vm/src/dictdatatype.rs index 222e364a19..308c09d898 100644 --- a/vm/src/dictdatatype.rs +++ b/vm/src/dictdatatype.rs @@ -1,4 +1,3 @@ -use crate::obj::objbool; use crate::obj::objstr::PyString; use crate::pyhash; use crate::pyobject::{IdProtocol, IntoPyObject, PyObjectRef, PyResult}; @@ -305,8 +304,7 @@ impl DictKey for &PyObjectRef { } fn do_eq(self, vm: &VirtualMachine, other_key: &PyObjectRef) -> PyResult { - let result = vm._eq(self.clone(), other_key.clone())?; - objbool::boolval(vm, result) + vm.identical_or_equal(self, other_key) } } @@ -399,7 +397,7 @@ mod tests { assert_eq!(true, dict.contains(&vm, "x").unwrap()); let val = dict.get(&vm, "x").unwrap().unwrap(); - vm._eq(val, value2) + vm.bool_eq(val, value2) .expect("retrieved value must be equal to inserted value."); } diff --git a/vm/src/obj/objdict.rs b/vm/src/obj/objdict.rs index a630fc9025..d4c700cf03 100644 --- a/vm/src/obj/objdict.rs +++ b/vm/src/obj/objdict.rs @@ -1,7 +1,6 @@ use std::cell::{Cell, RefCell}; use std::fmt; -use super::objbool; use super::objiter; use super::objstr; use super::objtype::{self, PyClassRef}; @@ -121,8 +120,7 @@ impl PyDictRef { if v1.is(&v2) { continue; } - let value = objbool::boolval(vm, vm._eq(v1, v2)?)?; - if !value { + if !vm.bool_eq(v1, v2)? { return Ok(false); } } diff --git a/vm/src/obj/objlist.rs b/vm/src/obj/objlist.rs index b187f01c31..b3e013daa7 100644 --- a/vm/src/obj/objlist.rs +++ b/vm/src/obj/objlist.rs @@ -447,13 +447,8 @@ impl PyListRef { fn count(self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { let mut count: usize = 0; for element in self.elements.borrow().iter() { - if needle.is(element) { + if vm.identical_or_equal(element, &needle)? { count += 1; - } else { - let py_equal = vm._eq(element.clone(), needle.clone())?; - if objbool::boolval(vm, py_equal)? { - count += 1; - } } } Ok(count) @@ -461,11 +456,7 @@ impl PyListRef { fn contains(self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { for element in self.elements.borrow().iter() { - if needle.is(element) { - return Ok(true); - } - let py_equal = vm._eq(element.clone(), needle.clone())?; - if objbool::boolval(vm, py_equal)? { + if vm.identical_or_equal(element, &needle)? { return Ok(true); } } @@ -475,11 +466,7 @@ impl PyListRef { fn index(self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { for (index, element) in self.elements.borrow().iter().enumerate() { - if needle.is(element) { - return Ok(index); - } - let py_equal = vm._eq(needle.clone(), element.clone())?; - if objbool::boolval(vm, py_equal)? { + if vm.identical_or_equal(element, &needle)? { return Ok(index); } } @@ -505,12 +492,7 @@ impl PyListRef { fn remove(self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { let mut ri: Option = None; for (index, element) in self.elements.borrow().iter().enumerate() { - if needle.is(element) { - ri = Some(index); - break; - } - let py_equal = vm._eq(needle.clone(), element.clone())?; - if objbool::get_value(&py_equal) { + if vm.identical_or_equal(element, &needle)? { ri = Some(index); break; } diff --git a/vm/src/obj/objslice.rs b/vm/src/obj/objslice.rs index e2f1787cf1..2f1e7d56f6 100644 --- a/vm/src/obj/objslice.rs +++ b/vm/src/obj/objslice.rs @@ -117,13 +117,13 @@ impl PySlice { } fn inner_eq(&self, other: &PySlice, vm: &VirtualMachine) -> PyResult { - if !vm.bool_eq(self.start(vm), other.start(vm))? { + if !vm.identical_or_equal(&self.start(vm), &other.start(vm))? { return Ok(false); } - if !vm.bool_eq(self.stop(vm), other.stop(vm))? { + if !vm.identical_or_equal(&self.stop(vm), &other.stop(vm))? { return Ok(false); } - if !vm.bool_eq(self.step(vm), other.step(vm))? { + if !vm.identical_or_equal(&self.step(vm), &other.step(vm))? { return Ok(false); } Ok(true) diff --git a/vm/src/obj/objtuple.rs b/vm/src/obj/objtuple.rs index f84cc2ff91..f05c9bb03a 100644 --- a/vm/src/obj/objtuple.rs +++ b/vm/src/obj/objtuple.rs @@ -1,7 +1,6 @@ use std::cell::Cell; use std::fmt; -use super::objbool; use super::objiter; use super::objsequence::{ get_elements_tuple, get_item, seq_equal, seq_ge, seq_gt, seq_le, seq_lt, seq_mul, @@ -10,7 +9,7 @@ use super::objtype::{self, PyClassRef}; use crate::function::OptionalArg; use crate::pyhash; use crate::pyobject::{ - IdProtocol, IntoPyObject, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, + IntoPyObject, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, }; use crate::vm::{ReprGuard, VirtualMachine}; @@ -142,13 +141,8 @@ impl PyTuple { fn count(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { let mut count: usize = 0; for element in self.elements.iter() { - if element.is(&needle) { + if vm.identical_or_equal(element, &needle)? { count += 1; - } else { - let is_eq = vm._eq(element.clone(), needle.clone())?; - if objbool::boolval(vm, is_eq)? { - count += 1; - } } } Ok(count) @@ -236,11 +230,7 @@ impl PyTuple { #[pymethod(name = "index")] fn index(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { for (index, element) in self.elements.iter().enumerate() { - if element.is(&needle) { - return Ok(index); - } - let is_eq = vm._eq(needle.clone(), element.clone())?; - if objbool::boolval(vm, is_eq)? { + if vm.identical_or_equal(element, &needle)? { return Ok(index); } } @@ -250,11 +240,7 @@ impl PyTuple { #[pymethod(name = "__contains__")] fn contains(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { for element in self.elements.iter() { - if element.is(&needle) { - return Ok(true); - } - let is_eq = vm._eq(needle.clone(), element.clone())?; - if objbool::boolval(vm, is_eq)? { + if vm.identical_or_equal(element, &needle)? { return Ok(true); } } diff --git a/vm/src/stdlib/collections.rs b/vm/src/stdlib/collections.rs index 1aec08810b..79c35c2822 100644 --- a/vm/src/stdlib/collections.rs +++ b/vm/src/stdlib/collections.rs @@ -1,5 +1,5 @@ use crate::function::OptionalArg; -use crate::obj::{objbool, objiter, objsequence, objtype::PyClassRef}; +use crate::obj::{objiter, objsequence, objtype::PyClassRef}; use crate::pyobject::{IdProtocol, PyClassImpl, PyIterable, PyObjectRef, PyRef, PyResult, PyValue}; use crate::vm::ReprGuard; use crate::VirtualMachine; @@ -78,7 +78,7 @@ impl PyDeque { fn count(&self, obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { let mut count = 0; for elem in self.deque.borrow().iter() { - if objbool::boolval(vm, vm._eq(elem.clone(), obj.clone())?)? { + if vm.identical_or_equal(elem, &obj)? { count += 1; } } @@ -114,7 +114,7 @@ impl PyDeque { let start = start.unwrap_or(0); let stop = stop.unwrap_or_else(|| deque.len()); for (i, elem) in deque.iter().skip(start).take(stop - start).enumerate() { - if objbool::boolval(vm, vm._eq(elem.clone(), obj.clone())?)? { + if vm.identical_or_equal(elem, &obj)? { return Ok(i); } } @@ -171,7 +171,7 @@ impl PyDeque { let mut deque = self.deque.borrow_mut(); let mut idx = None; for (i, elem) in deque.iter().enumerate() { - if objbool::boolval(vm, vm._eq(elem.clone(), obj.clone())?)? { + if vm.identical_or_equal(elem, &obj)? { idx = Some(i); break; } diff --git a/vm/src/vm.rs b/vm/src/vm.rs index f4145b846f..84e966f470 100644 --- a/vm/src/vm.rs +++ b/vm/src/vm.rs @@ -1236,8 +1236,7 @@ impl VirtualMachine { let iter = objiter::get_iter(self, &haystack)?; loop { if let Some(element) = objiter::get_next_object(self, &iter)? { - let equal = self._eq(needle.clone(), element.clone())?; - if objbool::get_value(&equal) { + if self.bool_eq(needle.clone(), element.clone())? { return Ok(self.new_bool(true)); } else { continue; @@ -1270,11 +1269,19 @@ impl VirtualMachine { } pub fn bool_eq(&self, a: PyObjectRef, b: PyObjectRef) -> PyResult { - let eq = self._eq(a.clone(), b.clone())?; + let eq = self._eq(a, b)?; let value = objbool::boolval(self, eq)?; Ok(value) } + pub fn identical_or_equal(&self, a: &PyObjectRef, b: &PyObjectRef) -> PyResult { + if a.is(b) { + Ok(true) + } else { + self.bool_eq(a.clone(), b.clone()) + } + } + pub fn bool_seq_lt(&self, a: PyObjectRef, b: PyObjectRef) -> PyResult> { let value = if objbool::boolval(self, self._lt(a.clone(), b.clone())?)? { Some(true)