import sys import unittest from contextlib import closing from functools import partial from pathlib import Path from test.support import import_helper, os_helper dbm_sqlite3 = import_helper.import_module("dbm.sqlite3") # N.B. The test will fail on some platforms without sqlite3 # if the sqlite3 import is above the import of dbm.sqlite3. # This is deliberate: if the import helper managed to import dbm.sqlite3, # we must inevitably be able to import sqlite3. Else, we have a problem. import sqlite3 from dbm.sqlite3 import _normalize_uri class _SQLiteDbmTests(unittest.TestCase): def setUp(self): self.filename = os_helper.TESTFN db = dbm_sqlite3.open(self.filename, "c") db.close() def tearDown(self): for suffix in "", "-wal", "-shm": os_helper.unlink(self.filename + suffix) class URI(unittest.TestCase): def test_uri_substitutions(self): dataset = ( ("/absolute/////b/c", "/absolute/b/c"), ("PRE#MID##END", "PRE%23MID%23%23END"), ("%#?%%#", "%25%23%3F%25%25%23"), ) for path, normalized in dataset: with self.subTest(path=path, normalized=normalized): self.assertTrue(_normalize_uri(path).endswith(normalized)) @unittest.skipUnless(sys.platform == "win32", "requires Windows") def test_uri_windows(self): dataset = ( # Relative subdir. (r"2018\January.xlsx", "2018/January.xlsx"), # Absolute with drive letter. (r"C:\Projects\apilibrary\apilibrary.sln", "/C:/Projects/apilibrary/apilibrary.sln"), # Relative with drive letter. (r"C:Projects\apilibrary\apilibrary.sln", "/C:Projects/apilibrary/apilibrary.sln"), ) for path, normalized in dataset: with self.subTest(path=path, normalized=normalized): if not Path(path).is_absolute(): self.skipTest(f"skipping relative path: {path!r}") self.assertTrue(_normalize_uri(path).endswith(normalized)) class ReadOnly(_SQLiteDbmTests): def setUp(self): super().setUp() with dbm_sqlite3.open(self.filename, "w") as db: db[b"key1"] = "value1" db[b"key2"] = "value2" self.db = dbm_sqlite3.open(self.filename, "r") def tearDown(self): self.db.close() super().tearDown() def test_readonly_read(self): self.assertEqual(self.db[b"key1"], b"value1") self.assertEqual(self.db[b"key2"], b"value2") def test_readonly_write(self): with self.assertRaises(dbm_sqlite3.error): self.db[b"new"] = "value" def test_readonly_delete(self): with self.assertRaises(dbm_sqlite3.error): del self.db[b"key1"] def test_readonly_keys(self): self.assertEqual(self.db.keys(), [b"key1", b"key2"]) def test_readonly_iter(self): self.assertEqual([k for k in self.db], [b"key1", b"key2"]) class ReadWrite(_SQLiteDbmTests): def setUp(self): super().setUp() self.db = dbm_sqlite3.open(self.filename, "w") def tearDown(self): self.db.close() super().tearDown() def db_content(self): with closing(sqlite3.connect(self.filename)) as cx: keys = [r[0] for r in cx.execute("SELECT key FROM Dict")] vals = [r[0] for r in cx.execute("SELECT value FROM Dict")] return keys, vals def test_readwrite_unique_key(self): self.db["key"] = "value" self.db["key"] = "other" keys, vals = self.db_content() self.assertEqual(keys, [b"key"]) self.assertEqual(vals, [b"other"]) def test_readwrite_delete(self): self.db["key"] = "value" self.db["new"] = "other" del self.db[b"new"] keys, vals = self.db_content() self.assertEqual(keys, [b"key"]) self.assertEqual(vals, [b"value"]) del self.db[b"key"] keys, vals = self.db_content() self.assertEqual(keys, []) self.assertEqual(vals, []) def test_readwrite_null_key(self): with self.assertRaises(dbm_sqlite3.error): self.db[None] = "value" def test_readwrite_null_value(self): with self.assertRaises(dbm_sqlite3.error): self.db[b"key"] = None class Misuse(_SQLiteDbmTests): def setUp(self): super().setUp() self.db = dbm_sqlite3.open(self.filename, "w") def tearDown(self): self.db.close() super().tearDown() def test_misuse_double_create(self): self.db["key"] = "value" with dbm_sqlite3.open(self.filename, "c") as db: self.assertEqual(db[b"key"], b"value") def test_misuse_double_close(self): self.db.close() def test_misuse_invalid_flag(self): regex = "must be.*'r'.*'w'.*'c'.*'n', not 'invalid'" with self.assertRaisesRegex(ValueError, regex): dbm_sqlite3.open(self.filename, flag="invalid") def test_misuse_double_delete(self): self.db["key"] = "value" del self.db[b"key"] with self.assertRaises(KeyError): del self.db[b"key"] def test_misuse_invalid_key(self): with self.assertRaises(KeyError): self.db[b"key"] def test_misuse_iter_close1(self): self.db["1"] = 1 it = iter(self.db) self.db.close() with self.assertRaises(dbm_sqlite3.error): next(it) def test_misuse_iter_close2(self): self.db["1"] = 1 self.db["2"] = 2 it = iter(self.db) next(it) self.db.close() with self.assertRaises(dbm_sqlite3.error): next(it) def test_misuse_use_after_close(self): self.db.close() with self.assertRaises(dbm_sqlite3.error): self.db[b"read"] with self.assertRaises(dbm_sqlite3.error): self.db[b"write"] = "value" with self.assertRaises(dbm_sqlite3.error): del self.db[b"del"] with self.assertRaises(dbm_sqlite3.error): len(self.db) with self.assertRaises(dbm_sqlite3.error): self.db.keys() def test_misuse_reinit(self): with self.assertRaises(dbm_sqlite3.error): self.db.__init__("new.db", flag="n", mode=0o666) def test_misuse_empty_filename(self): for flag in "r", "w", "c", "n": with self.assertRaises(dbm_sqlite3.error): db = dbm_sqlite3.open("", flag="c") class DataTypes(_SQLiteDbmTests): dataset = ( # (raw, coerced) (42, b"42"), (3.14, b"3.14"), ("string", b"string"), (b"bytes", b"bytes"), ) def setUp(self): super().setUp() self.db = dbm_sqlite3.open(self.filename, "w") def tearDown(self): self.db.close() super().tearDown() def test_datatypes_values(self): for raw, coerced in self.dataset: with self.subTest(raw=raw, coerced=coerced): self.db["key"] = raw self.assertEqual(self.db[b"key"], coerced) def test_datatypes_keys(self): for raw, coerced in self.dataset: with self.subTest(raw=raw, coerced=coerced): self.db[raw] = "value" self.assertEqual(self.db[coerced], b"value") # Raw keys are silently coerced to bytes. self.assertEqual(self.db[raw], b"value") del self.db[raw] def test_datatypes_replace_coerced(self): self.db["10"] = "value" self.db[b"10"] = "value" self.db[10] = "value" self.assertEqual(self.db.keys(), [b"10"]) class CorruptDatabase(_SQLiteDbmTests): """Verify that database exceptions are raised as dbm.sqlite3.error.""" def setUp(self): super().setUp() with closing(sqlite3.connect(self.filename)) as cx: with cx: cx.execute("DROP TABLE IF EXISTS Dict") cx.execute("CREATE TABLE Dict (invalid_schema)") def check(self, flag, fn, should_succeed=False): with closing(dbm_sqlite3.open(self.filename, flag)) as db: with self.assertRaises(dbm_sqlite3.error): fn(db) @staticmethod def read(db): return db["key"] @staticmethod def write(db): db["key"] = "value" @staticmethod def iter(db): next(iter(db)) @staticmethod def keys(db): db.keys() @staticmethod def del_(db): del db["key"] @staticmethod def len_(db): len(db) def test_corrupt_readwrite(self): for flag in "r", "w", "c": with self.subTest(flag=flag): check = partial(self.check, flag=flag) check(fn=self.read) check(fn=self.write) check(fn=self.iter) check(fn=self.keys) check(fn=self.del_) check(fn=self.len_) def test_corrupt_force_new(self): with closing(dbm_sqlite3.open(self.filename, "n")) as db: db["foo"] = "write" _ = db[b"foo"] next(iter(db)) del db[b"foo"] if __name__ == "__main__": unittest.main()