diff --git a/Doc/library/concurrent.futures.rst b/Doc/library/concurrent.futures.rst index 3c8d9ab111e09e..c8c8a45623ef6b 100644 --- a/Doc/library/concurrent.futures.rst +++ b/Doc/library/concurrent.futures.rst @@ -265,7 +265,7 @@ Each worker's interpreter is isolated from all the other interpreters. "Isolated" means each interpreter has its own runtime state and operates completely independently. For example, if you redirect :data:`sys.stdout` in one interpreter, it will not be automatically -redirected any other interpreter. If you import a module in one +redirected to any other interpreter. If you import a module in one interpreter, it is not automatically imported in any other. You would need to import the module separately in interpreter where you need it. In fact, each module imported in an interpreter is @@ -304,10 +304,6 @@ the bytes over a shared :mod:`socket ` or and *initargs* using :mod:`pickle` when sending them to the worker's interpreter. - .. note:: - Functions defined in the ``__main__`` module cannot be pickled - and thus cannot be used. - .. note:: The executor may replace uncaught exceptions from *initializer* with :class:`~concurrent.futures.interpreter.ExecutionFailed`. @@ -326,10 +322,6 @@ except the worker serializes the callable and arguments using :mod:`pickle` when sending them to its interpreter. The worker likewise serializes the return value when sending it back. -.. note:: - Functions defined in the ``__main__`` module cannot be pickled - and thus cannot be used. - When a worker's current task raises an uncaught exception, the worker always tries to preserve the exception as-is. If that is successful then it also sets the ``__cause__`` to a corresponding diff --git a/Lib/concurrent/futures/interpreter.py b/Lib/concurrent/futures/interpreter.py index a2c4fbfd3fb831..537fe34171f99e 100644 --- a/Lib/concurrent/futures/interpreter.py +++ b/Lib/concurrent/futures/interpreter.py @@ -45,12 +45,8 @@ def resolve_task(fn, args, kwargs): # XXX Circle back to this later. raise TypeError('scripts not supported') else: - # Functions defined in the __main__ module can't be pickled, - # so they can't be used here. In the future, we could possibly - # borrow from multiprocessing to work around this. task = (fn, args, kwargs) - data = pickle.dumps(task) - return data + return task if initializer is not None: try: @@ -65,35 +61,6 @@ def create_context(): return cls(initdata, shared) return create_context, resolve_task - @classmethod - @contextlib.contextmanager - def _capture_exc(cls, resultsid): - try: - yield - except BaseException as exc: - # Send the captured exception out on the results queue, - # but still leave it unhandled for the interpreter to handle. - _interpqueues.put(resultsid, (None, exc)) - raise # re-raise - - @classmethod - def _send_script_result(cls, resultsid): - _interpqueues.put(resultsid, (None, None)) - - @classmethod - def _call(cls, func, args, kwargs, resultsid): - with cls._capture_exc(resultsid): - res = func(*args or (), **kwargs or {}) - # Send the result back. - with cls._capture_exc(resultsid): - _interpqueues.put(resultsid, (res, None)) - - @classmethod - def _call_pickled(cls, pickled, resultsid): - with cls._capture_exc(resultsid): - fn, args, kwargs = pickle.loads(pickled) - cls._call(fn, args, kwargs, resultsid) - def __init__(self, initdata, shared=None): self.initdata = initdata self.shared = dict(shared) if shared else None @@ -104,11 +71,56 @@ def __del__(self): if self.interpid is not None: self.finalize() - def _exec(self, script): - assert self.interpid is not None - excinfo = _interpreters.exec(self.interpid, script, restrict=True) + def _call(self, fn, args, kwargs): + def do_call(resultsid, func, *args, **kwargs): + try: + return func(*args, **kwargs) + except BaseException as exc: + # Avoid relying on globals. + import _interpreters + import _interpqueues + # Send the captured exception out on the results queue, + # but still leave it unhandled for the interpreter to handle. + try: + _interpqueues.put(resultsid, exc) + except _interpreters.NotShareableError: + # The exception is not shareable. + import sys + import traceback + print('exception is not shareable:', file=sys.stderr) + traceback.print_exception(exc) + _interpqueues.put(resultsid, None) + raise # re-raise + + args = (self.resultsid, fn, *args) + res, excinfo = _interpreters.call(self.interpid, do_call, args, kwargs) if excinfo is not None: raise ExecutionFailed(excinfo) + return res + + def _get_exception(self): + # Wait for the exception data to show up. + while True: + try: + excdata = _interpqueues.get(self.resultsid) + except _interpqueues.QueueNotFoundError: + raise # re-raise + except _interpqueues.QueueError as exc: + if exc.__cause__ is not None or exc.__context__ is not None: + raise # re-raise + if str(exc).endswith(' is empty'): + continue + else: + raise # re-raise + except ModuleNotFoundError: + # interpreters.queues doesn't exist, which means + # QueueEmpty doesn't. Act as though it does. + continue + else: + break + exc, unboundop = excdata + assert unboundop is None, unboundop + return exc def initialize(self): assert self.interpid is None, self.interpid @@ -119,8 +131,6 @@ def initialize(self): maxsize = 0 self.resultsid = _interpqueues.create(maxsize) - self._exec(f'from {__name__} import WorkerContext') - if self.shared: _interpreters.set___main___attrs( self.interpid, self.shared, restrict=True) @@ -148,37 +158,15 @@ def finalize(self): pass def run(self, task): - data = task - script = f'WorkerContext._call_pickled({data!r}, {self.resultsid})' - + fn, args, kwargs = task try: - self._exec(script) - except ExecutionFailed as exc: - exc_wrapper = exc - else: - exc_wrapper = None - - # Return the result, or raise the exception. - while True: - try: - obj = _interpqueues.get(self.resultsid) - except _interpqueues.QueueNotFoundError: + return self._call(fn, args, kwargs) + except ExecutionFailed as wrapper: + exc = self._get_exception() + if exc is None: + # The exception must have been not shareable. raise # re-raise - except _interpqueues.QueueError: - continue - except ModuleNotFoundError: - # interpreters.queues doesn't exist, which means - # QueueEmpty doesn't. Act as though it does. - continue - else: - break - (res, exc), unboundop = obj - assert unboundop is None, unboundop - if exc is not None: - assert res is None, res - assert exc_wrapper is not None - raise exc from exc_wrapper - return res + raise exc from wrapper class BrokenInterpreterPool(_thread.BrokenThreadPool): diff --git a/Lib/test/test_concurrent_futures/test_init.py b/Lib/test/test_concurrent_futures/test_init.py index df640929309318..6b8484c0d5f197 100644 --- a/Lib/test/test_concurrent_futures/test_init.py +++ b/Lib/test/test_concurrent_futures/test_init.py @@ -20,6 +20,10 @@ def init(x): global INITIALIZER_STATUS INITIALIZER_STATUS = x + # InterpreterPoolInitializerTest.test_initializer fails + # if we don't have a LOAD_GLOBAL. (It could be any global.) + # We will address this separately. + INITIALIZER_STATUS def get_init_status(): return INITIALIZER_STATUS diff --git a/Lib/test/test_concurrent_futures/test_interpreter_pool.py b/Lib/test/test_concurrent_futures/test_interpreter_pool.py index f6c62ae4b2021b..54ee96501eee7a 100644 --- a/Lib/test/test_concurrent_futures/test_interpreter_pool.py +++ b/Lib/test/test_concurrent_futures/test_interpreter_pool.py @@ -2,7 +2,7 @@ import contextlib import io import os -import pickle +import sys import time import unittest from concurrent.futures.interpreter import ( @@ -10,6 +10,8 @@ ) import _interpreters from test import support +from test.support import os_helper +from test.support import script_helper import test.test_asyncio.utils as testasyncio_utils from test.support.interpreters import queues @@ -17,20 +19,62 @@ from .util import BaseTestCase, InterpreterPoolMixin, setup_module +WINDOWS = sys.platform.startswith('win') + + +@contextlib.contextmanager +def nonblocking(fd): + blocking = os.get_blocking(fd) + if blocking: + os.set_blocking(fd, False) + try: + yield + finally: + if blocking: + os.set_blocking(fd, blocking) + + +def read_file_with_timeout(fd, nbytes, timeout): + with nonblocking(fd): + end = time.time() + timeout + try: + return os.read(fd, nbytes) + except BlockingIOError: + pass + while time.time() < end: + try: + return os.read(fd, nbytes) + except BlockingIOError: + continue + else: + raise TimeoutError('nothing to read') + + +if not WINDOWS: + import select + def read_file_with_timeout(fd, nbytes, timeout): + r, _, _ = select.select([fd], [], [], timeout) + if fd not in r: + raise TimeoutError('nothing to read') + return os.read(fd, nbytes) + + def noop(): pass def write_msg(fd, msg): + import os os.write(fd, msg + b'\0') -def read_msg(fd): +def read_msg(fd, timeout=10.0): msg = b'' - while ch := os.read(fd, 1): - if ch == b'\0': - return msg + ch = read_file_with_timeout(fd, 1, timeout) + while ch != b'\0': msg += ch + ch = os.read(fd, 1) + return msg def get_current_name(): @@ -113,6 +157,38 @@ def test_init_func(self): self.assertEqual(before, b'\0') self.assertEqual(after, msg) + def test_init_with___main___global(self): + # See https://github.com/python/cpython/pull/133957#issuecomment-2927415311. + text = """if True: + from concurrent.futures import InterpreterPoolExecutor + + INITIALIZER_STATUS = 'uninitialized' + + def init(x): + global INITIALIZER_STATUS + INITIALIZER_STATUS = x + INITIALIZER_STATUS + + def get_init_status(): + return INITIALIZER_STATUS + + if __name__ == "__main__": + exe = InterpreterPoolExecutor(initializer=init, + initargs=('initialized',)) + fut = exe.submit(get_init_status) + print(fut.result()) # 'initialized' + exe.shutdown(wait=True) + print(INITIALIZER_STATUS) # 'uninitialized' + """ + with os_helper.temp_dir() as tempdir: + filename = script_helper.make_script(tempdir, 'my-script', text) + res = script_helper.assert_python_ok(filename) + stdout = res.out.decode('utf-8').strip() + self.assertEqual(stdout.splitlines(), [ + 'initialized', + 'uninitialized', + ]) + def test_init_closure(self): count = 0 def init1(): @@ -121,10 +197,19 @@ def init2(): nonlocal count count += 1 - with self.assertRaises(pickle.PicklingError): - self.executor_type(initializer=init1) - with self.assertRaises(pickle.PicklingError): - self.executor_type(initializer=init2) + with contextlib.redirect_stderr(io.StringIO()) as stderr: + with self.executor_type(initializer=init1) as executor: + fut = executor.submit(lambda: None) + self.assertIn('NotShareableError', stderr.getvalue()) + with self.assertRaises(BrokenInterpreterPool): + fut.result() + + with contextlib.redirect_stderr(io.StringIO()) as stderr: + with self.executor_type(initializer=init2) as executor: + fut = executor.submit(lambda: None) + self.assertIn('NotShareableError', stderr.getvalue()) + with self.assertRaises(BrokenInterpreterPool): + fut.result() def test_init_instance_method(self): class Spam: @@ -132,8 +217,12 @@ def initializer(self): raise NotImplementedError spam = Spam() - with self.assertRaises(pickle.PicklingError): - self.executor_type(initializer=spam.initializer) + with contextlib.redirect_stderr(io.StringIO()) as stderr: + with self.executor_type(initializer=spam.initializer) as executor: + fut = executor.submit(lambda: None) + self.assertIn('NotShareableError', stderr.getvalue()) + with self.assertRaises(BrokenInterpreterPool): + fut.result() def test_init_shared(self): msg = b'eggs' @@ -178,8 +267,6 @@ def test_init_exception_in_func(self): stderr = stderr.getvalue() self.assertIn('ExecutionFailed: Exception: spam', stderr) self.assertIn('Uncaught in the interpreter:', stderr) - self.assertIn('The above exception was the direct cause of the following exception:', - stderr) @unittest.expectedFailure def test_submit_script(self): @@ -208,10 +295,14 @@ def task2(): return spam executor = self.executor_type() - with self.assertRaises(pickle.PicklingError): - executor.submit(task1) - with self.assertRaises(pickle.PicklingError): - executor.submit(task2) + + fut = executor.submit(task1) + with self.assertRaises(_interpreters.NotShareableError): + fut.result() + + fut = executor.submit(task2) + with self.assertRaises(_interpreters.NotShareableError): + fut.result() def test_submit_local_instance(self): class Spam: @@ -219,8 +310,9 @@ def __init__(self): self.value = True executor = self.executor_type() - with self.assertRaises(pickle.PicklingError): - executor.submit(Spam) + fut = executor.submit(Spam) + with self.assertRaises(_interpreters.NotShareableError): + fut.result() def test_submit_instance_method(self): class Spam: @@ -229,8 +321,9 @@ def run(self): spam = Spam() executor = self.executor_type() - with self.assertRaises(pickle.PicklingError): - executor.submit(spam.run) + fut = executor.submit(spam.run) + with self.assertRaises(_interpreters.NotShareableError): + fut.result() def test_submit_func_globals(self): executor = self.executor_type() @@ -242,6 +335,7 @@ def test_submit_func_globals(self): @unittest.expectedFailure def test_submit_exception_in_script(self): + # Scripts are not supported currently. fut = self.executor.submit('raise Exception("spam")') with self.assertRaises(Exception) as captured: fut.result() @@ -289,13 +383,21 @@ def test_idle_thread_reuse(self): executor.shutdown(wait=True) def test_pickle_errors_propagate(self): - # GH-125864: Pickle errors happen before the script tries to execute, so the - # queue used to wait infinitely. - + # GH-125864: Pickle errors happen before the script tries to execute, + # so the queue used to wait infinitely. fut = self.executor.submit(PickleShenanigans(0)) - with self.assertRaisesRegex(RuntimeError, "gotcha"): + expected = _interpreters.NotShareableError + with self.assertRaisesRegex(expected, 'unpickled'): fut.result() + def test_no_stale_references(self): + # Weak references don't cross between interpreters. + raise unittest.SkipTest('not applicable') + + def test_free_reference(self): + # Weak references don't cross between interpreters. + raise unittest.SkipTest('not applicable') + class AsyncioTest(InterpretersMixin, testasyncio_utils.TestCase): diff --git a/Python/crossinterp.c b/Python/crossinterp.c index 5e73ab28f2b663..bfb14f2cf88b6b 100644 --- a/Python/crossinterp.c +++ b/Python/crossinterp.c @@ -540,6 +540,50 @@ sync_module_clear(struct sync_module *data) } +static PyObject * +get_cached_module_ns(PyThreadState *tstate, + const char *modname, const char *filename) +{ + // Load the module from the original file. + assert(filename != NULL); + PyObject *loaded = NULL; + + const char *run_modname = modname; + if (strcmp(modname, "__main__") == 0) { + // We don't want to trigger "if __name__ == '__main__':". + run_modname = ""; + } + + // First try the per-interpreter cache. + PyObject *interpns = PyInterpreterState_GetDict(tstate->interp); + assert(interpns != NULL); + PyObject *key = PyUnicode_FromFormat("CACHED_MODULE_NS_%s", modname); + if (key == NULL) { + return NULL; + } + if (PyDict_GetItemRef(interpns, key, &loaded) < 0) { + goto finally; + } + if (loaded != NULL) { + goto finally; + } + + // It wasn't already loaded from file. + loaded = runpy_run_path(filename, run_modname); + if (loaded == NULL) { + goto finally; + } + if (PyDict_SetItem(interpns, key, loaded) < 0) { + Py_CLEAR(loaded); + goto finally; + } + +finally: + Py_DECREF(key); + return loaded; +} + + struct _unpickle_context { PyThreadState *tstate; // We only special-case the __main__ module, @@ -574,37 +618,40 @@ _unpickle_context_set_module(struct _unpickle_context *ctx, struct sync_module_result res = {0}; struct sync_module_result *cached = NULL; const char *filename = NULL; - const char *run_modname = modname; if (strcmp(modname, "__main__") == 0) { cached = &ctx->main.cached; filename = ctx->main.filename; - // We don't want to trigger "if __name__ == '__main__':". - run_modname = ""; } else { res.failed = PyExc_NotImplementedError; - goto finally; + goto error; } res.module = import_get_module(ctx->tstate, modname); if (res.module == NULL) { - res.failed = _PyErr_GetRaisedException(ctx->tstate); - assert(res.failed != NULL); - goto finally; + goto error; } + // Load the module ns from the original file and cache it. + // Note that functions will use the cached ns for __globals__, + // not res.module. if (filename == NULL) { - Py_CLEAR(res.module); res.failed = PyExc_NotImplementedError; - goto finally; + goto error; } - res.loaded = runpy_run_path(filename, run_modname); + res.loaded = get_cached_module_ns(ctx->tstate, modname, filename); if (res.loaded == NULL) { - Py_CLEAR(res.module); + goto error; + } + goto finally; + +error: + Py_CLEAR(res.module); + if (res.failed == NULL) { res.failed = _PyErr_GetRaisedException(ctx->tstate); assert(res.failed != NULL); - goto finally; } + assert(!_PyErr_Occurred(ctx->tstate)); finally: if (cached != NULL) { @@ -629,7 +676,8 @@ _handle_unpickle_missing_attr(struct _unpickle_context *ctx, PyObject *exc) } // Get the module. - struct sync_module_result mod = _unpickle_context_get_module(ctx, info.modname); + struct sync_module_result mod = + _unpickle_context_get_module(ctx, info.modname); if (mod.failed != NULL) { // It must have failed previously. return -1;