diff --git a/vm/src/stdlib/array.rs b/vm/src/stdlib/array.rs index 12bdf7ab12..b4a437477e 100644 --- a/vm/src/stdlib/array.rs +++ b/vm/src/stdlib/array.rs @@ -6,12 +6,13 @@ use crate::obj::objtype::PyClassRef; use crate::obj::{objbool, objiter}; use crate::pyobject::{ Either, IntoPyObject, PyClassImpl, PyIterable, PyObjectRef, PyRef, PyResult, PyValue, - TryFromObject, + ThreadSafe, TryFromObject, }; use crate::VirtualMachine; -use std::cell::{Cell, RefCell}; +use std::cell::Cell; use std::fmt; +use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; struct ArrayTypeSpecifierError { _priv: (), @@ -223,8 +224,11 @@ def_array_enum!( #[pyclass(name = "array")] #[derive(Debug)] pub struct PyArray { - array: RefCell, + array: RwLock, } + +impl ThreadSafe for PyArray {} + pub type PyArrayRef = PyRef; impl PyValue for PyArray { @@ -235,6 +239,14 @@ impl PyValue for PyArray { #[pyimpl(flags(BASETYPE))] impl PyArray { + fn borrow_value(&self) -> RwLockReadGuard<'_, ArrayContentType> { + self.array.read().unwrap() + } + + fn borrow_value_mut(&self) -> RwLockWriteGuard<'_, ArrayContentType> { + self.array.write().unwrap() + } + #[pyslot] fn tp_new( cls: PyClassRef, @@ -253,7 +265,7 @@ impl PyArray { let array = ArrayContentType::from_char(spec).map_err(|err| vm.new_value_error(err.to_string()))?; let zelf = PyArray { - array: RefCell::new(array), + array: RwLock::new(array), }; if let OptionalArg::Present(init) = init { zelf.extend(init, vm)?; @@ -263,33 +275,33 @@ impl PyArray { #[pyproperty] fn typecode(&self) -> String { - self.array.borrow().typecode().to_string() + self.borrow_value().typecode().to_string() } #[pyproperty] fn itemsize(&self) -> usize { - self.array.borrow().itemsize() + self.borrow_value().itemsize() } #[pymethod] fn append(&self, x: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - self.array.borrow_mut().push(x, vm) + self.borrow_value_mut().push(x, vm) } #[pymethod] fn buffer_info(&self) -> (usize, usize) { - let array = self.array.borrow(); + let array = self.borrow_value(); (array.addr(), array.len()) } #[pymethod] fn count(&self, x: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.array.borrow().count(x, vm) + self.borrow_value().count(x, vm) } #[pymethod] fn extend(&self, iter: PyIterable, vm: &VirtualMachine) -> PyResult<()> { - let mut array = self.array.borrow_mut(); + let mut array = self.borrow_value_mut(); for elem in iter.iter(vm)? { array.push(elem?, vm)?; } @@ -299,20 +311,19 @@ impl PyArray { #[pymethod] fn frombytes(&self, b: PyBytesRef, vm: &VirtualMachine) -> PyResult<()> { let b = b.get_value(); - let itemsize = self.array.borrow().itemsize(); + let itemsize = self.borrow_value().itemsize(); if b.len() % itemsize != 0 { return Err(vm.new_value_error("bytes length not a multiple of item size".to_owned())); } if b.len() / itemsize > 0 { - self.array.borrow_mut().frombytes(&b); + self.borrow_value_mut().frombytes(&b); } Ok(()) } #[pymethod] fn index(&self, x: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.array - .borrow() + self.borrow_value() .index(x, vm)? .ok_or_else(|| vm.new_value_error("x not in array".to_owned())) } @@ -335,7 +346,7 @@ impl PyArray { i } }; - self.array.borrow_mut().insert(i, x, vm) + self.borrow_value_mut().insert(i, x, vm) } #[pymethod] @@ -343,19 +354,19 @@ impl PyArray { if self.len() == 0 { Err(vm.new_index_error("pop from empty array".to_owned())) } else { - let i = self.array.borrow().idx(i.unwrap_or(-1), "pop", vm)?; - self.array.borrow_mut().pop(i, vm) + let i = self.borrow_value().idx(i.unwrap_or(-1), "pop", vm)?; + self.borrow_value_mut().pop(i, vm) } } #[pymethod] pub(crate) fn tobytes(&self) -> Vec { - self.array.borrow().tobytes() + self.borrow_value().tobytes() } #[pymethod] fn tolist(&self, vm: &VirtualMachine) -> PyResult { - let array = self.array.borrow(); + let array = self.borrow_value(); let mut v = Vec::with_capacity(array.len()); for obj in array.iter(vm) { v.push(obj?); @@ -365,12 +376,12 @@ impl PyArray { #[pymethod] fn reverse(&self) { - self.array.borrow_mut().reverse() + self.borrow_value_mut().reverse() } #[pymethod(magic)] fn getitem(&self, needle: Either, vm: &VirtualMachine) -> PyResult { - self.array.borrow().getitem(needle, vm) + self.borrow_value().getitem(needle, vm) } #[pymethod(magic)] @@ -380,15 +391,15 @@ impl PyArray { obj: PyObjectRef, vm: &VirtualMachine, ) -> PyResult<()> { - self.array.borrow_mut().setitem(needle, obj, vm) + self.borrow_value_mut().setitem(needle, obj, vm) } #[pymethod(name = "__eq__")] fn eq(lhs: PyObjectRef, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult { let lhs = class_or_notimplemented!(vm, Self, lhs); let rhs = class_or_notimplemented!(vm, Self, rhs); - let lhs = lhs.array.borrow(); - let rhs = rhs.array.borrow(); + let lhs = lhs.borrow_value(); + let rhs = rhs.borrow_value(); if lhs.len() != rhs.len() { Ok(vm.new_bool(false)) } else { @@ -404,7 +415,7 @@ impl PyArray { #[pymethod(name = "__len__")] fn len(&self) -> usize { - self.array.borrow().len() + self.borrow_value().len() } #[pymethod(name = "__iter__")] @@ -433,11 +444,10 @@ impl PyValue for PyArrayIter { impl PyArrayIter { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - if self.position.get() < self.array.array.borrow().len() { + if self.position.get() < self.array.borrow_value().len() { let ret = self .array - .array - .borrow() + .borrow_value() .getitem_by_idx(self.position.get(), vm) .unwrap()?; self.position.set(self.position.get() + 1); diff --git a/vm/src/stdlib/hashlib.rs b/vm/src/stdlib/hashlib.rs index 318994e450..88319a535c 100644 --- a/vm/src/stdlib/hashlib.rs +++ b/vm/src/stdlib/hashlib.rs @@ -2,10 +2,10 @@ use crate::function::{OptionalArg, PyFuncArgs}; use crate::obj::objbytes::{PyBytes, PyBytesRef}; use crate::obj::objstr::PyStringRef; use crate::obj::objtype::PyClassRef; -use crate::pyobject::{PyClassImpl, PyObjectRef, PyResult, PyValue}; +use crate::pyobject::{PyClassImpl, PyObjectRef, PyResult, PyValue, ThreadSafe}; use crate::vm::VirtualMachine; -use std::cell::RefCell; use std::fmt; +use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; use blake2::{Blake2b, Blake2s}; use digest::DynDigest; @@ -17,9 +17,11 @@ use sha3::{Sha3_224, Sha3_256, Sha3_384, Sha3_512}; // TODO: , Shake128, Shake25 #[pyclass(name = "hasher")] struct PyHasher { name: String, - buffer: RefCell, + buffer: RwLock, } +impl ThreadSafe for PyHasher {} + impl fmt::Debug for PyHasher { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "hasher {}", self.name) @@ -37,10 +39,18 @@ impl PyHasher { fn new(name: &str, d: HashWrapper) -> Self { PyHasher { name: name.to_owned(), - buffer: RefCell::new(d), + buffer: RwLock::new(d), } } + fn borrow_value(&self) -> RwLockReadGuard<'_, HashWrapper> { + self.buffer.read().unwrap() + } + + fn borrow_value_mut(&self) -> RwLockWriteGuard<'_, HashWrapper> { + self.buffer.write().unwrap() + } + #[pyslot] fn tp_new(_cls: PyClassRef, _args: PyFuncArgs, vm: &VirtualMachine) -> PyResult { Ok(PyHasher::new("md5", HashWrapper::md5()) @@ -55,12 +65,12 @@ impl PyHasher { #[pyproperty(name = "digest_size")] fn digest_size(&self, vm: &VirtualMachine) -> PyResult { - Ok(vm.ctx.new_int(self.buffer.borrow().digest_size())) + Ok(vm.ctx.new_int(self.borrow_value().digest_size())) } #[pymethod(name = "update")] fn update(&self, data: PyBytesRef, vm: &VirtualMachine) -> PyResult { - self.buffer.borrow_mut().input(data.get_value()); + self.borrow_value_mut().input(data.get_value()); Ok(vm.get_none()) } @@ -77,7 +87,7 @@ impl PyHasher { } fn get_digest(&self) -> Vec { - self.buffer.borrow().get_digest() + self.borrow_value().get_digest() } } @@ -200,15 +210,18 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { }) } +trait ThreadSafeDynDigest: DynDigest + Sync + Send {} +impl ThreadSafeDynDigest for T where T: DynDigest + Sync + Send {} + /// Generic wrapper patching around the hashing libraries. struct HashWrapper { - inner: Box, + inner: Box, } impl HashWrapper { fn new(d: D) -> Self where - D: DynDigest + Sized, + D: ThreadSafeDynDigest, { HashWrapper { inner: Box::new(d) } } @@ -279,7 +292,7 @@ impl HashWrapper { } fn get_digest(&self) -> Vec { - let cloned = self.inner.clone(); + let cloned = self.inner.box_clone(); cloned.result().to_vec() } }