Skip to content

Commit 239115f

Browse files
author
Hyunji Kim
committed
itertools.compress
1 parent a3c2eea commit 239115f

File tree

2 files changed

+66
-2
lines changed

2 files changed

+66
-2
lines changed

tests/snippets/stdlib_itertools.py

+8
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,11 @@ def assert_matches_seq(it, seq):
235235
assert 0 == next(it)
236236
with assert_raises(StopIteration):
237237
next(it)
238+
239+
# itertools.compress
240+
assert list(itertools.compress("ABCDEF", [1,0,1,0,1,1])) == list("ACEF")
241+
assert list(itertools.compress("ABCDEF", [0,0,0,0,0,0])) == list("")
242+
assert list(itertools.compress("ABCDEF", [1,1,1,1,1,1])) == list("ABCDEF")
243+
assert list(itertools.compress("ABCDEF", [1,0,1])) == list("AC")
244+
assert list(itertools.compress("ABC", [0,1,1,1,1,1])) == list("BC")
245+
assert list(itertools.compress("ABCDEF", [True,False,"t","",1,9])) == list("ACEF")

vm/src/stdlib/itertools.rs

+58-2
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,59 @@ impl PyItertoolsChain {
7373
}
7474
}
7575

76+
#[pyclass(name = "compress")]
77+
#[derive(Debug)]
78+
struct PyItertoolsCompress {
79+
data: PyObjectRef,
80+
selector: PyObjectRef,
81+
}
82+
83+
impl PyValue for PyItertoolsCompress {
84+
fn class(vm: &VirtualMachine) -> PyClassRef {
85+
vm.class("itertools", "compress")
86+
}
87+
}
88+
89+
#[pyimpl]
90+
impl PyItertoolsCompress {
91+
#[pymethod(name = "__new__")]
92+
#[allow(clippy::new_ret_no_self)]
93+
fn new(
94+
_cls: PyClassRef,
95+
data: PyObjectRef,
96+
selector: PyObjectRef,
97+
vm: &VirtualMachine,
98+
) -> PyResult {
99+
let data_iter = get_iter(vm, &data)?;
100+
let selector_iter = get_iter(vm, &selector)?;
101+
102+
Ok(PyItertoolsCompress {
103+
data: data_iter,
104+
selector: selector_iter,
105+
}
106+
.into_ref(vm)
107+
.into_object())
108+
}
109+
110+
#[pymethod(name = "__next__")]
111+
fn next(&self, vm: &VirtualMachine) -> PyResult {
112+
loop {
113+
let sel_obj = call_next(vm, &self.selector)?;
114+
let verdict = objbool::boolval(vm, sel_obj.clone())?;
115+
let data_obj = call_next(vm, &self.data)?;
116+
117+
if verdict {
118+
return Ok(data_obj);
119+
}
120+
}
121+
}
122+
123+
#[pymethod(name = "__iter__")]
124+
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
125+
zelf
126+
}
127+
}
128+
76129
#[pyclass]
77130
#[derive(Debug)]
78131
struct PyItertoolsCount {
@@ -577,8 +630,8 @@ impl PyItertoolsAccumulate {
577630
let obj = call_next(vm, iterable)?;
578631

579632
let next_acc_value = match &*self.acc_value.borrow() {
580-
Option::None => obj.clone(),
581-
Option::Some(value) => {
633+
None => obj.clone(),
634+
Some(value) => {
582635
if self.binop.is(&vm.get_none()) {
583636
vm._add(value.clone(), obj.clone())?
584637
} else {
@@ -602,6 +655,8 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
602655

603656
let chain = PyItertoolsChain::make_class(ctx);
604657

658+
let compress = PyItertoolsCompress::make_class(ctx);
659+
605660
let count = ctx.new_class("count", ctx.object());
606661
PyItertoolsCount::extend_class(ctx, &count);
607662

@@ -626,6 +681,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
626681

627682
py_module!(vm, "itertools", {
628683
"chain" => chain,
684+
"compress" => compress,
629685
"count" => count,
630686
"dropwhile" => dropwhile,
631687
"repeat" => repeat,

0 commit comments

Comments
 (0)