bpo-42064: Pass module state to `sqlite3` UDF callbacks (GH-27456)

- Establish common callback context struct
- Convert UDF callbacks to fetch module state from callback context
This commit is contained in:
Erlend Egeberg Aasland 2021-08-24 14:24:09 +02:00 committed by GitHub
parent 7179930ab5
commit 9ed523159c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 64 additions and 31 deletions

View File

@ -612,8 +612,10 @@ set_sqlite_error(sqlite3_context *context, const char *msg)
else { else {
sqlite3_result_error(context, msg, -1); sqlite3_result_error(context, msg, -1);
} }
pysqlite_state *state = pysqlite_get_state(NULL); callback_context *ctx = (callback_context *)sqlite3_user_data(context);
if (state->enable_callback_tracebacks) { assert(ctx != NULL);
assert(ctx->state != NULL);
if (ctx->state->enable_callback_tracebacks) {
PyErr_Print(); PyErr_Print();
} }
else { else {
@ -625,7 +627,6 @@ static void
_pysqlite_func_callback(sqlite3_context *context, int argc, sqlite3_value **argv) _pysqlite_func_callback(sqlite3_context *context, int argc, sqlite3_value **argv)
{ {
PyObject* args; PyObject* args;
PyObject* py_func;
PyObject* py_retval = NULL; PyObject* py_retval = NULL;
int ok; int ok;
@ -633,11 +634,11 @@ _pysqlite_func_callback(sqlite3_context *context, int argc, sqlite3_value **argv
threadstate = PyGILState_Ensure(); threadstate = PyGILState_Ensure();
py_func = (PyObject*)sqlite3_user_data(context);
args = _pysqlite_build_py_params(context, argc, argv); args = _pysqlite_build_py_params(context, argc, argv);
if (args) { if (args) {
py_retval = PyObject_CallObject(py_func, args); callback_context *ctx = (callback_context *)sqlite3_user_data(context);
assert(ctx != NULL);
py_retval = PyObject_CallObject(ctx->callable, args);
Py_DECREF(args); Py_DECREF(args);
} }
@ -657,7 +658,6 @@ static void _pysqlite_step_callback(sqlite3_context *context, int argc, sqlite3_
{ {
PyObject* args; PyObject* args;
PyObject* function_result = NULL; PyObject* function_result = NULL;
PyObject* aggregate_class;
PyObject** aggregate_instance; PyObject** aggregate_instance;
PyObject* stepmethod = NULL; PyObject* stepmethod = NULL;
@ -665,12 +665,12 @@ static void _pysqlite_step_callback(sqlite3_context *context, int argc, sqlite3_
threadstate = PyGILState_Ensure(); threadstate = PyGILState_Ensure();
aggregate_class = (PyObject*)sqlite3_user_data(context);
aggregate_instance = (PyObject**)sqlite3_aggregate_context(context, sizeof(PyObject*)); aggregate_instance = (PyObject**)sqlite3_aggregate_context(context, sizeof(PyObject*));
if (*aggregate_instance == NULL) { if (*aggregate_instance == NULL) {
*aggregate_instance = _PyObject_CallNoArg(aggregate_class); callback_context *ctx = (callback_context *)sqlite3_user_data(context);
assert(ctx != NULL);
*aggregate_instance = _PyObject_CallNoArg(ctx->callable);
if (!*aggregate_instance) { if (!*aggregate_instance) {
set_sqlite_error(context, set_sqlite_error(context,
"user-defined aggregate's '__init__' method raised error"); "user-defined aggregate's '__init__' method raised error");
@ -784,14 +784,35 @@ static void _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self)
Py_SETREF(self->cursors, new_list); Py_SETREF(self->cursors, new_list);
} }
static callback_context *
create_callback_context(pysqlite_state *state, PyObject *callable)
{
PyGILState_STATE gstate = PyGILState_Ensure();
callback_context *ctx = PyMem_Malloc(sizeof(callback_context));
if (ctx != NULL) {
ctx->callable = Py_NewRef(callable);
ctx->state = state;
}
PyGILState_Release(gstate);
return ctx;
}
static void
free_callback_context(callback_context *ctx)
{
if (ctx != NULL) {
// This function may be called without the GIL held, so we need to
// ensure that we destroy 'ctx' with the GIL held.
PyGILState_STATE gstate = PyGILState_Ensure();
Py_DECREF(ctx->callable);
PyMem_Free(ctx);
PyGILState_Release(gstate);
}
}
static void _destructor(void* args) static void _destructor(void* args)
{ {
// This function may be called without the GIL held, so we need to ensure free_callback_context((callback_context *)args);
// that we destroy 'args' with the GIL
PyGILState_STATE gstate;
gstate = PyGILState_Ensure();
Py_DECREF((PyObject*)args);
PyGILState_Release(gstate);
} }
/*[clinic input] /*[clinic input]
@ -833,11 +854,11 @@ pysqlite_connection_create_function_impl(pysqlite_Connection *self,
flags |= SQLITE_DETERMINISTIC; flags |= SQLITE_DETERMINISTIC;
#endif #endif
} }
rc = sqlite3_create_function_v2(self->db, callback_context *ctx = create_callback_context(self->state, func);
name, if (ctx == NULL) {
narg, return NULL;
flags, }
(void*)Py_NewRef(func), rc = sqlite3_create_function_v2(self->db, name, narg, flags, ctx,
_pysqlite_func_callback, _pysqlite_func_callback,
NULL, NULL,
NULL, NULL,
@ -873,11 +894,12 @@ pysqlite_connection_create_aggregate_impl(pysqlite_Connection *self,
return NULL; return NULL;
} }
rc = sqlite3_create_function_v2(self->db, callback_context *ctx = create_callback_context(self->state,
name, aggregate_class);
n_arg, if (ctx == NULL) {
SQLITE_UTF8, return NULL;
(void*)Py_NewRef(aggregate_class), }
rc = sqlite3_create_function_v2(self->db, name, n_arg, SQLITE_UTF8, ctx,
0, 0,
&_pysqlite_step_callback, &_pysqlite_step_callback,
&_pysqlite_final_callback, &_pysqlite_final_callback,
@ -1439,7 +1461,6 @@ pysqlite_collation_callback(
int text1_length, const void* text1_data, int text1_length, const void* text1_data,
int text2_length, const void* text2_data) int text2_length, const void* text2_data)
{ {
PyObject* callback = (PyObject*)context;
PyObject* string1 = 0; PyObject* string1 = 0;
PyObject* string2 = 0; PyObject* string2 = 0;
PyGILState_STATE gilstate; PyGILState_STATE gilstate;
@ -1459,8 +1480,10 @@ pysqlite_collation_callback(
goto finally; /* failed to allocate strings */ goto finally; /* failed to allocate strings */
} }
callback_context *ctx = (callback_context *)context;
assert(ctx != NULL);
PyObject *args[] = { string1, string2 }; // Borrowed refs. PyObject *args[] = { string1, string2 }; // Borrowed refs.
retval = PyObject_Vectorcall(callback, args, 2, NULL); retval = PyObject_Vectorcall(ctx->callable, args, 2, NULL);
if (retval == NULL) { if (retval == NULL) {
/* execution failed */ /* execution failed */
goto finally; goto finally;
@ -1690,6 +1713,7 @@ pysqlite_connection_create_collation_impl(pysqlite_Connection *self,
return NULL; return NULL;
} }
callback_context *ctx = NULL;
int rc; int rc;
int flags = SQLITE_UTF8; int flags = SQLITE_UTF8;
if (callable == Py_None) { if (callable == Py_None) {
@ -1701,8 +1725,11 @@ pysqlite_connection_create_collation_impl(pysqlite_Connection *self,
PyErr_SetString(PyExc_TypeError, "parameter must be callable"); PyErr_SetString(PyExc_TypeError, "parameter must be callable");
return NULL; return NULL;
} }
rc = sqlite3_create_collation_v2(self->db, name, flags, ctx = create_callback_context(self->state, callable);
Py_NewRef(callable), if (ctx == NULL) {
return NULL;
}
rc = sqlite3_create_collation_v2(self->db, name, flags, ctx,
&pysqlite_collation_callback, &pysqlite_collation_callback,
&_destructor); &_destructor);
} }
@ -1713,7 +1740,7 @@ pysqlite_connection_create_collation_impl(pysqlite_Connection *self,
* the context before returning. * the context before returning.
*/ */
if (callable != Py_None) { if (callable != Py_None) {
Py_DECREF(callable); free_callback_context(ctx);
} }
_pysqlite_seterror(self->state, self->db); _pysqlite_seterror(self->state, self->db);
return NULL; return NULL;

View File

@ -32,6 +32,12 @@
#include "sqlite3.h" #include "sqlite3.h"
typedef struct _callback_context
{
PyObject *callable;
pysqlite_state *state;
} callback_context;
typedef struct typedef struct
{ {
PyObject_HEAD PyObject_HEAD