Skip to content

Commit faee9e3

Browse files
Merge pull request #1458 from j30ng/itertools-tee
Implement itertools.tee
2 parents 162ff58 + 0b7da12 commit faee9e3

File tree

2 files changed

+173
-1
lines changed

2 files changed

+173
-1
lines changed

tests/snippets/stdlib_itertools.py

+56
Original file line numberDiff line numberDiff line change
@@ -243,3 +243,59 @@ def assert_matches_seq(it, seq):
243243
assert list(itertools.compress("ABCDEF", [1,0,1])) == list("AC")
244244
assert list(itertools.compress("ABC", [0,1,1,1,1,1])) == list("BC")
245245
assert list(itertools.compress("ABCDEF", [True,False,"t","",1,9])) == list("ACEF")
246+
247+
248+
# itertools.tee
249+
t = itertools.tee([])
250+
assert len(t) == 2
251+
assert t[0] is not t[1]
252+
assert list(t[0]) == list(t[1]) == []
253+
254+
with assert_raises(TypeError):
255+
itertools.tee()
256+
257+
t = itertools.tee(range(1000))
258+
assert list(t[0]) == list(t[1]) == list(range(1000))
259+
260+
t = itertools.tee([1,22,333], 3)
261+
assert len(t) == 3
262+
assert 1 == next(t[0])
263+
assert 1 == next(t[1])
264+
assert 22 == next(t[0])
265+
assert 333 == next(t[0])
266+
assert 1 == next(t[2])
267+
assert 22 == next(t[1])
268+
assert 333 == next(t[1])
269+
assert 22 == next(t[2])
270+
with assert_raises(StopIteration):
271+
next(t[0])
272+
assert 333 == next(t[2])
273+
with assert_raises(StopIteration):
274+
next(t[2])
275+
with assert_raises(StopIteration):
276+
next(t[1])
277+
278+
t0, t1 = itertools.tee([1,2,3])
279+
tc = t0.__copy__()
280+
assert list(t0) == [1,2,3]
281+
assert list(t1) == [1,2,3]
282+
assert list(tc) == [1,2,3]
283+
284+
t0, t1 = itertools.tee([1,2,3])
285+
assert 1 == next(t0) # advance index of t0 by 1 before __copy__()
286+
t0c = t0.__copy__()
287+
t1c = t1.__copy__()
288+
assert list(t0) == [2,3]
289+
assert list(t0c) == [2,3]
290+
assert list(t1) == [1,2,3]
291+
assert list(t1c) == [1,2,3]
292+
293+
t0, t1 = itertools.tee([1,2,3])
294+
t2, t3 = itertools.tee(t0)
295+
assert list(t1) == [1,2,3]
296+
assert list(t2) == [1,2,3]
297+
assert list(t3) == [1,2,3]
298+
299+
t = itertools.tee([1,2,3])
300+
assert list(t[0]) == [1,2,3]
301+
assert list(t[0]) == []

vm/src/stdlib/itertools.rs

+117-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::cell::{Cell, RefCell};
22
use std::cmp::Ordering;
33
use std::ops::{AddAssign, SubAssign};
4+
use std::rc::Rc;
45

56
use num_bigint::BigInt;
67
use num_traits::ToPrimitive;
@@ -10,9 +11,12 @@ use crate::obj::objbool;
1011
use crate::obj::objint;
1112
use crate::obj::objint::{PyInt, PyIntRef};
1213
use crate::obj::objiter::{call_next, get_iter, new_stop_iteration};
14+
use crate::obj::objtuple::PyTuple;
1315
use crate::obj::objtype;
1416
use crate::obj::objtype::PyClassRef;
15-
use crate::pyobject::{IdProtocol, PyCallable, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue};
17+
use crate::pyobject::{
18+
IdProtocol, PyCallable, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol,
19+
};
1620
use crate::vm::VirtualMachine;
1721

1822
#[pyclass(name = "chain")]
@@ -629,6 +633,114 @@ impl PyItertoolsAccumulate {
629633
}
630634
}
631635

636+
#[derive(Debug)]
637+
struct PyItertoolsTeeData {
638+
iterable: PyObjectRef,
639+
values: RefCell<Vec<PyObjectRef>>,
640+
}
641+
642+
impl PyItertoolsTeeData {
643+
fn new(
644+
iterable: PyObjectRef,
645+
vm: &VirtualMachine,
646+
) -> Result<Rc<PyItertoolsTeeData>, PyObjectRef> {
647+
Ok(Rc::new(PyItertoolsTeeData {
648+
iterable: get_iter(vm, &iterable)?,
649+
values: RefCell::new(vec![]),
650+
}))
651+
}
652+
653+
fn get_item(&self, vm: &VirtualMachine, index: usize) -> PyResult {
654+
if self.values.borrow().len() == index {
655+
let result = call_next(vm, &self.iterable)?;
656+
self.values.borrow_mut().push(result);
657+
}
658+
Ok(self.values.borrow()[index].clone())
659+
}
660+
}
661+
662+
#[pyclass]
663+
#[derive(Debug)]
664+
struct PyItertoolsTee {
665+
tee_data: Rc<PyItertoolsTeeData>,
666+
index: Cell<usize>,
667+
}
668+
669+
impl PyValue for PyItertoolsTee {
670+
fn class(vm: &VirtualMachine) -> PyClassRef {
671+
vm.class("itertools", "tee")
672+
}
673+
}
674+
675+
#[pyimpl]
676+
impl PyItertoolsTee {
677+
fn from_iter(iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
678+
let it = get_iter(vm, &iterable)?;
679+
if it.class().is(&PyItertoolsTee::class(vm)) {
680+
return vm.call_method(&it, "__copy__", PyFuncArgs::from(vec![]));
681+
}
682+
Ok(PyItertoolsTee {
683+
tee_data: PyItertoolsTeeData::new(it, vm)?,
684+
index: Cell::from(0),
685+
}
686+
.into_ref_with_type(vm, PyItertoolsTee::class(vm))?
687+
.into_object())
688+
}
689+
690+
#[pymethod(name = "__new__")]
691+
#[allow(clippy::new_ret_no_self)]
692+
fn new(
693+
_cls: PyClassRef,
694+
iterable: PyObjectRef,
695+
n: OptionalArg<PyIntRef>,
696+
vm: &VirtualMachine,
697+
) -> PyResult<PyRef<PyTuple>> {
698+
let n = match n {
699+
OptionalArg::Present(x) => match x.as_bigint().to_usize() {
700+
Some(y) => y,
701+
None => return Err(vm.new_overflow_error(String::from("n is too big"))),
702+
},
703+
OptionalArg::Missing => 2,
704+
};
705+
706+
let copyable = if objtype::class_has_attr(&iterable.class(), "__copy__") {
707+
vm.call_method(&iterable, "__copy__", PyFuncArgs::from(vec![]))?
708+
} else {
709+
PyItertoolsTee::from_iter(iterable, vm)?
710+
};
711+
712+
let mut tee_vec: Vec<PyObjectRef> = Vec::with_capacity(n);
713+
for _ in 0..n {
714+
let no_args = PyFuncArgs::from(vec![]);
715+
tee_vec.push(vm.call_method(&copyable, "__copy__", no_args)?);
716+
}
717+
718+
Ok(PyTuple::from(tee_vec).into_ref(vm))
719+
}
720+
721+
#[pymethod(name = "__copy__")]
722+
fn copy(&self, vm: &VirtualMachine) -> PyResult {
723+
Ok(PyItertoolsTee {
724+
tee_data: Rc::clone(&self.tee_data),
725+
index: self.index.clone(),
726+
}
727+
.into_ref_with_type(vm, Self::class(vm))?
728+
.into_object())
729+
}
730+
731+
#[pymethod(name = "__next__")]
732+
fn next(&self, vm: &VirtualMachine) -> PyResult {
733+
let value = self.tee_data.get_item(vm, self.index.get())?;
734+
self.index.set(self.index.get() + 1);
735+
Ok(value)
736+
}
737+
738+
#[pymethod(name = "__iter__")]
739+
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
740+
zelf
741+
}
742+
}
743+
632744
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
633745
let ctx = &vm.ctx;
634746

@@ -658,6 +770,9 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
658770
let accumulate = ctx.new_class("accumulate", ctx.object());
659771
PyItertoolsAccumulate::extend_class(ctx, &accumulate);
660772

773+
let tee = ctx.new_class("tee", ctx.object());
774+
PyItertoolsTee::extend_class(ctx, &tee);
775+
661776
py_module!(vm, "itertools", {
662777
"chain" => chain,
663778
"compress" => compress,
@@ -669,5 +784,6 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
669784
"islice" => islice,
670785
"filterfalse" => filterfalse,
671786
"accumulate" => accumulate,
787+
"tee" => tee,
672788
})
673789
}

0 commit comments

Comments
 (0)