From b17acad68ea21c60dbc2088644f2934032304628 Mon Sep 17 00:00:00 2001 From: Georg Brandl Date: Wed, 28 May 2008 08:43:17 +0000 Subject: [PATCH] Make db modules' error classes inherit IOError. Stop dbm from importing every dbm module when imported. --- Lib/dbm/__init__.py | 42 +++++++++++++++-------------------- Lib/dbm/bsd.py | 3 ++- Lib/test/test_dbm.py | 51 +++++++++++++++++++++++-------------------- Modules/_dbmmodule.c | 3 ++- Modules/_gdbmmodule.c | 2 +- 5 files changed, 50 insertions(+), 51 deletions(-) diff --git a/Lib/dbm/__init__.py b/Lib/dbm/__init__.py index 9fdd4145cc1..2082e073357 100644 --- a/Lib/dbm/__init__.py +++ b/Lib/dbm/__init__.py @@ -48,27 +48,26 @@ class error(Exception): pass _names = ['dbm.bsd', 'dbm.gnu', 'dbm.ndbm', 'dbm.dumb'] -_errors = [error] _defaultmod = None _modules = {} -for _name in _names: - try: - _mod = __import__(_name, fromlist=['open']) - except ImportError: - continue - if not _defaultmod: - _defaultmod = _mod - _modules[_name] = _mod - _errors.append(_mod.error) - -if not _defaultmod: - raise ImportError("no dbm clone found; tried %s" % _names) - -error = tuple(_errors) +error = (error, IOError) def open(file, flag = 'r', mode = 0o666): + global _defaultmod + if _defaultmod is None: + for name in _names: + try: + mod = __import__(name, fromlist=['open']) + except ImportError: + continue + if not _defaultmod: + _defaultmod = mod + _modules[name] = mod + if not _defaultmod: + raise ImportError("no dbm clone found; tried %s" % _names) + # guess the type of an existing database result = whichdb(file) if result is None: @@ -81,19 +80,14 @@ def open(file, flag = 'r', mode = 0o666): elif result == "": # db type cannot be determined raise error("db type could not be determined") + elif result not in _modules: + raise error("db type is {0}, but the module is not " + "available".format(result)) else: mod = _modules[result] return mod.open(file, flag, mode) -try: - from dbm import ndbm - _dbmerror = ndbm.error -except ImportError: - ndbm = None - # just some sort of valid exception which might be raised in the ndbm test - _dbmerror = IOError - def whichdb(filename): """Guess which db package to use to open a db file. @@ -129,7 +123,7 @@ def whichdb(filename): d = ndbm.open(filename) d.close() return "dbm.ndbm" - except (IOError, _dbmerror): + except IOError: pass # Check for dumbdbm next -- this has a .dir and a .dat file diff --git a/Lib/dbm/bsd.py b/Lib/dbm/bsd.py index 8353f503766..2dccadb8b1e 100644 --- a/Lib/dbm/bsd.py +++ b/Lib/dbm/bsd.py @@ -4,7 +4,8 @@ import bsddb __all__ = ["error", "open"] -error = bsddb.error +class error(bsddb.error, IOError): + pass def open(file, flag = 'r', mode=0o666): return bsddb.hashopen(file, flag, mode) diff --git a/Lib/test/test_dbm.py b/Lib/test/test_dbm.py index aab1388d791..41c37cbea51 100644 --- a/Lib/test/test_dbm.py +++ b/Lib/test/test_dbm.py @@ -14,11 +14,13 @@ _fname = test.support.TESTFN # setting dbm to use each in turn, and yielding that module # def dbm_iterator(): - old_default = dbm._defaultmod - for module in dbm._modules.values(): - dbm._defaultmod = module - yield module - dbm._defaultmod = old_default + for name in dbm._names: + try: + mod = __import__(name, fromlist=['open']) + except ImportError: + continue + dbm._modules[name] = mod + yield mod # # Clean up all scratch databases we might have created during testing @@ -40,8 +42,20 @@ class AnyDBMTestCase(unittest.TestCase): 'g': b'intended', } - def __init__(self, *args): - unittest.TestCase.__init__(self, *args) + def init_db(self): + f = dbm.open(_fname, 'n') + for k in self._dict: + f[k.encode("ascii")] = self._dict[k] + f.close() + + def keys_helper(self, f): + keys = sorted(k.decode("ascii") for k in f.keys()) + dkeys = sorted(self._dict.keys()) + self.assertEqual(keys, dkeys) + return keys + + def test_error(self): + self.assert_(issubclass(self.module.error, IOError)) def test_anydbm_creation(self): f = dbm.open(_fname, 'c') @@ -83,22 +97,11 @@ class AnyDBMTestCase(unittest.TestCase): for key in self._dict: self.assertEqual(self._dict[key], f[key.encode("ascii")]) - def init_db(self): - f = dbm.open(_fname, 'n') - for k in self._dict: - f[k.encode("ascii")] = self._dict[k] - f.close() - - def keys_helper(self, f): - keys = sorted(k.decode("ascii") for k in f.keys()) - dkeys = sorted(self._dict.keys()) - self.assertEqual(keys, dkeys) - return keys - def tearDown(self): delete_files() def setUp(self): + dbm._defaultmod = self.module delete_files() @@ -137,11 +140,11 @@ class WhichDBTestCase(unittest.TestCase): def test_main(): - try: - for module in dbm_iterator(): - test.support.run_unittest(AnyDBMTestCase, WhichDBTestCase) - finally: - delete_files() + classes = [WhichDBTestCase] + for mod in dbm_iterator(): + classes.append(type("TestCase-" + mod.__name__, (AnyDBMTestCase,), + {'module': mod})) + test.support.run_unittest(*classes) if __name__ == "__main__": test_main() diff --git a/Modules/_dbmmodule.c b/Modules/_dbmmodule.c index ddfd4cd759a..7e80381db77 100644 --- a/Modules/_dbmmodule.c +++ b/Modules/_dbmmodule.c @@ -401,7 +401,8 @@ init_dbm(void) { return; d = PyModule_GetDict(m); if (DbmError == NULL) - DbmError = PyErr_NewException("_dbm.error", NULL, NULL); + DbmError = PyErr_NewException("_dbm.error", + PyExc_IOError, NULL); s = PyUnicode_FromString(which_dbm); if (s != NULL) { PyDict_SetItemString(d, "library", s); diff --git a/Modules/_gdbmmodule.c b/Modules/_gdbmmodule.c index 6c7581969ae..abc88370911 100644 --- a/Modules/_gdbmmodule.c +++ b/Modules/_gdbmmodule.c @@ -523,7 +523,7 @@ init_gdbm(void) { if (m == NULL) return; d = PyModule_GetDict(m); - DbmError = PyErr_NewException("_gdbm.error", NULL, NULL); + DbmError = PyErr_NewException("_gdbm.error", PyExc_IOError, NULL); if (DbmError != NULL) { PyDict_SetItemString(d, "error", DbmError); s = PyUnicode_FromString(dbmmodule_open_flags);