diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index 142fa04e38..cf1107c45a 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -208,7 +208,7 @@ def test_chain_setstate(self): it = chain() it.__setstate__((iter(['abc', 'def']), iter(['ghi']))) self.assertEqual(list(it), ['ghi', 'a', 'b', 'c', 'd', 'e', 'f']) - + # TODO: RUSTPYTHON @unittest.expectedFailure def test_combinations(self): @@ -1165,8 +1165,7 @@ def test_product_tuple_reuse(self): self.assertEqual(len(set(map(id, product('abc', 'def')))), 1) self.assertNotEqual(len(set(map(id, list(product('abc', 'def'))))), 1) - # TODO: RUSTPYTHON - @unittest.expectedFailure + def test_product_pickling(self): # check copy, deepcopy, pickle for args, result in [ @@ -2297,7 +2296,7 @@ def __eq__(self, other): class SubclassWithKwargsTest(unittest.TestCase): - + # TODO: RUSTPYTHON @unittest.expectedFailure def test_keywords_in_subclass(self): diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index 526ab61af3..3bc26c0a8f 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -2,6 +2,7 @@ pub(crate) use decl::make_module; #[pymodule(name = "itertools")] mod decl { + use crate::stdlib::itertools::decl::int::get_value; use crate::{ builtins::{ int, tuple::IntoPyTuple, PyGenericAlias, PyInt, PyIntRef, PyList, PyTuple, PyTupleRef, @@ -110,7 +111,9 @@ mod decl { Ok(()) } } + impl SelfIter for PyItertoolsChain {} + impl IterNext for PyItertoolsChain { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { let Some(source) = zelf.source.read().clone() else { @@ -201,6 +204,7 @@ mod decl { } impl SelfIter for PyItertoolsCompress {} + impl IterNext for PyItertoolsCompress { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { loop { @@ -268,7 +272,9 @@ mod decl { (zelf.class().to_owned(), (zelf.cur.read().clone(),)) } } + impl SelfIter for PyItertoolsCount {} + impl IterNext for PyItertoolsCount { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { let mut cur = zelf.cur.write(); @@ -316,7 +322,9 @@ mod decl { #[pyclass(with(IterNext, Iterable, Constructor), flags(BASETYPE))] impl PyItertoolsCycle {} + impl SelfIter for PyItertoolsCycle {} + impl IterNext for PyItertoolsCycle { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { let item = if let PyIterReturn::Return(item) = zelf.iter.next(vm)? { @@ -401,6 +409,7 @@ mod decl { } impl SelfIter for PyItertoolsRepeat {} + impl IterNext for PyItertoolsRepeat { fn next(zelf: &Py, _vm: &VirtualMachine) -> PyResult { if let Some(ref times) = zelf.times { @@ -466,7 +475,9 @@ mod decl { ) } } + impl SelfIter for PyItertoolsStarmap {} + impl IterNext for PyItertoolsStarmap { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { let obj = zelf.iterable.next(vm)?; @@ -537,7 +548,9 @@ mod decl { Ok(()) } } + impl SelfIter for PyItertoolsTakewhile {} + impl IterNext for PyItertoolsTakewhile { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { if zelf.stop_flag.load() { @@ -618,7 +631,9 @@ mod decl { Ok(()) } } + impl SelfIter for PyItertoolsDropwhile {} + impl IterNext for PyItertoolsDropwhile { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { let predicate = &zelf.predicate; @@ -629,7 +644,7 @@ mod decl { let obj = match iterable.next(vm)? { PyIterReturn::Return(obj) => obj, PyIterReturn::StopIteration(v) => { - return Ok(PyIterReturn::StopIteration(v)) + return Ok(PyIterReturn::StopIteration(v)); } }; let pred = predicate.clone(); @@ -737,7 +752,9 @@ mod decl { Ok(PyIterReturn::Return((new_value, new_key))) } } + impl SelfIter for PyItertoolsGroupBy {} + impl IterNext for PyItertoolsGroupBy { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { let mut state = zelf.state.lock(); @@ -753,7 +770,7 @@ mod decl { let (value, new_key) = match zelf.advance(vm)? { PyIterReturn::Return(obj) => obj, PyIterReturn::StopIteration(v) => { - return Ok(PyIterReturn::StopIteration(v)) + return Ok(PyIterReturn::StopIteration(v)); } }; if !vm.bool_eq(&new_key, &old_key)? { @@ -764,7 +781,7 @@ mod decl { match zelf.advance(vm)? { PyIterReturn::Return(obj) => obj, PyIterReturn::StopIteration(v) => { - return Ok(PyIterReturn::StopIteration(v)) + return Ok(PyIterReturn::StopIteration(v)); } } }; @@ -797,7 +814,9 @@ mod decl { #[pyclass(with(IterNext, Iterable))] impl PyItertoolsGrouper {} + impl SelfIter for PyItertoolsGrouper {} + impl IterNext for PyItertoolsGrouper { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { let old_key = { @@ -960,6 +979,7 @@ mod decl { } impl SelfIter for PyItertoolsIslice {} + impl IterNext for PyItertoolsIslice { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { while zelf.cur.load() < zelf.next.load() { @@ -1033,7 +1053,9 @@ mod decl { ) } } + impl SelfIter for PyItertoolsFilterFalse {} + impl IterNext for PyItertoolsFilterFalse { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { let predicate = &zelf.predicate; @@ -1142,6 +1164,7 @@ mod decl { } impl SelfIter for PyItertoolsAccumulate {} + impl IterNext for PyItertoolsAccumulate { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { let iterable = &zelf.iterable; @@ -1153,7 +1176,7 @@ mod decl { None => match iterable.next(vm)? { PyIterReturn::Return(obj) => obj, PyIterReturn::StopIteration(v) => { - return Ok(PyIterReturn::StopIteration(v)) + return Ok(PyIterReturn::StopIteration(v)); } }, Some(obj) => obj.clone(), @@ -1162,7 +1185,7 @@ mod decl { let obj = match iterable.next(vm)? { PyIterReturn::Return(obj) => obj, PyIterReturn::StopIteration(v) => { - return Ok(PyIterReturn::StopIteration(v)) + return Ok(PyIterReturn::StopIteration(v)); } }; match &zelf.binop { @@ -1348,7 +1371,60 @@ mod decl { self.cur.store(idxs.len() - 1); } } + + #[pymethod(magic)] + fn setstate(zelf: PyRef, state: PyTupleRef, vm: &VirtualMachine) -> PyResult<()> { + let args = state.as_slice(); + if args.len() != zelf.pools.len() { + let msg = "Invalid number of arguments".to_string(); + return Err(vm.new_type_error(msg)); + } + let mut idxs: PyRwLockWriteGuard<'_, Vec> = zelf.idxs.write(); + idxs.clear(); + for s in 0..args.len() { + let index = get_value(state.get(s).unwrap()).to_usize().unwrap(); + let pool_size = zelf.pools.get(s).unwrap().len(); + if pool_size == 0 { + zelf.stop.store(true); + return Ok(()); + } + if index >= pool_size { + idxs.push(pool_size - 1); + } else { + idxs.push(index); + } + } + zelf.stop.store(false); + Ok(()) + } + + #[pymethod(magic)] + fn reduce(zelf: PyRef, vm: &VirtualMachine) -> PyTupleRef { + let class = zelf.class().to_owned(); + + if zelf.stop.load() { + return vm.new_tuple((class, (vm.ctx.empty_tuple.clone(),))); + } + + let mut pools: Vec = Vec::new(); + for element in zelf.pools.iter() { + pools.push(element.clone().into_pytuple(vm).into()); + } + + let mut indices: Vec = Vec::new(); + + for item in &zelf.idxs.read()[..] { + indices.push(vm.new_pyobj(*item)); + } + + vm.new_tuple(( + class, + pools.clone().into_pytuple(vm), + indices.into_pytuple(vm), + )) + } } + impl SelfIter for PyItertoolsProduct {} impl IterNext for PyItertoolsProduct { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { @@ -1563,6 +1639,7 @@ mod decl { impl PyItertoolsCombinationsWithReplacement {} impl SelfIter for PyItertoolsCombinationsWithReplacement {} + impl IterNext for PyItertoolsCombinationsWithReplacement { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { // stop signal @@ -1679,7 +1756,9 @@ mod decl { )) } } + impl SelfIter for PyItertoolsPermutations {} + impl IterNext for PyItertoolsPermutations { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { // stop signal @@ -1802,7 +1881,9 @@ mod decl { Ok(()) } } + impl SelfIter for PyItertoolsZipLongest {} + impl IterNext for PyItertoolsZipLongest { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { if zelf.iterators.is_empty() { @@ -1851,7 +1932,9 @@ mod decl { #[pyclass(with(IterNext, Iterable, Constructor))] impl PyItertoolsPairwise {} + impl SelfIter for PyItertoolsPairwise {} + impl IterNext for PyItertoolsPairwise { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { let old = match zelf.old.read().clone() {