From 5ffa325a82b89a7b79d779b32a4d9ad2ad91aadd Mon Sep 17 00:00:00 2001 From: Michael Foord Date: Sun, 7 Mar 2010 22:04:55 +0000 Subject: [PATCH] Addition of setUpClass and setUpModule shared fixtures to unittest. --- Lib/test/test_unittest.py | 396 +++++++++++++++++++++++++++++++++++++- Lib/unittest/__init__.py | 3 +- Lib/unittest/case.py | 11 ++ Lib/unittest/result.py | 2 + Lib/unittest/suite.py | 199 ++++++++++++++++++- 5 files changed, 600 insertions(+), 11 deletions(-) diff --git a/Lib/test/test_unittest.py b/Lib/test/test_unittest.py index be4811ccf51..1b486282e7a 100644 --- a/Lib/test/test_unittest.py +++ b/Lib/test/test_unittest.py @@ -22,6 +22,9 @@ import warnings ### Support code ################################################################ +def resultFactory(*_): + return unittest.TestResult() + class LoggingResult(unittest.TestResult): def __init__(self, log): self._events = log @@ -3937,6 +3940,397 @@ class TestDiscovery(TestCase): self.assertEqual(program.verbosity, 2) +class TestSetups(unittest.TestCase): + + def getRunner(self): + return unittest.TextTestRunner(resultclass=resultFactory, + stream=StringIO()) + def runTests(self, *cases): + suite = unittest.TestSuite() + for case in cases: + tests = unittest.defaultTestLoader.loadTestsFromTestCase(case) + suite.addTests(tests) + + runner = self.getRunner() + + # creating a nested suite exposes some potential bugs + realSuite = unittest.TestSuite() + realSuite.addTest(suite) + # adding empty suites to the end exposes potential bugs + suite.addTest(unittest.TestSuite()) + realSuite.addTest(unittest.TestSuite()) + return runner.run(realSuite) + + def test_setup_class(self): + class Test(unittest.TestCase): + setUpCalled = 0 + @classmethod + def setUpClass(cls): + Test.setUpCalled += 1 + unittest.TestCase.setUpClass() + def test_one(self): + pass + def test_two(self): + pass + + result = self.runTests(Test) + + self.assertEqual(Test.setUpCalled, 1) + self.assertEqual(result.testsRun, 2) + self.assertEqual(len(result.errors), 0) + + def test_teardown_class(self): + class Test(unittest.TestCase): + tearDownCalled = 0 + @classmethod + def tearDownClass(cls): + Test.tearDownCalled += 1 + unittest.TestCase.tearDownClass() + def test_one(self): + pass + def test_two(self): + pass + + result = self.runTests(Test) + + self.assertEqual(Test.tearDownCalled, 1) + self.assertEqual(result.testsRun, 2) + self.assertEqual(len(result.errors), 0) + + def test_teardown_class_two_classes(self): + class Test(unittest.TestCase): + tearDownCalled = 0 + @classmethod + def tearDownClass(cls): + Test.tearDownCalled += 1 + unittest.TestCase.tearDownClass() + def test_one(self): + pass + def test_two(self): + pass + + class Test2(unittest.TestCase): + tearDownCalled = 0 + @classmethod + def tearDownClass(cls): + Test2.tearDownCalled += 1 + unittest.TestCase.tearDownClass() + def test_one(self): + pass + def test_two(self): + pass + + result = self.runTests(Test, Test2) + + self.assertEqual(Test.tearDownCalled, 1) + self.assertEqual(Test2.tearDownCalled, 1) + self.assertEqual(result.testsRun, 4) + self.assertEqual(len(result.errors), 0) + + def test_error_in_setupclass(self): + class BrokenTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + raise TypeError('foo') + def test_one(self): + pass + def test_two(self): + pass + + result = self.runTests(BrokenTest) + + self.assertEqual(result.testsRun, 0) + self.assertEqual(len(result.errors), 1) + error, _ = result.errors[0] + self.assertEqual(str(error), + 'classSetUp (%s.BrokenTest)' % __name__) + + def test_error_in_teardown_class(self): + class Test(unittest.TestCase): + tornDown = 0 + @classmethod + def tearDownClass(cls): + Test.tornDown += 1 + raise TypeError('foo') + def test_one(self): + pass + def test_two(self): + pass + + class Test2(unittest.TestCase): + tornDown = 0 + @classmethod + def tearDownClass(cls): + Test2.tornDown += 1 + raise TypeError('foo') + def test_one(self): + pass + def test_two(self): + pass + + result = self.runTests(Test, Test2) + self.assertEqual(result.testsRun, 4) + self.assertEqual(len(result.errors), 2) + self.assertEqual(Test.tornDown, 1) + self.assertEqual(Test2.tornDown, 1) + + error, _ = result.errors[0] + self.assertEqual(str(error), + 'classTearDown (%s.Test)' % __name__) + + def test_class_not_torndown_when_setup_fails(self): + class Test(unittest.TestCase): + tornDown = False + @classmethod + def setUpClass(cls): + raise TypeError + @classmethod + def tearDownClass(cls): + Test.tornDown = True + raise TypeError('foo') + def test_one(self): + pass + + self.runTests(Test) + self.assertFalse(Test.tornDown) + + def test_class_not_setup_or_torndown_when_skipped(self): + class Test(unittest.TestCase): + classSetUp = False + tornDown = False + @classmethod + def setUpClass(cls): + Test.classSetUp = True + @classmethod + def tearDownClass(cls): + Test.tornDown = True + def test_one(self): + pass + + Test = unittest.skip("hop")(Test) + self.runTests(Test) + self.assertFalse(Test.classSetUp) + self.assertFalse(Test.tornDown) + + def test_setup_teardown_order_with_pathological_suite(self): + results = [] + + class Module1(object): + @staticmethod + def setUpModule(): + results.append('Module1.setUpModule') + @staticmethod + def tearDownModule(): + results.append('Module1.tearDownModule') + + class Module2(object): + @staticmethod + def setUpModule(): + results.append('Module2.setUpModule') + @staticmethod + def tearDownModule(): + results.append('Module2.tearDownModule') + + class Test1(unittest.TestCase): + @classmethod + def setUpClass(cls): + results.append('setup 1') + @classmethod + def tearDownClass(cls): + results.append('teardown 1') + def testOne(self): + results.append('Test1.testOne') + def testTwo(self): + results.append('Test1.testTwo') + + class Test2(unittest.TestCase): + @classmethod + def setUpClass(cls): + results.append('setup 2') + @classmethod + def tearDownClass(cls): + results.append('teardown 2') + def testOne(self): + results.append('Test2.testOne') + def testTwo(self): + results.append('Test2.testTwo') + + class Test3(unittest.TestCase): + @classmethod + def setUpClass(cls): + results.append('setup 3') + @classmethod + def tearDownClass(cls): + results.append('teardown 3') + def testOne(self): + results.append('Test3.testOne') + def testTwo(self): + results.append('Test3.testTwo') + + Test1.__module__ = Test2.__module__ = 'Module' + Test3.__module__ = 'Module2' + sys.modules['Module'] = Module1 + sys.modules['Module2'] = Module2 + + first = unittest.TestSuite((Test1('testOne'),)) + second = unittest.TestSuite((Test1('testTwo'),)) + third = unittest.TestSuite((Test2('testOne'),)) + fourth = unittest.TestSuite((Test2('testTwo'),)) + fifth = unittest.TestSuite((Test3('testOne'),)) + sixth = unittest.TestSuite((Test3('testTwo'),)) + suite = unittest.TestSuite((first, second, third, fourth, fifth, sixth)) + + runner = self.getRunner() + result = runner.run(suite) + self.assertEqual(result.testsRun, 6) + self.assertEqual(len(result.errors), 0) + + self.assertEqual(results, + ['Module1.setUpModule', 'setup 1', + 'Test1.testOne', 'Test1.testTwo', 'teardown 1', + 'setup 2', 'Test2.testOne', 'Test2.testTwo', + 'teardown 2', 'Module1.tearDownModule', + 'Module2.setUpModule', 'setup 3', + 'Test3.testOne', 'Test3.testTwo', + 'teardown 3', 'Module2.tearDownModule']) + + def test_setup_module(self): + class Module(object): + moduleSetup = 0 + @staticmethod + def setUpModule(): + Module.moduleSetup += 1 + + class Test(unittest.TestCase): + def test_one(self): + pass + def test_two(self): + pass + Test.__module__ = 'Module' + sys.modules['Module'] = Module + + result = self.runTests(Test) + self.assertEqual(Module.moduleSetup, 1) + self.assertEqual(result.testsRun, 2) + self.assertEqual(len(result.errors), 0) + + def test_error_in_setup_module(self): + class Module(object): + moduleSetup = 0 + moduleTornDown = 0 + @staticmethod + def setUpModule(): + Module.moduleSetup += 1 + raise TypeError('foo') + @staticmethod + def tearDownModule(): + Module.moduleTornDown += 1 + + class Test(unittest.TestCase): + classSetUp = False + classTornDown = False + @classmethod + def setUpClass(cls): + Test.classSetUp = True + @classmethod + def tearDownClass(cls): + Test.classTornDown = True + def test_one(self): + pass + def test_two(self): + pass + + class Test2(unittest.TestCase): + def test_one(self): + pass + def test_two(self): + pass + Test.__module__ = 'Module' + Test2.__module__ = 'Module' + sys.modules['Module'] = Module + + result = self.runTests(Test, Test2) + self.assertEqual(Module.moduleSetup, 1) + self.assertEqual(Module.moduleTornDown, 0) + self.assertEqual(result.testsRun, 0) + self.assertFalse(Test.classSetUp) + self.assertFalse(Test.classTornDown) + self.assertEqual(len(result.errors), 1) + error, _ = result.errors[0] + self.assertEqual(str(error), 'setUpModule (Module)') + + def test_testcase_with_missing_module(self): + class Test(unittest.TestCase): + def test_one(self): + pass + def test_two(self): + pass + Test.__module__ = 'Module' + sys.modules.pop('Module', None) + + result = self.runTests(Test) + self.assertEqual(result.testsRun, 2) + + def test_teardown_module(self): + class Module(object): + moduleTornDown = 0 + @staticmethod + def tearDownModule(): + Module.moduleTornDown += 1 + + class Test(unittest.TestCase): + def test_one(self): + pass + def test_two(self): + pass + Test.__module__ = 'Module' + sys.modules['Module'] = Module + + result = self.runTests(Test) + self.assertEqual(Module.moduleTornDown, 1) + self.assertEqual(result.testsRun, 2) + self.assertEqual(len(result.errors), 0) + + def test_error_in_teardown_module(self): + class Module(object): + moduleTornDown = 0 + @staticmethod + def tearDownModule(): + Module.moduleTornDown += 1 + raise TypeError('foo') + + class Test(unittest.TestCase): + classSetUp = False + classTornDown = False + @classmethod + def setUpClass(cls): + Test.classSetUp = True + @classmethod + def tearDownClass(cls): + Test.classTornDown = True + def test_one(self): + pass + def test_two(self): + pass + + class Test2(unittest.TestCase): + def test_one(self): + pass + def test_two(self): + pass + Test.__module__ = 'Module' + Test2.__module__ = 'Module' + sys.modules['Module'] = Module + + result = self.runTests(Test, Test2) + self.assertEqual(Module.moduleTornDown, 1) + self.assertEqual(result.testsRun, 4) + self.assertTrue(Test.classSetUp) + self.assertTrue(Test.classTornDown) + self.assertEqual(len(result.errors), 1) + error, _ = result.errors[0] + self.assertEqual(str(error), 'tearDownModule (Module)') + ###################################################################### ## Main ###################################################################### @@ -3946,7 +4340,7 @@ def test_main(): Test_TestSuite, Test_TestResult, Test_FunctionTestCase, Test_TestSkipping, Test_Assertions, TestLongMessage, Test_TestProgram, TestCleanUp, TestDiscovery, Test_TextTestRunner, - Test_OldTestResult) + Test_OldTestResult, TestSetups) if __name__ == "__main__": test_main() diff --git a/Lib/unittest/__init__.py b/Lib/unittest/__init__.py index 4a308fa1c46..06fe55d08d5 100644 --- a/Lib/unittest/__init__.py +++ b/Lib/unittest/__init__.py @@ -51,13 +51,12 @@ __all__ = ['TestResult', 'TestCase', 'TestSuite', # Expose obsolete functions for backwards compatibility __all__.extend(['getTestCaseNames', 'makeSuite', 'findTestCases']) -__all__.append('_TextTestResult') from .result import TestResult from .case import (TestCase, FunctionTestCase, SkipTest, skip, skipIf, skipUnless, expectedFailure) -from .suite import TestSuite +from .suite import BaseTestSuite, TestSuite from .loader import (TestLoader, defaultTestLoader, makeSuite, getTestCaseNames, findTestCases) from .main import TestProgram, main diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py index 19b196c86fe..0bfcc757acf 100644 --- a/Lib/unittest/case.py +++ b/Lib/unittest/case.py @@ -153,6 +153,9 @@ class TestCase(object): longMessage = False + # Attribute used by TestSuite for classSetUp + + _classSetupFailed = False def __init__(self, methodName='runTest'): """Create an instance of the class that will use the named test @@ -211,6 +214,14 @@ class TestCase(object): "Hook method for deconstructing the test fixture after testing it." pass + @classmethod + def setUpClass(cls): + "Hook method for setting up class fixture before running tests in the class." + + @classmethod + def tearDownClass(cls): + "Hook method for deconstructing the class fixture after running all tests in the class." + def countTestCases(self): return 1 diff --git a/Lib/unittest/result.py b/Lib/unittest/result.py index 22e825ac61a..746967ec28a 100644 --- a/Lib/unittest/result.py +++ b/Lib/unittest/result.py @@ -16,6 +16,8 @@ class TestResult(object): contain tuples of (testcase, exceptioninfo), where exceptioninfo is the formatted traceback of the error that occurred. """ + _previousTestClass = None + _moduleSetUpFailed = False def __init__(self, stream=None, descriptions=None, verbosity=None): self.failures = [] self.errors = [] diff --git a/Lib/unittest/suite.py b/Lib/unittest/suite.py index 60e9b6c61da..cccc7efda46 100644 --- a/Lib/unittest/suite.py +++ b/Lib/unittest/suite.py @@ -1,17 +1,13 @@ """TestSuite""" +import sys + from . import case from . import util -class TestSuite(object): - """A test suite is a composite test consisting of a number of TestCases. - - For use, create an instance of TestSuite, then add test case instances. - When all tests have been added, the suite can be passed to a test - runner, such as TextTestRunner. It will run the individual test cases - in the order in which they were added, aggregating the results. When - subclassing, do not forget to call the base class constructor. +class BaseTestSuite(object): + """A simple test suite that doesn't provide class or module shared fixtures. """ def __init__(self, tests=()): self._tests = [] @@ -70,3 +66,190 @@ class TestSuite(object): """Run the tests without collecting errors in a TestResult""" for test in self: test.debug() + + +class TestSuite(BaseTestSuite): + """A test suite is a composite test consisting of a number of TestCases. + + For use, create an instance of TestSuite, then add test case instances. + When all tests have been added, the suite can be passed to a test + runner, such as TextTestRunner. It will run the individual test cases + in the order in which they were added, aggregating the results. When + subclassing, do not forget to call the base class constructor. + """ + + + def run(self, result): + self._wrapped_run(result) + self._tearDownPreviousClass(None, result) + self._handleModuleTearDown(result) + return result + + ################################ + # private methods + def _wrapped_run(self, result): + for test in self: + if result.shouldStop: + break + + if _isnotsuite(test): + self._tearDownPreviousClass(test, result) + self._handleModuleFixture(test, result) + self._handleClassSetUp(test, result) + result._previousTestClass = test.__class__ + + if (getattr(test.__class__, '_classSetupFailed', False) or + getattr(result, '_moduleSetUpFailed', False)): + continue + + if hasattr(test, '_wrapped_run'): + test._wrapped_run(result) + else: + test(result) + + def _handleClassSetUp(self, test, result): + previousClass = getattr(result, '_previousTestClass', None) + currentClass = test.__class__ + if currentClass == previousClass: + return + if result._moduleSetUpFailed: + return + if getattr(currentClass, "__unittest_skip__", False): + return + + currentClass._classSetupFailed = False + + setUpClass = getattr(currentClass, 'setUpClass', None) + if setUpClass is not None: + try: + setUpClass() + except: + currentClass._classSetupFailed = True + self._addClassSetUpError(result, currentClass) + + def _get_previous_module(self, result): + previousModule = None + previousClass = getattr(result, '_previousTestClass', None) + if previousClass is not None: + previousModule = previousClass.__module__ + return previousModule + + + def _handleModuleFixture(self, test, result): + previousModule = self._get_previous_module(result) + currentModule = test.__class__.__module__ + if currentModule == previousModule: + return + + self._handleModuleTearDown(result) + + + result._moduleSetUpFailed = False + try: + module = sys.modules[currentModule] + except KeyError: + return + setUpModule = getattr(module, 'setUpModule', None) + if setUpModule is not None: + try: + setUpModule() + except: + result._moduleSetUpFailed = True + error = _ErrorHolder('setUpModule (%s)' % currentModule) + result.addError(error, sys.exc_info()) + + def _handleModuleTearDown(self, result): + previousModule = self._get_previous_module(result) + if previousModule is None: + return + if result._moduleSetUpFailed: + return + + try: + module = sys.modules[previousModule] + except KeyError: + return + + tearDownModule = getattr(module, 'tearDownModule', None) + if tearDownModule is not None: + try: + tearDownModule() + except: + error = _ErrorHolder('tearDownModule (%s)' % previousModule) + result.addError(error, sys.exc_info()) + + def _tearDownPreviousClass(self, test, result): + previousClass = getattr(result, '_previousTestClass', None) + currentClass = test.__class__ + if currentClass == previousClass: + return + if getattr(previousClass, '_classSetupFailed', False): + return + if getattr(result, '_moduleSetUpFailed', False): + return + if getattr(previousClass, "__unittest_skip__", False): + return + + tearDownClass = getattr(previousClass, 'tearDownClass', None) + if tearDownClass is not None: + try: + tearDownClass() + except: + self._addClassTearDownError(result) + + def _addClassTearDownError(self, result): + className = util.strclass(result._previousTestClass) + error = _ErrorHolder('classTearDown (%s)' % className) + result.addError(error, sys.exc_info()) + + def _addClassSetUpError(self, result, klass): + className = util.strclass(klass) + error = _ErrorHolder('classSetUp (%s)' % className) + result.addError(error, sys.exc_info()) + + +class _ErrorHolder(object): + """ + Placeholder for a TestCase inside a result. As far as a TestResult + is concerned, this looks exactly like a unit test. Used to insert + arbitrary errors into a test suite run. + """ + # Inspired by the ErrorHolder from Twisted: + # http://twistedmatrix.com/trac/browser/trunk/twisted/trial/runner.py + + # attribute used by TestResult._exc_info_to_string + failureException = None + + def __init__(self, description): + self.description = description + + def id(self): + return self.description + + def shortDescription(self): + return None + + def __repr__(self): + return "" % (self.description,) + + def __str__(self): + return self.id() + + def run(self, result): + # could call result.addError(...) - but this test-like object + # shouldn't be run anyway + pass + + def __call__(self, result): + return self.run(result) + + def countTestCases(self): + return 0 + +def _isnotsuite(test): + "A crude way to tell apart testcases and suites with duck-typing" + try: + iter(test) + except TypeError: + return True + return False