diff --git a/Lib/unittest.py b/Lib/unittest.py index 043b9a848a4..f44769e9261 100644 --- a/Lib/unittest.py +++ b/Lib/unittest.py @@ -27,7 +27,7 @@ Further information is available in the bundled documentation, and from http://pyunit.sourceforge.net/ -Copyright (c) 1999, 2000, 2001 Steve Purcell +Copyright (c) 1999-2003 Steve Purcell This module is free software, and you may redistribute it and/or modify it under the same terms as Python itself, so long as this copyright message and disclaimer are retained in their original form. @@ -46,12 +46,11 @@ SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. __author__ = "Steve Purcell" __email__ = "stephen_purcell at yahoo dot com" -__version__ = "#Revision: 1.46 $"[11:-2] +__version__ = "#Revision: 1.56 $"[11:-2] import time import sys import traceback -import string import os import types @@ -61,10 +60,26 @@ import types __all__ = ['TestResult', 'TestCase', 'TestSuite', 'TextTestRunner', 'TestLoader', 'FunctionTestCase', 'main', 'defaultTestLoader'] -# Expose obsolete functions for backwards compatability +# Expose obsolete functions for backwards compatibility __all__.extend(['getTestCaseNames', 'makeSuite', 'findTestCases']) +############################################################################## +# Backward compatibility +############################################################################## +if sys.version_info[:2] < (2, 2): + False, True = 0, 1 + def isinstance(obj, clsinfo): + import __builtin__ + if type(clsinfo) in (types.TupleType, types.ListType): + for cls in clsinfo: + if cls is type: cls = types.ClassType + if __builtin__.isinstance(obj, cls): + return 1 + return 0 + else: return __builtin__.isinstance(obj, clsinfo) + + ############################################################################## # Test framework core ############################################################################## @@ -121,11 +136,11 @@ class TestResult: def stop(self): "Indicates that the tests should be aborted" - self.shouldStop = 1 + self.shouldStop = True def _exc_info_to_string(self, err): """Converts a sys.exc_info()-style tuple of values into a string.""" - return string.join(traceback.format_exception(*err), '') + return ''.join(traceback.format_exception(*err)) def __repr__(self): return "<%s run=%i errors=%i failures=%i>" % \ @@ -196,7 +211,7 @@ class TestCase: the specified test method's docstring. """ doc = self.__testMethodDoc - return doc and string.strip(string.split(doc, "\n")[0]) or None + return doc and doc.split("\n")[0].strip() or None def id(self): return "%s.%s" % (_strclass(self.__class__), self.__testMethodName) @@ -209,9 +224,6 @@ class TestCase: (_strclass(self.__class__), self.__testMethodName) def run(self, result=None): - return self(result) - - def __call__(self, result=None): if result is None: result = self.defaultTestResult() result.startTest(self) testMethod = getattr(self, self.__testMethodName) @@ -224,10 +236,10 @@ class TestCase: result.addError(self, self.__exc_info()) return - ok = 0 + ok = False try: testMethod() - ok = 1 + ok = True except self.failureException: result.addFailure(self, self.__exc_info()) except KeyboardInterrupt: @@ -241,11 +253,13 @@ class TestCase: raise except: result.addError(self, self.__exc_info()) - ok = 0 + ok = False if ok: result.addSuccess(self) finally: result.stopTest(self) + __call__ = run + def debug(self): """Run the test without collecting errors in a TestResult""" self.setUp() @@ -292,7 +306,7 @@ class TestCase: else: if hasattr(excClass,'__name__'): excName = excClass.__name__ else: excName = str(excClass) - raise self.failureException, excName + raise self.failureException, "%s not raised" % excName def failUnlessEqual(self, first, second, msg=None): """Fail if the two objects are unequal as determined by the '==' @@ -334,6 +348,8 @@ class TestCase: raise self.failureException, \ (msg or '%s == %s within %s places' % (`first`, `second`, `places`)) + # Synonyms for assertion methods + assertEqual = assertEquals = failUnlessEqual assertNotEqual = assertNotEquals = failIfEqual @@ -344,7 +360,9 @@ class TestCase: assertRaises = failUnlessRaises - assert_ = failUnless + assert_ = assertTrue = failUnless + + assertFalse = failIf @@ -369,7 +387,7 @@ class TestSuite: def countTestCases(self): cases = 0 for test in self._tests: - cases = cases + test.countTestCases() + cases += test.countTestCases() return cases def addTest(self, test): @@ -434,7 +452,7 @@ class FunctionTestCase(TestCase): def shortDescription(self): if self.__description is not None: return self.__description doc = self.__testFunc.__doc__ - return doc and string.strip(string.split(doc, "\n")[0]) or None + return doc and doc.split("\n")[0].strip() or None @@ -452,8 +470,10 @@ class TestLoader: def loadTestsFromTestCase(self, testCaseClass): """Return a suite of all tests cases contained in testCaseClass""" - return self.suiteClass(map(testCaseClass, - self.getTestCaseNames(testCaseClass))) + testCaseNames = self.getTestCaseNames(testCaseClass) + if not testCaseNames and hasattr(testCaseClass, 'runTest'): + testCaseNames = ['runTest'] + return self.suiteClass(map(testCaseClass, testCaseNames)) def loadTestsFromModule(self, module): """Return a suite of all tests cases contained in the given module""" @@ -474,23 +494,20 @@ class TestLoader: The method optionally resolves the names relative to a given module. """ - parts = string.split(name, '.') + parts = name.split('.') if module is None: - if not parts: - raise ValueError, "incomplete test name: %s" % name - else: - parts_copy = parts[:] - while parts_copy: - try: - module = __import__(string.join(parts_copy,'.')) - break - except ImportError: - del parts_copy[-1] - if not parts_copy: raise + parts_copy = parts[:] + while parts_copy: + try: + module = __import__('.'.join(parts_copy)) + break + except ImportError: + del parts_copy[-1] + if not parts_copy: raise parts = parts[1:] obj = module for part in parts: - obj = getattr(obj, part) + parent, obj = obj, getattr(obj, part) import unittest if type(obj) == types.ModuleType: @@ -499,11 +516,13 @@ class TestLoader: issubclass(obj, unittest.TestCase)): return self.loadTestsFromTestCase(obj) elif type(obj) == types.UnboundMethodType: + return parent(obj.__name__) return obj.im_class(obj.__name__) + elif isinstance(obj, unittest.TestSuite): + return obj elif callable(obj): test = obj() - if not isinstance(test, unittest.TestCase) and \ - not isinstance(test, unittest.TestSuite): + if not isinstance(test, (unittest.TestCase, unittest.TestSuite)): raise ValueError, \ "calling %s returned %s, not a test" % (obj,test) return test @@ -514,16 +533,15 @@ class TestLoader: """Return a suite of all tests cases found using the given sequence of string specifiers. See 'loadTestsFromName()'. """ - suites = [] - for name in names: - suites.append(self.loadTestsFromName(name, module)) + suites = [self.loadTestsFromName(name, module) for name in names] return self.suiteClass(suites) def getTestCaseNames(self, testCaseClass): """Return a sorted sequence of method names found within testCaseClass """ - testFnNames = filter(lambda n,p=self.testMethodPrefix: n[:len(p)] == p, - dir(testCaseClass)) + def isTestMethod(attrname, testCaseClass=testCaseClass, prefix=self.testMethodPrefix): + return attrname[:len(prefix)] == prefix and callable(getattr(testCaseClass, attrname)) + testFnNames = filter(isTestMethod, dir(testCaseClass)) for baseclass in testCaseClass.__bases__: for testFnName in self.getTestCaseNames(baseclass): if testFnName not in testFnNames: # handle overridden methods @@ -706,7 +724,7 @@ Examples: argv=None, testRunner=None, testLoader=defaultTestLoader): if type(module) == type(''): self.module = __import__(module) - for part in string.split(module,'.')[1:]: + for part in module.split('.')[1:]: self.module = getattr(self.module, part) else: self.module = module