mirror of https://github.com/python/cpython
gh-89301: Fix regression with bound values in traced SQLite statements (#92053)
This commit is contained in:
parent
6811bdef63
commit
721aa96540
|
@ -20,12 +20,16 @@
|
|||
# misrepresented as being the original software.
|
||||
# 3. This notice may not be removed or altered from any source distribution.
|
||||
|
||||
import unittest
|
||||
import contextlib
|
||||
import sqlite3 as sqlite
|
||||
import unittest
|
||||
|
||||
from test.support.os_helper import TESTFN, unlink
|
||||
|
||||
from test.test_sqlite3.test_dbapi import memory_database, cx_limit
|
||||
from test.test_sqlite3.test_userfunctions import with_tracebacks
|
||||
|
||||
|
||||
class CollationTests(unittest.TestCase):
|
||||
def test_create_collation_not_string(self):
|
||||
con = sqlite.connect(":memory:")
|
||||
|
@ -224,6 +228,16 @@ class ProgressTests(unittest.TestCase):
|
|||
|
||||
|
||||
class TraceCallbackTests(unittest.TestCase):
|
||||
@contextlib.contextmanager
|
||||
def check_stmt_trace(self, cx, expected):
|
||||
try:
|
||||
traced = []
|
||||
cx.set_trace_callback(lambda stmt: traced.append(stmt))
|
||||
yield
|
||||
finally:
|
||||
self.assertEqual(traced, expected)
|
||||
cx.set_trace_callback(None)
|
||||
|
||||
def test_trace_callback_used(self):
|
||||
"""
|
||||
Test that the trace callback is invoked once it is set.
|
||||
|
@ -289,6 +303,52 @@ class TraceCallbackTests(unittest.TestCase):
|
|||
con2.close()
|
||||
self.assertEqual(traced_statements, queries)
|
||||
|
||||
def test_trace_expanded_sql(self):
|
||||
expected = [
|
||||
"create table t(t)",
|
||||
"BEGIN ",
|
||||
"insert into t values(0)",
|
||||
"insert into t values(1)",
|
||||
"insert into t values(2)",
|
||||
"COMMIT",
|
||||
]
|
||||
with memory_database() as cx, self.check_stmt_trace(cx, expected):
|
||||
with cx:
|
||||
cx.execute("create table t(t)")
|
||||
cx.executemany("insert into t values(?)", ((v,) for v in range(3)))
|
||||
|
||||
@with_tracebacks(
|
||||
sqlite.DataError,
|
||||
regex="Expanded SQL string exceeds the maximum string length"
|
||||
)
|
||||
def test_trace_too_much_expanded_sql(self):
|
||||
# If the expanded string is too large, we'll fall back to the
|
||||
# unexpanded SQL statement (for SQLite 3.14.0 and newer).
|
||||
# The resulting string length is limited by the runtime limit
|
||||
# SQLITE_LIMIT_LENGTH.
|
||||
template = "select 1 as a where a="
|
||||
category = sqlite.SQLITE_LIMIT_LENGTH
|
||||
with memory_database() as cx, cx_limit(cx, category=category) as lim:
|
||||
ok_param = "a"
|
||||
bad_param = "a" * lim
|
||||
|
||||
unexpanded_query = template + "?"
|
||||
expected = [unexpanded_query]
|
||||
if sqlite.sqlite_version_info < (3, 14, 0):
|
||||
expected = []
|
||||
with self.check_stmt_trace(cx, expected):
|
||||
cx.execute(unexpanded_query, (bad_param,))
|
||||
|
||||
expanded_query = f"{template}'{ok_param}'"
|
||||
with self.check_stmt_trace(cx, [expanded_query]):
|
||||
cx.execute(unexpanded_query, (ok_param,))
|
||||
|
||||
@with_tracebacks(ZeroDivisionError, regex="division by zero")
|
||||
def test_trace_bad_handler(self):
|
||||
with memory_database() as cx:
|
||||
cx.set_trace_callback(lambda stmt: 5/0)
|
||||
cx.execute("select 1")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
Fix a regression in the :mod:`sqlite3` trace callback where bound parameters
|
||||
were not expanded in the passed statement string. The regression was introduced
|
||||
in Python 3.10 by :issue:`40318`. Patch by Erlend E. Aasland.
|
|
@ -1332,11 +1332,10 @@ progress_callback(void *ctx)
|
|||
* to ensure future compatibility.
|
||||
*/
|
||||
static int
|
||||
trace_callback(unsigned int type, void *ctx, void *prepared_statement,
|
||||
void *statement_string)
|
||||
trace_callback(unsigned int type, void *ctx, void *stmt, void *sql)
|
||||
#else
|
||||
static void
|
||||
trace_callback(void *ctx, const char *statement_string)
|
||||
trace_callback(void *ctx, const char *sql)
|
||||
#endif
|
||||
{
|
||||
#ifdef HAVE_TRACE_V2
|
||||
|
@ -1347,24 +1346,51 @@ trace_callback(void *ctx, const char *statement_string)
|
|||
|
||||
PyGILState_STATE gilstate = PyGILState_Ensure();
|
||||
|
||||
PyObject *py_statement = NULL;
|
||||
PyObject *ret = NULL;
|
||||
py_statement = PyUnicode_DecodeUTF8(statement_string,
|
||||
strlen(statement_string), "replace");
|
||||
assert(ctx != NULL);
|
||||
if (py_statement) {
|
||||
PyObject *callable = ((callback_context *)ctx)->callable;
|
||||
ret = PyObject_CallOneArg(callable, py_statement);
|
||||
Py_DECREF(py_statement);
|
||||
}
|
||||
pysqlite_state *state = ((callback_context *)ctx)->state;
|
||||
assert(state != NULL);
|
||||
|
||||
if (ret) {
|
||||
Py_DECREF(ret);
|
||||
PyObject *py_statement = NULL;
|
||||
#ifdef HAVE_TRACE_V2
|
||||
const char *expanded_sql = sqlite3_expanded_sql((sqlite3_stmt *)stmt);
|
||||
if (expanded_sql == NULL) {
|
||||
sqlite3 *db = sqlite3_db_handle((sqlite3_stmt *)stmt);
|
||||
if (sqlite3_errcode(db) == SQLITE_NOMEM) {
|
||||
(void)PyErr_NoMemory();
|
||||
goto exit;
|
||||
}
|
||||
|
||||
PyErr_SetString(state->DataError,
|
||||
"Expanded SQL string exceeds the maximum string length");
|
||||
print_or_clear_traceback((callback_context *)ctx);
|
||||
|
||||
// Fall back to unexpanded sql
|
||||
py_statement = PyUnicode_FromString((const char *)sql);
|
||||
}
|
||||
else {
|
||||
print_or_clear_traceback(ctx);
|
||||
py_statement = PyUnicode_FromString(expanded_sql);
|
||||
sqlite3_free((void *)expanded_sql);
|
||||
}
|
||||
#else
|
||||
if (sql == NULL) {
|
||||
PyErr_SetString(state->DataError,
|
||||
"Expanded SQL string exceeds the maximum string length");
|
||||
print_or_clear_traceback((callback_context *)ctx);
|
||||
goto exit;
|
||||
}
|
||||
py_statement = PyUnicode_FromString(sql);
|
||||
#endif
|
||||
if (py_statement) {
|
||||
PyObject *callable = ((callback_context *)ctx)->callable;
|
||||
PyObject *ret = PyObject_CallOneArg(callable, py_statement);
|
||||
Py_DECREF(py_statement);
|
||||
Py_XDECREF(ret);
|
||||
}
|
||||
if (PyErr_Occurred()) {
|
||||
print_or_clear_traceback((callback_context *)ctx);
|
||||
}
|
||||
|
||||
exit:
|
||||
PyGILState_Release(gilstate);
|
||||
#ifdef HAVE_TRACE_V2
|
||||
return 0;
|
||||
|
|
Loading…
Reference in New Issue