diff --git a/Doc/c-api/contextvars.rst b/Doc/c-api/contextvars.rst
index 59e74ba1ac7022..90f56578319e21 100644
--- a/Doc/c-api/contextvars.rst
+++ b/Doc/c-api/contextvars.rst
@@ -123,33 +123,22 @@ Context object management functions:
Enumeration of possible context object watcher events:
- - ``Py_CONTEXT_EVENT_ENTER``: A context has been entered, causing the
- :term:`current context` to switch to it. The object passed to the watch
- callback is the now-current :class:`contextvars.Context` object. Each
- enter event will eventually have a corresponding exit event for the same
- context object after any subsequently entered contexts have themselves been
- exited.
- - ``Py_CONTEXT_EVENT_EXIT``: A context is about to be exited, which will
- cause the :term:`current context` to switch back to what it was before the
- context was entered. The object passed to the watch callback is the
- still-current :class:`contextvars.Context` object.
+ - ``Py_CONTEXT_SWITCHED``: The :term:`current context` has switched to a
+ different context. The object passed to the watch callback is the
+ now-current :class:`contextvars.Context` object, or None if no context is
+ current.
.. versionadded:: 3.14
-.. c:type:: int (*PyContext_WatchCallback)(PyContextEvent event, PyContext* ctx)
+.. c:type:: void (*PyContext_WatchCallback)(PyContextEvent event, PyObject* obj)
Context object watcher callback function. The object passed to the callback
is event-specific; see :c:type:`PyContextEvent` for details.
- If the callback returns with an exception set, it must return ``-1``; this
- exception will be printed as an unraisable exception using
- :c:func:`PyErr_FormatUnraisable`. Otherwise it should return ``0``.
+ Any pending exception is cleared before the callback is called and restored
+ after the callback returns.
- There may already be a pending exception set on entry to the callback. In
- this case, the callback should return ``0`` with the same exception still
- set. This means the callback may not call any other API that can set an
- exception unless it saves and clears the exception state first, and restores
- it before returning.
+ If the callback raises an exception it will be ignored.
.. versionadded:: 3.14
diff --git a/Doc/library/contextvars.rst b/Doc/library/contextvars.rst
index 2b1fb9fdd29cd8..faf1e69c58533f 100644
--- a/Doc/library/contextvars.rst
+++ b/Doc/library/contextvars.rst
@@ -150,20 +150,24 @@ Manual Context Management
considered to be *entered*.
*Entering* a context, which can be done by calling its :meth:`~Context.run`
- method, makes the context the current context by pushing it onto the top of
- the current thread's context stack.
+ method or by using it as a :term:`context manager`, makes the context the
+ current context by pushing it onto the top of the current thread's context
+ stack.
*Exiting* from the current context, which can be done by returning from the
- callback passed to the :meth:`~Context.run` method, restores the current
- context to what it was before the context was entered by popping the context
- off the top of the context stack.
+ callback passed to :meth:`~Context.run` or by exiting the :keyword:`with`
+ statement suite, restores the current context to what it was before the
+ context was entered by popping the context off the top of the context stack.
Since each thread has its own context stack, :class:`ContextVar` objects
behave in a similar fashion to :func:`threading.local` when values are
assigned in different threads.
- Attempting to enter an already entered context, including contexts entered in
- other threads, raises a :exc:`RuntimeError`.
+ Attempting to do either of the following raises a :exc:`RuntimeError`:
+
+ * Entering an already entered context, including contexts entered in
+ other threads.
+ * Exiting from a context that is not the current context.
After exiting a context, it can later be re-entered (from any thread).
@@ -176,6 +180,50 @@ Manual Context Management
Context implements the :class:`collections.abc.Mapping` interface.
+ .. versionadded:: 3.14
+ Added support for the :term:`context management protocol`.
+
+ When used as a :term:`context manager`, the value bound to the identifier
+ given in the :keyword:`with` statement's :keyword:`!as` clause (if present)
+ is the :class:`!Context` object itself.
+
+ Example:
+
+ .. testcode::
+
+ import contextvars
+
+ var = contextvars.ContextVar("var")
+ var.set("initial")
+ print(var.get()) # 'initial'
+
+ # Copy the current Context and enter the copy.
+ with contextvars.copy_context() as ctx:
+ var.set("updated")
+ print(var in ctx) # 'True'
+ print(ctx[var]) # 'updated'
+ print(var.get()) # 'updated'
+
+ # Exited ctx, so the observed value of var has reverted.
+ print(var.get()) # 'initial'
+ # But the updated value is still recorded in ctx.
+ print(ctx[var]) # 'updated'
+
+ # Re-entering ctx restores the updated value of var.
+ with ctx:
+ print(var.get()) # 'updated'
+
+ .. testoutput::
+ :hide:
+
+ initial
+ True
+ updated
+ updated
+ initial
+ updated
+ updated
+
.. method:: run(callable, *args, **kwargs)
Enters the Context, executes ``callable(*args, **kwargs)``, then exits the
diff --git a/Doc/whatsnew/3.14.rst b/Doc/whatsnew/3.14.rst
index c62a3ca5872eef..eea6d96361a1ac 100644
--- a/Doc/whatsnew/3.14.rst
+++ b/Doc/whatsnew/3.14.rst
@@ -226,6 +226,13 @@ ast
(Contributed by Tomas R in :gh:`116022`.)
+contextvars
+-----------
+
+* Added support for the :term:`context management protocol` to
+ :class:`contextvars.Context`. (Contributed by Richard Hansen in :gh:`99634`.)
+
+
ctypes
------
diff --git a/Include/cpython/context.h b/Include/cpython/context.h
index d722b4d93134f7..841376bc79f2d5 100644
--- a/Include/cpython/context.h
+++ b/Include/cpython/context.h
@@ -29,30 +29,23 @@ PyAPI_FUNC(int) PyContext_Exit(PyObject *);
typedef enum {
/*
- * A context has been entered, causing the "current context" to switch to
- * it. The object passed to the watch callback is the now-current
- * contextvars.Context object. Each enter event will eventually have a
- * corresponding exit event for the same context object after any
- * subsequently entered contexts have themselves been exited.
+ * The current context has switched to a different context. The object
+ * passed to the watch callback is the now-current contextvars.Context
+ * object, or None if no context is current.
*/
- Py_CONTEXT_EVENT_ENTER,
- /*
- * A context is about to be exited, which will cause the "current context"
- * to switch back to what it was before the context was entered. The
- * object passed to the watch callback is the still-current
- * contextvars.Context object.
- */
- Py_CONTEXT_EVENT_EXIT,
+ Py_CONTEXT_SWITCHED = 1,
} PyContextEvent;
/*
* Context object watcher callback function. The object passed to the callback
* is event-specific; see PyContextEvent for details.
*
- * if the callback returns with an exception set, it must return -1. Otherwise
- * it should return 0
+ * Any pending exception is cleared before the callback is called and restored
+ * after the callback returns.
+ *
+ * If the callback raises an exception it will be ignored.
*/
-typedef int (*PyContext_WatchCallback)(PyContextEvent, PyContext *);
+typedef void (*PyContext_WatchCallback)(PyContextEvent, PyObject *);
/*
* Register a per-interpreter callback that will be invoked for context object
diff --git a/Include/cpython/pystate.h b/Include/cpython/pystate.h
index 32f68378ea5d72..79e23435c2b621 100644
--- a/Include/cpython/pystate.h
+++ b/Include/cpython/pystate.h
@@ -164,7 +164,6 @@ struct _ts {
PyObject *async_gen_firstiter;
PyObject *async_gen_finalizer;
- PyObject *context;
uint64_t context_ver;
/* Unique thread state id. */
diff --git a/Include/internal/pycore_context.h b/Include/internal/pycore_context.h
index c2b98d15da68fa..08c59d8c27d321 100644
--- a/Include/internal/pycore_context.h
+++ b/Include/internal/pycore_context.h
@@ -5,7 +5,11 @@
# error "this header requires Py_BUILD_CORE define"
#endif
+#include "cpython/context.h"
+#include "cpython/genobject.h" // PyGenObject
+#include "pycore_genobject.h" // PyGenObject
#include "pycore_hamt.h" // PyHamtObject
+#include "pycore_tstate.h" // _PyThreadStateImpl
#define CONTEXT_MAX_WATCHERS 8
@@ -24,12 +28,81 @@ typedef struct {
struct _pycontextobject {
PyObject_HEAD
- PyContext *ctx_prev;
+ PyObject *ctx_prev;
PyHamtObject *ctx_vars;
PyObject *ctx_weakreflist;
int ctx_entered;
};
+// Resets a coroutine's independent context stack to ctx. If ctx is NULL or
+// Py_None, the coroutine will be a dependent coroutine (its context stack will
+// be empty) upon successful return. Otherwise, the coroutine will be an
+// independent coroutine upon successful return, with ctx as the sole item on
+// its context stack.
+//
+// The coroutine's existing stack must be empty (NULL) or contain only a single
+// entry (from a previous call to this function). If the coroutine is
+// currently executing, this function must be called from the coroutine's
+// thread.
+//
+// Unless ctx already equals the coroutine's existing context stack, the
+// context on the existing stack (if one exists) is immediately exited and ctx
+// (if non-NULL) is immediately entered.
+int _PyGen_ResetContext(PyThreadState *ts, PyGenObject *self, PyObject *ctx);
+
+void _PyGen_ActivateContextImpl(_PyThreadStateImpl *tsi, PyGenObject *self);
+void _PyGen_DeactivateContextImpl(_PyThreadStateImpl *tsi, PyGenObject *self);
+
+// Makes the given coroutine's context stack the active context stack for the
+// thread, shadowing (temporarily deactivating) the thread's previously active
+// context stack. The context stack remains active until deactivated with a
+// call to _PyGen_DeactivateContext, as long as it is not shadowed by another
+// activated context stack.
+//
+// Each activated context stack must eventually be deactivated by calling
+// _PyGen_DeactivateContext. The same context stack cannot be activated again
+// until deactivated.
+//
+// If the coroutine's context stack is empty this function has no effect.
+//
+// This is called each time a value is sent to a coroutine, so it is inlined to
+// avoid function call overhead in the common case of dependent coroutines.
+static inline void
+_PyGen_ActivateContext(PyThreadState *ts, PyGenObject *self)
+{
+ _PyThreadStateImpl *tsi = (_PyThreadStateImpl *)ts;
+ assert(self->_ctx_chain.prev == NULL);
+ if (self->_ctx_chain.ctx == NULL) {
+ return;
+ }
+ _PyGen_ActivateContextImpl(tsi, self);
+}
+
+// Deactivates the given coroutine's context stack, un-shadowing (reactivating)
+// the thread's previously active context stack. Does not affect any contexts
+// in the coroutine's context stack (they remain entered).
+//
+// Must not be called if a different context stack is currently shadowing the
+// coroutine's context stack.
+//
+// If the coroutine's context stack is not the active context stack this
+// function has no effect.
+//
+// This is called each time a value is sent to a coroutine, so it is inlined to
+// avoid function call overhead in the common case of dependent coroutines.
+static inline void
+_PyGen_DeactivateContext(PyThreadState *ts, PyGenObject *self)
+{
+ _PyThreadStateImpl *tsi = (_PyThreadStateImpl *)ts;
+ if (tsi->_ctx_chain.prev != &self->_ctx_chain) {
+ assert(self->_ctx_chain.ctx == NULL);
+ assert(self->_ctx_chain.prev == NULL);
+ return;
+ }
+ assert(self->_ctx_chain.ctx != NULL);
+ _PyGen_DeactivateContextImpl(tsi, self);
+}
+
struct _pycontextvarobject {
PyObject_HEAD
diff --git a/Include/internal/pycore_contextchain.h b/Include/internal/pycore_contextchain.h
new file mode 100644
index 00000000000000..5b02726f080dfd
--- /dev/null
+++ b/Include/internal/pycore_contextchain.h
@@ -0,0 +1,79 @@
+#ifndef Py_INTERNAL_CONTEXTCHAIN_H
+#define Py_INTERNAL_CONTEXTCHAIN_H
+
+#ifndef Py_BUILD_CORE
+# error "this header requires Py_BUILD_CORE define"
+#endif
+
+#include "pytypedefs.h" // PyObject
+
+
+// Circularly linked chain of multiple independent context stacks, used to give
+// coroutines (including generators) their own (optional) independent context
+// stacks.
+//
+// Detailed notes on how this chain is used:
+// * The chain is circular simply to save a pointer's worth of memory in
+// _PyThreadStateImpl. It is actually used as an ordinary linear linked
+// list. It is called "chain" instead of "stack" or "list" to evoke "call
+// chain", which it is related to, and to avoid confusion with "context
+// stack".
+// * There is one chain per thread, and _PyThreadStateImpl::_ctx_chain::prev
+// points to the head of the thread's chain.
+// * A thread's chain is never empty.
+// * _PyThreadStateImpl::_ctx_chain is always the tail entry of the thread's
+// chain.
+// * _PyThreadStateImpl::_ctx_chain is usually the only link in the thread's
+// chain, so _PyThreadStateImpl::_ctx_chain::prev usually points to the
+// _PyThreadStateImpl::_ctx_chain itself.
+// * The "active context stack" is always at the head link in a thread's
+// context chain. Contexts are entered by pushing onto the active context
+// stack and exited by popping off of the active context stack.
+// * The "current context" is the top context in the active context stack.
+// Context variable accesses (reads/writes) use the current context.
+// * A *dependent* coroutine or generator is a coroutine or generator that
+// does not have its own independent context stack. When a dependent
+// coroutine starts or resumes execution, the current context -- as
+// observed by the coroutine -- is the same context that was current just
+// before the coroutine's `send` method was called. This means that the
+// current context as observed by a dependent coroutine can change
+// arbitrarily during a yield/await. Dependent coroutines are so-named
+// because they depend on their senders to enter the appropriate context
+// before each send. Coroutines and generators are dependent by default
+// for backwards compatibility.
+// * The purpose of the context chain is to enable *independent* coroutines
+// and generators, which have their own context stacks. Whenever an
+// independent coroutine starts or resumes execution, the current context
+// automatically switches to the context associated with the coroutine.
+// This is accomplished by linking the coroutine's chain link (at
+// PyGenObject::_ctx_chain) to the head of the thread's chain. Independent
+// coroutines are so-named because they do not depend on their senders to
+// enter the appropriate context before each send.
+// * The head link is unlinked from the thread's chain when its associated
+// independent coroutine or generator stops executing (yields, awaits,
+// returns, or throws).
+// * A running dependent coroutine's chain link is linked into the thread's
+// chain if the coroutine is upgraded from dependent to independent by
+// assigning a context to the coroutine's `_context` property. The chain
+// link is inserted at the position corresponding to the coroutine's
+// position in the call chain relative to any other currently running
+// independent coroutines. For example, if dependent coroutine `coro_a`
+// calls function `func_b` which resumes independent coroutine `coro_c`
+// which assigns a context to `coro_a._context`, then `coro_a` becomes an
+// independent coroutine with its chain link inserted after `coro_c`'s
+// chain link (which remains the head link).
+// * A running independent coroutine's chain link is unlinked from the
+// thread's chain if the coroutine is downgraded from independent to
+// dependent by assigning `None` to its `_context` property.
+// * The references to the object at the `prev` link in the chain are
+// implicit (borrowed).
+typedef struct _PyContextChain {
+ // NULL for dependent coroutines/generators, non-NULL for independent
+ // coroutines/generators.
+ PyObject *ctx;
+ // NULL if unlinked from the thread's context chain, non-NULL otherwise.
+ struct _PyContextChain *prev;
+} _PyContextChain;
+
+
+#endif /* !Py_INTERNAL_CONTEXTCHAIN_H */
diff --git a/Include/internal/pycore_genobject.h b/Include/internal/pycore_genobject.h
index f6d7e6d367177b..9ef4a9e79d3b55 100644
--- a/Include/internal/pycore_genobject.h
+++ b/Include/internal/pycore_genobject.h
@@ -1,5 +1,8 @@
#ifndef Py_INTERNAL_GENOBJECT_H
#define Py_INTERNAL_GENOBJECT_H
+
+#include "pycore_contextchain.h" // _PyContextChain
+
#ifdef __cplusplus
extern "C" {
#endif
@@ -22,6 +25,7 @@ extern "C" {
PyObject *prefix##_qualname; \
_PyErr_StackItem prefix##_exc_state; \
PyObject *prefix##_origin_or_finalizer; \
+ _PyContextChain _ctx_chain; \
char prefix##_hooks_inited; \
char prefix##_closed; \
char prefix##_running_async; \
diff --git a/Include/internal/pycore_tstate.h b/Include/internal/pycore_tstate.h
index a72ef4493b77ca..1e511297e0b61e 100644
--- a/Include/internal/pycore_tstate.h
+++ b/Include/internal/pycore_tstate.h
@@ -9,6 +9,7 @@ extern "C" {
#endif
#include "pycore_brc.h" // struct _brc_thread_state
+#include "pycore_contextchain.h" // _PyContextChain
#include "pycore_freelist_state.h" // struct _Py_freelists
#include "pycore_mimalloc.h" // struct _mimalloc_thread_state
#include "pycore_qsbr.h" // struct qsbr
@@ -21,6 +22,9 @@ typedef struct _PyThreadStateImpl {
// semi-public fields are in PyThreadState.
PyThreadState base;
+ // Lazily initialized (must be zeroed at startup).
+ _PyContextChain _ctx_chain;
+
PyObject *asyncio_running_loop; // Strong reference
struct _qsbr_thread_state *qsbr; // only used by free-threaded build
diff --git a/Lib/test/test_capi/test_watchers.py b/Lib/test/test_capi/test_watchers.py
index f21d2627c6094b..085266554cb8f6 100644
--- a/Lib/test/test_capi/test_watchers.py
+++ b/Lib/test/test_capi/test_watchers.py
@@ -577,68 +577,72 @@ class TestContextObjectWatchers(unittest.TestCase):
def context_watcher(self, which_watcher):
wid = _testcapi.add_context_watcher(which_watcher)
try:
- yield wid
+ switches = _testcapi.get_context_switches(which_watcher)
+ except ValueError:
+ switches = None
+ try:
+ yield switches
finally:
_testcapi.clear_context_watcher(wid)
- def assert_event_counts(self, exp_enter_0, exp_exit_0,
- exp_enter_1, exp_exit_1):
- self.assertEqual(
- exp_enter_0, _testcapi.get_context_watcher_num_enter_events(0))
- self.assertEqual(
- exp_exit_0, _testcapi.get_context_watcher_num_exit_events(0))
- self.assertEqual(
- exp_enter_1, _testcapi.get_context_watcher_num_enter_events(1))
- self.assertEqual(
- exp_exit_1, _testcapi.get_context_watcher_num_exit_events(1))
+ def assert_event_counts(self, want_0, want_1):
+ self.assertEqual(len(_testcapi.get_context_switches(0)), want_0)
+ self.assertEqual(len(_testcapi.get_context_switches(1)), want_1)
def test_context_object_events_dispatched(self):
# verify that all counts are zero before any watchers are registered
- self.assert_event_counts(0, 0, 0, 0)
+ self.assert_event_counts(0, 0)
# verify that all counts remain zero when a context object is
# entered and exited with no watchers registered
ctx = contextvars.copy_context()
- ctx.run(self.assert_event_counts, 0, 0, 0, 0)
- self.assert_event_counts(0, 0, 0, 0)
+ ctx.run(self.assert_event_counts, 0, 0)
+ self.assert_event_counts(0, 0)
# verify counts are as expected when first watcher is registered
with self.context_watcher(0):
- self.assert_event_counts(0, 0, 0, 0)
- ctx.run(self.assert_event_counts, 1, 0, 0, 0)
- self.assert_event_counts(1, 1, 0, 0)
+ self.assert_event_counts(0, 0)
+ ctx.run(self.assert_event_counts, 1, 0)
+ self.assert_event_counts(2, 0)
# again with second watcher registered
with self.context_watcher(1):
- self.assert_event_counts(1, 1, 0, 0)
- ctx.run(self.assert_event_counts, 2, 1, 1, 0)
- self.assert_event_counts(2, 2, 1, 1)
+ self.assert_event_counts(2, 0)
+ ctx.run(self.assert_event_counts, 3, 1)
+ self.assert_event_counts(4, 2)
# verify counts are reset and don't change after both watchers are cleared
- ctx.run(self.assert_event_counts, 0, 0, 0, 0)
- self.assert_event_counts(0, 0, 0, 0)
-
- def test_enter_error(self):
+ ctx.run(self.assert_event_counts, 0, 0)
+ self.assert_event_counts(0, 0)
+
+ def test_callback_error(self):
+ ctx_outer = contextvars.copy_context()
+ ctx_inner = contextvars.copy_context()
+ unraisables = []
+
+ def _in_outer():
+ with self.context_watcher(2):
+ with catch_unraisable_exception() as cm:
+ ctx_inner.run(lambda: unraisables.append(cm.unraisable))
+ unraisables.append(cm.unraisable)
+
+ ctx_outer.run(_in_outer)
+ self.assertEqual([x.err_msg for x in unraisables],
+ ["Exception ignored in Py_CONTEXT_SWITCHED "
+ f"watcher callback for {ctx!r}"
+ for ctx in [ctx_inner, ctx_outer]])
+ self.assertEqual([str(x.exc_value) for x in unraisables],
+ ["boom!", "boom!"])
+
+ def test_exception_save(self):
with self.context_watcher(2):
with catch_unraisable_exception() as cm:
- ctx = contextvars.copy_context()
- ctx.run(int, 0)
- self.assertEqual(
- cm.unraisable.err_msg,
- "Exception ignored in "
- f"Py_CONTEXT_EVENT_EXIT watcher callback for {ctx!r}"
- )
- self.assertEqual(str(cm.unraisable.exc_value), "boom!")
-
- def test_exit_error(self):
- ctx = contextvars.copy_context()
- def _in_context(stack):
- stack.enter_context(self.context_watcher(2))
+ def _in_context():
+ raise RuntimeError("test")
- with catch_unraisable_exception() as cm:
- with ExitStack() as stack:
- ctx.run(_in_context, stack)
- self.assertEqual(str(cm.unraisable.exc_value), "boom!")
+ with self.assertRaisesRegex(RuntimeError, "test"):
+ contextvars.copy_context().run(_in_context)
+ self.assertEqual(str(cm.unraisable.exc_value), "boom!")
def test_clear_out_of_range_watcher_id(self):
with self.assertRaisesRegex(ValueError, r"Invalid context watcher ID -1"):
@@ -654,5 +658,12 @@ def test_allocate_too_many_watchers(self):
with self.assertRaisesRegex(RuntimeError, r"no more context watcher IDs available"):
_testcapi.allocate_too_many_context_watchers()
+ def test_exit_base_context(self):
+ ctx = contextvars.Context()
+ _testcapi.clear_context_stack()
+ with self.context_watcher(0) as switches:
+ ctx.run(lambda: None)
+ self.assertEqual(switches, [ctx, None])
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_context.py b/Lib/test/test_context.py
index b06b9df9f5b0b8..04d4d1f42ed209 100644
--- a/Lib/test/test_context.py
+++ b/Lib/test/test_context.py
@@ -1,8 +1,11 @@
+import asyncio
import concurrent.futures
+import contextlib
import contextvars
import functools
import gc
import random
+import threading
import time
import unittest
import weakref
@@ -369,6 +372,498 @@ def sub(num):
tp.shutdown()
self.assertEqual(results, list(range(10)))
+ @isolated_context
+ def test_context_manager(self):
+ cvar = contextvars.ContextVar('cvar', default='initial')
+ self.assertEqual(cvar.get(), 'initial')
+ with contextvars.copy_context():
+ self.assertEqual(cvar.get(), 'initial')
+ cvar.set('updated')
+ self.assertEqual(cvar.get(), 'updated')
+ self.assertEqual(cvar.get(), 'initial')
+
+ def test_context_manager_as_binding(self):
+ ctx = contextvars.copy_context()
+ with ctx as ctx_as_binding:
+ self.assertIs(ctx_as_binding, ctx)
+
+ @isolated_context
+ def test_context_manager_nested(self):
+ cvar = contextvars.ContextVar('cvar', default='default')
+ with contextvars.copy_context() as outer_ctx:
+ cvar.set('outer')
+ with contextvars.copy_context() as inner_ctx:
+ self.assertIsNot(outer_ctx, inner_ctx)
+ self.assertEqual(cvar.get(), 'outer')
+ cvar.set('inner')
+ self.assertEqual(outer_ctx[cvar], 'outer')
+ self.assertEqual(cvar.get(), 'inner')
+ self.assertEqual(cvar.get(), 'outer')
+ self.assertEqual(cvar.get(), 'default')
+
+ @isolated_context
+ def test_context_manager_enter_again_after_exit(self):
+ cvar = contextvars.ContextVar('cvar', default='initial')
+ self.assertEqual(cvar.get(), 'initial')
+ with contextvars.copy_context() as ctx:
+ cvar.set('updated')
+ self.assertEqual(cvar.get(), 'updated')
+ self.assertEqual(cvar.get(), 'initial')
+ with ctx:
+ self.assertEqual(cvar.get(), 'updated')
+ self.assertEqual(cvar.get(), 'initial')
+
+ @threading_helper.requires_working_threading()
+ def test_context_manager_rejects_exit_from_different_thread(self):
+ ctx = contextvars.copy_context()
+ thread = threading.Thread(target=ctx.__enter__)
+ thread.start()
+ thread.join()
+ with self.assertRaises(RuntimeError):
+ ctx.__exit__(None, None, None)
+
+ def test_context_manager_is_not_reentrant(self):
+ with self.subTest('context manager then context manager'):
+ with contextvars.copy_context() as ctx:
+ with self.assertRaises(RuntimeError):
+ with ctx:
+ pass
+ with self.subTest('context manager then run method'):
+ with contextvars.copy_context() as ctx:
+ with self.assertRaises(RuntimeError):
+ ctx.run(lambda: None)
+ with self.subTest('run method then context manager'):
+ ctx = contextvars.copy_context()
+
+ def fn():
+ with self.assertRaises(RuntimeError):
+ with ctx:
+ pass
+
+ ctx.run(fn)
+
+ def test_context_manager_rejects_noncurrent_exit(self):
+ with contextvars.copy_context() as outer_ctx:
+ with contextvars.copy_context() as inner_ctx:
+ self.assertIsNot(outer_ctx, inner_ctx)
+ with self.assertRaises(RuntimeError):
+ outer_ctx.__exit__(None, None, None)
+
+ def test_context_manager_rejects_nonentered_exit(self):
+ ctx = contextvars.copy_context()
+ with self.assertRaises(RuntimeError):
+ ctx.__exit__(None, None, None)
+
+
+class GeneratorContextTest(unittest.TestCase):
+ def test_default_is_none(self):
+ def makegen():
+ yield 1
+
+ gen = makegen()
+ self.assertIsNone(gen._context)
+
+ def test_none_is_dependent(self):
+ """Test behavior when the generator's context is set to None.
+
+ The generator should use the thread's context whenever it starts or
+ resumes execution. This means that the current context as observed by
+ the generator can change arbitrarily during a yield. This is the
+ behavior of older versions of Python, so for backwards compatibility it
+ should remain the default behavior.
+ """
+ cvar = contextvars.ContextVar('cvar', default='initial')
+
+ def makegen():
+ while True:
+ yield cvar.get()
+
+ gen = makegen()
+ self.assertEqual(next(gen), 'initial')
+ cvar.set('updated outer')
+ self.assertEqual(next(gen), 'updated outer')
+
+ def cb():
+ cvar.set('updated inner')
+ return next(gen)
+
+ self.assertEqual(contextvars.copy_context().run(cb), 'updated inner')
+ self.assertEqual(next(gen), 'updated outer')
+
+ def test_dependent_to_dependent(self):
+ """Test resetting an already-dependent generator's context to None."""
+ def makegen():
+ yield 1
+
+ gen = makegen()
+ gen._context = None
+ self.assertIsNone(gen._context)
+
+ def test_dependent_to_independent(self):
+ """Test upgrading a dependent generator to independent."""
+ cvar = contextvars.ContextVar('cvar', default='initial')
+
+ def makegen():
+ while True:
+ yield cvar.get()
+
+ gen = makegen()
+ ctx = contextvars.copy_context()
+ ctx.run(lambda: cvar.set('independent'))
+ gen._context = ctx
+ self.assertIs(gen._context, ctx)
+ with self.assertRaisesRegex(RuntimeError, 'already entered'):
+ ctx.run(lambda: None)
+ self.assertEqual(next(gen), 'independent')
+
+ def cb():
+ cvar.set('new context')
+ return next(gen)
+
+ self.assertEqual(contextvars.copy_context().run(cb), 'independent')
+ self.assertEqual(next(gen), 'independent')
+
+ def test_independent_to_dependent(self):
+ """Test downgrading an independent generator to dependent."""
+ cvar = contextvars.ContextVar('cvar', default='initial')
+
+ def makegen():
+ while True:
+ yield cvar.get()
+
+ gen = makegen()
+ ctx = contextvars.copy_context()
+ ctx.run(lambda: cvar.set('independent'))
+ gen._context = ctx
+ gen._context = None
+ ctx.run(lambda: cvar.set('independent not entered anymore'))
+
+ def cb():
+ cvar.set('dependent')
+ return next(gen)
+
+ self.assertEqual(contextvars.copy_context().run(cb), 'dependent')
+ self.assertEqual(next(gen), 'initial')
+
+ def test_independent_to_independent_same(self):
+ """Test resetting an independent generator's ctx to the same ctx."""
+ cvar = contextvars.ContextVar('cvar', default='initial')
+
+ def makegen():
+ while True:
+ yield cvar.get()
+
+ gen = makegen()
+ ctx = contextvars.copy_context()
+ ctx.run(lambda: cvar.set('independent'))
+ gen._context = ctx
+ gen._context = ctx
+ self.assertIs(gen._context, ctx)
+ with self.assertRaisesRegex(RuntimeError, 'already entered'):
+ ctx.run(lambda: None)
+ self.assertEqual(next(gen), 'independent')
+ self.assertEqual(contextvars.copy_context().run(lambda: next(gen)),
+ 'independent')
+
+ def test_independent_to_independent_different(self):
+ """Test resetting an independent generator's ctx to a different ctx."""
+ cvar = contextvars.ContextVar('cvar', default='initial')
+
+ def makegen():
+ while True:
+ yield cvar.get()
+
+ gen = makegen()
+
+ ctx1 = contextvars.copy_context()
+ ctx1.run(lambda: cvar.set('independent1'))
+ gen._context = ctx1
+ self.assertIs(gen._context, ctx1)
+ self.assertEqual(next(gen), 'independent1')
+ with self.assertRaisesRegex(RuntimeError, 'already entered'):
+ ctx1.run(lambda: None)
+
+ ctx2 = contextvars.copy_context()
+ ctx2.run(lambda: cvar.set('independent2'))
+ gen._context = ctx2
+ self.assertIs(gen._context, ctx2)
+ self.assertEqual(next(gen), 'independent2')
+ with self.assertRaisesRegex(RuntimeError, 'already entered'):
+ ctx2.run(lambda: None)
+
+ ctx1.run(lambda: None) # Check that ctx1 is no longer entered.
+
+ def test_entering_updates__context(self):
+ """Entering another ctx from an indep generator updates _context."""
+ ctx1 = contextvars.copy_context()
+ ctx2 = contextvars.copy_context()
+
+ def makegen():
+ gen = yield
+ yield gen._context
+ yield ctx2.run(lambda: gen._context)
+
+ gen = makegen()
+ gen._context = ctx1
+ gen.send(None)
+ self.assertIs(gen.send(gen), ctx1)
+ self.assertIs(gen.send(None), ctx2)
+ self.assertIs(gen._context, ctx1)
+
+ def test_reset_while_another_entered_is_error(self):
+ """Resetting indep gen's ctx while ctx stack non-empty is an error."""
+ cvar = contextvars.ContextVar('cvar', default='initial')
+ ctx1_outer = contextvars.copy_context()
+ ctx1_outer.run(lambda: cvar.set('independent1 outer'))
+ ctx2 = contextvars.copy_context()
+ ctx2.run(lambda: cvar.set('independent2'))
+
+ def makegen():
+ gen = yield cvar.get()
+ ctx1_inner = ctx1_outer.copy()
+ ctx1_inner.run(lambda: cvar.set('independent1 inner'))
+
+ def cb():
+ self.assertIs(gen._context, ctx1_inner)
+ with self.assertRaisesRegex(RuntimeError, 'cannot reset'):
+ gen._context = ctx2
+ ctx2.run(lambda: None) # Check that ctx2 is still not entered.
+ return cvar.get()
+
+ yield ctx1_inner.run(cb)
+
+ gen = makegen()
+ gen._context = ctx1_outer
+ self.assertIs(gen._context, ctx1_outer)
+ with self.assertRaisesRegex(RuntimeError, 'already entered'):
+ ctx1_outer.run(lambda: None)
+ self.assertEqual(next(gen), 'independent1 outer')
+ self.assertEqual(gen.send(gen), 'independent1 inner')
+ self.assertIs(gen._context, ctx1_outer)
+
+ def test_generator_calls_generator(self):
+ """Stresses deep shadowing/unshadowing of context stacks."""
+ cvar = contextvars.ContextVar('cvar', default='initial')
+ ctx_inner = contextvars.copy_context()
+ ctx_inner.run(lambda: cvar.set('inner'))
+ ctx_outer = contextvars.copy_context()
+ ctx_outer.run(lambda: cvar.set('outer'))
+
+ def makegen_inner():
+ while True:
+ yield cvar.get()
+
+ gen_inner = makegen_inner()
+ gen_inner._context = ctx_inner
+
+ def makegen_outer():
+ while True:
+ yield cvar.get(), next(gen_inner)
+
+ gen_outer = makegen_outer()
+ gen_outer._context = ctx_outer
+
+ for _ in range(5):
+ self.assertEqual(cvar.get(), 'initial')
+ self.assertEqual(next(gen_inner), 'inner')
+ self.assertEqual(next(gen_outer), ('outer', 'inner'))
+
+ def test_dependent_to_independent_from_called_generator(self):
+ """Upgrade generator when it is not the top indep gen in the call chain.
+
+ Upgrading a running generator from dependent to independent usually
+ causes its context stack to immediately become the visible context stack
+ by shadowing the previously visible context stack. However, if the
+ upgraded generator is not the topmost independent generator in the call
+ chain (the upgraded generator is running another independent generator),
+ its context stack should not become visible. Only when all generators
+ in the call chain above it have returned/yielded/thrown should its
+ context stack finally become visible.
+
+ Summary:
+ * thread runs dependent gen_outer which runs independent gen_inner:
+ - gen_inner's context stack shadows thread's context stack
+ * thread runs dependent gen_outer which runs independent gen_inner
+ which upgrades gen_outer to independent:
+ - gen_inner's context stack shadows gen_outer's context stack
+ which shadows thread's context stack
+
+ """
+ cvar = contextvars.ContextVar('cvar', default='initial')
+ ctx_inner = contextvars.copy_context()
+ ctx_inner.run(lambda: cvar.set('inner'))
+ ctx_outer = contextvars.copy_context()
+ ctx_outer.run(lambda: cvar.set('outer'))
+
+ def makegen_inner():
+ gen_outer = yield
+ while True:
+ gen_outer._context = yield cvar.get()
+
+ gen_inner = makegen_inner()
+ gen_inner._context = ctx_inner
+
+ def makegen_outer():
+ ctx = None
+ while True:
+ # Send the context to the inner generator before reading the
+ # context variable's value from this outer generator so that the
+ # inner generator can reset this outer generator's context
+ # before the read.
+ ctx = yield gen_inner.send(ctx), cvar.get()
+
+ gen_outer = makegen_outer()
+ gen_outer._context = None # Intentionally dependent.
+ gen_inner.send(None)
+ self.assertEqual(gen_inner.send(gen_outer), 'inner')
+
+ for _ in range(5):
+ self.assertEqual(gen_outer.send(None), ('inner', 'initial'))
+ cvar.set('updated')
+ for _ in range(5):
+ self.assertEqual(gen_outer.send(None), ('inner', 'updated'))
+ for _ in range(5):
+ self.assertEqual(gen_outer.send(ctx_outer), ('inner', 'outer'))
+ self.assertEqual(cvar.get(), 'updated')
+
+
+class AsyncContextTest(unittest.IsolatedAsyncioTestCase):
+ async def test_asyncio_independent_contexts(self):
+ """Check that coroutines are run with independent contexts.
+
+ Changes to context variables outside a coroutine should not affect the
+ values seen inside the coroutine and vice-versa. (This might be
+ implemented by manually setting the context before executing each step
+ of (send to) a coroutine, or by ensuring that the coroutine is an
+ independent coroutine before executing any steps.)
+ """
+ cvar = contextvars.ContextVar('cvar', default='A')
+ updated1 = asyncio.Event()
+ updated2 = asyncio.Event()
+
+ async def task1():
+ self.assertIs(cvar.get(), 'A')
+ await asyncio.sleep(0)
+ cvar.set('B')
+ await asyncio.sleep(0)
+ updated1.set()
+ await updated2.wait()
+ self.assertIs(cvar.get(), 'B')
+
+ async def task2():
+ await updated1.wait()
+ self.assertIs(cvar.get(), 'A')
+ await asyncio.sleep(0)
+ cvar.set('C')
+ await asyncio.sleep(0)
+ updated2.set()
+ await asyncio.sleep(0)
+ self.assertIs(cvar.get(), 'C')
+
+ async with asyncio.TaskGroup() as tg:
+ tg.create_task(task1())
+ tg.create_task(task2())
+
+ self.assertIs(cvar.get(), 'A')
+
+ async def test_asynccontextmanager_is_dependent_by_default(self):
+ """Async generator in asynccontextmanager is dependent by default.
+
+ Context switches during the yield of a generator wrapped with
+ contextlib.asynccontextmanager should be visible to the generator by
+ default (for backwards compatibility).
+ """
+ cvar = contextvars.ContextVar('cvar', default='A')
+
+ @contextlib.asynccontextmanager
+ async def makecm():
+ await asyncio.sleep(0)
+ self.assertEqual(cvar.get(), 'A')
+ await asyncio.sleep(0)
+ # Everything above runs during __aenter__.
+ yield cvar.get()
+ # Everything below runs during __aexit__.
+ await asyncio.sleep(0)
+ self.assertEqual(cvar.get(), 'C')
+ await asyncio.sleep(0)
+ cvar.set('D')
+ await asyncio.sleep(0)
+
+ cm = makecm()
+ val = await cm.__aenter__()
+ self.assertEqual(val, 'A')
+ self.assertEqual(cvar.get(), 'A')
+ cvar.set('B')
+
+ with contextvars.copy_context():
+ cvar.set('C')
+ await cm.__aexit__(None, None, None)
+ self.assertEqual(cvar.get(), 'D')
+ self.assertEqual(cvar.get(), 'B')
+
+ async def test_asynccontextmanager_independent(self):
+ cvar = contextvars.ContextVar('cvar', default='A')
+
+ @contextlib.asynccontextmanager
+ async def makecm():
+ # Context.__enter__ called from a generator makes the generator
+ # independent while the `with` statement suite runs.
+ # (Alternatively we could have set the generator's _context
+ # property.)
+ with contextvars.copy_context():
+ await asyncio.sleep(0)
+ self.assertEqual(cvar.get(), 'A')
+ await asyncio.sleep(0)
+ # Everything above runs during __aenter__.
+ yield cvar.get()
+ # Everything below runs during __aexit__.
+ await asyncio.sleep(0)
+ self.assertEqual(cvar.get(), 'A')
+ await asyncio.sleep(0)
+ cvar.set('D')
+ await asyncio.sleep(0)
+
+ cm = makecm()
+ val = await cm.__aenter__()
+ self.assertEqual(val, 'A')
+ self.assertEqual(cvar.get(), 'A')
+ cvar.set('B')
+ with contextvars.copy_context():
+ cvar.set('C')
+ await cm.__aexit__(None, None, None)
+ self.assertEqual(cvar.get(), 'C')
+ self.assertEqual(cvar.get(), 'B')
+
+ async def test_generator_switch_between_independent_dependent(self):
+ cvar = contextvars.ContextVar('cvar', default='default')
+ with contextvars.copy_context() as ctx1:
+ cvar.set('in ctx1')
+ with contextvars.copy_context() as ctx2:
+ cvar.set('in ctx2')
+ with contextvars.copy_context() as ctx3:
+ cvar.set('in ctx3')
+
+ async def makegen():
+ await asyncio.sleep(0)
+ yield cvar.get()
+ await asyncio.sleep(0)
+ yield cvar.get()
+ await asyncio.sleep(0)
+ with ctx2:
+ yield cvar.get()
+ await asyncio.sleep(0)
+ yield cvar.get()
+ await asyncio.sleep(0)
+ yield cvar.get()
+
+ gen = makegen()
+ self.assertEqual(await anext(gen), 'default')
+ with ctx1:
+ self.assertEqual(await anext(gen), 'in ctx1')
+ self.assertEqual(await anext(gen), 'in ctx2')
+ with ctx3:
+ self.assertEqual(await anext(gen), 'in ctx2')
+ self.assertEqual(await anext(gen), 'in ctx1')
# HAMT Tests
diff --git a/Lib/test/test_sys.py b/Lib/test/test_sys.py
index 9689ef8e96e072..0cb3da6cfda049 100644
--- a/Lib/test/test_sys.py
+++ b/Lib/test/test_sys.py
@@ -1617,7 +1617,7 @@ def bar(cls):
check(bar, size('PP'))
# generator
def get_gen(): yield 1
- check(get_gen(), size('6P4c' + INTERPRETER_FRAME + 'P'))
+ check(get_gen(), size('6P2P4c' + INTERPRETER_FRAME + 'P'))
# iterator
check(iter('abc'), size('lP'))
# callable-iterator
diff --git a/Makefile.pre.in b/Makefile.pre.in
index 07c8a4d20142db..3378313a1fb55d 100644
--- a/Makefile.pre.in
+++ b/Makefile.pre.in
@@ -1191,6 +1191,7 @@ PYTHON_HEADERS= \
$(srcdir)/Include/internal/pycore_complexobject.h \
$(srcdir)/Include/internal/pycore_condvar.h \
$(srcdir)/Include/internal/pycore_context.h \
+ $(srcdir)/Include/internal/pycore_contextchain.h \
$(srcdir)/Include/internal/pycore_critical_section.h \
$(srcdir)/Include/internal/pycore_crossinterp.h \
$(srcdir)/Include/internal/pycore_descrobject.h \
diff --git a/Misc/ACKS b/Misc/ACKS
index d94cbacf888468..69df31799737c7 100644
--- a/Misc/ACKS
+++ b/Misc/ACKS
@@ -716,6 +716,7 @@ Michael Handler
Andreas Hangauer
Milton L. Hankins
Carl Bordum Hansen
+Richard Hansen
Stephen Hansen
Barry Hantman
Lynda Hardman
diff --git a/Misc/NEWS.d/next/Library/2022-11-21-01-24-46.gh-issue-99633.vhrNRe.rst b/Misc/NEWS.d/next/Library/2022-11-21-01-24-46.gh-issue-99633.vhrNRe.rst
new file mode 100644
index 00000000000000..1c7191e1b22e6d
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2022-11-21-01-24-46.gh-issue-99633.vhrNRe.rst
@@ -0,0 +1,2 @@
+Added support for the :term:`context management protocol` to
+:class:`contextvars.Context`. Patch by Richard Hansen.
diff --git a/Modules/_testcapi/watchers.c b/Modules/_testcapi/watchers.c
index 689863d098ad8a..713975d59b3a43 100644
--- a/Modules/_testcapi/watchers.c
+++ b/Modules/_testcapi/watchers.c
@@ -9,6 +9,7 @@
#include "pycore_function.h" // FUNC_MAX_WATCHERS
#include "pycore_code.h" // CODE_MAX_WATCHERS
#include "pycore_context.h" // CONTEXT_MAX_WATCHERS
+#include "pycore_tstate.h" // _PyThreadStateImpl::_ctx_chain
/*[clinic input]
module _testcapi
@@ -626,72 +627,62 @@ allocate_too_many_func_watchers(PyObject *self, PyObject *args)
// Test contexct object watchers
#define NUM_CONTEXT_WATCHERS 2
static int context_watcher_ids[NUM_CONTEXT_WATCHERS] = {-1, -1};
-static int num_context_object_enter_events[NUM_CONTEXT_WATCHERS] = {0, 0};
-static int num_context_object_exit_events[NUM_CONTEXT_WATCHERS] = {0, 0};
+static PyObject *context_switches[NUM_CONTEXT_WATCHERS];
-static int
-handle_context_watcher_event(int which_watcher, PyContextEvent event, PyContext *ctx) {
- if (event == Py_CONTEXT_EVENT_ENTER) {
- num_context_object_enter_events[which_watcher]++;
- }
- else if (event == Py_CONTEXT_EVENT_EXIT) {
- num_context_object_exit_events[which_watcher]++;
+static void
+handle_context_watcher_event(int which_watcher, PyContextEvent event, PyObject *ctx) {
+ if (event == Py_CONTEXT_SWITCHED) {
+ PyList_Append(context_switches[which_watcher], ctx);
}
else {
- return -1;
+ Py_UNREACHABLE();
}
- return 0;
}
-static int
-first_context_watcher_callback(PyContextEvent event, PyContext *ctx) {
- return handle_context_watcher_event(0, event, ctx);
+static void
+first_context_watcher_callback(PyContextEvent event, PyObject *ctx) {
+ handle_context_watcher_event(0, event, ctx);
}
-static int
-second_context_watcher_callback(PyContextEvent event, PyContext *ctx) {
- return handle_context_watcher_event(1, event, ctx);
+static void
+second_context_watcher_callback(PyContextEvent event, PyObject *ctx) {
+ handle_context_watcher_event(1, event, ctx);
}
-static int
-noop_context_event_handler(PyContextEvent event, PyContext *ctx) {
- return 0;
+static void
+noop_context_event_handler(PyContextEvent event, PyObject *ctx) {
}
-static int
-error_context_event_handler(PyContextEvent event, PyContext *ctx) {
+static void
+error_context_event_handler(PyContextEvent event, PyObject *ctx) {
PyErr_SetString(PyExc_RuntimeError, "boom!");
- return -1;
}
static PyObject *
add_context_watcher(PyObject *self, PyObject *which_watcher)
{
- int watcher_id;
+ static const PyContext_WatchCallback callbacks[] = {
+ &first_context_watcher_callback,
+ &second_context_watcher_callback,
+ &error_context_event_handler,
+ };
assert(PyLong_Check(which_watcher));
long which_l = PyLong_AsLong(which_watcher);
- if (which_l == 0) {
- watcher_id = PyContext_AddWatcher(first_context_watcher_callback);
- context_watcher_ids[0] = watcher_id;
- num_context_object_enter_events[0] = 0;
- num_context_object_exit_events[0] = 0;
- }
- else if (which_l == 1) {
- watcher_id = PyContext_AddWatcher(second_context_watcher_callback);
- context_watcher_ids[1] = watcher_id;
- num_context_object_enter_events[1] = 0;
- num_context_object_exit_events[1] = 0;
- }
- else if (which_l == 2) {
- watcher_id = PyContext_AddWatcher(error_context_event_handler);
- }
- else {
+ if (which_l < 0 || which_l >= (long)Py_ARRAY_LENGTH(callbacks)) {
PyErr_Format(PyExc_ValueError, "invalid watcher %d", which_l);
return NULL;
}
+ int watcher_id = PyContext_AddWatcher(callbacks[which_l]);
if (watcher_id < 0) {
return NULL;
}
+ if (which_l >= 0 && which_l < NUM_CONTEXT_WATCHERS) {
+ context_watcher_ids[which_l] = watcher_id;
+ Py_XSETREF(context_switches[which_l], PyList_New(0));
+ if (context_switches[which_l] == NULL) {
+ return NULL;
+ }
+ }
return PyLong_FromLong(watcher_id);
}
@@ -708,8 +699,7 @@ clear_context_watcher(PyObject *self, PyObject *watcher_id)
for (int i = 0; i < NUM_CONTEXT_WATCHERS; i++) {
if (watcher_id_l == context_watcher_ids[i]) {
context_watcher_ids[i] = -1;
- num_context_object_enter_events[i] = 0;
- num_context_object_exit_events[i] = 0;
+ Py_CLEAR(context_switches[i]);
}
}
}
@@ -717,21 +707,48 @@ clear_context_watcher(PyObject *self, PyObject *watcher_id)
}
static PyObject *
-get_context_watcher_num_enter_events(PyObject *self, PyObject *watcher_id)
+clear_context_stack(PyObject *self, PyObject *args)
{
- assert(PyLong_Check(watcher_id));
- long watcher_id_l = PyLong_AsLong(watcher_id);
- assert(watcher_id_l >= 0 && watcher_id_l < NUM_CONTEXT_WATCHERS);
- return PyLong_FromLong(num_context_object_enter_events[watcher_id_l]);
+ // Ensure that _ctx_chain is initialized.
+ PyObject *ctx = PyContext_CopyCurrent();
+ if (ctx == NULL) {
+ return NULL;
+ }
+ Py_CLEAR(ctx);
+
+ _PyThreadStateImpl *tsi = (_PyThreadStateImpl *)PyThreadState_Get();
+ if (tsi->_ctx_chain.prev != &tsi->_ctx_chain) {
+ PyErr_SetString(PyExc_RuntimeError,
+ "must not be called from a coroutine or generator");
+ }
+ if (tsi->_ctx_chain.prev->ctx == NULL) {
+ Py_RETURN_NONE;
+ }
+ if (((PyContext *)tsi->_ctx_chain.ctx)->ctx_prev != NULL) {
+ PyErr_SetString(PyExc_RuntimeError,
+ "must first exit all non-base contexts");
+ return NULL;
+ }
+ if (PyContext_Exit(tsi->_ctx_chain.prev->ctx)) {
+ return NULL;
+ }
+ assert(tsi->_ctx_chain.prev->ctx == NULL);
+ Py_RETURN_NONE;
}
static PyObject *
-get_context_watcher_num_exit_events(PyObject *self, PyObject *watcher_id)
+get_context_switches(PyObject *self, PyObject *watcher_id)
{
assert(PyLong_Check(watcher_id));
long watcher_id_l = PyLong_AsLong(watcher_id);
- assert(watcher_id_l >= 0 && watcher_id_l < NUM_CONTEXT_WATCHERS);
- return PyLong_FromLong(num_context_object_exit_events[watcher_id_l]);
+ if (watcher_id_l < 0 || watcher_id_l >= NUM_CONTEXT_WATCHERS) {
+ PyErr_Format(PyExc_ValueError, "invalid watcher %d", watcher_id_l);
+ return NULL;
+ }
+ if (context_switches[watcher_id_l] == NULL) {
+ return PyList_New(0);
+ }
+ return Py_NewRef(context_switches[watcher_id_l]);
}
static PyObject *
@@ -835,10 +852,8 @@ static PyMethodDef test_methods[] = {
// Code object watchers.
{"add_context_watcher", add_context_watcher, METH_O, NULL},
{"clear_context_watcher", clear_context_watcher, METH_O, NULL},
- {"get_context_watcher_num_enter_events",
- get_context_watcher_num_enter_events, METH_O, NULL},
- {"get_context_watcher_num_exit_events",
- get_context_watcher_num_exit_events, METH_O, NULL},
+ {"clear_context_stack", clear_context_stack, METH_NOARGS, NULL},
+ {"get_context_switches", get_context_switches, METH_O, NULL},
{"allocate_too_many_context_watchers",
(PyCFunction) allocate_too_many_context_watchers, METH_NOARGS, NULL},
{NULL},
diff --git a/Objects/genobject.c b/Objects/genobject.c
index 19c2c4e3331a89..d78d1609852285 100644
--- a/Objects/genobject.c
+++ b/Objects/genobject.c
@@ -58,6 +58,7 @@ gen_traverse(PyObject *self, visitproc visit, void *arg)
PyGenObject *gen = _PyGen_CAST(self);
Py_VISIT(gen->gi_name);
Py_VISIT(gen->gi_qualname);
+ Py_VISIT(gen->_ctx_chain.ctx);
if (gen->gi_frame_state != FRAME_CLEARED) {
_PyInterpreterFrame *frame = &gen->gi_iframe;
assert(frame->frame_obj == NULL ||
@@ -129,6 +130,14 @@ _PyGen_Finalize(PyObject *self)
Py_DECREF(res);
}
}
+ if (_PyGen_ResetContext(_PyThreadState_GET(), gen, NULL)) {
+ // This can happen if the contextvars API is misused (the coroutine or a
+ // function it called entered a context but did not exit the context
+ // before the coroutine concluded). The coroutine's base context, and
+ // the entered contexts on top of it, will remain marked as entered but
+ // will otherwise behave normally.
+ PyErr_WriteUnraisable(self);
+ }
/* Restore the saved exception. */
PyErr_SetRaisedException(exc);
@@ -170,6 +179,7 @@ gen_dealloc(PyObject *self)
PyStackRef_CLEAR(gen->gi_iframe.f_executable);
Py_CLEAR(gen->gi_name);
Py_CLEAR(gen->gi_qualname);
+ Py_CLEAR(gen->_ctx_chain.ctx);
PyObject_GC_Del(gen);
}
@@ -242,7 +252,9 @@ gen_send_ex2(PyGenObject *gen, PyObject *arg, PyObject **presult,
gen->gi_frame_state = FRAME_EXECUTING;
EVAL_CALL_STAT_INC(EVAL_CALL_GENERATOR);
+ _PyGen_ActivateContext(tstate, gen);
PyObject *result = _PyEval_EvalFrame(tstate, frame, exc);
+ _PyGen_DeactivateContext(tstate, gen);
assert(tstate->exc_info == prev_exc_info);
assert(gen->gi_exc_state.previous_item == NULL);
assert(gen->gi_frame_state != FRAME_EXECUTING);
@@ -733,6 +745,22 @@ gen_set_qualname(PyObject *self, PyObject *value, void *Py_UNUSED(ignored))
return 0;
}
+static PyObject *
+gen_get_context(PyObject *self, void *Py_UNUSED(ignored))
+{
+ PyObject *ctx = _PyGen_CAST(self)->_ctx_chain.ctx;
+ if (ctx == NULL) {
+ Py_RETURN_NONE;
+ }
+ return Py_NewRef(ctx);
+}
+
+static int
+gen_set_context(PyObject *self, PyObject *ctx, void *Py_UNUSED(ignored))
+{
+ return _PyGen_ResetContext(_PyThreadState_GET(), _PyGen_CAST(self), ctx);
+}
+
static PyObject *
gen_getyieldfrom(PyObject *gen, void *Py_UNUSED(ignored))
{
@@ -801,6 +829,10 @@ static PyGetSetDef gen_getsetlist[] = {
PyDoc_STR("name of the generator")},
{"__qualname__", gen_get_qualname, gen_set_qualname,
PyDoc_STR("qualified name of the generator")},
+ {"_context", gen_get_context, gen_set_context,
+ PyDoc_STR("the generator's observed \"current context\", or None if the "
+ "generator uses the thread's context (which can change during a "
+ "yield) as its current context")},
{"gi_yieldfrom", gen_getyieldfrom, NULL,
PyDoc_STR("object being iterated by yield from, or None")},
{"gi_running", gen_getrunning, NULL, NULL},
@@ -914,6 +946,7 @@ make_gen(PyTypeObject *type, PyFunctionObject *func)
gen->gi_name = Py_NewRef(func->func_name);
assert(func->func_qualname != NULL);
gen->gi_qualname = Py_NewRef(func->func_qualname);
+ gen->_ctx_chain = (_PyContextChain){0};
_PyObject_GC_TRACK(gen);
return (PyObject *)gen;
}
@@ -1001,6 +1034,7 @@ gen_new_with_qualname(PyTypeObject *type, PyFrameObject *f,
gen->gi_qualname = Py_NewRef(qualname);
else
gen->gi_qualname = Py_NewRef(_PyGen_GetCode(gen)->co_qualname);
+ gen->_ctx_chain = (_PyContextChain){0};
_PyObject_GC_TRACK(gen);
return (PyObject *)gen;
}
@@ -1157,6 +1191,10 @@ static PyGetSetDef coro_getsetlist[] = {
PyDoc_STR("name of the coroutine")},
{"__qualname__", gen_get_qualname, gen_set_qualname,
PyDoc_STR("qualified name of the coroutine")},
+ {"_context", gen_get_context, gen_set_context,
+ PyDoc_STR("the coroutine's observed \"current context\", or None if the "
+ "coroutine uses the thread's context (which can change during "
+ "an await) as its current context")},
{"cr_await", coro_get_cr_await, NULL,
PyDoc_STR("object being awaited on, or None")},
{"cr_running", cr_getrunning, NULL, NULL},
@@ -1588,6 +1626,10 @@ static PyGetSetDef async_gen_getsetlist[] = {
PyDoc_STR("name of the async generator")},
{"__qualname__", gen_get_qualname, gen_set_qualname,
PyDoc_STR("qualified name of the async generator")},
+ {"_context", gen_get_context, gen_set_context,
+ PyDoc_STR("the generator's observed \"current context\", or None if the "
+ "generator uses the thread's context (which can change during a "
+ "yield) as its current context")},
{"ag_await", coro_get_cr_await, NULL,
PyDoc_STR("object being awaited on, or None")},
{"ag_frame", ag_getframe, NULL, NULL},
diff --git a/PCbuild/pythoncore.vcxproj b/PCbuild/pythoncore.vcxproj
index 3b33c6bf6bb91d..a7264ee18718f0 100644
--- a/PCbuild/pythoncore.vcxproj
+++ b/PCbuild/pythoncore.vcxproj
@@ -225,6 +225,7 @@
+
diff --git a/PCbuild/pythoncore.vcxproj.filters b/PCbuild/pythoncore.vcxproj.filters
index ee2930b10439a9..b2de7df9466385 100644
--- a/PCbuild/pythoncore.vcxproj.filters
+++ b/PCbuild/pythoncore.vcxproj.filters
@@ -597,6 +597,9 @@
Include\internal
+
+ Include\internal
+
Include\internal
diff --git a/Python/clinic/context.c.h b/Python/clinic/context.c.h
index 997ac6f63384a9..45d8639333c591 100644
--- a/Python/clinic/context.c.h
+++ b/Python/clinic/context.c.h
@@ -4,6 +4,75 @@ preserve
#include "pycore_modsupport.h" // _PyArg_CheckPositional()
+PyDoc_STRVAR(_contextvars_Context___enter____doc__,
+"__enter__($self, /)\n"
+"--\n"
+"\n"
+"Context manager enter.\n"
+"\n"
+"Automatically called by the \'with\' statement. Using the Context object as a\n"
+"context manager is an alternative to calling the Context.run() method.\n"
+"\n"
+"Example:\n"
+"\n"
+" var = contextvars.ContextVar(\'var\')\n"
+" var.set(\'initial\')\n"
+"\n"
+" with contextvars.copy_context():\n"
+" var.set(\'updated\')\n"
+" print(var.get()) # \'updated\'\n"
+"\n"
+" print(var.get()) # \'initial\'");
+
+#define _CONTEXTVARS_CONTEXT___ENTER___METHODDEF \
+ {"__enter__", (PyCFunction)_contextvars_Context___enter__, METH_NOARGS, _contextvars_Context___enter____doc__},
+
+static PyObject *
+_contextvars_Context___enter___impl(PyContext *self);
+
+static PyObject *
+_contextvars_Context___enter__(PyContext *self, PyObject *Py_UNUSED(ignored))
+{
+ return _contextvars_Context___enter___impl(self);
+}
+
+PyDoc_STRVAR(_contextvars_Context___exit____doc__,
+"__exit__($self, exc_type, exc_val, exc_tb, /)\n"
+"--\n"
+"\n"
+"Context manager exit.\n"
+"\n"
+"Automatically called at the conclusion of a \'with\' statement when the Context\n"
+"is used as a context manager. See the Context.__enter__() method for more\n"
+"details.");
+
+#define _CONTEXTVARS_CONTEXT___EXIT___METHODDEF \
+ {"__exit__", _PyCFunction_CAST(_contextvars_Context___exit__), METH_FASTCALL, _contextvars_Context___exit____doc__},
+
+static PyObject *
+_contextvars_Context___exit___impl(PyContext *self, PyObject *exc_type,
+ PyObject *exc_val, PyObject *exc_tb);
+
+static PyObject *
+_contextvars_Context___exit__(PyContext *self, PyObject *const *args, Py_ssize_t nargs)
+{
+ PyObject *return_value = NULL;
+ PyObject *exc_type;
+ PyObject *exc_val;
+ PyObject *exc_tb;
+
+ if (!_PyArg_CheckPositional("__exit__", nargs, 3, 3)) {
+ goto exit;
+ }
+ exc_type = args[0];
+ exc_val = args[1];
+ exc_tb = args[2];
+ return_value = _contextvars_Context___exit___impl(self, exc_type, exc_val, exc_tb);
+
+exit:
+ return return_value;
+}
+
PyDoc_STRVAR(_contextvars_Context_get__doc__,
"get($self, key, default=None, /)\n"
"--\n"
@@ -179,4 +248,4 @@ PyDoc_STRVAR(_contextvars_ContextVar_reset__doc__,
#define _CONTEXTVARS_CONTEXTVAR_RESET_METHODDEF \
{"reset", (PyCFunction)_contextvars_ContextVar_reset, METH_O, _contextvars_ContextVar_reset__doc__},
-/*[clinic end generated code: output=b667826178444c3f input=a9049054013a1b77]*/
+/*[clinic end generated code: output=68e3b8eb96ff5dc8 input=a9049054013a1b77]*/
diff --git a/Python/context.c b/Python/context.c
index 36e2677c398f59..c774dc4ec44875 100644
--- a/Python/context.c
+++ b/Python/context.c
@@ -43,6 +43,40 @@ module _contextvars
/////////////////////////// Context API
+// Returns the head of the context chain, which holds the "active" context
+// stack. Always succeeds.
+static _PyContextChain *
+contextchain_head(_PyThreadStateImpl *tsi)
+{
+ assert(tsi != NULL);
+ // Lazy initialization.
+ if (tsi->_ctx_chain.prev == NULL) {
+ assert(tsi->_ctx_chain.ctx == NULL);
+ tsi->_ctx_chain.prev = &tsi->_ctx_chain;
+ }
+ return tsi->_ctx_chain.prev;
+}
+
+// Inserts prev before next in the context chain. Always succeeds.
+static inline void
+contextchain_link(_PyContextChain *prev, _PyContextChain *next)
+{
+ assert(next->prev != NULL);
+ assert(prev->prev == NULL);
+ prev->prev = next->prev;
+ next->prev = prev;
+}
+
+// Removes prev from the context chain. Always succeeds.
+static inline void
+contextchain_unlink(_PyContextChain *prev, _PyContextChain *next)
+{
+ assert(next->prev == prev);
+ assert(prev->prev != NULL);
+ next->prev = prev->prev;
+ prev->prev = NULL;
+}
+
static PyContext *
context_new_empty(void);
@@ -102,18 +136,24 @@ PyContext_CopyCurrent(void)
static const char *
context_event_name(PyContextEvent event) {
switch (event) {
- case Py_CONTEXT_EVENT_ENTER:
- return "Py_CONTEXT_EVENT_ENTER";
- case Py_CONTEXT_EVENT_EXIT:
- return "Py_CONTEXT_EVENT_EXIT";
+ case Py_CONTEXT_SWITCHED:
+ return "Py_CONTEXT_SWITCHED";
default:
return "?";
}
Py_UNREACHABLE();
}
-static void notify_context_watchers(PyContextEvent event, PyContext *ctx, PyThreadState *ts)
+static void
+notify_context_watchers(PyThreadState *ts, PyContextEvent event, PyObject *ctx)
{
+ if (ctx == NULL) {
+ // This will happen after exiting the last context in the stack, which
+ // can occur if context_get was never called before entering a context
+ // (e.g., called `contextvars.Context().run()` on a fresh thread, as
+ // PyContext_Enter doesn't call context_get).
+ ctx = Py_None;
+ }
assert(Py_REFCNT(ctx) > 0);
PyInterpreterState *interp = ts->interp;
assert(interp->_initialized);
@@ -124,11 +164,14 @@ static void notify_context_watchers(PyContextEvent event, PyContext *ctx, PyThre
if (bits & 1) {
PyContext_WatchCallback cb = interp->context_watchers[i];
assert(cb != NULL);
- if (cb(event, ctx) < 0) {
+ PyObject *exc = _PyErr_GetRaisedException(ts);
+ cb(event, ctx);
+ if (_PyErr_Occurred(ts) != NULL) {
PyErr_FormatUnraisable(
"Exception ignored in %s watcher callback for %R",
context_event_name(event), ctx);
}
+ _PyErr_SetRaisedException(ts, exc);
}
i++;
bits >>= 1;
@@ -174,25 +217,31 @@ PyContext_ClearWatcher(int watcher_id)
}
-static int
-_PyContext_Enter(PyThreadState *ts, PyObject *octx)
+static inline void
+context_switched(_PyThreadStateImpl *tsi)
{
- ENSURE_Context(octx, -1)
- PyContext *ctx = (PyContext *)octx;
+ tsi->base.context_ver++;
+ // contextchain_head(tsi)->ctx is used instead of context_get() because if
+ // tsi->_ctx_chain.ctx is NULL, context_get() will either call
+ // context_switched -- causing a double notification -- or throw.
+ notify_context_watchers(
+ &tsi->base, Py_CONTEXT_SWITCHED, contextchain_head(tsi)->ctx);
+}
+
+static int
+_PyContext_Enter(PyObject **stack, PyContext *ctx)
+{
+ assert(stack != NULL);
if (ctx->ctx_entered) {
- _PyErr_Format(ts, PyExc_RuntimeError,
- "cannot enter context: %R is already entered", ctx);
+ PyErr_Format(PyExc_RuntimeError,
+ "cannot enter context: %R is already entered", ctx);
return -1;
}
- ctx->ctx_prev = (PyContext *)ts->context; /* borrow */
ctx->ctx_entered = 1;
-
- ts->context = Py_NewRef(ctx);
- ts->context_ver++;
-
- notify_context_watchers(Py_CONTEXT_EVENT_ENTER, ctx, ts);
+ ctx->ctx_prev = *stack; /* steal */
+ *stack = Py_NewRef(ctx);
return 0;
}
@@ -200,50 +249,172 @@ _PyContext_Enter(PyThreadState *ts, PyObject *octx)
int
PyContext_Enter(PyObject *octx)
{
- PyThreadState *ts = _PyThreadState_GET();
- assert(ts != NULL);
- return _PyContext_Enter(ts, octx);
+ _PyThreadStateImpl *tsi = (_PyThreadStateImpl *)_PyThreadState_GET();
+ assert(tsi != NULL);
+ ENSURE_Context(octx, -1)
+ if (_PyContext_Enter(&contextchain_head(tsi)->ctx, (PyContext *)octx)) {
+ return -1;
+ }
+ context_switched(tsi);
+ return 0;
}
static int
-_PyContext_Exit(PyThreadState *ts, PyObject *octx)
+_PyContext_Exit(PyObject **stack, PyContext *ctx)
{
- ENSURE_Context(octx, -1)
- PyContext *ctx = (PyContext *)octx;
-
+ assert(stack != NULL);
if (!ctx->ctx_entered) {
PyErr_Format(PyExc_RuntimeError,
"cannot exit context: %R has not been entered", ctx);
return -1;
}
- if (ts->context != (PyObject *)ctx) {
- /* Can only happen if someone misuses the C API */
+ if (*stack != (PyObject *)ctx) {
PyErr_SetString(PyExc_RuntimeError,
- "cannot exit context: thread state references "
- "a different context object");
+ "cannot exit context: not the current context");
return -1;
}
- notify_context_watchers(Py_CONTEXT_EVENT_EXIT, ctx, ts);
- Py_SETREF(ts->context, (PyObject *)ctx->ctx_prev);
- ts->context_ver++;
-
+ Py_SETREF(*stack, ctx->ctx_prev); /* steal */
ctx->ctx_prev = NULL;
ctx->ctx_entered = 0;
-
return 0;
}
int
PyContext_Exit(PyObject *octx)
{
- PyThreadState *ts = _PyThreadState_GET();
- assert(ts != NULL);
- return _PyContext_Exit(ts, octx);
+ _PyThreadStateImpl *tsi = (_PyThreadStateImpl *)_PyThreadState_GET();
+ assert(tsi != NULL);
+ ENSURE_Context(octx, -1)
+ _PyContextChain *active = contextchain_head(tsi);
+ if (_PyContext_Exit(&active->ctx, (PyContext *)octx)) {
+ return -1;
+ }
+ if (active->ctx == NULL && active != &tsi->_ctx_chain) {
+ contextchain_unlink(active, &tsi->_ctx_chain);
+ }
+ context_switched(tsi);
+ return 0;
+}
+
+static _PyContextChain *
+gen_find_next_contextchain(_PyThreadStateImpl *tsi, PyGenObject *self)
+{
+ assert(self->gi_frame_state == FRAME_EXECUTING);
+ assert(tsi != NULL);
+ _PyContextChain *nlink = &tsi->_ctx_chain;
+ _PyInterpreterFrame *frame = _PyThreadState_GetFrame((PyThreadState *)tsi);
+ while (frame != NULL) {
+ if (frame->owner == FRAME_OWNED_BY_GENERATOR) {
+ PyGenObject *gen = _PyGen_GetGeneratorFromFrame(frame);
+ if (gen == self) {
+ break;
+ }
+ if (gen->_ctx_chain.ctx != NULL) {
+ assert(nlink->prev == &gen->_ctx_chain);
+ nlink = &gen->_ctx_chain;
+ assert(nlink->prev != NULL);
+ }
+ }
+ frame = _PyFrame_GetFirstComplete(frame->previous);
+ }
+ if (frame == NULL) {
+ PyErr_SetString(PyExc_RuntimeError,
+ "coroutine is running but not in the current thread");
+ return NULL;
+ }
+ return nlink;
+}
+
+int
+_PyGen_ResetContext(PyThreadState *ts, PyGenObject *self, PyObject *ctx)
+{
+ _PyThreadStateImpl *tsi = (_PyThreadStateImpl *)ts;
+ if (ctx == Py_None) {
+ ctx = NULL;
+ }
+ if (ctx != NULL && !PyContext_CheckExact(ctx)) {
+ PyErr_SetString(PyExc_TypeError,
+ "a coroutine's base context must be a context.Context "
+ "object or None");
+ return -1;
+ }
+ PyObject *old_stack = self->_ctx_chain.ctx;
+ assert(old_stack == NULL || PyContext_CheckExact(old_stack));
+ if (old_stack != NULL && ((PyContext *)old_stack)->ctx_prev != NULL) {
+ PyErr_SetString(PyExc_RuntimeError,
+ "cannot reset a coroutine's base context until the "
+ "coroutine has exited all of its non-base contexts");
+ return -1;
+ }
+ if (ctx == old_stack) {
+ return 0;
+ }
+ assert(self->_ctx_chain.ctx != NULL || self->_ctx_chain.prev == NULL);
+ assert(self->_ctx_chain.prev == NULL ||
+ (self->gi_frame_state == FRAME_EXECUTING &&
+ self->_ctx_chain.ctx != NULL));
+ assert(self->gi_frame_state != FRAME_EXECUTING ||
+ self->_ctx_chain.prev != NULL || self->_ctx_chain.ctx == NULL);
+ // contextchain_head(tsi)->ctx is used instead of context_get() because
+ // context_get can throw, and we don't need tsi->_ctx_chain.ctx to be
+ // initialized if currently NULL.
+ PyObject *old_ctx = contextchain_head(tsi)->ctx;
+ // Enter the new context (and activate/deactivate the context stack if
+ // necessary) before exiting the old context in case there is a problem
+ // entering the new context. (Exiting the old should always succeed.)
+ PyObject *new_stack = NULL;
+ if (ctx != NULL && _PyContext_Enter(&new_stack, (PyContext *)ctx)) {
+ return -1;
+ }
+ assert(new_stack == ctx);
+ if (self->gi_frame_state == FRAME_EXECUTING &&
+ (old_stack == NULL) != (new_stack == NULL)) {
+ _PyContextChain *nlink = gen_find_next_contextchain(tsi, self);
+ if (nlink == NULL) {
+ if (new_stack != NULL
+ && _PyContext_Exit(&new_stack, (PyContext *)new_stack)) {
+ Py_UNREACHABLE();
+ }
+ assert(new_stack == NULL);
+ return -1;
+ }
+ if (new_stack != NULL) {
+ contextchain_link(&self->_ctx_chain, nlink);
+ } else {
+ contextchain_unlink(&self->_ctx_chain, nlink);
+ }
+ }
+ if (old_stack != NULL && _PyContext_Exit(&old_stack,
+ (PyContext *)old_stack)) {
+ Py_UNREACHABLE();
+ }
+ assert(old_stack == NULL);
+ self->_ctx_chain.ctx = new_stack;
+ // contextchain_head(tsi)->ctx is used instead of context_get() because
+ // context_get can throw, and we don't need tsi->_ctx_chain.ctx to be
+ // initialized if currently NULL.
+ if (contextchain_head(tsi)->ctx != old_ctx) {
+ context_switched(tsi);
+ }
+ return 0;
+}
+
+void
+_PyGen_ActivateContextImpl(_PyThreadStateImpl *tsi, PyGenObject *self)
+{
+ contextchain_link(&self->_ctx_chain, &tsi->_ctx_chain);
+ context_switched(tsi);
}
+void
+_PyGen_DeactivateContextImpl(_PyThreadStateImpl *tsi, PyGenObject *self)
+{
+ contextchain_unlink(&self->_ctx_chain, &tsi->_ctx_chain);
+ context_switched(tsi);
+}
PyObject *
PyContextVar_New(const char *name, PyObject *def)
@@ -265,8 +436,13 @@ PyContextVar_Get(PyObject *ovar, PyObject *def, PyObject **val)
PyContextVar *var = (PyContextVar *)ovar;
PyThreadState *ts = _PyThreadState_GET();
+ _PyThreadStateImpl *tsi = (_PyThreadStateImpl *)ts;
assert(ts != NULL);
- if (ts->context == NULL) {
+ // contextchain_head(tsi)->ctx is used instead of context_get() because
+ // context_get can throw, and we don't need tsi->_ctx_chain.ctx to be
+ // initialized if currently NULL.
+ PyContext *ctx = (PyContext *)contextchain_head(tsi)->ctx;
+ if (ctx == NULL) {
goto not_found;
}
@@ -280,8 +456,8 @@ PyContextVar_Get(PyObject *ovar, PyObject *def, PyObject **val)
}
#endif
- assert(PyContext_CheckExact(ts->context));
- PyHamtObject *vars = ((PyContext *)ts->context)->ctx_vars;
+ assert(PyContext_CheckExact(ctx));
+ PyHamtObject *vars = ctx->ctx_vars;
PyObject *found = NULL;
int res = _PyHamt_Find(vars, (PyObject*)var, &found);
@@ -464,17 +640,25 @@ context_new_from_vars(PyHamtObject *vars)
static inline PyContext *
context_get(void)
{
- PyThreadState *ts = _PyThreadState_GET();
- assert(ts != NULL);
- PyContext *current_ctx = (PyContext *)ts->context;
- if (current_ctx == NULL) {
- current_ctx = context_new_empty();
- if (current_ctx == NULL) {
- return NULL;
+ _PyThreadStateImpl *tsi = (_PyThreadStateImpl *)_PyThreadState_GET();
+ assert(tsi != NULL);
+ _PyContextChain *active = contextchain_head(tsi);
+ if (active->ctx == NULL) {
+ assert(active == &tsi->_ctx_chain);
+ PyContext *ctx = context_new_empty();
+ if (ctx != NULL && _PyContext_Enter(&active->ctx, ctx)) {
+ Py_UNREACHABLE();
+ }
+ assert(active->ctx == (PyObject *)ctx);
+ if (ctx != NULL) {
+ context_switched(tsi);
}
- ts->context = (PyObject *)current_ctx;
+ Py_CLEAR(ctx); // _PyContext_Enter created its own ref.
}
- return current_ctx;
+ // The current context may be NULL if the above context_new_empty() call
+ // failed.
+ assert(active->ctx == NULL || PyContext_CheckExact(active->ctx));
+ return (PyContext *)active->ctx;
}
static int
@@ -597,6 +781,84 @@ context_tp_contains(PyContext *self, PyObject *key)
}
+/*[clinic input]
+_contextvars.Context.__enter__
+
+Context manager enter.
+
+Automatically called by the 'with' statement. Using the Context object as a
+context manager is an alternative to calling the Context.run() method.
+
+Example:
+
+ var = contextvars.ContextVar('var')
+ var.set('initial')
+
+ with contextvars.copy_context():
+ # Changes to context variables will be rolled back upon exiting the
+ # `with` statement.
+ var.set('updated')
+ print(var.get()) # 'updated'
+
+ # The context variable value has been rolled back.
+ print(var.get()) # 'initial'
+[clinic start generated code]*/
+
+static PyObject *
+_contextvars_Context___enter___impl(PyContext *self)
+/*[clinic end generated code: output=7374aea8983b777a input=fffe71e56ca17ee4]*/
+{
+ PyThreadState *ts = _PyThreadState_GET();
+ _PyInterpreterFrame *frame = ts->current_frame;
+ PyGenObject *gen = frame->owner == FRAME_OWNED_BY_GENERATOR
+ ? _PyGen_GetGeneratorFromFrame(frame) : NULL;
+ if (gen == NULL || gen->_ctx_chain.ctx != NULL) {
+ assert(gen == NULL || (contextchain_head((_PyThreadStateImpl *)ts)
+ == &gen->_ctx_chain));
+ if (PyContext_Enter((PyObject *)self)) {
+ return NULL;
+ }
+ } else if (_PyGen_ResetContext(ts, gen, (PyObject *)self)) {
+ return NULL;
+ }
+ // The new ref added here is for the `with` statement's `as` binding. It
+ // is decremented when the variable goes out of scope, which can be before
+ // or after `PyContext_Exit` is called. (The binding can go out of scope
+ // immediately -- before the `with` suite even runs -- if there is no `as`
+ // clause. Or it can go out of scope long after the `with` suite completes
+ // because `with` does not have its own scope.) Because of this timing,
+ // two references are needed: the one added in the `PyContext_Enter` or
+ // `_PyGen_ResetContext` call and the one returned here.
+ return Py_NewRef(self);
+}
+
+
+/*[clinic input]
+_contextvars.Context.__exit__
+ exc_type: object
+ exc_val: object
+ exc_tb: object
+ /
+
+Context manager exit.
+
+Automatically called at the conclusion of a 'with' statement when the Context
+is used as a context manager. See the Context.__enter__() method for more
+details.
+[clinic start generated code]*/
+
+static PyObject *
+_contextvars_Context___exit___impl(PyContext *self, PyObject *exc_type,
+ PyObject *exc_val, PyObject *exc_tb)
+/*[clinic end generated code: output=4608fa9151f968f1 input=aff87cd8f5c864b0]*/
+{
+ if (PyContext_Exit((PyObject *)self)) {
+ return NULL;
+ }
+ Py_RETURN_NONE;
+}
+
+
/*[clinic input]
_contextvars.Context.get
key: object
@@ -689,25 +951,23 @@ _contextvars_Context_copy_impl(PyContext *self)
static PyObject *
-context_run(PyContext *self, PyObject *const *args,
+context_run(PyObject *self, PyObject *const *args,
Py_ssize_t nargs, PyObject *kwnames)
{
- PyThreadState *ts = _PyThreadState_GET();
-
if (nargs < 1) {
- _PyErr_SetString(ts, PyExc_TypeError,
- "run() missing 1 required positional argument");
+ PyErr_SetString(PyExc_TypeError,
+ "run() missing 1 required positional argument");
return NULL;
}
- if (_PyContext_Enter(ts, (PyObject *)self)) {
+ if (PyContext_Enter(self)) {
return NULL;
}
- PyObject *call_result = _PyObject_VectorcallTstate(
- ts, args[0], args + 1, nargs - 1, kwnames);
+ PyObject *call_result =
+ PyObject_Vectorcall(args[0], args + 1, nargs - 1, kwnames);
- if (_PyContext_Exit(ts, (PyObject *)self)) {
+ if (PyContext_Exit(self)) {
Py_XDECREF(call_result);
return NULL;
}
@@ -717,6 +977,8 @@ context_run(PyContext *self, PyObject *const *args,
static PyMethodDef PyContext_methods[] = {
+ _CONTEXTVARS_CONTEXT___ENTER___METHODDEF
+ _CONTEXTVARS_CONTEXT___EXIT___METHODDEF
_CONTEXTVARS_CONTEXT_GET_METHODDEF
_CONTEXTVARS_CONTEXT_ITEMS_METHODDEF
_CONTEXTVARS_CONTEXT_KEYS_METHODDEF
diff --git a/Python/pystate.c b/Python/pystate.c
index 45e79ade7b6035..8ef4a84c33da46 100644
--- a/Python/pystate.c
+++ b/Python/pystate.c
@@ -1737,7 +1737,7 @@ PyThreadState_Clear(PyThreadState *tstate)
Py_CLEAR(tstate->async_gen_firstiter);
Py_CLEAR(tstate->async_gen_finalizer);
- Py_CLEAR(tstate->context);
+ Py_CLEAR(((_PyThreadStateImpl *)tstate)->_ctx_chain.ctx);
#ifdef Py_GIL_DISABLED
// Each thread should clear own freelists in free-threading builds.
diff --git a/Tools/c-analyzer/cpython/ignored.tsv b/Tools/c-analyzer/cpython/ignored.tsv
index e6c599a2ac4a46..2605825d3d0078 100644
--- a/Tools/c-analyzer/cpython/ignored.tsv
+++ b/Tools/c-analyzer/cpython/ignored.tsv
@@ -455,8 +455,8 @@ Modules/_testcapi/watchers.c - pyfunc_watchers -
Modules/_testcapi/watchers.c - func_watcher_ids -
Modules/_testcapi/watchers.c - func_watcher_callbacks -
Modules/_testcapi/watchers.c - context_watcher_ids -
-Modules/_testcapi/watchers.c - num_context_object_enter_events -
-Modules/_testcapi/watchers.c - num_context_object_exit_events -
+Modules/_testcapi/watchers.c - context_switches -
+Modules/_testcapi/watchers.c add_context_watcher callbacks -
Modules/_testcapimodule.c - BasicStaticTypes -
Modules/_testcapimodule.c - num_basic_static_types_used -
Modules/_testcapimodule.c - ContainerNoGC_members -