gh-108240: Add _PyCapsule_SetTraverse() internal function (#108339)

The _socket extension uses _PyCapsule_SetTraverse() to visit and clear
the socket type in the garbage collector. So the _socket.socket type
can be cleared in some corner cases when it wasn't possible before.
This commit is contained in:
Victor Stinner 2023-08-24 00:19:11 +02:00 committed by GitHub
parent b6be18812c
commit 513c89d012
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 136 additions and 73 deletions

View File

@ -48,6 +48,10 @@ PyAPI_FUNC(int) PyCapsule_SetName(PyObject *capsule, const char *name);
PyAPI_FUNC(int) PyCapsule_SetContext(PyObject *capsule, void *context); PyAPI_FUNC(int) PyCapsule_SetContext(PyObject *capsule, void *context);
#ifdef Py_BUILD_CORE
PyAPI_FUNC(int) _PyCapsule_SetTraverse(PyObject *op, traverseproc traverse_func, inquiry clear_func);
#endif
PyAPI_FUNC(void *) PyCapsule_Import( PyAPI_FUNC(void *) PyCapsule_Import(
const char *name, /* UTF-8 encoded string */ const char *name, /* UTF-8 encoded string */
int no_block); int no_block);

View File

@ -7314,20 +7314,39 @@ os_init(void)
} }
#endif #endif
static void static int
sock_free_api(PySocketModule_APIObject *capi) sock_capi_traverse(PyObject *capsule, visitproc visit, void *arg)
{ {
Py_DECREF(capi->Sock_Type); PySocketModule_APIObject *capi = PyCapsule_GetPointer(capsule, PySocket_CAPSULE_NAME);
assert(capi != NULL);
Py_VISIT(capi->Sock_Type);
return 0;
}
static int
sock_capi_clear(PyObject *capsule)
{
PySocketModule_APIObject *capi = PyCapsule_GetPointer(capsule, PySocket_CAPSULE_NAME);
assert(capi != NULL);
Py_CLEAR(capi->Sock_Type);
return 0;
}
static void
sock_capi_free(PySocketModule_APIObject *capi)
{
Py_XDECREF(capi->Sock_Type); // sock_capi_free() can clear it
Py_DECREF(capi->error); Py_DECREF(capi->error);
Py_DECREF(capi->timeout_error); Py_DECREF(capi->timeout_error);
PyMem_Free(capi); PyMem_Free(capi);
} }
static void static void
sock_destroy_api(PyObject *capsule) sock_capi_destroy(PyObject *capsule)
{ {
void *capi = PyCapsule_GetPointer(capsule, PySocket_CAPSULE_NAME); void *capi = PyCapsule_GetPointer(capsule, PySocket_CAPSULE_NAME);
sock_free_api(capi); assert(capi != NULL);
sock_capi_free(capi);
} }
static PySocketModule_APIObject * static PySocketModule_APIObject *
@ -7432,11 +7451,17 @@ socket_exec(PyObject *m)
} }
PyObject *capsule = PyCapsule_New(capi, PyObject *capsule = PyCapsule_New(capi,
PySocket_CAPSULE_NAME, PySocket_CAPSULE_NAME,
sock_destroy_api); sock_capi_destroy);
if (capsule == NULL) { if (capsule == NULL) {
sock_free_api(capi); sock_capi_free(capi);
goto error; goto error;
} }
if (_PyCapsule_SetTraverse(capsule,
sock_capi_traverse, sock_capi_clear) < 0) {
sock_capi_free(capi);
goto error;
}
if (PyModule_Add(m, PySocket_CAPI_NAME, capsule) < 0) { if (PyModule_Add(m, PySocket_CAPI_NAME, capsule) < 0) {
goto error; goto error;
} }

View File

@ -9,18 +9,28 @@ typedef struct {
const char *name; const char *name;
void *context; void *context;
PyCapsule_Destructor destructor; PyCapsule_Destructor destructor;
traverseproc traverse_func;
inquiry clear_func;
} PyCapsule; } PyCapsule;
static int static int
_is_legal_capsule(PyCapsule *capsule, const char *invalid_capsule) _is_legal_capsule(PyObject *op, const char *invalid_capsule)
{ {
if (!capsule || !PyCapsule_CheckExact(capsule) || capsule->pointer == NULL) { if (!op || !PyCapsule_CheckExact(op)) {
PyErr_SetString(PyExc_ValueError, invalid_capsule); goto error;
return 0; }
PyCapsule *capsule = (PyCapsule *)op;
if (capsule->pointer == NULL) {
goto error;
} }
return 1; return 1;
error:
PyErr_SetString(PyExc_ValueError, invalid_capsule);
return 0;
} }
#define is_legal_capsule(capsule, name) \ #define is_legal_capsule(capsule, name) \
@ -50,7 +60,7 @@ PyCapsule_New(void *pointer, const char *name, PyCapsule_Destructor destructor)
return NULL; return NULL;
} }
capsule = PyObject_New(PyCapsule, &PyCapsule_Type); capsule = PyObject_GC_New(PyCapsule, &PyCapsule_Type);
if (capsule == NULL) { if (capsule == NULL) {
return NULL; return NULL;
} }
@ -59,15 +69,18 @@ PyCapsule_New(void *pointer, const char *name, PyCapsule_Destructor destructor)
capsule->name = name; capsule->name = name;
capsule->context = NULL; capsule->context = NULL;
capsule->destructor = destructor; capsule->destructor = destructor;
capsule->traverse_func = NULL;
capsule->clear_func = NULL;
// Only track the capsule if _PyCapsule_SetTraverse() is called
return (PyObject *)capsule; return (PyObject *)capsule;
} }
int int
PyCapsule_IsValid(PyObject *o, const char *name) PyCapsule_IsValid(PyObject *op, const char *name)
{ {
PyCapsule *capsule = (PyCapsule *)o; PyCapsule *capsule = (PyCapsule *)op;
return (capsule != NULL && return (capsule != NULL &&
PyCapsule_CheckExact(capsule) && PyCapsule_CheckExact(capsule) &&
@ -77,13 +90,12 @@ PyCapsule_IsValid(PyObject *o, const char *name)
void * void *
PyCapsule_GetPointer(PyObject *o, const char *name) PyCapsule_GetPointer(PyObject *op, const char *name)
{ {
PyCapsule *capsule = (PyCapsule *)o; if (!is_legal_capsule(op, "PyCapsule_GetPointer")) {
if (!is_legal_capsule(capsule, "PyCapsule_GetPointer")) {
return NULL; return NULL;
} }
PyCapsule *capsule = (PyCapsule *)op;
if (!name_matches(name, capsule->name)) { if (!name_matches(name, capsule->name)) {
PyErr_SetString(PyExc_ValueError, "PyCapsule_GetPointer called with incorrect name"); PyErr_SetString(PyExc_ValueError, "PyCapsule_GetPointer called with incorrect name");
@ -95,68 +107,63 @@ PyCapsule_GetPointer(PyObject *o, const char *name)
const char * const char *
PyCapsule_GetName(PyObject *o) PyCapsule_GetName(PyObject *op)
{ {
PyCapsule *capsule = (PyCapsule *)o; if (!is_legal_capsule(op, "PyCapsule_GetName")) {
if (!is_legal_capsule(capsule, "PyCapsule_GetName")) {
return NULL; return NULL;
} }
PyCapsule *capsule = (PyCapsule *)op;
return capsule->name; return capsule->name;
} }
PyCapsule_Destructor PyCapsule_Destructor
PyCapsule_GetDestructor(PyObject *o) PyCapsule_GetDestructor(PyObject *op)
{ {
PyCapsule *capsule = (PyCapsule *)o; if (!is_legal_capsule(op, "PyCapsule_GetDestructor")) {
if (!is_legal_capsule(capsule, "PyCapsule_GetDestructor")) {
return NULL; return NULL;
} }
PyCapsule *capsule = (PyCapsule *)op;
return capsule->destructor; return capsule->destructor;
} }
void * void *
PyCapsule_GetContext(PyObject *o) PyCapsule_GetContext(PyObject *op)
{ {
PyCapsule *capsule = (PyCapsule *)o; if (!is_legal_capsule(op, "PyCapsule_GetContext")) {
if (!is_legal_capsule(capsule, "PyCapsule_GetContext")) {
return NULL; return NULL;
} }
PyCapsule *capsule = (PyCapsule *)op;
return capsule->context; return capsule->context;
} }
int int
PyCapsule_SetPointer(PyObject *o, void *pointer) PyCapsule_SetPointer(PyObject *op, void *pointer)
{ {
PyCapsule *capsule = (PyCapsule *)o; if (!is_legal_capsule(op, "PyCapsule_SetPointer")) {
return -1;
}
PyCapsule *capsule = (PyCapsule *)op;
if (!pointer) { if (!pointer) {
PyErr_SetString(PyExc_ValueError, "PyCapsule_SetPointer called with null pointer"); PyErr_SetString(PyExc_ValueError, "PyCapsule_SetPointer called with null pointer");
return -1; return -1;
} }
if (!is_legal_capsule(capsule, "PyCapsule_SetPointer")) {
return -1;
}
capsule->pointer = pointer; capsule->pointer = pointer;
return 0; return 0;
} }
int int
PyCapsule_SetName(PyObject *o, const char *name) PyCapsule_SetName(PyObject *op, const char *name)
{ {
PyCapsule *capsule = (PyCapsule *)o; if (!is_legal_capsule(op, "PyCapsule_SetName")) {
if (!is_legal_capsule(capsule, "PyCapsule_SetName")) {
return -1; return -1;
} }
PyCapsule *capsule = (PyCapsule *)op;
capsule->name = name; capsule->name = name;
return 0; return 0;
@ -164,13 +171,12 @@ PyCapsule_SetName(PyObject *o, const char *name)
int int
PyCapsule_SetDestructor(PyObject *o, PyCapsule_Destructor destructor) PyCapsule_SetDestructor(PyObject *op, PyCapsule_Destructor destructor)
{ {
PyCapsule *capsule = (PyCapsule *)o; if (!is_legal_capsule(op, "PyCapsule_SetDestructor")) {
if (!is_legal_capsule(capsule, "PyCapsule_SetDestructor")) {
return -1; return -1;
} }
PyCapsule *capsule = (PyCapsule *)op;
capsule->destructor = destructor; capsule->destructor = destructor;
return 0; return 0;
@ -178,19 +184,36 @@ PyCapsule_SetDestructor(PyObject *o, PyCapsule_Destructor destructor)
int int
PyCapsule_SetContext(PyObject *o, void *context) PyCapsule_SetContext(PyObject *op, void *context)
{ {
PyCapsule *capsule = (PyCapsule *)o; if (!is_legal_capsule(op, "PyCapsule_SetContext")) {
if (!is_legal_capsule(capsule, "PyCapsule_SetContext")) {
return -1; return -1;
} }
PyCapsule *capsule = (PyCapsule *)op;
capsule->context = context; capsule->context = context;
return 0; return 0;
} }
int
_PyCapsule_SetTraverse(PyObject *op, traverseproc traverse_func, inquiry clear_func)
{
if (!is_legal_capsule(op, "_PyCapsule_SetTraverse")) {
return -1;
}
PyCapsule *capsule = (PyCapsule *)op;
if (!PyObject_GC_IsTracked(op)) {
PyObject_GC_Track(op);
}
capsule->traverse_func = traverse_func;
capsule->clear_func = clear_func;
return 0;
}
void * void *
PyCapsule_Import(const char *name, int no_block) PyCapsule_Import(const char *name, int no_block)
{ {
@ -249,13 +272,14 @@ EXIT:
static void static void
capsule_dealloc(PyObject *o) capsule_dealloc(PyObject *op)
{ {
PyCapsule *capsule = (PyCapsule *)o; PyCapsule *capsule = (PyCapsule *)op;
PyObject_GC_UnTrack(op);
if (capsule->destructor) { if (capsule->destructor) {
capsule->destructor(o); capsule->destructor(op);
} }
PyObject_Free(o); PyObject_GC_Del(op);
} }
@ -279,6 +303,29 @@ capsule_repr(PyObject *o)
} }
static int
capsule_traverse(PyCapsule *capsule, visitproc visit, void *arg)
{
if (capsule->traverse_func) {
return capsule->traverse_func((PyObject*)capsule, visit, arg);
}
else {
return 0;
}
}
static int
capsule_clear(PyCapsule *capsule)
{
if (capsule->clear_func) {
return capsule->clear_func((PyObject*)capsule);
}
else {
return 0;
}
}
PyDoc_STRVAR(PyCapsule_Type__doc__, PyDoc_STRVAR(PyCapsule_Type__doc__,
"Capsule objects let you wrap a C \"void *\" pointer in a Python\n\ "Capsule objects let you wrap a C \"void *\" pointer in a Python\n\
@ -293,27 +340,14 @@ Python import mechanism to link to one another.\n\
PyTypeObject PyCapsule_Type = { PyTypeObject PyCapsule_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0) PyVarObject_HEAD_INIT(&PyType_Type, 0)
"PyCapsule", /*tp_name*/ .tp_name = "PyCapsule",
sizeof(PyCapsule), /*tp_basicsize*/ .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
0, /*tp_itemsize*/ .tp_basicsize = sizeof(PyCapsule),
/* methods */ .tp_dealloc = capsule_dealloc,
capsule_dealloc, /*tp_dealloc*/ .tp_repr = capsule_repr,
0, /*tp_vectorcall_offset*/ .tp_doc = PyCapsule_Type__doc__,
0, /*tp_getattr*/ .tp_traverse = (traverseproc)capsule_traverse,
0, /*tp_setattr*/ .tp_clear = (inquiry)capsule_clear,
0, /*tp_as_async*/
capsule_repr, /*tp_repr*/
0, /*tp_as_number*/
0, /*tp_as_sequence*/
0, /*tp_as_mapping*/
0, /*tp_hash*/
0, /*tp_call*/
0, /*tp_str*/
0, /*tp_getattro*/
0, /*tp_setattro*/
0, /*tp_as_buffer*/
0, /*tp_flags*/
PyCapsule_Type__doc__ /*tp_doc*/
}; };