From 239115f686ee6d7946c90b1df940b0ce7d18f243 Mon Sep 17 00:00:00 2001 From: Hyunji Kim <dev.hjkim@linecorp.com> Date: Sun, 22 Sep 2019 18:15:07 +0900 Subject: [PATCH] itertools.compress --- tests/snippets/stdlib_itertools.py | 8 ++++ vm/src/stdlib/itertools.rs | 60 +++++++++++++++++++++++++++++- 2 files changed, 66 insertions(+), 2 deletions(-) diff --git a/tests/snippets/stdlib_itertools.py b/tests/snippets/stdlib_itertools.py index 2c588fcf25..b83c6cfb69 100644 --- a/tests/snippets/stdlib_itertools.py +++ b/tests/snippets/stdlib_itertools.py @@ -235,3 +235,11 @@ def assert_matches_seq(it, seq): assert 0 == next(it) with assert_raises(StopIteration): next(it) + +# itertools.compress +assert list(itertools.compress("ABCDEF", [1,0,1,0,1,1])) == list("ACEF") +assert list(itertools.compress("ABCDEF", [0,0,0,0,0,0])) == list("") +assert list(itertools.compress("ABCDEF", [1,1,1,1,1,1])) == list("ABCDEF") +assert list(itertools.compress("ABCDEF", [1,0,1])) == list("AC") +assert list(itertools.compress("ABC", [0,1,1,1,1,1])) == list("BC") +assert list(itertools.compress("ABCDEF", [True,False,"t","",1,9])) == list("ACEF") diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index dabbb220ca..ddac6cc263 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -73,6 +73,59 @@ impl PyItertoolsChain { } } +#[pyclass(name = "compress")] +#[derive(Debug)] +struct PyItertoolsCompress { + data: PyObjectRef, + selector: PyObjectRef, +} + +impl PyValue for PyItertoolsCompress { + fn class(vm: &VirtualMachine) -> PyClassRef { + vm.class("itertools", "compress") + } +} + +#[pyimpl] +impl PyItertoolsCompress { + #[pymethod(name = "__new__")] + #[allow(clippy::new_ret_no_self)] + fn new( + _cls: PyClassRef, + data: PyObjectRef, + selector: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult { + let data_iter = get_iter(vm, &data)?; + let selector_iter = get_iter(vm, &selector)?; + + Ok(PyItertoolsCompress { + data: data_iter, + selector: selector_iter, + } + .into_ref(vm) + .into_object()) + } + + #[pymethod(name = "__next__")] + fn next(&self, vm: &VirtualMachine) -> PyResult { + loop { + let sel_obj = call_next(vm, &self.selector)?; + let verdict = objbool::boolval(vm, sel_obj.clone())?; + let data_obj = call_next(vm, &self.data)?; + + if verdict { + return Ok(data_obj); + } + } + } + + #[pymethod(name = "__iter__")] + fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> { + zelf + } +} + #[pyclass] #[derive(Debug)] struct PyItertoolsCount { @@ -577,8 +630,8 @@ impl PyItertoolsAccumulate { let obj = call_next(vm, iterable)?; let next_acc_value = match &*self.acc_value.borrow() { - Option::None => obj.clone(), - Option::Some(value) => { + None => obj.clone(), + Some(value) => { if self.binop.is(&vm.get_none()) { vm._add(value.clone(), obj.clone())? } else { @@ -602,6 +655,8 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let chain = PyItertoolsChain::make_class(ctx); + let compress = PyItertoolsCompress::make_class(ctx); + let count = ctx.new_class("count", ctx.object()); PyItertoolsCount::extend_class(ctx, &count); @@ -626,6 +681,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { py_module!(vm, "itertools", { "chain" => chain, + "compress" => compress, "count" => count, "dropwhile" => dropwhile, "repeat" => repeat,