Skip to content

Commit a0cd62a

Browse files
committed
Make objiter use atomic types
1 parent 2511b2f commit a0cd62a

File tree

3 files changed

+33
-32
lines changed

3 files changed

+33
-32
lines changed

vm/src/builtins.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
//!
33
//! Implements functions listed here: https://docs.python.org/3/library/builtins.html
44
5-
use std::cell::Cell;
65
use std::char;
76
use std::str;
87

@@ -672,12 +671,9 @@ fn builtin_reversed(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult {
672671
vm.get_method_or_type_error(obj.clone(), "__getitem__", || {
673672
"argument to reversed() must be a sequence".to_owned()
674673
})?;
675-
let len = vm.call_method(&obj.clone(), "__len__", PyFuncArgs::default())?;
676-
let obj_iterator = objiter::PySequenceIterator {
677-
position: Cell::new(objint::get_value(&len).to_isize().unwrap() - 1),
678-
obj: obj.clone(),
679-
reversed: true,
680-
};
674+
let len = vm.call_method(&obj, "__len__", PyFuncArgs::default())?;
675+
let len = objint::get_value(&len).to_isize().unwrap();
676+
let obj_iterator = objiter::PySequenceIterator::new_reversed(obj, len);
681677
Ok(obj_iterator.into_ref(vm).into_object())
682678
}
683679
}

vm/src/obj/objiter.rs

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
* Various types to support iteration.
33
*/
44

5+
use crossbeam_utils::atomic::AtomicCell;
56
use num_traits::{Signed, ToPrimitive};
6-
use std::cell::Cell;
77

88
use super::objint::PyInt;
99
use super::objsequence;
@@ -28,12 +28,9 @@ pub fn get_iter(vm: &VirtualMachine, iter_target: &PyObjectRef) -> PyResult {
2828
vm.get_method_or_type_error(iter_target.clone(), "__getitem__", || {
2929
format!("Cannot iterate over {}", iter_target.class().name)
3030
})?;
31-
let obj_iterator = PySequenceIterator {
32-
position: Cell::new(0),
33-
obj: iter_target.clone(),
34-
reversed: false,
35-
};
36-
Ok(obj_iterator.into_ref(vm).into_object())
31+
Ok(PySequenceIterator::new_forward(iter_target.clone())
32+
.into_ref(vm)
33+
.into_object())
3734
}
3835
}
3936

@@ -140,7 +137,7 @@ pub fn length_hint(vm: &VirtualMachine, iter: PyObjectRef) -> PyResult<Option<us
140137
#[pyclass]
141138
#[derive(Debug)]
142139
pub struct PySequenceIterator {
143-
pub position: Cell<isize>,
140+
pub position: AtomicCell<isize>,
144141
pub obj: PyObjectRef,
145142
pub reversed: bool,
146143
}
@@ -153,14 +150,31 @@ impl PyValue for PySequenceIterator {
153150

154151
#[pyimpl]
155152
impl PySequenceIterator {
153+
pub fn new_forward(obj: PyObjectRef) -> Self {
154+
Self {
155+
position: AtomicCell::new(0),
156+
obj,
157+
reversed: false,
158+
}
159+
}
160+
161+
pub fn new_reversed(obj: PyObjectRef, len: isize) -> Self {
162+
Self {
163+
position: AtomicCell::new(len - 1),
164+
obj,
165+
reversed: true,
166+
}
167+
}
168+
156169
#[pymethod(name = "__next__")]
157170
fn next(&self, vm: &VirtualMachine) -> PyResult {
158-
if self.position.get() >= 0 {
171+
let pos = self.position.load();
172+
if pos >= 0 {
159173
let step: isize = if self.reversed { -1 } else { 1 };
160-
let number = vm.ctx.new_int(self.position.get());
174+
let number = vm.ctx.new_int(pos);
161175
match vm.call_method(&self.obj, "__getitem__", vec![number]) {
162176
Ok(val) => {
163-
self.position.set(self.position.get() + step);
177+
self.position.store(pos + step);
164178
Ok(val)
165179
}
166180
Err(ref e) if objtype::isinstance(&e, &vm.ctx.exceptions.index_error) => {
@@ -181,7 +195,7 @@ impl PySequenceIterator {
181195

182196
#[pymethod(name = "__length_hint__")]
183197
fn length_hint(&self, vm: &VirtualMachine) -> PyResult<isize> {
184-
let pos = self.position.get();
198+
let pos = self.position.load();
185199
let hint = if self.reversed {
186200
pos + 1
187201
} else {
@@ -195,11 +209,7 @@ impl PySequenceIterator {
195209
}
196210

197211
pub fn seq_iter_method(obj: PyObjectRef) -> PySequenceIterator {
198-
PySequenceIterator {
199-
position: Cell::new(0),
200-
obj,
201-
reversed: false,
202-
}
212+
PySequenceIterator::new_forward(obj)
203213
}
204214

205215
#[pyclass(name = "callable_iterator")]

vm/src/pyobject.rs

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
use std::any::Any;
2-
use std::cell::Cell;
32
use std::collections::HashMap;
43
use std::fmt;
54
use std::marker::PhantomData;
@@ -1012,13 +1011,9 @@ where
10121011
})?;
10131012
Self::try_from_object(
10141013
vm,
1015-
objiter::PySequenceIterator {
1016-
position: Cell::new(0),
1017-
obj: obj.clone(),
1018-
reversed: false,
1019-
}
1020-
.into_ref(vm)
1021-
.into_object(),
1014+
objiter::PySequenceIterator::new_forward(obj)
1015+
.into_ref(vm)
1016+
.into_object(),
10221017
)
10231018
}
10241019
}

0 commit comments

Comments
 (0)