1
1
use std:: cell:: { Cell , RefCell } ;
2
2
use std:: cmp:: Ordering ;
3
+ use std:: iter;
3
4
use std:: ops:: { AddAssign , SubAssign } ;
4
5
use std:: rc:: Rc ;
5
6
6
7
use num_bigint:: BigInt ;
7
8
use num_traits:: ToPrimitive ;
8
9
9
- use crate :: function:: { OptionalArg , PyFuncArgs } ;
10
+ use crate :: function:: { Args , OptionalArg , PyFuncArgs } ;
10
11
use crate :: obj:: objbool;
11
12
use crate :: obj:: objint:: { self , PyInt , PyIntRef } ;
12
- use crate :: obj:: objiter:: { call_next, get_iter, new_stop_iteration} ;
13
+ use crate :: obj:: objiter:: { call_next, get_all , get_iter, new_stop_iteration} ;
13
14
use crate :: obj:: objtuple:: PyTuple ;
14
15
use crate :: obj:: objtype:: { self , PyClassRef } ;
15
16
use crate :: pyobject:: {
@@ -736,6 +737,123 @@ impl PyItertoolsTee {
736
737
}
737
738
}
738
739
740
+ #[ pyclass]
741
+ #[ derive( Debug ) ]
742
+ struct PyIterToolsProduct {
743
+ pools : Vec < Vec < PyObjectRef > > ,
744
+ idxs : RefCell < Vec < usize > > ,
745
+ cur : Cell < usize > ,
746
+ stop : Cell < bool > ,
747
+ }
748
+
749
+ impl PyValue for PyIterToolsProduct {
750
+ fn class ( vm : & VirtualMachine ) -> PyClassRef {
751
+ vm. class ( "itertools" , "product" )
752
+ }
753
+ }
754
+
755
+ #[ derive( FromArgs ) ]
756
+ struct ProductArgs {
757
+ #[ pyarg( keyword_only, optional = true ) ]
758
+ repeat : OptionalArg < usize > ,
759
+ }
760
+
761
+ #[ pyimpl]
762
+ impl PyIterToolsProduct {
763
+ #[ pyslot( new) ]
764
+ fn new (
765
+ cls : PyClassRef ,
766
+ iterables : Args < PyObjectRef > ,
767
+ args : ProductArgs ,
768
+ vm : & VirtualMachine ,
769
+ ) -> PyResult < PyRef < Self > > {
770
+ let repeat = match args. repeat . into_option ( ) {
771
+ Some ( i) => i,
772
+ None => 1 ,
773
+ } ;
774
+
775
+ let mut pools = Vec :: new ( ) ;
776
+ for arg in iterables. into_iter ( ) {
777
+ let it = get_iter ( vm, & arg) ?;
778
+ let pool = get_all ( vm, & it) ?;
779
+
780
+ pools. push ( pool) ;
781
+ }
782
+ let pools = iter:: repeat ( pools)
783
+ . take ( repeat)
784
+ . flatten ( )
785
+ . collect :: < Vec < Vec < PyObjectRef > > > ( ) ;
786
+
787
+ let l = pools. len ( ) ;
788
+
789
+ PyIterToolsProduct {
790
+ pools,
791
+ idxs : RefCell :: new ( vec ! [ 0 ; l] ) ,
792
+ cur : Cell :: new ( l - 1 ) ,
793
+ stop : Cell :: new ( false ) ,
794
+ }
795
+ . into_ref_with_type ( vm, cls)
796
+ }
797
+
798
+ #[ pymethod( name = "__next__" ) ]
799
+ fn next ( & self , vm : & VirtualMachine ) -> PyResult {
800
+ // stop signal
801
+ if self . stop . get ( ) {
802
+ return Err ( new_stop_iteration ( vm) ) ;
803
+ }
804
+
805
+ let pools = & self . pools ;
806
+
807
+ for p in pools {
808
+ if p. is_empty ( ) {
809
+ return Err ( new_stop_iteration ( vm) ) ;
810
+ }
811
+ }
812
+
813
+ let res = PyTuple :: from (
814
+ pools
815
+ . iter ( )
816
+ . zip ( self . idxs . borrow ( ) . iter ( ) )
817
+ . map ( |( pool, idx) | pool[ * idx] . clone ( ) )
818
+ . collect :: < Vec < PyObjectRef > > ( ) ,
819
+ ) ;
820
+
821
+ self . update_idxs ( ) ;
822
+
823
+ if self . is_end ( ) {
824
+ self . stop . set ( true ) ;
825
+ }
826
+
827
+ Ok ( res. into_ref ( vm) . into_object ( ) )
828
+ }
829
+
830
+ fn is_end ( & self ) -> bool {
831
+ ( self . idxs . borrow ( ) [ self . cur . get ( ) ] == & self . pools [ self . cur . get ( ) ] . len ( ) - 1
832
+ && self . cur . get ( ) == 0 )
833
+ }
834
+
835
+ fn update_idxs ( & self ) {
836
+ let lst_idx = & self . pools [ self . cur . get ( ) ] . len ( ) - 1 ;
837
+
838
+ if self . idxs . borrow ( ) [ self . cur . get ( ) ] == lst_idx {
839
+ if self . is_end ( ) {
840
+ return ;
841
+ }
842
+ self . idxs . borrow_mut ( ) [ self . cur . get ( ) ] = 0 ;
843
+ self . cur . set ( self . cur . get ( ) - 1 ) ;
844
+ self . update_idxs ( ) ;
845
+ } else {
846
+ self . idxs . borrow_mut ( ) [ self . cur . get ( ) ] += 1 ;
847
+ self . cur . set ( self . idxs . borrow ( ) . len ( ) - 1 ) ;
848
+ }
849
+ }
850
+
851
+ #[ pymethod( name = "__iter__" ) ]
852
+ fn iter ( zelf : PyRef < Self > , _vm : & VirtualMachine ) -> PyRef < Self > {
853
+ zelf
854
+ }
855
+ }
856
+
739
857
pub fn make_module ( vm : & VirtualMachine ) -> PyObjectRef {
740
858
let ctx = & vm. ctx ;
741
859
@@ -767,6 +885,8 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
767
885
768
886
let tee = ctx. new_class ( "tee" , ctx. object ( ) ) ;
769
887
PyItertoolsTee :: extend_class ( ctx, & tee) ;
888
+ let product = ctx. new_class ( "product" , ctx. object ( ) ) ;
889
+ PyIterToolsProduct :: extend_class ( ctx, & product) ;
770
890
771
891
py_module ! ( vm, "itertools" , {
772
892
"chain" => chain,
@@ -780,5 +900,6 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
780
900
"filterfalse" => filterfalse,
781
901
"accumulate" => accumulate,
782
902
"tee" => tee,
903
+ "product" => product,
783
904
} )
784
905
}
0 commit comments