diff --git a/Lib/test/test_iter.py b/Lib/test/test_iter.py index 7bdcc27b25..f5789e09f4 100644 --- a/Lib/test/test_iter.py +++ b/Lib/test/test_iter.py @@ -235,8 +235,6 @@ def __iter__(self): self.assertRaises(TypeError, iter, IterClass()) # Test two-argument iter() with callable instance - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_iter_callable(self): class C: def __init__(self): @@ -250,8 +248,6 @@ def __call__(self): self.check_iterator(iter(C(), 10), list(range(10)), pickle=False) # Test two-argument iter() with function - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_iter_function(self): def spam(state=[0]): i = state[0] @@ -260,8 +256,6 @@ def spam(state=[0]): self.check_iterator(iter(spam, 10), list(range(10)), pickle=False) # Test two-argument iter() with function that raises StopIteration - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_iter_function_stop(self): def spam(state=[0]): i = state[0] @@ -272,8 +266,6 @@ def spam(state=[0]): self.check_iterator(iter(spam, 20), list(range(10)), pickle=False) # Test exception propagation through function iterator - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_exception_function(self): def spam(state=[0]): i = state[0] @@ -962,8 +954,6 @@ def test_sinkstate_sequence(self): a.n = 10 self.assertEqual(list(b), []) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_sinkstate_callable(self): # This used to fail def spam(state=[0]): diff --git a/tests/snippets/iterable.py b/tests/snippets/iterable.py index bd3e1dcd7d..7158296c38 100644 --- a/tests/snippets/iterable.py +++ b/tests/snippets/iterable.py @@ -30,3 +30,7 @@ def __getitem__(self, x): class C: pass assert_raises(TypeError, lambda: 5 in C()) assert_raises(TypeError, iter, C) + +it = iter([1,2,3,4,5]) +call_it = iter(lambda: next(it), 4) +assert list(call_it) == [1,2,3] diff --git a/vm/src/builtins.rs b/vm/src/builtins.rs index afb0baaf43..f81804d1a5 100644 --- a/vm/src/builtins.rs +++ b/vm/src/builtins.rs @@ -29,8 +29,8 @@ use crate::obj::objstr::{PyString, PyStringRef}; use crate::obj::objtype::{self, PyClassRef}; use crate::pyhash; use crate::pyobject::{ - Either, IdProtocol, ItemProtocol, PyIterable, PyObjectRef, PyResult, PyValue, TryFromObject, - TypeProtocol, + Either, IdProtocol, ItemProtocol, PyCallable, PyIterable, PyObjectRef, PyResult, PyValue, + TryFromObject, TypeProtocol, }; use crate::readline::{Readline, ReadlineResult}; use crate::scope::Scope; @@ -404,8 +404,19 @@ fn builtin_issubclass( ) } -fn builtin_iter(iter_target: PyObjectRef, vm: &VirtualMachine) -> PyResult { - objiter::get_iter(vm, &iter_target) +fn builtin_iter( + iter_target: PyObjectRef, + sentinel: OptionalArg, + vm: &VirtualMachine, +) -> PyResult { + if let OptionalArg::Present(sentinel) = sentinel { + let callable = PyCallable::try_from_object(vm, iter_target)?; + Ok(objiter::PyCallableIterator::new(callable, sentinel) + .into_ref(vm) + .into_object()) + } else { + objiter::get_iter(vm, &iter_target) + } } fn builtin_len(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { diff --git a/vm/src/obj/objiter.rs b/vm/src/obj/objiter.rs index 5c1e06a6bf..bfaa10e779 100644 --- a/vm/src/obj/objiter.rs +++ b/vm/src/obj/objiter.rs @@ -10,7 +10,8 @@ use super::objsequence; use super::objtype::{self, PyClassRef}; use crate::exceptions::PyBaseExceptionRef; use crate::pyobject::{ - PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, + PyCallable, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, + TypeProtocol, }; use crate::vm::VirtualMachine; @@ -201,6 +202,53 @@ pub fn seq_iter_method(obj: PyObjectRef) -> PySequenceIterator { } } +#[pyclass(name = "callable_iterator")] +#[derive(Debug)] +pub struct PyCallableIterator { + callable: PyCallable, + sentinel: PyObjectRef, + done: Cell, +} + +impl PyValue for PyCallableIterator { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.ctx.types.callable_iterator.clone() + } +} + +#[pyimpl] +impl PyCallableIterator { + pub fn new(callable: PyCallable, sentinel: PyObjectRef) -> Self { + Self { + callable, + sentinel, + done: Cell::new(false), + } + } + + #[pymethod(magic)] + fn next(&self, vm: &VirtualMachine) -> PyResult { + if self.done.get() { + return Err(new_stop_iteration(vm)); + } + + let ret = self.callable.invoke(vec![], vm)?; + + if vm.bool_eq(ret.clone(), self.sentinel.clone())? { + self.done.set(true); + Err(new_stop_iteration(vm)) + } else { + Ok(ret) + } + } + + #[pymethod(magic)] + fn iter(zelf: PyRef) -> PyRef { + zelf + } +} + pub fn init(context: &PyContext) { PySequenceIterator::extend_class(context, &context.types.iter_type); + PyCallableIterator::extend_class(context, &context.types.callable_iterator); } diff --git a/vm/src/types.rs b/vm/src/types.rs index f2cc4a1845..b66849bf4d 100644 --- a/vm/src/types.rs +++ b/vm/src/types.rs @@ -57,6 +57,7 @@ pub struct TypeZoo { pub bytearray_type: PyClassRef, pub bytearrayiterator_type: PyClassRef, pub bool_type: PyClassRef, + pub callable_iterator: PyClassRef, pub classmethod_type: PyClassRef, pub code_type: PyClassRef, pub coroutine_type: PyClassRef, @@ -184,6 +185,7 @@ impl TypeZoo { let slice_type = create_type("slice", &type_type, &object_type); let mappingproxy_type = create_type("mappingproxy", &type_type, &object_type); let traceback_type = create_type("traceback", &type_type, &object_type); + let callable_iterator = create_type("callable_iterator", &type_type, &object_type); Self { async_generator, @@ -196,6 +198,7 @@ impl TypeZoo { bytearrayiterator_type, bytes_type, bytesiterator_type, + callable_iterator, code_type, coroutine_type, coroutine_wrapper_type,