From 7be477a0919a6600e2304036c55cbfcdea32214f Mon Sep 17 00:00:00 2001 From: Daniel Alley Date: Tue, 24 Dec 2019 20:13:50 -0500 Subject: [PATCH 1/2] Add itertools.cycle() --- tests/snippets/stdlib_itertools.py | 26 +++++++++++ vm/src/stdlib/itertools.rs | 73 +++++++++++++++++++++++++++++- 2 files changed, 98 insertions(+), 1 deletion(-) diff --git a/tests/snippets/stdlib_itertools.py b/tests/snippets/stdlib_itertools.py index e876f23125..c525e5f07b 100644 --- a/tests/snippets/stdlib_itertools.py +++ b/tests/snippets/stdlib_itertools.py @@ -76,6 +76,32 @@ # assert next(c) == 1.5 +# itertools.cycle tests + +r = itertools.cycle([1, 2, 3]) +assert next(r) == 1 +assert next(r) == 2 +assert next(r) == 3 +assert next(r) == 1 +assert next(r) == 2 +assert next(r) == 3 +assert next(r) == 1 + +r = itertools.cycle([1]) +assert next(r) == 1 +assert next(r) == 1 +assert next(r) == 1 + +r = itertools.cycle([]) +with assert_raises(StopIteration): + next(r) + +with assert_raises(TypeError): + itertools.cycle(None) + +with assert_raises(TypeError): + itertools.cycle(10) + # itertools.repeat tests # no times diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index 394cb299fa..0013c99131 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -11,7 +11,7 @@ use num_traits::ToPrimitive; use crate::function::{Args, OptionalArg, OptionalOption, PyFuncArgs}; use crate::obj::objbool; use crate::obj::objint::{self, PyInt, PyIntRef}; -use crate::obj::objiter::{call_next, get_all, get_iter, new_stop_iteration}; +use crate::obj::objiter::{call_next, get_all, get_iter, get_next_object, new_stop_iteration}; use crate::obj::objtuple::PyTuple; use crate::obj::objtype::{self, PyClassRef}; use crate::pyobject::{ @@ -177,6 +177,73 @@ impl PyItertoolsCount { } } +#[pyclass] +#[derive(Debug)] +struct PyItertoolsCycle { + iter: RefCell, + saved: RefCell>, + index: Cell, + first_pass: Cell, +} + +impl PyValue for PyItertoolsCycle { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("itertools", "cycle") + } +} + +#[pyimpl] +impl PyItertoolsCycle { + #[pyslot(new)] + fn tp_new( + cls: PyClassRef, + iterable: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult> { + let iter = get_iter(vm, &iterable)?; + + PyItertoolsCycle { + iter: RefCell::new(iter.clone()), + saved: RefCell::new(Vec::new()), + index: Cell::new(0), + first_pass: Cell::new(false), + } + .into_ref_with_type(vm, cls) + } + + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + let item = if let Some(item) = get_next_object(vm, &self.iter.borrow())? { + if self.first_pass.get() { + return Ok(item); + } + + self.saved.borrow_mut().push(item.clone()); + item + } else { + if self.saved.borrow().len() == 0 { + return Err(new_stop_iteration(vm)); + } + + let last_index = self.index.get(); + self.index.set(self.index.get() + 1); + + if self.index.get() >= self.saved.borrow().len() { + self.index.set(0); + } + + self.saved.borrow()[last_index].clone() + }; + + Ok(item) + } + + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { + zelf + } +} + #[pyclass] #[derive(Debug)] struct PyItertoolsRepeat { @@ -1177,6 +1244,9 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let count = ctx.new_class("count", ctx.object()); PyItertoolsCount::extend_class(ctx, &count); + let cycle = ctx.new_class("cycle", ctx.object()); + PyItertoolsCycle::extend_class(ctx, &cycle); + let dropwhile = ctx.new_class("dropwhile", ctx.object()); PyItertoolsDropwhile::extend_class(ctx, &dropwhile); @@ -1211,6 +1281,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "compress" => compress, "combinations" => combinations, "count" => count, + "cycle" => cycle, "dropwhile" => dropwhile, "islice" => islice, "filterfalse" => filterfalse, From bf82caed4b51834c02af21d861721320ffb242ea Mon Sep 17 00:00:00 2001 From: Daniel Alley Date: Tue, 24 Dec 2019 22:53:11 -0500 Subject: [PATCH 2/2] Implement itertools.chain.from_iterable() --- tests/snippets/stdlib_itertools.py | 32 +++++++++++++++++++++++++++++- vm/src/stdlib/itertools.rs | 16 +++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/tests/snippets/stdlib_itertools.py b/tests/snippets/stdlib_itertools.py index c525e5f07b..b8bc0ae3f6 100644 --- a/tests/snippets/stdlib_itertools.py +++ b/tests/snippets/stdlib_itertools.py @@ -21,6 +21,36 @@ with assert_raises(TypeError): next(x) +# empty +with assert_raises(TypeError): + chain.from_iterable() + +with assert_raises(TypeError): + chain.from_iterable("abc", "def") + +with assert_raises(TypeError): + # iterables are lazily evaluated -- can be constructed but will fail to execute + list(chain.from_iterable([1, 2, 3])) + +with assert_raises(TypeError): + list(chain(1)) + +args = ["abc", "def"] +assert list(chain.from_iterable(args)) == ['a', 'b', 'c', 'd', 'e', 'f'] + +args = [[], "", b"", ()] +assert list(chain.from_iterable(args)) == [] + +args = ["ab", "cd", (), 'e'] +assert list(chain.from_iterable(args)) == ['a', 'b', 'c', 'd', 'e'] + +x = chain.from_iterable(["ab", 1]) +assert next(x) == 'a' +assert next(x) == 'b' +with assert_raises(TypeError): + next(x) + + # itertools.count tests # default arguments @@ -117,7 +147,7 @@ with assert_raises(StopIteration): next(r) -# timees = 0 +# times = 0 r = itertools.repeat(1, 0) with assert_raises(StopIteration): next(r) diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index 0013c99131..84b690eb8c 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -73,6 +73,22 @@ impl PyItertoolsChain { fn iter(zelf: PyRef, _vm: &VirtualMachine) -> PyRef { zelf } + + #[pyclassmethod(name = "from_iterable")] + fn from_iterable( + cls: PyClassRef, + iterable: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult> { + let it = get_iter(vm, &iterable)?; + let iterables = get_all(vm, &it)?; + + PyItertoolsChain { + iterables, + cur: RefCell::new((0, None)), + } + .into_ref_with_type(vm, cls) + } } #[pyclass(name = "compress")]