Skip to content

Commit 9241e2e

Browse files
[VM] Object pickling implementation for product object (python itertools) (#5089)
* Implemented __reduce__, __setstate__ in product object
1 parent 830389f commit 9241e2e

File tree

2 files changed

+91
-9
lines changed

2 files changed

+91
-9
lines changed

Lib/test/test_itertools.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def test_chain_setstate(self):
208208
it = chain()
209209
it.__setstate__((iter(['abc', 'def']), iter(['ghi'])))
210210
self.assertEqual(list(it), ['ghi', 'a', 'b', 'c', 'd', 'e', 'f'])
211-
211+
212212
# TODO: RUSTPYTHON
213213
@unittest.expectedFailure
214214
def test_combinations(self):
@@ -1165,8 +1165,7 @@ def test_product_tuple_reuse(self):
11651165
self.assertEqual(len(set(map(id, product('abc', 'def')))), 1)
11661166
self.assertNotEqual(len(set(map(id, list(product('abc', 'def'))))), 1)
11671167

1168-
# TODO: RUSTPYTHON
1169-
@unittest.expectedFailure
1168+
11701169
def test_product_pickling(self):
11711170
# check copy, deepcopy, pickle
11721171
for args, result in [
@@ -2297,7 +2296,7 @@ def __eq__(self, other):
22972296

22982297

22992298
class SubclassWithKwargsTest(unittest.TestCase):
2300-
2299+
23012300
# TODO: RUSTPYTHON
23022301
@unittest.expectedFailure
23032302
def test_keywords_in_subclass(self):

vm/src/stdlib/itertools.rs

Lines changed: 88 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ pub(crate) use decl::make_module;
22

33
#[pymodule(name = "itertools")]
44
mod decl {
5+
use crate::stdlib::itertools::decl::int::get_value;
56
use crate::{
67
builtins::{
78
int, tuple::IntoPyTuple, PyGenericAlias, PyInt, PyIntRef, PyList, PyTuple, PyTupleRef,
@@ -110,7 +111,9 @@ mod decl {
110111
Ok(())
111112
}
112113
}
114+
113115
impl SelfIter for PyItertoolsChain {}
116+
114117
impl IterNext for PyItertoolsChain {
115118
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
116119
let Some(source) = zelf.source.read().clone() else {
@@ -201,6 +204,7 @@ mod decl {
201204
}
202205

203206
impl SelfIter for PyItertoolsCompress {}
207+
204208
impl IterNext for PyItertoolsCompress {
205209
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
206210
loop {
@@ -268,7 +272,9 @@ mod decl {
268272
(zelf.class().to_owned(), (zelf.cur.read().clone(),))
269273
}
270274
}
275+
271276
impl SelfIter for PyItertoolsCount {}
277+
272278
impl IterNext for PyItertoolsCount {
273279
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
274280
let mut cur = zelf.cur.write();
@@ -316,7 +322,9 @@ mod decl {
316322

317323
#[pyclass(with(IterNext, Iterable, Constructor), flags(BASETYPE))]
318324
impl PyItertoolsCycle {}
325+
319326
impl SelfIter for PyItertoolsCycle {}
327+
320328
impl IterNext for PyItertoolsCycle {
321329
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
322330
let item = if let PyIterReturn::Return(item) = zelf.iter.next(vm)? {
@@ -401,6 +409,7 @@ mod decl {
401409
}
402410

403411
impl SelfIter for PyItertoolsRepeat {}
412+
404413
impl IterNext for PyItertoolsRepeat {
405414
fn next(zelf: &Py<Self>, _vm: &VirtualMachine) -> PyResult<PyIterReturn> {
406415
if let Some(ref times) = zelf.times {
@@ -466,7 +475,9 @@ mod decl {
466475
)
467476
}
468477
}
478+
469479
impl SelfIter for PyItertoolsStarmap {}
480+
470481
impl IterNext for PyItertoolsStarmap {
471482
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
472483
let obj = zelf.iterable.next(vm)?;
@@ -537,7 +548,9 @@ mod decl {
537548
Ok(())
538549
}
539550
}
551+
540552
impl SelfIter for PyItertoolsTakewhile {}
553+
541554
impl IterNext for PyItertoolsTakewhile {
542555
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
543556
if zelf.stop_flag.load() {
@@ -618,7 +631,9 @@ mod decl {
618631
Ok(())
619632
}
620633
}
634+
621635
impl SelfIter for PyItertoolsDropwhile {}
636+
622637
impl IterNext for PyItertoolsDropwhile {
623638
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
624639
let predicate = &zelf.predicate;
@@ -629,7 +644,7 @@ mod decl {
629644
let obj = match iterable.next(vm)? {
630645
PyIterReturn::Return(obj) => obj,
631646
PyIterReturn::StopIteration(v) => {
632-
return Ok(PyIterReturn::StopIteration(v))
647+
return Ok(PyIterReturn::StopIteration(v));
633648
}
634649
};
635650
let pred = predicate.clone();
@@ -737,7 +752,9 @@ mod decl {
737752
Ok(PyIterReturn::Return((new_value, new_key)))
738753
}
739754
}
755+
740756
impl SelfIter for PyItertoolsGroupBy {}
757+
741758
impl IterNext for PyItertoolsGroupBy {
742759
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
743760
let mut state = zelf.state.lock();
@@ -753,7 +770,7 @@ mod decl {
753770
let (value, new_key) = match zelf.advance(vm)? {
754771
PyIterReturn::Return(obj) => obj,
755772
PyIterReturn::StopIteration(v) => {
756-
return Ok(PyIterReturn::StopIteration(v))
773+
return Ok(PyIterReturn::StopIteration(v));
757774
}
758775
};
759776
if !vm.bool_eq(&new_key, &old_key)? {
@@ -764,7 +781,7 @@ mod decl {
764781
match zelf.advance(vm)? {
765782
PyIterReturn::Return(obj) => obj,
766783
PyIterReturn::StopIteration(v) => {
767-
return Ok(PyIterReturn::StopIteration(v))
784+
return Ok(PyIterReturn::StopIteration(v));
768785
}
769786
}
770787
};
@@ -797,7 +814,9 @@ mod decl {
797814

798815
#[pyclass(with(IterNext, Iterable))]
799816
impl PyItertoolsGrouper {}
817+
800818
impl SelfIter for PyItertoolsGrouper {}
819+
801820
impl IterNext for PyItertoolsGrouper {
802821
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
803822
let old_key = {
@@ -960,6 +979,7 @@ mod decl {
960979
}
961980

962981
impl SelfIter for PyItertoolsIslice {}
982+
963983
impl IterNext for PyItertoolsIslice {
964984
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
965985
while zelf.cur.load() < zelf.next.load() {
@@ -1033,7 +1053,9 @@ mod decl {
10331053
)
10341054
}
10351055
}
1056+
10361057
impl SelfIter for PyItertoolsFilterFalse {}
1058+
10371059
impl IterNext for PyItertoolsFilterFalse {
10381060
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
10391061
let predicate = &zelf.predicate;
@@ -1142,6 +1164,7 @@ mod decl {
11421164
}
11431165

11441166
impl SelfIter for PyItertoolsAccumulate {}
1167+
11451168
impl IterNext for PyItertoolsAccumulate {
11461169
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
11471170
let iterable = &zelf.iterable;
@@ -1153,7 +1176,7 @@ mod decl {
11531176
None => match iterable.next(vm)? {
11541177
PyIterReturn::Return(obj) => obj,
11551178
PyIterReturn::StopIteration(v) => {
1156-
return Ok(PyIterReturn::StopIteration(v))
1179+
return Ok(PyIterReturn::StopIteration(v));
11571180
}
11581181
},
11591182
Some(obj) => obj.clone(),
@@ -1162,7 +1185,7 @@ mod decl {
11621185
let obj = match iterable.next(vm)? {
11631186
PyIterReturn::Return(obj) => obj,
11641187
PyIterReturn::StopIteration(v) => {
1165-
return Ok(PyIterReturn::StopIteration(v))
1188+
return Ok(PyIterReturn::StopIteration(v));
11661189
}
11671190
};
11681191
match &zelf.binop {
@@ -1348,7 +1371,60 @@ mod decl {
13481371
self.cur.store(idxs.len() - 1);
13491372
}
13501373
}
1374+
1375+
#[pymethod(magic)]
1376+
fn setstate(zelf: PyRef<Self>, state: PyTupleRef, vm: &VirtualMachine) -> PyResult<()> {
1377+
let args = state.as_slice();
1378+
if args.len() != zelf.pools.len() {
1379+
let msg = "Invalid number of arguments".to_string();
1380+
return Err(vm.new_type_error(msg));
1381+
}
1382+
let mut idxs: PyRwLockWriteGuard<'_, Vec<usize>> = zelf.idxs.write();
1383+
idxs.clear();
1384+
for s in 0..args.len() {
1385+
let index = get_value(state.get(s).unwrap()).to_usize().unwrap();
1386+
let pool_size = zelf.pools.get(s).unwrap().len();
1387+
if pool_size == 0 {
1388+
zelf.stop.store(true);
1389+
return Ok(());
1390+
}
1391+
if index >= pool_size {
1392+
idxs.push(pool_size - 1);
1393+
} else {
1394+
idxs.push(index);
1395+
}
1396+
}
1397+
zelf.stop.store(false);
1398+
Ok(())
1399+
}
1400+
1401+
#[pymethod(magic)]
1402+
fn reduce(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyTupleRef {
1403+
let class = zelf.class().to_owned();
1404+
1405+
if zelf.stop.load() {
1406+
return vm.new_tuple((class, (vm.ctx.empty_tuple.clone(),)));
1407+
}
1408+
1409+
let mut pools: Vec<PyObjectRef> = Vec::new();
1410+
for element in zelf.pools.iter() {
1411+
pools.push(element.clone().into_pytuple(vm).into());
1412+
}
1413+
1414+
let mut indices: Vec<PyObjectRef> = Vec::new();
1415+
1416+
for item in &zelf.idxs.read()[..] {
1417+
indices.push(vm.new_pyobj(*item));
1418+
}
1419+
1420+
vm.new_tuple((
1421+
class,
1422+
pools.clone().into_pytuple(vm),
1423+
indices.into_pytuple(vm),
1424+
))
1425+
}
13511426
}
1427+
13521428
impl SelfIter for PyItertoolsProduct {}
13531429
impl IterNext for PyItertoolsProduct {
13541430
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
@@ -1563,6 +1639,7 @@ mod decl {
15631639
impl PyItertoolsCombinationsWithReplacement {}
15641640

15651641
impl SelfIter for PyItertoolsCombinationsWithReplacement {}
1642+
15661643
impl IterNext for PyItertoolsCombinationsWithReplacement {
15671644
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
15681645
// stop signal
@@ -1679,7 +1756,9 @@ mod decl {
16791756
))
16801757
}
16811758
}
1759+
16821760
impl SelfIter for PyItertoolsPermutations {}
1761+
16831762
impl IterNext for PyItertoolsPermutations {
16841763
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
16851764
// stop signal
@@ -1802,7 +1881,9 @@ mod decl {
18021881
Ok(())
18031882
}
18041883
}
1884+
18051885
impl SelfIter for PyItertoolsZipLongest {}
1886+
18061887
impl IterNext for PyItertoolsZipLongest {
18071888
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
18081889
if zelf.iterators.is_empty() {
@@ -1851,7 +1932,9 @@ mod decl {
18511932

18521933
#[pyclass(with(IterNext, Iterable, Constructor))]
18531934
impl PyItertoolsPairwise {}
1935+
18541936
impl SelfIter for PyItertoolsPairwise {}
1937+
18551938
impl IterNext for PyItertoolsPairwise {
18561939
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
18571940
let old = match zelf.old.read().clone() {

0 commit comments

Comments
 (0)