From 94444ea45a86ae46e0c11d1917c43c4c271018cd Mon Sep 17 00:00:00 2001 From: Donghee Na Date: Fri, 19 Apr 2024 00:18:22 +0900 Subject: [PATCH] gh-112069: Add _PySet_NextEntryRef to be thread-safe. (gh-117990) --- Include/internal/pycore_setobject.h | 9 +++++++- Modules/_abc.c | 2 +- Modules/_pickle.c | 33 +++++++++++++++++------------ Modules/_testinternalcapi/set.c | 10 ++++++--- Objects/dictobject.c | 5 +++-- Objects/listobject.c | 3 +-- Objects/setobject.c | 18 +++++++++++++++- Python/marshal.c | 29 +++++++++++++++---------- Python/pylifecycle.c | 1 + 9 files changed, 76 insertions(+), 34 deletions(-) diff --git a/Include/internal/pycore_setobject.h b/Include/internal/pycore_setobject.h index c4ec3ceb17e..41b351ead25 100644 --- a/Include/internal/pycore_setobject.h +++ b/Include/internal/pycore_setobject.h @@ -8,13 +8,20 @@ extern "C" { # error "this header requires Py_BUILD_CORE define" #endif -// Export for '_pickle' shared extension +// Export for '_abc' shared extension PyAPI_FUNC(int) _PySet_NextEntry( PyObject *set, Py_ssize_t *pos, PyObject **key, Py_hash_t *hash); +// Export for '_pickle' shared extension +PyAPI_FUNC(int) _PySet_NextEntryRef( + PyObject *set, + Py_ssize_t *pos, + PyObject **key, + Py_hash_t *hash); + // Export for '_pickle' shared extension PyAPI_FUNC(int) _PySet_Update(PyObject *set, PyObject *iterable); diff --git a/Modules/_abc.c b/Modules/_abc.c index ad28035843f..f2a523e6f2f 100644 --- a/Modules/_abc.c +++ b/Modules/_abc.c @@ -862,7 +862,7 @@ subclasscheck_check_registry(_abc_data *impl, PyObject *subclass, // Make a local copy of the registry to protect against concurrent // modifications of _abc_registry. - PyObject *registry = PySet_New(registry_shared); + PyObject *registry = PyFrozenSet_New(registry_shared); if (registry == NULL) { return -1; } diff --git a/Modules/_pickle.c b/Modules/_pickle.c index 0d832611681..d7ffb04c28c 100644 --- a/Modules/_pickle.c +++ b/Modules/_pickle.c @@ -9,15 +9,16 @@ #endif #include "Python.h" -#include "pycore_bytesobject.h" // _PyBytesWriter -#include "pycore_ceval.h" // _Py_EnterRecursiveCall() -#include "pycore_long.h" // _PyLong_AsByteArray() -#include "pycore_moduleobject.h" // _PyModule_GetState() -#include "pycore_object.h" // _PyNone_Type -#include "pycore_pystate.h" // _PyThreadState_GET() -#include "pycore_runtime.h" // _Py_ID() -#include "pycore_setobject.h" // _PySet_NextEntry() -#include "pycore_sysmodule.h" // _PySys_GetAttr() +#include "pycore_bytesobject.h" // _PyBytesWriter +#include "pycore_ceval.h" // _Py_EnterRecursiveCall() +#include "pycore_critical_section.h" // Py_BEGIN_CRITICAL_SECTION() +#include "pycore_long.h" // _PyLong_AsByteArray() +#include "pycore_moduleobject.h" // _PyModule_GetState() +#include "pycore_object.h" // _PyNone_Type +#include "pycore_pystate.h" // _PyThreadState_GET() +#include "pycore_runtime.h" // _Py_ID() +#include "pycore_setobject.h" // _PySet_NextEntry() +#include "pycore_sysmodule.h" // _PySys_GetAttr() #include // strtol() @@ -3413,15 +3414,21 @@ save_set(PickleState *state, PicklerObject *self, PyObject *obj) i = 0; if (_Pickler_Write(self, &mark_op, 1) < 0) return -1; - while (_PySet_NextEntry(obj, &ppos, &item, &hash)) { - Py_INCREF(item); - int err = save(state, self, item, 0); + + int err = 0; + Py_BEGIN_CRITICAL_SECTION(obj); + while (_PySet_NextEntryRef(obj, &ppos, &item, &hash)) { + err = save(state, self, item, 0); Py_CLEAR(item); if (err < 0) - return -1; + break; if (++i == BATCHSIZE) break; } + Py_END_CRITICAL_SECTION(); + if (err < 0) { + return -1; + } if (_Pickler_Write(self, &additems_op, 1) < 0) return -1; if (PySet_GET_SIZE(obj) != set_size) { diff --git a/Modules/_testinternalcapi/set.c b/Modules/_testinternalcapi/set.c index 0305a7885d2..01aab03cc10 100644 --- a/Modules/_testinternalcapi/set.c +++ b/Modules/_testinternalcapi/set.c @@ -1,6 +1,7 @@ #include "parts.h" #include "../_testcapi/util.h" // NULLABLE, RETURN_INT +#include "pycore_critical_section.h" #include "pycore_setobject.h" @@ -27,10 +28,13 @@ set_next_entry(PyObject *self, PyObject *args) return NULL; } NULLABLE(set); - - rc = _PySet_NextEntry(set, &pos, &item, &hash); + Py_BEGIN_CRITICAL_SECTION(set); + rc = _PySet_NextEntryRef(set, &pos, &item, &hash); + Py_END_CRITICAL_SECTION(); if (rc == 1) { - return Py_BuildValue("innO", rc, pos, hash, item); + PyObject *ret = Py_BuildValue("innO", rc, pos, hash, item); + Py_DECREF(item); + return ret; } assert(item == UNINITIALIZED_PTR); assert(hash == (Py_hash_t)UNINITIALIZED_SIZE); diff --git a/Objects/dictobject.c b/Objects/dictobject.c index 003a03fd741..58f34c32a87 100644 --- a/Objects/dictobject.c +++ b/Objects/dictobject.c @@ -2979,8 +2979,9 @@ dict_set_fromkeys(PyInterpreterState *interp, PyDictObject *mp, return NULL; } - while (_PySet_NextEntry(iterable, &pos, &key, &hash)) { - if (insertdict(interp, mp, Py_NewRef(key), hash, Py_NewRef(value))) { + _Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED(iterable); + while (_PySet_NextEntryRef(iterable, &pos, &key, &hash)) { + if (insertdict(interp, mp, key, hash, Py_NewRef(value))) { Py_DECREF(mp); return NULL; } diff --git a/Objects/listobject.c b/Objects/listobject.c index 472c471d996..4eaf20033fa 100644 --- a/Objects/listobject.c +++ b/Objects/listobject.c @@ -1287,8 +1287,7 @@ list_extend_set(PyListObject *self, PySetObject *other) Py_hash_t hash; PyObject *key; PyObject **dest = self->ob_item + m; - while (_PySet_NextEntry((PyObject *)other, &setpos, &key, &hash)) { - Py_INCREF(key); + while (_PySet_NextEntryRef((PyObject *)other, &setpos, &key, &hash)) { FT_ATOMIC_STORE_PTR_RELEASE(*dest, key); dest++; } diff --git a/Objects/setobject.c b/Objects/setobject.c index 66ca80e8fc2..7af0ae166f9 100644 --- a/Objects/setobject.c +++ b/Objects/setobject.c @@ -2661,7 +2661,6 @@ PySet_Add(PyObject *anyset, PyObject *key) return rv; } -// TODO: Make thread-safe in free-threaded builds int _PySet_NextEntry(PyObject *set, Py_ssize_t *pos, PyObject **key, Py_hash_t *hash) { @@ -2678,6 +2677,23 @@ _PySet_NextEntry(PyObject *set, Py_ssize_t *pos, PyObject **key, Py_hash_t *hash return 1; } +int +_PySet_NextEntryRef(PyObject *set, Py_ssize_t *pos, PyObject **key, Py_hash_t *hash) +{ + setentry *entry; + + if (!PyAnySet_Check(set)) { + PyErr_BadInternalCall(); + return -1; + } + _Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED(set); + if (set_next((PySetObject *)set, pos, &entry) == 0) + return 0; + *key = Py_NewRef(entry->key); + *hash = entry->hash; + return 1; +} + PyObject * PySet_Pop(PyObject *set) { diff --git a/Python/marshal.c b/Python/marshal.c index 21d242bbb97..4274f90206b 100644 --- a/Python/marshal.c +++ b/Python/marshal.c @@ -7,12 +7,13 @@ and sharing. */ #include "Python.h" -#include "pycore_call.h" // _PyObject_CallNoArgs() -#include "pycore_code.h" // _PyCode_New() -#include "pycore_hashtable.h" // _Py_hashtable_t -#include "pycore_long.h" // _PyLong_DigitCount -#include "pycore_setobject.h" // _PySet_NextEntry() -#include "marshal.h" // Py_MARSHAL_VERSION +#include "pycore_call.h" // _PyObject_CallNoArgs() +#include "pycore_code.h" // _PyCode_New() +#include "pycore_critical_section.h" // Py_BEGIN_CRITICAL_SECTION() +#include "pycore_hashtable.h" // _Py_hashtable_t +#include "pycore_long.h" // _PyLong_DigitCount +#include "pycore_setobject.h" // _PySet_NextEntry() +#include "marshal.h" // Py_MARSHAL_VERSION #ifdef __APPLE__ # include "TargetConditionals.h" @@ -531,23 +532,29 @@ w_complex_object(PyObject *v, char flag, WFILE *p) return; } Py_ssize_t i = 0; - while (_PySet_NextEntry(v, &pos, &value, &hash)) { + Py_BEGIN_CRITICAL_SECTION(v); + while (_PySet_NextEntryRef(v, &pos, &value, &hash)) { PyObject *dump = _PyMarshal_WriteObjectToString(value, p->version, p->allow_code); if (dump == NULL) { p->error = WFERR_UNMARSHALLABLE; - Py_DECREF(pairs); - return; + Py_DECREF(value); + break; } PyObject *pair = PyTuple_Pack(2, dump, value); Py_DECREF(dump); + Py_DECREF(value); if (pair == NULL) { p->error = WFERR_NOMEMORY; - Py_DECREF(pairs); - return; + break; } PyList_SET_ITEM(pairs, i++, pair); } + Py_END_CRITICAL_SECTION(); + if (p->error == WFERR_UNMARSHALLABLE || p->error == WFERR_NOMEMORY) { + Py_DECREF(pairs); + return; + } assert(i == n); if (PyList_Sort(pairs)) { p->error = WFERR_NOMEMORY; diff --git a/Python/pylifecycle.c b/Python/pylifecycle.c index efb25878312..cc1824634e7 100644 --- a/Python/pylifecycle.c +++ b/Python/pylifecycle.c @@ -2910,6 +2910,7 @@ _Py_DumpExtensionModules(int fd, PyInterpreterState *interp) Py_ssize_t i = 0; PyObject *item; Py_hash_t hash; + // if stdlib_module_names is not NULL, it is always a frozenset. while (_PySet_NextEntry(stdlib_module_names, &i, &item, &hash)) { if (PyUnicode_Check(item) && PyUnicode_Compare(key, item) == 0)