diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index ad41bdc87e..9d45b26d80 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -3,7 +3,10 @@ pub(crate) use decl::make_module; #[pymodule(name = "itertools")] mod decl { use crate::{ - builtins::{int, PyGenericAlias, PyInt, PyIntRef, PyList, PyTuple, PyTupleRef, PyTypeRef}, + builtins::{ + int, tuple::IntoPyTuple, PyGenericAlias, PyInt, PyIntRef, PyList, PyTuple, PyTupleRef, + PyTypeRef, + }, common::{ lock::{PyMutex, PyRwLock, PyRwLockWriteGuard}, rc::PyRc, @@ -1308,6 +1311,7 @@ mod decl { struct PyItertoolsCombinations { pool: Vec, indices: PyRwLock>, + result: PyRwLock>>, r: AtomicCell, exhausted: AtomicCell, } @@ -1341,6 +1345,7 @@ mod decl { PyItertoolsCombinations { pool, indices: PyRwLock::new((0..r).collect()), + result: PyRwLock::new(None), r: AtomicCell::new(r), exhausted: AtomicCell::new(r > n), } @@ -1350,7 +1355,36 @@ mod decl { } #[pyclass(with(IterNext, Constructor))] - impl PyItertoolsCombinations {} + impl PyItertoolsCombinations { + #[pymethod(magic)] + fn reduce(zelf: PyRef, vm: &VirtualMachine) -> PyTupleRef { + let r = zelf.r.load(); + + let class = zelf.class().to_owned(); + + if zelf.exhausted.load() { + return vm.new_tuple(( + class, + vm.new_tuple((vm.ctx.empty_tuple.clone(), vm.ctx.new_int(r))), + )); + } + + let tup = vm.new_tuple((zelf.pool.clone().into_pytuple(vm), vm.ctx.new_int(r))); + + if zelf.result.read().is_none() { + vm.new_tuple((class, tup)) + } else { + let mut indices: Vec = Vec::new(); + + for item in &zelf.indices.read()[..r] { + indices.push(vm.new_pyobj(*item)); + } + + vm.new_tuple((class, tup, indices.into_pytuple(vm))) + } + } + } + impl IterNextIterable for PyItertoolsCombinations {} impl IterNext for PyItertoolsCombinations { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { @@ -1367,38 +1401,48 @@ mod decl { return Ok(PyIterReturn::Return(vm.new_tuple(()).into())); } - let res = vm.ctx.new_tuple( - zelf.indices - .read() - .iter() - .map(|&i| zelf.pool[i].clone()) - .collect(), - ); + let mut result_lock = zelf.result.write(); + let result = if let Some(ref mut result) = *result_lock { + let mut indices = zelf.indices.write(); - let mut indices = zelf.indices.write(); + // Scan indices right-to-left until finding one that is not at its maximum (i + n - r). + let mut idx = r as isize - 1; + while idx >= 0 && indices[idx as usize] == idx as usize + n - r { + idx -= 1; + } - // Scan indices right-to-left until finding one that is not at its maximum (i + n - r). - let mut idx = r as isize - 1; - while idx >= 0 && indices[idx as usize] == idx as usize + n - r { - idx -= 1; - } + // If no suitable index is found, then the indices are all at + // their maximum value and we're done. + if idx < 0 { + zelf.exhausted.store(true); + return Ok(PyIterReturn::StopIteration(None)); + } else { + // Increment the current index which we know is not at its + // maximum. Then move back to the right setting each index + // to its lowest possible value (one higher than the index + // to its left -- this maintains the sort order invariant). + indices[idx as usize] += 1; + for j in idx as usize + 1..r { + indices[j] = indices[j - 1] + 1; + } - // If no suitable index is found, then the indices are all at - // their maximum value and we're done. - if idx < 0 { - zelf.exhausted.store(true); - } else { - // Increment the current index which we know is not at its - // maximum. Then move back to the right setting each index - // to its lowest possible value (one higher than the index - // to its left -- this maintains the sort order invariant). - indices[idx as usize] += 1; - for j in idx as usize + 1..r { - indices[j] = indices[j - 1] + 1; + // Update the result tuple for the new indices + // starting with i, the leftmost index that changed + for i in idx as usize..r { + let index = indices[i]; + let elem = &zelf.pool[index]; + result[i] = elem.to_owned(); + } + + result.to_vec() } - } + } else { + let res = zelf.pool[0..r].to_vec(); + *result_lock = Some(res.clone()); + res + }; - Ok(PyIterReturn::Return(res.into())) + Ok(PyIterReturn::Return(vm.ctx.new_tuple(result).into())) } }