Skip to content

Commit c8f4a91

Browse files
committed
Make PyHasher ThreadSafe
1 parent 094fc9d commit c8f4a91

File tree

1 file changed

+23
-10
lines changed

1 file changed

+23
-10
lines changed

vm/src/stdlib/hashlib.rs

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ use crate::function::{OptionalArg, PyFuncArgs};
22
use crate::obj::objbytes::{PyBytes, PyBytesRef};
33
use crate::obj::objstr::PyStringRef;
44
use crate::obj::objtype::PyClassRef;
5-
use crate::pyobject::{PyClassImpl, PyObjectRef, PyResult, PyValue};
5+
use crate::pyobject::{PyClassImpl, PyObjectRef, PyResult, PyValue, ThreadSafe};
66
use crate::vm::VirtualMachine;
7-
use std::cell::RefCell;
87
use std::fmt;
8+
use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
99

1010
use blake2::{Blake2b, Blake2s};
1111
use digest::DynDigest;
@@ -17,9 +17,11 @@ use sha3::{Sha3_224, Sha3_256, Sha3_384, Sha3_512}; // TODO: , Shake128, Shake25
1717
#[pyclass(name = "hasher")]
1818
struct PyHasher {
1919
name: String,
20-
buffer: RefCell<HashWrapper>,
20+
buffer: RwLock<HashWrapper>,
2121
}
2222

23+
impl ThreadSafe for PyHasher {}
24+
2325
impl fmt::Debug for PyHasher {
2426
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2527
write!(f, "hasher {}", self.name)
@@ -37,10 +39,18 @@ impl PyHasher {
3739
fn new(name: &str, d: HashWrapper) -> Self {
3840
PyHasher {
3941
name: name.to_owned(),
40-
buffer: RefCell::new(d),
42+
buffer: RwLock::new(d),
4143
}
4244
}
4345

46+
fn borrow_value(&self) -> RwLockReadGuard<'_, HashWrapper> {
47+
self.buffer.read().unwrap()
48+
}
49+
50+
fn borrow_value_mut(&self) -> RwLockWriteGuard<'_, HashWrapper> {
51+
self.buffer.write().unwrap()
52+
}
53+
4454
#[pyslot]
4555
fn tp_new(_cls: PyClassRef, _args: PyFuncArgs, vm: &VirtualMachine) -> PyResult {
4656
Ok(PyHasher::new("md5", HashWrapper::md5())
@@ -55,12 +65,12 @@ impl PyHasher {
5565

5666
#[pyproperty(name = "digest_size")]
5767
fn digest_size(&self, vm: &VirtualMachine) -> PyResult {
58-
Ok(vm.ctx.new_int(self.buffer.borrow().digest_size()))
68+
Ok(vm.ctx.new_int(self.borrow_value().digest_size()))
5969
}
6070

6171
#[pymethod(name = "update")]
6272
fn update(&self, data: PyBytesRef, vm: &VirtualMachine) -> PyResult {
63-
self.buffer.borrow_mut().input(data.get_value());
73+
self.borrow_value_mut().input(data.get_value());
6474
Ok(vm.get_none())
6575
}
6676

@@ -77,7 +87,7 @@ impl PyHasher {
7787
}
7888

7989
fn get_digest(&self) -> Vec<u8> {
80-
self.buffer.borrow().get_digest()
90+
self.borrow_value().get_digest()
8191
}
8292
}
8393

@@ -200,15 +210,18 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
200210
})
201211
}
202212

213+
trait ThreadSafeDynDigest: DynDigest + Sync + Send {}
214+
impl<T> ThreadSafeDynDigest for T where T: DynDigest + Sync + Send {}
215+
203216
/// Generic wrapper patching around the hashing libraries.
204217
struct HashWrapper {
205-
inner: Box<dyn DynDigest>,
218+
inner: Box<dyn ThreadSafeDynDigest>,
206219
}
207220

208221
impl HashWrapper {
209222
fn new<D: 'static>(d: D) -> Self
210223
where
211-
D: DynDigest + Sized,
224+
D: ThreadSafeDynDigest,
212225
{
213226
HashWrapper { inner: Box::new(d) }
214227
}
@@ -279,7 +292,7 @@ impl HashWrapper {
279292
}
280293

281294
fn get_digest(&self) -> Vec<u8> {
282-
let cloned = self.inner.clone();
295+
let cloned = self.inner.box_clone();
283296
cloned.result().to_vec()
284297
}
285298
}

0 commit comments

Comments
 (0)