gh-116738: Make _abc module thread-safe (#117488)

A collection of small changes aimed at making the `_abc` module safe to
use in a free-threaded build.
This commit is contained in:
Brett Simmers 2024-04-11 15:13:25 -07:00 committed by GitHub
parent 1b10efad66
commit f268e328ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 205 additions and 115 deletions

View File

@ -152,6 +152,18 @@ PyAPI_FUNC(PyObject*) _PySuper_Lookup(PyTypeObject *su_type, PyObject *su_obj,
extern PyObject* _PyType_GetFullyQualifiedName(PyTypeObject *type, char sep); extern PyObject* _PyType_GetFullyQualifiedName(PyTypeObject *type, char sep);
// Perform the following operation, in a thread-safe way when required by the
// build mode.
//
// self->tp_flags = (self->tp_flags & ~mask) | flags;
extern void _PyType_SetFlags(PyTypeObject *self, unsigned long mask,
unsigned long flags);
// Like _PyType_SetFlags(), but apply the operation to self and any of its
// subclasses without Py_TPFLAGS_IMMUTABLETYPE set.
extern void _PyType_SetFlagsRecursive(PyTypeObject *self, unsigned long mask,
unsigned long flags);
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@ -21,7 +21,7 @@ PyDoc_STRVAR(_abc__doc__,
typedef struct { typedef struct {
PyTypeObject *_abc_data_type; PyTypeObject *_abc_data_type;
unsigned long long abc_invalidation_counter; uint64_t abc_invalidation_counter;
} _abcmodule_state; } _abcmodule_state;
static inline _abcmodule_state* static inline _abcmodule_state*
@ -32,17 +32,61 @@ get_abc_state(PyObject *module)
return (_abcmodule_state *)state; return (_abcmodule_state *)state;
} }
static inline uint64_t
get_invalidation_counter(_abcmodule_state *state)
{
#ifdef Py_GIL_DISABLED
return _Py_atomic_load_uint64(&state->abc_invalidation_counter);
#else
return state->abc_invalidation_counter;
#endif
}
static inline void
increment_invalidation_counter(_abcmodule_state *state)
{
#ifdef Py_GIL_DISABLED
_Py_atomic_add_uint64(&state->abc_invalidation_counter, 1);
#else
state->abc_invalidation_counter++;
#endif
}
/* This object stores internal state for ABCs. /* This object stores internal state for ABCs.
Note that we can use normal sets for caches, Note that we can use normal sets for caches,
since they are never iterated over. */ since they are never iterated over. */
typedef struct { typedef struct {
PyObject_HEAD PyObject_HEAD
/* These sets of weak references are lazily created. Once created, they
will point to the same sets until the ABCMeta object is destroyed or
cleared, both of which will only happen while the object is visible to a
single thread. */
PyObject *_abc_registry; PyObject *_abc_registry;
PyObject *_abc_cache; /* Normal set of weak references. */ PyObject *_abc_cache;
PyObject *_abc_negative_cache; /* Normal set of weak references. */ PyObject *_abc_negative_cache;
unsigned long long _abc_negative_cache_version; uint64_t _abc_negative_cache_version;
} _abc_data; } _abc_data;
static inline uint64_t
get_cache_version(_abc_data *impl)
{
#ifdef Py_GIL_DISABLED
return _Py_atomic_load_uint64(&impl->_abc_negative_cache_version);
#else
return impl->_abc_negative_cache_version;
#endif
}
static inline void
set_cache_version(_abc_data *impl, uint64_t version)
{
#ifdef Py_GIL_DISABLED
_Py_atomic_store_uint64(&impl->_abc_negative_cache_version, version);
#else
impl->_abc_negative_cache_version = version;
#endif
}
static int static int
abc_data_traverse(_abc_data *self, visitproc visit, void *arg) abc_data_traverse(_abc_data *self, visitproc visit, void *arg)
{ {
@ -90,7 +134,7 @@ abc_data_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
self->_abc_registry = NULL; self->_abc_registry = NULL;
self->_abc_cache = NULL; self->_abc_cache = NULL;
self->_abc_negative_cache = NULL; self->_abc_negative_cache = NULL;
self->_abc_negative_cache_version = state->abc_invalidation_counter; self->_abc_negative_cache_version = get_invalidation_counter(state);
return (PyObject *) self; return (PyObject *) self;
} }
@ -130,8 +174,12 @@ _get_impl(PyObject *module, PyObject *self)
} }
static int static int
_in_weak_set(PyObject *set, PyObject *obj) _in_weak_set(_abc_data *impl, PyObject **pset, PyObject *obj)
{ {
PyObject *set;
Py_BEGIN_CRITICAL_SECTION(impl);
set = *pset;
Py_END_CRITICAL_SECTION();
if (set == NULL || PySet_GET_SIZE(set) == 0) { if (set == NULL || PySet_GET_SIZE(set) == 0) {
return 0; return 0;
} }
@ -168,16 +216,19 @@ static PyMethodDef _destroy_def = {
}; };
static int static int
_add_to_weak_set(PyObject **pset, PyObject *obj) _add_to_weak_set(_abc_data *impl, PyObject **pset, PyObject *obj)
{ {
if (*pset == NULL) { PyObject *set;
*pset = PySet_New(NULL); Py_BEGIN_CRITICAL_SECTION(impl);
if (*pset == NULL) { set = *pset;
return -1; if (set == NULL) {
} set = *pset = PySet_New(NULL);
}
Py_END_CRITICAL_SECTION();
if (set == NULL) {
return -1;
} }
PyObject *set = *pset;
PyObject *ref, *wr; PyObject *ref, *wr;
PyObject *destroy_cb; PyObject *destroy_cb;
wr = PyWeakref_NewRef(set, NULL); wr = PyWeakref_NewRef(set, NULL);
@ -220,7 +271,11 @@ _abc__reset_registry(PyObject *module, PyObject *self)
if (impl == NULL) { if (impl == NULL) {
return NULL; return NULL;
} }
if (impl->_abc_registry != NULL && PySet_Clear(impl->_abc_registry) < 0) { PyObject *registry;
Py_BEGIN_CRITICAL_SECTION(impl);
registry = impl->_abc_registry;
Py_END_CRITICAL_SECTION();
if (registry != NULL && PySet_Clear(registry) < 0) {
Py_DECREF(impl); Py_DECREF(impl);
return NULL; return NULL;
} }
@ -247,13 +302,17 @@ _abc__reset_caches(PyObject *module, PyObject *self)
if (impl == NULL) { if (impl == NULL) {
return NULL; return NULL;
} }
if (impl->_abc_cache != NULL && PySet_Clear(impl->_abc_cache) < 0) { PyObject *cache, *negative_cache;
Py_BEGIN_CRITICAL_SECTION(impl);
cache = impl->_abc_cache;
negative_cache = impl->_abc_negative_cache;
Py_END_CRITICAL_SECTION();
if (cache != NULL && PySet_Clear(cache) < 0) {
Py_DECREF(impl); Py_DECREF(impl);
return NULL; return NULL;
} }
/* also the second cache */ /* also the second cache */
if (impl->_abc_negative_cache != NULL && if (negative_cache != NULL && PySet_Clear(negative_cache) < 0) {
PySet_Clear(impl->_abc_negative_cache) < 0) {
Py_DECREF(impl); Py_DECREF(impl);
return NULL; return NULL;
} }
@ -282,11 +341,14 @@ _abc__get_dump(PyObject *module, PyObject *self)
if (impl == NULL) { if (impl == NULL) {
return NULL; return NULL;
} }
PyObject *res = Py_BuildValue("NNNK", PyObject *res;
PySet_New(impl->_abc_registry), Py_BEGIN_CRITICAL_SECTION(impl);
PySet_New(impl->_abc_cache), res = Py_BuildValue("NNNK",
PySet_New(impl->_abc_negative_cache), PySet_New(impl->_abc_registry),
impl->_abc_negative_cache_version); PySet_New(impl->_abc_cache),
PySet_New(impl->_abc_negative_cache),
get_cache_version(impl));
Py_END_CRITICAL_SECTION();
Py_DECREF(impl); Py_DECREF(impl);
return res; return res;
} }
@ -453,58 +515,29 @@ _abc__abc_init(PyObject *module, PyObject *self)
if (PyType_Check(self)) { if (PyType_Check(self)) {
PyTypeObject *cls = (PyTypeObject *)self; PyTypeObject *cls = (PyTypeObject *)self;
PyObject *dict = _PyType_GetDict(cls); PyObject *dict = _PyType_GetDict(cls);
PyObject *flags = PyDict_GetItemWithError(dict, PyObject *flags = NULL;
&_Py_ID(__abc_tpflags__)); if (PyDict_Pop(dict, &_Py_ID(__abc_tpflags__), &flags) < 0) {
if (flags == NULL) { return NULL;
if (PyErr_Occurred()) {
return NULL;
}
} }
else { if (flags == NULL || !PyLong_CheckExact(flags)) {
if (PyLong_CheckExact(flags)) { Py_XDECREF(flags);
long val = PyLong_AsLong(flags); Py_RETURN_NONE;
if (val == -1 && PyErr_Occurred()) {
return NULL;
}
if ((val & COLLECTION_FLAGS) == COLLECTION_FLAGS) {
PyErr_SetString(PyExc_TypeError, "__abc_tpflags__ cannot be both Py_TPFLAGS_SEQUENCE and Py_TPFLAGS_MAPPING");
return NULL;
}
((PyTypeObject *)self)->tp_flags |= (val & COLLECTION_FLAGS);
}
if (PyDict_DelItem(dict, &_Py_ID(__abc_tpflags__)) < 0) {
return NULL;
}
} }
long val = PyLong_AsLong(flags);
Py_DECREF(flags);
if (val == -1 && PyErr_Occurred()) {
return NULL;
}
if ((val & COLLECTION_FLAGS) == COLLECTION_FLAGS) {
PyErr_SetString(PyExc_TypeError, "__abc_tpflags__ cannot be both Py_TPFLAGS_SEQUENCE and Py_TPFLAGS_MAPPING");
return NULL;
}
_PyType_SetFlags((PyTypeObject *)self, 0, val & COLLECTION_FLAGS);
} }
Py_RETURN_NONE; Py_RETURN_NONE;
} }
static void
set_collection_flag_recursive(PyTypeObject *child, unsigned long flag)
{
assert(flag == Py_TPFLAGS_MAPPING || flag == Py_TPFLAGS_SEQUENCE);
if (PyType_HasFeature(child, Py_TPFLAGS_IMMUTABLETYPE) ||
(child->tp_flags & COLLECTION_FLAGS) == flag)
{
return;
}
child->tp_flags &= ~COLLECTION_FLAGS;
child->tp_flags |= flag;
PyObject *grandchildren = _PyType_GetSubclasses(child);
if (grandchildren == NULL) {
return;
}
for (Py_ssize_t i = 0; i < PyList_GET_SIZE(grandchildren); i++) {
PyObject *grandchild = PyList_GET_ITEM(grandchildren, i);
set_collection_flag_recursive((PyTypeObject *)grandchild, flag);
}
Py_DECREF(grandchildren);
}
/*[clinic input] /*[clinic input]
_abc._abc_register _abc._abc_register
@ -545,20 +578,23 @@ _abc__abc_register_impl(PyObject *module, PyObject *self, PyObject *subclass)
if (impl == NULL) { if (impl == NULL) {
return NULL; return NULL;
} }
if (_add_to_weak_set(&impl->_abc_registry, subclass) < 0) { if (_add_to_weak_set(impl, &impl->_abc_registry, subclass) < 0) {
Py_DECREF(impl); Py_DECREF(impl);
return NULL; return NULL;
} }
Py_DECREF(impl); Py_DECREF(impl);
/* Invalidate negative cache */ /* Invalidate negative cache */
get_abc_state(module)->abc_invalidation_counter++; increment_invalidation_counter(get_abc_state(module));
/* Set Py_TPFLAGS_SEQUENCE or Py_TPFLAGS_MAPPING flag */ /* Set Py_TPFLAGS_SEQUENCE or Py_TPFLAGS_MAPPING flag */
if (PyType_Check(self)) { if (PyType_Check(self)) {
unsigned long collection_flag = ((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS; unsigned long collection_flag =
PyType_GetFlags((PyTypeObject *)self) & COLLECTION_FLAGS;
if (collection_flag) { if (collection_flag) {
set_collection_flag_recursive((PyTypeObject *)subclass, collection_flag); _PyType_SetFlagsRecursive((PyTypeObject *)subclass,
COLLECTION_FLAGS,
collection_flag);
} }
} }
return Py_NewRef(subclass); return Py_NewRef(subclass);
@ -592,7 +628,7 @@ _abc__abc_instancecheck_impl(PyObject *module, PyObject *self,
return NULL; return NULL;
} }
/* Inline the cache checking. */ /* Inline the cache checking. */
int incache = _in_weak_set(impl->_abc_cache, subclass); int incache = _in_weak_set(impl, &impl->_abc_cache, subclass);
if (incache < 0) { if (incache < 0) {
goto end; goto end;
} }
@ -602,8 +638,8 @@ _abc__abc_instancecheck_impl(PyObject *module, PyObject *self,
} }
subtype = (PyObject *)Py_TYPE(instance); subtype = (PyObject *)Py_TYPE(instance);
if (subtype == subclass) { if (subtype == subclass) {
if (impl->_abc_negative_cache_version == get_abc_state(module)->abc_invalidation_counter) { if (get_cache_version(impl) == get_invalidation_counter(get_abc_state(module))) {
incache = _in_weak_set(impl->_abc_negative_cache, subclass); incache = _in_weak_set(impl, &impl->_abc_negative_cache, subclass);
if (incache < 0) { if (incache < 0) {
goto end; goto end;
} }
@ -681,7 +717,7 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self,
} }
/* 1. Check cache. */ /* 1. Check cache. */
incache = _in_weak_set(impl->_abc_cache, subclass); incache = _in_weak_set(impl, &impl->_abc_cache, subclass);
if (incache < 0) { if (incache < 0) {
goto end; goto end;
} }
@ -692,17 +728,20 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self,
state = get_abc_state(module); state = get_abc_state(module);
/* 2. Check negative cache; may have to invalidate. */ /* 2. Check negative cache; may have to invalidate. */
if (impl->_abc_negative_cache_version < state->abc_invalidation_counter) { uint64_t invalidation_counter = get_invalidation_counter(state);
if (get_cache_version(impl) < invalidation_counter) {
/* Invalidate the negative cache. */ /* Invalidate the negative cache. */
if (impl->_abc_negative_cache != NULL && PyObject *negative_cache;
PySet_Clear(impl->_abc_negative_cache) < 0) Py_BEGIN_CRITICAL_SECTION(impl);
{ negative_cache = impl->_abc_negative_cache;
Py_END_CRITICAL_SECTION();
if (negative_cache != NULL && PySet_Clear(negative_cache) < 0) {
goto end; goto end;
} }
impl->_abc_negative_cache_version = state->abc_invalidation_counter; set_cache_version(impl, invalidation_counter);
} }
else { else {
incache = _in_weak_set(impl->_abc_negative_cache, subclass); incache = _in_weak_set(impl, &impl->_abc_negative_cache, subclass);
if (incache < 0) { if (incache < 0) {
goto end; goto end;
} }
@ -720,7 +759,7 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self,
} }
if (ok == Py_True) { if (ok == Py_True) {
Py_DECREF(ok); Py_DECREF(ok);
if (_add_to_weak_set(&impl->_abc_cache, subclass) < 0) { if (_add_to_weak_set(impl, &impl->_abc_cache, subclass) < 0) {
goto end; goto end;
} }
result = Py_True; result = Py_True;
@ -728,7 +767,7 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self,
} }
if (ok == Py_False) { if (ok == Py_False) {
Py_DECREF(ok); Py_DECREF(ok);
if (_add_to_weak_set(&impl->_abc_negative_cache, subclass) < 0) { if (_add_to_weak_set(impl, &impl->_abc_negative_cache, subclass) < 0) {
goto end; goto end;
} }
result = Py_False; result = Py_False;
@ -744,7 +783,7 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self,
/* 4. Check if it's a direct subclass. */ /* 4. Check if it's a direct subclass. */
if (PyType_IsSubtype((PyTypeObject *)subclass, (PyTypeObject *)self)) { if (PyType_IsSubtype((PyTypeObject *)subclass, (PyTypeObject *)self)) {
if (_add_to_weak_set(&impl->_abc_cache, subclass) < 0) { if (_add_to_weak_set(impl, &impl->_abc_cache, subclass) < 0) {
goto end; goto end;
} }
result = Py_True; result = Py_True;
@ -767,12 +806,14 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self,
goto end; goto end;
} }
for (pos = 0; pos < PyList_GET_SIZE(subclasses); pos++) { for (pos = 0; pos < PyList_GET_SIZE(subclasses); pos++) {
PyObject *scls = PyList_GET_ITEM(subclasses, pos); PyObject *scls = PyList_GetItemRef(subclasses, pos);
Py_INCREF(scls); if (scls == NULL) {
goto end;
}
int r = PyObject_IsSubclass(subclass, scls); int r = PyObject_IsSubclass(subclass, scls);
Py_DECREF(scls); Py_DECREF(scls);
if (r > 0) { if (r > 0) {
if (_add_to_weak_set(&impl->_abc_cache, subclass) < 0) { if (_add_to_weak_set(impl, &impl->_abc_cache, subclass) < 0) {
goto end; goto end;
} }
result = Py_True; result = Py_True;
@ -784,7 +825,7 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self,
} }
/* No dice; update negative cache. */ /* No dice; update negative cache. */
if (_add_to_weak_set(&impl->_abc_negative_cache, subclass) < 0) { if (_add_to_weak_set(impl, &impl->_abc_negative_cache, subclass) < 0) {
goto end; goto end;
} }
result = Py_False; result = Py_False;
@ -801,7 +842,7 @@ subclasscheck_check_registry(_abc_data *impl, PyObject *subclass,
PyObject **result) PyObject **result)
{ {
// Fast path: check subclass is in weakref directly. // Fast path: check subclass is in weakref directly.
int ret = _in_weak_set(impl->_abc_registry, subclass); int ret = _in_weak_set(impl, &impl->_abc_registry, subclass);
if (ret < 0) { if (ret < 0) {
*result = NULL; *result = NULL;
return -1; return -1;
@ -811,33 +852,27 @@ subclasscheck_check_registry(_abc_data *impl, PyObject *subclass,
return 1; return 1;
} }
if (impl->_abc_registry == NULL) { PyObject *registry_shared;
Py_BEGIN_CRITICAL_SECTION(impl);
registry_shared = impl->_abc_registry;
Py_END_CRITICAL_SECTION();
if (registry_shared == NULL) {
return 0; return 0;
} }
Py_ssize_t registry_size = PySet_Size(impl->_abc_registry);
if (registry_size == 0) { // Make a local copy of the registry to protect against concurrent
return 0; // modifications of _abc_registry.
} PyObject *registry = PySet_New(registry_shared);
// Weakref callback may remove entry from set. if (registry == NULL) {
// So we take snapshot of registry first.
PyObject **copy = PyMem_Malloc(sizeof(PyObject*) * registry_size);
if (copy == NULL) {
PyErr_NoMemory();
return -1; return -1;
} }
PyObject *key; PyObject *key;
Py_ssize_t pos = 0; Py_ssize_t pos = 0;
Py_hash_t hash; Py_hash_t hash;
Py_ssize_t i = 0;
while (_PySet_NextEntry(impl->_abc_registry, &pos, &key, &hash)) { while (_PySet_NextEntry(registry, &pos, &key, &hash)) {
copy[i++] = Py_NewRef(key);
}
assert(i == registry_size);
for (i = 0; i < registry_size; i++) {
PyObject *rkey; PyObject *rkey;
if (PyWeakref_GetRef(copy[i], &rkey) < 0) { if (PyWeakref_GetRef(key, &rkey) < 0) {
// Someone inject non-weakref type in the registry. // Someone inject non-weakref type in the registry.
ret = -1; ret = -1;
break; break;
@ -853,7 +888,7 @@ subclasscheck_check_registry(_abc_data *impl, PyObject *subclass,
break; break;
} }
if (r > 0) { if (r > 0) {
if (_add_to_weak_set(&impl->_abc_cache, subclass) < 0) { if (_add_to_weak_set(impl, &impl->_abc_cache, subclass) < 0) {
ret = -1; ret = -1;
break; break;
} }
@ -863,10 +898,7 @@ subclasscheck_check_registry(_abc_data *impl, PyObject *subclass,
} }
} }
for (i = 0; i < registry_size; i++) { Py_DECREF(registry);
Py_DECREF(copy[i]);
}
PyMem_Free(copy);
return ret; return ret;
} }
@ -885,7 +917,7 @@ _abc_get_cache_token_impl(PyObject *module)
/*[clinic end generated code: output=c7d87841e033dacc input=70413d1c423ad9f9]*/ /*[clinic end generated code: output=c7d87841e033dacc input=70413d1c423ad9f9]*/
{ {
_abcmodule_state *state = get_abc_state(module); _abcmodule_state *state = get_abc_state(module);
return PyLong_FromUnsignedLongLong(state->abc_invalidation_counter); return PyLong_FromUnsignedLongLong(get_invalidation_counter(state));
} }
static struct PyMethodDef _abcmodule_methods[] = { static struct PyMethodDef _abcmodule_methods[] = {

View File

@ -5117,6 +5117,52 @@ _PyType_LookupId(PyTypeObject *type, _Py_Identifier *name)
return _PyType_Lookup(type, oname); return _PyType_Lookup(type, oname);
} }
static void
set_flags(PyTypeObject *self, unsigned long mask, unsigned long flags)
{
ASSERT_TYPE_LOCK_HELD();
self->tp_flags = (self->tp_flags & ~mask) | flags;
}
void
_PyType_SetFlags(PyTypeObject *self, unsigned long mask, unsigned long flags)
{
BEGIN_TYPE_LOCK();
set_flags(self, mask, flags);
END_TYPE_LOCK();
}
static void
set_flags_recursive(PyTypeObject *self, unsigned long mask, unsigned long flags)
{
if (PyType_HasFeature(self, Py_TPFLAGS_IMMUTABLETYPE) ||
(self->tp_flags & mask) == flags)
{
return;
}
set_flags(self, mask, flags);
PyObject *children = _PyType_GetSubclasses(self);
if (children == NULL) {
return;
}
for (Py_ssize_t i = 0; i < PyList_GET_SIZE(children); i++) {
PyObject *child = PyList_GET_ITEM(children, i);
set_flags_recursive((PyTypeObject *)child, mask, flags);
}
Py_DECREF(children);
}
void
_PyType_SetFlagsRecursive(PyTypeObject *self, unsigned long mask, unsigned long flags)
{
BEGIN_TYPE_LOCK();
set_flags_recursive(self, mask, flags);
END_TYPE_LOCK();
}
/* This is similar to PyObject_GenericGetAttr(), /* This is similar to PyObject_GenericGetAttr(),
but uses _PyType_Lookup() instead of just looking in type->tp_dict. but uses _PyType_Lookup() instead of just looking in type->tp_dict.