gh-116738: Make _abc module thread-safe (#117488)

A collection of small changes aimed at making the `_abc` module safe to
use in a free-threaded build.
This commit is contained in:
Brett Simmers 2024-04-11 15:13:25 -07:00 committed by GitHub
parent 1b10efad66
commit f268e328ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 205 additions and 115 deletions

View File

@ -152,6 +152,18 @@ PyAPI_FUNC(PyObject*) _PySuper_Lookup(PyTypeObject *su_type, PyObject *su_obj,
extern PyObject* _PyType_GetFullyQualifiedName(PyTypeObject *type, char sep);
// Perform the following operation, in a thread-safe way when required by the
// build mode.
//
// self->tp_flags = (self->tp_flags & ~mask) | flags;
extern void _PyType_SetFlags(PyTypeObject *self, unsigned long mask,
unsigned long flags);
// Like _PyType_SetFlags(), but apply the operation to self and any of its
// subclasses without Py_TPFLAGS_IMMUTABLETYPE set.
extern void _PyType_SetFlagsRecursive(PyTypeObject *self, unsigned long mask,
unsigned long flags);
#ifdef __cplusplus
}

View File

@ -21,7 +21,7 @@ PyDoc_STRVAR(_abc__doc__,
typedef struct {
PyTypeObject *_abc_data_type;
unsigned long long abc_invalidation_counter;
uint64_t abc_invalidation_counter;
} _abcmodule_state;
static inline _abcmodule_state*
@ -32,17 +32,61 @@ get_abc_state(PyObject *module)
return (_abcmodule_state *)state;
}
static inline uint64_t
get_invalidation_counter(_abcmodule_state *state)
{
#ifdef Py_GIL_DISABLED
return _Py_atomic_load_uint64(&state->abc_invalidation_counter);
#else
return state->abc_invalidation_counter;
#endif
}
static inline void
increment_invalidation_counter(_abcmodule_state *state)
{
#ifdef Py_GIL_DISABLED
_Py_atomic_add_uint64(&state->abc_invalidation_counter, 1);
#else
state->abc_invalidation_counter++;
#endif
}
/* This object stores internal state for ABCs.
Note that we can use normal sets for caches,
since they are never iterated over. */
typedef struct {
PyObject_HEAD
/* These sets of weak references are lazily created. Once created, they
will point to the same sets until the ABCMeta object is destroyed or
cleared, both of which will only happen while the object is visible to a
single thread. */
PyObject *_abc_registry;
PyObject *_abc_cache; /* Normal set of weak references. */
PyObject *_abc_negative_cache; /* Normal set of weak references. */
unsigned long long _abc_negative_cache_version;
PyObject *_abc_cache;
PyObject *_abc_negative_cache;
uint64_t _abc_negative_cache_version;
} _abc_data;
static inline uint64_t
get_cache_version(_abc_data *impl)
{
#ifdef Py_GIL_DISABLED
return _Py_atomic_load_uint64(&impl->_abc_negative_cache_version);
#else
return impl->_abc_negative_cache_version;
#endif
}
static inline void
set_cache_version(_abc_data *impl, uint64_t version)
{
#ifdef Py_GIL_DISABLED
_Py_atomic_store_uint64(&impl->_abc_negative_cache_version, version);
#else
impl->_abc_negative_cache_version = version;
#endif
}
static int
abc_data_traverse(_abc_data *self, visitproc visit, void *arg)
{
@ -90,7 +134,7 @@ abc_data_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
self->_abc_registry = NULL;
self->_abc_cache = NULL;
self->_abc_negative_cache = NULL;
self->_abc_negative_cache_version = state->abc_invalidation_counter;
self->_abc_negative_cache_version = get_invalidation_counter(state);
return (PyObject *) self;
}
@ -130,8 +174,12 @@ _get_impl(PyObject *module, PyObject *self)
}
static int
_in_weak_set(PyObject *set, PyObject *obj)
_in_weak_set(_abc_data *impl, PyObject **pset, PyObject *obj)
{
PyObject *set;
Py_BEGIN_CRITICAL_SECTION(impl);
set = *pset;
Py_END_CRITICAL_SECTION();
if (set == NULL || PySet_GET_SIZE(set) == 0) {
return 0;
}
@ -168,16 +216,19 @@ static PyMethodDef _destroy_def = {
};
static int
_add_to_weak_set(PyObject **pset, PyObject *obj)
_add_to_weak_set(_abc_data *impl, PyObject **pset, PyObject *obj)
{
if (*pset == NULL) {
*pset = PySet_New(NULL);
if (*pset == NULL) {
return -1;
}
PyObject *set;
Py_BEGIN_CRITICAL_SECTION(impl);
set = *pset;
if (set == NULL) {
set = *pset = PySet_New(NULL);
}
Py_END_CRITICAL_SECTION();
if (set == NULL) {
return -1;
}
PyObject *set = *pset;
PyObject *ref, *wr;
PyObject *destroy_cb;
wr = PyWeakref_NewRef(set, NULL);
@ -220,7 +271,11 @@ _abc__reset_registry(PyObject *module, PyObject *self)
if (impl == NULL) {
return NULL;
}
if (impl->_abc_registry != NULL && PySet_Clear(impl->_abc_registry) < 0) {
PyObject *registry;
Py_BEGIN_CRITICAL_SECTION(impl);
registry = impl->_abc_registry;
Py_END_CRITICAL_SECTION();
if (registry != NULL && PySet_Clear(registry) < 0) {
Py_DECREF(impl);
return NULL;
}
@ -247,13 +302,17 @@ _abc__reset_caches(PyObject *module, PyObject *self)
if (impl == NULL) {
return NULL;
}
if (impl->_abc_cache != NULL && PySet_Clear(impl->_abc_cache) < 0) {
PyObject *cache, *negative_cache;
Py_BEGIN_CRITICAL_SECTION(impl);
cache = impl->_abc_cache;
negative_cache = impl->_abc_negative_cache;
Py_END_CRITICAL_SECTION();
if (cache != NULL && PySet_Clear(cache) < 0) {
Py_DECREF(impl);
return NULL;
}
/* also the second cache */
if (impl->_abc_negative_cache != NULL &&
PySet_Clear(impl->_abc_negative_cache) < 0) {
if (negative_cache != NULL && PySet_Clear(negative_cache) < 0) {
Py_DECREF(impl);
return NULL;
}
@ -282,11 +341,14 @@ _abc__get_dump(PyObject *module, PyObject *self)
if (impl == NULL) {
return NULL;
}
PyObject *res = Py_BuildValue("NNNK",
PySet_New(impl->_abc_registry),
PySet_New(impl->_abc_cache),
PySet_New(impl->_abc_negative_cache),
impl->_abc_negative_cache_version);
PyObject *res;
Py_BEGIN_CRITICAL_SECTION(impl);
res = Py_BuildValue("NNNK",
PySet_New(impl->_abc_registry),
PySet_New(impl->_abc_cache),
PySet_New(impl->_abc_negative_cache),
get_cache_version(impl));
Py_END_CRITICAL_SECTION();
Py_DECREF(impl);
return res;
}
@ -453,58 +515,29 @@ _abc__abc_init(PyObject *module, PyObject *self)
if (PyType_Check(self)) {
PyTypeObject *cls = (PyTypeObject *)self;
PyObject *dict = _PyType_GetDict(cls);
PyObject *flags = PyDict_GetItemWithError(dict,
&_Py_ID(__abc_tpflags__));
if (flags == NULL) {
if (PyErr_Occurred()) {
return NULL;
}
PyObject *flags = NULL;
if (PyDict_Pop(dict, &_Py_ID(__abc_tpflags__), &flags) < 0) {
return NULL;
}
else {
if (PyLong_CheckExact(flags)) {
long val = PyLong_AsLong(flags);
if (val == -1 && PyErr_Occurred()) {
return NULL;
}
if ((val & COLLECTION_FLAGS) == COLLECTION_FLAGS) {
PyErr_SetString(PyExc_TypeError, "__abc_tpflags__ cannot be both Py_TPFLAGS_SEQUENCE and Py_TPFLAGS_MAPPING");
return NULL;
}
((PyTypeObject *)self)->tp_flags |= (val & COLLECTION_FLAGS);
}
if (PyDict_DelItem(dict, &_Py_ID(__abc_tpflags__)) < 0) {
return NULL;
}
if (flags == NULL || !PyLong_CheckExact(flags)) {
Py_XDECREF(flags);
Py_RETURN_NONE;
}
long val = PyLong_AsLong(flags);
Py_DECREF(flags);
if (val == -1 && PyErr_Occurred()) {
return NULL;
}
if ((val & COLLECTION_FLAGS) == COLLECTION_FLAGS) {
PyErr_SetString(PyExc_TypeError, "__abc_tpflags__ cannot be both Py_TPFLAGS_SEQUENCE and Py_TPFLAGS_MAPPING");
return NULL;
}
_PyType_SetFlags((PyTypeObject *)self, 0, val & COLLECTION_FLAGS);
}
Py_RETURN_NONE;
}
static void
set_collection_flag_recursive(PyTypeObject *child, unsigned long flag)
{
assert(flag == Py_TPFLAGS_MAPPING || flag == Py_TPFLAGS_SEQUENCE);
if (PyType_HasFeature(child, Py_TPFLAGS_IMMUTABLETYPE) ||
(child->tp_flags & COLLECTION_FLAGS) == flag)
{
return;
}
child->tp_flags &= ~COLLECTION_FLAGS;
child->tp_flags |= flag;
PyObject *grandchildren = _PyType_GetSubclasses(child);
if (grandchildren == NULL) {
return;
}
for (Py_ssize_t i = 0; i < PyList_GET_SIZE(grandchildren); i++) {
PyObject *grandchild = PyList_GET_ITEM(grandchildren, i);
set_collection_flag_recursive((PyTypeObject *)grandchild, flag);
}
Py_DECREF(grandchildren);
}
/*[clinic input]
_abc._abc_register
@ -545,20 +578,23 @@ _abc__abc_register_impl(PyObject *module, PyObject *self, PyObject *subclass)
if (impl == NULL) {
return NULL;
}
if (_add_to_weak_set(&impl->_abc_registry, subclass) < 0) {
if (_add_to_weak_set(impl, &impl->_abc_registry, subclass) < 0) {
Py_DECREF(impl);
return NULL;
}
Py_DECREF(impl);
/* Invalidate negative cache */
get_abc_state(module)->abc_invalidation_counter++;
increment_invalidation_counter(get_abc_state(module));
/* Set Py_TPFLAGS_SEQUENCE or Py_TPFLAGS_MAPPING flag */
/* Set Py_TPFLAGS_SEQUENCE or Py_TPFLAGS_MAPPING flag */
if (PyType_Check(self)) {
unsigned long collection_flag = ((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS;
unsigned long collection_flag =
PyType_GetFlags((PyTypeObject *)self) & COLLECTION_FLAGS;
if (collection_flag) {
set_collection_flag_recursive((PyTypeObject *)subclass, collection_flag);
_PyType_SetFlagsRecursive((PyTypeObject *)subclass,
COLLECTION_FLAGS,
collection_flag);
}
}
return Py_NewRef(subclass);
@ -592,7 +628,7 @@ _abc__abc_instancecheck_impl(PyObject *module, PyObject *self,
return NULL;
}
/* Inline the cache checking. */
int incache = _in_weak_set(impl->_abc_cache, subclass);
int incache = _in_weak_set(impl, &impl->_abc_cache, subclass);
if (incache < 0) {
goto end;
}
@ -602,8 +638,8 @@ _abc__abc_instancecheck_impl(PyObject *module, PyObject *self,
}
subtype = (PyObject *)Py_TYPE(instance);
if (subtype == subclass) {
if (impl->_abc_negative_cache_version == get_abc_state(module)->abc_invalidation_counter) {
incache = _in_weak_set(impl->_abc_negative_cache, subclass);
if (get_cache_version(impl) == get_invalidation_counter(get_abc_state(module))) {
incache = _in_weak_set(impl, &impl->_abc_negative_cache, subclass);
if (incache < 0) {
goto end;
}
@ -681,7 +717,7 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self,
}
/* 1. Check cache. */
incache = _in_weak_set(impl->_abc_cache, subclass);
incache = _in_weak_set(impl, &impl->_abc_cache, subclass);
if (incache < 0) {
goto end;
}
@ -692,17 +728,20 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self,
state = get_abc_state(module);
/* 2. Check negative cache; may have to invalidate. */
if (impl->_abc_negative_cache_version < state->abc_invalidation_counter) {
uint64_t invalidation_counter = get_invalidation_counter(state);
if (get_cache_version(impl) < invalidation_counter) {
/* Invalidate the negative cache. */
if (impl->_abc_negative_cache != NULL &&
PySet_Clear(impl->_abc_negative_cache) < 0)
{
PyObject *negative_cache;
Py_BEGIN_CRITICAL_SECTION(impl);
negative_cache = impl->_abc_negative_cache;
Py_END_CRITICAL_SECTION();
if (negative_cache != NULL && PySet_Clear(negative_cache) < 0) {
goto end;
}
impl->_abc_negative_cache_version = state->abc_invalidation_counter;
set_cache_version(impl, invalidation_counter);
}
else {
incache = _in_weak_set(impl->_abc_negative_cache, subclass);
incache = _in_weak_set(impl, &impl->_abc_negative_cache, subclass);
if (incache < 0) {
goto end;
}
@ -720,7 +759,7 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self,
}
if (ok == Py_True) {
Py_DECREF(ok);
if (_add_to_weak_set(&impl->_abc_cache, subclass) < 0) {
if (_add_to_weak_set(impl, &impl->_abc_cache, subclass) < 0) {
goto end;
}
result = Py_True;
@ -728,7 +767,7 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self,
}
if (ok == Py_False) {
Py_DECREF(ok);
if (_add_to_weak_set(&impl->_abc_negative_cache, subclass) < 0) {
if (_add_to_weak_set(impl, &impl->_abc_negative_cache, subclass) < 0) {
goto end;
}
result = Py_False;
@ -744,7 +783,7 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self,
/* 4. Check if it's a direct subclass. */
if (PyType_IsSubtype((PyTypeObject *)subclass, (PyTypeObject *)self)) {
if (_add_to_weak_set(&impl->_abc_cache, subclass) < 0) {
if (_add_to_weak_set(impl, &impl->_abc_cache, subclass) < 0) {
goto end;
}
result = Py_True;
@ -767,12 +806,14 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self,
goto end;
}
for (pos = 0; pos < PyList_GET_SIZE(subclasses); pos++) {
PyObject *scls = PyList_GET_ITEM(subclasses, pos);
Py_INCREF(scls);
PyObject *scls = PyList_GetItemRef(subclasses, pos);
if (scls == NULL) {
goto end;
}
int r = PyObject_IsSubclass(subclass, scls);
Py_DECREF(scls);
if (r > 0) {
if (_add_to_weak_set(&impl->_abc_cache, subclass) < 0) {
if (_add_to_weak_set(impl, &impl->_abc_cache, subclass) < 0) {
goto end;
}
result = Py_True;
@ -784,7 +825,7 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self,
}
/* No dice; update negative cache. */
if (_add_to_weak_set(&impl->_abc_negative_cache, subclass) < 0) {
if (_add_to_weak_set(impl, &impl->_abc_negative_cache, subclass) < 0) {
goto end;
}
result = Py_False;
@ -801,7 +842,7 @@ subclasscheck_check_registry(_abc_data *impl, PyObject *subclass,
PyObject **result)
{
// Fast path: check subclass is in weakref directly.
int ret = _in_weak_set(impl->_abc_registry, subclass);
int ret = _in_weak_set(impl, &impl->_abc_registry, subclass);
if (ret < 0) {
*result = NULL;
return -1;
@ -811,33 +852,27 @@ subclasscheck_check_registry(_abc_data *impl, PyObject *subclass,
return 1;
}
if (impl->_abc_registry == NULL) {
PyObject *registry_shared;
Py_BEGIN_CRITICAL_SECTION(impl);
registry_shared = impl->_abc_registry;
Py_END_CRITICAL_SECTION();
if (registry_shared == NULL) {
return 0;
}
Py_ssize_t registry_size = PySet_Size(impl->_abc_registry);
if (registry_size == 0) {
return 0;
}
// Weakref callback may remove entry from set.
// So we take snapshot of registry first.
PyObject **copy = PyMem_Malloc(sizeof(PyObject*) * registry_size);
if (copy == NULL) {
PyErr_NoMemory();
// Make a local copy of the registry to protect against concurrent
// modifications of _abc_registry.
PyObject *registry = PySet_New(registry_shared);
if (registry == NULL) {
return -1;
}
PyObject *key;
Py_ssize_t pos = 0;
Py_hash_t hash;
Py_ssize_t i = 0;
while (_PySet_NextEntry(impl->_abc_registry, &pos, &key, &hash)) {
copy[i++] = Py_NewRef(key);
}
assert(i == registry_size);
for (i = 0; i < registry_size; i++) {
while (_PySet_NextEntry(registry, &pos, &key, &hash)) {
PyObject *rkey;
if (PyWeakref_GetRef(copy[i], &rkey) < 0) {
if (PyWeakref_GetRef(key, &rkey) < 0) {
// Someone inject non-weakref type in the registry.
ret = -1;
break;
@ -853,7 +888,7 @@ subclasscheck_check_registry(_abc_data *impl, PyObject *subclass,
break;
}
if (r > 0) {
if (_add_to_weak_set(&impl->_abc_cache, subclass) < 0) {
if (_add_to_weak_set(impl, &impl->_abc_cache, subclass) < 0) {
ret = -1;
break;
}
@ -863,10 +898,7 @@ subclasscheck_check_registry(_abc_data *impl, PyObject *subclass,
}
}
for (i = 0; i < registry_size; i++) {
Py_DECREF(copy[i]);
}
PyMem_Free(copy);
Py_DECREF(registry);
return ret;
}
@ -885,7 +917,7 @@ _abc_get_cache_token_impl(PyObject *module)
/*[clinic end generated code: output=c7d87841e033dacc input=70413d1c423ad9f9]*/
{
_abcmodule_state *state = get_abc_state(module);
return PyLong_FromUnsignedLongLong(state->abc_invalidation_counter);
return PyLong_FromUnsignedLongLong(get_invalidation_counter(state));
}
static struct PyMethodDef _abcmodule_methods[] = {

View File

@ -5117,6 +5117,52 @@ _PyType_LookupId(PyTypeObject *type, _Py_Identifier *name)
return _PyType_Lookup(type, oname);
}
static void
set_flags(PyTypeObject *self, unsigned long mask, unsigned long flags)
{
ASSERT_TYPE_LOCK_HELD();
self->tp_flags = (self->tp_flags & ~mask) | flags;
}
void
_PyType_SetFlags(PyTypeObject *self, unsigned long mask, unsigned long flags)
{
BEGIN_TYPE_LOCK();
set_flags(self, mask, flags);
END_TYPE_LOCK();
}
static void
set_flags_recursive(PyTypeObject *self, unsigned long mask, unsigned long flags)
{
if (PyType_HasFeature(self, Py_TPFLAGS_IMMUTABLETYPE) ||
(self->tp_flags & mask) == flags)
{
return;
}
set_flags(self, mask, flags);
PyObject *children = _PyType_GetSubclasses(self);
if (children == NULL) {
return;
}
for (Py_ssize_t i = 0; i < PyList_GET_SIZE(children); i++) {
PyObject *child = PyList_GET_ITEM(children, i);
set_flags_recursive((PyTypeObject *)child, mask, flags);
}
Py_DECREF(children);
}
void
_PyType_SetFlagsRecursive(PyTypeObject *self, unsigned long mask, unsigned long flags)
{
BEGIN_TYPE_LOCK();
set_flags_recursive(self, mask, flags);
END_TYPE_LOCK();
}
/* This is similar to PyObject_GenericGetAttr(),
but uses _PyType_Lookup() instead of just looking in type->tp_dict.