Skip to content

Commit 389db55

Browse files
authored
Merge pull request #1603 from dralley/itertools
Add itertools.combinations()
2 parents aa1af71 + 16b2b42 commit 389db55

File tree

2 files changed

+150
-17
lines changed

2 files changed

+150
-17
lines changed

tests/snippets/stdlib_itertools.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ def assert_matches_seq(it, seq):
301301
assert list(t[0]) == []
302302

303303
# itertools.product
304+
304305
it = itertools.product([1, 2], [3, 4])
305306
assert (1, 3) == next(it)
306307
assert (1, 4) == next(it)
@@ -321,3 +322,25 @@ def assert_matches_seq(it, seq):
321322
itertools.product(None)
322323
with assert_raises(TypeError):
323324
itertools.product([1, 2], repeat=None)
325+
326+
# itertools.combinations
327+
328+
it = itertools.combinations([1, 2, 3, 4], 2)
329+
assert list(it) == [(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]
330+
331+
it = itertools.combinations([1, 2, 3], 1)
332+
assert list(it) == [(1,), (2,), (3,)]
333+
334+
it = itertools.combinations([1, 2, 3], 2)
335+
assert next(it) == (1, 2)
336+
assert next(it) == (1, 3)
337+
assert next(it) == (2, 3)
338+
with assert_raises(StopIteration):
339+
next(it)
340+
341+
it = itertools.combinations([1, 2, 3], 4)
342+
with assert_raises(StopIteration):
343+
next(it)
344+
345+
with assert_raises(ValueError):
346+
itertools.combinations([1, 2, 3, 4], -2)

vm/src/stdlib/itertools.rs

Lines changed: 127 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use std::ops::{AddAssign, SubAssign};
55
use std::rc::Rc;
66

77
use num_bigint::BigInt;
8+
use num_traits::sign::Signed;
89
use num_traits::ToPrimitive;
910

1011
use crate::function::{Args, OptionalArg, PyFuncArgs};
@@ -733,14 +734,14 @@ impl PyItertoolsTee {
733734

734735
#[pyclass]
735736
#[derive(Debug)]
736-
struct PyIterToolsProduct {
737+
struct PyItertoolsProduct {
737738
pools: Vec<Vec<PyObjectRef>>,
738739
idxs: RefCell<Vec<usize>>,
739740
cur: Cell<usize>,
740741
stop: Cell<bool>,
741742
}
742743

743-
impl PyValue for PyIterToolsProduct {
744+
impl PyValue for PyItertoolsProduct {
744745
fn class(vm: &VirtualMachine) -> PyClassRef {
745746
vm.class("itertools", "product")
746747
}
@@ -753,7 +754,7 @@ struct ProductArgs {
753754
}
754755

755756
#[pyimpl]
756-
impl PyIterToolsProduct {
757+
impl PyItertoolsProduct {
757758
#[pyslot(new)]
758759
fn tp_new(
759760
cls: PyClassRef,
@@ -780,7 +781,7 @@ impl PyIterToolsProduct {
780781

781782
let l = pools.len();
782783

783-
PyIterToolsProduct {
784+
PyItertoolsProduct {
784785
pools,
785786
idxs: RefCell::new(vec![0; l]),
786787
cur: Cell::new(l - 1),
@@ -848,19 +849,137 @@ impl PyIterToolsProduct {
848849
}
849850
}
850851

852+
#[pyclass]
853+
#[derive(Debug)]
854+
struct PyItertoolsCombinations {
855+
pool: Vec<PyObjectRef>,
856+
indices: RefCell<Vec<usize>>,
857+
r: Cell<usize>,
858+
exhausted: Cell<bool>,
859+
}
860+
861+
impl PyValue for PyItertoolsCombinations {
862+
fn class(vm: &VirtualMachine) -> PyClassRef {
863+
vm.class("itertools", "combinations")
864+
}
865+
}
866+
867+
#[pyimpl]
868+
impl PyItertoolsCombinations {
869+
#[pyslot(new)]
870+
fn tp_new(
871+
cls: PyClassRef,
872+
iterable: PyObjectRef,
873+
r: PyIntRef,
874+
vm: &VirtualMachine,
875+
) -> PyResult<PyRef<Self>> {
876+
let iter = get_iter(vm, &iterable)?;
877+
let pool = get_all(vm, &iter)?;
878+
879+
let r = r.as_bigint();
880+
if r.is_negative() {
881+
return Err(vm.new_value_error("r must be non-negative".to_string()));
882+
}
883+
let r = r.to_usize().unwrap();
884+
885+
let n = pool.len();
886+
887+
PyItertoolsCombinations {
888+
pool,
889+
indices: RefCell::new((0..r).collect()),
890+
r: Cell::new(r),
891+
exhausted: Cell::new(r > n),
892+
}
893+
.into_ref_with_type(vm, cls)
894+
}
895+
896+
#[pymethod(name = "__iter__")]
897+
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
898+
zelf
899+
}
900+
901+
#[pymethod(name = "__next__")]
902+
fn next(&self, vm: &VirtualMachine) -> PyResult {
903+
// stop signal
904+
if self.exhausted.get() {
905+
return Err(new_stop_iteration(vm));
906+
}
907+
908+
let n = self.pool.len();
909+
let r = self.r.get();
910+
911+
let res = PyTuple::from(
912+
self.pool
913+
.iter()
914+
.enumerate()
915+
.filter(|(idx, _)| self.indices.borrow().contains(&idx))
916+
.map(|(_, num)| num.clone())
917+
.collect::<Vec<PyObjectRef>>(),
918+
);
919+
920+
let mut indices = self.indices.borrow_mut();
921+
let mut sentinel = false;
922+
923+
// Scan indices right-to-left until finding one that is not at its maximum (i + n - r).
924+
let mut idx = r - 1;
925+
loop {
926+
if indices[idx] != idx + n - r {
927+
sentinel = true;
928+
break;
929+
}
930+
931+
if idx != 0 {
932+
idx -= 1;
933+
} else {
934+
break;
935+
}
936+
}
937+
// If no suitable index is found, then the indices are all at
938+
// their maximum value and we're done.
939+
if !sentinel {
940+
self.exhausted.set(true);
941+
}
942+
943+
// Increment the current index which we know is not at its
944+
// maximum. Then move back to the right setting each index
945+
// to its lowest possible value (one higher than the index
946+
// to its left -- this maintains the sort order invariant).
947+
indices[idx] += 1;
948+
for j in idx + 1..r {
949+
indices[j] = indices[j - 1] + 1;
950+
}
951+
952+
Ok(res.into_ref(vm).into_object())
953+
}
954+
}
955+
851956
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
852957
let ctx = &vm.ctx;
853958

959+
let accumulate = ctx.new_class("accumulate", ctx.object());
960+
PyItertoolsAccumulate::extend_class(ctx, &accumulate);
961+
854962
let chain = PyItertoolsChain::make_class(ctx);
855963

856964
let compress = PyItertoolsCompress::make_class(ctx);
857965

966+
let combinations = ctx.new_class("combinations", ctx.object());
967+
PyItertoolsCombinations::extend_class(ctx, &combinations);
968+
858969
let count = ctx.new_class("count", ctx.object());
859970
PyItertoolsCount::extend_class(ctx, &count);
860971

861972
let dropwhile = ctx.new_class("dropwhile", ctx.object());
862973
PyItertoolsDropwhile::extend_class(ctx, &dropwhile);
863974

975+
let islice = PyItertoolsIslice::make_class(ctx);
976+
977+
let filterfalse = ctx.new_class("filterfalse", ctx.object());
978+
PyItertoolsFilterFalse::extend_class(ctx, &filterfalse);
979+
980+
let product = ctx.new_class("product", ctx.object());
981+
PyItertoolsProduct::extend_class(ctx, &product);
982+
864983
let repeat = ctx.new_class("repeat", ctx.object());
865984
PyItertoolsRepeat::extend_class(ctx, &repeat);
866985

@@ -869,30 +988,21 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
869988
let takewhile = ctx.new_class("takewhile", ctx.object());
870989
PyItertoolsTakewhile::extend_class(ctx, &takewhile);
871990

872-
let islice = PyItertoolsIslice::make_class(ctx);
873-
874-
let filterfalse = ctx.new_class("filterfalse", ctx.object());
875-
PyItertoolsFilterFalse::extend_class(ctx, &filterfalse);
876-
877-
let accumulate = ctx.new_class("accumulate", ctx.object());
878-
PyItertoolsAccumulate::extend_class(ctx, &accumulate);
879-
880991
let tee = ctx.new_class("tee", ctx.object());
881992
PyItertoolsTee::extend_class(ctx, &tee);
882-
let product = ctx.new_class("product", ctx.object());
883-
PyIterToolsProduct::extend_class(ctx, &product);
884993

885994
py_module!(vm, "itertools", {
995+
"accumulate" => accumulate,
886996
"chain" => chain,
887997
"compress" => compress,
998+
"combinations" => combinations,
888999
"count" => count,
8891000
"dropwhile" => dropwhile,
1001+
"islice" => islice,
1002+
"filterfalse" => filterfalse,
8901003
"repeat" => repeat,
8911004
"starmap" => starmap,
8921005
"takewhile" => takewhile,
893-
"islice" => islice,
894-
"filterfalse" => filterfalse,
895-
"accumulate" => accumulate,
8961006
"tee" => tee,
8971007
"product" => product,
8981008
})

0 commit comments

Comments
 (0)