diff --git a/tests/snippets/stdlib_itertools.py b/tests/snippets/stdlib_itertools.py index 5a400db714..117ccb9380 100644 --- a/tests/snippets/stdlib_itertools.py +++ b/tests/snippets/stdlib_itertools.py @@ -344,3 +344,46 @@ def assert_matches_seq(it, seq): with assert_raises(ValueError): itertools.combinations([1, 2, 3, 4], -2) + +# itertools.zip_longest tests +zl = itertools.zip_longest +assert list(zl(['a', 'b', 'c'], range(3), [9, 8, 7])) \ + == [('a', 0, 9), ('b', 1, 8), ('c', 2, 7)] +assert list(zl(['a', 'b', 'c'], range(3), [9, 8, 7, 99])) \ + == [('a', 0, 9), ('b', 1, 8), ('c', 2, 7), (None, None, 99)] +assert list(zl(['a', 'b', 'c'], range(3), [9, 8, 7, 99], fillvalue='d')) \ + == [('a', 0, 9), ('b', 1, 8), ('c', 2, 7), ('d', 'd', 99)] + +assert list(zl(['a', 'b', 'c'])) == [('a',), ('b',), ('c',)] +assert list(zl()) == [] + +assert list(zl(*zl(['a', 'b', 'c'], range(1, 4)))) \ + == [('a', 'b', 'c'), (1, 2, 3)] +assert list(zl(*zl(['a', 'b', 'c'], range(1, 5)))) \ + == [('a', 'b', 'c', None), (1, 2, 3, 4)] +assert list(zl(*zl(['a', 'b', 'c'], range(1, 5), fillvalue=100))) \ + == [('a', 'b', 'c', 100), (1, 2, 3, 4)] + + +# test infinite iterator +class Counter(object): + def __init__(self, counter=0): + self.counter = counter + + def __next__(self): + self.counter += 1 + return self.counter + + def __iter__(self): + return self + + +it = zl(Counter(), Counter(3)) +assert next(it) == (1, 4) +assert next(it) == (2, 5) + +it = zl([1,2], [3]) +assert next(it) == (1, 3) +assert next(it) == (2, None) +with assert_raises(StopIteration): + next(it) diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index 0d176029dc..f114d80765 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -953,6 +953,89 @@ impl PyItertoolsCombinations { } } +#[pyclass] +#[derive(Debug)] +struct PyItertoolsZiplongest { + iterators: Vec<PyObjectRef>, + fillvalue: PyObjectRef, + numactive: Cell<usize>, +} + +impl PyValue for PyItertoolsZiplongest { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("itertools", "zip_longest") + } +} + +#[derive(FromArgs)] +struct ZiplongestArgs { + #[pyarg(keyword_only, optional = true)] + fillvalue: OptionalArg<PyObjectRef>, +} + +#[pyimpl] +impl PyItertoolsZiplongest { + #[pyslot(new)] + fn tp_new( + cls: PyClassRef, + iterables: Args, + args: ZiplongestArgs, + vm: &VirtualMachine, + ) -> PyResult<PyRef<Self>> { + let fillvalue = match args.fillvalue.into_option() { + Some(i) => i, + None => vm.get_none(), + }; + + let iterators = iterables + .into_iter() + .map(|iterable| get_iter(vm, &iterable)) + .collect::<Result<Vec<_>, _>>()?; + + let numactive = Cell::new(iterators.len()); + + PyItertoolsZiplongest { + iterators, + fillvalue, + numactive, + } + .into_ref_with_type(vm, cls) + } + + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + if self.iterators.is_empty() { + Err(new_stop_iteration(vm)) + } else { + let mut result: Vec<PyObjectRef> = Vec::new(); + let mut numactive = self.numactive.get(); + + for idx in 0..self.iterators.len() { + let next_obj = match call_next(vm, &self.iterators[idx]) { + Ok(obj) => obj, + Err(err) => { + if !objtype::isinstance(&err, &vm.ctx.exceptions.stop_iteration) { + return Err(err); + } + numactive -= 1; + if numactive == 0 { + return Err(new_stop_iteration(vm)); + } + self.fillvalue.clone() + } + }; + result.push(next_obj); + } + Ok(vm.ctx.new_tuple(result)) + } + } + + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> { + zelf + } +} + pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let ctx = &vm.ctx; @@ -991,6 +1074,9 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let tee = ctx.new_class("tee", ctx.object()); PyItertoolsTee::extend_class(ctx, &tee); + let zip_longest = ctx.new_class("zip_longest", ctx.object()); + PyItertoolsZiplongest::extend_class(ctx, &zip_longest); + py_module!(vm, "itertools", { "accumulate" => accumulate, "chain" => chain, @@ -1005,5 +1091,6 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { "takewhile" => takewhile, "tee" => tee, "product" => product, + "zip_longest" => zip_longest, }) }