Skip to content

Commit ee0d6f6

Browse files
authored
Merge pull request #1860 from palaviv/theading-mark
Make array and hashlib thread safe
2 parents a12671d + c8f4a91 commit ee0d6f6

File tree

2 files changed

+61
-38
lines changed

2 files changed

+61
-38
lines changed

vm/src/stdlib/array.rs

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@ use crate::obj::objtype::PyClassRef;
66
use crate::obj::{objbool, objiter};
77
use crate::pyobject::{
88
Either, IntoPyObject, PyClassImpl, PyIterable, PyObjectRef, PyRef, PyResult, PyValue,
9-
TryFromObject,
9+
ThreadSafe, TryFromObject,
1010
};
1111
use crate::VirtualMachine;
1212

13-
use std::cell::{Cell, RefCell};
13+
use std::cell::Cell;
1414
use std::fmt;
15+
use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
1516

1617
struct ArrayTypeSpecifierError {
1718
_priv: (),
@@ -223,8 +224,11 @@ def_array_enum!(
223224
#[pyclass(name = "array")]
224225
#[derive(Debug)]
225226
pub struct PyArray {
226-
array: RefCell<ArrayContentType>,
227+
array: RwLock<ArrayContentType>,
227228
}
229+
230+
impl ThreadSafe for PyArray {}
231+
228232
pub type PyArrayRef = PyRef<PyArray>;
229233

230234
impl PyValue for PyArray {
@@ -235,6 +239,14 @@ impl PyValue for PyArray {
235239

236240
#[pyimpl(flags(BASETYPE))]
237241
impl PyArray {
242+
fn borrow_value(&self) -> RwLockReadGuard<'_, ArrayContentType> {
243+
self.array.read().unwrap()
244+
}
245+
246+
fn borrow_value_mut(&self) -> RwLockWriteGuard<'_, ArrayContentType> {
247+
self.array.write().unwrap()
248+
}
249+
238250
#[pyslot]
239251
fn tp_new(
240252
cls: PyClassRef,
@@ -253,7 +265,7 @@ impl PyArray {
253265
let array =
254266
ArrayContentType::from_char(spec).map_err(|err| vm.new_value_error(err.to_string()))?;
255267
let zelf = PyArray {
256-
array: RefCell::new(array),
268+
array: RwLock::new(array),
257269
};
258270
if let OptionalArg::Present(init) = init {
259271
zelf.extend(init, vm)?;
@@ -263,33 +275,33 @@ impl PyArray {
263275

264276
#[pyproperty]
265277
fn typecode(&self) -> String {
266-
self.array.borrow().typecode().to_string()
278+
self.borrow_value().typecode().to_string()
267279
}
268280

269281
#[pyproperty]
270282
fn itemsize(&self) -> usize {
271-
self.array.borrow().itemsize()
283+
self.borrow_value().itemsize()
272284
}
273285

274286
#[pymethod]
275287
fn append(&self, x: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
276-
self.array.borrow_mut().push(x, vm)
288+
self.borrow_value_mut().push(x, vm)
277289
}
278290

279291
#[pymethod]
280292
fn buffer_info(&self) -> (usize, usize) {
281-
let array = self.array.borrow();
293+
let array = self.borrow_value();
282294
(array.addr(), array.len())
283295
}
284296

285297
#[pymethod]
286298
fn count(&self, x: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
287-
self.array.borrow().count(x, vm)
299+
self.borrow_value().count(x, vm)
288300
}
289301

290302
#[pymethod]
291303
fn extend(&self, iter: PyIterable, vm: &VirtualMachine) -> PyResult<()> {
292-
let mut array = self.array.borrow_mut();
304+
let mut array = self.borrow_value_mut();
293305
for elem in iter.iter(vm)? {
294306
array.push(elem?, vm)?;
295307
}
@@ -299,20 +311,19 @@ impl PyArray {
299311
#[pymethod]
300312
fn frombytes(&self, b: PyBytesRef, vm: &VirtualMachine) -> PyResult<()> {
301313
let b = b.get_value();
302-
let itemsize = self.array.borrow().itemsize();
314+
let itemsize = self.borrow_value().itemsize();
303315
if b.len() % itemsize != 0 {
304316
return Err(vm.new_value_error("bytes length not a multiple of item size".to_owned()));
305317
}
306318
if b.len() / itemsize > 0 {
307-
self.array.borrow_mut().frombytes(&b);
319+
self.borrow_value_mut().frombytes(&b);
308320
}
309321
Ok(())
310322
}
311323

312324
#[pymethod]
313325
fn index(&self, x: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
314-
self.array
315-
.borrow()
326+
self.borrow_value()
316327
.index(x, vm)?
317328
.ok_or_else(|| vm.new_value_error("x not in array".to_owned()))
318329
}
@@ -335,27 +346,27 @@ impl PyArray {
335346
i
336347
}
337348
};
338-
self.array.borrow_mut().insert(i, x, vm)
349+
self.borrow_value_mut().insert(i, x, vm)
339350
}
340351

341352
#[pymethod]
342353
fn pop(&self, i: OptionalArg<isize>, vm: &VirtualMachine) -> PyResult {
343354
if self.len() == 0 {
344355
Err(vm.new_index_error("pop from empty array".to_owned()))
345356
} else {
346-
let i = self.array.borrow().idx(i.unwrap_or(-1), "pop", vm)?;
347-
self.array.borrow_mut().pop(i, vm)
357+
let i = self.borrow_value().idx(i.unwrap_or(-1), "pop", vm)?;
358+
self.borrow_value_mut().pop(i, vm)
348359
}
349360
}
350361

351362
#[pymethod]
352363
pub(crate) fn tobytes(&self) -> Vec<u8> {
353-
self.array.borrow().tobytes()
364+
self.borrow_value().tobytes()
354365
}
355366

356367
#[pymethod]
357368
fn tolist(&self, vm: &VirtualMachine) -> PyResult {
358-
let array = self.array.borrow();
369+
let array = self.borrow_value();
359370
let mut v = Vec::with_capacity(array.len());
360371
for obj in array.iter(vm) {
361372
v.push(obj?);
@@ -365,12 +376,12 @@ impl PyArray {
365376

366377
#[pymethod]
367378
fn reverse(&self) {
368-
self.array.borrow_mut().reverse()
379+
self.borrow_value_mut().reverse()
369380
}
370381

371382
#[pymethod(magic)]
372383
fn getitem(&self, needle: Either<isize, PySliceRef>, vm: &VirtualMachine) -> PyResult {
373-
self.array.borrow().getitem(needle, vm)
384+
self.borrow_value().getitem(needle, vm)
374385
}
375386

376387
#[pymethod(magic)]
@@ -380,15 +391,15 @@ impl PyArray {
380391
obj: PyObjectRef,
381392
vm: &VirtualMachine,
382393
) -> PyResult<()> {
383-
self.array.borrow_mut().setitem(needle, obj, vm)
394+
self.borrow_value_mut().setitem(needle, obj, vm)
384395
}
385396

386397
#[pymethod(name = "__eq__")]
387398
fn eq(lhs: PyObjectRef, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult {
388399
let lhs = class_or_notimplemented!(vm, Self, lhs);
389400
let rhs = class_or_notimplemented!(vm, Self, rhs);
390-
let lhs = lhs.array.borrow();
391-
let rhs = rhs.array.borrow();
401+
let lhs = lhs.borrow_value();
402+
let rhs = rhs.borrow_value();
392403
if lhs.len() != rhs.len() {
393404
Ok(vm.new_bool(false))
394405
} else {
@@ -404,7 +415,7 @@ impl PyArray {
404415

405416
#[pymethod(name = "__len__")]
406417
fn len(&self) -> usize {
407-
self.array.borrow().len()
418+
self.borrow_value().len()
408419
}
409420

410421
#[pymethod(name = "__iter__")]
@@ -433,11 +444,10 @@ impl PyValue for PyArrayIter {
433444
impl PyArrayIter {
434445
#[pymethod(name = "__next__")]
435446
fn next(&self, vm: &VirtualMachine) -> PyResult {
436-
if self.position.get() < self.array.array.borrow().len() {
447+
if self.position.get() < self.array.borrow_value().len() {
437448
let ret = self
438449
.array
439-
.array
440-
.borrow()
450+
.borrow_value()
441451
.getitem_by_idx(self.position.get(), vm)
442452
.unwrap()?;
443453
self.position.set(self.position.get() + 1);

vm/src/stdlib/hashlib.rs

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ use crate::function::{OptionalArg, PyFuncArgs};
22
use crate::obj::objbytes::{PyBytes, PyBytesRef};
33
use crate::obj::objstr::PyStringRef;
44
use crate::obj::objtype::PyClassRef;
5-
use crate::pyobject::{PyClassImpl, PyObjectRef, PyResult, PyValue};
5+
use crate::pyobject::{PyClassImpl, PyObjectRef, PyResult, PyValue, ThreadSafe};
66
use crate::vm::VirtualMachine;
7-
use std::cell::RefCell;
87
use std::fmt;
8+
use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
99

1010
use blake2::{Blake2b, Blake2s};
1111
use digest::DynDigest;
@@ -17,9 +17,11 @@ use sha3::{Sha3_224, Sha3_256, Sha3_384, Sha3_512}; // TODO: , Shake128, Shake25
1717
#[pyclass(name = "hasher")]
1818
struct PyHasher {
1919
name: String,
20-
buffer: RefCell<HashWrapper>,
20+
buffer: RwLock<HashWrapper>,
2121
}
2222

23+
impl ThreadSafe for PyHasher {}
24+
2325
impl fmt::Debug for PyHasher {
2426
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2527
write!(f, "hasher {}", self.name)
@@ -37,10 +39,18 @@ impl PyHasher {
3739
fn new(name: &str, d: HashWrapper) -> Self {
3840
PyHasher {
3941
name: name.to_owned(),
40-
buffer: RefCell::new(d),
42+
buffer: RwLock::new(d),
4143
}
4244
}
4345

46+
fn borrow_value(&self) -> RwLockReadGuard<'_, HashWrapper> {
47+
self.buffer.read().unwrap()
48+
}
49+
50+
fn borrow_value_mut(&self) -> RwLockWriteGuard<'_, HashWrapper> {
51+
self.buffer.write().unwrap()
52+
}
53+
4454
#[pyslot]
4555
fn tp_new(_cls: PyClassRef, _args: PyFuncArgs, vm: &VirtualMachine) -> PyResult {
4656
Ok(PyHasher::new("md5", HashWrapper::md5())
@@ -55,12 +65,12 @@ impl PyHasher {
5565

5666
#[pyproperty(name = "digest_size")]
5767
fn digest_size(&self, vm: &VirtualMachine) -> PyResult {
58-
Ok(vm.ctx.new_int(self.buffer.borrow().digest_size()))
68+
Ok(vm.ctx.new_int(self.borrow_value().digest_size()))
5969
}
6070

6171
#[pymethod(name = "update")]
6272
fn update(&self, data: PyBytesRef, vm: &VirtualMachine) -> PyResult {
63-
self.buffer.borrow_mut().input(data.get_value());
73+
self.borrow_value_mut().input(data.get_value());
6474
Ok(vm.get_none())
6575
}
6676

@@ -77,7 +87,7 @@ impl PyHasher {
7787
}
7888

7989
fn get_digest(&self) -> Vec<u8> {
80-
self.buffer.borrow().get_digest()
90+
self.borrow_value().get_digest()
8191
}
8292
}
8393

@@ -200,15 +210,18 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
200210
})
201211
}
202212

213+
trait ThreadSafeDynDigest: DynDigest + Sync + Send {}
214+
impl<T> ThreadSafeDynDigest for T where T: DynDigest + Sync + Send {}
215+
203216
/// Generic wrapper patching around the hashing libraries.
204217
struct HashWrapper {
205-
inner: Box<dyn DynDigest>,
218+
inner: Box<dyn ThreadSafeDynDigest>,
206219
}
207220

208221
impl HashWrapper {
209222
fn new<D: 'static>(d: D) -> Self
210223
where
211-
D: DynDigest + Sized,
224+
D: ThreadSafeDynDigest,
212225
{
213226
HashWrapper { inner: Box::new(d) }
214227
}
@@ -279,7 +292,7 @@ impl HashWrapper {
279292
}
280293

281294
fn get_digest(&self) -> Vec<u8> {
282-
let cloned = self.inner.clone();
295+
let cloned = self.inner.box_clone();
283296
cloned.result().to_vec()
284297
}
285298
}

0 commit comments

Comments
 (0)