gh-112075: Fix race in constructing dict for instance (#118499)

This commit is contained in:
Dino Viehland 2024-05-06 16:31:09 -07:00 committed by GitHub
parent 430945db4c
commit 636b8d94c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 216 additions and 71 deletions

View File

@ -105,10 +105,10 @@ PyAPI_FUNC(PyObject *)_PyDict_LoadGlobal(PyDictObject *, PyDictObject *, PyObjec
/* Consumes references to key and value */ /* Consumes references to key and value */
PyAPI_FUNC(int) _PyDict_SetItem_Take2(PyDictObject *op, PyObject *key, PyObject *value); PyAPI_FUNC(int) _PyDict_SetItem_Take2(PyDictObject *op, PyObject *key, PyObject *value);
extern int _PyObjectDict_SetItem(PyTypeObject *tp, PyObject **dictptr, PyObject *name, PyObject *value);
extern int _PyDict_SetItem_LockHeld(PyDictObject *dict, PyObject *name, PyObject *value); extern int _PyDict_SetItem_LockHeld(PyDictObject *dict, PyObject *name, PyObject *value);
extern int _PyDict_GetItemRef_Unicode_LockHeld(PyDictObject *op, PyObject *key, PyObject **result); extern int _PyDict_GetItemRef_Unicode_LockHeld(PyDictObject *op, PyObject *key, PyObject **result);
extern int _PyDict_GetItemRef_KnownHash(PyDictObject *op, PyObject *key, Py_hash_t hash, PyObject **result); extern int _PyDict_GetItemRef_KnownHash(PyDictObject *op, PyObject *key, Py_hash_t hash, PyObject **result);
extern int _PyObjectDict_SetItem(PyTypeObject *tp, PyObject *obj, PyObject **dictptr, PyObject *name, PyObject *value);
extern int _PyDict_Pop_KnownHash( extern int _PyDict_Pop_KnownHash(
PyDictObject *dict, PyDictObject *dict,

View File

@ -0,0 +1,141 @@
import gc
import time
import unittest
import weakref
from ast import Or
from functools import partial
from threading import Thread
from unittest import TestCase
from test.support import threading_helper
@threading_helper.requires_working_threading()
class TestDict(TestCase):
def test_racing_creation_shared_keys(self):
"""Verify that creating dictionaries is thread safe when we
have a type with shared keys"""
class C(int):
pass
self.racing_creation(C)
def test_racing_creation_no_shared_keys(self):
"""Verify that creating dictionaries is thread safe when we
have a type with an ordinary dict"""
self.racing_creation(Or)
def test_racing_creation_inline_values_invalid(self):
"""Verify that re-creating a dict after we have invalid inline values
is thread safe"""
class C:
pass
def make_obj():
a = C()
# Make object, make inline values invalid, and then delete dict
a.__dict__ = {}
del a.__dict__
return a
self.racing_creation(make_obj)
def test_racing_creation_nonmanaged_dict(self):
"""Verify that explicit creation of an unmanaged dict is thread safe
outside of the normal attribute setting code path"""
def make_obj():
def f(): pass
return f
def set(func, name, val):
# Force creation of the dict via PyObject_GenericGetDict
func.__dict__[name] = val
self.racing_creation(make_obj, set)
def racing_creation(self, cls, set=setattr):
objects = []
processed = []
OBJECT_COUNT = 100
THREAD_COUNT = 10
CUR = 0
for i in range(OBJECT_COUNT):
objects.append(cls())
def writer_func(name):
last = -1
while True:
if CUR == last:
continue
elif CUR == OBJECT_COUNT:
break
obj = objects[CUR]
set(obj, name, name)
last = CUR
processed.append(name)
writers = []
for x in range(THREAD_COUNT):
writer = Thread(target=partial(writer_func, f"a{x:02}"))
writers.append(writer)
writer.start()
for i in range(OBJECT_COUNT):
CUR = i
while len(processed) != THREAD_COUNT:
time.sleep(0.001)
processed.clear()
CUR = OBJECT_COUNT
for writer in writers:
writer.join()
for obj_idx, obj in enumerate(objects):
assert (
len(obj.__dict__) == THREAD_COUNT
), f"{len(obj.__dict__)} {obj.__dict__!r} {obj_idx}"
for i in range(THREAD_COUNT):
assert f"a{i:02}" in obj.__dict__, f"a{i:02} missing at {obj_idx}"
def test_racing_set_dict(self):
"""Races assigning to __dict__ should be thread safe"""
def f(): pass
l = []
THREAD_COUNT = 10
class MyDict(dict): pass
def writer_func(l):
for i in range(1000):
d = MyDict()
l.append(weakref.ref(d))
f.__dict__ = d
lists = []
writers = []
for x in range(THREAD_COUNT):
thread_list = []
lists.append(thread_list)
writer = Thread(target=partial(writer_func, thread_list))
writers.append(writer)
for writer in writers:
writer.start()
for writer in writers:
writer.join()
f.__dict__ = {}
gc.collect()
for thread_list in lists:
for ref in thread_list:
self.assertIsNone(ref())
if __name__ == "__main__":
unittest.main()

View File

@ -924,16 +924,15 @@ new_dict(PyInterpreterState *interp,
return (PyObject *)mp; return (PyObject *)mp;
} }
/* Consumes a reference to the keys object */
static PyObject * static PyObject *
new_dict_with_shared_keys(PyInterpreterState *interp, PyDictKeysObject *keys) new_dict_with_shared_keys(PyInterpreterState *interp, PyDictKeysObject *keys)
{ {
size_t size = shared_keys_usable_size(keys); size_t size = shared_keys_usable_size(keys);
PyDictValues *values = new_values(size); PyDictValues *values = new_values(size);
if (values == NULL) { if (values == NULL) {
dictkeys_decref(interp, keys, false);
return PyErr_NoMemory(); return PyErr_NoMemory();
} }
dictkeys_incref(keys);
for (size_t i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
values->values[i] = NULL; values->values[i] = NULL;
} }
@ -6693,8 +6692,6 @@ materialize_managed_dict_lock_held(PyObject *obj)
{ {
_Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED(obj); _Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED(obj);
OBJECT_STAT_INC(dict_materialized_on_request);
PyDictValues *values = _PyObject_InlineValues(obj); PyDictValues *values = _PyObject_InlineValues(obj);
PyInterpreterState *interp = _PyInterpreterState_GET(); PyInterpreterState *interp = _PyInterpreterState_GET();
PyDictKeysObject *keys = CACHED_KEYS(Py_TYPE(obj)); PyDictKeysObject *keys = CACHED_KEYS(Py_TYPE(obj));
@ -7186,35 +7183,77 @@ _PyDict_DetachFromObject(PyDictObject *mp, PyObject *obj)
return 0; return 0;
} }
PyObject * static inline PyObject *
PyObject_GenericGetDict(PyObject *obj, void *context) ensure_managed_dict(PyObject *obj)
{ {
PyInterpreterState *interp = _PyInterpreterState_GET(); PyDictObject *dict = _PyObject_GetManagedDict(obj);
if (dict == NULL) {
PyTypeObject *tp = Py_TYPE(obj); PyTypeObject *tp = Py_TYPE(obj);
PyDictObject *dict; if ((tp->tp_flags & Py_TPFLAGS_INLINE_VALUES) &&
if (_PyType_HasFeature(tp, Py_TPFLAGS_MANAGED_DICT)) {
dict = _PyObject_GetManagedDict(obj);
if (dict == NULL &&
(tp->tp_flags & Py_TPFLAGS_INLINE_VALUES) &&
FT_ATOMIC_LOAD_UINT8(_PyObject_InlineValues(obj)->valid)) { FT_ATOMIC_LOAD_UINT8(_PyObject_InlineValues(obj)->valid)) {
dict = _PyObject_MaterializeManagedDict(obj); dict = _PyObject_MaterializeManagedDict(obj);
} }
else if (dict == NULL) { else {
Py_BEGIN_CRITICAL_SECTION(obj); #ifdef Py_GIL_DISABLED
// Check again that we're not racing with someone else creating the dict // Check again that we're not racing with someone else creating the dict
Py_BEGIN_CRITICAL_SECTION(obj);
dict = _PyObject_GetManagedDict(obj); dict = _PyObject_GetManagedDict(obj);
if (dict == NULL) { if (dict != NULL) {
OBJECT_STAT_INC(dict_materialized_on_request); goto done;
dictkeys_incref(CACHED_KEYS(tp)); }
dict = (PyDictObject *)new_dict_with_shared_keys(interp, CACHED_KEYS(tp)); #endif
dict = (PyDictObject *)new_dict_with_shared_keys(_PyInterpreterState_GET(),
CACHED_KEYS(tp));
FT_ATOMIC_STORE_PTR_RELEASE(_PyObject_ManagedDictPointer(obj)->dict, FT_ATOMIC_STORE_PTR_RELEASE(_PyObject_ManagedDictPointer(obj)->dict,
(PyDictObject *)dict); (PyDictObject *)dict);
#ifdef Py_GIL_DISABLED
done:
Py_END_CRITICAL_SECTION();
#endif
}
}
return (PyObject *)dict;
} }
Py_END_CRITICAL_SECTION(); static inline PyObject *
ensure_nonmanaged_dict(PyObject *obj, PyObject **dictptr)
{
PyDictKeysObject *cached;
PyObject *dict = FT_ATOMIC_LOAD_PTR_ACQUIRE(*dictptr);
if (dict == NULL) {
#ifdef Py_GIL_DISABLED
Py_BEGIN_CRITICAL_SECTION(obj);
dict = *dictptr;
if (dict != NULL) {
goto done;
} }
return Py_XNewRef((PyObject *)dict); #endif
PyTypeObject *tp = Py_TYPE(obj);
if (_PyType_HasFeature(tp, Py_TPFLAGS_HEAPTYPE) && (cached = CACHED_KEYS(tp))) {
PyInterpreterState *interp = _PyInterpreterState_GET();
assert(!_PyType_HasFeature(tp, Py_TPFLAGS_INLINE_VALUES));
dict = new_dict_with_shared_keys(interp, cached);
}
else {
dict = PyDict_New();
}
FT_ATOMIC_STORE_PTR_RELEASE(*dictptr, dict);
#ifdef Py_GIL_DISABLED
done:
Py_END_CRITICAL_SECTION();
#endif
}
return dict;
}
PyObject *
PyObject_GenericGetDict(PyObject *obj, void *context)
{
PyTypeObject *tp = Py_TYPE(obj);
if (_PyType_HasFeature(tp, Py_TPFLAGS_MANAGED_DICT)) {
return Py_XNewRef(ensure_managed_dict(obj));
} }
else { else {
PyObject **dictptr = _PyObject_ComputedDictPointer(obj); PyObject **dictptr = _PyObject_ComputedDictPointer(obj);
@ -7223,65 +7262,28 @@ PyObject_GenericGetDict(PyObject *obj, void *context)
"This object has no __dict__"); "This object has no __dict__");
return NULL; return NULL;
} }
PyObject *dict = *dictptr;
if (dict == NULL) { return Py_XNewRef(ensure_nonmanaged_dict(obj, dictptr));
PyTypeObject *tp = Py_TYPE(obj);
if (_PyType_HasFeature(tp, Py_TPFLAGS_HEAPTYPE) && CACHED_KEYS(tp)) {
dictkeys_incref(CACHED_KEYS(tp));
*dictptr = dict = new_dict_with_shared_keys(
interp, CACHED_KEYS(tp));
}
else {
*dictptr = dict = PyDict_New();
}
}
return Py_XNewRef(dict);
} }
} }
int int
_PyObjectDict_SetItem(PyTypeObject *tp, PyObject **dictptr, _PyObjectDict_SetItem(PyTypeObject *tp, PyObject *obj, PyObject **dictptr,
PyObject *key, PyObject *value) PyObject *key, PyObject *value)
{ {
PyObject *dict; PyObject *dict;
int res; int res;
PyDictKeysObject *cached;
PyInterpreterState *interp = _PyInterpreterState_GET();
assert(dictptr != NULL); assert(dictptr != NULL);
if ((tp->tp_flags & Py_TPFLAGS_HEAPTYPE) && (cached = CACHED_KEYS(tp))) { dict = ensure_nonmanaged_dict(obj, dictptr);
assert(dictptr != NULL);
dict = *dictptr;
if (dict == NULL) { if (dict == NULL) {
assert(!_PyType_HasFeature(tp, Py_TPFLAGS_INLINE_VALUES));
dictkeys_incref(cached);
dict = new_dict_with_shared_keys(interp, cached);
if (dict == NULL)
return -1; return -1;
*dictptr = dict;
}
if (value == NULL) {
res = PyDict_DelItem(dict, key);
}
else {
res = PyDict_SetItem(dict, key, value);
}
} else {
dict = *dictptr;
if (dict == NULL) {
dict = PyDict_New();
if (dict == NULL)
return -1;
*dictptr = dict;
}
if (value == NULL) {
res = PyDict_DelItem(dict, key);
} else {
res = PyDict_SetItem(dict, key, value);
}
} }
Py_BEGIN_CRITICAL_SECTION(dict);
res = _PyDict_SetItem_LockHeld((PyDictObject *)dict, key, value);
ASSERT_CONSISTENT(dict); ASSERT_CONSISTENT(dict);
Py_END_CRITICAL_SECTION();
return res; return res;
} }

View File

@ -1731,7 +1731,7 @@ _PyObject_GenericSetAttrWithDict(PyObject *obj, PyObject *name,
goto done; goto done;
} }
else { else {
res = _PyObjectDict_SetItem(tp, dictptr, name, value); res = _PyObjectDict_SetItem(tp, obj, dictptr, name, value);
} }
} }
else { else {
@ -1789,7 +1789,9 @@ PyObject_GenericSetDict(PyObject *obj, PyObject *value, void *context)
"not a '%.200s'", Py_TYPE(value)->tp_name); "not a '%.200s'", Py_TYPE(value)->tp_name);
return -1; return -1;
} }
Py_BEGIN_CRITICAL_SECTION(obj);
Py_XSETREF(*dictptr, Py_NewRef(value)); Py_XSETREF(*dictptr, Py_NewRef(value));
Py_END_CRITICAL_SECTION();
return 0; return 0;
} }