Skip to content

Commit fcd39a3

Browse files
authored
Merge pull request #1372 from j30ng/itertools-accumulate
Implement itertools.accumulate
2 parents 986f61e + d4e5d76 commit fcd39a3

File tree

2 files changed

+90
-1
lines changed

2 files changed

+90
-1
lines changed

tests/snippets/stdlib_itertools.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -209,4 +209,29 @@ def assert_matches_seq(it, seq):
209209
assert 4 == next(it)
210210
assert 1 == next(it)
211211
with assertRaises(StopIteration):
212-
next(it)
212+
next(it)
213+
214+
215+
# itertools.accumulate
216+
it = itertools.accumulate([6, 3, 7, 1, 0, 9, 8, 8])
217+
assert 6 == next(it)
218+
assert 9 == next(it)
219+
assert 16 == next(it)
220+
assert 17 == next(it)
221+
assert 17 == next(it)
222+
assert 26 == next(it)
223+
assert 34 == next(it)
224+
assert 42 == next(it)
225+
with assertRaises(StopIteration):
226+
next(it)
227+
228+
it = itertools.accumulate([3, 2, 4, 1, 0, 5, 8], lambda a, v: a*v)
229+
assert 3 == next(it)
230+
assert 6 == next(it)
231+
assert 24 == next(it)
232+
assert 24 == next(it)
233+
assert 0 == next(it)
234+
assert 0 == next(it)
235+
assert 0 == next(it)
236+
with assertRaises(StopIteration):
237+
next(it)

vm/src/stdlib/itertools.rs

+64
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,66 @@ impl PyItertoolsFilterFalse {
537537
}
538538
}
539539

540+
#[pyclass]
541+
#[derive(Debug)]
542+
struct PyItertoolsAccumulate {
543+
iterable: PyObjectRef,
544+
binop: PyObjectRef,
545+
acc_value: RefCell<Option<PyObjectRef>>,
546+
}
547+
548+
impl PyValue for PyItertoolsAccumulate {
549+
fn class(vm: &VirtualMachine) -> PyClassRef {
550+
vm.class("itertools", "accumulate")
551+
}
552+
}
553+
554+
#[pyimpl]
555+
impl PyItertoolsAccumulate {
556+
#[pymethod(name = "__new__")]
557+
#[allow(clippy::new_ret_no_self)]
558+
fn new(
559+
cls: PyClassRef,
560+
iterable: PyObjectRef,
561+
binop: OptionalArg<PyObjectRef>,
562+
vm: &VirtualMachine,
563+
) -> PyResult<PyRef<PyItertoolsAccumulate>> {
564+
let iter = get_iter(vm, &iterable)?;
565+
566+
PyItertoolsAccumulate {
567+
iterable: iter,
568+
binop: binop.unwrap_or_else(|| vm.get_none()),
569+
acc_value: RefCell::from(Option::None),
570+
}
571+
.into_ref_with_type(vm, cls)
572+
}
573+
574+
#[pymethod(name = "__next__")]
575+
fn next(&self, vm: &VirtualMachine) -> PyResult {
576+
let iterable = &self.iterable;
577+
let obj = call_next(vm, iterable)?;
578+
579+
let next_acc_value = match &*self.acc_value.borrow() {
580+
Option::None => obj.clone(),
581+
Option::Some(value) => {
582+
if self.binop.is(&vm.get_none()) {
583+
vm._add(value.clone(), obj.clone())?
584+
} else {
585+
vm.invoke(&self.binop, vec![value.clone(), obj.clone()])?
586+
}
587+
}
588+
};
589+
self.acc_value.replace(Option::from(next_acc_value.clone()));
590+
591+
Ok(next_acc_value)
592+
}
593+
594+
#[pymethod(name = "__iter__")]
595+
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
596+
zelf
597+
}
598+
}
599+
540600
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
541601
let ctx = &vm.ctx;
542602

@@ -561,6 +621,9 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
561621
let filterfalse = ctx.new_class("filterfalse", ctx.object());
562622
PyItertoolsFilterFalse::extend_class(ctx, &filterfalse);
563623

624+
let accumulate = ctx.new_class("accumulate", ctx.object());
625+
PyItertoolsAccumulate::extend_class(ctx, &accumulate);
626+
564627
py_module!(vm, "itertools", {
565628
"chain" => chain,
566629
"count" => count,
@@ -570,5 +633,6 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
570633
"takewhile" => takewhile,
571634
"islice" => islice,
572635
"filterfalse" => filterfalse,
636+
"accumulate" => accumulate,
573637
})
574638
}

0 commit comments

Comments
 (0)