diff --git a/Lib/unittest/main.py b/Lib/unittest/main.py index 2f488e191bf..45af6992237 100644 --- a/Lib/unittest/main.py +++ b/Lib/unittest/main.py @@ -70,10 +70,10 @@ class TestProgram(object): # defaults for testing failfast = catchbreak = buffer = None - def __init__(self, module='__main__', defaultTest=None, - argv=None, testRunner=None, - testLoader=loader.defaultTestLoader, exit=True, - verbosity=1, failfast=None, catchbreak=None, buffer=None): + def __init__(self, module='__main__', defaultTest=None, argv=None, + testRunner=None, testLoader=loader.defaultTestLoader, + exit=True, verbosity=1, failfast=None, catchbreak=None, + buffer=None): if isinstance(module, str): self.module = __import__(module) for part in module.split('.')[1:]: diff --git a/Lib/unittest/test/test_program.py b/Lib/unittest/test/test_program.py index b6a69dc28ad..752a0664b00 100644 --- a/Lib/unittest/test/test_program.py +++ b/Lib/unittest/test/test_program.py @@ -1,10 +1,27 @@ import io +import os import unittest class Test_TestProgram(unittest.TestCase): + def test_discovery_from_dotted_path(self): + loader = unittest.TestLoader() + + tests = [self] + expectedPath = os.path.abspath(os.path.dirname(unittest.test.__file__)) + + self.wasRun = False + def _find_tests(start_dir, pattern): + self.wasRun = True + self.assertEqual(start_dir, expectedPath) + return tests + loader._find_tests = _find_tests + suite = loader.discover('unittest.test') + self.assertTrue(self.wasRun) + self.assertEqual(suite._tests, tests) + # Horrible white box test def testNoExit(self): result = object()