1
1
use std:: cell:: { Cell , RefCell } ;
2
2
use std:: cmp:: Ordering ;
3
+ use std:: iter;
3
4
use std:: ops:: { AddAssign , SubAssign } ;
4
5
use std:: rc:: Rc ;
5
6
6
- use num_bigint:: BigInt ;
7
+ use num_bigint:: { BigInt , Sign } ;
7
8
use num_traits:: ToPrimitive ;
8
9
9
- use crate :: function:: { OptionalArg , PyFuncArgs } ;
10
+ use crate :: function:: { Args , OptionalArg , PyFuncArgs } ;
10
11
use crate :: obj:: objbool;
11
12
use crate :: obj:: objint:: { self , PyInt , PyIntRef } ;
12
- use crate :: obj:: objiter:: { call_next, get_iter, new_stop_iteration} ;
13
+ use crate :: obj:: objiter:: { call_next, get_all , get_iter, new_stop_iteration} ;
13
14
use crate :: obj:: objtuple:: PyTuple ;
14
15
use crate :: obj:: objtype:: { self , PyClassRef } ;
15
16
use crate :: pyobject:: {
@@ -736,6 +737,143 @@ impl PyItertoolsTee {
736
737
}
737
738
}
738
739
740
+ #[ pyclass]
741
+ #[ derive( Debug ) ]
742
+ struct PyIterToolsProduct {
743
+ pools : Vec < Vec < PyObjectRef > > ,
744
+ idxs : RefCell < Vec < usize > > ,
745
+ cur : RefCell < usize > ,
746
+ sizes : Vec < usize > ,
747
+ stop : RefCell < bool > ,
748
+ }
749
+
750
+ impl PyValue for PyIterToolsProduct {
751
+ fn class ( vm : & VirtualMachine ) -> PyClassRef {
752
+ vm. class ( "itertools" , "product" )
753
+ }
754
+ }
755
+
756
+ #[ derive( FromArgs ) ]
757
+ struct ProductArgs {
758
+ #[ pyarg( keyword_only, optional = true ) ]
759
+ repeat : OptionalArg < PyIntRef > ,
760
+ }
761
+
762
+ #[ pyimpl]
763
+ impl PyIterToolsProduct {
764
+ #[ pyslot( new) ]
765
+ fn new (
766
+ cls : PyClassRef ,
767
+ iterables : Args < PyObjectRef > ,
768
+ args : ProductArgs ,
769
+ vm : & VirtualMachine ,
770
+ ) -> PyResult < PyRef < Self > > {
771
+ let repeat = match args. repeat . into_option ( ) {
772
+ Some ( int) => match int. as_bigint ( ) . sign ( ) {
773
+ Sign :: Plus | Sign :: NoSign => match int. as_bigint ( ) . to_usize ( ) {
774
+ Some ( x) => x,
775
+ None => {
776
+ return Err ( vm. new_overflow_error ( "repeat argument too large" . to_string ( ) ) )
777
+ }
778
+ } ,
779
+ Sign :: Minus => {
780
+ return Err ( vm. new_value_error ( "repeat argument cannot be negative" . to_string ( ) ) )
781
+ }
782
+ } ,
783
+ None => 1 ,
784
+ } ;
785
+
786
+ let mut pools = Vec :: new ( ) ;
787
+ let mut sizes = Vec :: new ( ) ;
788
+ for arg in iterables. into_iter ( ) {
789
+ let it = get_iter ( vm, & arg) ?;
790
+ let pool = get_all ( vm, & it) ?;
791
+ let size = pool. len ( ) ;
792
+
793
+ pools. push ( pool) ;
794
+ sizes. push ( size) ;
795
+ }
796
+ let pools = iter:: repeat ( pools)
797
+ . take ( repeat)
798
+ . flatten ( )
799
+ . collect :: < Vec < Vec < PyObjectRef > > > ( ) ;
800
+
801
+ let sizes = iter:: repeat ( sizes)
802
+ . take ( repeat)
803
+ . flatten ( )
804
+ . collect :: < Vec < usize > > ( ) ;
805
+
806
+ let l = pools. len ( ) ;
807
+
808
+ PyIterToolsProduct {
809
+ pools,
810
+ idxs : RefCell :: new ( vec ! [ 0 ; l] ) ,
811
+ cur : RefCell :: new ( l - 1 ) ,
812
+ sizes,
813
+ stop : RefCell :: new ( false ) ,
814
+ }
815
+ . into_ref_with_type ( vm, cls)
816
+ }
817
+
818
+ #[ pymethod( name = "__next__" ) ]
819
+ fn next ( & self , vm : & VirtualMachine ) -> PyResult {
820
+ // stop signal
821
+ if * self . stop . borrow ( ) {
822
+ return Err ( new_stop_iteration ( vm) ) ;
823
+ }
824
+
825
+ let pools = & self . pools ;
826
+
827
+ for s in & self . sizes {
828
+ if * s == 0 {
829
+ return Err ( new_stop_iteration ( vm) ) ;
830
+ }
831
+ }
832
+
833
+ let res = PyTuple :: from (
834
+ pools
835
+ . iter ( )
836
+ . zip ( self . idxs . borrow ( ) . iter ( ) )
837
+ . map ( |( pool, idx) | pool[ * idx] . clone ( ) )
838
+ . collect :: < Vec < PyObjectRef > > ( ) ,
839
+ ) ;
840
+
841
+ self . update_idxs ( ) ;
842
+
843
+ if self . is_end ( ) {
844
+ * self . stop . borrow_mut ( ) = true ;
845
+ }
846
+
847
+ Ok ( res. into_ref ( vm) . into_object ( ) )
848
+ }
849
+
850
+ fn is_end ( & self ) -> bool {
851
+ ( self . idxs . borrow ( ) [ * self . cur . borrow ( ) ] == & self . sizes [ * self . cur . borrow ( ) ] - 1
852
+ && * self . cur . borrow ( ) == 0 )
853
+ }
854
+
855
+ fn update_idxs ( & self ) {
856
+ let lst_idx = & self . sizes [ * self . cur . borrow ( ) ] - 1 ;
857
+
858
+ if self . idxs . borrow ( ) [ * self . cur . borrow ( ) ] == lst_idx {
859
+ if self . is_end ( ) {
860
+ return ;
861
+ }
862
+ self . idxs . borrow_mut ( ) [ * self . cur . borrow ( ) ] = 0 ;
863
+ * self . cur . borrow_mut ( ) -= 1 ;
864
+ self . update_idxs ( ) ;
865
+ } else {
866
+ self . idxs . borrow_mut ( ) [ * self . cur . borrow ( ) ] += 1 ;
867
+ * self . cur . borrow_mut ( ) = self . idxs . borrow ( ) . len ( ) - 1 ;
868
+ }
869
+ }
870
+
871
+ #[ pymethod( name = "__iter__" ) ]
872
+ fn iter ( zelf : PyRef < Self > , _vm : & VirtualMachine ) -> PyRef < Self > {
873
+ zelf
874
+ }
875
+ }
876
+
739
877
pub fn make_module ( vm : & VirtualMachine ) -> PyObjectRef {
740
878
let ctx = & vm. ctx ;
741
879
@@ -767,6 +905,8 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
767
905
768
906
let tee = ctx. new_class ( "tee" , ctx. object ( ) ) ;
769
907
PyItertoolsTee :: extend_class ( ctx, & tee) ;
908
+ let product = ctx. new_class ( "product" , ctx. object ( ) ) ;
909
+ PyIterToolsProduct :: extend_class ( ctx, & product) ;
770
910
771
911
py_module ! ( vm, "itertools" , {
772
912
"chain" => chain,
@@ -780,5 +920,6 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
780
920
"filterfalse" => filterfalse,
781
921
"accumulate" => accumulate,
782
922
"tee" => tee,
923
+ "product" => product,
783
924
} )
784
925
}
0 commit comments