diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index ad41bdc87e..0a617e090e 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -1308,6 +1308,7 @@ mod decl { struct PyItertoolsCombinations { pool: Vec, indices: PyRwLock>, + result: PyRwLock>>, r: AtomicCell, exhausted: AtomicCell, } @@ -1341,6 +1342,7 @@ mod decl { PyItertoolsCombinations { pool, indices: PyRwLock::new((0..r).collect()), + result: PyRwLock::new(None), r: AtomicCell::new(r), exhausted: AtomicCell::new(r > n), } @@ -1350,7 +1352,39 @@ mod decl { } #[pyclass(with(IterNext, Constructor))] - impl PyItertoolsCombinations {} + impl PyItertoolsCombinations { + #[pymethod(magic)] + fn reduce(zelf: PyRef, vm: &VirtualMachine) -> PyTupleRef { + let result = zelf.result.read(); + if let Some(result) = &*result { + if zelf.exhausted.load() { + vm.new_tuple(( + zelf.class().to_owned(), + vm.new_tuple((vm.new_tuple(()), vm.ctx.new_int(zelf.r.load()))), + )) + } else { + vm.new_tuple(( + zelf.class().to_owned(), + vm.new_tuple(( + vm.new_tuple(zelf.pool.clone()), + vm.ctx.new_int(zelf.r.load()), + )), + vm.ctx + .new_tuple(result.iter().map(|&i| zelf.pool[i].clone()).collect()), + )) + } + } else { + vm.new_tuple(( + zelf.class().to_owned(), + vm.new_tuple(( + vm.new_tuple(zelf.pool.clone()), + vm.ctx.new_int(zelf.r.load()), + )), + )) + } + } + } + impl IterNextIterable for PyItertoolsCombinations {} impl IterNext for PyItertoolsCombinations { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { @@ -1367,38 +1401,51 @@ mod decl { return Ok(PyIterReturn::Return(vm.new_tuple(()).into())); } - let res = vm.ctx.new_tuple( - zelf.indices - .read() - .iter() - .map(|&i| zelf.pool[i].clone()) - .collect(), - ); + let mut result = zelf.result.write(); - let mut indices = zelf.indices.write(); + if let Some(ref mut result) = *result { + let mut indices = zelf.indices.write(); - // Scan indices right-to-left until finding one that is not at its maximum (i + n - r). - let mut idx = r as isize - 1; - while idx >= 0 && indices[idx as usize] == idx as usize + n - r { - idx -= 1; - } + // Scan indices right-to-left until finding one that is not at its maximum (i + n - r). + let mut idx = r as isize - 1; + while idx >= 0 && indices[idx as usize] == idx as usize + n - r { + idx -= 1; + } - // If no suitable index is found, then the indices are all at - // their maximum value and we're done. - if idx < 0 { - zelf.exhausted.store(true); - } else { - // Increment the current index which we know is not at its - // maximum. Then move back to the right setting each index - // to its lowest possible value (one higher than the index - // to its left -- this maintains the sort order invariant). - indices[idx as usize] += 1; - for j in idx as usize + 1..r { - indices[j] = indices[j - 1] + 1; + // If no suitable index is found, then the indices are all at + // their maximum value and we're done. + if idx < 0 { + zelf.exhausted.store(true); + return Ok(PyIterReturn::StopIteration(None)); + } else { + // Increment the current index which we know is not at its + // maximum. Then move back to the right setting each index + // to its lowest possible value (one higher than the index + // to its left -- this maintains the sort order invariant). + indices[idx as usize] += 1; + for j in idx as usize + 1..r { + indices[j] = indices[j - 1] + 1; + } + for j in 0..r { + result[j] = indices[j]; + } } + } else { + *result = Some((0..r).collect()); } - Ok(PyIterReturn::Return(res.into())) + Ok(PyIterReturn::Return( + vm.ctx + .new_tuple( + result + .as_ref() + .unwrap() + .iter() + .map(|&i| zelf.pool[i].clone()) + .collect(), + ) + .into(), + )) } }