@@ -5,6 +5,7 @@ use std::ops::{AddAssign, SubAssign};
5
5
use std:: rc:: Rc ;
6
6
7
7
use num_bigint:: BigInt ;
8
+ use num_traits:: sign:: Signed ;
8
9
use num_traits:: ToPrimitive ;
9
10
10
11
use crate :: function:: { Args , OptionalArg , PyFuncArgs } ;
@@ -733,14 +734,14 @@ impl PyItertoolsTee {
733
734
734
735
#[ pyclass]
735
736
#[ derive( Debug ) ]
736
- struct PyIterToolsProduct {
737
+ struct PyItertoolsProduct {
737
738
pools : Vec < Vec < PyObjectRef > > ,
738
739
idxs : RefCell < Vec < usize > > ,
739
740
cur : Cell < usize > ,
740
741
stop : Cell < bool > ,
741
742
}
742
743
743
- impl PyValue for PyIterToolsProduct {
744
+ impl PyValue for PyItertoolsProduct {
744
745
fn class ( vm : & VirtualMachine ) -> PyClassRef {
745
746
vm. class ( "itertools" , "product" )
746
747
}
@@ -753,7 +754,7 @@ struct ProductArgs {
753
754
}
754
755
755
756
#[ pyimpl]
756
- impl PyIterToolsProduct {
757
+ impl PyItertoolsProduct {
757
758
#[ pyslot( new) ]
758
759
fn tp_new (
759
760
cls : PyClassRef ,
@@ -780,7 +781,7 @@ impl PyIterToolsProduct {
780
781
781
782
let l = pools. len ( ) ;
782
783
783
- PyIterToolsProduct {
784
+ PyItertoolsProduct {
784
785
pools,
785
786
idxs : RefCell :: new ( vec ! [ 0 ; l] ) ,
786
787
cur : Cell :: new ( l - 1 ) ,
@@ -848,19 +849,137 @@ impl PyIterToolsProduct {
848
849
}
849
850
}
850
851
852
+ #[ pyclass]
853
+ #[ derive( Debug ) ]
854
+ struct PyItertoolsCombinations {
855
+ pool : Vec < PyObjectRef > ,
856
+ indices : RefCell < Vec < usize > > ,
857
+ r : Cell < usize > ,
858
+ exhausted : Cell < bool > ,
859
+ }
860
+
861
+ impl PyValue for PyItertoolsCombinations {
862
+ fn class ( vm : & VirtualMachine ) -> PyClassRef {
863
+ vm. class ( "itertools" , "combinations" )
864
+ }
865
+ }
866
+
867
+ #[ pyimpl]
868
+ impl PyItertoolsCombinations {
869
+ #[ pyslot( new) ]
870
+ fn tp_new (
871
+ cls : PyClassRef ,
872
+ iterable : PyObjectRef ,
873
+ r : PyIntRef ,
874
+ vm : & VirtualMachine ,
875
+ ) -> PyResult < PyRef < Self > > {
876
+ let iter = get_iter ( vm, & iterable) ?;
877
+ let pool = get_all ( vm, & iter) ?;
878
+
879
+ let r = r. as_bigint ( ) ;
880
+ if r. is_negative ( ) {
881
+ return Err ( vm. new_value_error ( "r must be non-negative" . to_string ( ) ) ) ;
882
+ }
883
+ let r = r. to_usize ( ) . unwrap ( ) ;
884
+
885
+ let n = pool. len ( ) ;
886
+
887
+ PyItertoolsCombinations {
888
+ pool,
889
+ indices : RefCell :: new ( ( 0 ..r) . collect ( ) ) ,
890
+ r : Cell :: new ( r) ,
891
+ exhausted : Cell :: new ( r > n) ,
892
+ }
893
+ . into_ref_with_type ( vm, cls)
894
+ }
895
+
896
+ #[ pymethod( name = "__iter__" ) ]
897
+ fn iter ( zelf : PyRef < Self > , _vm : & VirtualMachine ) -> PyRef < Self > {
898
+ zelf
899
+ }
900
+
901
+ #[ pymethod( name = "__next__" ) ]
902
+ fn next ( & self , vm : & VirtualMachine ) -> PyResult {
903
+ // stop signal
904
+ if self . exhausted . get ( ) {
905
+ return Err ( new_stop_iteration ( vm) ) ;
906
+ }
907
+
908
+ let n = self . pool . len ( ) ;
909
+ let r = self . r . get ( ) ;
910
+
911
+ let res = PyTuple :: from (
912
+ self . pool
913
+ . iter ( )
914
+ . enumerate ( )
915
+ . filter ( |( idx, _) | self . indices . borrow ( ) . contains ( & idx) )
916
+ . map ( |( _, num) | num. clone ( ) )
917
+ . collect :: < Vec < PyObjectRef > > ( ) ,
918
+ ) ;
919
+
920
+ let mut indices = self . indices . borrow_mut ( ) ;
921
+ let mut sentinel = false ;
922
+
923
+ // Scan indices right-to-left until finding one that is not at its maximum (i + n - r).
924
+ let mut idx = r - 1 ;
925
+ loop {
926
+ if indices[ idx] != idx + n - r {
927
+ sentinel = true ;
928
+ break ;
929
+ }
930
+
931
+ if idx != 0 {
932
+ idx -= 1 ;
933
+ } else {
934
+ break ;
935
+ }
936
+ }
937
+ // If no suitable index is found, then the indices are all at
938
+ // their maximum value and we're done.
939
+ if !sentinel {
940
+ self . exhausted . set ( true ) ;
941
+ }
942
+
943
+ // Increment the current index which we know is not at its
944
+ // maximum. Then move back to the right setting each index
945
+ // to its lowest possible value (one higher than the index
946
+ // to its left -- this maintains the sort order invariant).
947
+ indices[ idx] += 1 ;
948
+ for j in idx + 1 ..r {
949
+ indices[ j] = indices[ j - 1 ] + 1 ;
950
+ }
951
+
952
+ Ok ( res. into_ref ( vm) . into_object ( ) )
953
+ }
954
+ }
955
+
851
956
pub fn make_module ( vm : & VirtualMachine ) -> PyObjectRef {
852
957
let ctx = & vm. ctx ;
853
958
959
+ let accumulate = ctx. new_class ( "accumulate" , ctx. object ( ) ) ;
960
+ PyItertoolsAccumulate :: extend_class ( ctx, & accumulate) ;
961
+
854
962
let chain = PyItertoolsChain :: make_class ( ctx) ;
855
963
856
964
let compress = PyItertoolsCompress :: make_class ( ctx) ;
857
965
966
+ let combinations = ctx. new_class ( "combinations" , ctx. object ( ) ) ;
967
+ PyItertoolsCombinations :: extend_class ( ctx, & combinations) ;
968
+
858
969
let count = ctx. new_class ( "count" , ctx. object ( ) ) ;
859
970
PyItertoolsCount :: extend_class ( ctx, & count) ;
860
971
861
972
let dropwhile = ctx. new_class ( "dropwhile" , ctx. object ( ) ) ;
862
973
PyItertoolsDropwhile :: extend_class ( ctx, & dropwhile) ;
863
974
975
+ let islice = PyItertoolsIslice :: make_class ( ctx) ;
976
+
977
+ let filterfalse = ctx. new_class ( "filterfalse" , ctx. object ( ) ) ;
978
+ PyItertoolsFilterFalse :: extend_class ( ctx, & filterfalse) ;
979
+
980
+ let product = ctx. new_class ( "product" , ctx. object ( ) ) ;
981
+ PyItertoolsProduct :: extend_class ( ctx, & product) ;
982
+
864
983
let repeat = ctx. new_class ( "repeat" , ctx. object ( ) ) ;
865
984
PyItertoolsRepeat :: extend_class ( ctx, & repeat) ;
866
985
@@ -869,30 +988,21 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
869
988
let takewhile = ctx. new_class ( "takewhile" , ctx. object ( ) ) ;
870
989
PyItertoolsTakewhile :: extend_class ( ctx, & takewhile) ;
871
990
872
- let islice = PyItertoolsIslice :: make_class ( ctx) ;
873
-
874
- let filterfalse = ctx. new_class ( "filterfalse" , ctx. object ( ) ) ;
875
- PyItertoolsFilterFalse :: extend_class ( ctx, & filterfalse) ;
876
-
877
- let accumulate = ctx. new_class ( "accumulate" , ctx. object ( ) ) ;
878
- PyItertoolsAccumulate :: extend_class ( ctx, & accumulate) ;
879
-
880
991
let tee = ctx. new_class ( "tee" , ctx. object ( ) ) ;
881
992
PyItertoolsTee :: extend_class ( ctx, & tee) ;
882
- let product = ctx. new_class ( "product" , ctx. object ( ) ) ;
883
- PyIterToolsProduct :: extend_class ( ctx, & product) ;
884
993
885
994
py_module ! ( vm, "itertools" , {
995
+ "accumulate" => accumulate,
886
996
"chain" => chain,
887
997
"compress" => compress,
998
+ "combinations" => combinations,
888
999
"count" => count,
889
1000
"dropwhile" => dropwhile,
1001
+ "islice" => islice,
1002
+ "filterfalse" => filterfalse,
890
1003
"repeat" => repeat,
891
1004
"starmap" => starmap,
892
1005
"takewhile" => takewhile,
893
- "islice" => islice,
894
- "filterfalse" => filterfalse,
895
- "accumulate" => accumulate,
896
1006
"tee" => tee,
897
1007
"product" => product,
898
1008
} )
0 commit comments