1
- use std:: cell:: { Cell , RefCell } ;
1
+ use std:: cell:: RefCell ;
2
2
use std:: fmt;
3
3
use std:: mem:: size_of;
4
4
use std:: ops:: Range ;
5
5
6
+ use crossbeam_utils:: atomic:: AtomicCell ;
6
7
use num_bigint:: { BigInt , ToBigInt } ;
7
8
use num_traits:: { One , Signed , ToPrimitive , Zero } ;
8
9
@@ -28,6 +29,7 @@ use crate::vm::{ReprGuard, VirtualMachine};
28
29
#[ pyclass]
29
30
#[ derive( Default ) ]
30
31
pub struct PyList {
32
+ // TODO: make this a RwLock at the same time as PyObjectRef is Send + Sync
31
33
elements : RefCell < Vec < PyObjectRef > > ,
32
34
}
33
35
@@ -234,7 +236,7 @@ impl PyList {
234
236
fn reversed ( zelf : PyRef < Self > ) -> PyListReverseIterator {
235
237
let final_position = zelf. elements . borrow ( ) . len ( ) ;
236
238
PyListReverseIterator {
237
- position : Cell :: new ( final_position) ,
239
+ position : AtomicCell :: new ( final_position) ,
238
240
list : zelf,
239
241
}
240
242
}
@@ -252,7 +254,7 @@ impl PyList {
252
254
#[ pymethod( name = "__iter__" ) ]
253
255
fn iter ( zelf : PyRef < Self > ) -> PyListIterator {
254
256
PyListIterator {
255
- position : Cell :: new ( 0 ) ,
257
+ position : AtomicCell :: new ( 0 ) ,
256
258
list : zelf,
257
259
}
258
260
}
@@ -844,7 +846,7 @@ fn do_sort(
844
846
#[ pyclass]
845
847
#[ derive( Debug ) ]
846
848
pub struct PyListIterator {
847
- pub position : Cell < usize > ,
849
+ pub position : AtomicCell < usize > ,
848
850
pub list : PyListRef ,
849
851
}
850
852
@@ -858,10 +860,11 @@ impl PyValue for PyListIterator {
858
860
impl PyListIterator {
859
861
#[ pymethod( name = "__next__" ) ]
860
862
fn next ( & self , vm : & VirtualMachine ) -> PyResult {
861
- if self . position . get ( ) < self . list . elements . borrow ( ) . len ( ) {
862
- let ret = self . list . elements . borrow ( ) [ self . position . get ( ) ] . clone ( ) ;
863
- self . position . set ( self . position . get ( ) + 1 ) ;
864
- Ok ( ret)
863
+ let list = self . list . elements . borrow ( ) ;
864
+ let pos = self . position . load ( ) ;
865
+ if let Some ( obj) = list. get ( pos) {
866
+ self . position . store ( pos + 1 ) ;
867
+ Ok ( obj. clone ( ) )
865
868
} else {
866
869
Err ( objiter:: new_stop_iteration ( vm) )
867
870
}
@@ -874,14 +877,16 @@ impl PyListIterator {
874
877
875
878
#[ pymethod( name = "__length_hint__" ) ]
876
879
fn length_hint ( & self ) -> usize {
877
- self . list . elements . borrow ( ) . len ( ) - self . position . get ( )
880
+ let list = self . list . elements . borrow ( ) ;
881
+ let pos = self . position . load ( ) ;
882
+ list. len ( ) - pos
878
883
}
879
884
}
880
885
881
886
#[ pyclass]
882
887
#[ derive( Debug ) ]
883
888
pub struct PyListReverseIterator {
884
- pub position : Cell < usize > ,
889
+ pub position : AtomicCell < usize > ,
885
890
pub list : PyListRef ,
886
891
}
887
892
@@ -895,10 +900,12 @@ impl PyValue for PyListReverseIterator {
895
900
impl PyListReverseIterator {
896
901
#[ pymethod( name = "__next__" ) ]
897
902
fn next ( & self , vm : & VirtualMachine ) -> PyResult {
898
- if self . position . get ( ) > 0 {
899
- let position: usize = self . position . get ( ) - 1 ;
900
- let ret = self . list . elements . borrow ( ) [ position] . clone ( ) ;
901
- self . position . set ( position) ;
903
+ let pos = self . position . load ( ) ;
904
+ if pos > 0 {
905
+ let pos = pos - 1 ;
906
+ let list = self . list . elements . borrow ( ) ;
907
+ let ret = list[ pos] . clone ( ) ;
908
+ self . position . store ( pos) ;
902
909
Ok ( ret)
903
910
} else {
904
911
Err ( objiter:: new_stop_iteration ( vm) )
@@ -912,7 +919,7 @@ impl PyListReverseIterator {
912
919
913
920
#[ pymethod( name = "__length_hint__" ) ]
914
921
fn length_hint ( & self ) -> usize {
915
- self . position . get ( )
922
+ self . position . load ( )
916
923
}
917
924
}
918
925
0 commit comments