gh-117578: Introduce _PyType_GetModuleByDef2 private function (GH-117661)

Co-authored-by: Erlend E. Aasland <erlend.aasland@protonmail.com>
Co-authored-by: Petr Viktorin <encukou@gmail.com>
This commit is contained in:
neonene 2024-04-25 20:51:31 +09:00 committed by GitHub
parent f180b31e76
commit 2c45148912
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 49 additions and 12 deletions

View File

@ -164,6 +164,7 @@ extern PyObject * _PyType_GetBases(PyTypeObject *type);
extern PyObject * _PyType_GetMRO(PyTypeObject *type); extern PyObject * _PyType_GetMRO(PyTypeObject *type);
extern PyObject* _PyType_GetSubclasses(PyTypeObject *); extern PyObject* _PyType_GetSubclasses(PyTypeObject *);
extern int _PyType_HasSubclasses(PyTypeObject *); extern int _PyType_HasSubclasses(PyTypeObject *);
PyAPI_FUNC(PyObject *) _PyType_GetModuleByDef2(PyTypeObject *, PyTypeObject *, PyModuleDef *);
// PyType_Ready() must be called if _PyType_IsReady() is false. // PyType_Ready() must be called if _PyType_IsReady() is false.
// See also the Py_TPFLAGS_READY flag. // See also the Py_TPFLAGS_READY flag.

View File

@ -32,6 +32,7 @@
#include <Python.h> #include <Python.h>
#include "pycore_long.h" // _PyLong_IsZero() #include "pycore_long.h" // _PyLong_IsZero()
#include "pycore_pystate.h" // _PyThreadState_GET() #include "pycore_pystate.h" // _PyThreadState_GET()
#include "pycore_typeobject.h"
#include "complexobject.h" #include "complexobject.h"
#include "mpdecimal.h" #include "mpdecimal.h"
@ -120,11 +121,8 @@ get_module_state_by_def(PyTypeObject *tp)
static inline decimal_state * static inline decimal_state *
find_state_left_or_right(PyObject *left, PyObject *right) find_state_left_or_right(PyObject *left, PyObject *right)
{ {
PyObject *mod = PyType_GetModuleByDef(Py_TYPE(left), &_decimal_module); PyObject *mod = _PyType_GetModuleByDef2(Py_TYPE(left), Py_TYPE(right),
if (mod == NULL) { &_decimal_module);
PyErr_Clear();
mod = PyType_GetModuleByDef(Py_TYPE(right), &_decimal_module);
}
assert(mod != NULL); assert(mod != NULL);
return get_module_state(mod); return get_module_state(mod);
} }

View File

@ -4825,11 +4825,24 @@ PyType_GetModuleState(PyTypeObject *type)
/* Get the module of the first superclass where the module has the /* Get the module of the first superclass where the module has the
* given PyModuleDef. * given PyModuleDef.
*/ */
PyObject * static inline PyObject *
PyType_GetModuleByDef(PyTypeObject *type, PyModuleDef *def) get_module_by_def(PyTypeObject *type, PyModuleDef *def)
{ {
assert(PyType_Check(type)); assert(PyType_Check(type));
if (!_PyType_HasFeature(type, Py_TPFLAGS_HEAPTYPE)) {
// type_ready_mro() ensures that no heap type is
// contained in a static type MRO.
return NULL;
}
else {
PyHeapTypeObject *ht = (PyHeapTypeObject*)type;
PyObject *module = ht->ht_module;
if (module && _PyModule_GetDef(module) == def) {
return module;
}
}
PyObject *res = NULL; PyObject *res = NULL;
BEGIN_TYPE_LOCK() BEGIN_TYPE_LOCK()
@ -4837,12 +4850,14 @@ PyType_GetModuleByDef(PyTypeObject *type, PyModuleDef *def)
// The type must be ready // The type must be ready
assert(mro != NULL); assert(mro != NULL);
assert(PyTuple_Check(mro)); assert(PyTuple_Check(mro));
// mro_invoke() ensures that the type MRO cannot be empty, so we don't have // mro_invoke() ensures that the type MRO cannot be empty.
// to check i < PyTuple_GET_SIZE(mro) at the first loop iteration.
assert(PyTuple_GET_SIZE(mro) >= 1); assert(PyTuple_GET_SIZE(mro) >= 1);
// Also, the first item in the MRO is the type itself, which
// we already checked above. We skip it in the loop.
assert(PyTuple_GET_ITEM(mro, 0) == (PyObject *)type);
Py_ssize_t n = PyTuple_GET_SIZE(mro); Py_ssize_t n = PyTuple_GET_SIZE(mro);
for (Py_ssize_t i = 0; i < n; i++) { for (Py_ssize_t i = 1; i < n; i++) {
PyObject *super = PyTuple_GET_ITEM(mro, i); PyObject *super = PyTuple_GET_ITEM(mro, i);
if(!_PyType_HasFeature((PyTypeObject *)super, Py_TPFLAGS_HEAPTYPE)) { if(!_PyType_HasFeature((PyTypeObject *)super, Py_TPFLAGS_HEAPTYPE)) {
// Static types in the MRO need to be skipped // Static types in the MRO need to be skipped
@ -4857,14 +4872,37 @@ PyType_GetModuleByDef(PyTypeObject *type, PyModuleDef *def)
} }
} }
END_TYPE_LOCK() END_TYPE_LOCK()
return res;
}
if (res == NULL) { PyObject *
PyType_GetModuleByDef(PyTypeObject *type, PyModuleDef *def)
{
PyObject *module = get_module_by_def(type, def);
if (module == NULL) {
PyErr_Format( PyErr_Format(
PyExc_TypeError, PyExc_TypeError,
"PyType_GetModuleByDef: No superclass of '%s' has the given module", "PyType_GetModuleByDef: No superclass of '%s' has the given module",
type->tp_name); type->tp_name);
} }
return res; return module;
}
PyObject *
_PyType_GetModuleByDef2(PyTypeObject *left, PyTypeObject *right,
PyModuleDef *def)
{
PyObject *module = get_module_by_def(left, def);
if (module == NULL) {
module = get_module_by_def(right, def);
if (module == NULL) {
PyErr_Format(
PyExc_TypeError,
"PyType_GetModuleByDef: No superclass of '%s' nor '%s' has "
"the given module", left->tp_name, right->tp_name);
}
}
return module;
} }
void * void *