From 301e5a99ed564ca438e6c4638dccd9ad8ba1ccc6 Mon Sep 17 00:00:00 2001 From: j30ng Date: Fri, 4 Oct 2019 02:59:46 +0900 Subject: [PATCH 1/5] Implement built-in itertools.tee --- vm/src/stdlib/itertools.rs | 118 ++++++++++++++++++++++++++++++++++++- 1 file changed, 117 insertions(+), 1 deletion(-) diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index ddac6cc263..54188c731a 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -1,6 +1,7 @@ use std::cell::{Cell, RefCell}; use std::cmp::Ordering; use std::ops::{AddAssign, SubAssign}; +use std::rc::Rc; use num_bigint::BigInt; use num_traits::ToPrimitive; @@ -10,9 +11,12 @@ use crate::obj::objbool; use crate::obj::objint; use crate::obj::objint::{PyInt, PyIntRef}; use crate::obj::objiter::{call_next, get_iter, new_stop_iteration}; +use crate::obj::objtuple::PyTuple; use crate::obj::objtype; use crate::obj::objtype::PyClassRef; -use crate::pyobject::{IdProtocol, PyCallable, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue}; +use crate::pyobject::{ + IdProtocol, PyCallable, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, +}; use crate::vm::VirtualMachine; #[pyclass(name = "chain")] @@ -650,6 +654,114 @@ impl PyItertoolsAccumulate { } } +#[derive(Debug)] +struct PyItertoolsTeeData { + iterable: PyObjectRef, + values: RefCell>, +} + +impl PyItertoolsTeeData { + fn new( + iterable: PyObjectRef, + vm: &VirtualMachine, + ) -> Result, PyObjectRef> { + Ok(Rc::new(PyItertoolsTeeData { + iterable: get_iter(vm, &iterable)?, + values: RefCell::new(vec![]), + })) + } + + fn get_item(&self, vm: &VirtualMachine, index: usize) -> PyResult { + if self.values.borrow().len() == index { + let result = call_next(vm, &self.iterable)?; + self.values.borrow_mut().push(result); + } + Ok(self.values.borrow()[index].clone()) + } +} + +#[pyclass] +#[derive(Debug)] +struct PyItertoolsTee { + tee_data: Rc, + index: Cell, +} + +impl PyValue for PyItertoolsTee { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("itertools", "tee") + } +} + +#[pyimpl] +impl PyItertoolsTee { + fn from_iter(iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult { + let it = get_iter(vm, &iterable)?; + if it.class().is(&PyItertoolsTee::class(vm)) { + return vm.call_method(&it, "__copy__", PyFuncArgs::from(vec![])); + } + Ok(PyItertoolsTee { + tee_data: PyItertoolsTeeData::new(it, vm)?, + index: Cell::from(0), + } + .into_ref_with_type(vm, PyItertoolsTee::class(vm))? + .into_object()) + } + + #[pymethod(name = "__new__")] + #[allow(clippy::new_ret_no_self)] + fn new( + _cls: PyClassRef, + iterable: PyObjectRef, + n: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult> { + let n = match n { + OptionalArg::Present(x) => match x.as_bigint().to_usize() { + Some(y) => y, + None => return Err(vm.new_overflow_error(String::from("n is too big"))), + }, + OptionalArg::Missing => 2, + }; + + let copyable = if objtype::class_has_attr(&iterable.class(), "__copy__") { + vm.call_method(&iterable, "__copy__", PyFuncArgs::from(vec![]))? + } else { + PyItertoolsTee::from_iter(iterable, vm)? + }; + + let mut tee_vec: Vec = Vec::with_capacity(n); + for _ in 0..n { + let no_args = PyFuncArgs::from(vec![]); + tee_vec.push(vm.call_method(©able, "__copy__", no_args)?); + } + + Ok(PyTuple::from(tee_vec).into_ref(vm)) + } + + #[pymethod(name = "__copy__")] + fn copy(&self, vm: &VirtualMachine) -> PyResult { + Ok(PyItertoolsTee { + tee_data: self.tee_data.clone(), + index: self.index.clone(), + } + .into_ref_with_type(vm, Self::class(vm))? + .into_object()) + } + + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + let result = self.tee_data.get_item(vm, self.index.get()); + self.index.set(self.index.get() + 1); + result + } + + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { + zelf + } +} + pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let ctx = &vm.ctx; @@ -679,6 +791,9 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let accumulate = ctx.new_class("accumulate", ctx.object()); PyItertoolsAccumulate::extend_class(ctx, &accumulate); + let tee = ctx.new_class("tee", ctx.object()); + PyItertoolsTee::extend_class(ctx, &tee); + py_module!(vm, "itertools", { "chain" => chain, "compress" => compress, @@ -690,5 +805,6 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "islice" => islice, "filterfalse" => filterfalse, "accumulate" => accumulate, + "tee" => tee, }) } From 048ec3e0ee9792460cc3162a8842fc9d38baefdb Mon Sep 17 00:00:00 2001 From: j30ng Date: Fri, 4 Oct 2019 03:33:47 +0900 Subject: [PATCH 2/5] Add Testcases --- tests/snippets/stdlib_itertools.py | 31 ++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/snippets/stdlib_itertools.py b/tests/snippets/stdlib_itertools.py index b83c6cfb69..4cc28b489b 100644 --- a/tests/snippets/stdlib_itertools.py +++ b/tests/snippets/stdlib_itertools.py @@ -243,3 +243,34 @@ def assert_matches_seq(it, seq): assert list(itertools.compress("ABCDEF", [1,0,1])) == list("AC") assert list(itertools.compress("ABC", [0,1,1,1,1,1])) == list("BC") assert list(itertools.compress("ABCDEF", [True,False,"t","",1,9])) == list("ACEF") + + +# itertools.tee +t = itertools.tee([]) +assert len(t) == 2 +assert t[0] is not t[1] +assert list(t[0]) == list(t[1]) == [] + +with assert_raises(TypeError): + itertools.tee() + +t = itertools.tee(range(1000)) +assert list(t[0]) == list(t[1]) == list(range(1000)) + +t = itertools.tee([1,22,333], 3) +assert len(t) == 3 +assert 1 == next(t[0]) +assert 1 == next(t[1]) +assert 22 == next(t[0]) +assert 333 == next(t[0]) +assert 1 == next(t[2]) +assert 22 == next(t[1]) +assert 333 == next(t[1]) +assert 22 == next(t[2]) +with assert_raises(StopIteration): + next(t[0]) +assert 333 == next(t[2]) +with assert_raises(StopIteration): + next(t[2]) +with assert_raises(StopIteration): + next(t[1]) From 5351b7d98d87a1fe333e3bee67518fe178e97e8f Mon Sep 17 00:00:00 2001 From: j30ng Date: Fri, 4 Oct 2019 04:03:34 +0900 Subject: [PATCH 3/5] Use Rc::clone(&rc) instead of rc.clone() --- vm/src/stdlib/itertools.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index 54188c731a..eefe6e354d 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -742,7 +742,7 @@ impl PyItertoolsTee { #[pymethod(name = "__copy__")] fn copy(&self, vm: &VirtualMachine) -> PyResult { Ok(PyItertoolsTee { - tee_data: self.tee_data.clone(), + tee_data: Rc::clone(&self.tee_data), index: self.index.clone(), } .into_ref_with_type(vm, Self::class(vm))? From 418de0f62f637f0a969b482e331bbd06d0206317 Mon Sep 17 00:00:00 2001 From: j30ng Date: Fri, 4 Oct 2019 15:44:57 +0900 Subject: [PATCH 4/5] Add Tests * Test `__copy__` method for `tee` objects works properly. * Cover the case where the iterable argument passed to `itertools.tee()` has `__copy__` method implemented. --- tests/snippets/stdlib_itertools.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/snippets/stdlib_itertools.py b/tests/snippets/stdlib_itertools.py index 4cc28b489b..13cb7cdbec 100644 --- a/tests/snippets/stdlib_itertools.py +++ b/tests/snippets/stdlib_itertools.py @@ -274,3 +274,24 @@ def assert_matches_seq(it, seq): next(t[2]) with assert_raises(StopIteration): next(t[1]) + +t0, t1 = itertools.tee([1,2,3]) +tc = t0.__copy__() +assert list(t0) == [1,2,3] +assert list(t1) == [1,2,3] +assert list(tc) == [1,2,3] + +t0, t1 = itertools.tee([1,2,3]) +assert 1 == next(t0) # advance index of t0 by 1 before __copy__() +t0c = t0.__copy__() +t1c = t1.__copy__() +assert list(t0) == [2,3] +assert list(t0c) == [2,3] +assert list(t1) == [1,2,3] +assert list(t1c) == [1,2,3] + +t0, t1 = itertools.tee([1,2,3]) +t2, t3 = itertools.tee(t0) +assert list(t1) == [1,2,3] +assert list(t2) == [1,2,3] +assert list(t3) == [1,2,3] From 0b7da124637e1f424d59449e054925c4197d966c Mon Sep 17 00:00:00 2001 From: j30ng Date: Fri, 4 Oct 2019 16:11:43 +0900 Subject: [PATCH 5/5] Fix Bug in `tee.__next__` + Add Test --- tests/snippets/stdlib_itertools.py | 4 ++++ vm/src/stdlib/itertools.rs | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/snippets/stdlib_itertools.py b/tests/snippets/stdlib_itertools.py index 13cb7cdbec..9368a8ed76 100644 --- a/tests/snippets/stdlib_itertools.py +++ b/tests/snippets/stdlib_itertools.py @@ -295,3 +295,7 @@ def assert_matches_seq(it, seq): assert list(t1) == [1,2,3] assert list(t2) == [1,2,3] assert list(t3) == [1,2,3] + +t = itertools.tee([1,2,3]) +assert list(t[0]) == [1,2,3] +assert list(t[0]) == [] diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index eefe6e354d..721d54aa78 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -751,9 +751,9 @@ impl PyItertoolsTee { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - let result = self.tee_data.get_item(vm, self.index.get()); + let value = self.tee_data.get_item(vm, self.index.get())?; self.index.set(self.index.get() + 1); - result + Ok(value) } #[pymethod(name = "__iter__")]