diff --git a/tests/snippets/stdlib_itertools.py b/tests/snippets/stdlib_itertools.py index b83c6cfb69..9368a8ed76 100644 --- a/tests/snippets/stdlib_itertools.py +++ b/tests/snippets/stdlib_itertools.py @@ -243,3 +243,59 @@ def assert_matches_seq(it, seq): assert list(itertools.compress("ABCDEF", [1,0,1])) == list("AC") assert list(itertools.compress("ABC", [0,1,1,1,1,1])) == list("BC") assert list(itertools.compress("ABCDEF", [True,False,"t","",1,9])) == list("ACEF") + + +# itertools.tee +t = itertools.tee([]) +assert len(t) == 2 +assert t[0] is not t[1] +assert list(t[0]) == list(t[1]) == [] + +with assert_raises(TypeError): + itertools.tee() + +t = itertools.tee(range(1000)) +assert list(t[0]) == list(t[1]) == list(range(1000)) + +t = itertools.tee([1,22,333], 3) +assert len(t) == 3 +assert 1 == next(t[0]) +assert 1 == next(t[1]) +assert 22 == next(t[0]) +assert 333 == next(t[0]) +assert 1 == next(t[2]) +assert 22 == next(t[1]) +assert 333 == next(t[1]) +assert 22 == next(t[2]) +with assert_raises(StopIteration): + next(t[0]) +assert 333 == next(t[2]) +with assert_raises(StopIteration): + next(t[2]) +with assert_raises(StopIteration): + next(t[1]) + +t0, t1 = itertools.tee([1,2,3]) +tc = t0.__copy__() +assert list(t0) == [1,2,3] +assert list(t1) == [1,2,3] +assert list(tc) == [1,2,3] + +t0, t1 = itertools.tee([1,2,3]) +assert 1 == next(t0) # advance index of t0 by 1 before __copy__() +t0c = t0.__copy__() +t1c = t1.__copy__() +assert list(t0) == [2,3] +assert list(t0c) == [2,3] +assert list(t1) == [1,2,3] +assert list(t1c) == [1,2,3] + +t0, t1 = itertools.tee([1,2,3]) +t2, t3 = itertools.tee(t0) +assert list(t1) == [1,2,3] +assert list(t2) == [1,2,3] +assert list(t3) == [1,2,3] + +t = itertools.tee([1,2,3]) +assert list(t[0]) == [1,2,3] +assert list(t[0]) == [] diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index ddac6cc263..721d54aa78 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -1,6 +1,7 @@ use std::cell::{Cell, RefCell}; use std::cmp::Ordering; use std::ops::{AddAssign, SubAssign}; +use std::rc::Rc; use num_bigint::BigInt; use num_traits::ToPrimitive; @@ -10,9 +11,12 @@ use crate::obj::objbool; use crate::obj::objint; use crate::obj::objint::{PyInt, PyIntRef}; use crate::obj::objiter::{call_next, get_iter, new_stop_iteration}; +use crate::obj::objtuple::PyTuple; use crate::obj::objtype; use crate::obj::objtype::PyClassRef; -use crate::pyobject::{IdProtocol, PyCallable, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue}; +use crate::pyobject::{ + IdProtocol, PyCallable, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, +}; use crate::vm::VirtualMachine; #[pyclass(name = "chain")] @@ -650,6 +654,114 @@ impl PyItertoolsAccumulate { } } +#[derive(Debug)] +struct PyItertoolsTeeData { + iterable: PyObjectRef, + values: RefCell>, +} + +impl PyItertoolsTeeData { + fn new( + iterable: PyObjectRef, + vm: &VirtualMachine, + ) -> Result, PyObjectRef> { + Ok(Rc::new(PyItertoolsTeeData { + iterable: get_iter(vm, &iterable)?, + values: RefCell::new(vec![]), + })) + } + + fn get_item(&self, vm: &VirtualMachine, index: usize) -> PyResult { + if self.values.borrow().len() == index { + let result = call_next(vm, &self.iterable)?; + self.values.borrow_mut().push(result); + } + Ok(self.values.borrow()[index].clone()) + } +} + +#[pyclass] +#[derive(Debug)] +struct PyItertoolsTee { + tee_data: Rc, + index: Cell, +} + +impl PyValue for PyItertoolsTee { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("itertools", "tee") + } +} + +#[pyimpl] +impl PyItertoolsTee { + fn from_iter(iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult { + let it = get_iter(vm, &iterable)?; + if it.class().is(&PyItertoolsTee::class(vm)) { + return vm.call_method(&it, "__copy__", PyFuncArgs::from(vec![])); + } + Ok(PyItertoolsTee { + tee_data: PyItertoolsTeeData::new(it, vm)?, + index: Cell::from(0), + } + .into_ref_with_type(vm, PyItertoolsTee::class(vm))? + .into_object()) + } + + #[pymethod(name = "__new__")] + #[allow(clippy::new_ret_no_self)] + fn new( + _cls: PyClassRef, + iterable: PyObjectRef, + n: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult> { + let n = match n { + OptionalArg::Present(x) => match x.as_bigint().to_usize() { + Some(y) => y, + None => return Err(vm.new_overflow_error(String::from("n is too big"))), + }, + OptionalArg::Missing => 2, + }; + + let copyable = if objtype::class_has_attr(&iterable.class(), "__copy__") { + vm.call_method(&iterable, "__copy__", PyFuncArgs::from(vec![]))? + } else { + PyItertoolsTee::from_iter(iterable, vm)? + }; + + let mut tee_vec: Vec = Vec::with_capacity(n); + for _ in 0..n { + let no_args = PyFuncArgs::from(vec![]); + tee_vec.push(vm.call_method(©able, "__copy__", no_args)?); + } + + Ok(PyTuple::from(tee_vec).into_ref(vm)) + } + + #[pymethod(name = "__copy__")] + fn copy(&self, vm: &VirtualMachine) -> PyResult { + Ok(PyItertoolsTee { + tee_data: Rc::clone(&self.tee_data), + index: self.index.clone(), + } + .into_ref_with_type(vm, Self::class(vm))? + .into_object()) + } + + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + let value = self.tee_data.get_item(vm, self.index.get())?; + self.index.set(self.index.get() + 1); + Ok(value) + } + + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { + zelf + } +} + pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let ctx = &vm.ctx; @@ -679,6 +791,9 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let accumulate = ctx.new_class("accumulate", ctx.object()); PyItertoolsAccumulate::extend_class(ctx, &accumulate); + let tee = ctx.new_class("tee", ctx.object()); + PyItertoolsTee::extend_class(ctx, &tee); + py_module!(vm, "itertools", { "chain" => chain, "compress" => compress, @@ -690,5 +805,6 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "islice" => islice, "filterfalse" => filterfalse, "accumulate" => accumulate, + "tee" => tee, }) }