Skip to content

Commit 11e8e62

Browse files
Merge pull request #1526 from seeeturtle/itertools
Implement itertools.product
2 parents a828249 + 8deb936 commit 11e8e62

File tree

2 files changed

+145
-2
lines changed

2 files changed

+145
-2
lines changed

tests/snippets/stdlib_itertools.py

+22
Original file line numberDiff line numberDiff line change
@@ -299,3 +299,25 @@ def assert_matches_seq(it, seq):
299299
t = itertools.tee([1,2,3])
300300
assert list(t[0]) == [1,2,3]
301301
assert list(t[0]) == []
302+
303+
# itertools.product
304+
it = itertools.product([1, 2], [3, 4])
305+
assert (1, 3) == next(it)
306+
assert (1, 4) == next(it)
307+
assert (2, 3) == next(it)
308+
assert (2, 4) == next(it)
309+
with assert_raises(StopIteration):
310+
next(it)
311+
312+
it = itertools.product([1, 2], repeat=2)
313+
assert (1, 1) == next(it)
314+
assert (1, 2) == next(it)
315+
assert (2, 1) == next(it)
316+
assert (2, 2) == next(it)
317+
with assert_raises(StopIteration):
318+
next(it)
319+
320+
with assert_raises(TypeError):
321+
itertools.product(None)
322+
with assert_raises(TypeError):
323+
itertools.product([1, 2], repeat=None)

vm/src/stdlib/itertools.rs

+123-2
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
use std::cell::{Cell, RefCell};
22
use std::cmp::Ordering;
3+
use std::iter;
34
use std::ops::{AddAssign, SubAssign};
45
use std::rc::Rc;
56

67
use num_bigint::BigInt;
78
use num_traits::ToPrimitive;
89

9-
use crate::function::{OptionalArg, PyFuncArgs};
10+
use crate::function::{Args, OptionalArg, PyFuncArgs};
1011
use crate::obj::objbool;
1112
use crate::obj::objint::{self, PyInt, PyIntRef};
12-
use crate::obj::objiter::{call_next, get_iter, new_stop_iteration};
13+
use crate::obj::objiter::{call_next, get_all, get_iter, new_stop_iteration};
1314
use crate::obj::objtuple::PyTuple;
1415
use crate::obj::objtype::{self, PyClassRef};
1516
use crate::pyobject::{
@@ -730,6 +731,123 @@ impl PyItertoolsTee {
730731
}
731732
}
732733

734+
#[pyclass]
735+
#[derive(Debug)]
736+
struct PyIterToolsProduct {
737+
pools: Vec<Vec<PyObjectRef>>,
738+
idxs: RefCell<Vec<usize>>,
739+
cur: Cell<usize>,
740+
stop: Cell<bool>,
741+
}
742+
743+
impl PyValue for PyIterToolsProduct {
744+
fn class(vm: &VirtualMachine) -> PyClassRef {
745+
vm.class("itertools", "product")
746+
}
747+
}
748+
749+
#[derive(FromArgs)]
750+
struct ProductArgs {
751+
#[pyarg(keyword_only, optional = true)]
752+
repeat: OptionalArg<usize>,
753+
}
754+
755+
#[pyimpl]
756+
impl PyIterToolsProduct {
757+
#[pyslot(new)]
758+
fn tp_new(
759+
cls: PyClassRef,
760+
iterables: Args<PyObjectRef>,
761+
args: ProductArgs,
762+
vm: &VirtualMachine,
763+
) -> PyResult<PyRef<Self>> {
764+
let repeat = match args.repeat.into_option() {
765+
Some(i) => i,
766+
None => 1,
767+
};
768+
769+
let mut pools = Vec::new();
770+
for arg in iterables.into_iter() {
771+
let it = get_iter(vm, &arg)?;
772+
let pool = get_all(vm, &it)?;
773+
774+
pools.push(pool);
775+
}
776+
let pools = iter::repeat(pools)
777+
.take(repeat)
778+
.flatten()
779+
.collect::<Vec<Vec<PyObjectRef>>>();
780+
781+
let l = pools.len();
782+
783+
PyIterToolsProduct {
784+
pools,
785+
idxs: RefCell::new(vec![0; l]),
786+
cur: Cell::new(l - 1),
787+
stop: Cell::new(false),
788+
}
789+
.into_ref_with_type(vm, cls)
790+
}
791+
792+
#[pymethod(name = "__next__")]
793+
fn next(&self, vm: &VirtualMachine) -> PyResult {
794+
// stop signal
795+
if self.stop.get() {
796+
return Err(new_stop_iteration(vm));
797+
}
798+
799+
let pools = &self.pools;
800+
801+
for p in pools {
802+
if p.is_empty() {
803+
return Err(new_stop_iteration(vm));
804+
}
805+
}
806+
807+
let res = PyTuple::from(
808+
pools
809+
.iter()
810+
.zip(self.idxs.borrow().iter())
811+
.map(|(pool, idx)| pool[*idx].clone())
812+
.collect::<Vec<PyObjectRef>>(),
813+
);
814+
815+
self.update_idxs();
816+
817+
if self.is_end() {
818+
self.stop.set(true);
819+
}
820+
821+
Ok(res.into_ref(vm).into_object())
822+
}
823+
824+
fn is_end(&self) -> bool {
825+
(self.idxs.borrow()[self.cur.get()] == &self.pools[self.cur.get()].len() - 1
826+
&& self.cur.get() == 0)
827+
}
828+
829+
fn update_idxs(&self) {
830+
let lst_idx = &self.pools[self.cur.get()].len() - 1;
831+
832+
if self.idxs.borrow()[self.cur.get()] == lst_idx {
833+
if self.is_end() {
834+
return;
835+
}
836+
self.idxs.borrow_mut()[self.cur.get()] = 0;
837+
self.cur.set(self.cur.get() - 1);
838+
self.update_idxs();
839+
} else {
840+
self.idxs.borrow_mut()[self.cur.get()] += 1;
841+
self.cur.set(self.idxs.borrow().len() - 1);
842+
}
843+
}
844+
845+
#[pymethod(name = "__iter__")]
846+
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
847+
zelf
848+
}
849+
}
850+
733851
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
734852
let ctx = &vm.ctx;
735853

@@ -761,6 +879,8 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
761879

762880
let tee = ctx.new_class("tee", ctx.object());
763881
PyItertoolsTee::extend_class(ctx, &tee);
882+
let product = ctx.new_class("product", ctx.object());
883+
PyIterToolsProduct::extend_class(ctx, &product);
764884

765885
py_module!(vm, "itertools", {
766886
"chain" => chain,
@@ -774,5 +894,6 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
774894
"filterfalse" => filterfalse,
775895
"accumulate" => accumulate,
776896
"tee" => tee,
897+
"product" => product,
777898
})
778899
}

0 commit comments

Comments
 (0)