Skip to content

Add iter(callable, sentinel) #1842

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions Lib/test/test_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,6 @@ def __iter__(self):
self.assertRaises(TypeError, iter, IterClass())

# Test two-argument iter() with callable instance
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_iter_callable(self):
class C:
def __init__(self):
Expand All @@ -250,8 +248,6 @@ def __call__(self):
self.check_iterator(iter(C(), 10), list(range(10)), pickle=False)

# Test two-argument iter() with function
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_iter_function(self):
def spam(state=[0]):
i = state[0]
Expand All @@ -260,8 +256,6 @@ def spam(state=[0]):
self.check_iterator(iter(spam, 10), list(range(10)), pickle=False)

# Test two-argument iter() with function that raises StopIteration
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_iter_function_stop(self):
def spam(state=[0]):
i = state[0]
Expand All @@ -272,8 +266,6 @@ def spam(state=[0]):
self.check_iterator(iter(spam, 20), list(range(10)), pickle=False)

# Test exception propagation through function iterator
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_exception_function(self):
def spam(state=[0]):
i = state[0]
Expand Down Expand Up @@ -962,8 +954,6 @@ def test_sinkstate_sequence(self):
a.n = 10
self.assertEqual(list(b), [])

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_sinkstate_callable(self):
# This used to fail
def spam(state=[0]):
Expand Down
4 changes: 4 additions & 0 deletions tests/snippets/iterable.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,7 @@ def __getitem__(self, x):
class C: pass
assert_raises(TypeError, lambda: 5 in C())
assert_raises(TypeError, iter, C)

it = iter([1,2,3,4,5])
call_it = iter(lambda: next(it), 4)
assert list(call_it) == [1,2,3]
19 changes: 15 additions & 4 deletions vm/src/builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ use crate::obj::objstr::{PyString, PyStringRef};
use crate::obj::objtype::{self, PyClassRef};
use crate::pyhash;
use crate::pyobject::{
Either, IdProtocol, ItemProtocol, PyIterable, PyObjectRef, PyResult, PyValue, TryFromObject,
TypeProtocol,
Either, IdProtocol, ItemProtocol, PyCallable, PyIterable, PyObjectRef, PyResult, PyValue,
TryFromObject, TypeProtocol,
};
use crate::readline::{Readline, ReadlineResult};
use crate::scope::Scope;
Expand Down Expand Up @@ -404,8 +404,19 @@ fn builtin_issubclass(
)
}

fn builtin_iter(iter_target: PyObjectRef, vm: &VirtualMachine) -> PyResult {
objiter::get_iter(vm, &iter_target)
fn builtin_iter(
iter_target: PyObjectRef,
sentinel: OptionalArg<PyObjectRef>,
vm: &VirtualMachine,
) -> PyResult {
if let OptionalArg::Present(sentinel) = sentinel {
let callable = PyCallable::try_from_object(vm, iter_target)?;
Ok(objiter::PyCallableIterator::new(callable, sentinel)
.into_ref(vm)
.into_object())
} else {
objiter::get_iter(vm, &iter_target)
}
}

fn builtin_len(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
Expand Down
50 changes: 49 additions & 1 deletion vm/src/obj/objiter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ use super::objsequence;
use super::objtype::{self, PyClassRef};
use crate::exceptions::PyBaseExceptionRef;
use crate::pyobject::{
PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol,
PyCallable, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject,
TypeProtocol,
};
use crate::vm::VirtualMachine;

Expand Down Expand Up @@ -201,6 +202,53 @@ pub fn seq_iter_method(obj: PyObjectRef) -> PySequenceIterator {
}
}

#[pyclass(name = "callable_iterator")]
#[derive(Debug)]
pub struct PyCallableIterator {
callable: PyCallable,
sentinel: PyObjectRef,
done: Cell<bool>,
}

impl PyValue for PyCallableIterator {
fn class(vm: &VirtualMachine) -> PyClassRef {
vm.ctx.types.callable_iterator.clone()
}
}

#[pyimpl]
impl PyCallableIterator {
pub fn new(callable: PyCallable, sentinel: PyObjectRef) -> Self {
Self {
callable,
sentinel,
done: Cell::new(false),
}
}

#[pymethod(magic)]
fn next(&self, vm: &VirtualMachine) -> PyResult {
if self.done.get() {
return Err(new_stop_iteration(vm));
}

let ret = self.callable.invoke(vec![], vm)?;

if vm.bool_eq(ret.clone(), self.sentinel.clone())? {
self.done.set(true);
Err(new_stop_iteration(vm))
} else {
Ok(ret)
}
}

#[pymethod(magic)]
fn iter(zelf: PyRef<Self>) -> PyRef<Self> {
zelf
}
}

pub fn init(context: &PyContext) {
PySequenceIterator::extend_class(context, &context.types.iter_type);
PyCallableIterator::extend_class(context, &context.types.callable_iterator);
}
3 changes: 3 additions & 0 deletions vm/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ pub struct TypeZoo {
pub bytearray_type: PyClassRef,
pub bytearrayiterator_type: PyClassRef,
pub bool_type: PyClassRef,
pub callable_iterator: PyClassRef,
pub classmethod_type: PyClassRef,
pub code_type: PyClassRef,
pub coroutine_type: PyClassRef,
Expand Down Expand Up @@ -184,6 +185,7 @@ impl TypeZoo {
let slice_type = create_type("slice", &type_type, &object_type);
let mappingproxy_type = create_type("mappingproxy", &type_type, &object_type);
let traceback_type = create_type("traceback", &type_type, &object_type);
let callable_iterator = create_type("callable_iterator", &type_type, &object_type);

Self {
async_generator,
Expand All @@ -196,6 +198,7 @@ impl TypeZoo {
bytearrayiterator_type,
bytes_type,
bytesiterator_type,
callable_iterator,
code_type,
coroutine_type,
coroutine_wrapper_type,
Expand Down