Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions tests/snippets/stdlib_itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ def assert_matches_seq(it, seq):
assert list(t[0]) == []

# itertools.product

it = itertools.product([1, 2], [3, 4])
assert (1, 3) == next(it)
assert (1, 4) == next(it)
Expand All @@ -321,3 +322,25 @@ def assert_matches_seq(it, seq):
itertools.product(None)
with assert_raises(TypeError):
itertools.product([1, 2], repeat=None)

# itertools.combinations

it = itertools.combinations([1, 2, 3, 4], 2)
assert list(it) == [(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]

it = itertools.combinations([1, 2, 3], 1)
assert list(it) == [(1,), (2,), (3,)]

it = itertools.combinations([1, 2, 3], 2)
assert next(it) == (1, 2)
assert next(it) == (1, 3)
assert next(it) == (2, 3)
with assert_raises(StopIteration):
next(it)

it = itertools.combinations([1, 2, 3], 4)
with assert_raises(StopIteration):
next(it)

with assert_raises(ValueError):
itertools.combinations([1, 2, 3, 4], -2)
144 changes: 127 additions & 17 deletions vm/src/stdlib/itertools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::ops::{AddAssign, SubAssign};
use std::rc::Rc;

use num_bigint::BigInt;
use num_traits::sign::Signed;
use num_traits::ToPrimitive;

use crate::function::{Args, OptionalArg, PyFuncArgs};
Expand Down Expand Up @@ -733,14 +734,14 @@ impl PyItertoolsTee {

#[pyclass]
#[derive(Debug)]
struct PyIterToolsProduct {
struct PyItertoolsProduct {
pools: Vec<Vec<PyObjectRef>>,
idxs: RefCell<Vec<usize>>,
cur: Cell<usize>,
stop: Cell<bool>,
}

impl PyValue for PyIterToolsProduct {
impl PyValue for PyItertoolsProduct {
fn class(vm: &VirtualMachine) -> PyClassRef {
vm.class("itertools", "product")
}
Expand All @@ -753,7 +754,7 @@ struct ProductArgs {
}

#[pyimpl]
impl PyIterToolsProduct {
impl PyItertoolsProduct {
#[pyslot(new)]
fn tp_new(
cls: PyClassRef,
Expand All @@ -780,7 +781,7 @@ impl PyIterToolsProduct {

let l = pools.len();

PyIterToolsProduct {
PyItertoolsProduct {
pools,
idxs: RefCell::new(vec![0; l]),
cur: Cell::new(l - 1),
Expand Down Expand Up @@ -848,19 +849,137 @@ impl PyIterToolsProduct {
}
}

#[pyclass]
#[derive(Debug)]
struct PyItertoolsCombinations {
pool: Vec<PyObjectRef>,
indices: RefCell<Vec<usize>>,
r: Cell<usize>,
exhausted: Cell<bool>,
}

impl PyValue for PyItertoolsCombinations {
fn class(vm: &VirtualMachine) -> PyClassRef {
vm.class("itertools", "combinations")
}
}

#[pyimpl]
impl PyItertoolsCombinations {
#[pyslot(new)]
fn tp_new(
cls: PyClassRef,
iterable: PyObjectRef,
r: PyIntRef,
vm: &VirtualMachine,
) -> PyResult<PyRef<Self>> {
let iter = get_iter(vm, &iterable)?;
let pool = get_all(vm, &iter)?;

let r = r.as_bigint();
if r.is_negative() {
return Err(vm.new_value_error("r must be non-negative".to_string()));
}
let r = r.to_usize().unwrap();

let n = pool.len();

PyItertoolsCombinations {
pool,
indices: RefCell::new((0..r).collect()),
r: Cell::new(r),
exhausted: Cell::new(r > n),
}
.into_ref_with_type(vm, cls)
}

#[pymethod(name = "__iter__")]
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
zelf
}

#[pymethod(name = "__next__")]
fn next(&self, vm: &VirtualMachine) -> PyResult {
// stop signal
if self.exhausted.get() {
return Err(new_stop_iteration(vm));
}

let n = self.pool.len();
let r = self.r.get();

let res = PyTuple::from(
self.pool
.iter()
.enumerate()
.filter(|(idx, _)| self.indices.borrow().contains(&idx))
.map(|(_, num)| num.clone())
.collect::<Vec<PyObjectRef>>(),
);

let mut indices = self.indices.borrow_mut();
let mut sentinel = false;

// Scan indices right-to-left until finding one that is not at its maximum (i + n - r).
let mut idx = r - 1;
loop {
if indices[idx] != idx + n - r {
sentinel = true;
break;
}

if idx != 0 {
idx -= 1;
} else {
break;
}
}
// If no suitable index is found, then the indices are all at
// their maximum value and we're done.
if !sentinel {
self.exhausted.set(true);
}

// Increment the current index which we know is not at its
// maximum. Then move back to the right setting each index
// to its lowest possible value (one higher than the index
// to its left -- this maintains the sort order invariant).
indices[idx] += 1;
for j in idx + 1..r {
indices[j] = indices[j - 1] + 1;
}

Ok(res.into_ref(vm).into_object())
}
}

pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
let ctx = &vm.ctx;

let accumulate = ctx.new_class("accumulate", ctx.object());
PyItertoolsAccumulate::extend_class(ctx, &accumulate);

let chain = PyItertoolsChain::make_class(ctx);

let compress = PyItertoolsCompress::make_class(ctx);

let combinations = ctx.new_class("combinations", ctx.object());
PyItertoolsCombinations::extend_class(ctx, &combinations);

let count = ctx.new_class("count", ctx.object());
PyItertoolsCount::extend_class(ctx, &count);

let dropwhile = ctx.new_class("dropwhile", ctx.object());
PyItertoolsDropwhile::extend_class(ctx, &dropwhile);

let islice = PyItertoolsIslice::make_class(ctx);

let filterfalse = ctx.new_class("filterfalse", ctx.object());
PyItertoolsFilterFalse::extend_class(ctx, &filterfalse);

let product = ctx.new_class("product", ctx.object());
PyItertoolsProduct::extend_class(ctx, &product);

let repeat = ctx.new_class("repeat", ctx.object());
PyItertoolsRepeat::extend_class(ctx, &repeat);

Expand All @@ -869,30 +988,21 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
let takewhile = ctx.new_class("takewhile", ctx.object());
PyItertoolsTakewhile::extend_class(ctx, &takewhile);

let islice = PyItertoolsIslice::make_class(ctx);

let filterfalse = ctx.new_class("filterfalse", ctx.object());
PyItertoolsFilterFalse::extend_class(ctx, &filterfalse);

let accumulate = ctx.new_class("accumulate", ctx.object());
PyItertoolsAccumulate::extend_class(ctx, &accumulate);

let tee = ctx.new_class("tee", ctx.object());
PyItertoolsTee::extend_class(ctx, &tee);
let product = ctx.new_class("product", ctx.object());
PyIterToolsProduct::extend_class(ctx, &product);

py_module!(vm, "itertools", {
"accumulate" => accumulate,
"chain" => chain,
"compress" => compress,
"combinations" => combinations,
"count" => count,
"dropwhile" => dropwhile,
"islice" => islice,
"filterfalse" => filterfalse,
"repeat" => repeat,
"starmap" => starmap,
"takewhile" => takewhile,
"islice" => islice,
"filterfalse" => filterfalse,
"accumulate" => accumulate,
"tee" => tee,
"product" => product,
})
Expand Down