mirror of https://github.com/python/cpython
- Fixed loading of tests by name when name refers to unbound
method (PyUnit issue 563882, thanks to Alexandre Fayolle) - Ignore non-callable attributes of classes when searching for test method names (PyUnit issue 769338, thanks to Seth Falcon) - New assertTrue and assertFalse aliases for comfort of JUnit users - Automatically discover 'runTest()' test methods (PyUnit issue 469444, thanks to Roeland Rengelink) - Dropped Python 1.5.2 compatibility, merged appropriate shortcuts from Python CVS; should work with Python >= 2.1. - Removed all references to string module by using string methods instead
This commit is contained in:
parent
1e80359733
commit
7e74384af5
|
@ -27,7 +27,7 @@ Further information is available in the bundled documentation, and from
|
|||
|
||||
http://pyunit.sourceforge.net/
|
||||
|
||||
Copyright (c) 1999, 2000, 2001 Steve Purcell
|
||||
Copyright (c) 1999-2003 Steve Purcell
|
||||
This module is free software, and you may redistribute it and/or modify
|
||||
it under the same terms as Python itself, so long as this copyright message
|
||||
and disclaimer are retained in their original form.
|
||||
|
@ -46,12 +46,11 @@ SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
|
|||
|
||||
__author__ = "Steve Purcell"
|
||||
__email__ = "stephen_purcell at yahoo dot com"
|
||||
__version__ = "#Revision: 1.46 $"[11:-2]
|
||||
__version__ = "#Revision: 1.56 $"[11:-2]
|
||||
|
||||
import time
|
||||
import sys
|
||||
import traceback
|
||||
import string
|
||||
import os
|
||||
import types
|
||||
|
||||
|
@ -61,10 +60,26 @@ import types
|
|||
__all__ = ['TestResult', 'TestCase', 'TestSuite', 'TextTestRunner',
|
||||
'TestLoader', 'FunctionTestCase', 'main', 'defaultTestLoader']
|
||||
|
||||
# Expose obsolete functions for backwards compatability
|
||||
# Expose obsolete functions for backwards compatibility
|
||||
__all__.extend(['getTestCaseNames', 'makeSuite', 'findTestCases'])
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Backward compatibility
|
||||
##############################################################################
|
||||
if sys.version_info[:2] < (2, 2):
|
||||
False, True = 0, 1
|
||||
def isinstance(obj, clsinfo):
|
||||
import __builtin__
|
||||
if type(clsinfo) in (types.TupleType, types.ListType):
|
||||
for cls in clsinfo:
|
||||
if cls is type: cls = types.ClassType
|
||||
if __builtin__.isinstance(obj, cls):
|
||||
return 1
|
||||
return 0
|
||||
else: return __builtin__.isinstance(obj, clsinfo)
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Test framework core
|
||||
##############################################################################
|
||||
|
@ -121,11 +136,11 @@ class TestResult:
|
|||
|
||||
def stop(self):
|
||||
"Indicates that the tests should be aborted"
|
||||
self.shouldStop = 1
|
||||
self.shouldStop = True
|
||||
|
||||
def _exc_info_to_string(self, err):
|
||||
"""Converts a sys.exc_info()-style tuple of values into a string."""
|
||||
return string.join(traceback.format_exception(*err), '')
|
||||
return ''.join(traceback.format_exception(*err))
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s run=%i errors=%i failures=%i>" % \
|
||||
|
@ -196,7 +211,7 @@ class TestCase:
|
|||
the specified test method's docstring.
|
||||
"""
|
||||
doc = self.__testMethodDoc
|
||||
return doc and string.strip(string.split(doc, "\n")[0]) or None
|
||||
return doc and doc.split("\n")[0].strip() or None
|
||||
|
||||
def id(self):
|
||||
return "%s.%s" % (_strclass(self.__class__), self.__testMethodName)
|
||||
|
@ -209,9 +224,6 @@ class TestCase:
|
|||
(_strclass(self.__class__), self.__testMethodName)
|
||||
|
||||
def run(self, result=None):
|
||||
return self(result)
|
||||
|
||||
def __call__(self, result=None):
|
||||
if result is None: result = self.defaultTestResult()
|
||||
result.startTest(self)
|
||||
testMethod = getattr(self, self.__testMethodName)
|
||||
|
@ -224,10 +236,10 @@ class TestCase:
|
|||
result.addError(self, self.__exc_info())
|
||||
return
|
||||
|
||||
ok = 0
|
||||
ok = False
|
||||
try:
|
||||
testMethod()
|
||||
ok = 1
|
||||
ok = True
|
||||
except self.failureException:
|
||||
result.addFailure(self, self.__exc_info())
|
||||
except KeyboardInterrupt:
|
||||
|
@ -241,11 +253,13 @@ class TestCase:
|
|||
raise
|
||||
except:
|
||||
result.addError(self, self.__exc_info())
|
||||
ok = 0
|
||||
ok = False
|
||||
if ok: result.addSuccess(self)
|
||||
finally:
|
||||
result.stopTest(self)
|
||||
|
||||
__call__ = run
|
||||
|
||||
def debug(self):
|
||||
"""Run the test without collecting errors in a TestResult"""
|
||||
self.setUp()
|
||||
|
@ -292,7 +306,7 @@ class TestCase:
|
|||
else:
|
||||
if hasattr(excClass,'__name__'): excName = excClass.__name__
|
||||
else: excName = str(excClass)
|
||||
raise self.failureException, excName
|
||||
raise self.failureException, "%s not raised" % excName
|
||||
|
||||
def failUnlessEqual(self, first, second, msg=None):
|
||||
"""Fail if the two objects are unequal as determined by the '=='
|
||||
|
@ -334,6 +348,8 @@ class TestCase:
|
|||
raise self.failureException, \
|
||||
(msg or '%s == %s within %s places' % (`first`, `second`, `places`))
|
||||
|
||||
# Synonyms for assertion methods
|
||||
|
||||
assertEqual = assertEquals = failUnlessEqual
|
||||
|
||||
assertNotEqual = assertNotEquals = failIfEqual
|
||||
|
@ -344,7 +360,9 @@ class TestCase:
|
|||
|
||||
assertRaises = failUnlessRaises
|
||||
|
||||
assert_ = failUnless
|
||||
assert_ = assertTrue = failUnless
|
||||
|
||||
assertFalse = failIf
|
||||
|
||||
|
||||
|
||||
|
@ -369,7 +387,7 @@ class TestSuite:
|
|||
def countTestCases(self):
|
||||
cases = 0
|
||||
for test in self._tests:
|
||||
cases = cases + test.countTestCases()
|
||||
cases += test.countTestCases()
|
||||
return cases
|
||||
|
||||
def addTest(self, test):
|
||||
|
@ -434,7 +452,7 @@ class FunctionTestCase(TestCase):
|
|||
def shortDescription(self):
|
||||
if self.__description is not None: return self.__description
|
||||
doc = self.__testFunc.__doc__
|
||||
return doc and string.strip(string.split(doc, "\n")[0]) or None
|
||||
return doc and doc.split("\n")[0].strip() or None
|
||||
|
||||
|
||||
|
||||
|
@ -452,8 +470,10 @@ class TestLoader:
|
|||
|
||||
def loadTestsFromTestCase(self, testCaseClass):
|
||||
"""Return a suite of all tests cases contained in testCaseClass"""
|
||||
return self.suiteClass(map(testCaseClass,
|
||||
self.getTestCaseNames(testCaseClass)))
|
||||
testCaseNames = self.getTestCaseNames(testCaseClass)
|
||||
if not testCaseNames and hasattr(testCaseClass, 'runTest'):
|
||||
testCaseNames = ['runTest']
|
||||
return self.suiteClass(map(testCaseClass, testCaseNames))
|
||||
|
||||
def loadTestsFromModule(self, module):
|
||||
"""Return a suite of all tests cases contained in the given module"""
|
||||
|
@ -474,23 +494,20 @@ class TestLoader:
|
|||
|
||||
The method optionally resolves the names relative to a given module.
|
||||
"""
|
||||
parts = string.split(name, '.')
|
||||
parts = name.split('.')
|
||||
if module is None:
|
||||
if not parts:
|
||||
raise ValueError, "incomplete test name: %s" % name
|
||||
else:
|
||||
parts_copy = parts[:]
|
||||
while parts_copy:
|
||||
try:
|
||||
module = __import__(string.join(parts_copy,'.'))
|
||||
break
|
||||
except ImportError:
|
||||
del parts_copy[-1]
|
||||
if not parts_copy: raise
|
||||
parts_copy = parts[:]
|
||||
while parts_copy:
|
||||
try:
|
||||
module = __import__('.'.join(parts_copy))
|
||||
break
|
||||
except ImportError:
|
||||
del parts_copy[-1]
|
||||
if not parts_copy: raise
|
||||
parts = parts[1:]
|
||||
obj = module
|
||||
for part in parts:
|
||||
obj = getattr(obj, part)
|
||||
parent, obj = obj, getattr(obj, part)
|
||||
|
||||
import unittest
|
||||
if type(obj) == types.ModuleType:
|
||||
|
@ -499,11 +516,13 @@ class TestLoader:
|
|||
issubclass(obj, unittest.TestCase)):
|
||||
return self.loadTestsFromTestCase(obj)
|
||||
elif type(obj) == types.UnboundMethodType:
|
||||
return parent(obj.__name__)
|
||||
return obj.im_class(obj.__name__)
|
||||
elif isinstance(obj, unittest.TestSuite):
|
||||
return obj
|
||||
elif callable(obj):
|
||||
test = obj()
|
||||
if not isinstance(test, unittest.TestCase) and \
|
||||
not isinstance(test, unittest.TestSuite):
|
||||
if not isinstance(test, (unittest.TestCase, unittest.TestSuite)):
|
||||
raise ValueError, \
|
||||
"calling %s returned %s, not a test" % (obj,test)
|
||||
return test
|
||||
|
@ -514,16 +533,15 @@ class TestLoader:
|
|||
"""Return a suite of all tests cases found using the given sequence
|
||||
of string specifiers. See 'loadTestsFromName()'.
|
||||
"""
|
||||
suites = []
|
||||
for name in names:
|
||||
suites.append(self.loadTestsFromName(name, module))
|
||||
suites = [self.loadTestsFromName(name, module) for name in names]
|
||||
return self.suiteClass(suites)
|
||||
|
||||
def getTestCaseNames(self, testCaseClass):
|
||||
"""Return a sorted sequence of method names found within testCaseClass
|
||||
"""
|
||||
testFnNames = filter(lambda n,p=self.testMethodPrefix: n[:len(p)] == p,
|
||||
dir(testCaseClass))
|
||||
def isTestMethod(attrname, testCaseClass=testCaseClass, prefix=self.testMethodPrefix):
|
||||
return attrname[:len(prefix)] == prefix and callable(getattr(testCaseClass, attrname))
|
||||
testFnNames = filter(isTestMethod, dir(testCaseClass))
|
||||
for baseclass in testCaseClass.__bases__:
|
||||
for testFnName in self.getTestCaseNames(baseclass):
|
||||
if testFnName not in testFnNames: # handle overridden methods
|
||||
|
@ -706,7 +724,7 @@ Examples:
|
|||
argv=None, testRunner=None, testLoader=defaultTestLoader):
|
||||
if type(module) == type(''):
|
||||
self.module = __import__(module)
|
||||
for part in string.split(module,'.')[1:]:
|
||||
for part in module.split('.')[1:]:
|
||||
self.module = getattr(self.module, part)
|
||||
else:
|
||||
self.module = module
|
||||
|
|
Loading…
Reference in New Issue