From 9ed523159c7ba840dbf403e02498eeae1b5d3ed9 Mon Sep 17 00:00:00 2001 From: Erlend Egeberg Aasland Date: Tue, 24 Aug 2021 14:24:09 +0200 Subject: [PATCH] bpo-42064: Pass module state to `sqlite3` UDF callbacks (GH-27456) - Establish common callback context struct - Convert UDF callbacks to fetch module state from callback context --- Modules/_sqlite/connection.c | 89 +++++++++++++++++++++++------------- Modules/_sqlite/connection.h | 6 +++ 2 files changed, 64 insertions(+), 31 deletions(-) diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index 0645367988e..8ad5f5f061d 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -612,8 +612,10 @@ set_sqlite_error(sqlite3_context *context, const char *msg) else { sqlite3_result_error(context, msg, -1); } - pysqlite_state *state = pysqlite_get_state(NULL); - if (state->enable_callback_tracebacks) { + callback_context *ctx = (callback_context *)sqlite3_user_data(context); + assert(ctx != NULL); + assert(ctx->state != NULL); + if (ctx->state->enable_callback_tracebacks) { PyErr_Print(); } else { @@ -625,7 +627,6 @@ static void _pysqlite_func_callback(sqlite3_context *context, int argc, sqlite3_value **argv) { PyObject* args; - PyObject* py_func; PyObject* py_retval = NULL; int ok; @@ -633,11 +634,11 @@ _pysqlite_func_callback(sqlite3_context *context, int argc, sqlite3_value **argv threadstate = PyGILState_Ensure(); - py_func = (PyObject*)sqlite3_user_data(context); - args = _pysqlite_build_py_params(context, argc, argv); if (args) { - py_retval = PyObject_CallObject(py_func, args); + callback_context *ctx = (callback_context *)sqlite3_user_data(context); + assert(ctx != NULL); + py_retval = PyObject_CallObject(ctx->callable, args); Py_DECREF(args); } @@ -657,7 +658,6 @@ static void _pysqlite_step_callback(sqlite3_context *context, int argc, sqlite3_ { PyObject* args; PyObject* function_result = NULL; - PyObject* aggregate_class; PyObject** aggregate_instance; PyObject* stepmethod = NULL; @@ -665,12 +665,12 @@ static void _pysqlite_step_callback(sqlite3_context *context, int argc, sqlite3_ threadstate = PyGILState_Ensure(); - aggregate_class = (PyObject*)sqlite3_user_data(context); - aggregate_instance = (PyObject**)sqlite3_aggregate_context(context, sizeof(PyObject*)); if (*aggregate_instance == NULL) { - *aggregate_instance = _PyObject_CallNoArg(aggregate_class); + callback_context *ctx = (callback_context *)sqlite3_user_data(context); + assert(ctx != NULL); + *aggregate_instance = _PyObject_CallNoArg(ctx->callable); if (!*aggregate_instance) { set_sqlite_error(context, "user-defined aggregate's '__init__' method raised error"); @@ -784,14 +784,35 @@ static void _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self) Py_SETREF(self->cursors, new_list); } +static callback_context * +create_callback_context(pysqlite_state *state, PyObject *callable) +{ + PyGILState_STATE gstate = PyGILState_Ensure(); + callback_context *ctx = PyMem_Malloc(sizeof(callback_context)); + if (ctx != NULL) { + ctx->callable = Py_NewRef(callable); + ctx->state = state; + } + PyGILState_Release(gstate); + return ctx; +} + +static void +free_callback_context(callback_context *ctx) +{ + if (ctx != NULL) { + // This function may be called without the GIL held, so we need to + // ensure that we destroy 'ctx' with the GIL held. + PyGILState_STATE gstate = PyGILState_Ensure(); + Py_DECREF(ctx->callable); + PyMem_Free(ctx); + PyGILState_Release(gstate); + } +} + static void _destructor(void* args) { - // This function may be called without the GIL held, so we need to ensure - // that we destroy 'args' with the GIL - PyGILState_STATE gstate; - gstate = PyGILState_Ensure(); - Py_DECREF((PyObject*)args); - PyGILState_Release(gstate); + free_callback_context((callback_context *)args); } /*[clinic input] @@ -833,11 +854,11 @@ pysqlite_connection_create_function_impl(pysqlite_Connection *self, flags |= SQLITE_DETERMINISTIC; #endif } - rc = sqlite3_create_function_v2(self->db, - name, - narg, - flags, - (void*)Py_NewRef(func), + callback_context *ctx = create_callback_context(self->state, func); + if (ctx == NULL) { + return NULL; + } + rc = sqlite3_create_function_v2(self->db, name, narg, flags, ctx, _pysqlite_func_callback, NULL, NULL, @@ -873,11 +894,12 @@ pysqlite_connection_create_aggregate_impl(pysqlite_Connection *self, return NULL; } - rc = sqlite3_create_function_v2(self->db, - name, - n_arg, - SQLITE_UTF8, - (void*)Py_NewRef(aggregate_class), + callback_context *ctx = create_callback_context(self->state, + aggregate_class); + if (ctx == NULL) { + return NULL; + } + rc = sqlite3_create_function_v2(self->db, name, n_arg, SQLITE_UTF8, ctx, 0, &_pysqlite_step_callback, &_pysqlite_final_callback, @@ -1439,7 +1461,6 @@ pysqlite_collation_callback( int text1_length, const void* text1_data, int text2_length, const void* text2_data) { - PyObject* callback = (PyObject*)context; PyObject* string1 = 0; PyObject* string2 = 0; PyGILState_STATE gilstate; @@ -1459,8 +1480,10 @@ pysqlite_collation_callback( goto finally; /* failed to allocate strings */ } + callback_context *ctx = (callback_context *)context; + assert(ctx != NULL); PyObject *args[] = { string1, string2 }; // Borrowed refs. - retval = PyObject_Vectorcall(callback, args, 2, NULL); + retval = PyObject_Vectorcall(ctx->callable, args, 2, NULL); if (retval == NULL) { /* execution failed */ goto finally; @@ -1690,6 +1713,7 @@ pysqlite_connection_create_collation_impl(pysqlite_Connection *self, return NULL; } + callback_context *ctx = NULL; int rc; int flags = SQLITE_UTF8; if (callable == Py_None) { @@ -1701,8 +1725,11 @@ pysqlite_connection_create_collation_impl(pysqlite_Connection *self, PyErr_SetString(PyExc_TypeError, "parameter must be callable"); return NULL; } - rc = sqlite3_create_collation_v2(self->db, name, flags, - Py_NewRef(callable), + ctx = create_callback_context(self->state, callable); + if (ctx == NULL) { + return NULL; + } + rc = sqlite3_create_collation_v2(self->db, name, flags, ctx, &pysqlite_collation_callback, &_destructor); } @@ -1713,7 +1740,7 @@ pysqlite_connection_create_collation_impl(pysqlite_Connection *self, * the context before returning. */ if (callable != Py_None) { - Py_DECREF(callable); + free_callback_context(ctx); } _pysqlite_seterror(self->state, self->db); return NULL; diff --git a/Modules/_sqlite/connection.h b/Modules/_sqlite/connection.h index 4f08a6d5f7b..11b3a8005e1 100644 --- a/Modules/_sqlite/connection.h +++ b/Modules/_sqlite/connection.h @@ -32,6 +32,12 @@ #include "sqlite3.h" +typedef struct _callback_context +{ + PyObject *callable; + pysqlite_state *state; +} callback_context; + typedef struct { PyObject_HEAD