Skip to content

Commit 094fc9d

Browse files
committed
Make PyArray ThreadSafe
1 parent b728a41 commit 094fc9d

File tree

1 file changed

+38
-28
lines changed

1 file changed

+38
-28
lines changed

vm/src/stdlib/array.rs

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@ use crate::obj::objtype::PyClassRef;
66
use crate::obj::{objbool, objiter};
77
use crate::pyobject::{
88
Either, IntoPyObject, PyClassImpl, PyIterable, PyObjectRef, PyRef, PyResult, PyValue,
9-
TryFromObject,
9+
ThreadSafe, TryFromObject,
1010
};
1111
use crate::VirtualMachine;
1212

13-
use std::cell::{Cell, RefCell};
13+
use std::cell::Cell;
1414
use std::fmt;
15+
use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
1516

1617
struct ArrayTypeSpecifierError {
1718
_priv: (),
@@ -223,8 +224,11 @@ def_array_enum!(
223224
#[pyclass(name = "array")]
224225
#[derive(Debug)]
225226
pub struct PyArray {
226-
array: RefCell<ArrayContentType>,
227+
array: RwLock<ArrayContentType>,
227228
}
229+
230+
impl ThreadSafe for PyArray {}
231+
228232
pub type PyArrayRef = PyRef<PyArray>;
229233

230234
impl PyValue for PyArray {
@@ -235,6 +239,14 @@ impl PyValue for PyArray {
235239

236240
#[pyimpl(flags(BASETYPE))]
237241
impl PyArray {
242+
fn borrow_value(&self) -> RwLockReadGuard<'_, ArrayContentType> {
243+
self.array.read().unwrap()
244+
}
245+
246+
fn borrow_value_mut(&self) -> RwLockWriteGuard<'_, ArrayContentType> {
247+
self.array.write().unwrap()
248+
}
249+
238250
#[pyslot]
239251
fn tp_new(
240252
cls: PyClassRef,
@@ -253,7 +265,7 @@ impl PyArray {
253265
let array =
254266
ArrayContentType::from_char(spec).map_err(|err| vm.new_value_error(err.to_string()))?;
255267
let zelf = PyArray {
256-
array: RefCell::new(array),
268+
array: RwLock::new(array),
257269
};
258270
if let OptionalArg::Present(init) = init {
259271
zelf.extend(init, vm)?;
@@ -263,33 +275,33 @@ impl PyArray {
263275

264276
#[pyproperty]
265277
fn typecode(&self) -> String {
266-
self.array.borrow().typecode().to_string()
278+
self.borrow_value().typecode().to_string()
267279
}
268280

269281
#[pyproperty]
270282
fn itemsize(&self) -> usize {
271-
self.array.borrow().itemsize()
283+
self.borrow_value().itemsize()
272284
}
273285

274286
#[pymethod]
275287
fn append(&self, x: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
276-
self.array.borrow_mut().push(x, vm)
288+
self.borrow_value_mut().push(x, vm)
277289
}
278290

279291
#[pymethod]
280292
fn buffer_info(&self) -> (usize, usize) {
281-
let array = self.array.borrow();
293+
let array = self.borrow_value();
282294
(array.addr(), array.len())
283295
}
284296

285297
#[pymethod]
286298
fn count(&self, x: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
287-
self.array.borrow().count(x, vm)
299+
self.borrow_value().count(x, vm)
288300
}
289301

290302
#[pymethod]
291303
fn extend(&self, iter: PyIterable, vm: &VirtualMachine) -> PyResult<()> {
292-
let mut array = self.array.borrow_mut();
304+
let mut array = self.borrow_value_mut();
293305
for elem in iter.iter(vm)? {
294306
array.push(elem?, vm)?;
295307
}
@@ -299,20 +311,19 @@ impl PyArray {
299311
#[pymethod]
300312
fn frombytes(&self, b: PyBytesRef, vm: &VirtualMachine) -> PyResult<()> {
301313
let b = b.get_value();
302-
let itemsize = self.array.borrow().itemsize();
314+
let itemsize = self.borrow_value().itemsize();
303315
if b.len() % itemsize != 0 {
304316
return Err(vm.new_value_error("bytes length not a multiple of item size".to_owned()));
305317
}
306318
if b.len() / itemsize > 0 {
307-
self.array.borrow_mut().frombytes(&b);
319+
self.borrow_value_mut().frombytes(&b);
308320
}
309321
Ok(())
310322
}
311323

312324
#[pymethod]
313325
fn index(&self, x: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
314-
self.array
315-
.borrow()
326+
self.borrow_value()
316327
.index(x, vm)?
317328
.ok_or_else(|| vm.new_value_error("x not in array".to_owned()))
318329
}
@@ -335,27 +346,27 @@ impl PyArray {
335346
i
336347
}
337348
};
338-
self.array.borrow_mut().insert(i, x, vm)
349+
self.borrow_value_mut().insert(i, x, vm)
339350
}
340351

341352
#[pymethod]
342353
fn pop(&self, i: OptionalArg<isize>, vm: &VirtualMachine) -> PyResult {
343354
if self.len() == 0 {
344355
Err(vm.new_index_error("pop from empty array".to_owned()))
345356
} else {
346-
let i = self.array.borrow().idx(i.unwrap_or(-1), "pop", vm)?;
347-
self.array.borrow_mut().pop(i, vm)
357+
let i = self.borrow_value().idx(i.unwrap_or(-1), "pop", vm)?;
358+
self.borrow_value_mut().pop(i, vm)
348359
}
349360
}
350361

351362
#[pymethod]
352363
pub(crate) fn tobytes(&self) -> Vec<u8> {
353-
self.array.borrow().tobytes()
364+
self.borrow_value().tobytes()
354365
}
355366

356367
#[pymethod]
357368
fn tolist(&self, vm: &VirtualMachine) -> PyResult {
358-
let array = self.array.borrow();
369+
let array = self.borrow_value();
359370
let mut v = Vec::with_capacity(array.len());
360371
for obj in array.iter(vm) {
361372
v.push(obj?);
@@ -365,12 +376,12 @@ impl PyArray {
365376

366377
#[pymethod]
367378
fn reverse(&self) {
368-
self.array.borrow_mut().reverse()
379+
self.borrow_value_mut().reverse()
369380
}
370381

371382
#[pymethod(magic)]
372383
fn getitem(&self, needle: Either<isize, PySliceRef>, vm: &VirtualMachine) -> PyResult {
373-
self.array.borrow().getitem(needle, vm)
384+
self.borrow_value().getitem(needle, vm)
374385
}
375386

376387
#[pymethod(magic)]
@@ -380,15 +391,15 @@ impl PyArray {
380391
obj: PyObjectRef,
381392
vm: &VirtualMachine,
382393
) -> PyResult<()> {
383-
self.array.borrow_mut().setitem(needle, obj, vm)
394+
self.borrow_value_mut().setitem(needle, obj, vm)
384395
}
385396

386397
#[pymethod(name = "__eq__")]
387398
fn eq(lhs: PyObjectRef, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult {
388399
let lhs = class_or_notimplemented!(vm, Self, lhs);
389400
let rhs = class_or_notimplemented!(vm, Self, rhs);
390-
let lhs = lhs.array.borrow();
391-
let rhs = rhs.array.borrow();
401+
let lhs = lhs.borrow_value();
402+
let rhs = rhs.borrow_value();
392403
if lhs.len() != rhs.len() {
393404
Ok(vm.new_bool(false))
394405
} else {
@@ -404,7 +415,7 @@ impl PyArray {
404415

405416
#[pymethod(name = "__len__")]
406417
fn len(&self) -> usize {
407-
self.array.borrow().len()
418+
self.borrow_value().len()
408419
}
409420

410421
#[pymethod(name = "__iter__")]
@@ -433,11 +444,10 @@ impl PyValue for PyArrayIter {
433444
impl PyArrayIter {
434445
#[pymethod(name = "__next__")]
435446
fn next(&self, vm: &VirtualMachine) -> PyResult {
436-
if self.position.get() < self.array.array.borrow().len() {
447+
if self.position.get() < self.array.borrow_value().len() {
437448
let ret = self
438449
.array
439-
.array
440-
.borrow()
450+
.borrow_value()
441451
.getitem_by_idx(self.position.get(), vm)
442452
.unwrap()?;
443453
self.position.set(self.position.get() + 1);

0 commit comments

Comments
 (0)