Skip to content

Commit 17944d3

Browse files
authored
Merge pull request RustPython#4272 from oow214/iter_combi_reduce
Add `combinations.__reduce__ `
2 parents 56b97fb + dd93ec3 commit 17944d3

File tree

1 file changed

+73
-29
lines changed

1 file changed

+73
-29
lines changed

vm/src/stdlib/itertools.rs

Lines changed: 73 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@ pub(crate) use decl::make_module;
33
#[pymodule(name = "itertools")]
44
mod decl {
55
use crate::{
6-
builtins::{int, PyGenericAlias, PyInt, PyIntRef, PyList, PyTuple, PyTupleRef, PyTypeRef},
6+
builtins::{
7+
int, tuple::IntoPyTuple, PyGenericAlias, PyInt, PyIntRef, PyList, PyTuple, PyTupleRef,
8+
PyTypeRef,
9+
},
710
common::{
811
lock::{PyMutex, PyRwLock, PyRwLockWriteGuard},
912
rc::PyRc,
@@ -1308,6 +1311,7 @@ mod decl {
13081311
struct PyItertoolsCombinations {
13091312
pool: Vec<PyObjectRef>,
13101313
indices: PyRwLock<Vec<usize>>,
1314+
result: PyRwLock<Option<Vec<PyObjectRef>>>,
13111315
r: AtomicCell<usize>,
13121316
exhausted: AtomicCell<bool>,
13131317
}
@@ -1341,6 +1345,7 @@ mod decl {
13411345
PyItertoolsCombinations {
13421346
pool,
13431347
indices: PyRwLock::new((0..r).collect()),
1348+
result: PyRwLock::new(None),
13441349
r: AtomicCell::new(r),
13451350
exhausted: AtomicCell::new(r > n),
13461351
}
@@ -1350,7 +1355,36 @@ mod decl {
13501355
}
13511356

13521357
#[pyclass(with(IterNext, Constructor))]
1353-
impl PyItertoolsCombinations {}
1358+
impl PyItertoolsCombinations {
1359+
#[pymethod(magic)]
1360+
fn reduce(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyTupleRef {
1361+
let r = zelf.r.load();
1362+
1363+
let class = zelf.class().to_owned();
1364+
1365+
if zelf.exhausted.load() {
1366+
return vm.new_tuple((
1367+
class,
1368+
vm.new_tuple((vm.ctx.empty_tuple.clone(), vm.ctx.new_int(r))),
1369+
));
1370+
}
1371+
1372+
let tup = vm.new_tuple((zelf.pool.clone().into_pytuple(vm), vm.ctx.new_int(r)));
1373+
1374+
if zelf.result.read().is_none() {
1375+
vm.new_tuple((class, tup))
1376+
} else {
1377+
let mut indices: Vec<PyObjectRef> = Vec::new();
1378+
1379+
for item in &zelf.indices.read()[..r] {
1380+
indices.push(vm.new_pyobj(*item));
1381+
}
1382+
1383+
vm.new_tuple((class, tup, indices.into_pytuple(vm)))
1384+
}
1385+
}
1386+
}
1387+
13541388
impl IterNextIterable for PyItertoolsCombinations {}
13551389
impl IterNext for PyItertoolsCombinations {
13561390
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
@@ -1367,38 +1401,48 @@ mod decl {
13671401
return Ok(PyIterReturn::Return(vm.new_tuple(()).into()));
13681402
}
13691403

1370-
let res = vm.ctx.new_tuple(
1371-
zelf.indices
1372-
.read()
1373-
.iter()
1374-
.map(|&i| zelf.pool[i].clone())
1375-
.collect(),
1376-
);
1404+
let mut result_lock = zelf.result.write();
1405+
let result = if let Some(ref mut result) = *result_lock {
1406+
let mut indices = zelf.indices.write();
13771407

1378-
let mut indices = zelf.indices.write();
1408+
// Scan indices right-to-left until finding one that is not at its maximum (i + n - r).
1409+
let mut idx = r as isize - 1;
1410+
while idx >= 0 && indices[idx as usize] == idx as usize + n - r {
1411+
idx -= 1;
1412+
}
13791413

1380-
// Scan indices right-to-left until finding one that is not at its maximum (i + n - r).
1381-
let mut idx = r as isize - 1;
1382-
while idx >= 0 && indices[idx as usize] == idx as usize + n - r {
1383-
idx -= 1;
1384-
}
1414+
// If no suitable index is found, then the indices are all at
1415+
// their maximum value and we're done.
1416+
if idx < 0 {
1417+
zelf.exhausted.store(true);
1418+
return Ok(PyIterReturn::StopIteration(None));
1419+
} else {
1420+
// Increment the current index which we know is not at its
1421+
// maximum. Then move back to the right setting each index
1422+
// to its lowest possible value (one higher than the index
1423+
// to its left -- this maintains the sort order invariant).
1424+
indices[idx as usize] += 1;
1425+
for j in idx as usize + 1..r {
1426+
indices[j] = indices[j - 1] + 1;
1427+
}
13851428

1386-
// If no suitable index is found, then the indices are all at
1387-
// their maximum value and we're done.
1388-
if idx < 0 {
1389-
zelf.exhausted.store(true);
1390-
} else {
1391-
// Increment the current index which we know is not at its
1392-
// maximum. Then move back to the right setting each index
1393-
// to its lowest possible value (one higher than the index
1394-
// to its left -- this maintains the sort order invariant).
1395-
indices[idx as usize] += 1;
1396-
for j in idx as usize + 1..r {
1397-
indices[j] = indices[j - 1] + 1;
1429+
// Update the result tuple for the new indices
1430+
// starting with i, the leftmost index that changed
1431+
for i in idx as usize..r {
1432+
let index = indices[i];
1433+
let elem = &zelf.pool[index];
1434+
result[i] = elem.to_owned();
1435+
}
1436+
1437+
result.to_vec()
13981438
}
1399-
}
1439+
} else {
1440+
let res = zelf.pool[0..r].to_vec();
1441+
*result_lock = Some(res.clone());
1442+
res
1443+
};
14001444

1401-
Ok(PyIterReturn::Return(res.into()))
1445+
Ok(PyIterReturn::Return(vm.ctx.new_tuple(result).into()))
14021446
}
14031447
}
14041448

0 commit comments

Comments
 (0)