Skip to content

Commit 15c88c7

Browse files
authored
Merge pull request #1913 from palaviv/coro-threading
Change more object to thread safe
2 parents 049412c + b7d8bd6 commit 15c88c7

File tree

5 files changed

+106
-84
lines changed

5 files changed

+106
-84
lines changed

vm/src/exceptions.rs

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,32 @@ use crate::obj::objtuple::{PyTuple, PyTupleRef};
66
use crate::obj::objtype::{self, PyClass, PyClassRef};
77
use crate::py_serde;
88
use crate::pyobject::{
9-
PyClassImpl, PyContext, PyIterable, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject,
10-
TypeProtocol,
9+
PyClassImpl, PyContext, PyIterable, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe,
10+
TryFromObject, TypeProtocol,
1111
};
1212
use crate::slots::PyTpFlags;
1313
use crate::types::create_type;
1414
use crate::VirtualMachine;
15+
1516
use itertools::Itertools;
16-
use std::cell::{Cell, RefCell};
1717
use std::fmt;
1818
use std::fs::File;
1919
use std::io::{self, BufRead, BufReader, Write};
20+
use std::sync::RwLock;
21+
22+
use crossbeam_utils::atomic::AtomicCell;
2023

2124
#[pyclass]
2225
pub struct PyBaseException {
23-
traceback: RefCell<Option<PyTracebackRef>>,
24-
cause: RefCell<Option<PyBaseExceptionRef>>,
25-
context: RefCell<Option<PyBaseExceptionRef>>,
26-
suppress_context: Cell<bool>,
27-
args: RefCell<PyTupleRef>,
26+
traceback: RwLock<Option<PyTracebackRef>>,
27+
cause: RwLock<Option<PyBaseExceptionRef>>,
28+
context: RwLock<Option<PyBaseExceptionRef>>,
29+
suppress_context: AtomicCell<bool>,
30+
args: RwLock<PyTupleRef>,
2831
}
2932

33+
impl ThreadSafe for PyBaseException {}
34+
3035
impl fmt::Debug for PyBaseException {
3136
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
3237
// TODO: implement more detailed, non-recursive Debug formatter
@@ -48,11 +53,11 @@ impl PyValue for PyBaseException {
4853
impl PyBaseException {
4954
pub(crate) fn new(args: Vec<PyObjectRef>, vm: &VirtualMachine) -> PyBaseException {
5055
PyBaseException {
51-
traceback: RefCell::new(None),
52-
cause: RefCell::new(None),
53-
context: RefCell::new(None),
54-
suppress_context: Cell::new(false),
55-
args: RefCell::new(PyTuple::from(args).into_ref(vm)),
56+
traceback: RwLock::new(None),
57+
cause: RwLock::new(None),
58+
context: RwLock::new(None),
59+
suppress_context: AtomicCell::new(false),
60+
args: RwLock::new(PyTuple::from(args).into_ref(vm)),
5661
}
5762
}
5863

@@ -63,65 +68,65 @@ impl PyBaseException {
6368

6469
#[pymethod(name = "__init__")]
6570
fn init(&self, args: PyFuncArgs, vm: &VirtualMachine) -> PyResult<()> {
66-
self.args.replace(PyTuple::from(args.args).into_ref(vm));
71+
*self.args.write().unwrap() = PyTuple::from(args.args).into_ref(vm);
6772
Ok(())
6873
}
6974

7075
#[pyproperty]
7176
pub fn args(&self) -> PyTupleRef {
72-
self.args.borrow().clone()
77+
self.args.read().unwrap().clone()
7378
}
7479

7580
#[pyproperty(setter)]
7681
fn set_args(&self, args: PyIterable, vm: &VirtualMachine) -> PyResult<()> {
7782
let args = args.iter(vm)?.collect::<PyResult<Vec<_>>>()?;
78-
self.args.replace(PyTuple::from(args).into_ref(vm));
83+
*self.args.write().unwrap() = PyTuple::from(args).into_ref(vm);
7984
Ok(())
8085
}
8186

8287
#[pyproperty(name = "__traceback__")]
8388
pub fn traceback(&self) -> Option<PyTracebackRef> {
84-
self.traceback.borrow().clone()
89+
self.traceback.read().unwrap().clone()
8590
}
8691

8792
#[pyproperty(name = "__traceback__", setter)]
8893
pub fn set_traceback(&self, traceback: Option<PyTracebackRef>) {
89-
self.traceback.replace(traceback);
94+
*self.traceback.write().unwrap() = traceback;
9095
}
9196

9297
#[pyproperty(name = "__cause__")]
9398
pub fn cause(&self) -> Option<PyBaseExceptionRef> {
94-
self.cause.borrow().clone()
99+
self.cause.read().unwrap().clone()
95100
}
96101

97102
#[pyproperty(name = "__cause__", setter)]
98103
pub fn set_cause(&self, cause: Option<PyBaseExceptionRef>) {
99-
self.cause.replace(cause);
104+
*self.cause.write().unwrap() = cause;
100105
}
101106

102107
#[pyproperty(name = "__context__")]
103108
pub fn context(&self) -> Option<PyBaseExceptionRef> {
104-
self.context.borrow().clone()
109+
self.context.read().unwrap().clone()
105110
}
106111

107112
#[pyproperty(name = "__context__", setter)]
108113
pub fn set_context(&self, context: Option<PyBaseExceptionRef>) {
109-
self.context.replace(context);
114+
*self.context.write().unwrap() = context;
110115
}
111116

112117
#[pyproperty(name = "__suppress_context__")]
113118
fn get_suppress_context(&self) -> bool {
114-
self.suppress_context.get()
119+
self.suppress_context.load()
115120
}
116121

117122
#[pyproperty(name = "__suppress_context__", setter)]
118123
fn set_suppress_context(&self, suppress_context: bool) {
119-
self.suppress_context.set(suppress_context);
124+
self.suppress_context.store(suppress_context);
120125
}
121126

122127
#[pymethod]
123128
fn with_traceback(zelf: PyRef<Self>, tb: Option<PyTracebackRef>) -> PyResult {
124-
zelf.traceback.replace(tb);
129+
*zelf.traceback.write().unwrap() = tb;
125130
Ok(zelf.as_object().clone())
126131
}
127132

@@ -213,7 +218,7 @@ pub fn write_exception_inner<W: Write>(
213218
vm: &VirtualMachine,
214219
exc: &PyBaseExceptionRef,
215220
) -> io::Result<()> {
216-
if let Some(tb) = exc.traceback.borrow().clone() {
221+
if let Some(tb) = exc.traceback.read().unwrap().clone() {
217222
writeln!(output, "Traceback (most recent call last):")?;
218223
for tb in tb.iter() {
219224
write_traceback_entry(output, &tb)?;
@@ -605,7 +610,8 @@ fn none_getter(_obj: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
605610
fn make_arg_getter(idx: usize) -> impl Fn(PyBaseExceptionRef, &VirtualMachine) -> PyObjectRef {
606611
move |exc, vm| {
607612
exc.args
608-
.borrow()
613+
.read()
614+
.unwrap()
609615
.as_slice()
610616
.get(idx)
611617
.cloned()
@@ -716,7 +722,7 @@ impl serde::Serialize for SerializeException<'_> {
716722
"context",
717723
&self.exc.context().as_ref().map(|e| Self::new(self.vm, e)),
718724
)?;
719-
struc.serialize_field("suppress_context", &self.exc.suppress_context.get())?;
725+
struc.serialize_field("suppress_context", &self.exc.suppress_context.load())?;
720726

721727
let args = {
722728
struct Args<'vm>(&'vm VirtualMachine, PyTupleRef);

vm/src/obj/objasyncgenerator.rs

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@ use super::objtype::{self, PyClassRef};
44
use crate::exceptions::PyBaseExceptionRef;
55
use crate::frame::FrameRef;
66
use crate::function::OptionalArg;
7-
use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue};
7+
use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe};
88
use crate::vm::VirtualMachine;
99

10-
use std::cell::Cell;
10+
use crossbeam_utils::atomic::AtomicCell;
1111

1212
#[pyclass(name = "async_generator")]
1313
#[derive(Debug)]
1414
pub struct PyAsyncGen {
1515
inner: Coro,
16-
running_async: Cell<bool>,
16+
running_async: AtomicCell<bool>,
1717
}
1818
pub type PyAsyncGenRef = PyRef<PyAsyncGen>;
19+
impl ThreadSafe for PyAsyncGen {}
1920

2021
impl PyValue for PyAsyncGen {
2122
fn class(vm: &VirtualMachine) -> PyClassRef {
@@ -32,7 +33,7 @@ impl PyAsyncGen {
3233
pub fn new(frame: FrameRef, vm: &VirtualMachine) -> PyAsyncGenRef {
3334
PyAsyncGen {
3435
inner: Coro::new(frame, Variant::AsyncGen),
35-
running_async: Cell::new(false),
36+
running_async: AtomicCell::new(false),
3637
}
3738
.into_ref(vm)
3839
}
@@ -57,7 +58,7 @@ impl PyAsyncGen {
5758
fn asend(zelf: PyRef<Self>, value: PyObjectRef, _vm: &VirtualMachine) -> PyAsyncGenASend {
5859
PyAsyncGenASend {
5960
ag: zelf,
60-
state: Cell::new(AwaitableState::Init),
61+
state: AtomicCell::new(AwaitableState::Init),
6162
value,
6263
}
6364
}
@@ -73,7 +74,7 @@ impl PyAsyncGen {
7374
PyAsyncGenAThrow {
7475
ag: zelf,
7576
aclose: false,
76-
state: Cell::new(AwaitableState::Init),
77+
state: AtomicCell::new(AwaitableState::Init),
7778
value: (
7879
exc_type,
7980
exc_val.unwrap_or_else(|| vm.get_none()),
@@ -87,7 +88,7 @@ impl PyAsyncGen {
8788
PyAsyncGenAThrow {
8889
ag: zelf,
8990
aclose: true,
90-
state: Cell::new(AwaitableState::Init),
91+
state: AtomicCell::new(AwaitableState::Init),
9192
value: (
9293
vm.ctx.exceptions.generator_exit.clone().into_object(),
9394
vm.get_none(),
@@ -129,15 +130,15 @@ impl PyAsyncGenWrappedValue {
129130
if objtype::isinstance(&e, &vm.ctx.exceptions.stop_async_iteration)
130131
|| objtype::isinstance(&e, &vm.ctx.exceptions.generator_exit)
131132
{
132-
ag.inner.closed.set(true);
133+
ag.inner.closed.store(true);
133134
}
134-
ag.running_async.set(false);
135+
ag.running_async.store(false);
135136
}
136137
let val = val?;
137138

138139
match_class!(match val {
139140
val @ Self => {
140-
ag.running_async.set(false);
141+
ag.running_async.store(false);
141142
Err(vm.new_exception(
142143
vm.ctx.exceptions.stop_iteration.clone(),
143144
vec![val.0.clone()],
@@ -159,10 +160,12 @@ enum AwaitableState {
159160
#[derive(Debug)]
160161
struct PyAsyncGenASend {
161162
ag: PyAsyncGenRef,
162-
state: Cell<AwaitableState>,
163+
state: AtomicCell<AwaitableState>,
163164
value: PyObjectRef,
164165
}
165166

167+
impl ThreadSafe for PyAsyncGenASend {}
168+
166169
impl PyValue for PyAsyncGenASend {
167170
fn class(vm: &VirtualMachine) -> PyClassRef {
168171
vm.ctx.types.async_generator_asend.clone()
@@ -187,21 +190,21 @@ impl PyAsyncGenASend {
187190

188191
#[pymethod]
189192
fn send(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult {
190-
let val = match self.state.get() {
193+
let val = match self.state.load() {
191194
AwaitableState::Closed => {
192195
return Err(vm.new_runtime_error(
193196
"cannot reuse already awaited __anext__()/asend()".to_owned(),
194197
))
195198
}
196199
AwaitableState::Iter => val, // already running, all good
197200
AwaitableState::Init => {
198-
if self.ag.running_async.get() {
201+
if self.ag.running_async.load() {
199202
return Err(vm.new_runtime_error(
200203
"anext(): asynchronous generator is already running".to_owned(),
201204
));
202205
}
203-
self.ag.running_async.set(true);
204-
self.state.set(AwaitableState::Iter);
206+
self.ag.running_async.store(true);
207+
self.state.store(AwaitableState::Iter);
205208
if vm.is_none(&val) {
206209
self.value.clone()
207210
} else {
@@ -225,7 +228,7 @@ impl PyAsyncGenASend {
225228
exc_tb: OptionalArg,
226229
vm: &VirtualMachine,
227230
) -> PyResult {
228-
if let AwaitableState::Closed = self.state.get() {
231+
if let AwaitableState::Closed = self.state.load() {
229232
return Err(
230233
vm.new_runtime_error("cannot reuse already awaited __anext__()/asend()".to_owned())
231234
);
@@ -246,7 +249,7 @@ impl PyAsyncGenASend {
246249

247250
#[pymethod]
248251
fn close(&self) {
249-
self.state.set(AwaitableState::Closed);
252+
self.state.store(AwaitableState::Closed);
250253
}
251254
}
252255

@@ -255,10 +258,12 @@ impl PyAsyncGenASend {
255258
struct PyAsyncGenAThrow {
256259
ag: PyAsyncGenRef,
257260
aclose: bool,
258-
state: Cell<AwaitableState>,
261+
state: AtomicCell<AwaitableState>,
259262
value: (PyObjectRef, PyObjectRef, PyObjectRef),
260263
}
261264

265+
impl ThreadSafe for PyAsyncGenAThrow {}
266+
262267
impl PyValue for PyAsyncGenAThrow {
263268
fn class(vm: &VirtualMachine) -> PyClassRef {
264269
vm.ctx.types.async_generator_athrow.clone()
@@ -283,14 +288,14 @@ impl PyAsyncGenAThrow {
283288

284289
#[pymethod]
285290
fn send(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult {
286-
match self.state.get() {
291+
match self.state.load() {
287292
AwaitableState::Closed => {
288293
Err(vm
289294
.new_runtime_error("cannot reuse already awaited aclose()/athrow()".to_owned()))
290295
}
291296
AwaitableState::Init => {
292-
if self.ag.running_async.get() {
293-
self.state.set(AwaitableState::Closed);
297+
if self.ag.running_async.load() {
298+
self.state.store(AwaitableState::Closed);
294299
let msg = if self.aclose {
295300
"aclose(): asynchronous generator is already running"
296301
} else {
@@ -299,16 +304,16 @@ impl PyAsyncGenAThrow {
299304
return Err(vm.new_runtime_error(msg.to_owned()));
300305
}
301306
if self.ag.inner.closed() {
302-
self.state.set(AwaitableState::Closed);
307+
self.state.store(AwaitableState::Closed);
303308
return Err(vm.new_exception_empty(vm.ctx.exceptions.stop_iteration.clone()));
304309
}
305310
if !vm.is_none(&val) {
306311
return Err(vm.new_runtime_error(
307312
"can't send non-None value to a just-started async generator".to_owned(),
308313
));
309314
}
310-
self.state.set(AwaitableState::Iter);
311-
self.ag.running_async.set(true);
315+
self.state.store(AwaitableState::Iter);
316+
self.ag.running_async.store(true);
312317

313318
let (ty, val, tb) = self.value.clone();
314319
let ret = self.ag.inner.throw(ty, val, tb, vm);
@@ -368,21 +373,21 @@ impl PyAsyncGenAThrow {
368373

369374
#[pymethod]
370375
fn close(&self) {
371-
self.state.set(AwaitableState::Closed);
376+
self.state.store(AwaitableState::Closed);
372377
}
373378

374379
fn ignored_close(&self, res: &PyResult) -> bool {
375380
res.as_ref()
376381
.map_or(false, |v| v.payload_is::<PyAsyncGenWrappedValue>())
377382
}
378383
fn yield_close(&self, vm: &VirtualMachine) -> PyBaseExceptionRef {
379-
self.ag.running_async.set(false);
380-
self.state.set(AwaitableState::Closed);
384+
self.ag.running_async.store(false);
385+
self.state.store(AwaitableState::Closed);
381386
vm.new_runtime_error("async generator ignored GeneratorExit".to_owned())
382387
}
383388
fn check_error(&self, exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyBaseExceptionRef {
384-
self.ag.running_async.set(false);
385-
self.state.set(AwaitableState::Closed);
389+
self.ag.running_async.store(false);
390+
self.state.store(AwaitableState::Closed);
386391
if self.aclose
387392
&& (objtype::isinstance(&exc, &vm.ctx.exceptions.stop_async_iteration)
388393
|| objtype::isinstance(&exc, &vm.ctx.exceptions.generator_exit))

0 commit comments

Comments
 (0)