From 88567a997005c9388137cd18c5d7f4483423dac3 Mon Sep 17 00:00:00 2001 From: Erlend Egeberg Aasland Date: Thu, 3 Mar 2022 14:54:36 +0100 Subject: [PATCH] bpo-46874: Speed up sqlite3 user-defined aggregate 'step' method (GH-31604) --- Lib/test/test_sqlite3/test_userfunctions.py | 6 ++++-- Modules/_sqlite/connection.c | 10 ++++++---- Modules/_sqlite/module.c | 2 ++ Modules/_sqlite/module.h | 1 + 4 files changed, 13 insertions(+), 6 deletions(-) diff --git a/Lib/test/test_sqlite3/test_userfunctions.py b/Lib/test/test_sqlite3/test_userfunctions.py index 23ecfb4e8a6..2588cae3d1f 100644 --- a/Lib/test/test_sqlite3/test_userfunctions.py +++ b/Lib/test/test_sqlite3/test_userfunctions.py @@ -502,11 +502,13 @@ class AggregateTests(unittest.TestCase): with self.assertRaises(sqlite.OperationalError): self.con.create_function("bla", -100, AggrSum) + @with_tracebacks(AttributeError, name="AggrNoStep") def test_aggr_no_step(self): cur = self.con.cursor() - with self.assertRaises(AttributeError) as cm: + with self.assertRaises(sqlite.OperationalError) as cm: cur.execute("select nostep(t) from test") - self.assertEqual(str(cm.exception), "'AggrNoStep' object has no attribute 'step'") + self.assertEqual(str(cm.exception), + "user-defined aggregate's 'step' method not defined") def test_aggr_no_finalize(self): cur = self.con.cursor() diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index 0efb5ae35a7..9f12e691f89 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -734,11 +734,11 @@ step_callback(sqlite3_context *context, int argc, sqlite3_value **params) PyObject** aggregate_instance; PyObject* stepmethod = NULL; - aggregate_instance = (PyObject**)sqlite3_aggregate_context(context, sizeof(PyObject*)); + callback_context *ctx = (callback_context *)sqlite3_user_data(context); + assert(ctx != NULL); + aggregate_instance = (PyObject**)sqlite3_aggregate_context(context, sizeof(PyObject*)); if (*aggregate_instance == NULL) { - callback_context *ctx = (callback_context *)sqlite3_user_data(context); - assert(ctx != NULL); *aggregate_instance = PyObject_CallNoArgs(ctx->callable); if (!*aggregate_instance) { set_sqlite_error(context, @@ -747,8 +747,10 @@ step_callback(sqlite3_context *context, int argc, sqlite3_value **params) } } - stepmethod = PyObject_GetAttrString(*aggregate_instance, "step"); + stepmethod = PyObject_GetAttr(*aggregate_instance, ctx->state->str_step); if (!stepmethod) { + set_sqlite_error(context, + "user-defined aggregate's 'step' method not defined"); goto error; } diff --git a/Modules/_sqlite/module.c b/Modules/_sqlite/module.c index 70fde4910f6..563105c6391 100644 --- a/Modules/_sqlite/module.c +++ b/Modules/_sqlite/module.c @@ -627,6 +627,7 @@ module_clear(PyObject *module) Py_CLEAR(state->str___conform__); Py_CLEAR(state->str_executescript); Py_CLEAR(state->str_finalize); + Py_CLEAR(state->str_step); Py_CLEAR(state->str_upper); return 0; @@ -713,6 +714,7 @@ module_exec(PyObject *module) ADD_INTERNED(state, __conform__); ADD_INTERNED(state, executescript); ADD_INTERNED(state, finalize); + ADD_INTERNED(state, step); ADD_INTERNED(state, upper); /* Set error constants */ diff --git a/Modules/_sqlite/module.h b/Modules/_sqlite/module.h index 35c6f385526..cca52d1e04b 100644 --- a/Modules/_sqlite/module.h +++ b/Modules/_sqlite/module.h @@ -64,6 +64,7 @@ typedef struct { PyObject *str___conform__; PyObject *str_executescript; PyObject *str_finalize; + PyObject *str_step; PyObject *str_upper; } pysqlite_state;