Skip to content

gh-112075: Fix race in constructing dict for instance #118499

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Include/internal/pycore_dict.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,10 @@ PyAPI_FUNC(PyObject *)_PyDict_LoadGlobal(PyDictObject *, PyDictObject *, PyObjec

/* Consumes references to key and value */
PyAPI_FUNC(int) _PyDict_SetItem_Take2(PyDictObject *op, PyObject *key, PyObject *value);
extern int _PyObjectDict_SetItem(PyTypeObject *tp, PyObject **dictptr, PyObject *name, PyObject *value);
extern int _PyDict_SetItem_LockHeld(PyDictObject *dict, PyObject *name, PyObject *value);
extern int _PyDict_GetItemRef_Unicode_LockHeld(PyDictObject *op, PyObject *key, PyObject **result);
extern int _PyDict_GetItemRef_KnownHash(PyDictObject *op, PyObject *key, Py_hash_t hash, PyObject **result);
extern int _PyObjectDict_SetItem(PyTypeObject *tp, PyObject *obj, PyObject **dictptr, PyObject *name, PyObject *value);

extern int _PyDict_Pop_KnownHash(
PyDictObject *dict,
Expand Down
141 changes: 141 additions & 0 deletions Lib/test/test_free_threading/test_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import gc
import time
import unittest
import weakref

from ast import Or
from functools import partial
from threading import Thread
from unittest import TestCase

from test.support import threading_helper


@threading_helper.requires_working_threading()
class TestDict(TestCase):
def test_racing_creation_shared_keys(self):
"""Verify that creating dictionaries is thread safe when we
have a type with shared keys"""
class C(int):
pass

self.racing_creation(C)

def test_racing_creation_no_shared_keys(self):
"""Verify that creating dictionaries is thread safe when we
have a type with an ordinary dict"""
self.racing_creation(Or)

def test_racing_creation_inline_values_invalid(self):
"""Verify that re-creating a dict after we have invalid inline values
is thread safe"""
class C:
pass

def make_obj():
a = C()
# Make object, make inline values invalid, and then delete dict
a.__dict__ = {}
del a.__dict__
return a

self.racing_creation(make_obj)

def test_racing_creation_nonmanaged_dict(self):
"""Verify that explicit creation of an unmanaged dict is thread safe
outside of the normal attribute setting code path"""
def make_obj():
def f(): pass
return f

def set(func, name, val):
# Force creation of the dict via PyObject_GenericGetDict
func.__dict__[name] = val

self.racing_creation(make_obj, set)

def racing_creation(self, cls, set=setattr):
objects = []
processed = []

OBJECT_COUNT = 100
THREAD_COUNT = 10
CUR = 0

for i in range(OBJECT_COUNT):
objects.append(cls())

def writer_func(name):
last = -1
while True:
if CUR == last:
continue
elif CUR == OBJECT_COUNT:
break

obj = objects[CUR]
set(obj, name, name)
last = CUR
processed.append(name)

writers = []
for x in range(THREAD_COUNT):
writer = Thread(target=partial(writer_func, f"a{x:02}"))
writers.append(writer)
writer.start()

for i in range(OBJECT_COUNT):
CUR = i
while len(processed) != THREAD_COUNT:
time.sleep(0.001)
processed.clear()

CUR = OBJECT_COUNT

for writer in writers:
writer.join()

for obj_idx, obj in enumerate(objects):
assert (
len(obj.__dict__) == THREAD_COUNT
), f"{len(obj.__dict__)} {obj.__dict__!r} {obj_idx}"
for i in range(THREAD_COUNT):
assert f"a{i:02}" in obj.__dict__, f"a{i:02} missing at {obj_idx}"

def test_racing_set_dict(self):
"""Races assigning to __dict__ should be thread safe"""

def f(): pass
l = []
THREAD_COUNT = 10
class MyDict(dict): pass

def writer_func(l):
for i in range(1000):
d = MyDict()
l.append(weakref.ref(d))
f.__dict__ = d

lists = []
writers = []
for x in range(THREAD_COUNT):
thread_list = []
lists.append(thread_list)
writer = Thread(target=partial(writer_func, thread_list))
writers.append(writer)

for writer in writers:
writer.start()

for writer in writers:
writer.join()

f.__dict__ = {}
gc.collect()

for thread_list in lists:
for ref in thread_list:
self.assertIsNone(ref())

if __name__ == "__main__":
unittest.main()
140 changes: 71 additions & 69 deletions Objects/dictobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -924,16 +924,15 @@ new_dict(PyInterpreterState *interp,
return (PyObject *)mp;
}

/* Consumes a reference to the keys object */
static PyObject *
new_dict_with_shared_keys(PyInterpreterState *interp, PyDictKeysObject *keys)
{
size_t size = shared_keys_usable_size(keys);
PyDictValues *values = new_values(size);
if (values == NULL) {
dictkeys_decref(interp, keys, false);
return PyErr_NoMemory();
}
dictkeys_incref(keys);
for (size_t i = 0; i < size; i++) {
values->values[i] = NULL;
}
Expand Down Expand Up @@ -6693,8 +6692,6 @@ materialize_managed_dict_lock_held(PyObject *obj)
{
_Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED(obj);

OBJECT_STAT_INC(dict_materialized_on_request);

PyDictValues *values = _PyObject_InlineValues(obj);
PyInterpreterState *interp = _PyInterpreterState_GET();
PyDictKeysObject *keys = CACHED_KEYS(Py_TYPE(obj));
Expand Down Expand Up @@ -7186,35 +7183,77 @@ _PyDict_DetachFromObject(PyDictObject *mp, PyObject *obj)
return 0;
}

PyObject *
PyObject_GenericGetDict(PyObject *obj, void *context)
static inline PyObject *
ensure_managed_dict(PyObject *obj)
{
PyInterpreterState *interp = _PyInterpreterState_GET();
PyTypeObject *tp = Py_TYPE(obj);
PyDictObject *dict;
if (_PyType_HasFeature(tp, Py_TPFLAGS_MANAGED_DICT)) {
dict = _PyObject_GetManagedDict(obj);
if (dict == NULL &&
(tp->tp_flags & Py_TPFLAGS_INLINE_VALUES) &&
PyDictObject *dict = _PyObject_GetManagedDict(obj);
if (dict == NULL) {
PyTypeObject *tp = Py_TYPE(obj);
if ((tp->tp_flags & Py_TPFLAGS_INLINE_VALUES) &&
FT_ATOMIC_LOAD_UINT8(_PyObject_InlineValues(obj)->valid)) {
dict = _PyObject_MaterializeManagedDict(obj);
}
else if (dict == NULL) {
Py_BEGIN_CRITICAL_SECTION(obj);

else {
#ifdef Py_GIL_DISABLED
// Check again that we're not racing with someone else creating the dict
Py_BEGIN_CRITICAL_SECTION(obj);
dict = _PyObject_GetManagedDict(obj);
if (dict == NULL) {
OBJECT_STAT_INC(dict_materialized_on_request);
dictkeys_incref(CACHED_KEYS(tp));
dict = (PyDictObject *)new_dict_with_shared_keys(interp, CACHED_KEYS(tp));
FT_ATOMIC_STORE_PTR_RELEASE(_PyObject_ManagedDictPointer(obj)->dict,
(PyDictObject *)dict);
if (dict != NULL) {
goto done;
}
#endif
dict = (PyDictObject *)new_dict_with_shared_keys(_PyInterpreterState_GET(),
CACHED_KEYS(tp));
FT_ATOMIC_STORE_PTR_RELEASE(_PyObject_ManagedDictPointer(obj)->dict,
(PyDictObject *)dict);

#ifdef Py_GIL_DISABLED
done:
Py_END_CRITICAL_SECTION();
#endif
}
return Py_XNewRef((PyObject *)dict);
}
return (PyObject *)dict;
}

static inline PyObject *
ensure_nonmanaged_dict(PyObject *obj, PyObject **dictptr)
{
PyDictKeysObject *cached;

PyObject *dict = FT_ATOMIC_LOAD_PTR_ACQUIRE(*dictptr);
if (dict == NULL) {
#ifdef Py_GIL_DISABLED
Py_BEGIN_CRITICAL_SECTION(obj);
dict = *dictptr;
if (dict != NULL) {
goto done;
}
#endif
PyTypeObject *tp = Py_TYPE(obj);
if (_PyType_HasFeature(tp, Py_TPFLAGS_HEAPTYPE) && (cached = CACHED_KEYS(tp))) {
PyInterpreterState *interp = _PyInterpreterState_GET();
assert(!_PyType_HasFeature(tp, Py_TPFLAGS_INLINE_VALUES));
dict = new_dict_with_shared_keys(interp, cached);
}
else {
dict = PyDict_New();
}
FT_ATOMIC_STORE_PTR_RELEASE(*dictptr, dict);
#ifdef Py_GIL_DISABLED
done:
Py_END_CRITICAL_SECTION();
#endif
}
return dict;
}

PyObject *
PyObject_GenericGetDict(PyObject *obj, void *context)
{
PyTypeObject *tp = Py_TYPE(obj);
if (_PyType_HasFeature(tp, Py_TPFLAGS_MANAGED_DICT)) {
return Py_XNewRef(ensure_managed_dict(obj));
}
else {
PyObject **dictptr = _PyObject_ComputedDictPointer(obj);
Expand All @@ -7223,65 +7262,28 @@ PyObject_GenericGetDict(PyObject *obj, void *context)
"This object has no __dict__");
return NULL;
}
PyObject *dict = *dictptr;
if (dict == NULL) {
PyTypeObject *tp = Py_TYPE(obj);
if (_PyType_HasFeature(tp, Py_TPFLAGS_HEAPTYPE) && CACHED_KEYS(tp)) {
dictkeys_incref(CACHED_KEYS(tp));
*dictptr = dict = new_dict_with_shared_keys(
interp, CACHED_KEYS(tp));
}
else {
*dictptr = dict = PyDict_New();
}
}
return Py_XNewRef(dict);

return Py_XNewRef(ensure_nonmanaged_dict(obj, dictptr));
}
}

int
_PyObjectDict_SetItem(PyTypeObject *tp, PyObject **dictptr,
_PyObjectDict_SetItem(PyTypeObject *tp, PyObject *obj, PyObject **dictptr,
PyObject *key, PyObject *value)
{
PyObject *dict;
int res;
PyDictKeysObject *cached;
PyInterpreterState *interp = _PyInterpreterState_GET();

assert(dictptr != NULL);
if ((tp->tp_flags & Py_TPFLAGS_HEAPTYPE) && (cached = CACHED_KEYS(tp))) {
assert(dictptr != NULL);
dict = *dictptr;
if (dict == NULL) {
assert(!_PyType_HasFeature(tp, Py_TPFLAGS_INLINE_VALUES));
dictkeys_incref(cached);
dict = new_dict_with_shared_keys(interp, cached);
if (dict == NULL)
return -1;
*dictptr = dict;
}
if (value == NULL) {
res = PyDict_DelItem(dict, key);
}
else {
res = PyDict_SetItem(dict, key, value);
}
} else {
dict = *dictptr;
if (dict == NULL) {
dict = PyDict_New();
if (dict == NULL)
return -1;
*dictptr = dict;
}
if (value == NULL) {
res = PyDict_DelItem(dict, key);
} else {
res = PyDict_SetItem(dict, key, value);
}
dict = ensure_nonmanaged_dict(obj, dictptr);
if (dict == NULL) {
return -1;
}

Py_BEGIN_CRITICAL_SECTION(dict);
res = _PyDict_SetItem_LockHeld((PyDictObject *)dict, key, value);
ASSERT_CONSISTENT(dict);
Py_END_CRITICAL_SECTION();
return res;
}

Expand Down
4 changes: 3 additions & 1 deletion Objects/object.c
Original file line number Diff line number Diff line change
Expand Up @@ -1731,7 +1731,7 @@ _PyObject_GenericSetAttrWithDict(PyObject *obj, PyObject *name,
goto done;
}
else {
res = _PyObjectDict_SetItem(tp, dictptr, name, value);
res = _PyObjectDict_SetItem(tp, obj, dictptr, name, value);
}
}
else {
Expand Down Expand Up @@ -1789,7 +1789,9 @@ PyObject_GenericSetDict(PyObject *obj, PyObject *value, void *context)
"not a '%.200s'", Py_TYPE(value)->tp_name);
return -1;
}
Py_BEGIN_CRITICAL_SECTION(obj);
Py_XSETREF(*dictptr, Py_NewRef(value));
Py_END_CRITICAL_SECTION();
return 0;
}

Expand Down
Loading