Skip to content

Commit 455d8d6

Browse files
committed
Add itertools.combinations()
re: #1361
1 parent 53b3911 commit 455d8d6

File tree

2 files changed

+129
-0
lines changed

2 files changed

+129
-0
lines changed

tests/snippets/stdlib_itertools.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ def assert_matches_seq(it, seq):
301301
assert list(t[0]) == []
302302

303303
# itertools.product
304+
304305
it = itertools.product([1, 2], [3, 4])
305306
assert (1, 3) == next(it)
306307
assert (1, 4) == next(it)
@@ -321,3 +322,25 @@ def assert_matches_seq(it, seq):
321322
itertools.product(None)
322323
with assert_raises(TypeError):
323324
itertools.product([1, 2], repeat=None)
325+
326+
# itertools.combinations
327+
328+
it = itertools.combinations([1, 2, 3, 4], 2)
329+
assert list(it) == [(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]
330+
331+
it = itertools.combinations([1, 2, 3], 1)
332+
assert list(it) == [(1,), (2,), (3,)]
333+
334+
it = itertools.combinations([1, 2, 3], 2)
335+
assert next(it) == (1, 2)
336+
assert next(it) == (1, 3)
337+
assert next(it) == (2, 3)
338+
with assert_raises(StopIteration):
339+
next(it)
340+
341+
it = itertools.combinations([1, 2, 3], 4)
342+
with assert_raises(StopIteration):
343+
next(it)
344+
345+
with assert_raises(ValueError):
346+
itertools.combinations([1, 2, 3, 4], -2)

vm/src/stdlib/itertools.rs

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,108 @@ impl PyItertoolsProduct {
848848
}
849849
}
850850

851+
#[pyclass]
852+
#[derive(Debug)]
853+
struct PyItertoolsCombinations {
854+
pool: Vec<PyObjectRef>,
855+
indices: RefCell<Vec<usize>>,
856+
r: Cell<usize>,
857+
exhausted: Cell<bool>,
858+
}
859+
860+
impl PyValue for PyItertoolsCombinations {
861+
fn class(vm: &VirtualMachine) -> PyClassRef {
862+
vm.class("itertools", "combinations")
863+
}
864+
}
865+
866+
#[pyimpl]
867+
impl PyItertoolsCombinations {
868+
#[pyslot(new)]
869+
fn tp_new(
870+
cls: PyClassRef,
871+
iterable: PyObjectRef,
872+
r: isize,
873+
vm: &VirtualMachine,
874+
) -> PyResult<PyRef<Self>> {
875+
let iter = get_iter(vm, &iterable)?;
876+
let pool = get_all(vm, &iter)?;
877+
878+
if r < 0 {
879+
return Err(vm.new_value_error("r must be non-negative".to_string()));
880+
}
881+
882+
let n = pool.len();
883+
884+
PyItertoolsCombinations {
885+
pool,
886+
indices: RefCell::new((0..r as usize).collect()),
887+
r: Cell::new(r as usize),
888+
exhausted: Cell::new(r as usize > n),
889+
}
890+
.into_ref_with_type(vm, cls)
891+
}
892+
893+
#[pymethod(name = "__iter__")]
894+
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
895+
zelf
896+
}
897+
898+
#[pymethod(name = "__next__")]
899+
fn next(&self, vm: &VirtualMachine) -> PyResult {
900+
// stop signal
901+
if self.exhausted.get() {
902+
return Err(new_stop_iteration(vm));
903+
}
904+
905+
let n = self.pool.len();
906+
let r = self.r.get();
907+
908+
let res = PyTuple::from(
909+
self.pool
910+
.iter()
911+
.enumerate()
912+
.filter(|(idx, _)| self.indices.borrow().contains(&idx))
913+
.map(|(_, num)| num.clone())
914+
.collect::<Vec<PyObjectRef>>(),
915+
);
916+
917+
let mut indices = self.indices.borrow_mut();
918+
let mut sentinel = false;
919+
920+
// Scan indices right-to-left until finding one that is not at its maximum (i + n - r).
921+
let mut idx = r - 1;
922+
loop {
923+
if indices[idx] != idx + n - r {
924+
sentinel = true;
925+
break;
926+
}
927+
928+
if idx != 0 {
929+
idx -= 1;
930+
} else {
931+
break;
932+
}
933+
}
934+
// If no suitable index is found, then the indices are all at
935+
// their maximum value and we're done.
936+
if !sentinel {
937+
self.exhausted.set(true);
938+
}
939+
940+
// Increment the current index which we know is not at its
941+
// maximum. Then move back to the right setting each index
942+
// to its lowest possible value (one higher than the index
943+
// to its left -- this maintains the sort order invariant).
944+
indices[idx] += 1;
945+
for j in idx + 1..r {
946+
indices[j] = indices[j - 1] + 1;
947+
}
948+
949+
Ok(res.into_ref(vm).into_object())
950+
}
951+
}
952+
851953
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
852954
let ctx = &vm.ctx;
853955

@@ -858,6 +960,9 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
858960

859961
let compress = PyItertoolsCompress::make_class(ctx);
860962

963+
let combinations = ctx.new_class("combinations", ctx.object());
964+
PyItertoolsCombinations::extend_class(ctx, &combinations);
965+
861966
let count = ctx.new_class("count", ctx.object());
862967
PyItertoolsCount::extend_class(ctx, &count);
863968

@@ -887,6 +992,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
887992
"accumulate" => accumulate,
888993
"chain" => chain,
889994
"compress" => compress,
995+
"combinations" => combinations,
890996
"count" => count,
891997
"dropwhile" => dropwhile,
892998
"islice" => islice,

0 commit comments

Comments
 (0)