From f448654229e670ea459f71775680af823a363299 Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Sat, 9 May 2020 10:45:12 +0300 Subject: [PATCH 1/7] Make PyArrayIter ThreadSafe --- vm/src/stdlib/array.rs | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/vm/src/stdlib/array.rs b/vm/src/stdlib/array.rs index b4a437477e..9c96659c71 100644 --- a/vm/src/stdlib/array.rs +++ b/vm/src/stdlib/array.rs @@ -10,10 +10,11 @@ use crate::pyobject::{ }; use crate::VirtualMachine; -use std::cell::Cell; use std::fmt; use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; +use crossbeam_utils::atomic::AtomicCell; + struct ArrayTypeSpecifierError { _priv: (), } @@ -421,7 +422,7 @@ impl PyArray { #[pymethod(name = "__iter__")] fn iter(zelf: PyRef) -> PyArrayIter { PyArrayIter { - position: Cell::new(0), + position: AtomicCell::new(0), array: zelf, } } @@ -430,10 +431,12 @@ impl PyArray { #[pyclass] #[derive(Debug)] pub struct PyArrayIter { - position: Cell, + position: AtomicCell, array: PyArrayRef, } +impl ThreadSafe for PyArrayIter {} + impl PyValue for PyArrayIter { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("array", "arrayiterator") @@ -444,14 +447,9 @@ impl PyValue for PyArrayIter { impl PyArrayIter { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - if self.position.get() < self.array.borrow_value().len() { - let ret = self - .array - .borrow_value() - .getitem_by_idx(self.position.get(), vm) - .unwrap()?; - self.position.set(self.position.get() + 1); - Ok(ret) + let pos = self.position.fetch_add(1); + if let Some(item) = self.array.borrow_value().getitem_by_idx(pos, vm) { + Ok(item?) } else { Err(objiter::new_stop_iteration(vm)) } From 94e93f72625ce7ce81f9d68e735b64d00544318c Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Sat, 9 May 2020 11:36:44 +0300 Subject: [PATCH 2/7] Make PyDeque, PyDequeIterator ThreadSafe --- vm/src/stdlib/collections.rs | 95 ++++++++++++++++++++---------------- 1 file changed, 53 insertions(+), 42 deletions(-) diff --git a/vm/src/stdlib/collections.rs b/vm/src/stdlib/collections.rs index 4ec3633b5c..0bfb21d2c7 100644 --- a/vm/src/stdlib/collections.rs +++ b/vm/src/stdlib/collections.rs @@ -6,23 +6,27 @@ mod _collections { use crate::obj::{objiter, objtype::PyClassRef}; use crate::pyobject::{ IdProtocol, PyArithmaticValue::*, PyClassImpl, PyComparisonValue, PyIterable, PyObjectRef, - PyRef, PyResult, PyValue, + PyRef, PyResult, PyValue, ThreadSafe, }; use crate::sequence; use crate::vm::ReprGuard; use crate::VirtualMachine; use itertools::Itertools; - use std::cell::{Cell, RefCell}; use std::collections::VecDeque; + use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; + + use crossbeam_utils::atomic::AtomicCell; #[pyclass(name = "deque")] - #[derive(Debug, Clone)] + #[derive(Debug)] struct PyDeque { - deque: RefCell>, - maxlen: Cell>, + deque: RwLock>, + maxlen: AtomicCell>, } type PyDequeRef = PyRef; + impl ThreadSafe for PyDeque {} + impl PyValue for PyDeque { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("_collections", "deque") @@ -36,8 +40,12 @@ mod _collections { } impl PyDeque { - fn borrow_deque<'a>(&'a self) -> impl std::ops::Deref> + 'a { - self.deque.borrow() + fn borrow_deque(&self) -> RwLockReadGuard<'_, VecDeque> { + self.deque.read().unwrap() + } + + fn borrow_deque_mut(&self) -> RwLockWriteGuard<'_, VecDeque> { + self.deque.write().unwrap() } } @@ -51,8 +59,8 @@ mod _collections { vm: &VirtualMachine, ) -> PyResult> { let py_deque = PyDeque { - deque: RefCell::default(), - maxlen: maxlen.into(), + deque: RwLock::default(), + maxlen: AtomicCell::new(maxlen), }; if let OptionalArg::Present(iter) = iter { py_deque.extend(iter, vm)?; @@ -62,8 +70,8 @@ mod _collections { #[pymethod] fn append(&self, obj: PyObjectRef) { - let mut deque = self.deque.borrow_mut(); - if self.maxlen.get() == Some(deque.len()) { + let mut deque = self.borrow_deque_mut(); + if self.maxlen.load() == Some(deque.len()) { deque.pop_front(); } deque.push_back(obj); @@ -71,8 +79,8 @@ mod _collections { #[pymethod] fn appendleft(&self, obj: PyObjectRef) { - let mut deque = self.deque.borrow_mut(); - if self.maxlen.get() == Some(deque.len()) { + let mut deque = self.borrow_deque_mut(); + if self.maxlen.load() == Some(deque.len()) { deque.pop_back(); } deque.push_front(obj); @@ -80,18 +88,21 @@ mod _collections { #[pymethod] fn clear(&self) { - self.deque.borrow_mut().clear() + self.borrow_deque_mut().clear() } #[pymethod] fn copy(&self) -> Self { - self.clone() + PyDeque { + deque: RwLock::new(self.borrow_deque().clone()), + maxlen: AtomicCell::new(self.maxlen.load()), + } } #[pymethod] fn count(&self, obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { let mut count = 0; - for elem in self.deque.borrow().iter() { + for elem in self.borrow_deque().iter() { if vm.identical_or_equal(elem, &obj)? { count += 1; } @@ -124,7 +135,7 @@ mod _collections { stop: OptionalArg, vm: &VirtualMachine, ) -> PyResult { - let deque = self.deque.borrow(); + let deque = self.borrow_deque(); let start = start.unwrap_or(0); let stop = stop.unwrap_or_else(|| deque.len()); for (i, elem) in deque.iter().skip(start).take(stop - start).enumerate() { @@ -141,9 +152,9 @@ mod _collections { #[pymethod] fn insert(&self, idx: i32, obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - let mut deque = self.deque.borrow_mut(); + let mut deque = self.borrow_deque_mut(); - if self.maxlen.get() == Some(deque.len()) { + if self.maxlen.load() == Some(deque.len()) { return Err(vm.new_index_error("deque already at its maximum size".to_owned())); } @@ -166,23 +177,21 @@ mod _collections { #[pymethod] fn pop(&self, vm: &VirtualMachine) -> PyResult { - self.deque - .borrow_mut() + self.borrow_deque_mut() .pop_back() .ok_or_else(|| vm.new_index_error("pop from an empty deque".to_owned())) } #[pymethod] fn popleft(&self, vm: &VirtualMachine) -> PyResult { - self.deque - .borrow_mut() + self.borrow_deque_mut() .pop_front() .ok_or_else(|| vm.new_index_error("pop from an empty deque".to_owned())) } #[pymethod] fn remove(&self, obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { - let mut deque = self.deque.borrow_mut(); + let mut deque = self.borrow_deque_mut(); let mut idx = None; for (i, elem) in deque.iter().enumerate() { if vm.identical_or_equal(elem, &obj)? { @@ -196,13 +205,13 @@ mod _collections { #[pymethod] fn reverse(&self) { - self.deque - .replace_with(|deque| deque.iter().cloned().rev().collect()); + let rev: VecDeque<_> = self.borrow_deque().iter().cloned().rev().collect(); + *self.borrow_deque_mut() = rev; } #[pymethod] fn rotate(&self, mid: OptionalArg) { - let mut deque = self.deque.borrow_mut(); + let mut deque = self.borrow_deque_mut(); let mid = mid.unwrap_or(1); if mid < 0 { deque.rotate_left(-mid as usize); @@ -213,26 +222,25 @@ mod _collections { #[pyproperty] fn maxlen(&self) -> Option { - self.maxlen.get() + self.maxlen.load() } #[pyproperty(setter)] fn set_maxlen(&self, maxlen: Option) { - self.maxlen.set(maxlen); + self.maxlen.store(maxlen); } #[pymethod(name = "__repr__")] fn repr(zelf: PyRef, vm: &VirtualMachine) -> PyResult { let repr = if let Some(_guard) = ReprGuard::enter(zelf.as_object()) { let elements = zelf - .deque - .borrow() + .borrow_deque() .iter() .map(|obj| vm.to_repr(obj)) .collect::, _>>()?; let maxlen = zelf .maxlen - .get() + .load() .map(|maxlen| format!(", maxlen={}", maxlen)) .unwrap_or_default(); format!("deque([{}]{})", elements.into_iter().format(", "), maxlen) @@ -336,29 +344,29 @@ mod _collections { #[pymethod(name = "__mul__")] fn mul(&self, n: isize) -> Self { - let deque: &VecDeque<_> = &self.deque.borrow(); + let deque: &VecDeque<_> = &self.borrow_deque(); let mul = sequence::seq_mul(deque, n); - let skipped = if let Some(maxlen) = self.maxlen.get() { + let skipped = if let Some(maxlen) = self.maxlen.load() { mul.len() - maxlen } else { 0 }; let deque = mul.skip(skipped).cloned().collect(); PyDeque { - deque: RefCell::new(deque), - maxlen: self.maxlen.clone(), + deque: RwLock::new(deque), + maxlen: AtomicCell::new(self.maxlen.load()), } } #[pymethod(name = "__len__")] fn len(&self) -> usize { - self.deque.borrow().len() + self.borrow_deque().len() } #[pymethod(name = "__iter__")] fn iter(zelf: PyRef) -> PyDequeIterator { PyDequeIterator { - position: Cell::new(0), + position: AtomicCell::new(0), deque: zelf, } } @@ -367,10 +375,12 @@ mod _collections { #[pyclass(name = "_deque_iterator")] #[derive(Debug)] struct PyDequeIterator { - position: Cell, + position: AtomicCell, deque: PyDequeRef, } + impl ThreadSafe for PyDequeIterator {} + impl PyValue for PyDequeIterator { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("_collections", "_deque_iterator") @@ -381,9 +391,10 @@ mod _collections { impl PyDequeIterator { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - if self.position.get() < self.deque.deque.borrow().len() { - let ret = self.deque.deque.borrow()[self.position.get()].clone(); - self.position.set(self.position.get() + 1); + let pos = self.position.fetch_add(1); + let deque = self.deque.borrow_deque(); + if pos < deque.len() { + let ret = deque[pos].clone(); Ok(ret) } else { Err(objiter::new_stop_iteration(vm)) From 1b585bda61c97673ff87cff510f758ddf303c960 Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Sat, 9 May 2020 11:41:55 +0300 Subject: [PATCH 3/7] Make Reader ThreadSafe --- vm/src/stdlib/csv.rs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/vm/src/stdlib/csv.rs b/vm/src/stdlib/csv.rs index fc42643a2a..b2ff582268 100644 --- a/vm/src/stdlib/csv.rs +++ b/vm/src/stdlib/csv.rs @@ -1,5 +1,5 @@ -use std::cell::RefCell; use std::fmt::{self, Debug, Formatter}; +use std::sync::RwLock; use csv as rust_csv; use itertools::join; @@ -10,7 +10,7 @@ use crate::obj::objiter; use crate::obj::objstr::{self, PyString}; use crate::obj::objtype::PyClassRef; use crate::pyobject::{IntoPyObject, TryFromObject, TypeProtocol}; -use crate::pyobject::{PyClassImpl, PyIterable, PyObjectRef, PyRef, PyResult, PyValue}; +use crate::pyobject::{PyClassImpl, PyIterable, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe}; use crate::types::create_type; use crate::VirtualMachine; @@ -126,9 +126,11 @@ impl ReadState { #[pyclass(name = "Reader")] struct Reader { - state: RefCell, + state: RwLock, } +impl ThreadSafe for Reader {} + impl Debug for Reader { fn fmt(&self, f: &mut Formatter) -> fmt::Result { write!(f, "_csv.reader") @@ -143,7 +145,7 @@ impl PyValue for Reader { impl Reader { fn new(iter: PyIterable, config: ReaderOption) -> Self { - let state = RefCell::new(ReadState::new(iter, config)); + let state = RwLock::new(ReadState::new(iter, config)); Reader { state } } } @@ -152,13 +154,13 @@ impl Reader { impl Reader { #[pymethod(name = "__iter__")] fn iter(this: PyRef, vm: &VirtualMachine) -> PyResult { - this.state.borrow_mut().cast_to_reader(vm)?; + this.state.write().unwrap().cast_to_reader(vm)?; this.into_pyobject(vm) } #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - let mut state = self.state.borrow_mut(); + let mut state = self.state.write().unwrap(); state.cast_to_reader(vm)?; if let ReadState::CsvIter(ref mut reader) = &mut *state { From 75af7f6b1c9f4f1fc99e1ac97a4cb8e72d973d1a Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Sat, 9 May 2020 12:07:35 +0300 Subject: [PATCH 4/7] Make PyBytesIO, PyStringIO ThreadSafe --- vm/src/stdlib/io.rs | 43 +++++++++++++++++++++++++------------------ 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/vm/src/stdlib/io.rs b/vm/src/stdlib/io.rs index 3b2760b92d..0a0736802e 100644 --- a/vm/src/stdlib/io.rs +++ b/vm/src/stdlib/io.rs @@ -1,10 +1,11 @@ /* * I/O core tools. */ -use std::cell::{RefCell, RefMut}; use std::fs; use std::io::{self, prelude::*, Cursor, SeekFrom}; +use std::sync::{RwLock, RwLockWriteGuard}; +use crossbeam_utils::atomic::AtomicCell; use num_traits::ToPrimitive; use crate::exceptions::PyBaseExceptionRef; @@ -18,7 +19,7 @@ use crate::obj::objiter; use crate::obj::objstr::{self, PyStringRef}; use crate::obj::objtype::{self, PyClassRef}; use crate::pyobject::{ - BufferProtocol, Either, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, + BufferProtocol, Either, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe, TryFromObject, }; use crate::vm::VirtualMachine; @@ -120,11 +121,14 @@ impl BufferedIO { #[derive(Debug)] struct PyStringIO { - buffer: RefCell>, + buffer: RwLock, + closed: AtomicCell, } type PyStringIORef = PyRef; +impl ThreadSafe for PyStringIO {} + impl PyValue for PyStringIO { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("io", "StringIO") @@ -132,10 +136,9 @@ impl PyValue for PyStringIO { } impl PyStringIORef { - fn buffer(&self, vm: &VirtualMachine) -> PyResult> { - let buffer = self.buffer.borrow_mut(); - if buffer.is_some() { - Ok(RefMut::map(buffer, |opt| opt.as_mut().unwrap())) + fn buffer(&self, vm: &VirtualMachine) -> PyResult> { + if !self.closed.load() { + Ok(self.buffer.write().unwrap()) } else { Err(vm.new_value_error("I/O operation on closed file.".to_owned())) } @@ -209,11 +212,11 @@ impl PyStringIORef { } fn closed(self) -> bool { - self.buffer.borrow().is_none() + self.closed.load() } fn close(self) { - self.buffer.replace(None); + self.closed.store(true); } } @@ -235,18 +238,22 @@ fn string_io_new( let input = flatten.map_or_else(Vec::new, |v| objstr::borrow_value(&v).as_bytes().to_vec()); PyStringIO { - buffer: RefCell::new(Some(BufferedIO::new(Cursor::new(input)))), + buffer: RwLock::new(BufferedIO::new(Cursor::new(input))), + closed: AtomicCell::new(false), } .into_ref_with_type(vm, cls) } #[derive(Debug)] struct PyBytesIO { - buffer: RefCell>, + buffer: RwLock, + closed: AtomicCell, } type PyBytesIORef = PyRef; +impl ThreadSafe for PyBytesIO {} + impl PyValue for PyBytesIO { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("io", "BytesIO") @@ -254,10 +261,9 @@ impl PyValue for PyBytesIO { } impl PyBytesIORef { - fn buffer(&self, vm: &VirtualMachine) -> PyResult> { - let buffer = self.buffer.borrow_mut(); - if buffer.is_some() { - Ok(RefMut::map(buffer, |opt| opt.as_mut().unwrap())) + fn buffer(&self, vm: &VirtualMachine) -> PyResult> { + if !self.closed.load() { + Ok(self.buffer.write().unwrap()) } else { Err(vm.new_value_error("I/O operation on closed file.".to_owned())) } @@ -320,11 +326,11 @@ impl PyBytesIORef { } fn closed(self) -> bool { - self.buffer.borrow().is_none() + self.closed.load() } fn close(self) { - self.buffer.replace(None); + self.closed.store(true) } } @@ -339,7 +345,8 @@ fn bytes_io_new( }; PyBytesIO { - buffer: RefCell::new(Some(BufferedIO::new(Cursor::new(raw_bytes)))), + buffer: RwLock::new(BufferedIO::new(Cursor::new(raw_bytes))), + closed: AtomicCell::new(false), } .into_ref_with_type(vm, cls) } From 4cf5178d43647b2528e25a7e9cb95e989ab0012a Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Sat, 9 May 2020 17:12:27 +0300 Subject: [PATCH 5/7] Make PyItertools* ThreadSafe --- vm/src/stdlib/itertools.rs | 350 +++++++++++++++++++++---------------- 1 file changed, 196 insertions(+), 154 deletions(-) diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index ea7767fb20..03573f1326 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -2,11 +2,11 @@ pub(crate) use decl::make_module; #[pymodule(name = "itertools")] mod decl { + use crossbeam_utils::atomic::AtomicCell; use num_bigint::BigInt; use num_traits::{One, Signed, ToPrimitive, Zero}; - use std::cell::{Cell, RefCell}; use std::iter; - use std::rc::Rc; + use std::sync::{Arc, RwLock, RwLockWriteGuard}; use crate::function::{Args, OptionalArg, OptionalOption, PyFuncArgs}; use crate::obj::objbool; @@ -15,7 +15,8 @@ mod decl { use crate::obj::objtuple::PyTuple; use crate::obj::objtype::{self, PyClassRef}; use crate::pyobject::{ - IdProtocol, PyCallable, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, + IdProtocol, PyCallable, PyClassImpl, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe, + TypeProtocol, }; use crate::vm::VirtualMachine; @@ -23,9 +24,12 @@ mod decl { #[derive(Debug)] struct PyItertoolsChain { iterables: Vec, - cur: RefCell<(usize, Option)>, + cur_idx: AtomicCell, + cached_iter: RwLock>, } + impl ThreadSafe for PyItertoolsChain {} + impl PyValue for PyItertoolsChain { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "chain") @@ -38,27 +42,40 @@ mod decl { fn tp_new(cls: PyClassRef, args: PyFuncArgs, vm: &VirtualMachine) -> PyResult> { PyItertoolsChain { iterables: args.args, - cur: RefCell::new((0, None)), + cur_idx: AtomicCell::new(0), + cached_iter: RwLock::new(None), } .into_ref_with_type(vm, cls) } #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - let (ref mut cur_idx, ref mut cur_iter) = *self.cur.borrow_mut(); - while *cur_idx < self.iterables.len() { - if cur_iter.is_none() { - *cur_iter = Some(get_iter(vm, &self.iterables[*cur_idx])?); + loop { + let pos = self.cur_idx.load(); + if pos >= self.iterables.len() { + break; } + let cur_iter = if self.cached_iter.read().unwrap().is_none() { + // We need to call "get_iter" outside of the lock. + let iter = get_iter(vm, &self.iterables[pos])?; + *self.cached_iter.write().unwrap() = Some(iter.clone()); + iter + } else { + if let Some(cached_iter) = (*(self.cached_iter.read().unwrap())).clone() { + cached_iter + } else { + // Someone changed cached iter to None since we checked. + continue; + } + }; - // can't be directly inside the 'match' clause, otherwise the borrows collide. - let obj = call_next(vm, cur_iter.as_ref().unwrap()); - match obj { + // We need to call "call_next" outside of the lock. + match call_next(vm, &cur_iter) { Ok(ok) => return Ok(ok), Err(err) => { if objtype::isinstance(&err, &vm.ctx.exceptions.stop_iteration) { - *cur_idx += 1; - *cur_iter = None; + self.cur_idx.fetch_add(1); + *self.cached_iter.write().unwrap() = None; } else { return Err(err); } @@ -85,7 +102,8 @@ mod decl { PyItertoolsChain { iterables, - cur: RefCell::new((0, None)), + cur_idx: AtomicCell::new(0), + cached_iter: RwLock::new(None), } .into_ref_with_type(vm, cls) } @@ -98,6 +116,8 @@ mod decl { selector: PyObjectRef, } + impl ThreadSafe for PyItertoolsCompress {} + impl PyValue for PyItertoolsCompress { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "compress") @@ -145,10 +165,12 @@ mod decl { #[pyclass(name = "count")] #[derive(Debug)] struct PyItertoolsCount { - cur: RefCell, + cur: RwLock, step: BigInt, } + impl ThreadSafe for PyItertoolsCount {} + impl PyValue for PyItertoolsCount { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "count") @@ -174,7 +196,7 @@ mod decl { }; PyItertoolsCount { - cur: RefCell::new(start), + cur: RwLock::new(start), step, } .into_ref_with_type(vm, cls) @@ -182,8 +204,9 @@ mod decl { #[pymethod(name = "__next__")] fn next(&self) -> PyResult { - let result = self.cur.borrow().clone(); - *self.cur.borrow_mut() += &self.step; + let mut cur = self.cur.write().unwrap(); + let result = cur.clone(); + *cur += &self.step; Ok(PyInt::new(result)) } @@ -196,12 +219,13 @@ mod decl { #[pyclass(name = "cycle")] #[derive(Debug)] struct PyItertoolsCycle { - iter: RefCell, - saved: RefCell>, - index: Cell, - first_pass: Cell, + iter: PyObjectRef, + saved: RwLock>, + index: AtomicCell, } + impl ThreadSafe for PyItertoolsCycle {} + impl PyValue for PyItertoolsCycle { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "cycle") @@ -219,36 +243,31 @@ mod decl { let iter = get_iter(vm, &iterable)?; PyItertoolsCycle { - iter: RefCell::new(iter.clone()), - saved: RefCell::new(Vec::new()), - index: Cell::new(0), - first_pass: Cell::new(false), + iter: iter.clone(), + saved: RwLock::new(Vec::new()), + index: AtomicCell::new(0), } .into_ref_with_type(vm, cls) } #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - let item = if let Some(item) = get_next_object(vm, &self.iter.borrow())? { - if self.first_pass.get() { - return Ok(item); - } - - self.saved.borrow_mut().push(item.clone()); + let item = if let Some(item) = get_next_object(vm, &self.iter)? { + self.saved.write().unwrap().push(item.clone()); item } else { - if self.saved.borrow().len() == 0 { + let saved = self.saved.read().unwrap(); + if saved.len() == 0 { return Err(new_stop_iteration(vm)); } - let last_index = self.index.get(); - self.index.set(self.index.get() + 1); + let last_index = self.index.fetch_add(1); - if self.index.get() >= self.saved.borrow().len() { - self.index.set(0); + if last_index >= saved.len() - 1 { + self.index.store(0); } - self.saved.borrow()[last_index].clone() + saved[last_index].clone() }; Ok(item) @@ -264,9 +283,11 @@ mod decl { #[derive(Debug)] struct PyItertoolsRepeat { object: PyObjectRef, - times: Option>, + times: Option>, } + impl ThreadSafe for PyItertoolsRepeat {} + impl PyValue for PyItertoolsRepeat { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "repeat") @@ -283,7 +304,7 @@ mod decl { vm: &VirtualMachine, ) -> PyResult> { let times = match times.into_option() { - Some(int) => Some(RefCell::new(int.as_bigint().clone())), + Some(int) => Some(RwLock::new(int.as_bigint().clone())), None => None, }; @@ -297,10 +318,11 @@ mod decl { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { if let Some(ref times) = self.times { - if *times.borrow() <= BigInt::zero() { + let mut times = times.write().unwrap(); + if *times <= BigInt::zero() { return Err(new_stop_iteration(vm)); } - *times.borrow_mut() -= 1; + *times -= 1; } Ok(self.object.clone()) @@ -314,7 +336,7 @@ mod decl { #[pymethod(name = "__length_hint__")] fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef { match self.times { - Some(ref times) => vm.new_int(times.borrow().clone()), + Some(ref times) => vm.new_int(times.read().unwrap().clone()), None => vm.new_int(0), } } @@ -327,6 +349,8 @@ mod decl { iter: PyObjectRef, } + impl ThreadSafe for PyItertoolsStarmap {} + impl PyValue for PyItertoolsStarmap { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "starmap") @@ -366,9 +390,11 @@ mod decl { struct PyItertoolsTakewhile { predicate: PyObjectRef, iterable: PyObjectRef, - stop_flag: RefCell, + stop_flag: AtomicCell, } + impl ThreadSafe for PyItertoolsTakewhile {} + impl PyValue for PyItertoolsTakewhile { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "takewhile") @@ -389,14 +415,14 @@ mod decl { PyItertoolsTakewhile { predicate, iterable: iter, - stop_flag: RefCell::new(false), + stop_flag: AtomicCell::new(false), } .into_ref_with_type(vm, cls) } #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - if *self.stop_flag.borrow() { + if self.stop_flag.load() { return Err(new_stop_iteration(vm)); } @@ -409,7 +435,7 @@ mod decl { if verdict { Ok(obj) } else { - *self.stop_flag.borrow_mut() = true; + self.stop_flag.store(true); Err(new_stop_iteration(vm)) } } @@ -425,9 +451,11 @@ mod decl { struct PyItertoolsDropwhile { predicate: PyCallable, iterable: PyObjectRef, - start_flag: Cell, + start_flag: AtomicCell, } + impl ThreadSafe for PyItertoolsDropwhile {} + impl PyValue for PyItertoolsDropwhile { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "dropwhile") @@ -448,7 +476,7 @@ mod decl { PyItertoolsDropwhile { predicate, iterable: iter, - start_flag: Cell::new(false), + start_flag: AtomicCell::new(false), } .into_ref_with_type(vm, cls) } @@ -458,13 +486,13 @@ mod decl { let predicate = &self.predicate; let iterable = &self.iterable; - if !self.start_flag.get() { + if !self.start_flag.load() { loop { let obj = call_next(vm, iterable)?; let pred = predicate.clone(); let pred_value = vm.invoke(&pred.into_object(), vec![obj.clone()])?; if !objbool::boolval(vm, pred_value)? { - self.start_flag.set(true); + self.start_flag.store(true); return Ok(obj); } } @@ -482,12 +510,14 @@ mod decl { #[derive(Debug)] struct PyItertoolsIslice { iterable: PyObjectRef, - cur: RefCell, - next: RefCell, + cur: AtomicCell, + next: AtomicCell, stop: Option, step: usize, } + impl ThreadSafe for PyItertoolsIslice {} + impl PyValue for PyItertoolsIslice { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "islice") @@ -567,8 +597,8 @@ mod decl { PyItertoolsIslice { iterable: iter, - cur: RefCell::new(0), - next: RefCell::new(start), + cur: AtomicCell::new(0), + next: AtomicCell::new(start), stop, step, } @@ -577,23 +607,23 @@ mod decl { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - while *self.cur.borrow() < *self.next.borrow() { + while self.cur.load() < self.next.load() { call_next(vm, &self.iterable)?; - *self.cur.borrow_mut() += 1; + self.cur.fetch_add(1); } if let Some(stop) = self.stop { - if *self.cur.borrow() >= stop { + if self.cur.load() >= stop { return Err(new_stop_iteration(vm)); } } let obj = call_next(vm, &self.iterable)?; - *self.cur.borrow_mut() += 1; + self.cur.fetch_add(1); // TODO is this overflow check required? attempts to copy CPython. - let (next, ovf) = (*self.next.borrow()).overflowing_add(self.step); - *self.next.borrow_mut() = if ovf { self.stop.unwrap() } else { next }; + let (next, ovf) = self.next.load().overflowing_add(self.step); + self.next.store(if ovf { self.stop.unwrap() } else { next }); Ok(obj) } @@ -611,6 +641,8 @@ mod decl { iterable: PyObjectRef, } + impl ThreadSafe for PyItertoolsFilterFalse {} + impl PyValue for PyItertoolsFilterFalse { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "filterfalse") @@ -665,9 +697,11 @@ mod decl { struct PyItertoolsAccumulate { iterable: PyObjectRef, binop: PyObjectRef, - acc_value: RefCell>, + acc_value: RwLock>, } + impl ThreadSafe for PyItertoolsAccumulate {} + impl PyValue for PyItertoolsAccumulate { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "accumulate") @@ -688,7 +722,7 @@ mod decl { PyItertoolsAccumulate { iterable: iter, binop: binop.unwrap_or_else(|| vm.get_none()), - acc_value: RefCell::from(Option::None), + acc_value: RwLock::new(None), } .into_ref_with_type(vm, cls) } @@ -698,7 +732,9 @@ mod decl { let iterable = &self.iterable; let obj = call_next(vm, iterable)?; - let next_acc_value = match &*self.acc_value.borrow() { + let acc_value = self.acc_value.read().unwrap().clone(); + + let next_acc_value = match acc_value { None => obj.clone(), Some(value) => { if self.binop.is(&vm.get_none()) { @@ -708,7 +744,7 @@ mod decl { } } }; - self.acc_value.replace(Option::from(next_acc_value.clone())); + *self.acc_value.write().unwrap() = Some(next_acc_value.clone()); Ok(next_acc_value) } @@ -722,33 +758,37 @@ mod decl { #[derive(Debug)] struct PyItertoolsTeeData { iterable: PyObjectRef, - values: RefCell>, + values: RwLock>, } + impl ThreadSafe for PyItertoolsTeeData {} + impl PyItertoolsTeeData { - fn new(iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult> { - Ok(Rc::new(PyItertoolsTeeData { + fn new(iterable: PyObjectRef, vm: &VirtualMachine) -> PyResult> { + Ok(Arc::new(PyItertoolsTeeData { iterable: get_iter(vm, &iterable)?, - values: RefCell::new(vec![]), + values: RwLock::new(vec![]), })) } fn get_item(&self, vm: &VirtualMachine, index: usize) -> PyResult { - if self.values.borrow().len() == index { + if self.values.read().unwrap().len() == index { let result = call_next(vm, &self.iterable)?; - self.values.borrow_mut().push(result); + self.values.write().unwrap().push(result); } - Ok(self.values.borrow()[index].clone()) + Ok(self.values.read().unwrap()[index].clone()) } } #[pyclass(name = "tee")] #[derive(Debug)] struct PyItertoolsTee { - tee_data: Rc, - index: Cell, + tee_data: Arc, + index: AtomicCell, } + impl ThreadSafe for PyItertoolsTee {} + impl PyValue for PyItertoolsTee { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "tee") @@ -764,7 +804,7 @@ mod decl { } Ok(PyItertoolsTee { tee_data: PyItertoolsTeeData::new(it, vm)?, - index: Cell::from(0), + index: AtomicCell::new(0), } .into_ref_with_type(vm, PyItertoolsTee::class(vm))? .into_object()) @@ -800,8 +840,8 @@ mod decl { #[pymethod(name = "__copy__")] fn copy(&self, vm: &VirtualMachine) -> PyResult { Ok(PyItertoolsTee { - tee_data: Rc::clone(&self.tee_data), - index: self.index.clone(), + tee_data: Arc::clone(&self.tee_data), + index: AtomicCell::new(self.index.load()), } .into_ref_with_type(vm, Self::class(vm))? .into_object()) @@ -809,8 +849,8 @@ mod decl { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - let value = self.tee_data.get_item(vm, self.index.get())?; - self.index.set(self.index.get() + 1); + let value = self.tee_data.get_item(vm, self.index.load())?; + self.index.fetch_add(1); Ok(value) } @@ -824,11 +864,13 @@ mod decl { #[derive(Debug)] struct PyItertoolsProduct { pools: Vec>, - idxs: RefCell>, - cur: Cell, - stop: Cell, + idxs: RwLock>, + cur: AtomicCell, + stop: AtomicCell, } + impl ThreadSafe for PyItertoolsProduct {} + impl PyValue for PyItertoolsProduct { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "product") @@ -871,9 +913,9 @@ mod decl { PyItertoolsProduct { pools, - idxs: RefCell::new(vec![0; l]), - cur: Cell::new(l - 1), - stop: Cell::new(false), + idxs: RwLock::new(vec![0; l]), + cur: AtomicCell::new(l - 1), + stop: AtomicCell::new(false), } .into_ref_with_type(vm, cls) } @@ -881,7 +923,7 @@ mod decl { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { // stop signal - if self.stop.get() { + if self.stop.load() { return Err(new_stop_iteration(vm)); } @@ -893,41 +935,36 @@ mod decl { } } + let idxs = self.idxs.write().unwrap(); + let res = PyTuple::from( pools .iter() - .zip(self.idxs.borrow().iter()) + .zip(idxs.iter()) .map(|(pool, idx)| pool[*idx].clone()) .collect::>(), ); - self.update_idxs(); - - if self.is_end() { - self.stop.set(true); - } + self.update_idxs(idxs); Ok(res.into_ref(vm).into_object()) } - fn is_end(&self) -> bool { - let cur = self.cur.get(); - self.idxs.borrow()[cur] == self.pools[cur].len() - 1 && cur == 0 - } - - fn update_idxs(&self) { - let lst_idx = &self.pools[self.cur.get()].len() - 1; + fn update_idxs(&self, mut idxs: RwLockWriteGuard<'_, Vec>) { + let cur = self.cur.load(); + let lst_idx = &self.pools[cur].len() - 1; - if self.idxs.borrow()[self.cur.get()] == lst_idx { - if self.is_end() { + if idxs[cur] == lst_idx { + if cur == 0 { + self.stop.store(true); return; } - self.idxs.borrow_mut()[self.cur.get()] = 0; - self.cur.set(self.cur.get() - 1); - self.update_idxs(); + idxs[cur] = 0; + self.cur.fetch_sub(1); + self.update_idxs(idxs); } else { - self.idxs.borrow_mut()[self.cur.get()] += 1; - self.cur.set(self.idxs.borrow().len() - 1); + idxs[cur] += 1; + self.cur.store(idxs.len() - 1); } } @@ -941,11 +978,13 @@ mod decl { #[derive(Debug)] struct PyItertoolsCombinations { pool: Vec, - indices: RefCell>, - r: Cell, - exhausted: Cell, + indices: RwLock>, + r: AtomicCell, + exhausted: AtomicCell, } + impl ThreadSafe for PyItertoolsCombinations {} + impl PyValue for PyItertoolsCombinations { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "combinations") @@ -974,9 +1013,9 @@ mod decl { PyItertoolsCombinations { pool, - indices: RefCell::new((0..r).collect()), - r: Cell::new(r), - exhausted: Cell::new(r > n), + indices: RwLock::new((0..r).collect()), + r: AtomicCell::new(r), + exhausted: AtomicCell::new(r > n), } .into_ref_with_type(vm, cls) } @@ -989,27 +1028,28 @@ mod decl { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { // stop signal - if self.exhausted.get() { + if self.exhausted.load() { return Err(new_stop_iteration(vm)); } let n = self.pool.len(); - let r = self.r.get(); + let r = self.r.load(); if r == 0 { - self.exhausted.set(true); + self.exhausted.store(true); return Ok(vm.ctx.new_tuple(vec![])); } let res = PyTuple::from( self.indices - .borrow() + .read() + .unwrap() .iter() .map(|&i| self.pool[i].clone()) .collect::>(), ); - let mut indices = self.indices.borrow_mut(); + let mut indices = self.indices.write().unwrap(); // Scan indices right-to-left until finding one that is not at its maximum (i + n - r). let mut idx = r as isize - 1; @@ -1020,7 +1060,7 @@ mod decl { // If no suitable index is found, then the indices are all at // their maximum value and we're done. if idx < 0 { - self.exhausted.set(true); + self.exhausted.store(true); } else { // Increment the current index which we know is not at its // maximum. Then move back to the right setting each index @@ -1040,11 +1080,13 @@ mod decl { #[derive(Debug)] struct PyItertoolsCombinationsWithReplacement { pool: Vec, - indices: RefCell>, - r: Cell, - exhausted: Cell, + indices: RwLock>, + r: AtomicCell, + exhausted: AtomicCell, } + impl ThreadSafe for PyItertoolsCombinationsWithReplacement {} + impl PyValue for PyItertoolsCombinationsWithReplacement { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "combinations_with_replacement") @@ -1073,9 +1115,9 @@ mod decl { PyItertoolsCombinationsWithReplacement { pool, - indices: RefCell::new(vec![0; r]), - r: Cell::new(r), - exhausted: Cell::new(n == 0 && r > 0), + indices: RwLock::new(vec![0; r]), + r: AtomicCell::new(r), + exhausted: AtomicCell::new(n == 0 && r > 0), } .into_ref_with_type(vm, cls) } @@ -1088,19 +1130,19 @@ mod decl { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { // stop signal - if self.exhausted.get() { + if self.exhausted.load() { return Err(new_stop_iteration(vm)); } let n = self.pool.len(); - let r = self.r.get(); + let r = self.r.load(); if r == 0 { - self.exhausted.set(true); + self.exhausted.store(true); return Ok(vm.ctx.new_tuple(vec![])); } - let mut indices = self.indices.borrow_mut(); + let mut indices = self.indices.write().unwrap(); let res = vm .ctx @@ -1115,7 +1157,7 @@ mod decl { // If no suitable index is found, then the indices are all at // their maximum value and we're done. if idx < 0 { - self.exhausted.set(true); + self.exhausted.store(true); } else { let index = indices[idx as usize] + 1; @@ -1133,14 +1175,16 @@ mod decl { #[pyclass(name = "permutations")] #[derive(Debug)] struct PyItertoolsPermutations { - pool: Vec, // Collected input iterable - indices: RefCell>, // One index per element in pool - cycles: RefCell>, // One rollover counter per element in the result - result: RefCell>>, // Indexes of the most recently returned result - r: Cell, // Size of result tuple - exhausted: Cell, // Set when the iterator is exhausted + pool: Vec, // Collected input iterable + indices: RwLock>, // One index per element in pool + cycles: RwLock>, // One rollover counter per element in the result + result: RwLock>>, // Indexes of the most recently returned result + r: AtomicCell, // Size of result tuple + exhausted: AtomicCell, // Set when the iterator is exhausted } + impl ThreadSafe for PyItertoolsPermutations {} + impl PyValue for PyItertoolsPermutations { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "permutations") @@ -1179,11 +1223,11 @@ mod decl { PyItertoolsPermutations { pool, - indices: RefCell::new((0..n).collect()), - cycles: RefCell::new((0..r).map(|i| n - i).collect()), - result: RefCell::new(None), - r: Cell::new(r), - exhausted: Cell::new(r > n), + indices: RwLock::new((0..n).collect()), + cycles: RwLock::new((0..r).map(|i| n - i).collect()), + result: RwLock::new(None), + r: AtomicCell::new(r), + exhausted: AtomicCell::new(r > n), } .into_ref_with_type(vm, cls) } @@ -1196,23 +1240,23 @@ mod decl { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { // stop signal - if self.exhausted.get() { + if self.exhausted.load() { return Err(new_stop_iteration(vm)); } let n = self.pool.len(); - let r = self.r.get(); + let r = self.r.load(); if n == 0 { - self.exhausted.set(true); + self.exhausted.store(true); return Ok(vm.ctx.new_tuple(vec![])); } - let result = &mut *self.result.borrow_mut(); + let mut result = self.result.write().unwrap(); - if let Some(ref mut result) = result { - let mut indices = self.indices.borrow_mut(); - let mut cycles = self.cycles.borrow_mut(); + if let Some(ref mut result) = *result { + let mut indices = self.indices.write().unwrap(); + let mut cycles = self.cycles.write().unwrap(); let mut sentinel = false; // Decrement rightmost cycle, moving leftward upon zero rollover @@ -1241,7 +1285,7 @@ mod decl { } } if !sentinel { - self.exhausted.set(true); + self.exhausted.store(true); return Err(new_stop_iteration(vm)); } } else { @@ -1265,9 +1309,10 @@ mod decl { struct PyItertoolsZipLongest { iterators: Vec, fillvalue: PyObjectRef, - numactive: Cell, } + impl ThreadSafe for PyItertoolsZipLongest {} + impl PyValue for PyItertoolsZipLongest { fn class(vm: &VirtualMachine) -> PyClassRef { vm.class("itertools", "zip_longest") @@ -1299,12 +1344,9 @@ mod decl { .map(|iterable| get_iter(vm, &iterable)) .collect::, _>>()?; - let numactive = Cell::new(iterators.len()); - PyItertoolsZipLongest { iterators, fillvalue, - numactive, } .into_ref_with_type(vm, cls) } @@ -1315,7 +1357,7 @@ mod decl { Err(new_stop_iteration(vm)) } else { let mut result: Vec = Vec::new(); - let mut numactive = self.numactive.get(); + let mut numactive = self.iterators.len(); for idx in 0..self.iterators.len() { let next_obj = match call_next(vm, &self.iterators[idx]) { From 8466f45f2a255d8285224615af4cac3cd793010c Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Sat, 9 May 2020 23:23:18 +0300 Subject: [PATCH 6/7] Fix clippy error --- vm/src/stdlib/itertools.rs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index 03573f1326..742b2b9384 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -60,13 +60,11 @@ mod decl { let iter = get_iter(vm, &self.iterables[pos])?; *self.cached_iter.write().unwrap() = Some(iter.clone()); iter + } else if let Some(cached_iter) = (*(self.cached_iter.read().unwrap())).clone() { + cached_iter } else { - if let Some(cached_iter) = (*(self.cached_iter.read().unwrap())).clone() { - cached_iter - } else { - // Someone changed cached iter to None since we checked. - continue; - } + // Someone changed cached iter to None since we checked. + continue; }; // We need to call "call_next" outside of the lock. From 25913a613f8507c4e3a23bef72c0f2f8d044d743 Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Sat, 9 May 2020 23:29:20 +0300 Subject: [PATCH 7/7] Remove expected failure from test_exhausted_iterator --- Lib/test/test_array.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/Lib/test/test_array.py b/Lib/test/test_array.py index 6bdbfe9f0a..7cca83d783 100644 --- a/Lib/test/test_array.py +++ b/Lib/test/test_array.py @@ -341,8 +341,6 @@ def test_iterator_pickle(self): a.fromlist(data2) self.assertEqual(list(it), []) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_exhausted_iterator(self): a = array.array(self.typecode, self.example) self.assertEqual(list(a), list(self.example))