Skip to content

Commit b5b4835

Browse files
committed
Implement itertools.product
This implements `itertools.product` of standard library. Related with #1361
1 parent 40783d1 commit b5b4835

File tree

2 files changed

+166
-3
lines changed

2 files changed

+166
-3
lines changed

tests/snippets/stdlib_itertools.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,3 +299,25 @@ def assert_matches_seq(it, seq):
299299
t = itertools.tee([1,2,3])
300300
assert list(t[0]) == [1,2,3]
301301
assert list(t[0]) == []
302+
303+
# itertools.product
304+
it = itertools.product([1, 2], [3, 4])
305+
assert (1, 3) == next(it)
306+
assert (1, 4) == next(it)
307+
assert (2, 3) == next(it)
308+
assert (2, 4) == next(it)
309+
with assert_raises(StopIteration):
310+
next(it)
311+
312+
it = itertools.product([1, 2], repeat=2)
313+
assert (1, 1) == next(it)
314+
assert (1, 2) == next(it)
315+
assert (2, 1) == next(it)
316+
assert (2, 2) == next(it)
317+
with assert_raises(StopIteration):
318+
next(it)
319+
320+
with assert_raises(TypeError):
321+
itertools.product(None)
322+
with assert_raises(TypeError):
323+
itertools.product([1, 2], repeat=None)

vm/src/stdlib/itertools.rs

Lines changed: 144 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
use std::cell::{Cell, RefCell};
22
use std::cmp::Ordering;
3+
use std::iter;
34
use std::ops::{AddAssign, SubAssign};
45
use std::rc::Rc;
56

6-
use num_bigint::BigInt;
7+
use num_bigint::{BigInt, Sign};
78
use num_traits::ToPrimitive;
89

9-
use crate::function::{OptionalArg, PyFuncArgs};
10+
use crate::function::{Args, OptionalArg, PyFuncArgs};
1011
use crate::obj::objbool;
1112
use crate::obj::objint::{self, PyInt, PyIntRef};
12-
use crate::obj::objiter::{call_next, get_iter, new_stop_iteration};
13+
use crate::obj::objiter::{call_next, get_all, get_iter, new_stop_iteration};
1314
use crate::obj::objtuple::PyTuple;
1415
use crate::obj::objtype::{self, PyClassRef};
1516
use crate::pyobject::{
@@ -736,6 +737,143 @@ impl PyItertoolsTee {
736737
}
737738
}
738739

740+
#[pyclass]
741+
#[derive(Debug)]
742+
struct PyIterToolsProduct {
743+
pools: Vec<Vec<PyObjectRef>>,
744+
idxs: RefCell<Vec<usize>>,
745+
cur: RefCell<usize>,
746+
sizes: Vec<usize>,
747+
stop: RefCell<bool>,
748+
}
749+
750+
impl PyValue for PyIterToolsProduct {
751+
fn class(vm: &VirtualMachine) -> PyClassRef {
752+
vm.class("itertools", "product")
753+
}
754+
}
755+
756+
#[derive(FromArgs)]
757+
struct ProductArgs {
758+
#[pyarg(keyword_only, optional = true)]
759+
repeat: OptionalArg<PyIntRef>,
760+
}
761+
762+
#[pyimpl]
763+
impl PyIterToolsProduct {
764+
#[pyslot(new)]
765+
fn new(
766+
cls: PyClassRef,
767+
iterables: Args<PyObjectRef>,
768+
args: ProductArgs,
769+
vm: &VirtualMachine,
770+
) -> PyResult<PyRef<Self>> {
771+
let repeat = match args.repeat.into_option() {
772+
Some(int) => match int.as_bigint().sign() {
773+
Sign::Plus | Sign::NoSign => match int.as_bigint().to_usize() {
774+
Some(x) => x,
775+
None => {
776+
return Err(vm.new_overflow_error("repeat argument too large".to_string()))
777+
}
778+
},
779+
Sign::Minus => {
780+
return Err(vm.new_value_error("repeat argument cannot be negative".to_string()))
781+
}
782+
},
783+
None => 1,
784+
};
785+
786+
let mut pools = Vec::new();
787+
let mut sizes = Vec::new();
788+
for arg in iterables.into_iter() {
789+
let it = get_iter(vm, &arg)?;
790+
let pool = get_all(vm, &it)?;
791+
let size = pool.len();
792+
793+
pools.push(pool);
794+
sizes.push(size);
795+
}
796+
let pools = iter::repeat(pools)
797+
.take(repeat)
798+
.flatten()
799+
.collect::<Vec<Vec<PyObjectRef>>>();
800+
801+
let sizes = iter::repeat(sizes)
802+
.take(repeat)
803+
.flatten()
804+
.collect::<Vec<usize>>();
805+
806+
let l = pools.len();
807+
808+
PyIterToolsProduct {
809+
pools,
810+
idxs: RefCell::new(vec![0; l]),
811+
cur: RefCell::new(l - 1),
812+
sizes,
813+
stop: RefCell::new(false),
814+
}
815+
.into_ref_with_type(vm, cls)
816+
}
817+
818+
#[pymethod(name = "__next__")]
819+
fn next(&self, vm: &VirtualMachine) -> PyResult {
820+
// stop signal
821+
if *self.stop.borrow() {
822+
return Err(new_stop_iteration(vm));
823+
}
824+
825+
let pools = &self.pools;
826+
827+
for s in &self.sizes {
828+
if *s == 0 {
829+
return Err(new_stop_iteration(vm));
830+
}
831+
}
832+
833+
let res = PyTuple::from(
834+
pools
835+
.iter()
836+
.zip(self.idxs.borrow().iter())
837+
.map(|(pool, idx)| pool[*idx].clone())
838+
.collect::<Vec<PyObjectRef>>(),
839+
);
840+
841+
self.update_idxs();
842+
843+
if self.is_end() {
844+
*self.stop.borrow_mut() = true;
845+
}
846+
847+
Ok(res.into_ref(vm).into_object())
848+
}
849+
850+
fn is_end(&self) -> bool {
851+
(self.idxs.borrow()[*self.cur.borrow()] == &self.sizes[*self.cur.borrow()] - 1
852+
&& *self.cur.borrow() == 0)
853+
}
854+
855+
fn update_idxs(&self) {
856+
let lst_idx = &self.sizes[*self.cur.borrow()] - 1;
857+
858+
if self.idxs.borrow()[*self.cur.borrow()] == lst_idx {
859+
if self.is_end() {
860+
return;
861+
}
862+
self.idxs.borrow_mut()[*self.cur.borrow()] = 0;
863+
*self.cur.borrow_mut() -= 1;
864+
self.update_idxs();
865+
} else {
866+
self.idxs.borrow_mut()[*self.cur.borrow()] += 1;
867+
*self.cur.borrow_mut() = self.idxs.borrow().len() - 1;
868+
}
869+
}
870+
871+
#[pymethod(name = "__iter__")]
872+
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
873+
zelf
874+
}
875+
}
876+
739877
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
740878
let ctx = &vm.ctx;
741879

@@ -767,6 +905,8 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
767905

768906
let tee = ctx.new_class("tee", ctx.object());
769907
PyItertoolsTee::extend_class(ctx, &tee);
908+
let product = ctx.new_class("product", ctx.object());
909+
PyIterToolsProduct::extend_class(ctx, &product);
770910

771911
py_module!(vm, "itertools", {
772912
"chain" => chain,
@@ -780,5 +920,6 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
780920
"filterfalse" => filterfalse,
781921
"accumulate" => accumulate,
782922
"tee" => tee,
923+
"product" => product,
783924
})
784925
}

0 commit comments

Comments
 (0)