Skip to content

Make array and hashlib thread safe #1860

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 38 additions & 28 deletions vm/src/stdlib/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ use crate::obj::objtype::PyClassRef;
use crate::obj::{objbool, objiter};
use crate::pyobject::{
Either, IntoPyObject, PyClassImpl, PyIterable, PyObjectRef, PyRef, PyResult, PyValue,
TryFromObject,
ThreadSafe, TryFromObject,
};
use crate::VirtualMachine;

use std::cell::{Cell, RefCell};
use std::cell::Cell;
use std::fmt;
use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};

struct ArrayTypeSpecifierError {
_priv: (),
Expand Down Expand Up @@ -223,8 +224,11 @@ def_array_enum!(
#[pyclass(name = "array")]
#[derive(Debug)]
pub struct PyArray {
array: RefCell<ArrayContentType>,
array: RwLock<ArrayContentType>,
}

impl ThreadSafe for PyArray {}

pub type PyArrayRef = PyRef<PyArray>;

impl PyValue for PyArray {
Expand All @@ -235,6 +239,14 @@ impl PyValue for PyArray {

#[pyimpl(flags(BASETYPE))]
impl PyArray {
fn borrow_value(&self) -> RwLockReadGuard<'_, ArrayContentType> {
self.array.read().unwrap()
}

fn borrow_value_mut(&self) -> RwLockWriteGuard<'_, ArrayContentType> {
self.array.write().unwrap()
}

#[pyslot]
fn tp_new(
cls: PyClassRef,
Expand All @@ -253,7 +265,7 @@ impl PyArray {
let array =
ArrayContentType::from_char(spec).map_err(|err| vm.new_value_error(err.to_string()))?;
let zelf = PyArray {
array: RefCell::new(array),
array: RwLock::new(array),
};
if let OptionalArg::Present(init) = init {
zelf.extend(init, vm)?;
Expand All @@ -263,33 +275,33 @@ impl PyArray {

#[pyproperty]
fn typecode(&self) -> String {
self.array.borrow().typecode().to_string()
self.borrow_value().typecode().to_string()
}

#[pyproperty]
fn itemsize(&self) -> usize {
self.array.borrow().itemsize()
self.borrow_value().itemsize()
}

#[pymethod]
fn append(&self, x: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
self.array.borrow_mut().push(x, vm)
self.borrow_value_mut().push(x, vm)
}

#[pymethod]
fn buffer_info(&self) -> (usize, usize) {
let array = self.array.borrow();
let array = self.borrow_value();
(array.addr(), array.len())
}

#[pymethod]
fn count(&self, x: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
self.array.borrow().count(x, vm)
self.borrow_value().count(x, vm)
}

#[pymethod]
fn extend(&self, iter: PyIterable, vm: &VirtualMachine) -> PyResult<()> {
let mut array = self.array.borrow_mut();
let mut array = self.borrow_value_mut();
for elem in iter.iter(vm)? {
array.push(elem?, vm)?;
}
Expand All @@ -299,20 +311,19 @@ impl PyArray {
#[pymethod]
fn frombytes(&self, b: PyBytesRef, vm: &VirtualMachine) -> PyResult<()> {
let b = b.get_value();
let itemsize = self.array.borrow().itemsize();
let itemsize = self.borrow_value().itemsize();
if b.len() % itemsize != 0 {
return Err(vm.new_value_error("bytes length not a multiple of item size".to_owned()));
}
if b.len() / itemsize > 0 {
self.array.borrow_mut().frombytes(&b);
self.borrow_value_mut().frombytes(&b);
}
Ok(())
}

#[pymethod]
fn index(&self, x: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
self.array
.borrow()
self.borrow_value()
.index(x, vm)?
.ok_or_else(|| vm.new_value_error("x not in array".to_owned()))
}
Expand All @@ -335,27 +346,27 @@ impl PyArray {
i
}
};
self.array.borrow_mut().insert(i, x, vm)
self.borrow_value_mut().insert(i, x, vm)
}

#[pymethod]
fn pop(&self, i: OptionalArg<isize>, vm: &VirtualMachine) -> PyResult {
if self.len() == 0 {
Err(vm.new_index_error("pop from empty array".to_owned()))
} else {
let i = self.array.borrow().idx(i.unwrap_or(-1), "pop", vm)?;
self.array.borrow_mut().pop(i, vm)
let i = self.borrow_value().idx(i.unwrap_or(-1), "pop", vm)?;
self.borrow_value_mut().pop(i, vm)
}
}

#[pymethod]
pub(crate) fn tobytes(&self) -> Vec<u8> {
self.array.borrow().tobytes()
self.borrow_value().tobytes()
}

#[pymethod]
fn tolist(&self, vm: &VirtualMachine) -> PyResult {
let array = self.array.borrow();
let array = self.borrow_value();
let mut v = Vec::with_capacity(array.len());
for obj in array.iter(vm) {
v.push(obj?);
Expand All @@ -365,12 +376,12 @@ impl PyArray {

#[pymethod]
fn reverse(&self) {
self.array.borrow_mut().reverse()
self.borrow_value_mut().reverse()
}

#[pymethod(magic)]
fn getitem(&self, needle: Either<isize, PySliceRef>, vm: &VirtualMachine) -> PyResult {
self.array.borrow().getitem(needle, vm)
self.borrow_value().getitem(needle, vm)
}

#[pymethod(magic)]
Expand All @@ -380,15 +391,15 @@ impl PyArray {
obj: PyObjectRef,
vm: &VirtualMachine,
) -> PyResult<()> {
self.array.borrow_mut().setitem(needle, obj, vm)
self.borrow_value_mut().setitem(needle, obj, vm)
}

#[pymethod(name = "__eq__")]
fn eq(lhs: PyObjectRef, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult {
let lhs = class_or_notimplemented!(vm, Self, lhs);
let rhs = class_or_notimplemented!(vm, Self, rhs);
let lhs = lhs.array.borrow();
let rhs = rhs.array.borrow();
let lhs = lhs.borrow_value();
let rhs = rhs.borrow_value();
if lhs.len() != rhs.len() {
Ok(vm.new_bool(false))
} else {
Expand All @@ -404,7 +415,7 @@ impl PyArray {

#[pymethod(name = "__len__")]
fn len(&self) -> usize {
self.array.borrow().len()
self.borrow_value().len()
}

#[pymethod(name = "__iter__")]
Expand Down Expand Up @@ -433,11 +444,10 @@ impl PyValue for PyArrayIter {
impl PyArrayIter {
#[pymethod(name = "__next__")]
fn next(&self, vm: &VirtualMachine) -> PyResult {
if self.position.get() < self.array.array.borrow().len() {
if self.position.get() < self.array.borrow_value().len() {
let ret = self
.array
.array
.borrow()
.borrow_value()
.getitem_by_idx(self.position.get(), vm)
.unwrap()?;
self.position.set(self.position.get() + 1);
Expand Down
33 changes: 23 additions & 10 deletions vm/src/stdlib/hashlib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ use crate::function::{OptionalArg, PyFuncArgs};
use crate::obj::objbytes::{PyBytes, PyBytesRef};
use crate::obj::objstr::PyStringRef;
use crate::obj::objtype::PyClassRef;
use crate::pyobject::{PyClassImpl, PyObjectRef, PyResult, PyValue};
use crate::pyobject::{PyClassImpl, PyObjectRef, PyResult, PyValue, ThreadSafe};
use crate::vm::VirtualMachine;
use std::cell::RefCell;
use std::fmt;
use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};

use blake2::{Blake2b, Blake2s};
use digest::DynDigest;
Expand All @@ -17,9 +17,11 @@ use sha3::{Sha3_224, Sha3_256, Sha3_384, Sha3_512}; // TODO: , Shake128, Shake25
#[pyclass(name = "hasher")]
struct PyHasher {
name: String,
buffer: RefCell<HashWrapper>,
buffer: RwLock<HashWrapper>,
}

impl ThreadSafe for PyHasher {}

impl fmt::Debug for PyHasher {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "hasher {}", self.name)
Expand All @@ -37,10 +39,18 @@ impl PyHasher {
fn new(name: &str, d: HashWrapper) -> Self {
PyHasher {
name: name.to_owned(),
buffer: RefCell::new(d),
buffer: RwLock::new(d),
}
}

fn borrow_value(&self) -> RwLockReadGuard<'_, HashWrapper> {
self.buffer.read().unwrap()
}

fn borrow_value_mut(&self) -> RwLockWriteGuard<'_, HashWrapper> {
self.buffer.write().unwrap()
}

#[pyslot]
fn tp_new(_cls: PyClassRef, _args: PyFuncArgs, vm: &VirtualMachine) -> PyResult {
Ok(PyHasher::new("md5", HashWrapper::md5())
Expand All @@ -55,12 +65,12 @@ impl PyHasher {

#[pyproperty(name = "digest_size")]
fn digest_size(&self, vm: &VirtualMachine) -> PyResult {
Ok(vm.ctx.new_int(self.buffer.borrow().digest_size()))
Ok(vm.ctx.new_int(self.borrow_value().digest_size()))
}

#[pymethod(name = "update")]
fn update(&self, data: PyBytesRef, vm: &VirtualMachine) -> PyResult {
self.buffer.borrow_mut().input(data.get_value());
self.borrow_value_mut().input(data.get_value());
Ok(vm.get_none())
}

Expand All @@ -77,7 +87,7 @@ impl PyHasher {
}

fn get_digest(&self) -> Vec<u8> {
self.buffer.borrow().get_digest()
self.borrow_value().get_digest()
}
}

Expand Down Expand Up @@ -200,15 +210,18 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
})
}

trait ThreadSafeDynDigest: DynDigest + Sync + Send {}
impl<T> ThreadSafeDynDigest for T where T: DynDigest + Sync + Send {}

/// Generic wrapper patching around the hashing libraries.
struct HashWrapper {
inner: Box<dyn DynDigest>,
inner: Box<dyn ThreadSafeDynDigest>,
}

impl HashWrapper {
fn new<D: 'static>(d: D) -> Self
where
D: DynDigest + Sized,
D: ThreadSafeDynDigest,
{
HashWrapper { inner: Box::new(d) }
}
Expand Down Expand Up @@ -279,7 +292,7 @@ impl HashWrapper {
}

fn get_digest(&self) -> Vec<u8> {
let cloned = self.inner.clone();
let cloned = self.inner.box_clone();
cloned.result().to_vec()
}
}