@@ -73,6 +73,59 @@ impl PyItertoolsChain {
73
73
}
74
74
}
75
75
76
+ #[ pyclass( name = "compress" ) ]
77
+ #[ derive( Debug ) ]
78
+ struct PyItertoolsCompress {
79
+ data : PyObjectRef ,
80
+ selector : PyObjectRef ,
81
+ }
82
+
83
+ impl PyValue for PyItertoolsCompress {
84
+ fn class ( vm : & VirtualMachine ) -> PyClassRef {
85
+ vm. class ( "itertools" , "compress" )
86
+ }
87
+ }
88
+
89
+ #[ pyimpl]
90
+ impl PyItertoolsCompress {
91
+ #[ pymethod( name = "__new__" ) ]
92
+ #[ allow( clippy:: new_ret_no_self) ]
93
+ fn new (
94
+ _cls : PyClassRef ,
95
+ data : PyObjectRef ,
96
+ selector : PyObjectRef ,
97
+ vm : & VirtualMachine ,
98
+ ) -> PyResult {
99
+ let data_iter = get_iter ( vm, & data) ?;
100
+ let selector_iter = get_iter ( vm, & selector) ?;
101
+
102
+ Ok ( PyItertoolsCompress {
103
+ data : data_iter,
104
+ selector : selector_iter,
105
+ }
106
+ . into_ref ( vm)
107
+ . into_object ( ) )
108
+ }
109
+
110
+ #[ pymethod( name = "__next__" ) ]
111
+ fn next ( & self , vm : & VirtualMachine ) -> PyResult {
112
+ loop {
113
+ let sel_obj = call_next ( vm, & self . selector ) ?;
114
+ let verdict = objbool:: boolval ( vm, sel_obj. clone ( ) ) ?;
115
+ let data_obj = call_next ( vm, & self . data ) ?;
116
+
117
+ if verdict {
118
+ return Ok ( data_obj) ;
119
+ }
120
+ }
121
+ }
122
+
123
+ #[ pymethod( name = "__iter__" ) ]
124
+ fn iter ( zelf : PyRef < Self > , _vm : & VirtualMachine ) -> PyRef < Self > {
125
+ zelf
126
+ }
127
+ }
128
+
76
129
#[ pyclass]
77
130
#[ derive( Debug ) ]
78
131
struct PyItertoolsCount {
@@ -577,8 +630,8 @@ impl PyItertoolsAccumulate {
577
630
let obj = call_next ( vm, iterable) ?;
578
631
579
632
let next_acc_value = match & * self . acc_value . borrow ( ) {
580
- Option :: None => obj. clone ( ) ,
581
- Option :: Some ( value) => {
633
+ None => obj. clone ( ) ,
634
+ Some ( value) => {
582
635
if self . binop . is ( & vm. get_none ( ) ) {
583
636
vm. _add ( value. clone ( ) , obj. clone ( ) ) ?
584
637
} else {
@@ -602,6 +655,8 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
602
655
603
656
let chain = PyItertoolsChain :: make_class ( ctx) ;
604
657
658
+ let compress = PyItertoolsCompress :: make_class ( ctx) ;
659
+
605
660
let count = ctx. new_class ( "count" , ctx. object ( ) ) ;
606
661
PyItertoolsCount :: extend_class ( ctx, & count) ;
607
662
@@ -626,6 +681,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
626
681
627
682
py_module ! ( vm, "itertools" , {
628
683
"chain" => chain,
684
+ "compress" => compress,
629
685
"count" => count,
630
686
"dropwhile" => dropwhile,
631
687
"repeat" => repeat,
0 commit comments