diff --git a/Lib/test/test_threading.py b/Lib/test/test_threading.py index 2bbad2dcc9..c9b0b5eda6 100644 --- a/Lib/test/test_threading.py +++ b/Lib/test/test_threading.py @@ -1382,8 +1382,6 @@ def test__all__(self): class InterruptMainTests(unittest.TestCase): - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_interrupt_main_subthread(self): # Calling start_new_thread with a function that executes interrupt_main # should raise KeyboardInterrupt upon completion. @@ -1395,16 +1393,12 @@ def call_interrupt(): t.join() t.join() - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_interrupt_main_mainthread(self): # Make sure that if interrupt_main is called in main thread that # KeyboardInterrupt is raised instantly. with self.assertRaises(KeyboardInterrupt): _thread.interrupt_main() - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_interrupt_main_noerror(self): handler = signal.getsignal(signal.SIGINT) try: diff --git a/vm/src/signal.rs b/vm/src/signal.rs index 2d46c18497..bca7963426 100644 --- a/vm/src/signal.rs +++ b/vm/src/signal.rs @@ -55,6 +55,33 @@ pub(crate) fn set_triggered() { ANY_TRIGGERED.store(true, Ordering::Release); } +pub fn assert_in_range(signum: i32, vm: &VirtualMachine) -> PyResult<()> { + if (1..NSIG as i32).contains(&signum) { + Ok(()) + } else { + Err(vm.new_value_error("signal number out of range".to_owned())) + } +} + +/// Similar to `PyErr_SetInterruptEx` in CPython +/// +/// Missing signal handler for the given signal number is silently ignored. +#[allow(dead_code)] +#[cfg(not(target_arch = "wasm32"))] +pub fn set_interrupt_ex(signum: i32, vm: &VirtualMachine) -> PyResult<()> { + use crate::stdlib::signal::_signal::{run_signal, SIG_DFL, SIG_IGN}; + assert_in_range(signum, vm)?; + + match signum as usize { + SIG_DFL | SIG_IGN => Ok(()), + _ => { + // interrupt the main thread with given signal number + run_signal(signum); + Ok(()) + } + } +} + pub type UserSignal = Box PyResult<()> + Send>; #[derive(Clone, Debug)] diff --git a/vm/src/stdlib/signal.rs b/vm/src/stdlib/signal.rs index 6212ab3e26..173dee28c3 100644 --- a/vm/src/stdlib/signal.rs +++ b/vm/src/stdlib/signal.rs @@ -33,23 +33,23 @@ pub(crate) mod _signal { } #[cfg(unix)] - use nix::unistd::alarm as sig_alarm; + pub use nix::unistd::alarm as sig_alarm; #[cfg(not(windows))] - use libc::SIG_ERR; + pub use libc::SIG_ERR; #[cfg(not(windows))] #[pyattr] - use libc::{SIG_DFL, SIG_IGN}; + pub use libc::{SIG_DFL, SIG_IGN}; #[cfg(windows)] #[pyattr] - const SIG_DFL: libc::sighandler_t = 0; + pub const SIG_DFL: libc::sighandler_t = 0; #[cfg(windows)] #[pyattr] - const SIG_IGN: libc::sighandler_t = 1; + pub const SIG_IGN: libc::sighandler_t = 1; #[cfg(windows)] - const SIG_ERR: libc::sighandler_t = !0; + pub const SIG_ERR: libc::sighandler_t = !0; #[cfg(all(unix, not(target_os = "redox")))] extern "C" { @@ -60,7 +60,7 @@ pub(crate) mod _signal { use crate::signal::NSIG; #[pyattr] - use libc::{SIGABRT, SIGFPE, SIGILL, SIGINT, SIGSEGV, SIGTERM}; + pub use libc::{SIGABRT, SIGFPE, SIGILL, SIGINT, SIGSEGV, SIGTERM}; #[cfg(unix)] #[pyattr] @@ -112,7 +112,7 @@ pub(crate) mod _signal { handler: PyObjectRef, vm: &VirtualMachine, ) -> PyResult> { - assert_in_range(signalnum, vm)?; + signal::assert_in_range(signalnum, vm)?; let signal_handlers = vm .signal_handlers .as_deref() @@ -148,7 +148,7 @@ pub(crate) mod _signal { #[pyfunction] fn getsignal(signalnum: i32, vm: &VirtualMachine) -> PyResult { - assert_in_range(signalnum, vm)?; + signal::assert_in_range(signalnum, vm)?; let signal_handlers = vm .signal_handlers .as_deref() @@ -246,7 +246,7 @@ pub(crate) mod _signal { #[cfg(all(unix, not(target_os = "redox")))] #[pyfunction(name = "siginterrupt")] fn py_siginterrupt(signum: i32, flag: i32, vm: &VirtualMachine) -> PyResult<()> { - assert_in_range(signum, vm)?; + signal::assert_in_range(signum, vm)?; let res = unsafe { siginterrupt(signum, flag) }; if res < 0 { Err(crate::stdlib::os::errno_err(vm)) @@ -255,7 +255,7 @@ pub(crate) mod _signal { } } - extern "C" fn run_signal(signum: i32) { + pub extern "C" fn run_signal(signum: i32) { signal::TRIGGERS[signum as usize].store(true, Ordering::Relaxed); signal::set_triggered(); let wakeup_fd = WAKEUP.load(Ordering::Relaxed); @@ -271,12 +271,4 @@ pub(crate) mod _signal { // TODO: handle _res < 1, support warn_on_full_buffer } } - - fn assert_in_range(signum: i32, vm: &VirtualMachine) -> PyResult<()> { - if (1..NSIG as i32).contains(&signum) { - Ok(()) - } else { - Err(vm.new_value_error("signal number out of range".to_owned())) - } - } } diff --git a/vm/src/stdlib/thread.rs b/vm/src/stdlib/thread.rs index 95faea8095..576ba6a9a8 100644 --- a/vm/src/stdlib/thread.rs +++ b/vm/src/stdlib/thread.rs @@ -304,6 +304,12 @@ pub(crate) mod _thread { vm.state.thread_count.fetch_sub(1); } + #[cfg(not(target_arch = "wasm32"))] + #[pyfunction] + fn interrupt_main(signum: OptionalArg, vm: &VirtualMachine) -> PyResult<()> { + crate::signal::set_interrupt_ex(signum.unwrap_or(libc::SIGINT), vm) + } + #[pyfunction] fn exit(vm: &VirtualMachine) -> PyResult { Err(vm.new_exception_empty(vm.ctx.exceptions.system_exit.to_owned()))