Skip to content

Convert stdlib object to thread safe #1919

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions Lib/test/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,6 @@ def test_iterator_pickle(self):
a.fromlist(data2)
self.assertEqual(list(it), [])

# TODO: RUSTPYTHON
@unittest.expectedFailure
Copy link
Member

@coolreader18 coolreader18 May 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's interesting how this "overflow" behavior is actually correct with how iterators work in CPython; we removed the expectedFailures from the same kind of tests with the other iterators when we converted them to use fetch_{add,sub}

def test_exhausted_iterator(self):
a = array.array(self.typecode, self.example)
self.assertEqual(list(a), list(self.example))
Expand Down
20 changes: 9 additions & 11 deletions vm/src/stdlib/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ use crate::pyobject::{
};
use crate::VirtualMachine;

use std::cell::Cell;
use std::fmt;
use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};

use crossbeam_utils::atomic::AtomicCell;

struct ArrayTypeSpecifierError {
_priv: (),
}
Expand Down Expand Up @@ -421,7 +422,7 @@ impl PyArray {
#[pymethod(name = "__iter__")]
fn iter(zelf: PyRef<Self>) -> PyArrayIter {
PyArrayIter {
position: Cell::new(0),
position: AtomicCell::new(0),
array: zelf,
}
}
Expand All @@ -430,10 +431,12 @@ impl PyArray {
#[pyclass]
#[derive(Debug)]
pub struct PyArrayIter {
position: Cell<usize>,
position: AtomicCell<usize>,
array: PyArrayRef,
}

impl ThreadSafe for PyArrayIter {}

impl PyValue for PyArrayIter {
fn class(vm: &VirtualMachine) -> PyClassRef {
vm.class("array", "arrayiterator")
Expand All @@ -444,14 +447,9 @@ impl PyValue for PyArrayIter {
impl PyArrayIter {
#[pymethod(name = "__next__")]
fn next(&self, vm: &VirtualMachine) -> PyResult {
if self.position.get() < self.array.borrow_value().len() {
let ret = self
.array
.borrow_value()
.getitem_by_idx(self.position.get(), vm)
.unwrap()?;
self.position.set(self.position.get() + 1);
Ok(ret)
let pos = self.position.fetch_add(1);
if let Some(item) = self.array.borrow_value().getitem_by_idx(pos, vm) {
Ok(item?)
} else {
Err(objiter::new_stop_iteration(vm))
}
Expand Down
95 changes: 53 additions & 42 deletions vm/src/stdlib/collections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,27 @@ mod _collections {
use crate::obj::{objiter, objtype::PyClassRef};
use crate::pyobject::{
IdProtocol, PyArithmaticValue::*, PyClassImpl, PyComparisonValue, PyIterable, PyObjectRef,
PyRef, PyResult, PyValue,
PyRef, PyResult, PyValue, ThreadSafe,
};
use crate::sequence;
use crate::vm::ReprGuard;
use crate::VirtualMachine;
use itertools::Itertools;
use std::cell::{Cell, RefCell};
use std::collections::VecDeque;
use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};

use crossbeam_utils::atomic::AtomicCell;

#[pyclass(name = "deque")]
#[derive(Debug, Clone)]
#[derive(Debug)]
struct PyDeque {
deque: RefCell<VecDeque<PyObjectRef>>,
maxlen: Cell<Option<usize>>,
deque: RwLock<VecDeque<PyObjectRef>>,
maxlen: AtomicCell<Option<usize>>,
}
type PyDequeRef = PyRef<PyDeque>;

impl ThreadSafe for PyDeque {}

impl PyValue for PyDeque {
fn class(vm: &VirtualMachine) -> PyClassRef {
vm.class("_collections", "deque")
Expand All @@ -36,8 +40,12 @@ mod _collections {
}

impl PyDeque {
fn borrow_deque<'a>(&'a self) -> impl std::ops::Deref<Target = VecDeque<PyObjectRef>> + 'a {
self.deque.borrow()
fn borrow_deque(&self) -> RwLockReadGuard<'_, VecDeque<PyObjectRef>> {
self.deque.read().unwrap()
}

fn borrow_deque_mut(&self) -> RwLockWriteGuard<'_, VecDeque<PyObjectRef>> {
self.deque.write().unwrap()
}
}

Expand All @@ -51,8 +59,8 @@ mod _collections {
vm: &VirtualMachine,
) -> PyResult<PyRef<Self>> {
let py_deque = PyDeque {
deque: RefCell::default(),
maxlen: maxlen.into(),
deque: RwLock::default(),
maxlen: AtomicCell::new(maxlen),
};
if let OptionalArg::Present(iter) = iter {
py_deque.extend(iter, vm)?;
Expand All @@ -62,36 +70,39 @@ mod _collections {

#[pymethod]
fn append(&self, obj: PyObjectRef) {
let mut deque = self.deque.borrow_mut();
if self.maxlen.get() == Some(deque.len()) {
let mut deque = self.borrow_deque_mut();
if self.maxlen.load() == Some(deque.len()) {
deque.pop_front();
}
deque.push_back(obj);
}

#[pymethod]
fn appendleft(&self, obj: PyObjectRef) {
let mut deque = self.deque.borrow_mut();
if self.maxlen.get() == Some(deque.len()) {
let mut deque = self.borrow_deque_mut();
if self.maxlen.load() == Some(deque.len()) {
deque.pop_back();
}
deque.push_front(obj);
}

#[pymethod]
fn clear(&self) {
self.deque.borrow_mut().clear()
self.borrow_deque_mut().clear()
}

#[pymethod]
fn copy(&self) -> Self {
self.clone()
PyDeque {
deque: RwLock::new(self.borrow_deque().clone()),
maxlen: AtomicCell::new(self.maxlen.load()),
}
}

#[pymethod]
fn count(&self, obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
let mut count = 0;
for elem in self.deque.borrow().iter() {
for elem in self.borrow_deque().iter() {
if vm.identical_or_equal(elem, &obj)? {
count += 1;
}
Expand Down Expand Up @@ -124,7 +135,7 @@ mod _collections {
stop: OptionalArg<usize>,
vm: &VirtualMachine,
) -> PyResult<usize> {
let deque = self.deque.borrow();
let deque = self.borrow_deque();
let start = start.unwrap_or(0);
let stop = stop.unwrap_or_else(|| deque.len());
for (i, elem) in deque.iter().skip(start).take(stop - start).enumerate() {
Expand All @@ -141,9 +152,9 @@ mod _collections {

#[pymethod]
fn insert(&self, idx: i32, obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
let mut deque = self.deque.borrow_mut();
let mut deque = self.borrow_deque_mut();

if self.maxlen.get() == Some(deque.len()) {
if self.maxlen.load() == Some(deque.len()) {
return Err(vm.new_index_error("deque already at its maximum size".to_owned()));
}

Expand All @@ -166,23 +177,21 @@ mod _collections {

#[pymethod]
fn pop(&self, vm: &VirtualMachine) -> PyResult {
self.deque
.borrow_mut()
self.borrow_deque_mut()
.pop_back()
.ok_or_else(|| vm.new_index_error("pop from an empty deque".to_owned()))
}

#[pymethod]
fn popleft(&self, vm: &VirtualMachine) -> PyResult {
self.deque
.borrow_mut()
self.borrow_deque_mut()
.pop_front()
.ok_or_else(|| vm.new_index_error("pop from an empty deque".to_owned()))
}

#[pymethod]
fn remove(&self, obj: PyObjectRef, vm: &VirtualMachine) -> PyResult {
let mut deque = self.deque.borrow_mut();
let mut deque = self.borrow_deque_mut();
let mut idx = None;
for (i, elem) in deque.iter().enumerate() {
if vm.identical_or_equal(elem, &obj)? {
Expand All @@ -196,13 +205,13 @@ mod _collections {

#[pymethod]
fn reverse(&self) {
self.deque
.replace_with(|deque| deque.iter().cloned().rev().collect());
let rev: VecDeque<_> = self.borrow_deque().iter().cloned().rev().collect();
*self.borrow_deque_mut() = rev;
}

#[pymethod]
fn rotate(&self, mid: OptionalArg<isize>) {
let mut deque = self.deque.borrow_mut();
let mut deque = self.borrow_deque_mut();
let mid = mid.unwrap_or(1);
if mid < 0 {
deque.rotate_left(-mid as usize);
Expand All @@ -213,26 +222,25 @@ mod _collections {

#[pyproperty]
fn maxlen(&self) -> Option<usize> {
self.maxlen.get()
self.maxlen.load()
}

#[pyproperty(setter)]
fn set_maxlen(&self, maxlen: Option<usize>) {
self.maxlen.set(maxlen);
self.maxlen.store(maxlen);
}

#[pymethod(name = "__repr__")]
fn repr(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult<String> {
let repr = if let Some(_guard) = ReprGuard::enter(zelf.as_object()) {
let elements = zelf
.deque
.borrow()
.borrow_deque()
.iter()
.map(|obj| vm.to_repr(obj))
.collect::<Result<Vec<_>, _>>()?;
let maxlen = zelf
.maxlen
.get()
.load()
.map(|maxlen| format!(", maxlen={}", maxlen))
.unwrap_or_default();
format!("deque([{}]{})", elements.into_iter().format(", "), maxlen)
Expand Down Expand Up @@ -336,29 +344,29 @@ mod _collections {

#[pymethod(name = "__mul__")]
fn mul(&self, n: isize) -> Self {
let deque: &VecDeque<_> = &self.deque.borrow();
let deque: &VecDeque<_> = &self.borrow_deque();
let mul = sequence::seq_mul(deque, n);
let skipped = if let Some(maxlen) = self.maxlen.get() {
let skipped = if let Some(maxlen) = self.maxlen.load() {
mul.len() - maxlen
} else {
0
};
let deque = mul.skip(skipped).cloned().collect();
PyDeque {
deque: RefCell::new(deque),
maxlen: self.maxlen.clone(),
deque: RwLock::new(deque),
maxlen: AtomicCell::new(self.maxlen.load()),
}
}

#[pymethod(name = "__len__")]
fn len(&self) -> usize {
self.deque.borrow().len()
self.borrow_deque().len()
}

#[pymethod(name = "__iter__")]
fn iter(zelf: PyRef<Self>) -> PyDequeIterator {
PyDequeIterator {
position: Cell::new(0),
position: AtomicCell::new(0),
deque: zelf,
}
}
Expand All @@ -367,10 +375,12 @@ mod _collections {
#[pyclass(name = "_deque_iterator")]
#[derive(Debug)]
struct PyDequeIterator {
position: Cell<usize>,
position: AtomicCell<usize>,
deque: PyDequeRef,
}

impl ThreadSafe for PyDequeIterator {}

impl PyValue for PyDequeIterator {
fn class(vm: &VirtualMachine) -> PyClassRef {
vm.class("_collections", "_deque_iterator")
Expand All @@ -381,9 +391,10 @@ mod _collections {
impl PyDequeIterator {
#[pymethod(name = "__next__")]
fn next(&self, vm: &VirtualMachine) -> PyResult {
if self.position.get() < self.deque.deque.borrow().len() {
let ret = self.deque.deque.borrow()[self.position.get()].clone();
self.position.set(self.position.get() + 1);
let pos = self.position.fetch_add(1);
let deque = self.deque.borrow_deque();
if pos < deque.len() {
let ret = deque[pos].clone();
Ok(ret)
} else {
Err(objiter::new_stop_iteration(vm))
Expand Down
14 changes: 8 additions & 6 deletions vm/src/stdlib/csv.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::cell::RefCell;
use std::fmt::{self, Debug, Formatter};
use std::sync::RwLock;

use csv as rust_csv;
use itertools::join;
Expand All @@ -10,7 +10,7 @@ use crate::obj::objiter;
use crate::obj::objstr::{self, PyString};
use crate::obj::objtype::PyClassRef;
use crate::pyobject::{IntoPyObject, TryFromObject, TypeProtocol};
use crate::pyobject::{PyClassImpl, PyIterable, PyObjectRef, PyRef, PyResult, PyValue};
use crate::pyobject::{PyClassImpl, PyIterable, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe};
use crate::types::create_type;
use crate::VirtualMachine;

Expand Down Expand Up @@ -126,9 +126,11 @@ impl ReadState {

#[pyclass(name = "Reader")]
struct Reader {
state: RefCell<ReadState>,
state: RwLock<ReadState>,
}

impl ThreadSafe for Reader {}

impl Debug for Reader {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "_csv.reader")
Expand All @@ -143,7 +145,7 @@ impl PyValue for Reader {

impl Reader {
fn new(iter: PyIterable<PyObjectRef>, config: ReaderOption) -> Self {
let state = RefCell::new(ReadState::new(iter, config));
let state = RwLock::new(ReadState::new(iter, config));
Reader { state }
}
}
Expand All @@ -152,13 +154,13 @@ impl Reader {
impl Reader {
#[pymethod(name = "__iter__")]
fn iter(this: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
this.state.borrow_mut().cast_to_reader(vm)?;
this.state.write().unwrap().cast_to_reader(vm)?;
this.into_pyobject(vm)
}

#[pymethod(name = "__next__")]
fn next(&self, vm: &VirtualMachine) -> PyResult {
let mut state = self.state.borrow_mut();
let mut state = self.state.write().unwrap();
state.cast_to_reader(vm)?;

if let ReadState::CsvIter(ref mut reader) = &mut *state {
Expand Down
Loading