1
1
use std:: cell:: { Cell , RefCell } ;
2
2
use std:: cmp:: Ordering ;
3
+ use std:: iter;
3
4
use std:: ops:: { AddAssign , SubAssign } ;
4
5
use std:: rc:: Rc ;
5
6
6
7
use num_bigint:: BigInt ;
7
8
use num_traits:: ToPrimitive ;
8
9
9
- use crate :: function:: { OptionalArg , PyFuncArgs } ;
10
+ use crate :: function:: { Args , OptionalArg , PyFuncArgs } ;
10
11
use crate :: obj:: objbool;
11
12
use crate :: obj:: objint:: { self , PyInt , PyIntRef } ;
12
- use crate :: obj:: objiter:: { call_next, get_iter, new_stop_iteration} ;
13
+ use crate :: obj:: objiter:: { call_next, get_all , get_iter, new_stop_iteration} ;
13
14
use crate :: obj:: objtuple:: PyTuple ;
14
15
use crate :: obj:: objtype:: { self , PyClassRef } ;
15
16
use crate :: pyobject:: {
@@ -730,6 +731,123 @@ impl PyItertoolsTee {
730
731
}
731
732
}
732
733
734
+ #[ pyclass]
735
+ #[ derive( Debug ) ]
736
+ struct PyIterToolsProduct {
737
+ pools : Vec < Vec < PyObjectRef > > ,
738
+ idxs : RefCell < Vec < usize > > ,
739
+ cur : Cell < usize > ,
740
+ stop : Cell < bool > ,
741
+ }
742
+
743
+ impl PyValue for PyIterToolsProduct {
744
+ fn class ( vm : & VirtualMachine ) -> PyClassRef {
745
+ vm. class ( "itertools" , "product" )
746
+ }
747
+ }
748
+
749
+ #[ derive( FromArgs ) ]
750
+ struct ProductArgs {
751
+ #[ pyarg( keyword_only, optional = true ) ]
752
+ repeat : OptionalArg < usize > ,
753
+ }
754
+
755
+ #[ pyimpl]
756
+ impl PyIterToolsProduct {
757
+ #[ pyslot( new) ]
758
+ fn tp_new (
759
+ cls : PyClassRef ,
760
+ iterables : Args < PyObjectRef > ,
761
+ args : ProductArgs ,
762
+ vm : & VirtualMachine ,
763
+ ) -> PyResult < PyRef < Self > > {
764
+ let repeat = match args. repeat . into_option ( ) {
765
+ Some ( i) => i,
766
+ None => 1 ,
767
+ } ;
768
+
769
+ let mut pools = Vec :: new ( ) ;
770
+ for arg in iterables. into_iter ( ) {
771
+ let it = get_iter ( vm, & arg) ?;
772
+ let pool = get_all ( vm, & it) ?;
773
+
774
+ pools. push ( pool) ;
775
+ }
776
+ let pools = iter:: repeat ( pools)
777
+ . take ( repeat)
778
+ . flatten ( )
779
+ . collect :: < Vec < Vec < PyObjectRef > > > ( ) ;
780
+
781
+ let l = pools. len ( ) ;
782
+
783
+ PyIterToolsProduct {
784
+ pools,
785
+ idxs : RefCell :: new ( vec ! [ 0 ; l] ) ,
786
+ cur : Cell :: new ( l - 1 ) ,
787
+ stop : Cell :: new ( false ) ,
788
+ }
789
+ . into_ref_with_type ( vm, cls)
790
+ }
791
+
792
+ #[ pymethod( name = "__next__" ) ]
793
+ fn next ( & self , vm : & VirtualMachine ) -> PyResult {
794
+ // stop signal
795
+ if self . stop . get ( ) {
796
+ return Err ( new_stop_iteration ( vm) ) ;
797
+ }
798
+
799
+ let pools = & self . pools ;
800
+
801
+ for p in pools {
802
+ if p. is_empty ( ) {
803
+ return Err ( new_stop_iteration ( vm) ) ;
804
+ }
805
+ }
806
+
807
+ let res = PyTuple :: from (
808
+ pools
809
+ . iter ( )
810
+ . zip ( self . idxs . borrow ( ) . iter ( ) )
811
+ . map ( |( pool, idx) | pool[ * idx] . clone ( ) )
812
+ . collect :: < Vec < PyObjectRef > > ( ) ,
813
+ ) ;
814
+
815
+ self . update_idxs ( ) ;
816
+
817
+ if self . is_end ( ) {
818
+ self . stop . set ( true ) ;
819
+ }
820
+
821
+ Ok ( res. into_ref ( vm) . into_object ( ) )
822
+ }
823
+
824
+ fn is_end ( & self ) -> bool {
825
+ ( self . idxs . borrow ( ) [ self . cur . get ( ) ] == & self . pools [ self . cur . get ( ) ] . len ( ) - 1
826
+ && self . cur . get ( ) == 0 )
827
+ }
828
+
829
+ fn update_idxs ( & self ) {
830
+ let lst_idx = & self . pools [ self . cur . get ( ) ] . len ( ) - 1 ;
831
+
832
+ if self . idxs . borrow ( ) [ self . cur . get ( ) ] == lst_idx {
833
+ if self . is_end ( ) {
834
+ return ;
835
+ }
836
+ self . idxs . borrow_mut ( ) [ self . cur . get ( ) ] = 0 ;
837
+ self . cur . set ( self . cur . get ( ) - 1 ) ;
838
+ self . update_idxs ( ) ;
839
+ } else {
840
+ self . idxs . borrow_mut ( ) [ self . cur . get ( ) ] += 1 ;
841
+ self . cur . set ( self . idxs . borrow ( ) . len ( ) - 1 ) ;
842
+ }
843
+ }
844
+
845
+ #[ pymethod( name = "__iter__" ) ]
846
+ fn iter ( zelf : PyRef < Self > , _vm : & VirtualMachine ) -> PyRef < Self > {
847
+ zelf
848
+ }
849
+ }
850
+
733
851
pub fn make_module ( vm : & VirtualMachine ) -> PyObjectRef {
734
852
let ctx = & vm. ctx ;
735
853
@@ -761,6 +879,8 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
761
879
762
880
let tee = ctx. new_class ( "tee" , ctx. object ( ) ) ;
763
881
PyItertoolsTee :: extend_class ( ctx, & tee) ;
882
+ let product = ctx. new_class ( "product" , ctx. object ( ) ) ;
883
+ PyIterToolsProduct :: extend_class ( ctx, & product) ;
764
884
765
885
py_module ! ( vm, "itertools" , {
766
886
"chain" => chain,
@@ -774,5 +894,6 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
774
894
"filterfalse" => filterfalse,
775
895
"accumulate" => accumulate,
776
896
"tee" => tee,
897
+ "product" => product,
777
898
} )
778
899
}
0 commit comments