diff --git a/tests/snippets/stdlib_itertools.py b/tests/snippets/stdlib_itertools.py index 9368a8ed76..f133cfc04e 100644 --- a/tests/snippets/stdlib_itertools.py +++ b/tests/snippets/stdlib_itertools.py @@ -299,3 +299,25 @@ def assert_matches_seq(it, seq): t = itertools.tee([1,2,3]) assert list(t[0]) == [1,2,3] assert list(t[0]) == [] + +# itertools.product +it = itertools.product([1, 2], [3, 4]) +assert (1, 3) == next(it) +assert (1, 4) == next(it) +assert (2, 3) == next(it) +assert (2, 4) == next(it) +with assert_raises(StopIteration): + next(it) + +it = itertools.product([1, 2], repeat=2) +assert (1, 1) == next(it) +assert (1, 2) == next(it) +assert (2, 1) == next(it) +assert (2, 2) == next(it) +with assert_raises(StopIteration): + next(it) + +with assert_raises(TypeError): + itertools.product(None) +with assert_raises(TypeError): + itertools.product([1, 2], repeat=None) diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index 195d9d58a2..7b695b98b4 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -1,15 +1,16 @@ use std::cell::{Cell, RefCell}; use std::cmp::Ordering; +use std::iter; use std::ops::{AddAssign, SubAssign}; use std::rc::Rc; use num_bigint::BigInt; use num_traits::ToPrimitive; -use crate::function::{OptionalArg, PyFuncArgs}; +use crate::function::{Args, OptionalArg, PyFuncArgs}; use crate::obj::objbool; use crate::obj::objint::{self, PyInt, PyIntRef}; -use crate::obj::objiter::{call_next, get_iter, new_stop_iteration}; +use crate::obj::objiter::{call_next, get_all, get_iter, new_stop_iteration}; use crate::obj::objtuple::PyTuple; use crate::obj::objtype::{self, PyClassRef}; use crate::pyobject::{ @@ -736,6 +737,123 @@ impl PyItertoolsTee { } } +#[pyclass] +#[derive(Debug)] +struct PyIterToolsProduct { + pools: Vec>, + idxs: RefCell>, + cur: Cell, + stop: Cell, +} + +impl PyValue for PyIterToolsProduct { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("itertools", "product") + } +} + +#[derive(FromArgs)] +struct ProductArgs { + #[pyarg(keyword_only, optional = true)] + repeat: OptionalArg, +} + +#[pyimpl] +impl PyIterToolsProduct { + #[pyslot(new)] + fn tp_new( + cls: PyClassRef, + iterables: Args, + args: ProductArgs, + vm: &VirtualMachine, + ) -> PyResult> { + let repeat = match args.repeat.into_option() { + Some(i) => i, + None => 1, + }; + + let mut pools = Vec::new(); + for arg in iterables.into_iter() { + let it = get_iter(vm, &arg)?; + let pool = get_all(vm, &it)?; + + pools.push(pool); + } + let pools = iter::repeat(pools) + .take(repeat) + .flatten() + .collect::>>(); + + let l = pools.len(); + + PyIterToolsProduct { + pools, + idxs: RefCell::new(vec![0; l]), + cur: Cell::new(l - 1), + stop: Cell::new(false), + } + .into_ref_with_type(vm, cls) + } + + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + // stop signal + if self.stop.get() { + return Err(new_stop_iteration(vm)); + } + + let pools = &self.pools; + + for p in pools { + if p.is_empty() { + return Err(new_stop_iteration(vm)); + } + } + + let res = PyTuple::from( + pools + .iter() + .zip(self.idxs.borrow().iter()) + .map(|(pool, idx)| pool[*idx].clone()) + .collect::>(), + ); + + self.update_idxs(); + + if self.is_end() { + self.stop.set(true); + } + + Ok(res.into_ref(vm).into_object()) + } + + fn is_end(&self) -> bool { + (self.idxs.borrow()[self.cur.get()] == &self.pools[self.cur.get()].len() - 1 + && self.cur.get() == 0) + } + + fn update_idxs(&self) { + let lst_idx = &self.pools[self.cur.get()].len() - 1; + + if self.idxs.borrow()[self.cur.get()] == lst_idx { + if self.is_end() { + return; + } + self.idxs.borrow_mut()[self.cur.get()] = 0; + self.cur.set(self.cur.get() - 1); + self.update_idxs(); + } else { + self.idxs.borrow_mut()[self.cur.get()] += 1; + self.cur.set(self.idxs.borrow().len() - 1); + } + } + + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { + zelf + } +} + pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let ctx = &vm.ctx; @@ -767,6 +885,8 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let tee = ctx.new_class("tee", ctx.object()); PyItertoolsTee::extend_class(ctx, &tee); + let product = ctx.new_class("product", ctx.object()); + PyIterToolsProduct::extend_class(ctx, &product); py_module!(vm, "itertools", { "chain" => chain, @@ -780,5 +900,6 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "filterfalse" => filterfalse, "accumulate" => accumulate, "tee" => tee, + "product" => product, }) }