From 2c1ae09764446beda5248759fb99c859e14f1b25 Mon Sep 17 00:00:00 2001 From: Erlend Egeberg Aasland Date: Thu, 24 Jun 2021 13:56:56 +0200 Subject: [PATCH] bpo-43553: Improve `sqlite3` test coverage (GH-26886) --- Lib/sqlite3/test/dbapi.py | 26 ++++++++++++++++++++-- Lib/sqlite3/test/factory.py | 2 ++ Lib/sqlite3/test/types.py | 37 +++++++++++++++++++++++++++++++ Lib/sqlite3/test/userfunctions.py | 37 +++++++++++++++++++++++++++++++ 4 files changed, 100 insertions(+), 2 deletions(-) diff --git a/Lib/sqlite3/test/dbapi.py b/Lib/sqlite3/test/dbapi.py index 7e44cac76f5..1a4b44188bd 100644 --- a/Lib/sqlite3/test/dbapi.py +++ b/Lib/sqlite3/test/dbapi.py @@ -26,9 +26,8 @@ import sys import threading import unittest -from test.support import check_disallow_instantiation +from test.support import check_disallow_instantiation, threading_helper from test.support.os_helper import TESTFN, unlink -from test.support import threading_helper # Helper for tests using TESTFN @@ -110,6 +109,10 @@ class ModuleTests(unittest.TestCase): cx = sqlite.connect(":memory:") check_disallow_instantiation(self, type(cx("select 1"))) + def test_complete_statement(self): + self.assertFalse(sqlite.complete_statement("select t")) + self.assertTrue(sqlite.complete_statement("create table t(t);")) + class ConnectionTests(unittest.TestCase): @@ -225,6 +228,20 @@ class ConnectionTests(unittest.TestCase): self.assertTrue(hasattr(self.cx, exc)) self.assertIs(getattr(sqlite, exc), getattr(self.cx, exc)) + def test_interrupt_on_closed_db(self): + cx = sqlite.connect(":memory:") + cx.close() + with self.assertRaises(sqlite.ProgrammingError): + cx.interrupt() + + def test_interrupt(self): + self.assertIsNone(self.cx.interrupt()) + + def test_drop_unused_refs(self): + for n in range(500): + cu = self.cx.execute(f"select {n}") + self.assertEqual(cu.fetchone()[0], n) + class OpenTests(unittest.TestCase): _sql = "create table test(id integer)" @@ -594,6 +611,11 @@ class CursorTests(unittest.TestCase): new_count = len(res.description) self.assertEqual(new_count - old_count, 1) + def test_same_query_in_multiple_cursors(self): + cursors = [self.cx.execute("select 1") for _ in range(3)] + for cu in cursors: + self.assertEqual(cu.fetchall(), [(1,)]) + class ThreadTests(unittest.TestCase): def setUp(self): diff --git a/Lib/sqlite3/test/factory.py b/Lib/sqlite3/test/factory.py index 87642849754..7faa9ac8c1f 100644 --- a/Lib/sqlite3/test/factory.py +++ b/Lib/sqlite3/test/factory.py @@ -123,6 +123,8 @@ class RowFactoryTests(unittest.TestCase): row[-3] with self.assertRaises(IndexError): row[2**1000] + with self.assertRaises(IndexError): + row[complex()] # index must be int or string def test_sqlite_row_index_unicode(self): self.con.row_factory = sqlite.Row diff --git a/Lib/sqlite3/test/types.py b/Lib/sqlite3/test/types.py index 4bb1de80887..4f0e4f6d268 100644 --- a/Lib/sqlite3/test/types.py +++ b/Lib/sqlite3/test/types.py @@ -381,6 +381,43 @@ class ObjectAdaptationTests(unittest.TestCase): val = self.cur.fetchone()[0] self.assertEqual(type(val), float) + def test_missing_adapter(self): + with self.assertRaises(sqlite.ProgrammingError): + sqlite.adapt(1.) # No float adapter registered + + def test_missing_protocol(self): + with self.assertRaises(sqlite.ProgrammingError): + sqlite.adapt(1, None) + + def test_defect_proto(self): + class DefectProto(): + def __adapt__(self): + return None + with self.assertRaises(sqlite.ProgrammingError): + sqlite.adapt(1., DefectProto) + + def test_defect_self_adapt(self): + class DefectSelfAdapt(float): + def __conform__(self, _): + return None + with self.assertRaises(sqlite.ProgrammingError): + sqlite.adapt(DefectSelfAdapt(1.)) + + def test_custom_proto(self): + class CustomProto(): + def __adapt__(self): + return "adapted" + self.assertEqual(sqlite.adapt(1., CustomProto), "adapted") + + def test_adapt(self): + val = 42 + self.assertEqual(float(val), sqlite.adapt(val)) + + def test_adapt_alt(self): + alt = "other" + self.assertEqual(alt, sqlite.adapt(1., None, alt)) + + @unittest.skipUnless(zlib, "requires zlib") class BinaryConverterTests(unittest.TestCase): def convert(s): diff --git a/Lib/sqlite3/test/userfunctions.py b/Lib/sqlite3/test/userfunctions.py index 42908907249..dc900f6486f 100644 --- a/Lib/sqlite3/test/userfunctions.py +++ b/Lib/sqlite3/test/userfunctions.py @@ -21,11 +21,36 @@ # misrepresented as being the original software. # 3. This notice may not be removed or altered from any source distribution. +import contextlib +import functools +import io import unittest import unittest.mock import gc import sqlite3 as sqlite +def with_tracebacks(strings): + """Convenience decorator for testing callback tracebacks.""" + strings.append('Traceback') + + def decorator(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + # First, run the test with traceback enabled. + sqlite.enable_callback_tracebacks(True) + buf = io.StringIO() + with contextlib.redirect_stderr(buf): + func(self, *args, **kwargs) + tb = buf.getvalue() + for s in strings: + self.assertIn(s, tb) + + # Then run the test with traceback disabled. + sqlite.enable_callback_tracebacks(False) + func(self, *args, **kwargs) + return wrapper + return decorator + def func_returntext(): return "foo" def func_returnunicode(): @@ -228,6 +253,7 @@ class FunctionTests(unittest.TestCase): val = cur.fetchone()[0] self.assertEqual(val, 1<<31) + @with_tracebacks(['func_raiseexception', '5/0', 'ZeroDivisionError']) def test_func_exception(self): cur = self.con.cursor() with self.assertRaises(sqlite.OperationalError) as cm: @@ -387,6 +413,7 @@ class AggregateTests(unittest.TestCase): val = cur.fetchone()[0] self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error") + @with_tracebacks(['__init__', '5/0', 'ZeroDivisionError']) def test_aggr_exception_in_init(self): cur = self.con.cursor() with self.assertRaises(sqlite.OperationalError) as cm: @@ -394,6 +421,7 @@ class AggregateTests(unittest.TestCase): val = cur.fetchone()[0] self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error") + @with_tracebacks(['step', '5/0', 'ZeroDivisionError']) def test_aggr_exception_in_step(self): cur = self.con.cursor() with self.assertRaises(sqlite.OperationalError) as cm: @@ -401,6 +429,7 @@ class AggregateTests(unittest.TestCase): val = cur.fetchone()[0] self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error") + @with_tracebacks(['finalize', '5/0', 'ZeroDivisionError']) def test_aggr_exception_in_finalize(self): cur = self.con.cursor() with self.assertRaises(sqlite.OperationalError) as cm: @@ -502,6 +531,14 @@ class AuthorizerRaiseExceptionTests(AuthorizerTests): raise ValueError return sqlite.SQLITE_OK + @with_tracebacks(['authorizer_cb', 'ValueError']) + def test_table_access(self): + super().test_table_access() + + @with_tracebacks(['authorizer_cb', 'ValueError']) + def test_column_access(self): + super().test_table_access() + class AuthorizerIllegalTypeTests(AuthorizerTests): @staticmethod def authorizer_cb(action, arg1, arg2, dbname, source):