From 239115f686ee6d7946c90b1df940b0ce7d18f243 Mon Sep 17 00:00:00 2001
From: Hyunji Kim <dev.hjkim@linecorp.com>
Date: Sun, 22 Sep 2019 18:15:07 +0900
Subject: [PATCH] itertools.compress

---
 tests/snippets/stdlib_itertools.py |  8 ++++
 vm/src/stdlib/itertools.rs         | 60 +++++++++++++++++++++++++++++-
 2 files changed, 66 insertions(+), 2 deletions(-)

diff --git a/tests/snippets/stdlib_itertools.py b/tests/snippets/stdlib_itertools.py
index 2c588fcf25..b83c6cfb69 100644
--- a/tests/snippets/stdlib_itertools.py
+++ b/tests/snippets/stdlib_itertools.py
@@ -235,3 +235,11 @@ def assert_matches_seq(it, seq):
 assert 0 == next(it)
 with assert_raises(StopIteration):
     next(it)
+
+# itertools.compress
+assert list(itertools.compress("ABCDEF", [1,0,1,0,1,1])) == list("ACEF")
+assert list(itertools.compress("ABCDEF", [0,0,0,0,0,0])) == list("")
+assert list(itertools.compress("ABCDEF", [1,1,1,1,1,1])) == list("ABCDEF")
+assert list(itertools.compress("ABCDEF", [1,0,1])) == list("AC")
+assert list(itertools.compress("ABC", [0,1,1,1,1,1])) == list("BC")
+assert list(itertools.compress("ABCDEF", [True,False,"t","",1,9])) == list("ACEF")
diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs
index dabbb220ca..ddac6cc263 100644
--- a/vm/src/stdlib/itertools.rs
+++ b/vm/src/stdlib/itertools.rs
@@ -73,6 +73,59 @@ impl PyItertoolsChain {
     }
 }
 
+#[pyclass(name = "compress")]
+#[derive(Debug)]
+struct PyItertoolsCompress {
+    data: PyObjectRef,
+    selector: PyObjectRef,
+}
+
+impl PyValue for PyItertoolsCompress {
+    fn class(vm: &VirtualMachine) -> PyClassRef {
+        vm.class("itertools", "compress")
+    }
+}
+
+#[pyimpl]
+impl PyItertoolsCompress {
+    #[pymethod(name = "__new__")]
+    #[allow(clippy::new_ret_no_self)]
+    fn new(
+        _cls: PyClassRef,
+        data: PyObjectRef,
+        selector: PyObjectRef,
+        vm: &VirtualMachine,
+    ) -> PyResult {
+        let data_iter = get_iter(vm, &data)?;
+        let selector_iter = get_iter(vm, &selector)?;
+
+        Ok(PyItertoolsCompress {
+            data: data_iter,
+            selector: selector_iter,
+        }
+        .into_ref(vm)
+        .into_object())
+    }
+
+    #[pymethod(name = "__next__")]
+    fn next(&self, vm: &VirtualMachine) -> PyResult {
+        loop {
+            let sel_obj = call_next(vm, &self.selector)?;
+            let verdict = objbool::boolval(vm, sel_obj.clone())?;
+            let data_obj = call_next(vm, &self.data)?;
+
+            if verdict {
+                return Ok(data_obj);
+            }
+        }
+    }
+
+    #[pymethod(name = "__iter__")]
+    fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
+        zelf
+    }
+}
+
 #[pyclass]
 #[derive(Debug)]
 struct PyItertoolsCount {
@@ -577,8 +630,8 @@ impl PyItertoolsAccumulate {
         let obj = call_next(vm, iterable)?;
 
         let next_acc_value = match &*self.acc_value.borrow() {
-            Option::None => obj.clone(),
-            Option::Some(value) => {
+            None => obj.clone(),
+            Some(value) => {
                 if self.binop.is(&vm.get_none()) {
                     vm._add(value.clone(), obj.clone())?
                 } else {
@@ -602,6 +655,8 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
 
     let chain = PyItertoolsChain::make_class(ctx);
 
+    let compress = PyItertoolsCompress::make_class(ctx);
+
     let count = ctx.new_class("count", ctx.object());
     PyItertoolsCount::extend_class(ctx, &count);
 
@@ -626,6 +681,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
 
     py_module!(vm, "itertools", {
         "chain" => chain,
+        "compress" => compress,
         "count" => count,
         "dropwhile" => dropwhile,
         "repeat" => repeat,