diff --git a/vm/src/stdlib/sys.rs b/vm/src/stdlib/sys.rs index d6b10f4d2d..bc629dc22a 100644 --- a/vm/src/stdlib/sys.rs +++ b/vm/src/stdlib/sys.rs @@ -10,10 +10,10 @@ mod sys { ascii, hash::{PyHash, PyUHash}, }, + convert::ToPyObject, frame::FrameRef, function::{FuncArgs, OptionalArg, PosArgs}, - stdlib::builtins, - stdlib::warnings::warn, + stdlib::{builtins, warnings::warn}, types::PyStructSequence, version, vm::{Settings, VirtualMachine}, @@ -706,6 +706,68 @@ mod sys { crate::vm::thread::COROUTINE_ORIGIN_TRACKING_DEPTH.with(|cell| cell.get()) as _ } + #[derive(FromArgs)] + struct SetAsyncgenHooksArgs { + #[pyarg(any, optional)] + firstiter: OptionalArg>, + #[pyarg(any, optional)] + finalizer: OptionalArg>, + } + + #[pyfunction] + fn set_asyncgen_hooks(args: SetAsyncgenHooksArgs, vm: &VirtualMachine) -> PyResult<()> { + if let Some(Some(finalizer)) = args.finalizer.as_option() { + if !finalizer.is_callable() { + return Err(vm.new_type_error(format!( + "callable finalizer expected, got {:.50}", + finalizer.class().name() + ))); + } + } + + if let Some(Some(firstiter)) = args.firstiter.as_option() { + if !firstiter.is_callable() { + return Err(vm.new_type_error(format!( + "callable firstiter expected, got {:.50}", + firstiter.class().name() + ))); + } + } + + if let Some(finalizer) = args.finalizer.into_option() { + crate::vm::thread::ASYNC_GEN_FINALIZER.with(|cell| { + cell.replace(finalizer); + }); + } + if let Some(firstiter) = args.firstiter.into_option() { + crate::vm::thread::ASYNC_GEN_FIRSTITER.with(|cell| { + cell.replace(firstiter); + }); + } + + Ok(()) + } + + #[pyclass(no_attr, name = "asyncgen_hooks")] + #[derive(PyStructSequence)] + pub(super) struct PyAsyncgenHooks { + firstiter: PyObjectRef, + finalizer: PyObjectRef, + } + + #[pyclass(with(PyStructSequence))] + impl PyAsyncgenHooks {} + + #[pyfunction] + fn get_asyncgen_hooks(vm: &VirtualMachine) -> PyAsyncgenHooks { + PyAsyncgenHooks { + firstiter: crate::vm::thread::ASYNC_GEN_FIRSTITER + .with(|cell| cell.borrow().clone().to_pyobject(vm)), + finalizer: crate::vm::thread::ASYNC_GEN_FINALIZER + .with(|cell| cell.borrow().clone().to_pyobject(vm)), + } + } + /// sys.flags /// /// Flags provided through command line arguments or environment vars. diff --git a/vm/src/vm/thread.rs b/vm/src/vm/thread.rs index 82110f73da..6ed422ab79 100644 --- a/vm/src/vm/thread.rs +++ b/vm/src/vm/thread.rs @@ -1,4 +1,4 @@ -use crate::{AsObject, PyObject, VirtualMachine}; +use crate::{AsObject, PyObject, PyObjectRef, VirtualMachine}; use itertools::Itertools; use std::{ cell::{Cell, RefCell}, @@ -11,6 +11,8 @@ thread_local! { static VM_CURRENT: RefCell<*const VirtualMachine> = std::ptr::null::().into(); pub(crate) static COROUTINE_ORIGIN_TRACKING_DEPTH: Cell = const { Cell::new(0) }; + pub(crate) static ASYNC_GEN_FINALIZER: RefCell> = const { RefCell::new(None) }; + pub(crate) static ASYNC_GEN_FIRSTITER: RefCell> = const { RefCell::new(None) }; } pub fn with_current_vm(f: impl FnOnce(&VirtualMachine) -> R) -> R {