Skip to content

Commit 73d9bac

Browse files
committed
Make PyAsyncGen, PyAsyncGenASend, PyAsyncGenAThrow ThreadSafe
1 parent ab628a3 commit 73d9bac

File tree

1 file changed

+33
-28
lines changed

1 file changed

+33
-28
lines changed

vm/src/obj/objasyncgenerator.rs

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@ use super::objtype::{self, PyClassRef};
44
use crate::exceptions::PyBaseExceptionRef;
55
use crate::frame::FrameRef;
66
use crate::function::OptionalArg;
7-
use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue};
7+
use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, ThreadSafe};
88
use crate::vm::VirtualMachine;
99

10-
use std::cell::Cell;
10+
use crossbeam_utils::atomic::AtomicCell;
1111

1212
#[pyclass(name = "async_generator")]
1313
#[derive(Debug)]
1414
pub struct PyAsyncGen {
1515
inner: Coro,
16-
running_async: Cell<bool>,
16+
running_async: AtomicCell<bool>,
1717
}
1818
pub type PyAsyncGenRef = PyRef<PyAsyncGen>;
19+
impl ThreadSafe for PyAsyncGen {}
1920

2021
impl PyValue for PyAsyncGen {
2122
fn class(vm: &VirtualMachine) -> PyClassRef {
@@ -32,7 +33,7 @@ impl PyAsyncGen {
3233
pub fn new(frame: FrameRef, vm: &VirtualMachine) -> PyAsyncGenRef {
3334
PyAsyncGen {
3435
inner: Coro::new(frame, Variant::AsyncGen),
35-
running_async: Cell::new(false),
36+
running_async: AtomicCell::new(false),
3637
}
3738
.into_ref(vm)
3839
}
@@ -57,7 +58,7 @@ impl PyAsyncGen {
5758
fn asend(zelf: PyRef<Self>, value: PyObjectRef, _vm: &VirtualMachine) -> PyAsyncGenASend {
5859
PyAsyncGenASend {
5960
ag: zelf,
60-
state: Cell::new(AwaitableState::Init),
61+
state: AtomicCell::new(AwaitableState::Init),
6162
value,
6263
}
6364
}
@@ -73,7 +74,7 @@ impl PyAsyncGen {
7374
PyAsyncGenAThrow {
7475
ag: zelf,
7576
aclose: false,
76-
state: Cell::new(AwaitableState::Init),
77+
state: AtomicCell::new(AwaitableState::Init),
7778
value: (
7879
exc_type,
7980
exc_val.unwrap_or_else(|| vm.get_none()),
@@ -87,7 +88,7 @@ impl PyAsyncGen {
8788
PyAsyncGenAThrow {
8889
ag: zelf,
8990
aclose: true,
90-
state: Cell::new(AwaitableState::Init),
91+
state: AtomicCell::new(AwaitableState::Init),
9192
value: (
9293
vm.ctx.exceptions.generator_exit.clone().into_object(),
9394
vm.get_none(),
@@ -131,13 +132,13 @@ impl PyAsyncGenWrappedValue {
131132
{
132133
ag.inner.closed.store(true);
133134
}
134-
ag.running_async.set(false);
135+
ag.running_async.store(false);
135136
}
136137
let val = val?;
137138

138139
match_class!(match val {
139140
val @ Self => {
140-
ag.running_async.set(false);
141+
ag.running_async.store(false);
141142
Err(vm.new_exception(
142143
vm.ctx.exceptions.stop_iteration.clone(),
143144
vec![val.0.clone()],
@@ -159,10 +160,12 @@ enum AwaitableState {
159160
#[derive(Debug)]
160161
struct PyAsyncGenASend {
161162
ag: PyAsyncGenRef,
162-
state: Cell<AwaitableState>,
163+
state: AtomicCell<AwaitableState>,
163164
value: PyObjectRef,
164165
}
165166

167+
impl ThreadSafe for PyAsyncGenASend {}
168+
166169
impl PyValue for PyAsyncGenASend {
167170
fn class(vm: &VirtualMachine) -> PyClassRef {
168171
vm.ctx.types.async_generator_asend.clone()
@@ -187,21 +190,21 @@ impl PyAsyncGenASend {
187190

188191
#[pymethod]
189192
fn send(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult {
190-
let val = match self.state.get() {
193+
let val = match self.state.load() {
191194
AwaitableState::Closed => {
192195
return Err(vm.new_runtime_error(
193196
"cannot reuse already awaited __anext__()/asend()".to_owned(),
194197
))
195198
}
196199
AwaitableState::Iter => val, // already running, all good
197200
AwaitableState::Init => {
198-
if self.ag.running_async.get() {
201+
if self.ag.running_async.load() {
199202
return Err(vm.new_runtime_error(
200203
"anext(): asynchronous generator is already running".to_owned(),
201204
));
202205
}
203-
self.ag.running_async.set(true);
204-
self.state.set(AwaitableState::Iter);
206+
self.ag.running_async.store(true);
207+
self.state.store(AwaitableState::Iter);
205208
if vm.is_none(&val) {
206209
self.value.clone()
207210
} else {
@@ -225,7 +228,7 @@ impl PyAsyncGenASend {
225228
exc_tb: OptionalArg,
226229
vm: &VirtualMachine,
227230
) -> PyResult {
228-
if let AwaitableState::Closed = self.state.get() {
231+
if let AwaitableState::Closed = self.state.load() {
229232
return Err(
230233
vm.new_runtime_error("cannot reuse already awaited __anext__()/asend()".to_owned())
231234
);
@@ -246,7 +249,7 @@ impl PyAsyncGenASend {
246249

247250
#[pymethod]
248251
fn close(&self) {
249-
self.state.set(AwaitableState::Closed);
252+
self.state.store(AwaitableState::Closed);
250253
}
251254
}
252255

@@ -255,10 +258,12 @@ impl PyAsyncGenASend {
255258
struct PyAsyncGenAThrow {
256259
ag: PyAsyncGenRef,
257260
aclose: bool,
258-
state: Cell<AwaitableState>,
261+
state: AtomicCell<AwaitableState>,
259262
value: (PyObjectRef, PyObjectRef, PyObjectRef),
260263
}
261264

265+
impl ThreadSafe for PyAsyncGenAThrow {}
266+
262267
impl PyValue for PyAsyncGenAThrow {
263268
fn class(vm: &VirtualMachine) -> PyClassRef {
264269
vm.ctx.types.async_generator_athrow.clone()
@@ -283,14 +288,14 @@ impl PyAsyncGenAThrow {
283288

284289
#[pymethod]
285290
fn send(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult {
286-
match self.state.get() {
291+
match self.state.load() {
287292
AwaitableState::Closed => {
288293
Err(vm
289294
.new_runtime_error("cannot reuse already awaited aclose()/athrow()".to_owned()))
290295
}
291296
AwaitableState::Init => {
292-
if self.ag.running_async.get() {
293-
self.state.set(AwaitableState::Closed);
297+
if self.ag.running_async.load() {
298+
self.state.store(AwaitableState::Closed);
294299
let msg = if self.aclose {
295300
"aclose(): asynchronous generator is already running"
296301
} else {
@@ -299,16 +304,16 @@ impl PyAsyncGenAThrow {
299304
return Err(vm.new_runtime_error(msg.to_owned()));
300305
}
301306
if self.ag.inner.closed() {
302-
self.state.set(AwaitableState::Closed);
307+
self.state.store(AwaitableState::Closed);
303308
return Err(vm.new_exception_empty(vm.ctx.exceptions.stop_iteration.clone()));
304309
}
305310
if !vm.is_none(&val) {
306311
return Err(vm.new_runtime_error(
307312
"can't send non-None value to a just-started async generator".to_owned(),
308313
));
309314
}
310-
self.state.set(AwaitableState::Iter);
311-
self.ag.running_async.set(true);
315+
self.state.store(AwaitableState::Iter);
316+
self.ag.running_async.store(true);
312317

313318
let (ty, val, tb) = self.value.clone();
314319
let ret = self.ag.inner.throw(ty, val, tb, vm);
@@ -368,21 +373,21 @@ impl PyAsyncGenAThrow {
368373

369374
#[pymethod]
370375
fn close(&self) {
371-
self.state.set(AwaitableState::Closed);
376+
self.state.store(AwaitableState::Closed);
372377
}
373378

374379
fn ignored_close(&self, res: &PyResult) -> bool {
375380
res.as_ref()
376381
.map_or(false, |v| v.payload_is::<PyAsyncGenWrappedValue>())
377382
}
378383
fn yield_close(&self, vm: &VirtualMachine) -> PyBaseExceptionRef {
379-
self.ag.running_async.set(false);
380-
self.state.set(AwaitableState::Closed);
384+
self.ag.running_async.store(false);
385+
self.state.store(AwaitableState::Closed);
381386
vm.new_runtime_error("async generator ignored GeneratorExit".to_owned())
382387
}
383388
fn check_error(&self, exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyBaseExceptionRef {
384-
self.ag.running_async.set(false);
385-
self.state.set(AwaitableState::Closed);
389+
self.ag.running_async.store(false);
390+
self.state.store(AwaitableState::Closed);
386391
if self.aclose
387392
&& (objtype::isinstance(&exc, &vm.ctx.exceptions.stop_async_iteration)
388393
|| objtype::isinstance(&exc, &vm.ctx.exceptions.generator_exit))

0 commit comments

Comments
 (0)