Skip to content

Commit 038f486

Browse files
committed
Implement itertools.product
This implements `itertools.product` of standard library. Related with #1361
1 parent 40783d1 commit 038f486

File tree

2 files changed

+145
-2
lines changed

2 files changed

+145
-2
lines changed

tests/snippets/stdlib_itertools.py

+22
Original file line numberDiff line numberDiff line change
@@ -299,3 +299,25 @@ def assert_matches_seq(it, seq):
299299
t = itertools.tee([1,2,3])
300300
assert list(t[0]) == [1,2,3]
301301
assert list(t[0]) == []
302+
303+
# itertools.product
304+
it = itertools.product([1, 2], [3, 4])
305+
assert (1, 3) == next(it)
306+
assert (1, 4) == next(it)
307+
assert (2, 3) == next(it)
308+
assert (2, 4) == next(it)
309+
with assert_raises(StopIteration):
310+
next(it)
311+
312+
it = itertools.product([1, 2], repeat=2)
313+
assert (1, 1) == next(it)
314+
assert (1, 2) == next(it)
315+
assert (2, 1) == next(it)
316+
assert (2, 2) == next(it)
317+
with assert_raises(StopIteration):
318+
next(it)
319+
320+
with assert_raises(TypeError):
321+
itertools.product(None)
322+
with assert_raises(TypeError):
323+
itertools.product([1, 2], repeat=None)

vm/src/stdlib/itertools.rs

+123-2
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
use std::cell::{Cell, RefCell};
22
use std::cmp::Ordering;
3+
use std::iter;
34
use std::ops::{AddAssign, SubAssign};
45
use std::rc::Rc;
56

67
use num_bigint::BigInt;
78
use num_traits::ToPrimitive;
89

9-
use crate::function::{OptionalArg, PyFuncArgs};
10+
use crate::function::{Args, OptionalArg, PyFuncArgs};
1011
use crate::obj::objbool;
1112
use crate::obj::objint::{self, PyInt, PyIntRef};
12-
use crate::obj::objiter::{call_next, get_iter, new_stop_iteration};
13+
use crate::obj::objiter::{call_next, get_all, get_iter, new_stop_iteration};
1314
use crate::obj::objtuple::PyTuple;
1415
use crate::obj::objtype::{self, PyClassRef};
1516
use crate::pyobject::{
@@ -736,6 +737,123 @@ impl PyItertoolsTee {
736737
}
737738
}
738739

740+
#[pyclass]
741+
#[derive(Debug)]
742+
struct PyIterToolsProduct {
743+
pools: Vec<Vec<PyObjectRef>>,
744+
idxs: RefCell<Vec<usize>>,
745+
cur: Cell<usize>,
746+
stop: Cell<bool>,
747+
}
748+
749+
impl PyValue for PyIterToolsProduct {
750+
fn class(vm: &VirtualMachine) -> PyClassRef {
751+
vm.class("itertools", "product")
752+
}
753+
}
754+
755+
#[derive(FromArgs)]
756+
struct ProductArgs {
757+
#[pyarg(keyword_only, optional = true)]
758+
repeat: OptionalArg<usize>,
759+
}
760+
761+
#[pyimpl]
762+
impl PyIterToolsProduct {
763+
#[pyslot(new)]
764+
fn new(
765+
cls: PyClassRef,
766+
iterables: Args<PyObjectRef>,
767+
args: ProductArgs,
768+
vm: &VirtualMachine,
769+
) -> PyResult<PyRef<Self>> {
770+
let repeat = match args.repeat.into_option() {
771+
Some(i) => i,
772+
None => 1,
773+
};
774+
775+
let mut pools = Vec::new();
776+
for arg in iterables.into_iter() {
777+
let it = get_iter(vm, &arg)?;
778+
let pool = get_all(vm, &it)?;
779+
780+
pools.push(pool);
781+
}
782+
let pools = iter::repeat(pools)
783+
.take(repeat)
784+
.flatten()
785+
.collect::<Vec<Vec<PyObjectRef>>>();
786+
787+
let l = pools.len();
788+
789+
PyIterToolsProduct {
790+
pools,
791+
idxs: RefCell::new(vec![0; l]),
792+
cur: Cell::new(l - 1),
793+
stop: Cell::new(false),
794+
}
795+
.into_ref_with_type(vm, cls)
796+
}
797+
798+
#[pymethod(name = "__next__")]
799+
fn next(&self, vm: &VirtualMachine) -> PyResult {
800+
// stop signal
801+
if self.stop.get() {
802+
return Err(new_stop_iteration(vm));
803+
}
804+
805+
let pools = &self.pools;
806+
807+
for p in pools {
808+
if p.is_empty() {
809+
return Err(new_stop_iteration(vm));
810+
}
811+
}
812+
813+
let res = PyTuple::from(
814+
pools
815+
.iter()
816+
.zip(self.idxs.borrow().iter())
817+
.map(|(pool, idx)| pool[*idx].clone())
818+
.collect::<Vec<PyObjectRef>>(),
819+
);
820+
821+
self.update_idxs();
822+
823+
if self.is_end() {
824+
self.stop.set(true);
825+
}
826+
827+
Ok(res.into_ref(vm).into_object())
828+
}
829+
830+
fn is_end(&self) -> bool {
831+
(self.idxs.borrow()[self.cur.get()] == &self.pools[self.cur.get()].len() - 1
832+
&& self.cur.get() == 0)
833+
}
834+
835+
fn update_idxs(&self) {
836+
let lst_idx = &self.pools[self.cur.get()].len() - 1;
837+
838+
if self.idxs.borrow()[self.cur.get()] == lst_idx {
839+
if self.is_end() {
840+
return;
841+
}
842+
self.idxs.borrow_mut()[self.cur.get()] = 0;
843+
self.cur.set(self.cur.get() - 1);
844+
self.update_idxs();
845+
} else {
846+
self.idxs.borrow_mut()[self.cur.get()] += 1;
847+
self.cur.set(self.idxs.borrow().len() - 1);
848+
}
849+
}
850+
851+
#[pymethod(name = "__iter__")]
852+
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
853+
zelf
854+
}
855+
}
856+
739857
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
740858
let ctx = &vm.ctx;
741859

@@ -767,6 +885,8 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
767885

768886
let tee = ctx.new_class("tee", ctx.object());
769887
PyItertoolsTee::extend_class(ctx, &tee);
888+
let product = ctx.new_class("product", ctx.object());
889+
PyIterToolsProduct::extend_class(ctx, &product);
770890

771891
py_module!(vm, "itertools", {
772892
"chain" => chain,
@@ -780,5 +900,6 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
780900
"filterfalse" => filterfalse,
781901
"accumulate" => accumulate,
782902
"tee" => tee,
903+
"product" => product,
783904
})
784905
}

0 commit comments

Comments
 (0)