From 4bbca2bed28d020cc6c2d66e6ead7912bedbe552 Mon Sep 17 00:00:00 2001 From: Daniel Alley Date: Wed, 25 Dec 2019 22:07:02 -0500 Subject: [PATCH] Add itertools.combinations_with_replacement() --- tests/snippets/stdlib_itertools.py | 19 ++++++ vm/src/stdlib/itertools.rs | 99 ++++++++++++++++++++++++++++++ 2 files changed, 118 insertions(+) diff --git a/tests/snippets/stdlib_itertools.py b/tests/snippets/stdlib_itertools.py index b8bc0ae3f6..05376d4b05 100644 --- a/tests/snippets/stdlib_itertools.py +++ b/tests/snippets/stdlib_itertools.py @@ -405,6 +405,25 @@ def assert_matches_seq(it, seq): with assert_raises(TypeError): itertools.combinations([1, 2, 3, 4], None) +# itertools.combinations +it = itertools.combinations_with_replacement([1, 2, 3], 0) +assert list(it) == [()] + +it = itertools.combinations_with_replacement([1, 2, 3], 1) +assert list(it) == [(1,), (2,), (3,)] + +it = itertools.combinations_with_replacement([1, 2, 3], 2) +assert list(it) == [(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)] + +it = itertools.combinations_with_replacement([1, 2], 3) +assert list(it) == [(1, 1, 1), (1, 1, 2), (1, 2, 2), (2, 2, 2)] + +with assert_raises(ValueError): + itertools.combinations_with_replacement([1, 2, 3, 4], -2) + +with assert_raises(TypeError): + itertools.combinations_with_replacement([1, 2, 3, 4], None) + # itertools.permutations it = itertools.permutations([1, 2, 3]) assert list(it) == [(1, 2, 3), (1, 3, 2), (2, 1, 3), (2, 3, 1), (3, 1, 2), (3, 2, 1)] diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index 84b690eb8c..1247b38a87 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -1031,6 +1031,100 @@ impl PyItertoolsCombinations { } } +#[pyclass] +#[derive(Debug)] +struct PyItertoolsCombinationsWithReplacement { + pool: Vec, + indices: RefCell>, + r: Cell, + exhausted: Cell, +} + +impl PyValue for PyItertoolsCombinationsWithReplacement { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("itertools", "combinations_with_replacement") + } +} + +#[pyimpl] +impl PyItertoolsCombinationsWithReplacement { + #[pyslot(new)] + fn tp_new( + cls: PyClassRef, + iterable: PyObjectRef, + r: PyIntRef, + vm: &VirtualMachine, + ) -> PyResult> { + let iter = get_iter(vm, &iterable)?; + let pool = get_all(vm, &iter)?; + + let r = r.as_bigint(); + if r.is_negative() { + return Err(vm.new_value_error("r must be non-negative".to_string())); + } + let r = r.to_usize().unwrap(); + + let n = pool.len(); + + PyItertoolsCombinationsWithReplacement { + pool, + indices: RefCell::new(vec![0; r]), + r: Cell::new(r), + exhausted: Cell::new(n == 0 && r > 0), + } + .into_ref_with_type(vm, cls) + } + + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { + zelf + } + + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + // stop signal + if self.exhausted.get() { + return Err(new_stop_iteration(vm)); + } + + let n = self.pool.len(); + let r = self.r.get(); + + if r == 0 { + self.exhausted.set(true); + return Ok(vm.ctx.new_tuple(vec![])); + } + + let mut indices = self.indices.borrow_mut(); + + let res = vm + .ctx + .new_tuple(indices.iter().map(|&i| self.pool[i].clone()).collect()); + + // Scan indices right-to-left until finding one that is not at its maximum (i + n - r). + let mut idx = r as isize - 1; + while idx >= 0 && indices[idx as usize] == n - 1 { + idx -= 1; + } + + // If no suitable index is found, then the indices are all at + // their maximum value and we're done. + if idx < 0 { + self.exhausted.set(true); + } else { + let index = indices[idx as usize] + 1; + + // Increment the current index which we know is not at its + // maximum. Then set all to the right to the same value. + for j in idx as usize..r { + indices[j as usize] = index as usize; + } + } + + Ok(res) + } +} + #[pyclass] #[derive(Debug)] struct PyItertoolsPermutations { @@ -1257,6 +1351,10 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let combinations = ctx.new_class("combinations", ctx.object()); PyItertoolsCombinations::extend_class(ctx, &combinations); + let combinations_with_replacement = + ctx.new_class("combinations_with_replacement", ctx.object()); + PyItertoolsCombinationsWithReplacement::extend_class(ctx, &combinations_with_replacement); + let count = ctx.new_class("count", ctx.object()); PyItertoolsCount::extend_class(ctx, &count); @@ -1296,6 +1394,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "chain" => chain, "compress" => compress, "combinations" => combinations, + "combinations_with_replacement" => combinations_with_replacement, "count" => count, "cycle" => cycle, "dropwhile" => dropwhile,