From bfa74dda87765698dba5df1df19f16856b592303 Mon Sep 17 00:00:00 2001 From: dvermd <315743+dvermd@users.noreply.github.com> Date: Tue, 4 Oct 2022 23:59:27 +0200 Subject: [PATCH] Add dropwhile.__reduce__ --- extra_tests/snippets/stdlib_itertools.py | 12 +++++++++ vm/src/stdlib/itertools.rs | 32 ++++++++++++++++++------ 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/extra_tests/snippets/stdlib_itertools.py b/extra_tests/snippets/stdlib_itertools.py index a5b91d0cde..8a230ec096 100644 --- a/extra_tests/snippets/stdlib_itertools.py +++ b/extra_tests/snippets/stdlib_itertools.py @@ -2,6 +2,7 @@ from testutils import assert_raises +import pickle # itertools.chain tests chain = itertools.chain @@ -279,6 +280,17 @@ def assert_matches_seq(it, seq): with assert_raises(StopIteration): next(it) +def underten(x): + return x<10 + +it = itertools.dropwhile(underten, [1, 3, 5, 20, 2, 4, 6, 8]) +assert pickle.dumps(it, 0) == b'citertools\ndropwhile\np0\n(c__main__\nunderten\np1\nc__builtin__\niter\np2\n((lp3\nI1\naI3\naI5\naI20\naI2\naI4\naI6\naI8\natp4\nRp5\nI0\nbtp6\nRp7\nI0\nb.' +assert pickle.dumps(it, 1) == b'citertools\ndropwhile\nq\x00(c__main__\nunderten\nq\x01c__builtin__\niter\nq\x02(]q\x03(K\x01K\x03K\x05K\x14K\x02K\x04K\x06K\x08etq\x04Rq\x05K\x00btq\x06Rq\x07K\x00b.' +assert pickle.dumps(it, 2) == b'\x80\x02citertools\ndropwhile\nq\x00c__main__\nunderten\nq\x01c__builtin__\niter\nq\x02]q\x03(K\x01K\x03K\x05K\x14K\x02K\x04K\x06K\x08e\x85q\x04Rq\x05K\x00b\x86q\x06Rq\x07K\x00b.' +assert pickle.dumps(it, 3) == b'\x80\x03citertools\ndropwhile\nq\x00c__main__\nunderten\nq\x01cbuiltins\niter\nq\x02]q\x03(K\x01K\x03K\x05K\x14K\x02K\x04K\x06K\x08e\x85q\x04Rq\x05K\x00b\x86q\x06Rq\x07K\x00b.' +assert pickle.dumps(it, 4) == b'\x80\x04\x95i\x00\x00\x00\x00\x00\x00\x00\x8c\titertools\x94\x8c\tdropwhile\x94\x93\x94\x8c\x08__main__\x94\x8c\x08underten\x94\x93\x94\x8c\x08builtins\x94\x8c\x04iter\x94\x93\x94]\x94(K\x01K\x03K\x05K\x14K\x02K\x04K\x06K\x08e\x85\x94R\x94K\x00b\x86\x94R\x94K\x00b.' +assert pickle.dumps(it, 5) == b'\x80\x05\x95i\x00\x00\x00\x00\x00\x00\x00\x8c\titertools\x94\x8c\tdropwhile\x94\x93\x94\x8c\x08__main__\x94\x8c\x08underten\x94\x93\x94\x8c\x08builtins\x94\x8c\x04iter\x94\x93\x94]\x94(K\x01K\x03K\x05K\x14K\x02K\x04K\x06K\x08e\x85\x94R\x94K\x00b\x86\x94R\x94K\x00b.' + # itertools.accumulate it = itertools.accumulate([6, 3, 7, 1, 0, 9, 8, 8]) diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index ddb976590e..5d5682f2f5 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -2,21 +2,23 @@ pub(crate) use decl::make_module; #[pymodule(name = "itertools")] mod decl { - use crate::common::{ - lock::{PyMutex, PyRwLock, PyRwLockWriteGuard}, - rc::PyRc, - }; use crate::{ builtins::{int, PyGenericAlias, PyInt, PyIntRef, PyList, PyTuple, PyTupleRef, PyTypeRef}, + common::{ + lock::{PyMutex, PyRwLock, PyRwLockWriteGuard}, + rc::PyRc, + }, convert::ToPyObject, - function::{ArgCallable, FuncArgs, OptionalArg, OptionalOption, PosArgs}, + function::{ArgCallable, ArgIntoBool, FuncArgs, OptionalArg, OptionalOption, PosArgs}, identifier, protocol::{PyIter, PyIterReturn, PyNumber}, stdlib::sys, types::{Constructor, IterNext, IterNextIterable}, - AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, PyWeakRef, VirtualMachine, + AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, PyWeakRef, TryFromObject, + VirtualMachine, }; use crossbeam_utils::atomic::AtomicCell; + use num_bigint::BigInt; use num_traits::{Signed, ToPrimitive}; use std::fmt; @@ -540,7 +542,23 @@ mod decl { } #[pyclass(with(IterNext, Constructor), flags(BASETYPE))] - impl PyItertoolsDropwhile {} + impl PyItertoolsDropwhile { + #[pymethod(magic)] + fn reduce(zelf: PyRef) -> (PyTypeRef, (PyObjectRef, PyIter), BigInt) { + ( + zelf.class().clone(), + (zelf.predicate.clone().into(), zelf.iterable.clone()), + (if zelf.start_flag.load() { 1 } else { 0 }).into(), + ) + } + #[pymethod(magic)] + fn setstate(zelf: PyRef, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + if let Ok(obj) = ArgIntoBool::try_from_object(vm, state) { + zelf.start_flag.store(*obj); + } + Ok(()) + } + } impl IterNextIterable for PyItertoolsDropwhile {} impl IterNext for PyItertoolsDropwhile { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult {