some cleanup and modernization

This commit is contained in:
Benjamin Peterson 2009-03-24 00:35:20 +00:00
parent 21b617bd98
commit a7d441de68
1 changed files with 50 additions and 41 deletions

View File

@ -251,12 +251,16 @@ class TestResult(object):
(_strclass(self.__class__), self.testsRun, len(self.errors), (_strclass(self.__class__), self.testsRun, len(self.errors),
len(self.failures)) len(self.failures))
class AssertRaisesContext(object): class AssertRaisesContext(object):
def __init__(self, expected, test_case): def __init__(self, expected, test_case):
self.expected = expected self.expected = expected
self.failureException = test_case.failureException self.failureException = test_case.failureException
def __enter__(self): def __enter__(self):
pass pass
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None: if exc_type is None:
try: try:
@ -270,6 +274,7 @@ class AssertRaisesContext(object):
# Let unexpected exceptions skip through # Let unexpected exceptions skip through
return False return False
class TestCase(object): class TestCase(object):
"""A class whose instances are single test cases. """A class whose instances are single test cases.
@ -303,13 +308,13 @@ class TestCase(object):
method when executed. Raises a ValueError if the instance does method when executed. Raises a ValueError if the instance does
not have a method with the specified name. not have a method with the specified name.
""" """
self._testMethodName = methodName
try: try:
self._testMethodName = methodName
testMethod = getattr(self, methodName) testMethod = getattr(self, methodName)
self._testMethodDoc = testMethod.__doc__
except AttributeError: except AttributeError:
raise ValueError("no such test method in %s: %s" % \ raise ValueError("no such test method in %s: %s" % \
(self.__class__, methodName)) (self.__class__, methodName))
self._testMethodDoc = testMethod.__doc__
def setUp(self): def setUp(self):
"Hook method for setting up the test fixture before exercising it." "Hook method for setting up the test fixture before exercising it."
@ -340,7 +345,7 @@ class TestCase(object):
def __eq__(self, other): def __eq__(self, other):
if type(self) is not type(other): if type(self) is not type(other):
return False return NotImplemented
return self._testMethodName == other._testMethodName return self._testMethodName == other._testMethodName
@ -358,7 +363,8 @@ class TestCase(object):
(_strclass(self.__class__), self._testMethodName) (_strclass(self.__class__), self._testMethodName)
def run(self, result=None): def run(self, result=None):
if result is None: result = self.defaultTestResult() if result is None:
result = self.defaultTestResult()
result.startTest(self) result.startTest(self)
testMethod = getattr(self, self._testMethodName) testMethod = getattr(self, self._testMethodName)
try: try:
@ -423,11 +429,13 @@ class TestCase(object):
def failIf(self, expr, msg=None): def failIf(self, expr, msg=None):
"Fail the test if the expression is true." "Fail the test if the expression is true."
if expr: raise self.failureException(msg) if expr:
raise self.failureException(msg)
def failUnless(self, expr, msg=None): def failUnless(self, expr, msg=None):
"""Fail the test unless the expression is true.""" """Fail the test unless the expression is true."""
if not expr: raise self.failureException(msg) if not expr:
raise self.failureException(msg)
def failUnlessRaises(self, excClass, callableObj=None, *args, **kwargs): def failUnlessRaises(self, excClass, callableObj=None, *args, **kwargs):
"""Fail unless an exception of class excClass is thrown """Fail unless an exception of class excClass is thrown
@ -521,8 +529,6 @@ class TestSuite(object):
def __repr__(self): def __repr__(self):
return "<%s tests=%s>" % (_strclass(self.__class__), self._tests) return "<%s tests=%s>" % (_strclass(self.__class__), self._tests)
__str__ = __repr__
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, self.__class__): if not isinstance(other, self.__class__):
return NotImplemented return NotImplemented
@ -547,8 +553,7 @@ class TestSuite(object):
# sanity checks # sanity checks
if not hasattr(test, '__call__'): if not hasattr(test, '__call__'):
raise TypeError("the test to add must be callable") raise TypeError("the test to add must be callable")
if (isinstance(test, (type, types.ClassType)) and if isinstance(test, type) and issubclass(test, (TestCase, TestSuite)):
issubclass(test, (TestCase, TestSuite))):
raise TypeError("TestCases and TestSuites must be instantiated " raise TypeError("TestCases and TestSuites must be instantiated "
"before passing them to addTest()") "before passing them to addTest()")
self._tests.append(test) self._tests.append(test)
@ -571,7 +576,8 @@ class TestSuite(object):
def debug(self): def debug(self):
"""Run the tests without collecting errors in a TestResult""" """Run the tests without collecting errors in a TestResult"""
for test in self._tests: test.debug() for test in self._tests:
test.debug()
class ClassTestSuite(TestSuite): class ClassTestSuite(TestSuite):
@ -614,9 +620,8 @@ class FunctionTestCase(TestCase):
always be called if the set-up ('setUp') function ran successfully. always be called if the set-up ('setUp') function ran successfully.
""" """
def __init__(self, testFunc, setUp=None, tearDown=None, def __init__(self, testFunc, setUp=None, tearDown=None, description=None):
description=None): super(FunctionTestCase, self).__init__()
TestCase.__init__(self)
self.__setUpFunc = setUp self.__setUpFunc = setUp
self.__tearDownFunc = tearDown self.__tearDownFunc = tearDown
self.__testFunc = testFunc self.__testFunc = testFunc
@ -637,8 +642,8 @@ class FunctionTestCase(TestCase):
return self.__testFunc.__name__ return self.__testFunc.__name__
def __eq__(self, other): def __eq__(self, other):
if type(self) is not type(other): if not isinstance(other, self.__class__):
return False return NotImplemented
return self.__setUpFunc == other.__setUpFunc and \ return self.__setUpFunc == other.__setUpFunc and \
self.__tearDownFunc == other.__tearDownFunc and \ self.__tearDownFunc == other.__tearDownFunc and \
@ -670,8 +675,9 @@ class FunctionTestCase(TestCase):
############################################################################## ##############################################################################
class TestLoader(object): class TestLoader(object):
"""This class is responsible for loading tests according to various """
criteria and returning them wrapped in a TestSuite This class is responsible for loading tests according to various criteria
and returning them wrapped in a TestSuite
""" """
testMethodPrefix = 'test' testMethodPrefix = 'test'
sortTestMethodsUsing = cmp sortTestMethodsUsing = cmp
@ -681,7 +687,8 @@ class TestLoader(object):
def loadTestsFromTestCase(self, testCaseClass): def loadTestsFromTestCase(self, testCaseClass):
"""Return a suite of all tests cases contained in testCaseClass""" """Return a suite of all tests cases contained in testCaseClass"""
if issubclass(testCaseClass, TestSuite): if issubclass(testCaseClass, TestSuite):
raise TypeError("Test cases should not be derived from TestSuite. Maybe you meant to derive from TestCase?") raise TypeError("Test cases should not be derived from TestSuite." \
" Maybe you meant to derive from TestCase?")
testCaseNames = self.getTestCaseNames(testCaseClass) testCaseNames = self.getTestCaseNames(testCaseClass)
if not testCaseNames and hasattr(testCaseClass, 'runTest'): if not testCaseNames and hasattr(testCaseClass, 'runTest'):
testCaseNames = ['runTest'] testCaseNames = ['runTest']
@ -694,8 +701,7 @@ class TestLoader(object):
tests = [] tests = []
for name in dir(module): for name in dir(module):
obj = getattr(module, name) obj = getattr(module, name)
if (isinstance(obj, (type, types.ClassType)) and if isinstance(obj, type) and issubclass(obj, TestCase):
issubclass(obj, TestCase)):
tests.append(self.loadTestsFromTestCase(obj)) tests.append(self.loadTestsFromTestCase(obj))
return self.suiteClass(tests) return self.suiteClass(tests)
@ -717,7 +723,8 @@ class TestLoader(object):
break break
except ImportError: except ImportError:
del parts_copy[-1] del parts_copy[-1]
if not parts_copy: raise if not parts_copy:
raise
parts = parts[1:] parts = parts[1:]
obj = module obj = module
for part in parts: for part in parts:
@ -725,11 +732,10 @@ class TestLoader(object):
if isinstance(obj, types.ModuleType): if isinstance(obj, types.ModuleType):
return self.loadTestsFromModule(obj) return self.loadTestsFromModule(obj)
elif (isinstance(obj, (type, types.ClassType)) and elif isinstance(obj, type) and issubclass(obj, TestCase):
issubclass(obj, TestCase)):
return self.loadTestsFromTestCase(obj) return self.loadTestsFromTestCase(obj)
elif (isinstance(obj, types.UnboundMethodType) and elif (isinstance(obj, types.UnboundMethodType) and
isinstance(parent, (type, types.ClassType)) and isinstance(parent, type) and
issubclass(parent, TestCase)): issubclass(parent, TestCase)):
return TestSuite([parent(obj.__name__)]) return TestSuite([parent(obj.__name__)])
elif isinstance(obj, TestSuite): elif isinstance(obj, TestSuite):
@ -756,8 +762,10 @@ class TestLoader(object):
def getTestCaseNames(self, testCaseClass): def getTestCaseNames(self, testCaseClass):
"""Return a sorted sequence of method names found within testCaseClass """Return a sorted sequence of method names found within testCaseClass
""" """
def isTestMethod(attrname, testCaseClass=testCaseClass, prefix=self.testMethodPrefix): def isTestMethod(attrname, testCaseClass=testCaseClass,
return attrname.startswith(prefix) and hasattr(getattr(testCaseClass, attrname), '__call__') prefix=self.testMethodPrefix):
return attrname.startswith(prefix) and \
hasattr(getattr(testCaseClass, attrname), '__call__')
testFnNames = filter(isTestMethod, dir(testCaseClass)) testFnNames = filter(isTestMethod, dir(testCaseClass))
if self.sortTestMethodsUsing: if self.sortTestMethodsUsing:
testFnNames.sort(key=_CmpToKey(self.sortTestMethodsUsing)) testFnNames.sort(key=_CmpToKey(self.sortTestMethodsUsing))
@ -815,7 +823,7 @@ class _TextTestResult(TestResult):
separator2 = '-' * 70 separator2 = '-' * 70
def __init__(self, stream, descriptions, verbosity): def __init__(self, stream, descriptions, verbosity):
TestResult.__init__(self) super(_TextTestResult, self).__init__()
self.stream = stream self.stream = stream
self.showAll = verbosity > 1 self.showAll = verbosity > 1
self.dots = verbosity == 1 self.dots = verbosity == 1
@ -828,14 +836,14 @@ class _TextTestResult(TestResult):
return str(test) return str(test)
def startTest(self, test): def startTest(self, test):
TestResult.startTest(self, test) super(_TextTestResult, self).startTest(test)
if self.showAll: if self.showAll:
self.stream.write(self.getDescription(test)) self.stream.write(self.getDescription(test))
self.stream.write(" ... ") self.stream.write(" ... ")
self.stream.flush() self.stream.flush()
def addSuccess(self, test): def addSuccess(self, test):
TestResult.addSuccess(self, test) super(_TextTestResult, self).addSuccess(test)
if self.showAll: if self.showAll:
self.stream.writeln("ok") self.stream.writeln("ok")
elif self.dots: elif self.dots:
@ -843,7 +851,7 @@ class _TextTestResult(TestResult):
self.stream.flush() self.stream.flush()
def addError(self, test, err): def addError(self, test, err):
TestResult.addError(self, test, err) super(_TextTestResult, self).addError(test, err)
if self.showAll: if self.showAll:
self.stream.writeln("ERROR") self.stream.writeln("ERROR")
elif self.dots: elif self.dots:
@ -851,7 +859,7 @@ class _TextTestResult(TestResult):
self.stream.flush() self.stream.flush()
def addFailure(self, test, err): def addFailure(self, test, err):
TestResult.addFailure(self, test, err) super(_TextTestResult, self).addFailure(test, err)
if self.showAll: if self.showAll:
self.stream.writeln("FAIL") self.stream.writeln("FAIL")
elif self.dots: elif self.dots:
@ -859,7 +867,7 @@ class _TextTestResult(TestResult):
self.stream.flush() self.stream.flush()
def addSkip(self, test, reason): def addSkip(self, test, reason):
TestResult.addSkip(self, test, reason) super(_TextTestResult, self).addSkip(test, reason)
if self.showAll: if self.showAll:
self.stream.writeln("skipped {0!r}".format(reason)) self.stream.writeln("skipped {0!r}".format(reason))
elif self.dots: elif self.dots:
@ -867,7 +875,7 @@ class _TextTestResult(TestResult):
self.stream.flush() self.stream.flush()
def addExpectedFailure(self, test, err): def addExpectedFailure(self, test, err):
TestResult.addExpectedFailure(self, test, err) super(_TextTestResult, self).addExpectedFailure(test, err)
if self.showAll: if self.showAll:
self.stream.writeln("expected failure") self.stream.writeln("expected failure")
elif self.dots: elif self.dots:
@ -875,7 +883,7 @@ class _TextTestResult(TestResult):
self.stream.flush() self.stream.flush()
def addUnexpectedSuccess(self, test): def addUnexpectedSuccess(self, test):
TestResult.addUnexpectedSuccess(self, test) super(_TextTestResult, self).addUnexpectedSuccess(test)
if self.showAll: if self.showAll:
self.stream.writeln("unexpected success") self.stream.writeln("unexpected success")
elif self.dots: elif self.dots:
@ -936,13 +944,13 @@ class TextTestRunner(object):
if errored: if errored:
infos.append("errors=%d" % errored) infos.append("errors=%d" % errored)
else: else:
self.stream.write("OK") self.stream.writeln("OK")
if skipped: if skipped:
infos.append("skipped=%d" % skipped) infos.append("skipped=%d" % skipped)
if expected_fails: if expectedFails:
infos.append("expected failures=%d" % expected_fails) infos.append("expected failures=%d" % expectedFails)
if unexpected_successes: if unexpectedSuccesses:
infos.append("unexpected successes=%d" % unexpected_successes) infos.append("unexpected successes=%d" % unexpectedSuccesses)
if infos: if infos:
self.stream.writeln(" (%s)" % (", ".join(infos),)) self.stream.writeln(" (%s)" % (", ".join(infos),))
return result return result
@ -992,7 +1000,8 @@ Examples:
self.runTests() self.runTests()
def usageExit(self, msg=None): def usageExit(self, msg=None):
if msg: print msg if msg:
print msg
print self.USAGE % self.__dict__ print self.USAGE % self.__dict__
sys.exit(2) sys.exit(2)