From f268e328ed5d7b2df5bdad39691f6e4789a2fcde Mon Sep 17 00:00:00 2001 From: Brett Simmers Date: Thu, 11 Apr 2024 15:13:25 -0700 Subject: [PATCH] gh-116738: Make _abc module thread-safe (#117488) A collection of small changes aimed at making the `_abc` module safe to use in a free-threaded build. --- Include/internal/pycore_typeobject.h | 12 ++ Modules/_abc.c | 262 +++++++++++++++------------ Objects/typeobject.c | 46 +++++ 3 files changed, 205 insertions(+), 115 deletions(-) diff --git a/Include/internal/pycore_typeobject.h b/Include/internal/pycore_typeobject.h index 8a25935f308..1693119ffec 100644 --- a/Include/internal/pycore_typeobject.h +++ b/Include/internal/pycore_typeobject.h @@ -152,6 +152,18 @@ PyAPI_FUNC(PyObject*) _PySuper_Lookup(PyTypeObject *su_type, PyObject *su_obj, extern PyObject* _PyType_GetFullyQualifiedName(PyTypeObject *type, char sep); +// Perform the following operation, in a thread-safe way when required by the +// build mode. +// +// self->tp_flags = (self->tp_flags & ~mask) | flags; +extern void _PyType_SetFlags(PyTypeObject *self, unsigned long mask, + unsigned long flags); + +// Like _PyType_SetFlags(), but apply the operation to self and any of its +// subclasses without Py_TPFLAGS_IMMUTABLETYPE set. +extern void _PyType_SetFlagsRecursive(PyTypeObject *self, unsigned long mask, + unsigned long flags); + #ifdef __cplusplus } diff --git a/Modules/_abc.c b/Modules/_abc.c index 399ecbbd6a2..ad28035843f 100644 --- a/Modules/_abc.c +++ b/Modules/_abc.c @@ -21,7 +21,7 @@ PyDoc_STRVAR(_abc__doc__, typedef struct { PyTypeObject *_abc_data_type; - unsigned long long abc_invalidation_counter; + uint64_t abc_invalidation_counter; } _abcmodule_state; static inline _abcmodule_state* @@ -32,17 +32,61 @@ get_abc_state(PyObject *module) return (_abcmodule_state *)state; } +static inline uint64_t +get_invalidation_counter(_abcmodule_state *state) +{ +#ifdef Py_GIL_DISABLED + return _Py_atomic_load_uint64(&state->abc_invalidation_counter); +#else + return state->abc_invalidation_counter; +#endif +} + +static inline void +increment_invalidation_counter(_abcmodule_state *state) +{ +#ifdef Py_GIL_DISABLED + _Py_atomic_add_uint64(&state->abc_invalidation_counter, 1); +#else + state->abc_invalidation_counter++; +#endif +} + /* This object stores internal state for ABCs. Note that we can use normal sets for caches, since they are never iterated over. */ typedef struct { PyObject_HEAD + /* These sets of weak references are lazily created. Once created, they + will point to the same sets until the ABCMeta object is destroyed or + cleared, both of which will only happen while the object is visible to a + single thread. */ PyObject *_abc_registry; - PyObject *_abc_cache; /* Normal set of weak references. */ - PyObject *_abc_negative_cache; /* Normal set of weak references. */ - unsigned long long _abc_negative_cache_version; + PyObject *_abc_cache; + PyObject *_abc_negative_cache; + uint64_t _abc_negative_cache_version; } _abc_data; +static inline uint64_t +get_cache_version(_abc_data *impl) +{ +#ifdef Py_GIL_DISABLED + return _Py_atomic_load_uint64(&impl->_abc_negative_cache_version); +#else + return impl->_abc_negative_cache_version; +#endif +} + +static inline void +set_cache_version(_abc_data *impl, uint64_t version) +{ +#ifdef Py_GIL_DISABLED + _Py_atomic_store_uint64(&impl->_abc_negative_cache_version, version); +#else + impl->_abc_negative_cache_version = version; +#endif +} + static int abc_data_traverse(_abc_data *self, visitproc visit, void *arg) { @@ -90,7 +134,7 @@ abc_data_new(PyTypeObject *type, PyObject *args, PyObject *kwds) self->_abc_registry = NULL; self->_abc_cache = NULL; self->_abc_negative_cache = NULL; - self->_abc_negative_cache_version = state->abc_invalidation_counter; + self->_abc_negative_cache_version = get_invalidation_counter(state); return (PyObject *) self; } @@ -130,8 +174,12 @@ _get_impl(PyObject *module, PyObject *self) } static int -_in_weak_set(PyObject *set, PyObject *obj) +_in_weak_set(_abc_data *impl, PyObject **pset, PyObject *obj) { + PyObject *set; + Py_BEGIN_CRITICAL_SECTION(impl); + set = *pset; + Py_END_CRITICAL_SECTION(); if (set == NULL || PySet_GET_SIZE(set) == 0) { return 0; } @@ -168,16 +216,19 @@ static PyMethodDef _destroy_def = { }; static int -_add_to_weak_set(PyObject **pset, PyObject *obj) +_add_to_weak_set(_abc_data *impl, PyObject **pset, PyObject *obj) { - if (*pset == NULL) { - *pset = PySet_New(NULL); - if (*pset == NULL) { - return -1; - } + PyObject *set; + Py_BEGIN_CRITICAL_SECTION(impl); + set = *pset; + if (set == NULL) { + set = *pset = PySet_New(NULL); + } + Py_END_CRITICAL_SECTION(); + if (set == NULL) { + return -1; } - PyObject *set = *pset; PyObject *ref, *wr; PyObject *destroy_cb; wr = PyWeakref_NewRef(set, NULL); @@ -220,7 +271,11 @@ _abc__reset_registry(PyObject *module, PyObject *self) if (impl == NULL) { return NULL; } - if (impl->_abc_registry != NULL && PySet_Clear(impl->_abc_registry) < 0) { + PyObject *registry; + Py_BEGIN_CRITICAL_SECTION(impl); + registry = impl->_abc_registry; + Py_END_CRITICAL_SECTION(); + if (registry != NULL && PySet_Clear(registry) < 0) { Py_DECREF(impl); return NULL; } @@ -247,13 +302,17 @@ _abc__reset_caches(PyObject *module, PyObject *self) if (impl == NULL) { return NULL; } - if (impl->_abc_cache != NULL && PySet_Clear(impl->_abc_cache) < 0) { + PyObject *cache, *negative_cache; + Py_BEGIN_CRITICAL_SECTION(impl); + cache = impl->_abc_cache; + negative_cache = impl->_abc_negative_cache; + Py_END_CRITICAL_SECTION(); + if (cache != NULL && PySet_Clear(cache) < 0) { Py_DECREF(impl); return NULL; } /* also the second cache */ - if (impl->_abc_negative_cache != NULL && - PySet_Clear(impl->_abc_negative_cache) < 0) { + if (negative_cache != NULL && PySet_Clear(negative_cache) < 0) { Py_DECREF(impl); return NULL; } @@ -282,11 +341,14 @@ _abc__get_dump(PyObject *module, PyObject *self) if (impl == NULL) { return NULL; } - PyObject *res = Py_BuildValue("NNNK", - PySet_New(impl->_abc_registry), - PySet_New(impl->_abc_cache), - PySet_New(impl->_abc_negative_cache), - impl->_abc_negative_cache_version); + PyObject *res; + Py_BEGIN_CRITICAL_SECTION(impl); + res = Py_BuildValue("NNNK", + PySet_New(impl->_abc_registry), + PySet_New(impl->_abc_cache), + PySet_New(impl->_abc_negative_cache), + get_cache_version(impl)); + Py_END_CRITICAL_SECTION(); Py_DECREF(impl); return res; } @@ -453,58 +515,29 @@ _abc__abc_init(PyObject *module, PyObject *self) if (PyType_Check(self)) { PyTypeObject *cls = (PyTypeObject *)self; PyObject *dict = _PyType_GetDict(cls); - PyObject *flags = PyDict_GetItemWithError(dict, - &_Py_ID(__abc_tpflags__)); - if (flags == NULL) { - if (PyErr_Occurred()) { - return NULL; - } + PyObject *flags = NULL; + if (PyDict_Pop(dict, &_Py_ID(__abc_tpflags__), &flags) < 0) { + return NULL; } - else { - if (PyLong_CheckExact(flags)) { - long val = PyLong_AsLong(flags); - if (val == -1 && PyErr_Occurred()) { - return NULL; - } - if ((val & COLLECTION_FLAGS) == COLLECTION_FLAGS) { - PyErr_SetString(PyExc_TypeError, "__abc_tpflags__ cannot be both Py_TPFLAGS_SEQUENCE and Py_TPFLAGS_MAPPING"); - return NULL; - } - ((PyTypeObject *)self)->tp_flags |= (val & COLLECTION_FLAGS); - } - if (PyDict_DelItem(dict, &_Py_ID(__abc_tpflags__)) < 0) { - return NULL; - } + if (flags == NULL || !PyLong_CheckExact(flags)) { + Py_XDECREF(flags); + Py_RETURN_NONE; } + + long val = PyLong_AsLong(flags); + Py_DECREF(flags); + if (val == -1 && PyErr_Occurred()) { + return NULL; + } + if ((val & COLLECTION_FLAGS) == COLLECTION_FLAGS) { + PyErr_SetString(PyExc_TypeError, "__abc_tpflags__ cannot be both Py_TPFLAGS_SEQUENCE and Py_TPFLAGS_MAPPING"); + return NULL; + } + _PyType_SetFlags((PyTypeObject *)self, 0, val & COLLECTION_FLAGS); } Py_RETURN_NONE; } -static void -set_collection_flag_recursive(PyTypeObject *child, unsigned long flag) -{ - assert(flag == Py_TPFLAGS_MAPPING || flag == Py_TPFLAGS_SEQUENCE); - if (PyType_HasFeature(child, Py_TPFLAGS_IMMUTABLETYPE) || - (child->tp_flags & COLLECTION_FLAGS) == flag) - { - return; - } - - child->tp_flags &= ~COLLECTION_FLAGS; - child->tp_flags |= flag; - - PyObject *grandchildren = _PyType_GetSubclasses(child); - if (grandchildren == NULL) { - return; - } - - for (Py_ssize_t i = 0; i < PyList_GET_SIZE(grandchildren); i++) { - PyObject *grandchild = PyList_GET_ITEM(grandchildren, i); - set_collection_flag_recursive((PyTypeObject *)grandchild, flag); - } - Py_DECREF(grandchildren); -} - /*[clinic input] _abc._abc_register @@ -545,20 +578,23 @@ _abc__abc_register_impl(PyObject *module, PyObject *self, PyObject *subclass) if (impl == NULL) { return NULL; } - if (_add_to_weak_set(&impl->_abc_registry, subclass) < 0) { + if (_add_to_weak_set(impl, &impl->_abc_registry, subclass) < 0) { Py_DECREF(impl); return NULL; } Py_DECREF(impl); /* Invalidate negative cache */ - get_abc_state(module)->abc_invalidation_counter++; + increment_invalidation_counter(get_abc_state(module)); - /* Set Py_TPFLAGS_SEQUENCE or Py_TPFLAGS_MAPPING flag */ + /* Set Py_TPFLAGS_SEQUENCE or Py_TPFLAGS_MAPPING flag */ if (PyType_Check(self)) { - unsigned long collection_flag = ((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS; + unsigned long collection_flag = + PyType_GetFlags((PyTypeObject *)self) & COLLECTION_FLAGS; if (collection_flag) { - set_collection_flag_recursive((PyTypeObject *)subclass, collection_flag); + _PyType_SetFlagsRecursive((PyTypeObject *)subclass, + COLLECTION_FLAGS, + collection_flag); } } return Py_NewRef(subclass); @@ -592,7 +628,7 @@ _abc__abc_instancecheck_impl(PyObject *module, PyObject *self, return NULL; } /* Inline the cache checking. */ - int incache = _in_weak_set(impl->_abc_cache, subclass); + int incache = _in_weak_set(impl, &impl->_abc_cache, subclass); if (incache < 0) { goto end; } @@ -602,8 +638,8 @@ _abc__abc_instancecheck_impl(PyObject *module, PyObject *self, } subtype = (PyObject *)Py_TYPE(instance); if (subtype == subclass) { - if (impl->_abc_negative_cache_version == get_abc_state(module)->abc_invalidation_counter) { - incache = _in_weak_set(impl->_abc_negative_cache, subclass); + if (get_cache_version(impl) == get_invalidation_counter(get_abc_state(module))) { + incache = _in_weak_set(impl, &impl->_abc_negative_cache, subclass); if (incache < 0) { goto end; } @@ -681,7 +717,7 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self, } /* 1. Check cache. */ - incache = _in_weak_set(impl->_abc_cache, subclass); + incache = _in_weak_set(impl, &impl->_abc_cache, subclass); if (incache < 0) { goto end; } @@ -692,17 +728,20 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self, state = get_abc_state(module); /* 2. Check negative cache; may have to invalidate. */ - if (impl->_abc_negative_cache_version < state->abc_invalidation_counter) { + uint64_t invalidation_counter = get_invalidation_counter(state); + if (get_cache_version(impl) < invalidation_counter) { /* Invalidate the negative cache. */ - if (impl->_abc_negative_cache != NULL && - PySet_Clear(impl->_abc_negative_cache) < 0) - { + PyObject *negative_cache; + Py_BEGIN_CRITICAL_SECTION(impl); + negative_cache = impl->_abc_negative_cache; + Py_END_CRITICAL_SECTION(); + if (negative_cache != NULL && PySet_Clear(negative_cache) < 0) { goto end; } - impl->_abc_negative_cache_version = state->abc_invalidation_counter; + set_cache_version(impl, invalidation_counter); } else { - incache = _in_weak_set(impl->_abc_negative_cache, subclass); + incache = _in_weak_set(impl, &impl->_abc_negative_cache, subclass); if (incache < 0) { goto end; } @@ -720,7 +759,7 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self, } if (ok == Py_True) { Py_DECREF(ok); - if (_add_to_weak_set(&impl->_abc_cache, subclass) < 0) { + if (_add_to_weak_set(impl, &impl->_abc_cache, subclass) < 0) { goto end; } result = Py_True; @@ -728,7 +767,7 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self, } if (ok == Py_False) { Py_DECREF(ok); - if (_add_to_weak_set(&impl->_abc_negative_cache, subclass) < 0) { + if (_add_to_weak_set(impl, &impl->_abc_negative_cache, subclass) < 0) { goto end; } result = Py_False; @@ -744,7 +783,7 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self, /* 4. Check if it's a direct subclass. */ if (PyType_IsSubtype((PyTypeObject *)subclass, (PyTypeObject *)self)) { - if (_add_to_weak_set(&impl->_abc_cache, subclass) < 0) { + if (_add_to_weak_set(impl, &impl->_abc_cache, subclass) < 0) { goto end; } result = Py_True; @@ -767,12 +806,14 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self, goto end; } for (pos = 0; pos < PyList_GET_SIZE(subclasses); pos++) { - PyObject *scls = PyList_GET_ITEM(subclasses, pos); - Py_INCREF(scls); + PyObject *scls = PyList_GetItemRef(subclasses, pos); + if (scls == NULL) { + goto end; + } int r = PyObject_IsSubclass(subclass, scls); Py_DECREF(scls); if (r > 0) { - if (_add_to_weak_set(&impl->_abc_cache, subclass) < 0) { + if (_add_to_weak_set(impl, &impl->_abc_cache, subclass) < 0) { goto end; } result = Py_True; @@ -784,7 +825,7 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self, } /* No dice; update negative cache. */ - if (_add_to_weak_set(&impl->_abc_negative_cache, subclass) < 0) { + if (_add_to_weak_set(impl, &impl->_abc_negative_cache, subclass) < 0) { goto end; } result = Py_False; @@ -801,7 +842,7 @@ subclasscheck_check_registry(_abc_data *impl, PyObject *subclass, PyObject **result) { // Fast path: check subclass is in weakref directly. - int ret = _in_weak_set(impl->_abc_registry, subclass); + int ret = _in_weak_set(impl, &impl->_abc_registry, subclass); if (ret < 0) { *result = NULL; return -1; @@ -811,33 +852,27 @@ subclasscheck_check_registry(_abc_data *impl, PyObject *subclass, return 1; } - if (impl->_abc_registry == NULL) { + PyObject *registry_shared; + Py_BEGIN_CRITICAL_SECTION(impl); + registry_shared = impl->_abc_registry; + Py_END_CRITICAL_SECTION(); + if (registry_shared == NULL) { return 0; } - Py_ssize_t registry_size = PySet_Size(impl->_abc_registry); - if (registry_size == 0) { - return 0; - } - // Weakref callback may remove entry from set. - // So we take snapshot of registry first. - PyObject **copy = PyMem_Malloc(sizeof(PyObject*) * registry_size); - if (copy == NULL) { - PyErr_NoMemory(); + + // Make a local copy of the registry to protect against concurrent + // modifications of _abc_registry. + PyObject *registry = PySet_New(registry_shared); + if (registry == NULL) { return -1; } PyObject *key; Py_ssize_t pos = 0; Py_hash_t hash; - Py_ssize_t i = 0; - while (_PySet_NextEntry(impl->_abc_registry, &pos, &key, &hash)) { - copy[i++] = Py_NewRef(key); - } - assert(i == registry_size); - - for (i = 0; i < registry_size; i++) { + while (_PySet_NextEntry(registry, &pos, &key, &hash)) { PyObject *rkey; - if (PyWeakref_GetRef(copy[i], &rkey) < 0) { + if (PyWeakref_GetRef(key, &rkey) < 0) { // Someone inject non-weakref type in the registry. ret = -1; break; @@ -853,7 +888,7 @@ subclasscheck_check_registry(_abc_data *impl, PyObject *subclass, break; } if (r > 0) { - if (_add_to_weak_set(&impl->_abc_cache, subclass) < 0) { + if (_add_to_weak_set(impl, &impl->_abc_cache, subclass) < 0) { ret = -1; break; } @@ -863,10 +898,7 @@ subclasscheck_check_registry(_abc_data *impl, PyObject *subclass, } } - for (i = 0; i < registry_size; i++) { - Py_DECREF(copy[i]); - } - PyMem_Free(copy); + Py_DECREF(registry); return ret; } @@ -885,7 +917,7 @@ _abc_get_cache_token_impl(PyObject *module) /*[clinic end generated code: output=c7d87841e033dacc input=70413d1c423ad9f9]*/ { _abcmodule_state *state = get_abc_state(module); - return PyLong_FromUnsignedLongLong(state->abc_invalidation_counter); + return PyLong_FromUnsignedLongLong(get_invalidation_counter(state)); } static struct PyMethodDef _abcmodule_methods[] = { diff --git a/Objects/typeobject.c b/Objects/typeobject.c index e9f2d2577e9..3f38abfcfe5 100644 --- a/Objects/typeobject.c +++ b/Objects/typeobject.c @@ -5117,6 +5117,52 @@ _PyType_LookupId(PyTypeObject *type, _Py_Identifier *name) return _PyType_Lookup(type, oname); } +static void +set_flags(PyTypeObject *self, unsigned long mask, unsigned long flags) +{ + ASSERT_TYPE_LOCK_HELD(); + self->tp_flags = (self->tp_flags & ~mask) | flags; +} + +void +_PyType_SetFlags(PyTypeObject *self, unsigned long mask, unsigned long flags) +{ + BEGIN_TYPE_LOCK(); + set_flags(self, mask, flags); + END_TYPE_LOCK(); +} + +static void +set_flags_recursive(PyTypeObject *self, unsigned long mask, unsigned long flags) +{ + if (PyType_HasFeature(self, Py_TPFLAGS_IMMUTABLETYPE) || + (self->tp_flags & mask) == flags) + { + return; + } + + set_flags(self, mask, flags); + + PyObject *children = _PyType_GetSubclasses(self); + if (children == NULL) { + return; + } + + for (Py_ssize_t i = 0; i < PyList_GET_SIZE(children); i++) { + PyObject *child = PyList_GET_ITEM(children, i); + set_flags_recursive((PyTypeObject *)child, mask, flags); + } + Py_DECREF(children); +} + +void +_PyType_SetFlagsRecursive(PyTypeObject *self, unsigned long mask, unsigned long flags) +{ + BEGIN_TYPE_LOCK(); + set_flags_recursive(self, mask, flags); + END_TYPE_LOCK(); +} + /* This is similar to PyObject_GenericGetAttr(), but uses _PyType_Lookup() instead of just looking in type->tp_dict.