From 6a87edc44295db7bb5104124d7bd20ad6037e0c7 Mon Sep 17 00:00:00 2001 From: Sam Gross Date: Fri, 2 Aug 2024 09:32:08 -0400 Subject: [PATCH] gh-120974: Make asyncio `swap_current_task` safe in free-threaded build (GH-122317) * gh-120974: Make asyncio `swap_current_task` safe in free-threaded build (cherry picked from commit b5e6fb39a246bf7ee470d58632cdf588bb9d0298) Co-authored-by: Sam Gross --- Include/internal/pycore_dict.h | 7 ++++- Modules/_asynciomodule.c | 37 ++++++++++++++--------- Objects/dictobject.c | 54 ++++++++++++++++++++++++---------- 3 files changed, 67 insertions(+), 31 deletions(-) diff --git a/Include/internal/pycore_dict.h b/Include/internal/pycore_dict.h index 1c4a63c2271a93..36da498db2c3e1 100644 --- a/Include/internal/pycore_dict.h +++ b/Include/internal/pycore_dict.h @@ -110,8 +110,13 @@ PyAPI_FUNC(PyObject *)_PyDict_LoadGlobal(PyDictObject *, PyDictObject *, PyObjec /* Consumes references to key and value */ PyAPI_FUNC(int) _PyDict_SetItem_Take2(PyDictObject *op, PyObject *key, PyObject *value); extern int _PyDict_SetItem_LockHeld(PyDictObject *dict, PyObject *name, PyObject *value); -extern int _PyDict_GetItemRef_Unicode_LockHeld(PyDictObject *op, PyObject *key, PyObject **result); +// Export for '_asyncio' shared extension +PyAPI_FUNC(int) _PyDict_SetItem_KnownHash_LockHeld(PyDictObject *mp, PyObject *key, + PyObject *value, Py_hash_t hash); +// Export for '_asyncio' shared extension +PyAPI_FUNC(int) _PyDict_GetItemRef_KnownHash_LockHeld(PyDictObject *op, PyObject *key, Py_hash_t hash, PyObject **result); extern int _PyDict_GetItemRef_KnownHash(PyDictObject *op, PyObject *key, Py_hash_t hash, PyObject **result); +extern int _PyDict_GetItemRef_Unicode_LockHeld(PyDictObject *op, PyObject *key, PyObject **result); extern int _PyObjectDict_SetItem(PyTypeObject *tp, PyObject *obj, PyObject **dictptr, PyObject *name, PyObject *value); extern int _PyDict_Pop_KnownHash( diff --git a/Modules/_asynciomodule.c b/Modules/_asynciomodule.c index 28641b85451763..6e87de5e954826 100644 --- a/Modules/_asynciomodule.c +++ b/Modules/_asynciomodule.c @@ -1977,6 +1977,24 @@ leave_task(asyncio_state *state, PyObject *loop, PyObject *task) return res; } +static PyObject * +swap_current_task_lock_held(PyDictObject *current_tasks, PyObject *loop, + Py_hash_t hash, PyObject *task) +{ + PyObject *prev_task; + if (_PyDict_GetItemRef_KnownHash_LockHeld(current_tasks, loop, hash, &prev_task) < 0) { + return NULL; + } + if (_PyDict_SetItem_KnownHash_LockHeld(current_tasks, loop, task, hash) < 0) { + Py_XDECREF(prev_task); + return NULL; + } + if (prev_task == NULL) { + Py_RETURN_NONE; + } + return prev_task; +} + static PyObject * swap_current_task(asyncio_state *state, PyObject *loop, PyObject *task) { @@ -1992,24 +2010,15 @@ swap_current_task(asyncio_state *state, PyObject *loop, PyObject *task) return prev_task; } - Py_hash_t hash; - hash = PyObject_Hash(loop); + Py_hash_t hash = PyObject_Hash(loop); if (hash == -1) { return NULL; } - prev_task = _PyDict_GetItem_KnownHash(state->current_tasks, loop, hash); - if (prev_task == NULL) { - if (PyErr_Occurred()) { - return NULL; - } - prev_task = Py_None; - } - Py_INCREF(prev_task); - if (_PyDict_SetItem_KnownHash(state->current_tasks, loop, task, hash) == -1) { - Py_DECREF(prev_task); - return NULL; - } + PyDictObject *current_tasks = (PyDictObject *)state->current_tasks; + Py_BEGIN_CRITICAL_SECTION(current_tasks); + prev_task = swap_current_task_lock_held(current_tasks, loop, hash, task); + Py_END_CRITICAL_SECTION(); return prev_task; } diff --git a/Objects/dictobject.c b/Objects/dictobject.c index dedcd232483575..05600772db7026 100644 --- a/Objects/dictobject.c +++ b/Objects/dictobject.c @@ -2275,6 +2275,29 @@ _PyDict_GetItem_KnownHash(PyObject *op, PyObject *key, Py_hash_t hash) return value; // borrowed reference } +/* Gets an item and provides a new reference if the value is present. + * Returns 1 if the key is present, 0 if the key is missing, and -1 if an + * exception occurred. +*/ +int +_PyDict_GetItemRef_KnownHash_LockHeld(PyDictObject *op, PyObject *key, + Py_hash_t hash, PyObject **result) +{ + PyObject *value; + Py_ssize_t ix = _Py_dict_lookup(op, key, hash, &value); + assert(ix >= 0 || value == NULL); + if (ix == DKIX_ERROR) { + *result = NULL; + return -1; + } + if (value == NULL) { + *result = NULL; + return 0; // missing key + } + *result = Py_NewRef(value); + return 1; // key is present +} + /* Gets an item and provides a new reference if the value is present. * Returns 1 if the key is present, 0 if the key is missing, and -1 if an * exception occurred. @@ -2519,11 +2542,21 @@ setitem_lock_held(PyDictObject *mp, PyObject *key, PyObject *value) int -_PyDict_SetItem_KnownHash(PyObject *op, PyObject *key, PyObject *value, - Py_hash_t hash) +_PyDict_SetItem_KnownHash_LockHeld(PyDictObject *mp, PyObject *key, PyObject *value, + Py_hash_t hash) { - PyDictObject *mp; + PyInterpreterState *interp = _PyInterpreterState_GET(); + if (mp->ma_keys == Py_EMPTY_KEYS) { + return insert_to_emptydict(interp, mp, Py_NewRef(key), hash, Py_NewRef(value)); + } + /* insertdict() handles any resizing that might be necessary */ + return insertdict(interp, mp, Py_NewRef(key), hash, Py_NewRef(value)); +} +int +_PyDict_SetItem_KnownHash(PyObject *op, PyObject *key, PyObject *value, + Py_hash_t hash) +{ if (!PyDict_Check(op)) { PyErr_BadInternalCall(); return -1; @@ -2531,21 +2564,10 @@ _PyDict_SetItem_KnownHash(PyObject *op, PyObject *key, PyObject *value, assert(key); assert(value); assert(hash != -1); - mp = (PyDictObject *)op; int res; - PyInterpreterState *interp = _PyInterpreterState_GET(); - - Py_BEGIN_CRITICAL_SECTION(mp); - - if (mp->ma_keys == Py_EMPTY_KEYS) { - res = insert_to_emptydict(interp, mp, Py_NewRef(key), hash, Py_NewRef(value)); - } - else { - /* insertdict() handles any resizing that might be necessary */ - res = insertdict(interp, mp, Py_NewRef(key), hash, Py_NewRef(value)); - } - + Py_BEGIN_CRITICAL_SECTION(op); + res = _PyDict_SetItem_KnownHash_LockHeld((PyDictObject *)op, key, value, hash); Py_END_CRITICAL_SECTION(); return res; }