2
2
* Various types to support iteration.
3
3
*/
4
4
5
+ use crossbeam_utils:: atomic:: AtomicCell ;
5
6
use num_traits:: { Signed , ToPrimitive } ;
6
- use std:: cell:: Cell ;
7
7
8
8
use super :: objint:: PyInt ;
9
9
use super :: objsequence;
@@ -28,12 +28,9 @@ pub fn get_iter(vm: &VirtualMachine, iter_target: &PyObjectRef) -> PyResult {
28
28
vm. get_method_or_type_error ( iter_target. clone ( ) , "__getitem__" , || {
29
29
format ! ( "Cannot iterate over {}" , iter_target. class( ) . name)
30
30
} ) ?;
31
- let obj_iterator = PySequenceIterator {
32
- position : Cell :: new ( 0 ) ,
33
- obj : iter_target. clone ( ) ,
34
- reversed : false ,
35
- } ;
36
- Ok ( obj_iterator. into_ref ( vm) . into_object ( ) )
31
+ Ok ( PySequenceIterator :: new_forward ( iter_target. clone ( ) )
32
+ . into_ref ( vm)
33
+ . into_object ( ) )
37
34
}
38
35
}
39
36
@@ -140,7 +137,7 @@ pub fn length_hint(vm: &VirtualMachine, iter: PyObjectRef) -> PyResult<Option<us
140
137
#[ pyclass]
141
138
#[ derive( Debug ) ]
142
139
pub struct PySequenceIterator {
143
- pub position : Cell < isize > ,
140
+ pub position : AtomicCell < isize > ,
144
141
pub obj : PyObjectRef ,
145
142
pub reversed : bool ,
146
143
}
@@ -153,14 +150,31 @@ impl PyValue for PySequenceIterator {
153
150
154
151
#[ pyimpl]
155
152
impl PySequenceIterator {
153
+ pub fn new_forward ( obj : PyObjectRef ) -> Self {
154
+ Self {
155
+ position : AtomicCell :: new ( 0 ) ,
156
+ obj,
157
+ reversed : false ,
158
+ }
159
+ }
160
+
161
+ pub fn new_reversed ( obj : PyObjectRef , len : isize ) -> Self {
162
+ Self {
163
+ position : AtomicCell :: new ( len - 1 ) ,
164
+ obj,
165
+ reversed : true ,
166
+ }
167
+ }
168
+
156
169
#[ pymethod( name = "__next__" ) ]
157
170
fn next ( & self , vm : & VirtualMachine ) -> PyResult {
158
- if self . position . get ( ) >= 0 {
171
+ let pos = self . position . load ( ) ;
172
+ if pos >= 0 {
159
173
let step: isize = if self . reversed { -1 } else { 1 } ;
160
- let number = vm. ctx . new_int ( self . position . get ( ) ) ;
174
+ let number = vm. ctx . new_int ( pos ) ;
161
175
match vm. call_method ( & self . obj , "__getitem__" , vec ! [ number] ) {
162
176
Ok ( val) => {
163
- self . position . set ( self . position . get ( ) + step) ;
177
+ self . position . store ( pos + step) ;
164
178
Ok ( val)
165
179
}
166
180
Err ( ref e) if objtype:: isinstance ( & e, & vm. ctx . exceptions . index_error ) => {
@@ -181,7 +195,7 @@ impl PySequenceIterator {
181
195
182
196
#[ pymethod( name = "__length_hint__" ) ]
183
197
fn length_hint ( & self , vm : & VirtualMachine ) -> PyResult < isize > {
184
- let pos = self . position . get ( ) ;
198
+ let pos = self . position . load ( ) ;
185
199
let hint = if self . reversed {
186
200
pos + 1
187
201
} else {
@@ -195,11 +209,7 @@ impl PySequenceIterator {
195
209
}
196
210
197
211
pub fn seq_iter_method ( obj : PyObjectRef ) -> PySequenceIterator {
198
- PySequenceIterator {
199
- position : Cell :: new ( 0 ) ,
200
- obj,
201
- reversed : false ,
202
- }
212
+ PySequenceIterator :: new_forward ( obj)
203
213
}
204
214
205
215
#[ pyclass( name = "callable_iterator" ) ]
0 commit comments