From 5b25f1d03100e2283c1b129d461ba68ac0169a14 Mon Sep 17 00:00:00 2001 From: Sergey Fedoseev Date: Wed, 5 Dec 2018 22:50:26 +0500 Subject: [PATCH] bpo-34052: Prevent SQLite functions from setting callbacks on exceptions. (GH-8113) --- Lib/sqlite3/test/regression.py | 83 ++++++++++++++----- .../2018-07-24-16-37-40.bpo-34052.VbbFAE.rst | 7 ++ Modules/_sqlite/connection.c | 31 +++---- 3 files changed, 84 insertions(+), 37 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2018-07-24-16-37-40.bpo-34052.VbbFAE.rst diff --git a/Lib/sqlite3/test/regression.py b/Lib/sqlite3/test/regression.py index 34cd233535d..1c59a3cd31c 100644 --- a/Lib/sqlite3/test/regression.py +++ b/Lib/sqlite3/test/regression.py @@ -256,24 +256,6 @@ class RegressionTests(unittest.TestCase): cur.execute("pragma page_size") row = cur.fetchone() - def CheckSetDict(self): - """ - See http://bugs.python.org/issue7478 - - It was possible to successfully register callbacks that could not be - hashed. Return codes of PyDict_SetItem were not checked properly. - """ - class NotHashable: - def __call__(self, *args, **kw): - pass - def __hash__(self): - raise TypeError() - var = NotHashable() - self.assertRaises(TypeError, self.con.create_function, var) - self.assertRaises(TypeError, self.con.create_aggregate, var) - self.assertRaises(TypeError, self.con.set_authorizer, var) - self.assertRaises(TypeError, self.con.set_progress_handler, var) - def CheckConnectionCall(self): """ Call a connection with a non-string SQL request: check error handling @@ -398,9 +380,72 @@ class RegressionTests(unittest.TestCase): support.gc_collect() +class UnhashableFunc: + __hash__ = None + + def __init__(self, return_value=None): + self.calls = 0 + self.return_value = return_value + + def __call__(self, *args, **kwargs): + self.calls += 1 + return self.return_value + + +class UnhashableCallbacksTestCase(unittest.TestCase): + """ + https://bugs.python.org/issue34052 + + Registering unhashable callbacks raises TypeError, callbacks are not + registered in SQLite after such registration attempt. + """ + def setUp(self): + self.con = sqlite.connect(':memory:') + + def tearDown(self): + self.con.close() + + def test_progress_handler(self): + f = UnhashableFunc(return_value=0) + with self.assertRaisesRegex(TypeError, 'unhashable type'): + self.con.set_progress_handler(f, 1) + self.con.execute('SELECT 1') + self.assertFalse(f.calls) + + def test_func(self): + func_name = 'func_name' + f = UnhashableFunc() + with self.assertRaisesRegex(TypeError, 'unhashable type'): + self.con.create_function(func_name, 0, f) + msg = 'no such function: %s' % func_name + with self.assertRaisesRegex(sqlite.OperationalError, msg): + self.con.execute('SELECT %s()' % func_name) + self.assertFalse(f.calls) + + def test_authorizer(self): + f = UnhashableFunc(return_value=sqlite.SQLITE_DENY) + with self.assertRaisesRegex(TypeError, 'unhashable type'): + self.con.set_authorizer(f) + self.con.execute('SELECT 1') + self.assertFalse(f.calls) + + def test_aggr(self): + class UnhashableType(type): + __hash__ = None + aggr_name = 'aggr_name' + with self.assertRaisesRegex(TypeError, 'unhashable type'): + self.con.create_aggregate(aggr_name, 0, UnhashableType('Aggr', (), {})) + msg = 'no such function: %s' % aggr_name + with self.assertRaisesRegex(sqlite.OperationalError, msg): + self.con.execute('SELECT %s()' % aggr_name) + + def suite(): regression_suite = unittest.makeSuite(RegressionTests, "Check") - return unittest.TestSuite((regression_suite,)) + return unittest.TestSuite(( + regression_suite, + unittest.makeSuite(UnhashableCallbacksTestCase), + )) def test(): runner = unittest.TextTestRunner() diff --git a/Misc/NEWS.d/next/Library/2018-07-24-16-37-40.bpo-34052.VbbFAE.rst b/Misc/NEWS.d/next/Library/2018-07-24-16-37-40.bpo-34052.VbbFAE.rst new file mode 100644 index 00000000000..5aa3cc9a81d --- /dev/null +++ b/Misc/NEWS.d/next/Library/2018-07-24-16-37-40.bpo-34052.VbbFAE.rst @@ -0,0 +1,7 @@ +:meth:`sqlite3.Connection.create_aggregate`, +:meth:`sqlite3.Connection.create_function`, +:meth:`sqlite3.Connection.set_authorizer`, +:meth:`sqlite3.Connection.set_progress_handler` methods raises TypeError +when unhashable objects are passed as callable. These methods now don't pass +such objects to SQLite API. Previous behavior could lead to segfaults. Patch +by Sergey Fedoseev. diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index b59d7d28cfd..fe0d03bb2b0 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -843,7 +843,9 @@ PyObject* pysqlite_connection_create_function(pysqlite_Connection* self, PyObjec flags |= SQLITE_DETERMINISTIC; #endif } - + if (PyDict_SetItem(self->function_pinboard, func, Py_None) == -1) { + return NULL; + } rc = sqlite3_create_function(self->db, name, narg, @@ -857,12 +859,8 @@ PyObject* pysqlite_connection_create_function(pysqlite_Connection* self, PyObjec /* Workaround for SQLite bug: no error code or string is available here */ PyErr_SetString(pysqlite_OperationalError, "Error creating function"); return NULL; - } else { - if (PyDict_SetItem(self->function_pinboard, func, Py_None) == -1) - return NULL; - - Py_RETURN_NONE; } + Py_RETURN_NONE; } PyObject* pysqlite_connection_create_aggregate(pysqlite_Connection* self, PyObject* args, PyObject* kwargs) @@ -883,17 +881,16 @@ PyObject* pysqlite_connection_create_aggregate(pysqlite_Connection* self, PyObje return NULL; } + if (PyDict_SetItem(self->function_pinboard, aggregate_class, Py_None) == -1) { + return NULL; + } rc = sqlite3_create_function(self->db, name, n_arg, SQLITE_UTF8, (void*)aggregate_class, 0, &_pysqlite_step_callback, &_pysqlite_final_callback); if (rc != SQLITE_OK) { /* Workaround for SQLite bug: no error code or string is available here */ PyErr_SetString(pysqlite_OperationalError, "Error creating aggregate"); return NULL; - } else { - if (PyDict_SetItem(self->function_pinboard, aggregate_class, Py_None) == -1) - return NULL; - - Py_RETURN_NONE; } + Py_RETURN_NONE; } static int _authorizer_callback(void* user_arg, int action, const char* arg1, const char* arg2 , const char* dbname, const char* access_attempt_source) @@ -1006,17 +1003,15 @@ static PyObject* pysqlite_connection_set_authorizer(pysqlite_Connection* self, P return NULL; } + if (PyDict_SetItem(self->function_pinboard, authorizer_cb, Py_None) == -1) { + return NULL; + } rc = sqlite3_set_authorizer(self->db, _authorizer_callback, (void*)authorizer_cb); - if (rc != SQLITE_OK) { PyErr_SetString(pysqlite_OperationalError, "Error setting authorizer callback"); return NULL; - } else { - if (PyDict_SetItem(self->function_pinboard, authorizer_cb, Py_None) == -1) - return NULL; - - Py_RETURN_NONE; } + Py_RETURN_NONE; } static PyObject* pysqlite_connection_set_progress_handler(pysqlite_Connection* self, PyObject* args, PyObject* kwargs) @@ -1039,9 +1034,9 @@ static PyObject* pysqlite_connection_set_progress_handler(pysqlite_Connection* s /* None clears the progress handler previously set */ sqlite3_progress_handler(self->db, 0, 0, (void*)0); } else { - sqlite3_progress_handler(self->db, n, _progress_handler, progress_handler); if (PyDict_SetItem(self->function_pinboard, progress_handler, Py_None) == -1) return NULL; + sqlite3_progress_handler(self->db, n, _progress_handler, progress_handler); } Py_RETURN_NONE;