1
1
use std:: cell:: { Cell , RefCell } ;
2
2
use std:: cmp:: Ordering ;
3
3
use std:: ops:: { AddAssign , SubAssign } ;
4
+ use std:: rc:: Rc ;
4
5
5
6
use num_bigint:: BigInt ;
6
7
use num_traits:: ToPrimitive ;
@@ -10,9 +11,12 @@ use crate::obj::objbool;
10
11
use crate :: obj:: objint;
11
12
use crate :: obj:: objint:: { PyInt , PyIntRef } ;
12
13
use crate :: obj:: objiter:: { call_next, get_iter, new_stop_iteration} ;
14
+ use crate :: obj:: objtuple:: PyTuple ;
13
15
use crate :: obj:: objtype;
14
16
use crate :: obj:: objtype:: PyClassRef ;
15
- use crate :: pyobject:: { IdProtocol , PyCallable , PyClassImpl , PyObjectRef , PyRef , PyResult , PyValue } ;
17
+ use crate :: pyobject:: {
18
+ IdProtocol , PyCallable , PyClassImpl , PyObjectRef , PyRef , PyResult , PyValue , TypeProtocol ,
19
+ } ;
16
20
use crate :: vm:: VirtualMachine ;
17
21
18
22
#[ pyclass( name = "chain" ) ]
@@ -629,6 +633,114 @@ impl PyItertoolsAccumulate {
629
633
}
630
634
}
631
635
636
+ #[ derive( Debug ) ]
637
+ struct PyItertoolsTeeData {
638
+ iterable : PyObjectRef ,
639
+ values : RefCell < Vec < PyObjectRef > > ,
640
+ }
641
+
642
+ impl PyItertoolsTeeData {
643
+ fn new (
644
+ iterable : PyObjectRef ,
645
+ vm : & VirtualMachine ,
646
+ ) -> Result < Rc < PyItertoolsTeeData > , PyObjectRef > {
647
+ Ok ( Rc :: new ( PyItertoolsTeeData {
648
+ iterable : get_iter ( vm, & iterable) ?,
649
+ values : RefCell :: new ( vec ! [ ] ) ,
650
+ } ) )
651
+ }
652
+
653
+ fn get_item ( & self , vm : & VirtualMachine , index : usize ) -> PyResult {
654
+ if self . values . borrow ( ) . len ( ) == index {
655
+ let result = call_next ( vm, & self . iterable ) ?;
656
+ self . values . borrow_mut ( ) . push ( result) ;
657
+ }
658
+ Ok ( self . values . borrow ( ) [ index] . clone ( ) )
659
+ }
660
+ }
661
+
662
+ #[ pyclass]
663
+ #[ derive( Debug ) ]
664
+ struct PyItertoolsTee {
665
+ tee_data : Rc < PyItertoolsTeeData > ,
666
+ index : Cell < usize > ,
667
+ }
668
+
669
+ impl PyValue for PyItertoolsTee {
670
+ fn class ( vm : & VirtualMachine ) -> PyClassRef {
671
+ vm. class ( "itertools" , "tee" )
672
+ }
673
+ }
674
+
675
+ #[ pyimpl]
676
+ impl PyItertoolsTee {
677
+ fn from_iter ( iterable : PyObjectRef , vm : & VirtualMachine ) -> PyResult < PyObjectRef > {
678
+ let it = get_iter ( vm, & iterable) ?;
679
+ if it. class ( ) . is ( & PyItertoolsTee :: class ( vm) ) {
680
+ return vm. call_method ( & it, "__copy__" , PyFuncArgs :: from ( vec ! [ ] ) ) ;
681
+ }
682
+ Ok ( PyItertoolsTee {
683
+ tee_data : PyItertoolsTeeData :: new ( it, vm) ?,
684
+ index : Cell :: from ( 0 ) ,
685
+ }
686
+ . into_ref_with_type ( vm, PyItertoolsTee :: class ( vm) ) ?
687
+ . into_object ( ) )
688
+ }
689
+
690
+ #[ pymethod( name = "__new__" ) ]
691
+ #[ allow( clippy:: new_ret_no_self) ]
692
+ fn new (
693
+ _cls : PyClassRef ,
694
+ iterable : PyObjectRef ,
695
+ n : OptionalArg < PyIntRef > ,
696
+ vm : & VirtualMachine ,
697
+ ) -> PyResult < PyRef < PyTuple > > {
698
+ let n = match n {
699
+ OptionalArg :: Present ( x) => match x. as_bigint ( ) . to_usize ( ) {
700
+ Some ( y) => y,
701
+ None => return Err ( vm. new_overflow_error ( String :: from ( "n is too big" ) ) ) ,
702
+ } ,
703
+ OptionalArg :: Missing => 2 ,
704
+ } ;
705
+
706
+ let copyable = if objtype:: class_has_attr ( & iterable. class ( ) , "__copy__" ) {
707
+ vm. call_method ( & iterable, "__copy__" , PyFuncArgs :: from ( vec ! [ ] ) ) ?
708
+ } else {
709
+ PyItertoolsTee :: from_iter ( iterable, vm) ?
710
+ } ;
711
+
712
+ let mut tee_vec: Vec < PyObjectRef > = Vec :: with_capacity ( n) ;
713
+ for _ in 0 ..n {
714
+ let no_args = PyFuncArgs :: from ( vec ! [ ] ) ;
715
+ tee_vec. push ( vm. call_method ( & copyable, "__copy__" , no_args) ?) ;
716
+ }
717
+
718
+ Ok ( PyTuple :: from ( tee_vec) . into_ref ( vm) )
719
+ }
720
+
721
+ #[ pymethod( name = "__copy__" ) ]
722
+ fn copy ( & self , vm : & VirtualMachine ) -> PyResult {
723
+ Ok ( PyItertoolsTee {
724
+ tee_data : Rc :: clone ( & self . tee_data ) ,
725
+ index : self . index . clone ( ) ,
726
+ }
727
+ . into_ref_with_type ( vm, Self :: class ( vm) ) ?
728
+ . into_object ( ) )
729
+ }
730
+
731
+ #[ pymethod( name = "__next__" ) ]
732
+ fn next ( & self , vm : & VirtualMachine ) -> PyResult {
733
+ let value = self . tee_data . get_item ( vm, self . index . get ( ) ) ?;
734
+ self . index . set ( self . index . get ( ) + 1 ) ;
735
+ Ok ( value)
736
+ }
737
+
738
+ #[ pymethod( name = "__iter__" ) ]
739
+ fn iter ( zelf : PyRef < Self > , _vm : & VirtualMachine ) -> PyRef < Self > {
740
+ zelf
741
+ }
742
+ }
743
+
632
744
pub fn make_module ( vm : & VirtualMachine ) -> PyObjectRef {
633
745
let ctx = & vm. ctx ;
634
746
@@ -658,6 +770,9 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
658
770
let accumulate = ctx. new_class ( "accumulate" , ctx. object ( ) ) ;
659
771
PyItertoolsAccumulate :: extend_class ( ctx, & accumulate) ;
660
772
773
+ let tee = ctx. new_class ( "tee" , ctx. object ( ) ) ;
774
+ PyItertoolsTee :: extend_class ( ctx, & tee) ;
775
+
661
776
py_module ! ( vm, "itertools" , {
662
777
"chain" => chain,
663
778
"compress" => compress,
@@ -669,5 +784,6 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
669
784
"islice" => islice,
670
785
"filterfalse" => filterfalse,
671
786
"accumulate" => accumulate,
787
+ "tee" => tee,
672
788
} )
673
789
}
0 commit comments