From bcd518633a783caa3de423fa1466ae15d8144332 Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Thu, 9 Apr 2020 11:54:59 +0300 Subject: [PATCH 1/7] Add ThreadSafe trait --- vm/src/pyobject.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index aed76576f0..2ace93e5eb 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -1204,6 +1204,9 @@ pub trait PyValue: fmt::Debug + Sized + 'static { } } +// Temporary trait to follow the progress of threading conversion +pub trait ThreadSafe: Send + Sync {} + pub trait PyObjectPayload: Any + fmt::Debug + 'static { fn as_any(&self) -> &dyn Any; } From ad883375d75943c56dd4a740659ce14c6ad4ca73 Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Thu, 9 Apr 2020 11:55:13 +0300 Subject: [PATCH 2/7] Mark PyInt as ThreadSafe --- vm/src/obj/objint.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vm/src/obj/objint.rs b/vm/src/obj/objint.rs index 42ff7677da..b633849c03 100644 --- a/vm/src/obj/objint.rs +++ b/vm/src/obj/objint.rs @@ -19,7 +19,7 @@ use crate::function::{OptionalArg, PyFuncArgs}; use crate::pyhash; use crate::pyobject::{ IdProtocol, IntoPyObject, PyArithmaticValue, PyClassImpl, PyComparisonValue, PyContext, - PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, + PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe, TryFromObject, TypeProtocol, }; use crate::stdlib::array::PyArray; use crate::vm::VirtualMachine; @@ -44,6 +44,8 @@ pub struct PyInt { value: BigInt, } +impl ThreadSafe for PyInt {} + impl fmt::Display for PyInt { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { BigInt::fmt(&self.value, f) From cd42d04922927c05fd1abacdf5442c3d4db333c6 Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Thu, 9 Apr 2020 11:57:37 +0300 Subject: [PATCH 3/7] Mark PyFloat as ThreadSafe --- vm/src/obj/objfloat.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vm/src/obj/objfloat.rs b/vm/src/obj/objfloat.rs index cbb5196d79..b29e2cbc21 100644 --- a/vm/src/obj/objfloat.rs +++ b/vm/src/obj/objfloat.rs @@ -13,7 +13,7 @@ use crate::function::{OptionalArg, OptionalOption}; use crate::pyhash; use crate::pyobject::{ IntoPyObject, PyArithmaticValue::*, PyClassImpl, PyComparisonValue, PyContext, PyObjectRef, - PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, + PyRef, PyResult, PyValue, ThreadSafe, TryFromObject, TypeProtocol, }; use crate::vm::VirtualMachine; @@ -24,6 +24,8 @@ pub struct PyFloat { value: f64, } +impl ThreadSafe for PyFloat {} + impl PyFloat { pub fn to_f64(self) -> f64 { self.value From 2c5237bec86f321c94871da3471aa43815c7d0c7 Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Thu, 9 Apr 2020 11:58:45 +0300 Subject: [PATCH 4/7] Mark PyComplex as ThreadSafe --- vm/src/obj/objcomplex.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vm/src/obj/objcomplex.rs b/vm/src/obj/objcomplex.rs index c428820470..6a05ebc776 100644 --- a/vm/src/obj/objcomplex.rs +++ b/vm/src/obj/objcomplex.rs @@ -7,7 +7,7 @@ use super::objtype::PyClassRef; use crate::function::OptionalArg; use crate::pyhash; use crate::pyobject::{ - IntoPyObject, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, + IntoPyObject, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe, }; use crate::vm::VirtualMachine; @@ -19,6 +19,9 @@ use crate::vm::VirtualMachine; pub struct PyComplex { value: Complex64, } + +impl ThreadSafe for PyComplex {} + type PyComplexRef = PyRef; impl PyValue for PyComplex { From 6e89c6e014f48d167b519bb9157efaaf72b69052 Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Thu, 9 Apr 2020 12:04:26 +0300 Subject: [PATCH 5/7] Mark PyByteInner and PyBytes as ThreadSafe --- vm/src/obj/objbyteinner.rs | 5 ++++- vm/src/obj/objbytes.rs | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/vm/src/obj/objbyteinner.rs b/vm/src/obj/objbyteinner.rs index 47bceeb0c7..a24dd7e380 100644 --- a/vm/src/obj/objbyteinner.rs +++ b/vm/src/obj/objbyteinner.rs @@ -18,7 +18,8 @@ use super::objtuple::PyTupleRef; use crate::function::OptionalArg; use crate::pyhash; use crate::pyobject::{ - Either, PyComparisonValue, PyIterable, PyObjectRef, PyResult, TryFromObject, TypeProtocol, + Either, PyComparisonValue, PyIterable, PyObjectRef, PyResult, ThreadSafe, TryFromObject, + TypeProtocol, }; use crate::vm::VirtualMachine; @@ -27,6 +28,8 @@ pub struct PyByteInner { pub elements: Vec, } +impl ThreadSafe for PyByteInner {} + impl TryFromObject for PyByteInner { fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { match_class!(match obj { diff --git a/vm/src/obj/objbytes.rs b/vm/src/obj/objbytes.rs index 18d8a7698d..7e277e4a4a 100644 --- a/vm/src/obj/objbytes.rs +++ b/vm/src/obj/objbytes.rs @@ -21,7 +21,7 @@ use crate::pyobject::{ Either, IntoPyObject, PyArithmaticValue::{self, *}, PyClassImpl, PyComparisonValue, PyContext, PyIterable, PyObjectRef, PyRef, PyResult, PyValue, - TryFromObject, TypeProtocol, + ThreadSafe, TryFromObject, TypeProtocol, }; use crate::vm::VirtualMachine; use std::str::FromStr; @@ -40,6 +40,9 @@ use std::str::FromStr; pub struct PyBytes { inner: PyByteInner, } + +impl ThreadSafe for PyBytes {} + pub type PyBytesRef = PyRef; impl PyBytes { From bcf875b5ca2712078d472c299990dbf9ef931820 Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Thu, 9 Apr 2020 12:36:36 +0300 Subject: [PATCH 6/7] Make PyByteArray ThreadSafe --- vm/src/obj/objbytearray.rs | 167 ++++++++++++++++++++----------------- 1 file changed, 92 insertions(+), 75 deletions(-) diff --git a/vm/src/obj/objbytearray.rs b/vm/src/obj/objbytearray.rs index 5ecc2f98ca..0113ad957c 100644 --- a/vm/src/obj/objbytearray.rs +++ b/vm/src/obj/objbytearray.rs @@ -1,6 +1,7 @@ //! Implementation of the python bytearray object. -use std::cell::{Cell, RefCell}; +use std::cell::Cell; use std::convert::TryFrom; +use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; use super::objbyteinner::{ ByteInnerExpandtabsOptions, ByteInnerFindOptions, ByteInnerNewOptions, ByteInnerPaddingOptions, @@ -18,7 +19,7 @@ use crate::function::OptionalArg; use crate::obj::objstr::do_cformat_string; use crate::pyobject::{ Either, PyClassImpl, PyComparisonValue, PyContext, PyIterable, PyObjectRef, PyRef, PyResult, - PyValue, TryFromObject, TypeProtocol, + PyValue, ThreadSafe, TryFromObject, TypeProtocol, }; use crate::vm::VirtualMachine; use std::mem::size_of; @@ -36,31 +37,34 @@ use std::str::FromStr; /// - any object implementing the buffer API.\n \ /// - an integer"; #[pyclass(name = "bytearray")] -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct PyByteArray { - inner: RefCell, + inner: RwLock, } + +impl ThreadSafe for PyByteArray {} + pub type PyByteArrayRef = PyRef; impl PyByteArray { pub fn new(data: Vec) -> Self { PyByteArray { - inner: RefCell::new(PyByteInner { elements: data }), + inner: RwLock::new(PyByteInner { elements: data }), } } fn from_inner(inner: PyByteInner) -> Self { PyByteArray { - inner: RefCell::new(inner), + inner: RwLock::new(inner), } } - pub fn borrow_value(&self) -> std::cell::Ref<'_, PyByteInner> { - self.inner.borrow() + pub fn borrow_value(&self) -> RwLockReadGuard<'_, PyByteInner> { + self.inner.read().unwrap() } - pub fn borrow_value_mut(&self) -> std::cell::RefMut<'_, PyByteInner> { - self.inner.borrow_mut() + pub fn borrow_value_mut(&self) -> RwLockWriteGuard<'_, PyByteInner> { + self.inner.write().unwrap() } } @@ -100,42 +104,45 @@ impl PyByteArray { #[pymethod(name = "__repr__")] fn repr(&self) -> PyResult { - Ok(format!("bytearray(b'{}')", self.inner.borrow().repr()?)) + Ok(format!( + "bytearray(b'{}')", + self.inner.read().unwrap().repr()? + )) } #[pymethod(name = "__len__")] fn len(&self) -> usize { - self.inner.borrow().len() + self.inner.read().unwrap().len() } #[pymethod(name = "__sizeof__")] fn sizeof(&self) -> usize { - size_of::() + self.inner.borrow().len() * size_of::() + size_of::() + self.inner.read().unwrap().len() * size_of::() } #[pymethod(name = "__eq__")] fn eq(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue { - self.inner.borrow().eq(other, vm) + self.inner.read().unwrap().eq(other, vm) } #[pymethod(name = "__ge__")] fn ge(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue { - self.inner.borrow().ge(other, vm) + self.inner.read().unwrap().ge(other, vm) } #[pymethod(name = "__le__")] fn le(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue { - self.inner.borrow().le(other, vm) + self.inner.read().unwrap().le(other, vm) } #[pymethod(name = "__gt__")] fn gt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue { - self.inner.borrow().gt(other, vm) + self.inner.read().unwrap().gt(other, vm) } #[pymethod(name = "__lt__")] fn lt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue { - self.inner.borrow().lt(other, vm) + self.inner.read().unwrap().lt(other, vm) } #[pymethod(name = "__hash__")] @@ -154,7 +161,7 @@ impl PyByteArray { #[pymethod(name = "__add__")] fn add(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { if let Ok(other) = PyByteInner::try_from_object(vm, other) { - Ok(vm.ctx.new_bytearray(self.inner.borrow().add(other))) + Ok(vm.ctx.new_bytearray(self.inner.read().unwrap().add(other))) } else { Ok(vm.ctx.not_implemented()) } @@ -166,12 +173,12 @@ impl PyByteArray { needle: Either, vm: &VirtualMachine, ) -> PyResult { - self.inner.borrow().contains(needle, vm) + self.inner.read().unwrap().contains(needle, vm) } #[pymethod(name = "__getitem__")] fn getitem(&self, needle: Either, vm: &VirtualMachine) -> PyResult { - self.inner.borrow().getitem(needle, vm) + self.inner.read().unwrap().getitem(needle, vm) } #[pymethod(name = "__setitem__")] @@ -181,77 +188,77 @@ impl PyByteArray { value: PyObjectRef, vm: &VirtualMachine, ) -> PyResult { - self.inner.borrow_mut().setitem(needle, value, vm) + self.inner.write().unwrap().setitem(needle, value, vm) } #[pymethod(name = "__delitem__")] fn delitem(&self, needle: Either, vm: &VirtualMachine) -> PyResult<()> { - self.inner.borrow_mut().delitem(needle, vm) + self.inner.write().unwrap().delitem(needle, vm) } #[pymethod(name = "isalnum")] fn isalnum(&self) -> bool { - self.inner.borrow().isalnum() + self.inner.read().unwrap().isalnum() } #[pymethod(name = "isalpha")] fn isalpha(&self) -> bool { - self.inner.borrow().isalpha() + self.inner.read().unwrap().isalpha() } #[pymethod(name = "isascii")] fn isascii(&self) -> bool { - self.inner.borrow().isascii() + self.inner.read().unwrap().isascii() } #[pymethod(name = "isdigit")] fn isdigit(&self) -> bool { - self.inner.borrow().isdigit() + self.inner.read().unwrap().isdigit() } #[pymethod(name = "islower")] fn islower(&self) -> bool { - self.inner.borrow().islower() + self.inner.read().unwrap().islower() } #[pymethod(name = "isspace")] fn isspace(&self) -> bool { - self.inner.borrow().isspace() + self.inner.read().unwrap().isspace() } #[pymethod(name = "isupper")] fn isupper(&self) -> bool { - self.inner.borrow().isupper() + self.inner.read().unwrap().isupper() } #[pymethod(name = "istitle")] fn istitle(&self) -> bool { - self.inner.borrow().istitle() + self.inner.read().unwrap().istitle() } #[pymethod(name = "lower")] fn lower(&self) -> PyByteArray { - self.inner.borrow().lower().into() + self.inner.read().unwrap().lower().into() } #[pymethod(name = "upper")] fn upper(&self) -> PyByteArray { - self.inner.borrow().upper().into() + self.inner.read().unwrap().upper().into() } #[pymethod(name = "capitalize")] fn capitalize(&self) -> PyByteArray { - self.inner.borrow().capitalize().into() + self.inner.read().unwrap().capitalize().into() } #[pymethod(name = "swapcase")] fn swapcase(&self) -> PyByteArray { - self.inner.borrow().swapcase().into() + self.inner.read().unwrap().swapcase().into() } #[pymethod(name = "hex")] fn hex(&self) -> String { - self.inner.borrow().hex() + self.inner.read().unwrap().hex() } #[pymethod] @@ -265,7 +272,7 @@ impl PyByteArray { options: ByteInnerPaddingOptions, vm: &VirtualMachine, ) -> PyResult { - Ok(self.inner.borrow().center(options, vm)?.into()) + Ok(self.inner.read().unwrap().center(options, vm)?.into()) } #[pymethod(name = "ljust")] @@ -274,7 +281,7 @@ impl PyByteArray { options: ByteInnerPaddingOptions, vm: &VirtualMachine, ) -> PyResult { - Ok(self.inner.borrow().ljust(options, vm)?.into()) + Ok(self.inner.read().unwrap().ljust(options, vm)?.into()) } #[pymethod(name = "rjust")] @@ -283,17 +290,17 @@ impl PyByteArray { options: ByteInnerPaddingOptions, vm: &VirtualMachine, ) -> PyResult { - Ok(self.inner.borrow().rjust(options, vm)?.into()) + Ok(self.inner.read().unwrap().rjust(options, vm)?.into()) } #[pymethod(name = "count")] fn count(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { - self.inner.borrow().count(options, vm) + self.inner.read().unwrap().count(options, vm) } #[pymethod(name = "join")] fn join(&self, iter: PyIterable, vm: &VirtualMachine) -> PyResult { - Ok(self.inner.borrow().join(iter, vm)?.into()) + Ok(self.inner.read().unwrap().join(iter, vm)?.into()) } #[pymethod(name = "endswith")] @@ -305,7 +312,8 @@ impl PyByteArray { vm: &VirtualMachine, ) -> PyResult { self.inner - .borrow() + .read() + .unwrap() .startsendswith(suffix, start, end, true, vm) } @@ -318,31 +326,32 @@ impl PyByteArray { vm: &VirtualMachine, ) -> PyResult { self.inner - .borrow() + .read() + .unwrap() .startsendswith(prefix, start, end, false, vm) } #[pymethod(name = "find")] fn find(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { - let index = self.inner.borrow().find(options, false, vm)?; + let index = self.inner.read().unwrap().find(options, false, vm)?; Ok(index.map_or(-1, |v| v as isize)) } #[pymethod(name = "index")] fn index(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { - let index = self.inner.borrow().find(options, false, vm)?; + let index = self.inner.read().unwrap().find(options, false, vm)?; index.ok_or_else(|| vm.new_value_error("substring not found".to_owned())) } #[pymethod(name = "rfind")] fn rfind(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { - let index = self.inner.borrow().find(options, true, vm)?; + let index = self.inner.read().unwrap().find(options, true, vm)?; Ok(index.map_or(-1, |v| v as isize)) } #[pymethod(name = "rindex")] fn rindex(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { - let index = self.inner.borrow().find(options, true, vm)?; + let index = self.inner.read().unwrap().find(options, true, vm)?; index.ok_or_else(|| vm.new_value_error("substring not found".to_owned())) } @@ -350,7 +359,7 @@ impl PyByteArray { fn remove(&self, x: PyIntRef, vm: &VirtualMachine) -> PyResult<()> { let x = x.as_bigint().byte_or(vm)?; - let bytes = &mut self.inner.borrow_mut().elements; + let bytes = &mut self.inner.write().unwrap().elements; let pos = bytes .iter() .position(|b| *b == x) @@ -367,14 +376,15 @@ impl PyByteArray { options: ByteInnerTranslateOptions, vm: &VirtualMachine, ) -> PyResult { - Ok(self.inner.borrow().translate(options, vm)?.into()) + Ok(self.inner.read().unwrap().translate(options, vm)?.into()) } #[pymethod(name = "strip")] fn strip(&self, chars: OptionalArg) -> PyResult { Ok(self .inner - .borrow() + .read() + .unwrap() .strip(chars, ByteInnerPosition::All)? .into()) } @@ -383,7 +393,8 @@ impl PyByteArray { fn lstrip(&self, chars: OptionalArg) -> PyResult { Ok(self .inner - .borrow() + .read() + .unwrap() .strip(chars, ByteInnerPosition::Left)? .into()) } @@ -392,7 +403,8 @@ impl PyByteArray { fn rstrip(&self, chars: OptionalArg) -> PyResult { Ok(self .inner - .borrow() + .read() + .unwrap() .strip(chars, ByteInnerPosition::Right)? .into()) } @@ -401,7 +413,8 @@ impl PyByteArray { fn split(&self, options: ByteInnerSplitOptions, vm: &VirtualMachine) -> PyResult { let as_bytes = self .inner - .borrow() + .read() + .unwrap() .split(options, false)? .iter() .map(|x| vm.ctx.new_bytearray(x.to_vec())) @@ -413,7 +426,8 @@ impl PyByteArray { fn rsplit(&self, options: ByteInnerSplitOptions, vm: &VirtualMachine) -> PyResult { let as_bytes = self .inner - .borrow() + .read() + .unwrap() .split(options, true)? .iter() .map(|x| vm.ctx.new_bytearray(x.to_vec())) @@ -425,7 +439,7 @@ impl PyByteArray { fn partition(&self, sep: PyByteInner, vm: &VirtualMachine) -> PyResult { // sep ALWAYS converted to bytearray even it's bytes or memoryview // so its ok to accept PyByteInner - let (left, right) = self.inner.borrow().partition(&sep, false)?; + let (left, right) = self.inner.read().unwrap().partition(&sep, false)?; Ok(vm.ctx.new_tuple(vec![ vm.ctx.new_bytearray(left), vm.ctx.new_bytearray(sep.elements), @@ -435,7 +449,7 @@ impl PyByteArray { #[pymethod(name = "rpartition")] fn rpartition(&self, sep: PyByteInner, vm: &VirtualMachine) -> PyResult { - let (left, right) = self.inner.borrow().partition(&sep, true)?; + let (left, right) = self.inner.read().unwrap().partition(&sep, true)?; Ok(vm.ctx.new_tuple(vec![ vm.ctx.new_bytearray(left), vm.ctx.new_bytearray(sep.elements), @@ -445,14 +459,15 @@ impl PyByteArray { #[pymethod(name = "expandtabs")] fn expandtabs(&self, options: ByteInnerExpandtabsOptions) -> PyByteArray { - self.inner.borrow().expandtabs(options).into() + self.inner.read().unwrap().expandtabs(options).into() } #[pymethod(name = "splitlines")] fn splitlines(&self, options: ByteInnerSplitlinesOptions, vm: &VirtualMachine) -> PyResult { let as_bytes = self .inner - .borrow() + .read() + .unwrap() .splitlines(options) .iter() .map(|x| vm.ctx.new_bytearray(x.to_vec())) @@ -462,7 +477,7 @@ impl PyByteArray { #[pymethod(name = "zfill")] fn zfill(&self, width: PyIntRef) -> PyByteArray { - self.inner.borrow().zfill(width).into() + self.inner.read().unwrap().zfill(width).into() } #[pymethod(name = "replace")] @@ -472,23 +487,24 @@ impl PyByteArray { new: PyByteInner, count: OptionalArg, ) -> PyResult { - Ok(self.inner.borrow().replace(old, new, count)?.into()) + Ok(self.inner.read().unwrap().replace(old, new, count)?.into()) } #[pymethod(name = "clear")] fn clear(&self) { - self.inner.borrow_mut().elements.clear(); + self.inner.write().unwrap().elements.clear(); } #[pymethod(name = "copy")] fn copy(&self) -> PyByteArray { - self.inner.borrow().elements.clone().into() + self.inner.read().unwrap().elements.clone().into() } #[pymethod(name = "append")] fn append(&self, x: PyIntRef, vm: &VirtualMachine) -> PyResult<()> { self.inner - .borrow_mut() + .write() + .unwrap() .elements .push(x.as_bigint().byte_or(vm)?); Ok(()) @@ -496,7 +512,7 @@ impl PyByteArray { #[pymethod(name = "extend")] fn extend(&self, iterable_of_ints: PyIterable, vm: &VirtualMachine) -> PyResult<()> { - let mut inner = self.inner.borrow_mut(); + let mut inner = self.inner.write().unwrap(); for x in iterable_of_ints.iter(vm)? { let x = x?; @@ -510,7 +526,7 @@ impl PyByteArray { #[pymethod(name = "insert")] fn insert(&self, mut index: isize, x: PyIntRef, vm: &VirtualMachine) -> PyResult<()> { - let bytes = &mut self.inner.borrow_mut().elements; + let bytes = &mut self.inner.write().unwrap().elements; let len = isize::try_from(bytes.len()) .map_err(|_e| vm.new_overflow_error("bytearray too big".to_owned()))?; @@ -536,7 +552,7 @@ impl PyByteArray { #[pymethod(name = "pop")] fn pop(&self, vm: &VirtualMachine) -> PyResult { - let bytes = &mut self.inner.borrow_mut().elements; + let bytes = &mut self.inner.write().unwrap().elements; bytes .pop() .ok_or_else(|| vm.new_index_error("pop from empty bytearray".to_owned())) @@ -544,12 +560,12 @@ impl PyByteArray { #[pymethod(name = "title")] fn title(&self) -> PyByteArray { - self.inner.borrow().title().into() + self.inner.read().unwrap().title().into() } #[pymethod(name = "__mul__")] fn repeat(&self, n: isize) -> PyByteArray { - self.inner.borrow().repeat(n).into() + self.inner.read().unwrap().repeat(n).into() } #[pymethod(name = "__rmul__")] @@ -559,7 +575,7 @@ impl PyByteArray { #[pymethod(name = "__imul__")] fn irepeat(&self, n: isize) { - self.inner.borrow_mut().irepeat(n) + self.inner.write().unwrap().irepeat(n) } fn do_cformat( @@ -574,9 +590,10 @@ impl PyByteArray { #[pymethod(name = "__mod__")] fn modulo(&self, values: PyObjectRef, vm: &VirtualMachine) -> PyResult { - let format_string = - CFormatString::from_str(std::str::from_utf8(&self.inner.borrow().elements).unwrap()) - .map_err(|err| vm.new_value_error(err.to_string()))?; + let format_string = CFormatString::from_str( + std::str::from_utf8(&self.inner.read().unwrap().elements).unwrap(), + ) + .map_err(|err| vm.new_value_error(err.to_string()))?; self.do_cformat(vm, format_string, values.clone()) } @@ -587,7 +604,7 @@ impl PyByteArray { #[pymethod(name = "reverse")] fn reverse(&self) -> PyResult<()> { - self.inner.borrow_mut().elements.reverse(); + self.inner.write().unwrap().elements.reverse(); Ok(()) } @@ -633,8 +650,8 @@ impl PyValue for PyByteArrayIterator { impl PyByteArrayIterator { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - if self.position.get() < self.bytearray.inner.borrow().len() { - let ret = self.bytearray.inner.borrow().elements[self.position.get()]; + if self.position.get() < self.bytearray.inner.read().unwrap().len() { + let ret = self.bytearray.inner.read().unwrap().elements[self.position.get()]; self.position.set(self.position.get() + 1); Ok(ret) } else { From e960c4dd3f5cae36aeaa6c0465ce6b5c1a7096a7 Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Thu, 9 Apr 2020 19:15:21 +0300 Subject: [PATCH 7/7] Use utility functions --- vm/src/obj/objbytearray.rs | 152 ++++++++++++++++--------------------- 1 file changed, 65 insertions(+), 87 deletions(-) diff --git a/vm/src/obj/objbytearray.rs b/vm/src/obj/objbytearray.rs index 0113ad957c..634124c7f7 100644 --- a/vm/src/obj/objbytearray.rs +++ b/vm/src/obj/objbytearray.rs @@ -104,45 +104,42 @@ impl PyByteArray { #[pymethod(name = "__repr__")] fn repr(&self) -> PyResult { - Ok(format!( - "bytearray(b'{}')", - self.inner.read().unwrap().repr()? - )) + Ok(format!("bytearray(b'{}')", self.borrow_value().repr()?)) } #[pymethod(name = "__len__")] fn len(&self) -> usize { - self.inner.read().unwrap().len() + self.borrow_value().len() } #[pymethod(name = "__sizeof__")] fn sizeof(&self) -> usize { - size_of::() + self.inner.read().unwrap().len() * size_of::() + size_of::() + self.borrow_value().len() * size_of::() } #[pymethod(name = "__eq__")] fn eq(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue { - self.inner.read().unwrap().eq(other, vm) + self.borrow_value().eq(other, vm) } #[pymethod(name = "__ge__")] fn ge(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue { - self.inner.read().unwrap().ge(other, vm) + self.borrow_value().ge(other, vm) } #[pymethod(name = "__le__")] fn le(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue { - self.inner.read().unwrap().le(other, vm) + self.borrow_value().le(other, vm) } #[pymethod(name = "__gt__")] fn gt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue { - self.inner.read().unwrap().gt(other, vm) + self.borrow_value().gt(other, vm) } #[pymethod(name = "__lt__")] fn lt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyComparisonValue { - self.inner.read().unwrap().lt(other, vm) + self.borrow_value().lt(other, vm) } #[pymethod(name = "__hash__")] @@ -161,7 +158,7 @@ impl PyByteArray { #[pymethod(name = "__add__")] fn add(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { if let Ok(other) = PyByteInner::try_from_object(vm, other) { - Ok(vm.ctx.new_bytearray(self.inner.read().unwrap().add(other))) + Ok(vm.ctx.new_bytearray(self.borrow_value().add(other))) } else { Ok(vm.ctx.not_implemented()) } @@ -173,12 +170,12 @@ impl PyByteArray { needle: Either, vm: &VirtualMachine, ) -> PyResult { - self.inner.read().unwrap().contains(needle, vm) + self.borrow_value().contains(needle, vm) } #[pymethod(name = "__getitem__")] fn getitem(&self, needle: Either, vm: &VirtualMachine) -> PyResult { - self.inner.read().unwrap().getitem(needle, vm) + self.borrow_value().getitem(needle, vm) } #[pymethod(name = "__setitem__")] @@ -188,77 +185,77 @@ impl PyByteArray { value: PyObjectRef, vm: &VirtualMachine, ) -> PyResult { - self.inner.write().unwrap().setitem(needle, value, vm) + self.borrow_value_mut().setitem(needle, value, vm) } #[pymethod(name = "__delitem__")] fn delitem(&self, needle: Either, vm: &VirtualMachine) -> PyResult<()> { - self.inner.write().unwrap().delitem(needle, vm) + self.borrow_value_mut().delitem(needle, vm) } #[pymethod(name = "isalnum")] fn isalnum(&self) -> bool { - self.inner.read().unwrap().isalnum() + self.borrow_value().isalnum() } #[pymethod(name = "isalpha")] fn isalpha(&self) -> bool { - self.inner.read().unwrap().isalpha() + self.borrow_value().isalpha() } #[pymethod(name = "isascii")] fn isascii(&self) -> bool { - self.inner.read().unwrap().isascii() + self.borrow_value().isascii() } #[pymethod(name = "isdigit")] fn isdigit(&self) -> bool { - self.inner.read().unwrap().isdigit() + self.borrow_value().isdigit() } #[pymethod(name = "islower")] fn islower(&self) -> bool { - self.inner.read().unwrap().islower() + self.borrow_value().islower() } #[pymethod(name = "isspace")] fn isspace(&self) -> bool { - self.inner.read().unwrap().isspace() + self.borrow_value().isspace() } #[pymethod(name = "isupper")] fn isupper(&self) -> bool { - self.inner.read().unwrap().isupper() + self.borrow_value().isupper() } #[pymethod(name = "istitle")] fn istitle(&self) -> bool { - self.inner.read().unwrap().istitle() + self.borrow_value().istitle() } #[pymethod(name = "lower")] fn lower(&self) -> PyByteArray { - self.inner.read().unwrap().lower().into() + self.borrow_value().lower().into() } #[pymethod(name = "upper")] fn upper(&self) -> PyByteArray { - self.inner.read().unwrap().upper().into() + self.borrow_value().upper().into() } #[pymethod(name = "capitalize")] fn capitalize(&self) -> PyByteArray { - self.inner.read().unwrap().capitalize().into() + self.borrow_value().capitalize().into() } #[pymethod(name = "swapcase")] fn swapcase(&self) -> PyByteArray { - self.inner.read().unwrap().swapcase().into() + self.borrow_value().swapcase().into() } #[pymethod(name = "hex")] fn hex(&self) -> String { - self.inner.read().unwrap().hex() + self.borrow_value().hex() } #[pymethod] @@ -272,7 +269,7 @@ impl PyByteArray { options: ByteInnerPaddingOptions, vm: &VirtualMachine, ) -> PyResult { - Ok(self.inner.read().unwrap().center(options, vm)?.into()) + Ok(self.borrow_value().center(options, vm)?.into()) } #[pymethod(name = "ljust")] @@ -281,7 +278,7 @@ impl PyByteArray { options: ByteInnerPaddingOptions, vm: &VirtualMachine, ) -> PyResult { - Ok(self.inner.read().unwrap().ljust(options, vm)?.into()) + Ok(self.borrow_value().ljust(options, vm)?.into()) } #[pymethod(name = "rjust")] @@ -290,17 +287,17 @@ impl PyByteArray { options: ByteInnerPaddingOptions, vm: &VirtualMachine, ) -> PyResult { - Ok(self.inner.read().unwrap().rjust(options, vm)?.into()) + Ok(self.borrow_value().rjust(options, vm)?.into()) } #[pymethod(name = "count")] fn count(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { - self.inner.read().unwrap().count(options, vm) + self.borrow_value().count(options, vm) } #[pymethod(name = "join")] fn join(&self, iter: PyIterable, vm: &VirtualMachine) -> PyResult { - Ok(self.inner.read().unwrap().join(iter, vm)?.into()) + Ok(self.borrow_value().join(iter, vm)?.into()) } #[pymethod(name = "endswith")] @@ -311,9 +308,7 @@ impl PyByteArray { end: OptionalArg, vm: &VirtualMachine, ) -> PyResult { - self.inner - .read() - .unwrap() + self.borrow_value() .startsendswith(suffix, start, end, true, vm) } @@ -325,33 +320,31 @@ impl PyByteArray { end: OptionalArg, vm: &VirtualMachine, ) -> PyResult { - self.inner - .read() - .unwrap() + self.borrow_value() .startsendswith(prefix, start, end, false, vm) } #[pymethod(name = "find")] fn find(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { - let index = self.inner.read().unwrap().find(options, false, vm)?; + let index = self.borrow_value().find(options, false, vm)?; Ok(index.map_or(-1, |v| v as isize)) } #[pymethod(name = "index")] fn index(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { - let index = self.inner.read().unwrap().find(options, false, vm)?; + let index = self.borrow_value().find(options, false, vm)?; index.ok_or_else(|| vm.new_value_error("substring not found".to_owned())) } #[pymethod(name = "rfind")] fn rfind(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { - let index = self.inner.read().unwrap().find(options, true, vm)?; + let index = self.borrow_value().find(options, true, vm)?; Ok(index.map_or(-1, |v| v as isize)) } #[pymethod(name = "rindex")] fn rindex(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { - let index = self.inner.read().unwrap().find(options, true, vm)?; + let index = self.borrow_value().find(options, true, vm)?; index.ok_or_else(|| vm.new_value_error("substring not found".to_owned())) } @@ -359,7 +352,7 @@ impl PyByteArray { fn remove(&self, x: PyIntRef, vm: &VirtualMachine) -> PyResult<()> { let x = x.as_bigint().byte_or(vm)?; - let bytes = &mut self.inner.write().unwrap().elements; + let bytes = &mut self.borrow_value_mut().elements; let pos = bytes .iter() .position(|b| *b == x) @@ -376,15 +369,13 @@ impl PyByteArray { options: ByteInnerTranslateOptions, vm: &VirtualMachine, ) -> PyResult { - Ok(self.inner.read().unwrap().translate(options, vm)?.into()) + Ok(self.borrow_value().translate(options, vm)?.into()) } #[pymethod(name = "strip")] fn strip(&self, chars: OptionalArg) -> PyResult { Ok(self - .inner - .read() - .unwrap() + .borrow_value() .strip(chars, ByteInnerPosition::All)? .into()) } @@ -392,9 +383,7 @@ impl PyByteArray { #[pymethod(name = "lstrip")] fn lstrip(&self, chars: OptionalArg) -> PyResult { Ok(self - .inner - .read() - .unwrap() + .borrow_value() .strip(chars, ByteInnerPosition::Left)? .into()) } @@ -402,9 +391,7 @@ impl PyByteArray { #[pymethod(name = "rstrip")] fn rstrip(&self, chars: OptionalArg) -> PyResult { Ok(self - .inner - .read() - .unwrap() + .borrow_value() .strip(chars, ByteInnerPosition::Right)? .into()) } @@ -412,9 +399,7 @@ impl PyByteArray { #[pymethod(name = "split")] fn split(&self, options: ByteInnerSplitOptions, vm: &VirtualMachine) -> PyResult { let as_bytes = self - .inner - .read() - .unwrap() + .borrow_value() .split(options, false)? .iter() .map(|x| vm.ctx.new_bytearray(x.to_vec())) @@ -425,9 +410,7 @@ impl PyByteArray { #[pymethod(name = "rsplit")] fn rsplit(&self, options: ByteInnerSplitOptions, vm: &VirtualMachine) -> PyResult { let as_bytes = self - .inner - .read() - .unwrap() + .borrow_value() .split(options, true)? .iter() .map(|x| vm.ctx.new_bytearray(x.to_vec())) @@ -439,7 +422,7 @@ impl PyByteArray { fn partition(&self, sep: PyByteInner, vm: &VirtualMachine) -> PyResult { // sep ALWAYS converted to bytearray even it's bytes or memoryview // so its ok to accept PyByteInner - let (left, right) = self.inner.read().unwrap().partition(&sep, false)?; + let (left, right) = self.borrow_value().partition(&sep, false)?; Ok(vm.ctx.new_tuple(vec![ vm.ctx.new_bytearray(left), vm.ctx.new_bytearray(sep.elements), @@ -449,7 +432,7 @@ impl PyByteArray { #[pymethod(name = "rpartition")] fn rpartition(&self, sep: PyByteInner, vm: &VirtualMachine) -> PyResult { - let (left, right) = self.inner.read().unwrap().partition(&sep, true)?; + let (left, right) = self.borrow_value().partition(&sep, true)?; Ok(vm.ctx.new_tuple(vec![ vm.ctx.new_bytearray(left), vm.ctx.new_bytearray(sep.elements), @@ -459,15 +442,13 @@ impl PyByteArray { #[pymethod(name = "expandtabs")] fn expandtabs(&self, options: ByteInnerExpandtabsOptions) -> PyByteArray { - self.inner.read().unwrap().expandtabs(options).into() + self.borrow_value().expandtabs(options).into() } #[pymethod(name = "splitlines")] fn splitlines(&self, options: ByteInnerSplitlinesOptions, vm: &VirtualMachine) -> PyResult { let as_bytes = self - .inner - .read() - .unwrap() + .borrow_value() .splitlines(options) .iter() .map(|x| vm.ctx.new_bytearray(x.to_vec())) @@ -477,7 +458,7 @@ impl PyByteArray { #[pymethod(name = "zfill")] fn zfill(&self, width: PyIntRef) -> PyByteArray { - self.inner.read().unwrap().zfill(width).into() + self.borrow_value().zfill(width).into() } #[pymethod(name = "replace")] @@ -487,24 +468,22 @@ impl PyByteArray { new: PyByteInner, count: OptionalArg, ) -> PyResult { - Ok(self.inner.read().unwrap().replace(old, new, count)?.into()) + Ok(self.borrow_value().replace(old, new, count)?.into()) } #[pymethod(name = "clear")] fn clear(&self) { - self.inner.write().unwrap().elements.clear(); + self.borrow_value_mut().elements.clear(); } #[pymethod(name = "copy")] fn copy(&self) -> PyByteArray { - self.inner.read().unwrap().elements.clone().into() + self.borrow_value().elements.clone().into() } #[pymethod(name = "append")] fn append(&self, x: PyIntRef, vm: &VirtualMachine) -> PyResult<()> { - self.inner - .write() - .unwrap() + self.borrow_value_mut() .elements .push(x.as_bigint().byte_or(vm)?); Ok(()) @@ -512,7 +491,7 @@ impl PyByteArray { #[pymethod(name = "extend")] fn extend(&self, iterable_of_ints: PyIterable, vm: &VirtualMachine) -> PyResult<()> { - let mut inner = self.inner.write().unwrap(); + let mut inner = self.borrow_value_mut(); for x in iterable_of_ints.iter(vm)? { let x = x?; @@ -526,7 +505,7 @@ impl PyByteArray { #[pymethod(name = "insert")] fn insert(&self, mut index: isize, x: PyIntRef, vm: &VirtualMachine) -> PyResult<()> { - let bytes = &mut self.inner.write().unwrap().elements; + let bytes = &mut self.borrow_value_mut().elements; let len = isize::try_from(bytes.len()) .map_err(|_e| vm.new_overflow_error("bytearray too big".to_owned()))?; @@ -552,7 +531,7 @@ impl PyByteArray { #[pymethod(name = "pop")] fn pop(&self, vm: &VirtualMachine) -> PyResult { - let bytes = &mut self.inner.write().unwrap().elements; + let bytes = &mut self.borrow_value_mut().elements; bytes .pop() .ok_or_else(|| vm.new_index_error("pop from empty bytearray".to_owned())) @@ -560,12 +539,12 @@ impl PyByteArray { #[pymethod(name = "title")] fn title(&self) -> PyByteArray { - self.inner.read().unwrap().title().into() + self.borrow_value().title().into() } #[pymethod(name = "__mul__")] fn repeat(&self, n: isize) -> PyByteArray { - self.inner.read().unwrap().repeat(n).into() + self.borrow_value().repeat(n).into() } #[pymethod(name = "__rmul__")] @@ -575,7 +554,7 @@ impl PyByteArray { #[pymethod(name = "__imul__")] fn irepeat(&self, n: isize) { - self.inner.write().unwrap().irepeat(n) + self.borrow_value_mut().irepeat(n) } fn do_cformat( @@ -590,10 +569,9 @@ impl PyByteArray { #[pymethod(name = "__mod__")] fn modulo(&self, values: PyObjectRef, vm: &VirtualMachine) -> PyResult { - let format_string = CFormatString::from_str( - std::str::from_utf8(&self.inner.read().unwrap().elements).unwrap(), - ) - .map_err(|err| vm.new_value_error(err.to_string()))?; + let format_string = + CFormatString::from_str(std::str::from_utf8(&self.borrow_value().elements).unwrap()) + .map_err(|err| vm.new_value_error(err.to_string()))?; self.do_cformat(vm, format_string, values.clone()) } @@ -604,7 +582,7 @@ impl PyByteArray { #[pymethod(name = "reverse")] fn reverse(&self) -> PyResult<()> { - self.inner.write().unwrap().elements.reverse(); + self.borrow_value_mut().elements.reverse(); Ok(()) } @@ -650,8 +628,8 @@ impl PyValue for PyByteArrayIterator { impl PyByteArrayIterator { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - if self.position.get() < self.bytearray.inner.read().unwrap().len() { - let ret = self.bytearray.inner.read().unwrap().elements[self.position.get()]; + if self.position.get() < self.bytearray.borrow_value().len() { + let ret = self.bytearray.borrow_value().elements[self.position.get()]; self.position.set(self.position.get() + 1); Ok(ret) } else {