@@ -4,18 +4,19 @@ use super::objtype::{self, PyClassRef};
4
4
use crate :: exceptions:: PyBaseExceptionRef ;
5
5
use crate :: frame:: FrameRef ;
6
6
use crate :: function:: OptionalArg ;
7
- use crate :: pyobject:: { PyClassImpl , PyContext , PyObjectRef , PyRef , PyResult , PyValue } ;
7
+ use crate :: pyobject:: { PyClassImpl , PyContext , PyObjectRef , PyRef , PyResult , PyValue , ThreadSafe } ;
8
8
use crate :: vm:: VirtualMachine ;
9
9
10
- use std :: cell :: Cell ;
10
+ use crossbeam_utils :: atomic :: AtomicCell ;
11
11
12
12
#[ pyclass( name = "async_generator" ) ]
13
13
#[ derive( Debug ) ]
14
14
pub struct PyAsyncGen {
15
15
inner : Coro ,
16
- running_async : Cell < bool > ,
16
+ running_async : AtomicCell < bool > ,
17
17
}
18
18
pub type PyAsyncGenRef = PyRef < PyAsyncGen > ;
19
+ impl ThreadSafe for PyAsyncGen { }
19
20
20
21
impl PyValue for PyAsyncGen {
21
22
fn class ( vm : & VirtualMachine ) -> PyClassRef {
@@ -32,7 +33,7 @@ impl PyAsyncGen {
32
33
pub fn new ( frame : FrameRef , vm : & VirtualMachine ) -> PyAsyncGenRef {
33
34
PyAsyncGen {
34
35
inner : Coro :: new ( frame, Variant :: AsyncGen ) ,
35
- running_async : Cell :: new ( false ) ,
36
+ running_async : AtomicCell :: new ( false ) ,
36
37
}
37
38
. into_ref ( vm)
38
39
}
@@ -57,7 +58,7 @@ impl PyAsyncGen {
57
58
fn asend ( zelf : PyRef < Self > , value : PyObjectRef , _vm : & VirtualMachine ) -> PyAsyncGenASend {
58
59
PyAsyncGenASend {
59
60
ag : zelf,
60
- state : Cell :: new ( AwaitableState :: Init ) ,
61
+ state : AtomicCell :: new ( AwaitableState :: Init ) ,
61
62
value,
62
63
}
63
64
}
@@ -73,7 +74,7 @@ impl PyAsyncGen {
73
74
PyAsyncGenAThrow {
74
75
ag : zelf,
75
76
aclose : false ,
76
- state : Cell :: new ( AwaitableState :: Init ) ,
77
+ state : AtomicCell :: new ( AwaitableState :: Init ) ,
77
78
value : (
78
79
exc_type,
79
80
exc_val. unwrap_or_else ( || vm. get_none ( ) ) ,
@@ -87,7 +88,7 @@ impl PyAsyncGen {
87
88
PyAsyncGenAThrow {
88
89
ag : zelf,
89
90
aclose : true ,
90
- state : Cell :: new ( AwaitableState :: Init ) ,
91
+ state : AtomicCell :: new ( AwaitableState :: Init ) ,
91
92
value : (
92
93
vm. ctx . exceptions . generator_exit . clone ( ) . into_object ( ) ,
93
94
vm. get_none ( ) ,
@@ -129,15 +130,15 @@ impl PyAsyncGenWrappedValue {
129
130
if objtype:: isinstance ( & e, & vm. ctx . exceptions . stop_async_iteration )
130
131
|| objtype:: isinstance ( & e, & vm. ctx . exceptions . generator_exit )
131
132
{
132
- ag. inner . closed . set ( true ) ;
133
+ ag. inner . closed . store ( true ) ;
133
134
}
134
- ag. running_async . set ( false ) ;
135
+ ag. running_async . store ( false ) ;
135
136
}
136
137
let val = val?;
137
138
138
139
match_class ! ( match val {
139
140
val @ Self => {
140
- ag. running_async. set ( false ) ;
141
+ ag. running_async. store ( false ) ;
141
142
Err ( vm. new_exception(
142
143
vm. ctx. exceptions. stop_iteration. clone( ) ,
143
144
vec![ val. 0 . clone( ) ] ,
@@ -159,10 +160,12 @@ enum AwaitableState {
159
160
#[ derive( Debug ) ]
160
161
struct PyAsyncGenASend {
161
162
ag : PyAsyncGenRef ,
162
- state : Cell < AwaitableState > ,
163
+ state : AtomicCell < AwaitableState > ,
163
164
value : PyObjectRef ,
164
165
}
165
166
167
+ impl ThreadSafe for PyAsyncGenASend { }
168
+
166
169
impl PyValue for PyAsyncGenASend {
167
170
fn class ( vm : & VirtualMachine ) -> PyClassRef {
168
171
vm. ctx . types . async_generator_asend . clone ( )
@@ -187,21 +190,21 @@ impl PyAsyncGenASend {
187
190
188
191
#[ pymethod]
189
192
fn send ( & self , val : PyObjectRef , vm : & VirtualMachine ) -> PyResult {
190
- let val = match self . state . get ( ) {
193
+ let val = match self . state . load ( ) {
191
194
AwaitableState :: Closed => {
192
195
return Err ( vm. new_runtime_error (
193
196
"cannot reuse already awaited __anext__()/asend()" . to_owned ( ) ,
194
197
) )
195
198
}
196
199
AwaitableState :: Iter => val, // already running, all good
197
200
AwaitableState :: Init => {
198
- if self . ag . running_async . get ( ) {
201
+ if self . ag . running_async . load ( ) {
199
202
return Err ( vm. new_runtime_error (
200
203
"anext(): asynchronous generator is already running" . to_owned ( ) ,
201
204
) ) ;
202
205
}
203
- self . ag . running_async . set ( true ) ;
204
- self . state . set ( AwaitableState :: Iter ) ;
206
+ self . ag . running_async . store ( true ) ;
207
+ self . state . store ( AwaitableState :: Iter ) ;
205
208
if vm. is_none ( & val) {
206
209
self . value . clone ( )
207
210
} else {
@@ -225,7 +228,7 @@ impl PyAsyncGenASend {
225
228
exc_tb : OptionalArg ,
226
229
vm : & VirtualMachine ,
227
230
) -> PyResult {
228
- if let AwaitableState :: Closed = self . state . get ( ) {
231
+ if let AwaitableState :: Closed = self . state . load ( ) {
229
232
return Err (
230
233
vm. new_runtime_error ( "cannot reuse already awaited __anext__()/asend()" . to_owned ( ) )
231
234
) ;
@@ -246,7 +249,7 @@ impl PyAsyncGenASend {
246
249
247
250
#[ pymethod]
248
251
fn close ( & self ) {
249
- self . state . set ( AwaitableState :: Closed ) ;
252
+ self . state . store ( AwaitableState :: Closed ) ;
250
253
}
251
254
}
252
255
@@ -255,10 +258,12 @@ impl PyAsyncGenASend {
255
258
struct PyAsyncGenAThrow {
256
259
ag : PyAsyncGenRef ,
257
260
aclose : bool ,
258
- state : Cell < AwaitableState > ,
261
+ state : AtomicCell < AwaitableState > ,
259
262
value : ( PyObjectRef , PyObjectRef , PyObjectRef ) ,
260
263
}
261
264
265
+ impl ThreadSafe for PyAsyncGenAThrow { }
266
+
262
267
impl PyValue for PyAsyncGenAThrow {
263
268
fn class ( vm : & VirtualMachine ) -> PyClassRef {
264
269
vm. ctx . types . async_generator_athrow . clone ( )
@@ -283,14 +288,14 @@ impl PyAsyncGenAThrow {
283
288
284
289
#[ pymethod]
285
290
fn send ( & self , val : PyObjectRef , vm : & VirtualMachine ) -> PyResult {
286
- match self . state . get ( ) {
291
+ match self . state . load ( ) {
287
292
AwaitableState :: Closed => {
288
293
Err ( vm
289
294
. new_runtime_error ( "cannot reuse already awaited aclose()/athrow()" . to_owned ( ) ) )
290
295
}
291
296
AwaitableState :: Init => {
292
- if self . ag . running_async . get ( ) {
293
- self . state . set ( AwaitableState :: Closed ) ;
297
+ if self . ag . running_async . load ( ) {
298
+ self . state . store ( AwaitableState :: Closed ) ;
294
299
let msg = if self . aclose {
295
300
"aclose(): asynchronous generator is already running"
296
301
} else {
@@ -299,16 +304,16 @@ impl PyAsyncGenAThrow {
299
304
return Err ( vm. new_runtime_error ( msg. to_owned ( ) ) ) ;
300
305
}
301
306
if self . ag . inner . closed ( ) {
302
- self . state . set ( AwaitableState :: Closed ) ;
307
+ self . state . store ( AwaitableState :: Closed ) ;
303
308
return Err ( vm. new_exception_empty ( vm. ctx . exceptions . stop_iteration . clone ( ) ) ) ;
304
309
}
305
310
if !vm. is_none ( & val) {
306
311
return Err ( vm. new_runtime_error (
307
312
"can't send non-None value to a just-started async generator" . to_owned ( ) ,
308
313
) ) ;
309
314
}
310
- self . state . set ( AwaitableState :: Iter ) ;
311
- self . ag . running_async . set ( true ) ;
315
+ self . state . store ( AwaitableState :: Iter ) ;
316
+ self . ag . running_async . store ( true ) ;
312
317
313
318
let ( ty, val, tb) = self . value . clone ( ) ;
314
319
let ret = self . ag . inner . throw ( ty, val, tb, vm) ;
@@ -368,21 +373,21 @@ impl PyAsyncGenAThrow {
368
373
369
374
#[ pymethod]
370
375
fn close ( & self ) {
371
- self . state . set ( AwaitableState :: Closed ) ;
376
+ self . state . store ( AwaitableState :: Closed ) ;
372
377
}
373
378
374
379
fn ignored_close ( & self , res : & PyResult ) -> bool {
375
380
res. as_ref ( )
376
381
. map_or ( false , |v| v. payload_is :: < PyAsyncGenWrappedValue > ( ) )
377
382
}
378
383
fn yield_close ( & self , vm : & VirtualMachine ) -> PyBaseExceptionRef {
379
- self . ag . running_async . set ( false ) ;
380
- self . state . set ( AwaitableState :: Closed ) ;
384
+ self . ag . running_async . store ( false ) ;
385
+ self . state . store ( AwaitableState :: Closed ) ;
381
386
vm. new_runtime_error ( "async generator ignored GeneratorExit" . to_owned ( ) )
382
387
}
383
388
fn check_error ( & self , exc : PyBaseExceptionRef , vm : & VirtualMachine ) -> PyBaseExceptionRef {
384
- self . ag . running_async . set ( false ) ;
385
- self . state . set ( AwaitableState :: Closed ) ;
389
+ self . ag . running_async . store ( false ) ;
390
+ self . state . store ( AwaitableState :: Closed ) ;
386
391
if self . aclose
387
392
&& ( objtype:: isinstance ( & exc, & vm. ctx . exceptions . stop_async_iteration )
388
393
|| objtype:: isinstance ( & exc, & vm. ctx . exceptions . generator_exit ) )
0 commit comments