Skip to content

[VM] Object pickling implementation for product object (python itertools) #5089

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions Lib/test/test_itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def test_chain_setstate(self):
it = chain()
it.__setstate__((iter(['abc', 'def']), iter(['ghi'])))
self.assertEqual(list(it), ['ghi', 'a', 'b', 'c', 'd', 'e', 'f'])

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_combinations(self):
Expand Down Expand Up @@ -1165,8 +1165,7 @@ def test_product_tuple_reuse(self):
self.assertEqual(len(set(map(id, product('abc', 'def')))), 1)
self.assertNotEqual(len(set(map(id, list(product('abc', 'def'))))), 1)

# TODO: RUSTPYTHON
@unittest.expectedFailure

def test_product_pickling(self):
# check copy, deepcopy, pickle
for args, result in [
Expand Down Expand Up @@ -2297,7 +2296,7 @@ def __eq__(self, other):


class SubclassWithKwargsTest(unittest.TestCase):

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_keywords_in_subclass(self):
Expand Down
93 changes: 88 additions & 5 deletions vm/src/stdlib/itertools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub(crate) use decl::make_module;

#[pymodule(name = "itertools")]
mod decl {
use crate::stdlib::itertools::decl::int::get_value;
use crate::{
builtins::{
int, tuple::IntoPyTuple, PyGenericAlias, PyInt, PyIntRef, PyList, PyTuple, PyTupleRef,
Expand Down Expand Up @@ -110,7 +111,9 @@ mod decl {
Ok(())
}
}

impl SelfIter for PyItertoolsChain {}

impl IterNext for PyItertoolsChain {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let Some(source) = zelf.source.read().clone() else {
Expand Down Expand Up @@ -201,6 +204,7 @@ mod decl {
}

impl SelfIter for PyItertoolsCompress {}

impl IterNext for PyItertoolsCompress {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
loop {
Expand Down Expand Up @@ -268,7 +272,9 @@ mod decl {
(zelf.class().to_owned(), (zelf.cur.read().clone(),))
}
}

impl SelfIter for PyItertoolsCount {}

impl IterNext for PyItertoolsCount {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let mut cur = zelf.cur.write();
Expand Down Expand Up @@ -316,7 +322,9 @@ mod decl {

#[pyclass(with(IterNext, Iterable, Constructor), flags(BASETYPE))]
impl PyItertoolsCycle {}

impl SelfIter for PyItertoolsCycle {}

impl IterNext for PyItertoolsCycle {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let item = if let PyIterReturn::Return(item) = zelf.iter.next(vm)? {
Expand Down Expand Up @@ -401,6 +409,7 @@ mod decl {
}

impl SelfIter for PyItertoolsRepeat {}

impl IterNext for PyItertoolsRepeat {
fn next(zelf: &Py<Self>, _vm: &VirtualMachine) -> PyResult<PyIterReturn> {
if let Some(ref times) = zelf.times {
Expand Down Expand Up @@ -466,7 +475,9 @@ mod decl {
)
}
}

impl SelfIter for PyItertoolsStarmap {}

impl IterNext for PyItertoolsStarmap {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let obj = zelf.iterable.next(vm)?;
Expand Down Expand Up @@ -537,7 +548,9 @@ mod decl {
Ok(())
}
}

impl SelfIter for PyItertoolsTakewhile {}

impl IterNext for PyItertoolsTakewhile {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
if zelf.stop_flag.load() {
Expand Down Expand Up @@ -618,7 +631,9 @@ mod decl {
Ok(())
}
}

impl SelfIter for PyItertoolsDropwhile {}

impl IterNext for PyItertoolsDropwhile {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let predicate = &zelf.predicate;
Expand All @@ -629,7 +644,7 @@ mod decl {
let obj = match iterable.next(vm)? {
PyIterReturn::Return(obj) => obj,
PyIterReturn::StopIteration(v) => {
return Ok(PyIterReturn::StopIteration(v))
return Ok(PyIterReturn::StopIteration(v));
}
};
let pred = predicate.clone();
Expand Down Expand Up @@ -737,7 +752,9 @@ mod decl {
Ok(PyIterReturn::Return((new_value, new_key)))
}
}

impl SelfIter for PyItertoolsGroupBy {}

impl IterNext for PyItertoolsGroupBy {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let mut state = zelf.state.lock();
Expand All @@ -753,7 +770,7 @@ mod decl {
let (value, new_key) = match zelf.advance(vm)? {
PyIterReturn::Return(obj) => obj,
PyIterReturn::StopIteration(v) => {
return Ok(PyIterReturn::StopIteration(v))
return Ok(PyIterReturn::StopIteration(v));
}
};
if !vm.bool_eq(&new_key, &old_key)? {
Expand All @@ -764,7 +781,7 @@ mod decl {
match zelf.advance(vm)? {
PyIterReturn::Return(obj) => obj,
PyIterReturn::StopIteration(v) => {
return Ok(PyIterReturn::StopIteration(v))
return Ok(PyIterReturn::StopIteration(v));
}
}
};
Expand Down Expand Up @@ -797,7 +814,9 @@ mod decl {

#[pyclass(with(IterNext, Iterable))]
impl PyItertoolsGrouper {}

impl SelfIter for PyItertoolsGrouper {}

impl IterNext for PyItertoolsGrouper {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let old_key = {
Expand Down Expand Up @@ -960,6 +979,7 @@ mod decl {
}

impl SelfIter for PyItertoolsIslice {}

impl IterNext for PyItertoolsIslice {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
while zelf.cur.load() < zelf.next.load() {
Expand Down Expand Up @@ -1033,7 +1053,9 @@ mod decl {
)
}
}

impl SelfIter for PyItertoolsFilterFalse {}

impl IterNext for PyItertoolsFilterFalse {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let predicate = &zelf.predicate;
Expand Down Expand Up @@ -1142,6 +1164,7 @@ mod decl {
}

impl SelfIter for PyItertoolsAccumulate {}

impl IterNext for PyItertoolsAccumulate {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let iterable = &zelf.iterable;
Expand All @@ -1153,7 +1176,7 @@ mod decl {
None => match iterable.next(vm)? {
PyIterReturn::Return(obj) => obj,
PyIterReturn::StopIteration(v) => {
return Ok(PyIterReturn::StopIteration(v))
return Ok(PyIterReturn::StopIteration(v));
}
},
Some(obj) => obj.clone(),
Expand All @@ -1162,7 +1185,7 @@ mod decl {
let obj = match iterable.next(vm)? {
PyIterReturn::Return(obj) => obj,
PyIterReturn::StopIteration(v) => {
return Ok(PyIterReturn::StopIteration(v))
return Ok(PyIterReturn::StopIteration(v));
}
};
match &zelf.binop {
Expand Down Expand Up @@ -1348,7 +1371,60 @@ mod decl {
self.cur.store(idxs.len() - 1);
}
}

#[pymethod(magic)]
fn setstate(zelf: PyRef<Self>, state: PyTupleRef, vm: &VirtualMachine) -> PyResult<()> {
let args = state.as_slice();
if args.len() != zelf.pools.len() {
let msg = "Invalid number of arguments".to_string();
return Err(vm.new_type_error(msg));
}
let mut idxs: PyRwLockWriteGuard<'_, Vec<usize>> = zelf.idxs.write();
idxs.clear();
for s in 0..args.len() {
let index = get_value(state.get(s).unwrap()).to_usize().unwrap();
let pool_size = zelf.pools.get(s).unwrap().len();
if pool_size == 0 {
zelf.stop.store(true);
return Ok(());
}
if index >= pool_size {
idxs.push(pool_size - 1);
} else {
idxs.push(index);
}
}
zelf.stop.store(false);
Ok(())
}

#[pymethod(magic)]
fn reduce(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyTupleRef {
let class = zelf.class().to_owned();

if zelf.stop.load() {
return vm.new_tuple((class, (vm.ctx.empty_tuple.clone(),)));
}

let mut pools: Vec<PyObjectRef> = Vec::new();
for element in zelf.pools.iter() {
pools.push(element.clone().into_pytuple(vm).into());
}

let mut indices: Vec<PyObjectRef> = Vec::new();

for item in &zelf.idxs.read()[..] {
indices.push(vm.new_pyobj(*item));
}

vm.new_tuple((
class,
pools.clone().into_pytuple(vm),
indices.into_pytuple(vm),
))
}
}

impl SelfIter for PyItertoolsProduct {}
impl IterNext for PyItertoolsProduct {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
Expand Down Expand Up @@ -1563,6 +1639,7 @@ mod decl {
impl PyItertoolsCombinationsWithReplacement {}

impl SelfIter for PyItertoolsCombinationsWithReplacement {}

impl IterNext for PyItertoolsCombinationsWithReplacement {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
// stop signal
Expand Down Expand Up @@ -1679,7 +1756,9 @@ mod decl {
))
}
}

impl SelfIter for PyItertoolsPermutations {}

impl IterNext for PyItertoolsPermutations {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
// stop signal
Expand Down Expand Up @@ -1802,7 +1881,9 @@ mod decl {
Ok(())
}
}

impl SelfIter for PyItertoolsZipLongest {}

impl IterNext for PyItertoolsZipLongest {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
if zelf.iterators.is_empty() {
Expand Down Expand Up @@ -1851,7 +1932,9 @@ mod decl {

#[pyclass(with(IterNext, Iterable, Constructor))]
impl PyItertoolsPairwise {}

impl SelfIter for PyItertoolsPairwise {}

impl IterNext for PyItertoolsPairwise {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let old = match zelf.old.read().clone() {
Expand Down