diff --git a/vm/src/builtins/asyncgenerator.rs b/vm/src/builtins/asyncgenerator.rs index d8365178fa..26a8eb7f77 100644 --- a/vm/src/builtins/asyncgenerator.rs +++ b/vm/src/builtins/asyncgenerator.rs @@ -4,7 +4,7 @@ use crate::{ coroutine::{Coro, Variant}, frame::FrameRef, function::OptionalArg, - slots::{IteratorIterable, PyIter}, + slots::{IteratorIterable, SlotIterator}, IdProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, VirtualMachine, }; @@ -182,7 +182,7 @@ impl PyValue for PyAsyncGenASend { } } -#[pyimpl(with(PyIter))] +#[pyimpl(with(SlotIterator))] impl PyAsyncGenASend { #[pymethod(name = "__await__")] fn r#await(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { @@ -255,7 +255,7 @@ impl PyAsyncGenASend { } impl IteratorIterable for PyAsyncGenASend {} -impl PyIter for PyAsyncGenASend { +impl SlotIterator for PyAsyncGenASend { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { zelf.send(vm.ctx.none(), vm) } @@ -276,7 +276,7 @@ impl PyValue for PyAsyncGenAThrow { } } -#[pyimpl(with(PyIter))] +#[pyimpl(with(SlotIterator))] impl PyAsyncGenAThrow { #[pymethod(name = "__await__")] fn r#await(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { @@ -397,7 +397,7 @@ impl PyAsyncGenAThrow { } impl IteratorIterable for PyAsyncGenAThrow {} -impl PyIter for PyAsyncGenAThrow { +impl SlotIterator for PyAsyncGenAThrow { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { zelf.send(vm.ctx.none(), vm) } diff --git a/vm/src/builtins/bytearray.rs b/vm/src/builtins/bytearray.rs index cc1c1228b5..7abd30bfbb 100644 --- a/vm/src/builtins/bytearray.rs +++ b/vm/src/builtins/bytearray.rs @@ -19,7 +19,7 @@ use crate::{ sliceable::{PySliceableSequence, PySliceableSequenceMut, SequenceIndex}, slots::{ AsBuffer, Callable, Comparable, Hashable, Iterable, IteratorIterable, PyComparisonOp, - PyIter, Unhashable, + SlotIterator, Unhashable, }, utils::Either, IdProtocol, IntoPyObject, PyClassDef, PyClassImpl, PyComparisonValue, PyContext, PyObjectRef, @@ -741,10 +741,10 @@ impl PyValue for PyByteArrayIterator { } } -#[pyimpl(with(PyIter))] +#[pyimpl(with(SlotIterator))] impl PyByteArrayIterator {} impl IteratorIterable for PyByteArrayIterator {} -impl PyIter for PyByteArrayIterator { +impl SlotIterator for PyByteArrayIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { let pos = zelf.position.fetch_add(1); if let Some(&ret) = zelf.bytearray.borrow_buf().get(pos) { diff --git a/vm/src/builtins/bytes.rs b/vm/src/builtins/bytes.rs index d052423670..ed6ea9089f 100644 --- a/vm/src/builtins/bytes.rs +++ b/vm/src/builtins/bytes.rs @@ -10,7 +10,7 @@ use crate::{ protocol::{BufferInternal, BufferOptions, PyBuffer}, slots::{ AsBuffer, Callable, Comparable, Hashable, Iterable, IteratorIterable, PyComparisonOp, - PyIter, SlotConstructor, + SlotConstructor, SlotIterator, }, utils::Either, IdProtocol, IntoPyObject, IntoPyResult, PyClassImpl, PyComparisonValue, PyContext, PyObjectRef, @@ -594,10 +594,10 @@ impl PyValue for PyBytesIterator { } } -#[pyimpl(with(PyIter))] +#[pyimpl(with(SlotIterator))] impl PyBytesIterator {} impl IteratorIterable for PyBytesIterator {} -impl PyIter for PyBytesIterator { +impl SlotIterator for PyBytesIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { let pos = zelf.position.fetch_add(1); if let Some(&ret) = zelf.bytes.as_bytes().get(pos) { diff --git a/vm/src/builtins/coroutine.rs b/vm/src/builtins/coroutine.rs index 0f9671abbf..c3b36a86c4 100644 --- a/vm/src/builtins/coroutine.rs +++ b/vm/src/builtins/coroutine.rs @@ -3,7 +3,7 @@ use crate::{ coroutine::{Coro, Variant}, frame::FrameRef, function::OptionalArg, - slots::{IteratorIterable, PyIter}, + slots::{IteratorIterable, SlotIterator}, IdProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, VirtualMachine, }; @@ -19,7 +19,7 @@ impl PyValue for PyCoroutine { } } -#[pyimpl(with(PyIter))] +#[pyimpl(with(SlotIterator))] impl PyCoroutine { pub fn as_coro(&self) -> &Coro { &self.inner @@ -102,7 +102,7 @@ impl PyCoroutine { } impl IteratorIterable for PyCoroutine {} -impl PyIter for PyCoroutine { +impl SlotIterator for PyCoroutine { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { zelf.send(vm.ctx.none(), vm) } @@ -120,7 +120,7 @@ impl PyValue for PyCoroutineWrapper { } } -#[pyimpl(with(PyIter))] +#[pyimpl(with(SlotIterator))] impl PyCoroutineWrapper { #[pymethod] fn send(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult { @@ -140,7 +140,7 @@ impl PyCoroutineWrapper { } impl IteratorIterable for PyCoroutineWrapper {} -impl PyIter for PyCoroutineWrapper { +impl SlotIterator for PyCoroutineWrapper { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { zelf.send(vm.ctx.none(), vm) } diff --git a/vm/src/builtins/dict.rs b/vm/src/builtins/dict.rs index 9eb94f8eea..e9b5b549bb 100644 --- a/vm/src/builtins/dict.rs +++ b/vm/src/builtins/dict.rs @@ -5,7 +5,9 @@ use crate::{ dictdatatype::{self, DictKey}, function::{ArgIterable, FuncArgs, KwArgs, OptionalArg}, iterator, - slots::{Comparable, Hashable, Iterable, IteratorIterable, PyComparisonOp, PyIter, Unhashable}, + slots::{ + Comparable, Hashable, Iterable, IteratorIterable, PyComparisonOp, SlotIterator, Unhashable, + }, vm::{ReprGuard, VirtualMachine}, IdProtocol, IntoPyObject, ItemProtocol, PyArithmaticValue::*, @@ -83,13 +85,13 @@ impl PyDict { return Err(vm.new_runtime_error("dict mutated during update".to_owned())); } } else if let Some(keys) = vm.get_method(dict_obj.clone(), "keys") { - let keys = iterator::get_iter(vm, vm.invoke(&keys?, ())?)?; + let keys = vm.invoke(&keys?, ())?.get_iter(vm)?; while let Some(key) = iterator::get_next_object(vm, &keys)? { let val = dict_obj.get_item(key.clone(), vm)?; dict.insert(vm, key, val)?; } } else { - let iter = iterator::get_iter(vm, dict_obj)?; + let iter = dict_obj.get_iter(vm)?; loop { fn err(vm: &VirtualMachine) -> PyBaseExceptionRef { vm.new_value_error("Iterator must have exactly two elements".to_owned()) @@ -98,7 +100,7 @@ impl PyDict { Some(obj) => obj, None => break, }; - let elem_iter = iterator::get_iter(vm, element)?; + let elem_iter = element.get_iter(vm)?; let key = iterator::get_next_object(vm, &elem_iter)?.ok_or_else(|| err(vm))?; let value = iterator::get_next_object(vm, &elem_iter)?.ok_or_else(|| err(vm))?; @@ -708,7 +710,7 @@ macro_rules! dict_iterator { } } - #[pyimpl(with(PyIter))] + #[pyimpl(with(SlotIterator))] impl $iter_name { fn new(dict: PyDictRef) -> Self { $iter_name { @@ -730,7 +732,7 @@ macro_rules! dict_iterator { } impl IteratorIterable for $iter_name {} - impl PyIter for $iter_name { + impl SlotIterator for $iter_name { #[allow(clippy::redundant_closure_call)] fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { match zelf.status.load() { @@ -769,7 +771,7 @@ macro_rules! dict_iterator { } } - #[pyimpl(with(PyIter))] + #[pyimpl(with(SlotIterator))] impl $reverse_iter_name { fn new(dict: PyDictRef) -> Self { $reverse_iter_name { @@ -791,7 +793,7 @@ macro_rules! dict_iterator { } impl IteratorIterable for $reverse_iter_name {} - impl PyIter for $reverse_iter_name { + impl SlotIterator for $reverse_iter_name { #[allow(clippy::redundant_closure_call)] fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { match zelf.status.load() { diff --git a/vm/src/builtins/enumerate.rs b/vm/src/builtins/enumerate.rs index 6de1e454e8..76bf4bccee 100644 --- a/vm/src/builtins/enumerate.rs +++ b/vm/src/builtins/enumerate.rs @@ -6,8 +6,8 @@ use super::{ use crate::common::lock::PyRwLock; use crate::{ function::OptionalArg, - iterator, - slots::{IteratorIterable, PyIter, SlotConstructor}, + protocol::PyIter, + slots::{IteratorIterable, SlotConstructor, SlotIterator}, IntoPyObject, ItemProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, VirtualMachine, }; @@ -19,7 +19,7 @@ use num_traits::Zero; #[derive(Debug)] pub struct PyEnumerate { counter: PyRwLock, - iterator: PyObjectRef, + iterator: PyIter, } impl PyValue for PyEnumerate { @@ -30,7 +30,7 @@ impl PyValue for PyEnumerate { #[derive(FromArgs)] pub struct EnumerateArgs { - iterable: PyObjectRef, + iterator: PyIter, #[pyarg(any, optional)] start: OptionalArg, } @@ -38,11 +38,12 @@ pub struct EnumerateArgs { impl SlotConstructor for PyEnumerate { type Args = EnumerateArgs; - fn py_new(cls: PyTypeRef, args: Self::Args, vm: &VirtualMachine) -> PyResult { - let counter = args - .start - .map_or_else(BigInt::zero, |start| start.as_bigint().clone()); - let iterator = iterator::get_iter(vm, args.iterable)?; + fn py_new( + cls: PyTypeRef, + Self::Args { iterator, start }: Self::Args, + vm: &VirtualMachine, + ) -> PyResult { + let counter = start.map_or_else(BigInt::zero, |start| start.as_bigint().clone()); PyEnumerate { counter: PyRwLock::new(counter), iterator, @@ -51,13 +52,13 @@ impl SlotConstructor for PyEnumerate { } } -#[pyimpl(with(PyIter, SlotConstructor), flags(BASETYPE))] +#[pyimpl(with(SlotIterator, SlotConstructor), flags(BASETYPE))] impl PyEnumerate {} impl IteratorIterable for PyEnumerate {} -impl PyIter for PyEnumerate { +impl SlotIterator for PyEnumerate { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - let next_obj = iterator::call_next(vm, &zelf.iterator)?; + let next_obj = zelf.iterator.next(vm)?; let mut counter = zelf.counter.write(); let position = counter.clone(); *counter += 1; @@ -79,7 +80,7 @@ impl PyValue for PyReverseSequenceIterator { } } -#[pyimpl(with(PyIter))] +#[pyimpl(with(SlotIterator))] impl PyReverseSequenceIterator { pub fn new(obj: PyObjectRef, len: usize) -> Self { Self { @@ -137,7 +138,7 @@ impl PyReverseSequenceIterator { } impl IteratorIterable for PyReverseSequenceIterator {} -impl PyIter for PyReverseSequenceIterator { +impl SlotIterator for PyReverseSequenceIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { if let Exhausted = zelf.status.load() { return Err(vm.new_stop_iteration()); diff --git a/vm/src/builtins/filter.rs b/vm/src/builtins/filter.rs index 4737a61b48..802605b389 100644 --- a/vm/src/builtins/filter.rs +++ b/vm/src/builtins/filter.rs @@ -1,7 +1,7 @@ use super::PyTypeRef; use crate::{ - iterator, - slots::{IteratorIterable, PyIter, SlotConstructor}, + protocol::PyIter, + slots::{IteratorIterable, SlotConstructor, SlotIterator}, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, VirtualMachine, }; @@ -13,7 +13,7 @@ use crate::{ #[derive(Debug)] pub struct PyFilter { predicate: PyObjectRef, - iterator: PyObjectRef, + iterator: PyIter, } impl PyValue for PyFilter { @@ -22,24 +22,10 @@ impl PyValue for PyFilter { } } -#[derive(FromArgs)] -pub struct FilterArgs { - #[pyarg(positional)] - function: PyObjectRef, - #[pyarg(positional)] - iterable: PyObjectRef, -} - impl SlotConstructor for PyFilter { - type Args = FilterArgs; - - fn py_new( - cls: PyTypeRef, - Self::Args { function, iterable }: Self::Args, - vm: &VirtualMachine, - ) -> PyResult { - let iterator = iterator::get_iter(vm, iterable)?; + type Args = (PyObjectRef, PyIter); + fn py_new(cls: PyTypeRef, (function, iterator): Self::Args, vm: &VirtualMachine) -> PyResult { Self { predicate: function, iterator, @@ -48,16 +34,15 @@ impl SlotConstructor for PyFilter { } } -#[pyimpl(with(PyIter, SlotConstructor), flags(BASETYPE))] +#[pyimpl(with(SlotIterator, SlotConstructor), flags(BASETYPE))] impl PyFilter {} impl IteratorIterable for PyFilter {} -impl PyIter for PyFilter { +impl SlotIterator for PyFilter { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { let predicate = &zelf.predicate; - let iterator = &zelf.iterator; loop { - let next_obj = iterator::call_next(vm, iterator)?; + let next_obj = zelf.iterator.next(vm)?; let predicate_value = if vm.is_none(predicate) { next_obj.clone() } else { diff --git a/vm/src/builtins/generator.rs b/vm/src/builtins/generator.rs index 6348dba042..cd1de09d98 100644 --- a/vm/src/builtins/generator.rs +++ b/vm/src/builtins/generator.rs @@ -7,7 +7,7 @@ use crate::{ coroutine::{Coro, Variant}, frame::FrameRef, function::OptionalArg, - slots::{IteratorIterable, PyIter}, + slots::{IteratorIterable, SlotIterator}, IdProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, VirtualMachine, }; @@ -23,7 +23,7 @@ impl PyValue for PyGenerator { } } -#[pyimpl(with(PyIter))] +#[pyimpl(with(SlotIterator))] impl PyGenerator { pub fn as_coro(&self) -> &Coro { &self.inner @@ -95,7 +95,7 @@ impl PyGenerator { } impl IteratorIterable for PyGenerator {} -impl PyIter for PyGenerator { +impl SlotIterator for PyGenerator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { zelf.send(vm.ctx.none(), vm) } diff --git a/vm/src/builtins/iter.rs b/vm/src/builtins/iter.rs index 918c41c990..a3e386fa3d 100644 --- a/vm/src/builtins/iter.rs +++ b/vm/src/builtins/iter.rs @@ -5,7 +5,7 @@ use super::{int, PyInt, PyTypeRef}; use crate::{ function::ArgCallable, - slots::{IteratorIterable, PyIter}, + slots::{IteratorIterable, SlotIterator}, ItemProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, VirtualMachine, }; @@ -34,7 +34,7 @@ impl PyValue for PySequenceIterator { } } -#[pyimpl(with(PyIter))] +#[pyimpl(with(SlotIterator))] impl PySequenceIterator { pub fn new(obj: PyObjectRef) -> Self { Self { @@ -91,7 +91,7 @@ impl PySequenceIterator { } impl IteratorIterable for PySequenceIterator {} -impl PyIter for PySequenceIterator { +impl SlotIterator for PySequenceIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { if let IterStatus::Exhausted = zelf.status.load() { return Err(vm.new_stop_iteration()); @@ -122,7 +122,7 @@ impl PyValue for PyCallableIterator { } } -#[pyimpl(with(PyIter))] +#[pyimpl(with(SlotIterator))] impl PyCallableIterator { pub fn new(callable: ArgCallable, sentinel: PyObjectRef) -> Self { Self { @@ -134,7 +134,7 @@ impl PyCallableIterator { } impl IteratorIterable for PyCallableIterator {} -impl PyIter for PyCallableIterator { +impl SlotIterator for PyCallableIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { if let IterStatus::Exhausted = zelf.status.load() { return Err(vm.new_stop_iteration()); diff --git a/vm/src/builtins/list.rs b/vm/src/builtins/list.rs index fa1dfe9072..11735a5fd3 100644 --- a/vm/src/builtins/list.rs +++ b/vm/src/builtins/list.rs @@ -10,7 +10,9 @@ use crate::{ function::{ArgIterable, FuncArgs, OptionalArg}, sequence::{self, SimpleSeq}, sliceable::{PySliceableSequence, PySliceableSequenceMut, SequenceIndex}, - slots::{Comparable, Hashable, Iterable, IteratorIterable, PyComparisonOp, PyIter, Unhashable}, + slots::{ + Comparable, Hashable, Iterable, IteratorIterable, PyComparisonOp, SlotIterator, Unhashable, + }, utils::Either, vm::{ReprGuard, VirtualMachine}, PyClassDef, PyClassImpl, PyComparisonValue, PyContext, PyObjectRef, PyRef, PyResult, PyValue, @@ -484,7 +486,7 @@ impl PyValue for PyListIterator { } } -#[pyimpl(with(PyIter))] +#[pyimpl(with(SlotIterator))] impl PyListIterator { #[pymethod(magic)] fn length_hint(&self) -> usize { @@ -521,7 +523,7 @@ impl PyListIterator { } impl IteratorIterable for PyListIterator {} -impl PyIter for PyListIterator { +impl SlotIterator for PyListIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { if let Exhausted = zelf.status.load() { return Err(vm.new_stop_iteration()); @@ -551,7 +553,7 @@ impl PyValue for PyListReverseIterator { } } -#[pyimpl(with(PyIter))] +#[pyimpl(with(SlotIterator))] impl PyListReverseIterator { #[pymethod(magic)] fn length_hint(&self) -> usize { @@ -595,7 +597,7 @@ impl PyListReverseIterator { } impl IteratorIterable for PyListReverseIterator {} -impl PyIter for PyListReverseIterator { +impl SlotIterator for PyListReverseIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { if let Exhausted = zelf.status.load() { return Err(vm.new_stop_iteration()); diff --git a/vm/src/builtins/make_module.rs b/vm/src/builtins/make_module.rs index f798872ea3..964da58e6c 100644 --- a/vm/src/builtins/make_module.rs +++ b/vm/src/builtins/make_module.rs @@ -25,7 +25,8 @@ mod decl { ArgBytesLike, ArgCallable, ArgIterable, FuncArgs, KwArgs, OptionalArg, OptionalOption, PosArgs, }, - iterator, py_io, + protocol::PyIter, + py_io, readline::{Readline, ReadlineResult}, scope::Scope, slots::PyComparisonOp, @@ -413,14 +414,15 @@ mod decl { iter_target: PyObjectRef, sentinel: OptionalArg, vm: &VirtualMachine, - ) -> PyResult { + ) -> PyResult { if let OptionalArg::Present(sentinel) = sentinel { let callable = ArgCallable::try_from_object(vm, iter_target)?; - Ok(PyCallableIterator::new(callable, sentinel) + let iterator = PyCallableIterator::new(callable, sentinel) .into_ref(vm) - .into_object()) + .into_object(); + Ok(PyIter::new(iterator)) } else { - iterator::get_iter(vm, iter_target) + iter_target.get_iter(vm) } } @@ -511,7 +513,7 @@ mod decl { default_value: OptionalArg, vm: &VirtualMachine, ) -> PyResult { - iterator::call_next(vm, &iterator).or_else(|err| { + PyIter::new(iterator).next(vm).or_else(|err| { if err.isinstance(&vm.ctx.exceptions.stop_iteration) { default_value.ok_or(err) } else { diff --git a/vm/src/builtins/map.rs b/vm/src/builtins/map.rs index 9a57517a46..7d7b174f8e 100644 --- a/vm/src/builtins/map.rs +++ b/vm/src/builtins/map.rs @@ -2,7 +2,8 @@ use super::PyTypeRef; use crate::{ function::PosArgs, iterator, - slots::{IteratorIterable, PyIter, SlotConstructor}, + protocol::PyIter, + slots::{IteratorIterable, SlotConstructor, SlotIterator}, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, VirtualMachine, }; @@ -14,7 +15,7 @@ use crate::{ #[derive(Debug)] pub struct PyMap { mapper: PyObjectRef, - iterators: Vec, + iterators: Vec, } impl PyValue for PyMap { @@ -24,27 +25,20 @@ impl PyValue for PyMap { } impl SlotConstructor for PyMap { - type Args = (PyObjectRef, PosArgs); + type Args = (PyObjectRef, PosArgs); - fn py_new(cls: PyTypeRef, (function, iterables): Self::Args, vm: &VirtualMachine) -> PyResult { - let iterators = iterables - .into_iter() - .map(|iterable| iterator::get_iter(vm, iterable)) - .collect::, _>>()?; - PyMap { - mapper: function, - iterators, - } - .into_pyresult_with_type(vm, cls) + fn py_new(cls: PyTypeRef, (mapper, iterators): Self::Args, vm: &VirtualMachine) -> PyResult { + let iterators = iterators.into_vec(); + PyMap { mapper, iterators }.into_pyresult_with_type(vm, cls) } } -#[pyimpl(with(PyIter, SlotConstructor), flags(BASETYPE))] +#[pyimpl(with(SlotIterator, SlotConstructor), flags(BASETYPE))] impl PyMap { #[pymethod(magic)] fn length_hint(&self, vm: &VirtualMachine) -> PyResult { self.iterators.iter().try_fold(0, |prev, cur| { - let cur = iterator::length_hint(vm, cur.clone())?.unwrap_or(0); + let cur = iterator::length_hint(vm, cur.as_object().clone())?.unwrap_or(0); let max = std::cmp::max(prev, cur); Ok(max) }) @@ -52,12 +46,12 @@ impl PyMap { } impl IteratorIterable for PyMap {} -impl PyIter for PyMap { +impl SlotIterator for PyMap { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { let next_objs = zelf .iterators .iter() - .map(|iterator| iterator::call_next(vm, iterator)) + .map(|iterator| iterator.next(vm)) .collect::, _>>()?; // the mapper itself can raise StopIteration which does stop the map iteration diff --git a/vm/src/builtins/mappingproxy.rs b/vm/src/builtins/mappingproxy.rs index c625f334e9..fcf7d90716 100644 --- a/vm/src/builtins/mappingproxy.rs +++ b/vm/src/builtins/mappingproxy.rs @@ -1,7 +1,6 @@ use super::{PyDict, PyStrRef, PyTypeRef}; use crate::{ function::OptionalArg, - iterator, slots::{Iterable, SlotConstructor}, IntoPyObject, ItemProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, VirtualMachine, @@ -137,7 +136,8 @@ impl Iterable for PyMappingProxy { PyDict::from_attributes(c.attributes.read().clone(), vm)?.into_pyobject(vm) } }; - iterator::get_iter(vm, obj) + let iter = obj.get_iter(vm)?; + Ok(iter.into_object()) } } diff --git a/vm/src/builtins/pystr.rs b/vm/src/builtins/pystr.rs index 425b30e03f..e29939296b 100644 --- a/vm/src/builtins/pystr.rs +++ b/vm/src/builtins/pystr.rs @@ -10,7 +10,8 @@ use crate::{ function::{ArgIterable, FuncArgs, OptionalArg, OptionalOption}, sliceable::PySliceableSequence, slots::{ - Comparable, Hashable, Iterable, IteratorIterable, PyComparisonOp, PyIter, SlotConstructor, + Comparable, Hashable, Iterable, IteratorIterable, PyComparisonOp, SlotConstructor, + SlotIterator, }, utils::Either, IdProtocol, IntoPyObject, ItemProtocol, PyClassDef, PyClassImpl, PyComparisonValue, PyContext, @@ -179,7 +180,7 @@ impl PyValue for PyStrIterator { } } -#[pyimpl(with(PyIter))] +#[pyimpl(with(SlotIterator))] impl PyStrIterator { #[pymethod(magic)] fn length_hint(&self) -> usize { @@ -228,7 +229,7 @@ impl PyStrIterator { } impl IteratorIterable for PyStrIterator {} -impl PyIter for PyStrIterator { +impl SlotIterator for PyStrIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { if let Exhausted = zelf.status.load() { return Err(vm.new_stop_iteration()); diff --git a/vm/src/builtins/range.rs b/vm/src/builtins/range.rs index a6510bbf95..bfb371a4bf 100644 --- a/vm/src/builtins/range.rs +++ b/vm/src/builtins/range.rs @@ -3,7 +3,7 @@ use crate::common::hash::PyHash; use crate::{ function::{FuncArgs, OptionalArg}, iterator, - slots::{Comparable, Hashable, Iterable, IteratorIterable, PyComparisonOp, PyIter}, + slots::{Comparable, Hashable, Iterable, IteratorIterable, PyComparisonOp, SlotIterator}, IdProtocol, IntoPyRef, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, VirtualMachine, }; @@ -30,7 +30,7 @@ fn iter_search( vm: &VirtualMachine, ) -> PyResult { let mut count = 0; - let iter = iterator::get_iter(vm, obj)?; + let iter = obj.get_iter(vm)?; while let Some(element) = iterator::get_next_object(vm, &iter)? { if vm.bool_eq(&item, &element)? { match flag { @@ -478,7 +478,7 @@ impl PyValue for PyLongRangeIterator { } } -#[pyimpl(with(PyIter))] +#[pyimpl(with(SlotIterator))] impl PyLongRangeIterator { #[pyslot] fn slot_new(_cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult { @@ -514,7 +514,7 @@ impl PyLongRangeIterator { } impl IteratorIterable for PyLongRangeIterator {} -impl PyIter for PyLongRangeIterator { +impl SlotIterator for PyLongRangeIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { // TODO: In pathological case (index == usize::MAX) this can wrap around // (since fetch_add wraps). This would result in the iterator spinning again @@ -547,7 +547,7 @@ impl PyValue for PyRangeIterator { } } -#[pyimpl(with(PyIter))] +#[pyimpl(with(SlotIterator))] impl PyRangeIterator { #[pyslot] fn slot_new(_cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult { @@ -584,7 +584,7 @@ impl PyRangeIterator { } impl IteratorIterable for PyRangeIterator {} -impl PyIter for PyRangeIterator { +impl SlotIterator for PyRangeIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { // TODO: In pathological case (index == usize::MAX) this can wrap around // (since fetch_add wraps). This would result in the iterator spinning again diff --git a/vm/src/builtins/set.rs b/vm/src/builtins/set.rs index f151cc6a76..7ce164d7a7 100644 --- a/vm/src/builtins/set.rs +++ b/vm/src/builtins/set.rs @@ -7,8 +7,8 @@ use crate::{ dictdatatype::{self, DictSize}, function::{ArgIterable, FuncArgs, OptionalArg, PosArgs}, slots::{ - Comparable, Hashable, Iterable, IteratorIterable, PyComparisonOp, PyIter, SlotConstructor, - Unhashable, + Comparable, Hashable, Iterable, IteratorIterable, PyComparisonOp, SlotConstructor, + SlotIterator, Unhashable, }, vm::{ReprGuard, VirtualMachine}, IdProtocol, PyClassImpl, PyComparisonValue, PyContext, PyObjectRef, PyRef, PyResult, PyValue, @@ -833,7 +833,7 @@ impl PyValue for PySetIterator { } } -#[pyimpl(with(PyIter))] +#[pyimpl(with(SlotIterator))] impl PySetIterator { #[pymethod(magic)] fn length_hint(&self) -> usize { @@ -862,7 +862,7 @@ impl PySetIterator { } impl IteratorIterable for PySetIterator {} -impl PyIter for PySetIterator { +impl SlotIterator for PySetIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { match zelf.status.load() { IterStatus::Exhausted => Err(vm.new_stop_iteration()), diff --git a/vm/src/builtins/tuple.rs b/vm/src/builtins/tuple.rs index dcf95108fa..c10f30420b 100644 --- a/vm/src/builtins/tuple.rs +++ b/vm/src/builtins/tuple.rs @@ -9,7 +9,8 @@ use crate::{ sequence::{self, SimpleSeq}, sliceable::PySliceableSequence, slots::{ - Comparable, Hashable, Iterable, IteratorIterable, PyComparisonOp, PyIter, SlotConstructor, + Comparable, Hashable, Iterable, IteratorIterable, PyComparisonOp, SlotConstructor, + SlotIterator, }, utils::Either, vm::{ReprGuard, VirtualMachine}, @@ -330,7 +331,7 @@ impl PyValue for PyTupleIterator { } } -#[pyimpl(with(PyIter))] +#[pyimpl(with(SlotIterator))] impl PyTupleIterator { #[pymethod(magic)] fn length_hint(&self) -> usize { @@ -376,7 +377,7 @@ impl PyTupleIterator { } impl IteratorIterable for PyTupleIterator {} -impl PyIter for PyTupleIterator { +impl SlotIterator for PyTupleIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { if let Exhausted = zelf.status.load() { return Err(vm.new_stop_iteration()); diff --git a/vm/src/builtins/zip.rs b/vm/src/builtins/zip.rs index 2abb5c1833..8412803d96 100644 --- a/vm/src/builtins/zip.rs +++ b/vm/src/builtins/zip.rs @@ -1,15 +1,15 @@ use super::PyTypeRef; use crate::{ function::PosArgs, - iterator, - slots::{IteratorIterable, PyIter, SlotConstructor}, - PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, VirtualMachine, + protocol::PyIter, + slots::{IteratorIterable, SlotConstructor, SlotIterator}, + PyClassImpl, PyContext, PyRef, PyResult, PyValue, VirtualMachine, }; #[pyclass(module = false, name = "zip")] #[derive(Debug)] pub struct PyZip { - iterators: Vec, + iterators: Vec, } impl PyValue for PyZip { @@ -19,22 +19,19 @@ impl PyValue for PyZip { } impl SlotConstructor for PyZip { - type Args = PosArgs; + type Args = PosArgs; - fn py_new(cls: PyTypeRef, iterables: Self::Args, vm: &VirtualMachine) -> PyResult { - let iterators = iterables - .into_iter() - .map(|iterable| iterator::get_iter(vm, iterable)) - .collect::, _>>()?; + fn py_new(cls: PyTypeRef, iterators: Self::Args, vm: &VirtualMachine) -> PyResult { + let iterators = iterators.into_vec(); PyZip { iterators }.into_pyresult_with_type(vm, cls) } } -#[pyimpl(with(PyIter, SlotConstructor), flags(BASETYPE))] +#[pyimpl(with(SlotIterator, SlotConstructor), flags(BASETYPE))] impl PyZip {} impl IteratorIterable for PyZip {} -impl PyIter for PyZip { +impl SlotIterator for PyZip { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { if zelf.iterators.is_empty() { Err(vm.new_stop_iteration()) @@ -42,7 +39,7 @@ impl PyIter for PyZip { let next_objs = zelf .iterators .iter() - .map(|iterator| iterator::call_next(vm, iterator)) + .map(|iterator| iterator.next(vm)) .collect::, _>>()?; Ok(vm.ctx.new_tuple(next_objs)) diff --git a/vm/src/frame.rs b/vm/src/frame.rs index dc3ac2a019..9c7c28dab3 100644 --- a/vm/src/frame.rs +++ b/vm/src/frame.rs @@ -17,6 +17,7 @@ use crate::{ exceptions::{self, ExceptionCtor}, function::FuncArgs, iterator, + protocol::PyIter, scope::Scope, slots::PyComparisonOp, IdProtocol, ItemProtocol, PyMethod, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, @@ -857,8 +858,8 @@ impl ExecutingFrame<'_> { } bytecode::Instruction::GetIter => { let iterated_obj = self.pop_value(); - let iter_obj = iterator::get_iter(vm, iterated_obj)?; - self.push_value(iter_obj); + let iter_obj = iterated_obj.get_iter(vm)?; + self.push_value(iter_obj.into_object()); Ok(None) } bytecode::Instruction::GetAwaitable => { @@ -1442,7 +1443,7 @@ impl ExecutingFrame<'_> { fn _send(&self, coro: &PyObjectRef, val: PyObjectRef, vm: &VirtualMachine) -> PyResult { match self.builtin_coro(coro) { Some(coro) => coro.send(val, vm), - None if vm.is_none(&val) => iterator::call_next(vm, coro), + None if vm.is_none(&val) => PyIter::new(coro).next(vm), None => { let meth = vm.get_attribute(coro.clone(), "send")?; vm.invoke(&meth, (val,)) @@ -1512,7 +1513,7 @@ impl ExecutingFrame<'_> { /// The top of stack contains the iterator, lets push it forward fn execute_for_iter(&mut self, vm: &VirtualMachine, target: bytecode::Label) -> FrameResult { - let top_of_stack = self.last_value(); + let top_of_stack = PyIter::new(self.last_value()); let next_obj = iterator::get_next_object(vm, &top_of_stack); // Check the next object: diff --git a/vm/src/function/argument.rs b/vm/src/function/argument.rs index 8b7a1d1adc..bc87e1dbf5 100644 --- a/vm/src/function/argument.rs +++ b/vm/src/function/argument.rs @@ -1,7 +1,7 @@ use super::IntoFuncArgs; -use crate::builtins::iter::PySequenceIterator; use crate::{ - iterator, PyObjectRef, PyResult, PyValue, TryFromObject, TypeProtocol, VirtualMachine, + builtins::iter::PySequenceIterator, iterator, protocol::PyIter, PyObjectRef, PyResult, PyValue, + TryFromObject, TypeProtocol, VirtualMachine, }; use std::marker::PhantomData; @@ -102,7 +102,7 @@ where type Item = PyResult; fn next(&mut self) -> Option { - iterator::get_next_object(self.vm, &self.obj) + iterator::get_next_object(self.vm, &PyIter::new(&self.obj)) .transpose() .map(|x| x.and_then(|obj| T::try_from_object(self.vm, obj))) } diff --git a/vm/src/iterator.rs b/vm/src/iterator.rs index 2d8406356b..3bbb0d0435 100644 --- a/vm/src/iterator.rs +++ b/vm/src/iterator.rs @@ -2,65 +2,24 @@ * utilities to support iteration. */ -use crate::builtins::int::{self, PyInt}; -use crate::builtins::iter::PySequenceIterator; -use crate::builtins::PyBaseExceptionRef; -use crate::vm::VirtualMachine; -use crate::{IdProtocol, PyObjectRef, PyResult, PyValue, TypeProtocol}; +use crate::{ + builtins::{int, PyBaseExceptionRef, PyInt}, + protocol::PyIter, + IdProtocol, PyObjectRef, PyResult, TypeProtocol, VirtualMachine, +}; use num_traits::Signed; -/* - * This helper function is called at multiple places. First, it is called - * in the vm when a for loop is entered. Next, it is used when the builtin - * function 'iter' is called. - */ -pub fn get_iter(vm: &VirtualMachine, iter_target: PyObjectRef) -> PyResult { - let getiter = { - let cls = iter_target.class(); - cls.mro_find_map(|x| x.slots.iter.load()) - }; - if let Some(getiter) = getiter { - let iter = getiter(iter_target, vm)?; - let cls = iter.class(); - let is_iter = cls.iter_mro().any(|x| x.slots.iternext.load().is_some()); - if is_iter { - drop(cls); - Ok(iter) - } else { - Err(vm.new_type_error(format!( - "iter() returned non-iterator of type '{}'", - cls.name() - ))) - } - } else { - vm.get_method_or_type_error(iter_target.clone(), "__getitem__", || { - format!("'{}' object is not iterable", iter_target.class().name()) - })?; - Ok(PySequenceIterator::new(iter_target) - .into_ref(vm) - .into_object()) - } -} - -pub fn call_next(vm: &VirtualMachine, iter_obj: &PyObjectRef) -> PyResult { - let iternext = { - let cls = iter_obj.class(); - cls.mro_find_map(|x| x.slots.iternext.load()) - .ok_or_else(|| { - vm.new_type_error(format!("'{}' object is not an iterator", cls.name())) - })? - }; - iternext(iter_obj, vm) -} - /* * Helper function to retrieve the next object (or none) from an iterator. */ -pub fn get_next_object( +pub fn get_next_object( vm: &VirtualMachine, - iter_obj: &PyObjectRef, -) -> PyResult> { - let next_obj: PyResult = call_next(vm, iter_obj); + iter_obj: &PyIter, +) -> PyResult> +where + T: std::borrow::Borrow, +{ + let next_obj: PyResult = iter_obj.next(vm); match next_obj { Ok(value) => Ok(Some(value)), @@ -75,12 +34,7 @@ pub fn get_next_object( } } -pub fn try_map( - vm: &VirtualMachine, - iter_obj: &PyObjectRef, - cap: usize, - mut f: F, -) -> PyResult> +pub fn try_map(vm: &VirtualMachine, iter: &PyIter, cap: usize, mut f: F) -> PyResult> where F: FnMut(PyObjectRef) -> PyResult, { @@ -90,7 +44,7 @@ where return Ok(Vec::new()); } let mut results = Vec::with_capacity(cap); - while let Some(element) = get_next_object(vm, iter_obj)? { + while let Some(element) = get_next_object(vm, iter)? { results.push(f(element)?); } results.shrink_to_fit(); diff --git a/vm/src/protocol/iter.rs b/vm/src/protocol/iter.rs new file mode 100644 index 0000000000..e19f7e525f --- /dev/null +++ b/vm/src/protocol/iter.rs @@ -0,0 +1,118 @@ +use crate::IntoPyObject; +use crate::{ + builtins::iter::PySequenceIterator, PyObjectRef, PyResult, PyValue, TryFromObject, + TypeProtocol, VirtualMachine, +}; +use std::borrow::Borrow; +use std::ops::Deref; + +/// Iterator Protocol +// https://docs.python.org/3/c-api/iter.html +#[derive(Debug, Clone)] +#[repr(transparent)] +pub struct PyIter(T) +where + T: Borrow; + +impl PyIter { + pub fn into_object(self) -> PyObjectRef { + self.0 + } +} + +impl PyIter +where + T: Borrow, +{ + pub fn new(obj: T) -> Self { + Self(obj) + } + pub fn as_object(&self) -> &PyObjectRef { + self.0.borrow() + } + pub fn next(&self, vm: &VirtualMachine) -> PyResult { + let iternext = { + self.0 + .borrow() + .class() + .mro_find_map(|x| x.slots.iternext.load()) + .ok_or_else(|| { + vm.new_type_error(format!( + "'{}' object is not an iterator", + self.0.borrow().class().name() + )) + })? + }; + iternext(self.0.borrow(), vm) + } +} + +impl Borrow for PyIter +where + T: Borrow, +{ + fn borrow(&self) -> &PyObjectRef { + self.0.borrow() + } +} + +impl Deref for PyIter +where + T: Borrow, +{ + type Target = PyObjectRef; + fn deref(&self) -> &Self::Target { + self.0.borrow() + } +} + +impl IntoPyObject for PyIter { + fn into_pyobject(self, _vm: &VirtualMachine) -> PyObjectRef { + self.into_object() + } +} + +impl TryFromObject for PyIter { + // This helper function is called at multiple places. First, it is called + // in the vm when a for loop is entered. Next, it is used when the builtin + // function 'iter' is called. + fn try_from_object(vm: &VirtualMachine, iter_target: PyObjectRef) -> PyResult { + let getiter = { + let cls = iter_target.class(); + cls.mro_find_map(|x| x.slots.iter.load()) + }; + if let Some(getiter) = getiter { + let iter = getiter(iter_target, vm)?; + let cls = iter.class(); + let is_iter = cls.iter_mro().any(|x| x.slots.iternext.load().is_some()); + if is_iter { + drop(cls); + Ok(Self(iter)) + } else { + Err(vm.new_type_error(format!( + "iter() returned non-iterator of type '{}'", + cls.name() + ))) + } + } else { + vm.get_method_or_type_error(iter_target.clone(), "__getitem__", || { + format!("'{}' object is not iterable", iter_target.class().name()) + })?; + Ok(Self( + PySequenceIterator::new(iter_target) + .into_ref(vm) + .into_object(), + )) + } + } +} + +impl PyObjectRef { + /// Takes an object and returns an iterator for it. + /// This is typically a new iterator but if the argument is an iterator, this + /// returns itself. + pub fn get_iter(self, vm: &VirtualMachine) -> PyResult { + // PyObject_GetIter + PyIter::try_from_object(vm, self) + } +} diff --git a/vm/src/protocol/mod.rs b/vm/src/protocol/mod.rs index a77e6134c1..ce4906da5e 100644 --- a/vm/src/protocol/mod.rs +++ b/vm/src/protocol/mod.rs @@ -1,3 +1,5 @@ mod buffer; +mod iter; pub(crate) use buffer::{BufferInternal, BufferOptions, PyBuffer, ResizeGuard}; +pub use iter::PyIter; diff --git a/vm/src/slots.rs b/vm/src/slots.rs index 633211f43e..4bf7933623 100644 --- a/vm/src/slots.rs +++ b/vm/src/slots.rs @@ -548,7 +548,7 @@ pub trait Iterable: PyValue { } #[pyimpl(with(Iterable))] -pub trait PyIter: PyValue + Iterable { +pub trait SlotIterator: PyValue + Iterable { #[pyslot] fn slot_iternext(zelf: &PyObjectRef, vm: &VirtualMachine) -> PyResult { if let Some(zelf) = zelf.downcast_ref() { diff --git a/vm/src/stdlib/array.rs b/vm/src/stdlib/array.rs index 16e5ed39b2..58175cb3f5 100644 --- a/vm/src/stdlib/array.rs +++ b/vm/src/stdlib/array.rs @@ -19,8 +19,8 @@ mod array { protocol::{BufferInternal, BufferOptions, PyBuffer, ResizeGuard}, sliceable::{saturate_index, PySliceableSequence, PySliceableSequenceMut, SequenceIndex}, slots::{ - AsBuffer, Comparable, Iterable, IteratorIterable, PyComparisonOp, PyIter, - SlotConstructor, + AsBuffer, Comparable, Iterable, IteratorIterable, PyComparisonOp, SlotConstructor, + SlotIterator, }, IdProtocol, IntoPyObject, IntoPyResult, PyComparisonValue, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, VirtualMachine, @@ -1193,11 +1193,11 @@ mod array { array: PyArrayRef, } - #[pyimpl(with(PyIter))] + #[pyimpl(with(SlotIterator))] impl PyArrayIter {} impl IteratorIterable for PyArrayIter {} - impl PyIter for PyArrayIter { + impl SlotIterator for PyArrayIter { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { let pos = zelf.position.fetch_add(1); if let Some(item) = zelf.array.read().getitem_by_idx(pos, vm)? { diff --git a/vm/src/stdlib/collections.rs b/vm/src/stdlib/collections.rs index 62fab4cf45..4a24b32197 100644 --- a/vm/src/stdlib/collections.rs +++ b/vm/src/stdlib/collections.rs @@ -11,8 +11,8 @@ mod _collections { function::{FuncArgs, KwArgs, OptionalArg}, sequence, sliceable, slots::{ - Comparable, Hashable, Iterable, IteratorIterable, PyComparisonOp, PyIter, - SlotConstructor, Unhashable, + Comparable, Hashable, Iterable, IteratorIterable, PyComparisonOp, SlotConstructor, + SlotIterator, Unhashable, }, vm::ReprGuard, PyComparisonValue, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, VirtualMachine, @@ -609,7 +609,7 @@ mod _collections { } } - #[pyimpl(with(PyIter, SlotConstructor))] + #[pyimpl(with(SlotIterator, SlotConstructor))] impl PyDequeIterator { pub(crate) fn new(deque: PyDequeRef) -> Self { PyDequeIterator { @@ -641,7 +641,7 @@ mod _collections { } impl IteratorIterable for PyDequeIterator {} - impl PyIter for PyDequeIterator { + impl SlotIterator for PyDequeIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { match zelf.status.load() { Exhausted => Err(vm.new_stop_iteration()), @@ -698,7 +698,7 @@ mod _collections { } } - #[pyimpl(with(PyIter, SlotConstructor))] + #[pyimpl(with(SlotIterator, SlotConstructor))] impl PyReverseDequeIterator { #[pymethod(magic)] fn length_hint(&self) -> usize { @@ -721,7 +721,7 @@ mod _collections { } impl IteratorIterable for PyReverseDequeIterator {} - impl PyIter for PyReverseDequeIterator { + impl SlotIterator for PyReverseDequeIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { match zelf.status.load() { Exhausted => Err(vm.new_stop_iteration()), diff --git a/vm/src/stdlib/csv.rs b/vm/src/stdlib/csv.rs index 0eb8ad170e..fab0a02ef1 100644 --- a/vm/src/stdlib/csv.rs +++ b/vm/src/stdlib/csv.rs @@ -2,8 +2,8 @@ use crate::common::lock::PyMutex; use crate::{ builtins::{PyStr, PyStrRef}, function::{ArgIterable, ArgumentError, FromArgs, FuncArgs}, - iterator, - slots::{IteratorIterable, PyIter}, + protocol::PyIter, + slots::{IteratorIterable, SlotIterator}, types::create_simple_type, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, VirtualMachine, @@ -85,7 +85,7 @@ struct ReadState { #[pyclass(module = "_csv", name = "reader")] #[derive(PyValue)] struct Reader { - iter: PyObjectRef, + iter: PyIter, state: PyMutex, } @@ -95,12 +95,12 @@ impl fmt::Debug for Reader { } } -#[pyimpl(with(PyIter))] +#[pyimpl(with(SlotIterator))] impl Reader {} impl IteratorIterable for Reader {} -impl PyIter for Reader { +impl SlotIterator for Reader { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - let string = iterator::call_next(vm, &zelf.iter)?; + let string = zelf.iter.next(vm)?; let string = string.downcast::().map_err(|obj| { vm.new_type_error(format!( "iterator should return strings, not {} (the file should be opened in text mode)", @@ -163,13 +163,12 @@ impl PyIter for Reader { } fn _csv_reader( - iter: PyObjectRef, + iter: PyIter, options: FormatOptions, // TODO: handle quote style, etc _rest: FuncArgs, - vm: &VirtualMachine, + _vm: &VirtualMachine, ) -> PyResult { - let iter = iterator::get_iter(vm, iter)?; Ok(Reader { iter, state: PyMutex::new(ReadState { diff --git a/vm/src/stdlib/functools.rs b/vm/src/stdlib/functools.rs index d7def1c153..08047e7920 100644 --- a/vm/src/stdlib/functools.rs +++ b/vm/src/stdlib/functools.rs @@ -3,23 +3,21 @@ pub(crate) use _functools::make_module; #[pymodule] mod _functools { use crate::function::OptionalArg; - use crate::iterator; + use crate::protocol::PyIter; use crate::vm::VirtualMachine; use crate::{PyObjectRef, PyResult, TypeProtocol}; #[pyfunction] fn reduce( function: PyObjectRef, - sequence: PyObjectRef, + iterator: PyIter, start_value: OptionalArg, vm: &VirtualMachine, ) -> PyResult { - let iterator = iterator::get_iter(vm, sequence)?; - let start_value = if let OptionalArg::Present(val) = start_value { val } else { - iterator::call_next(vm, &iterator).map_err(|err| { + iterator.next(vm).map_err(|err| { if err.isinstance(&vm.ctx.exceptions.stop_iteration) { let exc_type = vm.ctx.exceptions.type_error.clone(); vm.new_exception_msg( @@ -34,7 +32,7 @@ mod _functools { let mut accumulator = start_value; - while let Ok(next_obj) = iterator::call_next(vm, &iterator) { + while let Ok(next_obj) = iterator.next(vm) { accumulator = vm.invoke(&function, vec![accumulator, next_obj])? } diff --git a/vm/src/stdlib/io.rs b/vm/src/stdlib/io.rs index c9392b3ab3..548b652058 100644 --- a/vm/src/stdlib/io.rs +++ b/vm/src/stdlib/io.rs @@ -86,7 +86,7 @@ mod _io { ArgBytesLike, ArgIterable, ArgMemoryBuffer, FuncArgs, OptionalArg, OptionalOption, }, protocol::{BufferInternal, BufferOptions, PyBuffer, ResizeGuard}, - slots::{Iterable, PyIter, SlotConstructor}, + slots::{Iterable, SlotConstructor, SlotIterator}, utils::Either, vm::{ReprGuard, VirtualMachine}, IdProtocol, IntoPyObject, PyContext, PyObjectRef, PyRef, PyResult, PyValue, StaticType, @@ -344,7 +344,7 @@ mod _io { #[derive(Debug, PyValue)] struct _IOBase; - #[pyimpl(with(PyIter), flags(BASETYPE, HAS_DICT))] + #[pyimpl(with(SlotIterator), flags(BASETYPE, HAS_DICT))] impl _IOBase { #[pymethod] fn seek( @@ -523,7 +523,7 @@ mod _io { } } - impl PyIter for _IOBase { + impl SlotIterator for _IOBase { fn slot_iternext(zelf: &PyObjectRef, vm: &VirtualMachine) -> PyResult { let line = vm.call_method(zelf, "readline", ())?; if !line.clone().try_to_bool(vm)? { diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index 278142d6eb..d0dcb1274c 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -9,8 +9,9 @@ mod decl { use crate::{ builtins::{int, PyInt, PyIntRef, PyTupleRef, PyTypeRef}, function::{ArgCallable, FuncArgs, OptionalArg, OptionalOption, PosArgs}, - iterator::{call_next, get_iter, get_next_object}, - slots::{IteratorIterable, PyIter, SlotConstructor}, + iterator::get_next_object, + protocol::PyIter, + slots::{IteratorIterable, SlotConstructor, SlotIterator}, IdProtocol, IntoPyObject, PyObjectRef, PyRef, PyResult, PyValue, PyWeakRef, TypeProtocol, VirtualMachine, }; @@ -25,10 +26,10 @@ mod decl { struct PyItertoolsChain { iterables: Vec, cur_idx: AtomicCell, - cached_iter: PyRwLock>, + cached_iter: PyRwLock>, } - #[pyimpl(with(PyIter))] + #[pyimpl(with(SlotIterator))] impl PyItertoolsChain { #[pyslot] fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { @@ -55,7 +56,7 @@ mod decl { } } impl IteratorIterable for PyItertoolsChain {} - impl PyIter for PyItertoolsChain { + impl SlotIterator for PyItertoolsChain { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { loop { let pos = zelf.cur_idx.load(); @@ -64,7 +65,7 @@ mod decl { } let cur_iter = if zelf.cached_iter.read().is_none() { // We need to call "get_iter" outside of the lock. - let iter = get_iter(vm, zelf.iterables[pos].clone())?; + let iter = zelf.iterables[pos].clone().get_iter(vm)?; *zelf.cached_iter.write() = Some(iter.clone()); iter } else if let Some(cached_iter) = (*zelf.cached_iter.read()).clone() { @@ -74,8 +75,8 @@ mod decl { continue; }; - // We need to call "call_next" outside of the lock. - match call_next(vm, &cur_iter) { + // We need to call "next" outside of the lock. + match cur_iter.next(vm) { Ok(ok) => return Ok(ok), Err(err) => { if err.isinstance(&vm.ctx.exceptions.stop_iteration) { @@ -96,16 +97,16 @@ mod decl { #[pyclass(name = "compress")] #[derive(Debug, PyValue)] struct PyItertoolsCompress { - data: PyObjectRef, - selector: PyObjectRef, + data: PyIter, + selector: PyIter, } #[derive(FromArgs)] struct CompressNewArgs { #[pyarg(positional)] - data: PyObjectRef, + data: PyIter, #[pyarg(positional)] - selector: PyObjectRef, + selector: PyIter, } impl SlotConstructor for PyItertoolsCompress { @@ -116,27 +117,20 @@ mod decl { Self::Args { data, selector }: Self::Args, vm: &VirtualMachine, ) -> PyResult { - let data_iter = get_iter(vm, data)?; - let selector_iter = get_iter(vm, selector)?; - - PyItertoolsCompress { - data: data_iter, - selector: selector_iter, - } - .into_pyresult_with_type(vm, cls) + PyItertoolsCompress { data, selector }.into_pyresult_with_type(vm, cls) } } - #[pyimpl(with(PyIter, SlotConstructor))] + #[pyimpl(with(SlotIterator, SlotConstructor))] impl PyItertoolsCompress {} impl IteratorIterable for PyItertoolsCompress {} - impl PyIter for PyItertoolsCompress { + impl SlotIterator for PyItertoolsCompress { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { loop { - let sel_obj = call_next(vm, &zelf.selector)?; + let sel_obj = zelf.selector.next(vm)?; let verdict = sel_obj.clone().try_to_bool(vm)?; - let data_obj = call_next(vm, &zelf.data)?; + let data_obj = zelf.data.next(vm)?; if verdict { return Ok(data_obj); @@ -187,10 +181,10 @@ mod decl { } } - #[pyimpl(with(PyIter, SlotConstructor))] + #[pyimpl(with(SlotIterator, SlotConstructor))] impl PyItertoolsCount {} impl IteratorIterable for PyItertoolsCount {} - impl PyIter for PyItertoolsCount { + impl SlotIterator for PyItertoolsCount { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { let mut cur = zelf.cur.write(); let result = cur.clone(); @@ -203,18 +197,16 @@ mod decl { #[pyclass(name = "cycle")] #[derive(Debug, PyValue)] struct PyItertoolsCycle { - iter: PyObjectRef, + iter: PyIter, saved: PyRwLock>, index: AtomicCell, } impl SlotConstructor for PyItertoolsCycle { - type Args = PyObjectRef; + type Args = PyIter; - fn py_new(cls: PyTypeRef, iterable: Self::Args, vm: &VirtualMachine) -> PyResult { - let iter = get_iter(vm, iterable)?; - - PyItertoolsCycle { + fn py_new(cls: PyTypeRef, iter: Self::Args, vm: &VirtualMachine) -> PyResult { + Self { iter, saved: PyRwLock::new(Vec::new()), index: AtomicCell::new(0), @@ -223,10 +215,10 @@ mod decl { } } - #[pyimpl(with(PyIter, SlotConstructor))] + #[pyimpl(with(SlotIterator, SlotConstructor))] impl PyItertoolsCycle {} impl IteratorIterable for PyItertoolsCycle {} - impl PyIter for PyItertoolsCycle { + impl SlotIterator for PyItertoolsCycle { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { let item = if let Some(item) = get_next_object(vm, &zelf.iter)? { zelf.saved.write().push(item.clone()); @@ -288,7 +280,7 @@ mod decl { } } - #[pyimpl(with(PyIter, SlotConstructor), flags(BASETYPE))] + #[pyimpl(with(SlotIterator, SlotConstructor), flags(BASETYPE))] impl PyItertoolsRepeat { #[pymethod(magic)] fn length_hint(&self, vm: &VirtualMachine) -> PyResult { @@ -327,7 +319,7 @@ mod decl { } impl IteratorIterable for PyItertoolsRepeat {} - impl PyIter for PyItertoolsRepeat { + impl SlotIterator for PyItertoolsRepeat { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { if let Some(ref times) = zelf.times { let mut times = times.write(); @@ -345,7 +337,7 @@ mod decl { #[derive(Debug, PyValue)] struct PyItertoolsStarmap { function: PyObjectRef, - iter: PyObjectRef, + iter: PyIter, } #[derive(FromArgs)] @@ -353,7 +345,7 @@ mod decl { #[pyarg(positional)] function: PyObjectRef, #[pyarg(positional)] - iterable: PyObjectRef, + iterable: PyIter, } impl SlotConstructor for PyItertoolsStarmap { @@ -364,18 +356,17 @@ mod decl { Self::Args { function, iterable }: Self::Args, vm: &VirtualMachine, ) -> PyResult { - let iter = get_iter(vm, iterable)?; - + let iter = iterable; PyItertoolsStarmap { function, iter }.into_pyresult_with_type(vm, cls) } } - #[pyimpl(with(PyIter, SlotConstructor))] + #[pyimpl(with(SlotIterator, SlotConstructor))] impl PyItertoolsStarmap {} impl IteratorIterable for PyItertoolsStarmap {} - impl PyIter for PyItertoolsStarmap { + impl SlotIterator for PyItertoolsStarmap { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { - let obj = call_next(vm, &zelf.iter)?; + let obj = zelf.iter.next(vm)?; let function = &zelf.function; vm.invoke(function, vm.extract_elements(&obj)?) @@ -387,7 +378,7 @@ mod decl { #[derive(Debug, PyValue)] struct PyItertoolsTakewhile { predicate: PyObjectRef, - iterable: PyObjectRef, + iterable: PyIter, stop_flag: AtomicCell, } @@ -396,7 +387,7 @@ mod decl { #[pyarg(positional)] predicate: PyObjectRef, #[pyarg(positional)] - iterable: PyObjectRef, + iterable: PyIter, } impl SlotConstructor for PyItertoolsTakewhile { @@ -410,28 +401,26 @@ mod decl { }: Self::Args, vm: &VirtualMachine, ) -> PyResult { - let iter = get_iter(vm, iterable)?; - PyItertoolsTakewhile { predicate, - iterable: iter, + iterable, stop_flag: AtomicCell::new(false), } .into_pyresult_with_type(vm, cls) } } - #[pyimpl(with(PyIter, SlotConstructor))] + #[pyimpl(with(SlotIterator, SlotConstructor))] impl PyItertoolsTakewhile {} impl IteratorIterable for PyItertoolsTakewhile {} - impl PyIter for PyItertoolsTakewhile { + impl SlotIterator for PyItertoolsTakewhile { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { if zelf.stop_flag.load() { return Err(vm.new_stop_iteration()); } // might be StopIteration or anything else, which is propagated upwards - let obj = call_next(vm, &zelf.iterable)?; + let obj = zelf.iterable.next(vm)?; let predicate = &zelf.predicate; let verdict = vm.invoke(predicate, (obj.clone(),))?; @@ -450,7 +439,7 @@ mod decl { #[derive(Debug, PyValue)] struct PyItertoolsDropwhile { predicate: ArgCallable, - iterable: PyObjectRef, + iterable: PyIter, start_flag: AtomicCell, } @@ -459,7 +448,7 @@ mod decl { #[pyarg(positional)] predicate: ArgCallable, #[pyarg(positional)] - iterable: PyObjectRef, + iterable: PyIter, } impl SlotConstructor for PyItertoolsDropwhile { @@ -473,28 +462,26 @@ mod decl { }: Self::Args, vm: &VirtualMachine, ) -> PyResult { - let iter = get_iter(vm, iterable)?; - PyItertoolsDropwhile { predicate, - iterable: iter, + iterable, start_flag: AtomicCell::new(false), } .into_pyresult_with_type(vm, cls) } } - #[pyimpl(with(PyIter, SlotConstructor))] + #[pyimpl(with(SlotIterator, SlotConstructor))] impl PyItertoolsDropwhile {} impl IteratorIterable for PyItertoolsDropwhile {} - impl PyIter for PyItertoolsDropwhile { + impl SlotIterator for PyItertoolsDropwhile { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { let predicate = &zelf.predicate; let iterable = &zelf.iterable; if !zelf.start_flag.load() { loop { - let obj = call_next(vm, iterable)?; + let obj = iterable.next(vm)?; let pred = predicate.clone(); let pred_value = vm.invoke(&pred.into_object(), (obj.clone(),))?; if !pred_value.try_to_bool(vm)? { @@ -503,7 +490,7 @@ mod decl { } } } - call_next(vm, iterable) + iterable.next(vm) } } @@ -537,7 +524,7 @@ mod decl { #[pyclass(name = "groupby")] #[derive(PyValue)] struct PyItertoolsGroupBy { - iterable: PyObjectRef, + iterable: PyIter, key_func: Option, state: PyMutex, } @@ -554,7 +541,7 @@ mod decl { #[derive(FromArgs)] struct GroupByArgs { - iterable: PyObjectRef, + iterable: PyIter, #[pyarg(any, optional)] key: OptionalOption, } @@ -562,12 +549,14 @@ mod decl { impl SlotConstructor for PyItertoolsGroupBy { type Args = GroupByArgs; - fn py_new(cls: PyTypeRef, args: Self::Args, vm: &VirtualMachine) -> PyResult { - let iter = get_iter(vm, args.iterable)?; - + fn py_new( + cls: PyTypeRef, + Self::Args { iterable, key }: Self::Args, + vm: &VirtualMachine, + ) -> PyResult { PyItertoolsGroupBy { - iterable: iter, - key_func: args.key.flatten(), + iterable, + key_func: key.flatten(), state: PyMutex::new(GroupByState { current_key: None, current_value: None, @@ -579,10 +568,10 @@ mod decl { } } - #[pyimpl(with(PyIter, SlotConstructor))] + #[pyimpl(with(SlotIterator, SlotConstructor))] impl PyItertoolsGroupBy { pub(super) fn advance(&self, vm: &VirtualMachine) -> PyResult<(PyObjectRef, PyObjectRef)> { - let new_value = call_next(vm, &self.iterable)?; + let new_value = self.iterable.next(vm)?; let new_key = if let Some(ref kf) = self.key_func { vm.invoke(kf, vec![new_value.clone()])? } else { @@ -592,7 +581,7 @@ mod decl { } } impl IteratorIterable for PyItertoolsGroupBy {} - impl PyIter for PyItertoolsGroupBy { + impl SlotIterator for PyItertoolsGroupBy { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { let mut state = zelf.state.lock(); state.grouper = None; @@ -639,10 +628,10 @@ mod decl { type PyItertoolsGrouperRef = PyRef; - #[pyimpl(with(PyIter))] + #[pyimpl(with(SlotIterator))] impl PyItertoolsGrouper {} impl IteratorIterable for PyItertoolsGrouper {} - impl PyIter for PyItertoolsGrouper { + impl SlotIterator for PyItertoolsGrouper { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { let old_key = { let mut state = zelf.groupby.state.lock(); @@ -676,7 +665,7 @@ mod decl { #[pyclass(name = "islice")] #[derive(Debug, PyValue)] struct PyItertoolsIslice { - iterable: PyObjectRef, + iterable: PyIter, cur: AtomicCell, next: AtomicCell, stop: Option, @@ -707,7 +696,7 @@ mod decl { ))); } - #[pyimpl(with(PyIter))] + #[pyimpl(with(SlotIterator))] impl PyItertoolsIslice { #[pyslot] fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { @@ -758,7 +747,7 @@ mod decl { None }; - let iter = get_iter(vm, iter)?; + let iter = iter.get_iter(vm)?; PyItertoolsIslice { iterable: iter, @@ -772,10 +761,10 @@ mod decl { } impl IteratorIterable for PyItertoolsIslice {} - impl PyIter for PyItertoolsIslice { + impl SlotIterator for PyItertoolsIslice { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { while zelf.cur.load() < zelf.next.load() { - call_next(vm, &zelf.iterable)?; + zelf.iterable.next(vm)?; zelf.cur.fetch_add(1); } @@ -785,7 +774,7 @@ mod decl { } } - let obj = call_next(vm, &zelf.iterable)?; + let obj = zelf.iterable.next(vm)?; zelf.cur.fetch_add(1); // TODO is this overflow check required? attempts to copy CPython. @@ -801,7 +790,7 @@ mod decl { #[derive(Debug, PyValue)] struct PyItertoolsFilterFalse { predicate: PyObjectRef, - iterable: PyObjectRef, + iterable: PyIter, } #[derive(FromArgs)] @@ -809,7 +798,7 @@ mod decl { #[pyarg(positional)] predicate: PyObjectRef, #[pyarg(positional)] - iterable: PyObjectRef, + iterable: PyIter, } impl SlotConstructor for PyItertoolsFilterFalse { @@ -823,26 +812,24 @@ mod decl { }: Self::Args, vm: &VirtualMachine, ) -> PyResult { - let iter = get_iter(vm, iterable)?; - PyItertoolsFilterFalse { predicate, - iterable: iter, + iterable, } .into_pyresult_with_type(vm, cls) } } - #[pyimpl(with(PyIter, SlotConstructor))] + #[pyimpl(with(SlotIterator, SlotConstructor))] impl PyItertoolsFilterFalse {} impl IteratorIterable for PyItertoolsFilterFalse {} - impl PyIter for PyItertoolsFilterFalse { + impl SlotIterator for PyItertoolsFilterFalse { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { let predicate = &zelf.predicate; let iterable = &zelf.iterable; loop { - let obj = call_next(vm, iterable)?; + let obj = iterable.next(vm)?; let pred_value = if vm.is_none(predicate) { obj.clone() } else { @@ -860,7 +847,7 @@ mod decl { #[pyclass(name = "accumulate")] #[derive(Debug, PyValue)] struct PyItertoolsAccumulate { - iterable: PyObjectRef, + iterable: PyIter, binop: Option, initial: Option, acc_value: PyRwLock>, @@ -868,7 +855,7 @@ mod decl { #[derive(FromArgs)] struct AccumulateArgs { - iterable: PyObjectRef, + iterable: PyIter, #[pyarg(any, optional)] func: OptionalOption, #[pyarg(named, optional)] @@ -879,10 +866,8 @@ mod decl { type Args = AccumulateArgs; fn py_new(cls: PyTypeRef, args: AccumulateArgs, vm: &VirtualMachine) -> PyResult { - let iter = get_iter(vm, args.iterable)?; - PyItertoolsAccumulate { - iterable: iter, + iterable: args.iterable, binop: args.func.flatten(), initial: args.initial.flatten(), acc_value: PyRwLock::new(None), @@ -891,11 +876,11 @@ mod decl { } } - #[pyimpl(with(PyIter, SlotConstructor))] + #[pyimpl(with(SlotIterator, SlotConstructor))] impl PyItertoolsAccumulate {} impl IteratorIterable for PyItertoolsAccumulate {} - impl PyIter for PyItertoolsAccumulate { + impl SlotIterator for PyItertoolsAccumulate { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { let iterable = &zelf.iterable; @@ -903,11 +888,11 @@ mod decl { let next_acc_value = match acc_value { None => match &zelf.initial { - None => call_next(vm, iterable)?, + None => iterable.next(vm)?, Some(obj) => obj.clone(), }, Some(value) => { - let obj = call_next(vm, iterable)?; + let obj = iterable.next(vm)?; match &zelf.binop { None => vm._add(&value, &obj)?, Some(op) => vm.invoke(op, vec![value, obj])?, @@ -922,21 +907,21 @@ mod decl { #[derive(Debug)] struct PyItertoolsTeeData { - iterable: PyObjectRef, + iterable: PyIter, values: PyRwLock>, } impl PyItertoolsTeeData { - fn new(iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult> { + fn new(iterable: PyIter, _vm: &VirtualMachine) -> PyResult> { Ok(PyRc::new(PyItertoolsTeeData { - iterable: get_iter(vm, iterable)?, + iterable, values: PyRwLock::new(vec![]), })) } fn get_item(&self, vm: &VirtualMachine, index: usize) -> PyResult { if self.values.read().len() == index { - let result = call_next(vm, &self.iterable)?; + let result = self.iterable.next(vm)?; self.values.write().push(result); } Ok(self.values.read()[index].clone()) @@ -954,7 +939,7 @@ mod decl { #[derive(FromArgs)] struct TeeNewArgs { #[pyarg(positional)] - iterable: PyObjectRef, + iterable: PyIter, #[pyarg(positional, optional)] n: OptionalArg, } @@ -987,16 +972,15 @@ mod decl { } } - #[pyimpl(with(PyIter, SlotConstructor))] + #[pyimpl(with(SlotIterator, SlotConstructor))] impl PyItertoolsTee { - fn from_iter(iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult { + fn from_iter(iterator: PyIter, vm: &VirtualMachine) -> PyResult { let class = PyItertoolsTee::class(vm); - let it = get_iter(vm, iterable)?; - if it.class().is(PyItertoolsTee::class(vm)) { - return vm.call_method(&it, "__copy__", ()); + if iterator.class().is(PyItertoolsTee::class(vm)) { + return vm.call_method(&iterator, "__copy__", ()); } Ok(PyItertoolsTee { - tee_data: PyItertoolsTeeData::new(it, vm)?, + tee_data: PyItertoolsTeeData::new(iterator, vm)?, index: AtomicCell::new(0), } .into_ref_with_type(vm, class.clone())? @@ -1014,7 +998,7 @@ mod decl { } } impl IteratorIterable for PyItertoolsTee {} - impl PyIter for PyItertoolsTee { + impl SlotIterator for PyItertoolsTee { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { let value = zelf.tee_data.get_item(vm, zelf.index.load())?; zelf.index.fetch_add(1); @@ -1064,7 +1048,7 @@ mod decl { } } - #[pyimpl(with(PyIter, SlotConstructor))] + #[pyimpl(with(SlotIterator, SlotConstructor))] impl PyItertoolsProduct { fn update_idxs(&self, mut idxs: PyRwLockWriteGuard<'_, Vec>) { if idxs.len() == 0 { @@ -1090,7 +1074,7 @@ mod decl { } } impl IteratorIterable for PyItertoolsProduct {} - impl PyIter for PyItertoolsProduct { + impl SlotIterator for PyItertoolsProduct { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { // stop signal if zelf.stop.load() { @@ -1166,10 +1150,10 @@ mod decl { } } - #[pyimpl(with(PyIter, SlotConstructor))] + #[pyimpl(with(SlotIterator, SlotConstructor))] impl PyItertoolsCombinations {} impl IteratorIterable for PyItertoolsCombinations {} - impl PyIter for PyItertoolsCombinations { + impl SlotIterator for PyItertoolsCombinations { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { // stop signal if zelf.exhausted.load() { @@ -1256,11 +1240,11 @@ mod decl { } } - #[pyimpl(with(PyIter, SlotConstructor))] + #[pyimpl(with(SlotIterator, SlotConstructor))] impl PyItertoolsCombinationsWithReplacement {} impl IteratorIterable for PyItertoolsCombinationsWithReplacement {} - impl PyIter for PyItertoolsCombinationsWithReplacement { + impl SlotIterator for PyItertoolsCombinationsWithReplacement { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { // stop signal if zelf.exhausted.load() { @@ -1365,10 +1349,10 @@ mod decl { } } - #[pyimpl(with(PyIter, SlotConstructor))] + #[pyimpl(with(SlotIterator, SlotConstructor))] impl PyItertoolsPermutations {} impl IteratorIterable for PyItertoolsPermutations {} - impl PyIter for PyItertoolsPermutations { + impl SlotIterator for PyItertoolsPermutations { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { // stop signal if zelf.exhausted.load() { @@ -1442,15 +1426,11 @@ mod decl { } impl SlotConstructor for PyItertoolsZipLongest { - type Args = (PosArgs, ZipLongestArgs); + type Args = (PosArgs, ZipLongestArgs); - fn py_new(cls: PyTypeRef, (iterables, args): Self::Args, vm: &VirtualMachine) -> PyResult { + fn py_new(cls: PyTypeRef, (iterators, args): Self::Args, vm: &VirtualMachine) -> PyResult { let fillvalue = args.fillvalue.unwrap_or_none(vm); - let iterators = iterables - .into_iter() - .map(|iterable| get_iter(vm, iterable)) - .collect::, _>>()?; - + let iterators = iterators.into_vec(); PyItertoolsZipLongest { iterators, fillvalue, @@ -1463,14 +1443,14 @@ mod decl { #[pyclass(name = "zip_longest")] #[derive(Debug, PyValue)] struct PyItertoolsZipLongest { - iterators: Vec, + iterators: Vec, fillvalue: PyObjectRef, } - #[pyimpl(with(PyIter, SlotConstructor))] + #[pyimpl(with(SlotIterator, SlotConstructor))] impl PyItertoolsZipLongest {} impl IteratorIterable for PyItertoolsZipLongest {} - impl PyIter for PyItertoolsZipLongest { + impl SlotIterator for PyItertoolsZipLongest { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { if zelf.iterators.is_empty() { Err(vm.new_stop_iteration()) @@ -1479,7 +1459,7 @@ mod decl { let mut numactive = zelf.iterators.len(); for idx in 0..zelf.iterators.len() { - let next_obj = match call_next(vm, &zelf.iterators[idx]) { + let next_obj = match zelf.iterators[idx].next(vm) { Ok(obj) => obj, Err(err) => { if !err.isinstance(&vm.ctx.exceptions.stop_iteration) { @@ -1503,16 +1483,14 @@ mod decl { #[pyclass(name = "pairwise")] #[derive(Debug, PyValue)] struct PyItertoolsPairwise { - iterator: PyObjectRef, + iterator: PyIter, old: PyRwLock>, } impl SlotConstructor for PyItertoolsPairwise { - type Args = PyObjectRef; - - fn py_new(cls: PyTypeRef, iterable: Self::Args, vm: &VirtualMachine) -> PyResult { - let iterator = get_iter(vm, iterable)?; + type Args = PyIter; + fn py_new(cls: PyTypeRef, iterator: Self::Args, vm: &VirtualMachine) -> PyResult { PyItertoolsPairwise { iterator, old: PyRwLock::new(None), @@ -1521,16 +1499,16 @@ mod decl { } } - #[pyimpl(with(PyIter, SlotConstructor))] + #[pyimpl(with(SlotIterator, SlotConstructor))] impl PyItertoolsPairwise {} impl IteratorIterable for PyItertoolsPairwise {} - impl PyIter for PyItertoolsPairwise { + impl SlotIterator for PyItertoolsPairwise { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { let old = match zelf.old.read().clone() { - None => call_next(vm, &zelf.iterator)?, + None => zelf.iterator.next(vm)?, Some(obj) => obj, }; - let new = call_next(vm, &zelf.iterator)?; + let new = zelf.iterator.next(vm)?; *zelf.old.write() = Some(new.clone()); Ok(vm.ctx.new_tuple(vec![old, new])) } diff --git a/vm/src/stdlib/operator.rs b/vm/src/stdlib/operator.rs index e2214224b9..d71359b8ae 100644 --- a/vm/src/stdlib/operator.rs +++ b/vm/src/stdlib/operator.rs @@ -14,6 +14,7 @@ mod _operator { builtins::{PyInt, PyIntRef, PyStrRef, PyTypeRef}, function::{ArgBytesLike, FuncArgs, KwArgs, OptionalArg}, iterator, + protocol::PyIter, slots::{ Callable, PyComparisonOp::{Eq, Ge, Gt, Le, Lt, Ne}, @@ -217,10 +218,9 @@ mod _operator { /// Return the number of occurrences of b in a. #[pyfunction(name = "countOf")] - fn count_of(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult { + fn count_of(a: PyIter, b: PyObjectRef, vm: &VirtualMachine) -> PyResult { let mut count: usize = 0; - let iter = iterator::get_iter(vm, a)?; - while let Some(element) = iterator::get_next_object(vm, &iter)? { + while let Some(element) = iterator::get_next_object(vm, &a)? { if element.is(&b) || vm.bool_eq(&b, &element)? { count += 1; } @@ -242,10 +242,9 @@ mod _operator { /// Return the number of occurrences of b in a. #[pyfunction(name = "indexOf")] - fn index_of(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult { + fn index_of(a: PyIter, b: PyObjectRef, vm: &VirtualMachine) -> PyResult { let mut index: usize = 0; - let iter = iterator::get_iter(vm, a)?; - while let Some(element) = iterator::get_next_object(vm, &iter)? { + while let Some(element) = iterator::get_next_object(vm, &a)? { if element.is(&b) || vm.bool_eq(&b, &element)? { return Ok(index); } diff --git a/vm/src/stdlib/os.rs b/vm/src/stdlib/os.rs index 3210318960..3a652678d6 100644 --- a/vm/src/stdlib/os.rs +++ b/vm/src/stdlib/os.rs @@ -406,7 +406,7 @@ pub(super) mod _os { crt_fd::{Fd, Offset}, exceptions::IntoPyException, function::{ArgBytesLike, FuncArgs, OptionalArg}, - slots::{IteratorIterable, PyIter}, + slots::{IteratorIterable, SlotIterator}, suppress_iph, utils::Either, vm::{ReprGuard, VirtualMachine}, @@ -855,7 +855,7 @@ pub(super) mod _os { mode: OutputMode, } - #[pyimpl(with(PyIter))] + #[pyimpl(with(SlotIterator))] impl ScandirIterator { #[pymethod] fn close(&self) { @@ -873,7 +873,7 @@ pub(super) mod _os { } } impl IteratorIterable for ScandirIterator {} - impl PyIter for ScandirIterator { + impl SlotIterator for ScandirIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { if zelf.exhausted.load() { return Err(vm.new_stop_iteration()); diff --git a/vm/src/stdlib/pystruct.rs b/vm/src/stdlib/pystruct.rs index 501391ea38..a85a0f3f7d 100644 --- a/vm/src/stdlib/pystruct.rs +++ b/vm/src/stdlib/pystruct.rs @@ -18,7 +18,7 @@ pub(crate) mod _struct { }, common::str::wchar_t, function::{ArgBytesLike, ArgMemoryBuffer, PosArgs}, - slots::{IteratorIterable, PyIter, SlotConstructor}, + slots::{IteratorIterable, SlotConstructor, SlotIterator}, utils::Either, IntoPyObject, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, VirtualMachine, }; @@ -823,7 +823,7 @@ pub(crate) mod _struct { } } - #[pyimpl(with(PyIter))] + #[pyimpl(with(SlotIterator))] impl UnpackIterator { #[pymethod(magic)] fn length_hint(&self) -> usize { @@ -831,7 +831,7 @@ pub(crate) mod _struct { } } impl IteratorIterable for UnpackIterator {} - impl PyIter for UnpackIterator { + impl SlotIterator for UnpackIterator { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { let size = zelf.format_spec.size; let offset = zelf.offset.fetch_add(size); diff --git a/vm/src/vm.rs b/vm/src/vm.rs index d9c473c1f2..45361bfb99 100644 --- a/vm/src/vm.rs +++ b/vm/src/vm.rs @@ -1270,7 +1270,7 @@ impl VirtualMachine { .map(|obj| func(obj.clone())) .collect() } else { - let iter = iterator::get_iter(self, value.clone())?; + let iter = value.clone().get_iter(self)?; let cap = match iterator::length_hint(self, value.clone()) { Err(e) if e.class().is(&self.ctx.exceptions.runtime_error) => return Err(e), Ok(Some(value)) => value, @@ -1318,7 +1318,7 @@ impl VirtualMachine { ref t @ PyTuple => Ok(t.as_slice().iter().cloned().map(f).collect()), // TODO: put internal iterable type obj => { - let iter = iterator::get_iter(self, obj.clone())?; + let iter = obj.clone().get_iter(self)?; let cap = match iterator::length_hint(self, obj.clone()) { Err(e) if e.class().is(&self.ctx.exceptions.runtime_error) => { return Ok(Err(e)) @@ -1926,7 +1926,7 @@ impl VirtualMachine { // https://docs.python.org/3/reference/expressions.html#membership-test-operations fn _membership_iter_search(&self, haystack: PyObjectRef, needle: PyObjectRef) -> PyResult { - let iter = iterator::get_iter(self, haystack)?; + let iter = haystack.get_iter(self)?; loop { if let Some(element) = iterator::get_next_object(self, &iter)? { if self.bool_eq(&needle, &element)? { diff --git a/wasm/lib/src/js_module.rs b/wasm/lib/src/js_module.rs index a99efdd63f..b880422412 100644 --- a/wasm/lib/src/js_module.rs +++ b/wasm/lib/src/js_module.rs @@ -9,7 +9,7 @@ use wasm_bindgen_futures::{future_to_promise, JsFuture}; use rustpython_vm::builtins::PyBaseExceptionRef; use rustpython_vm::builtins::{PyFloatRef, PyStrRef, PyTypeRef}; use rustpython_vm::function::{OptionalArg, OptionalOption, PosArgs}; -use rustpython_vm::slots::{IteratorIterable, PyIter}; +use rustpython_vm::slots::{IteratorIterable, SlotIterator}; use rustpython_vm::types::create_simple_type; use rustpython_vm::VirtualMachine; use rustpython_vm::{ @@ -549,7 +549,7 @@ impl fmt::Debug for AwaitPromise { } } -#[pyimpl(with(PyIter))] +#[pyimpl(with(SlotIterator))] impl AwaitPromise { #[pymethod] fn send(&self, val: Option, vm: &VirtualMachine) -> PyResult { @@ -588,7 +588,7 @@ impl AwaitPromise { } impl IteratorIterable for AwaitPromise {} -impl PyIter for AwaitPromise { +impl SlotIterator for AwaitPromise { fn next(zelf: &PyRef, vm: &VirtualMachine) -> PyResult { zelf.send(None, vm) }