diff --git a/Lib/unittest/test/test_suite.py b/Lib/unittest/test/test_suite.py index ab2e16eb99a..47b57de3b74 100644 --- a/Lib/unittest/test/test_suite.py +++ b/Lib/unittest/test/test_suite.py @@ -1,5 +1,6 @@ import unittest +import sys from .support import LoggingResult, TestEquality @@ -300,5 +301,49 @@ class Test_TestSuite(unittest.TestCase, TestEquality): suite.run(unittest.TestResult()) + + def test_basetestsuite(self): + class Test(unittest.TestCase): + wasSetUp = False + wasTornDown = False + @classmethod + def setUpClass(cls): + cls.wasSetUp = True + @classmethod + def tearDownClass(cls): + cls.wasTornDown = True + def testPass(self): + pass + def testFail(self): + fail + class Module(object): + wasSetUp = False + wasTornDown = False + @staticmethod + def setUpModule(): + Module.wasSetUp = True + @staticmethod + def tearDownModule(): + Module.wasTornDown = True + + Test.__module__ = 'Module' + sys.modules['Module'] = Module + self.addCleanup(sys.modules.pop, 'Module') + + suite = unittest.BaseTestSuite() + suite.addTests([Test('testPass'), Test('testFail')]) + self.assertEqual(suite.countTestCases(), 2) + + result = unittest.TestResult() + suite.run(result) + self.assertFalse(Module.wasSetUp) + self.assertFalse(Module.wasTornDown) + self.assertFalse(Test.wasSetUp) + self.assertFalse(Test.wasTornDown) + self.assertEqual(len(result.errors), 1) + self.assertEqual(len(result.failures), 0) + self.assertEqual(result.testsRun, 2) + + if __name__ == '__main__': unittest.main()