Skip to content

Commit 2511b2f

Browse files
committed
Make some of objlist use atomic types
1 parent c82b986 commit 2511b2f

File tree

2 files changed

+24
-17
lines changed

2 files changed

+24
-17
lines changed

vm/src/obj/objlist.rs

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
use std::cell::{Cell, RefCell};
1+
use std::cell::RefCell;
22
use std::fmt;
33
use std::mem::size_of;
44
use std::ops::Range;
55

6+
use crossbeam_utils::atomic::AtomicCell;
67
use num_bigint::{BigInt, ToBigInt};
78
use num_traits::{One, Signed, ToPrimitive, Zero};
89

@@ -28,6 +29,7 @@ use crate::vm::{ReprGuard, VirtualMachine};
2829
#[pyclass]
2930
#[derive(Default)]
3031
pub struct PyList {
32+
// TODO: make this a RwLock at the same time as PyObjectRef is Send + Sync
3133
elements: RefCell<Vec<PyObjectRef>>,
3234
}
3335

@@ -234,7 +236,7 @@ impl PyList {
234236
fn reversed(zelf: PyRef<Self>) -> PyListReverseIterator {
235237
let final_position = zelf.elements.borrow().len();
236238
PyListReverseIterator {
237-
position: Cell::new(final_position),
239+
position: AtomicCell::new(final_position),
238240
list: zelf,
239241
}
240242
}
@@ -252,7 +254,7 @@ impl PyList {
252254
#[pymethod(name = "__iter__")]
253255
fn iter(zelf: PyRef<Self>) -> PyListIterator {
254256
PyListIterator {
255-
position: Cell::new(0),
257+
position: AtomicCell::new(0),
256258
list: zelf,
257259
}
258260
}
@@ -844,7 +846,7 @@ fn do_sort(
844846
#[pyclass]
845847
#[derive(Debug)]
846848
pub struct PyListIterator {
847-
pub position: Cell<usize>,
849+
pub position: AtomicCell<usize>,
848850
pub list: PyListRef,
849851
}
850852

@@ -858,10 +860,11 @@ impl PyValue for PyListIterator {
858860
impl PyListIterator {
859861
#[pymethod(name = "__next__")]
860862
fn next(&self, vm: &VirtualMachine) -> PyResult {
861-
if self.position.get() < self.list.elements.borrow().len() {
862-
let ret = self.list.elements.borrow()[self.position.get()].clone();
863-
self.position.set(self.position.get() + 1);
864-
Ok(ret)
863+
let list = self.list.elements.borrow();
864+
let pos = self.position.load();
865+
if let Some(obj) = list.get(pos) {
866+
self.position.store(pos + 1);
867+
Ok(obj.clone())
865868
} else {
866869
Err(objiter::new_stop_iteration(vm))
867870
}
@@ -874,14 +877,16 @@ impl PyListIterator {
874877

875878
#[pymethod(name = "__length_hint__")]
876879
fn length_hint(&self) -> usize {
877-
self.list.elements.borrow().len() - self.position.get()
880+
let list = self.list.elements.borrow();
881+
let pos = self.position.load();
882+
list.len() - pos
878883
}
879884
}
880885

881886
#[pyclass]
882887
#[derive(Debug)]
883888
pub struct PyListReverseIterator {
884-
pub position: Cell<usize>,
889+
pub position: AtomicCell<usize>,
885890
pub list: PyListRef,
886891
}
887892

@@ -895,10 +900,12 @@ impl PyValue for PyListReverseIterator {
895900
impl PyListReverseIterator {
896901
#[pymethod(name = "__next__")]
897902
fn next(&self, vm: &VirtualMachine) -> PyResult {
898-
if self.position.get() > 0 {
899-
let position: usize = self.position.get() - 1;
900-
let ret = self.list.elements.borrow()[position].clone();
901-
self.position.set(position);
903+
let pos = self.position.load();
904+
if pos > 0 {
905+
let pos = pos - 1;
906+
let list = self.list.elements.borrow();
907+
let ret = list[pos].clone();
908+
self.position.store(pos);
902909
Ok(ret)
903910
} else {
904911
Err(objiter::new_stop_iteration(vm))
@@ -912,7 +919,7 @@ impl PyListReverseIterator {
912919

913920
#[pymethod(name = "__length_hint__")]
914921
fn length_hint(&self) -> usize {
915-
self.position.get()
922+
self.position.load()
916923
}
917924
}
918925

vm/src/obj/objset.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
* Builtin set type with a sequence of unique items.
33
*/
44

5-
use std::cell::{Cell, RefCell};
5+
use std::cell::RefCell;
66
use std::fmt;
77

88
use super::objlist::PyListIterator;
@@ -244,7 +244,7 @@ impl PySetInner {
244244
let items = self.content.keys().collect();
245245
let set_list = vm.ctx.new_list(items);
246246
PyListIterator {
247-
position: Cell::new(0),
247+
position: crossbeam_utils::atomic::AtomicCell::new(0),
248248
list: set_list.downcast().unwrap(),
249249
}
250250
}

0 commit comments

Comments
 (0)