From 8cc649813748db66ea26c8604a6dd1699b03539f Mon Sep 17 00:00:00 2001
From: Sam Gross <colesbury@gmail.com>
Date: Mon, 22 Jul 2024 20:00:38 +0000
Subject: [PATCH 1/3] gh-120974: Make _asyncio._leave_task atomic in the
 free-threaded build

Update `_PyDict_DelItemIf` to allow for an argument to be passed to the
predicate.
---
 Include/internal/pycore_dict.h |  6 ++++--
 Modules/_asynciomodule.c       | 21 ++++++++++-----------
 Modules/_weakref.c             | 13 +++----------
 Objects/dictobject.c           | 22 ++++++++++------------
 4 files changed, 27 insertions(+), 35 deletions(-)

diff --git a/Include/internal/pycore_dict.h b/Include/internal/pycore_dict.h
index a4bdf0d7ad8283..3caeb417a290f9 100644
--- a/Include/internal/pycore_dict.h
+++ b/Include/internal/pycore_dict.h
@@ -14,8 +14,10 @@ extern "C" {
 // Unsafe flavor of PyDict_GetItemWithError(): no error checking
 extern PyObject* _PyDict_GetItemWithError(PyObject *dp, PyObject *key);
 
-extern int _PyDict_DelItemIf(PyObject *mp, PyObject *key,
-                             int (*predicate)(PyObject *value));
+// Export for '_asyncio' shared extension
+PyAPI_FUNC(int) _PyDict_DelItemIf(PyObject *mp, PyObject *key,
+                                  int (*predicate)(PyObject *value, void *arg),
+                                  void *arg);
 
 // "KnownHash" variants
 // Export for '_asyncio' shared extension
diff --git a/Modules/_asynciomodule.c b/Modules/_asynciomodule.c
index 372f19794be0dd..c3ede487a5b88b 100644
--- a/Modules/_asynciomodule.c
+++ b/Modules/_asynciomodule.c
@@ -2031,18 +2031,9 @@ enter_task(asyncio_state *state, PyObject *loop, PyObject *task)
     return _PyDict_SetItem_KnownHash(state->current_tasks, loop, task, hash);
 }
 
-
 static int
-leave_task(asyncio_state *state, PyObject *loop, PyObject *task)
-/*[clinic end generated code: output=0ebf6db4b858fb41 input=51296a46313d1ad8]*/
+leave_task_predicate(PyObject *item, void *task)
 {
-    PyObject *item;
-    Py_hash_t hash;
-    hash = PyObject_Hash(loop);
-    if (hash == -1) {
-        return -1;
-    }
-    item = _PyDict_GetItem_KnownHash(state->current_tasks, loop, hash);
     if (item != task) {
         if (item == NULL) {
             /* Not entered, replace with None */
@@ -2054,7 +2045,15 @@ leave_task(asyncio_state *state, PyObject *loop, PyObject *task)
             task, item, NULL);
         return -1;
     }
-    return _PyDict_DelItem_KnownHash(state->current_tasks, loop, hash);
+    return 1;
+}
+
+static int
+leave_task(asyncio_state *state, PyObject *loop, PyObject *task)
+/*[clinic end generated code: output=0ebf6db4b858fb41 input=51296a46313d1ad8]*/
+{
+    return _PyDict_DelItemIf(state->current_tasks, loop, leave_task_predicate,
+                             task);
 }
 
 static PyObject *
diff --git a/Modules/_weakref.c b/Modules/_weakref.c
index a5c15c0f10b930..ecaa08ff60f203 100644
--- a/Modules/_weakref.c
+++ b/Modules/_weakref.c
@@ -31,7 +31,7 @@ _weakref_getweakrefcount_impl(PyObject *module, PyObject *object)
 
 
 static int
-is_dead_weakref(PyObject *value)
+is_dead_weakref(PyObject *value, void *unused)
 {
     if (!PyWeakref_Check(value)) {
         PyErr_SetString(PyExc_TypeError, "not a weakref");
@@ -56,15 +56,8 @@ _weakref__remove_dead_weakref_impl(PyObject *module, PyObject *dct,
                                    PyObject *key)
 /*[clinic end generated code: output=d9ff53061fcb875c input=19fc91f257f96a1d]*/
 {
-    if (_PyDict_DelItemIf(dct, key, is_dead_weakref) < 0) {
-        if (PyErr_ExceptionMatches(PyExc_KeyError))
-            /* This function is meant to allow safe weak-value dicts
-               with GC in another thread (see issue #28427), so it's
-               ok if the key doesn't exist anymore.
-               */
-            PyErr_Clear();
-        else
-            return NULL;
+    if (_PyDict_DelItemIf(dct, key, is_dead_weakref, NULL) < 0) {
+        return NULL;
     }
     Py_RETURN_NONE;
 }
diff --git a/Objects/dictobject.c b/Objects/dictobject.c
index 7310c3c8e13b5b..f054a145f34d84 100644
--- a/Objects/dictobject.c
+++ b/Objects/dictobject.c
@@ -2608,7 +2608,8 @@ _PyDict_DelItem_KnownHash(PyObject *op, PyObject *key, Py_hash_t hash)
 
 static int
 delitemif_lock_held(PyObject *op, PyObject *key,
-                    int (*predicate)(PyObject *value))
+                    int (*predicate)(PyObject *value, void *arg),
+                    void *arg)
 {
     Py_ssize_t ix;
     PyDictObject *mp;
@@ -2618,10 +2619,6 @@ delitemif_lock_held(PyObject *op, PyObject *key,
 
     ASSERT_DICT_LOCKED(op);
 
-    if (!PyDict_Check(op)) {
-        PyErr_BadInternalCall();
-        return -1;
-    }
     assert(key);
     hash = PyObject_Hash(key);
     if (hash == -1)
@@ -2630,12 +2627,8 @@ delitemif_lock_held(PyObject *op, PyObject *key,
     ix = _Py_dict_lookup(mp, key, hash, &old_value);
     if (ix == DKIX_ERROR)
         return -1;
-    if (ix == DKIX_EMPTY || old_value == NULL) {
-        _PyErr_SetKeyError(key);
-        return -1;
-    }
 
-    res = predicate(old_value);
+    res = predicate(old_value, arg);
     if (res == -1)
         return -1;
 
@@ -2655,11 +2648,16 @@ delitemif_lock_held(PyObject *op, PyObject *key,
  */
 int
 _PyDict_DelItemIf(PyObject *op, PyObject *key,
-                  int (*predicate)(PyObject *value))
+                  int (*predicate)(PyObject *value, void *arg),
+                  void *arg)
 {
+    if (!PyDict_Check(op)) {
+        PyErr_BadInternalCall();
+        return -1;
+    }
     int res;
     Py_BEGIN_CRITICAL_SECTION(op);
-    res = delitemif_lock_held(op, key, predicate);
+    res = delitemif_lock_held(op, key, predicate, arg);
     Py_END_CRITICAL_SECTION();
     return res;
 }

From 01f8052dd079755b9cb79e8b5db9c7186d1a35a4 Mon Sep 17 00:00:00 2001
From: Sam Gross <colesbury@gmail.com>
Date: Mon, 22 Jul 2024 20:24:06 +0000
Subject: [PATCH 2/3] Fix handling of case where key is missing

---
 Include/internal/pycore_dict.h |  2 ++
 Modules/_asynciomodule.c       | 29 ++++++++++++++++++-----------
 Objects/dictobject.c           | 15 ++++++++++-----
 3 files changed, 30 insertions(+), 16 deletions(-)

diff --git a/Include/internal/pycore_dict.h b/Include/internal/pycore_dict.h
index 3caeb417a290f9..fc304aca7fea10 100644
--- a/Include/internal/pycore_dict.h
+++ b/Include/internal/pycore_dict.h
@@ -14,6 +14,8 @@ extern "C" {
 // Unsafe flavor of PyDict_GetItemWithError(): no error checking
 extern PyObject* _PyDict_GetItemWithError(PyObject *dp, PyObject *key);
 
+// Delete an item from a dict if a predicate is true
+// Returns -1 on error, 1 if the item was deleted, 0 otherwise
 // Export for '_asyncio' shared extension
 PyAPI_FUNC(int) _PyDict_DelItemIf(PyObject *mp, PyObject *key,
                                   int (*predicate)(PyObject *value, void *arg),
diff --git a/Modules/_asynciomodule.c b/Modules/_asynciomodule.c
index c3ede487a5b88b..6fa3d1c0fe92d4 100644
--- a/Modules/_asynciomodule.c
+++ b/Modules/_asynciomodule.c
@@ -2031,19 +2031,21 @@ enter_task(asyncio_state *state, PyObject *loop, PyObject *task)
     return _PyDict_SetItem_KnownHash(state->current_tasks, loop, task, hash);
 }
 
+static int
+err_leave_task(PyObject *item, PyObject *task)
+{
+    PyErr_Format(
+        PyExc_RuntimeError,
+        "Leaving task %R does not match the current task %R.",
+        task, item);
+    return -1;
+}
+
 static int
 leave_task_predicate(PyObject *item, void *task)
 {
     if (item != task) {
-        if (item == NULL) {
-            /* Not entered, replace with None */
-            item = Py_None;
-        }
-        PyErr_Format(
-            PyExc_RuntimeError,
-            "Leaving task %R does not match the current task %R.",
-            task, item, NULL);
-        return -1;
+        return err_leave_task(item, (PyObject *)task);
     }
     return 1;
 }
@@ -2052,8 +2054,13 @@ static int
 leave_task(asyncio_state *state, PyObject *loop, PyObject *task)
 /*[clinic end generated code: output=0ebf6db4b858fb41 input=51296a46313d1ad8]*/
 {
-    return _PyDict_DelItemIf(state->current_tasks, loop, leave_task_predicate,
-                             task);
+    int res = _PyDict_DelItemIf(state->current_tasks, loop,
+                                leave_task_predicate, task);
+    if (res == 0) {
+        // task was not found
+        return err_leave_task(Py_None, task);
+    }
+    return res;
 }
 
 static PyObject *
diff --git a/Objects/dictobject.c b/Objects/dictobject.c
index f054a145f34d84..4bb2824c4dcf68 100644
--- a/Objects/dictobject.c
+++ b/Objects/dictobject.c
@@ -2508,7 +2508,7 @@ delete_index_from_values(PyDictValues *values, Py_ssize_t ix)
     values->size = size;
 }
 
-static int
+static void
 delitem_common(PyDictObject *mp, Py_hash_t hash, Py_ssize_t ix,
                PyObject *old_value, uint64_t new_version)
 {
@@ -2550,7 +2550,6 @@ delitem_common(PyDictObject *mp, Py_hash_t hash, Py_ssize_t ix,
     Py_DECREF(old_value);
 
     ASSERT_CONSISTENT(mp);
-    return 0;
 }
 
 int
@@ -2593,7 +2592,8 @@ delitem_knownhash_lock_held(PyObject *op, PyObject *key, Py_hash_t hash)
     PyInterpreterState *interp = _PyInterpreterState_GET();
     uint64_t new_version = _PyDict_NotifyEvent(
             interp, PyDict_EVENT_DELETED, mp, key, NULL);
-    return delitem_common(mp, hash, ix, old_value, new_version);
+    delitem_common(mp, hash, ix, old_value, new_version);
+    return 0;
 }
 
 int
@@ -2625,8 +2625,12 @@ delitemif_lock_held(PyObject *op, PyObject *key,
         return -1;
     mp = (PyDictObject *)op;
     ix = _Py_dict_lookup(mp, key, hash, &old_value);
-    if (ix == DKIX_ERROR)
+    if (ix == DKIX_ERROR) {
         return -1;
+    }
+    if (ix == DKIX_EMPTY || old_value == NULL) {
+        return 0;
+    }
 
     res = predicate(old_value, arg);
     if (res == -1)
@@ -2636,7 +2640,8 @@ delitemif_lock_held(PyObject *op, PyObject *key,
         PyInterpreterState *interp = _PyInterpreterState_GET();
         uint64_t new_version = _PyDict_NotifyEvent(
                 interp, PyDict_EVENT_DELETED, mp, key, NULL);
-        return delitem_common(mp, hash, ix, old_value, new_version);
+        delitem_common(mp, hash, ix, old_value, new_version);
+        return 1;
     } else {
         return 0;
     }

From 5c3a8cf1093024483fec8e80069b64889a1c7453 Mon Sep 17 00:00:00 2001
From: Sam Gross <colesbury@gmail.com>
Date: Tue, 23 Jul 2024 16:39:28 +0000
Subject: [PATCH 3/3] Changes from review

---
 Objects/dictobject.c | 5 +----
 1 file changed, 1 insertion(+), 4 deletions(-)

diff --git a/Objects/dictobject.c b/Objects/dictobject.c
index 4bb2824c4dcf68..ee88576cc77dec 100644
--- a/Objects/dictobject.c
+++ b/Objects/dictobject.c
@@ -2656,10 +2656,7 @@ _PyDict_DelItemIf(PyObject *op, PyObject *key,
                   int (*predicate)(PyObject *value, void *arg),
                   void *arg)
 {
-    if (!PyDict_Check(op)) {
-        PyErr_BadInternalCall();
-        return -1;
-    }
+    assert(PyDict_Check(op));
     int res;
     Py_BEGIN_CRITICAL_SECTION(op);
     res = delitemif_lock_held(op, key, predicate, arg);