@@ -848,6 +848,108 @@ impl PyItertoolsProduct {
848
848
}
849
849
}
850
850
851
+ #[ pyclass]
852
+ #[ derive( Debug ) ]
853
+ struct PyItertoolsCombinations {
854
+ pool : Vec < PyObjectRef > ,
855
+ indices : RefCell < Vec < usize > > ,
856
+ r : Cell < usize > ,
857
+ exhausted : Cell < bool > ,
858
+ }
859
+
860
+ impl PyValue for PyItertoolsCombinations {
861
+ fn class ( vm : & VirtualMachine ) -> PyClassRef {
862
+ vm. class ( "itertools" , "combinations" )
863
+ }
864
+ }
865
+
866
+ #[ pyimpl]
867
+ impl PyItertoolsCombinations {
868
+ #[ pyslot( new) ]
869
+ fn tp_new (
870
+ cls : PyClassRef ,
871
+ iterable : PyObjectRef ,
872
+ r : isize ,
873
+ vm : & VirtualMachine ,
874
+ ) -> PyResult < PyRef < Self > > {
875
+ let iter = get_iter ( vm, & iterable) ?;
876
+ let pool = get_all ( vm, & iter) ?;
877
+
878
+ if r < 0 {
879
+ return Err ( vm. new_value_error ( "r must be non-negative" . to_string ( ) ) ) ;
880
+ }
881
+
882
+ let n = pool. len ( ) ;
883
+
884
+ PyItertoolsCombinations {
885
+ pool,
886
+ indices : RefCell :: new ( ( 0 ..r as usize ) . collect ( ) ) ,
887
+ r : Cell :: new ( r as usize ) ,
888
+ exhausted : Cell :: new ( r as usize > n) ,
889
+ }
890
+ . into_ref_with_type ( vm, cls)
891
+ }
892
+
893
+ #[ pymethod( name = "__iter__" ) ]
894
+ fn iter ( zelf : PyRef < Self > , _vm : & VirtualMachine ) -> PyRef < Self > {
895
+ zelf
896
+ }
897
+
898
+ #[ pymethod( name = "__next__" ) ]
899
+ fn next ( & self , vm : & VirtualMachine ) -> PyResult {
900
+ // stop signal
901
+ if self . exhausted . get ( ) {
902
+ return Err ( new_stop_iteration ( vm) ) ;
903
+ }
904
+
905
+ let n = self . pool . len ( ) ;
906
+ let r = self . r . get ( ) ;
907
+
908
+ let res = PyTuple :: from (
909
+ self . pool
910
+ . iter ( )
911
+ . enumerate ( )
912
+ . filter ( |( idx, _) | self . indices . borrow ( ) . contains ( & idx) )
913
+ . map ( |( _, num) | num. clone ( ) )
914
+ . collect :: < Vec < PyObjectRef > > ( ) ,
915
+ ) ;
916
+
917
+ let mut indices = self . indices . borrow_mut ( ) ;
918
+ let mut sentinel = false ;
919
+
920
+ // Scan indices right-to-left until finding one that is not at its maximum (i + n - r).
921
+ let mut idx = r - 1 ;
922
+ loop {
923
+ if indices[ idx] != idx + n - r {
924
+ sentinel = true ;
925
+ break ;
926
+ }
927
+
928
+ if idx != 0 {
929
+ idx -= 1 ;
930
+ } else {
931
+ break ;
932
+ }
933
+ }
934
+ // If no suitable index is found, then the indices are all at
935
+ // their maximum value and we're done.
936
+ if !sentinel {
937
+ self . exhausted . set ( true ) ;
938
+ }
939
+
940
+ // Increment the current index which we know is not at its
941
+ // maximum. Then move back to the right setting each index
942
+ // to its lowest possible value (one higher than the index
943
+ // to its left -- this maintains the sort order invariant).
944
+ indices[ idx] += 1 ;
945
+ for j in idx + 1 ..r {
946
+ indices[ j] = indices[ j - 1 ] + 1 ;
947
+ }
948
+
949
+ Ok ( res. into_ref ( vm) . into_object ( ) )
950
+ }
951
+ }
952
+
851
953
pub fn make_module ( vm : & VirtualMachine ) -> PyObjectRef {
852
954
let ctx = & vm. ctx ;
853
955
@@ -858,6 +960,9 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
858
960
859
961
let compress = PyItertoolsCompress :: make_class ( ctx) ;
860
962
963
+ let combinations = ctx. new_class ( "combinations" , ctx. object ( ) ) ;
964
+ PyItertoolsCombinations :: extend_class ( ctx, & combinations) ;
965
+
861
966
let count = ctx. new_class ( "count" , ctx. object ( ) ) ;
862
967
PyItertoolsCount :: extend_class ( ctx, & count) ;
863
968
@@ -887,6 +992,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
887
992
"accumulate" => accumulate,
888
993
"chain" => chain,
889
994
"compress" => compress,
995
+ "combinations" => combinations,
890
996
"count" => count,
891
997
"dropwhile" => dropwhile,
892
998
"islice" => islice,
0 commit comments