Skip to content

Commit 408032a

Browse files
authored
Merge pull request RustPython#1842 from RustPython/coolreader18/callable-iter
Add iter(callable, sentinel)
2 parents 3df6c04 + e83c23b commit 408032a

File tree

5 files changed

+71
-15
lines changed

5 files changed

+71
-15
lines changed

Lib/test/test_iter.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,6 @@ def __iter__(self):
235235
self.assertRaises(TypeError, iter, IterClass())
236236

237237
# Test two-argument iter() with callable instance
238-
# TODO: RUSTPYTHON
239-
@unittest.expectedFailure
240238
def test_iter_callable(self):
241239
class C:
242240
def __init__(self):
@@ -250,8 +248,6 @@ def __call__(self):
250248
self.check_iterator(iter(C(), 10), list(range(10)), pickle=False)
251249

252250
# Test two-argument iter() with function
253-
# TODO: RUSTPYTHON
254-
@unittest.expectedFailure
255251
def test_iter_function(self):
256252
def spam(state=[0]):
257253
i = state[0]
@@ -260,8 +256,6 @@ def spam(state=[0]):
260256
self.check_iterator(iter(spam, 10), list(range(10)), pickle=False)
261257

262258
# Test two-argument iter() with function that raises StopIteration
263-
# TODO: RUSTPYTHON
264-
@unittest.expectedFailure
265259
def test_iter_function_stop(self):
266260
def spam(state=[0]):
267261
i = state[0]
@@ -272,8 +266,6 @@ def spam(state=[0]):
272266
self.check_iterator(iter(spam, 20), list(range(10)), pickle=False)
273267

274268
# Test exception propagation through function iterator
275-
# TODO: RUSTPYTHON
276-
@unittest.expectedFailure
277269
def test_exception_function(self):
278270
def spam(state=[0]):
279271
i = state[0]
@@ -962,8 +954,6 @@ def test_sinkstate_sequence(self):
962954
a.n = 10
963955
self.assertEqual(list(b), [])
964956

965-
# TODO: RUSTPYTHON
966-
@unittest.expectedFailure
967957
def test_sinkstate_callable(self):
968958
# This used to fail
969959
def spam(state=[0]):

tests/snippets/iterable.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,7 @@ def __getitem__(self, x):
3030
class C: pass
3131
assert_raises(TypeError, lambda: 5 in C())
3232
assert_raises(TypeError, iter, C)
33+
34+
it = iter([1,2,3,4,5])
35+
call_it = iter(lambda: next(it), 4)
36+
assert list(call_it) == [1,2,3]

vm/src/builtins.rs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ use crate::obj::objstr::{PyString, PyStringRef};
2929
use crate::obj::objtype::{self, PyClassRef};
3030
use crate::pyhash;
3131
use crate::pyobject::{
32-
Either, IdProtocol, ItemProtocol, PyIterable, PyObjectRef, PyResult, PyValue, TryFromObject,
33-
TypeProtocol,
32+
Either, IdProtocol, ItemProtocol, PyCallable, PyIterable, PyObjectRef, PyResult, PyValue,
33+
TryFromObject, TypeProtocol,
3434
};
3535
use crate::readline::{Readline, ReadlineResult};
3636
use crate::scope::Scope;
@@ -404,8 +404,19 @@ fn builtin_issubclass(
404404
)
405405
}
406406

407-
fn builtin_iter(iter_target: PyObjectRef, vm: &VirtualMachine) -> PyResult {
408-
objiter::get_iter(vm, &iter_target)
407+
fn builtin_iter(
408+
iter_target: PyObjectRef,
409+
sentinel: OptionalArg<PyObjectRef>,
410+
vm: &VirtualMachine,
411+
) -> PyResult {
412+
if let OptionalArg::Present(sentinel) = sentinel {
413+
let callable = PyCallable::try_from_object(vm, iter_target)?;
414+
Ok(objiter::PyCallableIterator::new(callable, sentinel)
415+
.into_ref(vm)
416+
.into_object())
417+
} else {
418+
objiter::get_iter(vm, &iter_target)
419+
}
409420
}
410421

411422
fn builtin_len(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {

vm/src/obj/objiter.rs

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ use super::objsequence;
1010
use super::objtype::{self, PyClassRef};
1111
use crate::exceptions::PyBaseExceptionRef;
1212
use crate::pyobject::{
13-
PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol,
13+
PyCallable, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject,
14+
TypeProtocol,
1415
};
1516
use crate::vm::VirtualMachine;
1617

@@ -201,6 +202,53 @@ pub fn seq_iter_method(obj: PyObjectRef) -> PySequenceIterator {
201202
}
202203
}
203204

205+
#[pyclass(name = "callable_iterator")]
206+
#[derive(Debug)]
207+
pub struct PyCallableIterator {
208+
callable: PyCallable,
209+
sentinel: PyObjectRef,
210+
done: Cell<bool>,
211+
}
212+
213+
impl PyValue for PyCallableIterator {
214+
fn class(vm: &VirtualMachine) -> PyClassRef {
215+
vm.ctx.types.callable_iterator.clone()
216+
}
217+
}
218+
219+
#[pyimpl]
220+
impl PyCallableIterator {
221+
pub fn new(callable: PyCallable, sentinel: PyObjectRef) -> Self {
222+
Self {
223+
callable,
224+
sentinel,
225+
done: Cell::new(false),
226+
}
227+
}
228+
229+
#[pymethod(magic)]
230+
fn next(&self, vm: &VirtualMachine) -> PyResult {
231+
if self.done.get() {
232+
return Err(new_stop_iteration(vm));
233+
}
234+
235+
let ret = self.callable.invoke(vec![], vm)?;
236+
237+
if vm.bool_eq(ret.clone(), self.sentinel.clone())? {
238+
self.done.set(true);
239+
Err(new_stop_iteration(vm))
240+
} else {
241+
Ok(ret)
242+
}
243+
}
244+
245+
#[pymethod(magic)]
246+
fn iter(zelf: PyRef<Self>) -> PyRef<Self> {
247+
zelf
248+
}
249+
}
250+
204251
pub fn init(context: &PyContext) {
205252
PySequenceIterator::extend_class(context, &context.types.iter_type);
253+
PyCallableIterator::extend_class(context, &context.types.callable_iterator);
206254
}

vm/src/types.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ pub struct TypeZoo {
5757
pub bytearray_type: PyClassRef,
5858
pub bytearrayiterator_type: PyClassRef,
5959
pub bool_type: PyClassRef,
60+
pub callable_iterator: PyClassRef,
6061
pub classmethod_type: PyClassRef,
6162
pub code_type: PyClassRef,
6263
pub coroutine_type: PyClassRef,
@@ -184,6 +185,7 @@ impl TypeZoo {
184185
let slice_type = create_type("slice", &type_type, &object_type);
185186
let mappingproxy_type = create_type("mappingproxy", &type_type, &object_type);
186187
let traceback_type = create_type("traceback", &type_type, &object_type);
188+
let callable_iterator = create_type("callable_iterator", &type_type, &object_type);
187189

188190
Self {
189191
async_generator,
@@ -196,6 +198,7 @@ impl TypeZoo {
196198
bytearrayiterator_type,
197199
bytes_type,
198200
bytesiterator_type,
201+
callable_iterator,
199202
code_type,
200203
coroutine_type,
201204
coroutine_wrapper_type,

0 commit comments

Comments
 (0)