Close #22457: Honour load_tests in the start_dir of discovery.

We were not honouring load_tests in a package/__init__.py when that was the
start_dir parameter, though we do when it is a child package. The fix required
a little care since it introduces the possibility of infinite recursion.
This commit is contained in:
Robert Collins 2014-11-05 03:09:01 +13:00
parent d39e199a0d
commit bf2bda3c97
6 changed files with 166 additions and 58 deletions

View File

@ -1668,7 +1668,11 @@ Loading and running tests
If a package (a directory containing a file named :file:`__init__.py`) is
found, the package will be checked for a ``load_tests`` function. If this
exists then it will be called with *loader*, *tests*, *pattern*.
exists then it will be called
``package.load_tests(loader, tests, pattern)``. Test discovery takes care
to ensure that a package is only checked for tests once during an
invocation, even if the load_tests function itself calls
``loader.discover``.
If ``load_tests`` exists then discovery does *not* recurse into the
package, ``load_tests`` is responsible for loading all tests in the

View File

@ -67,3 +67,12 @@ from .signals import installHandler, registerResult, removeResult, removeHandler
# deprecated
_TextTestResult = TextTestResult
# There are no tests here, so don't try to run anything discovered from
# introspecting the symbols (e.g. FunctionTestCase). Instead, all our
# tests come from within unittest.test.
def load_tests(loader, tests, pattern):
import os.path
# top level directory cached on loader instance
this_dir = os.path.dirname(__file__)
return loader.discover(start_dir=this_dir, pattern=pattern)

View File

@ -65,6 +65,9 @@ class TestLoader(object):
def __init__(self):
super(TestLoader, self).__init__()
self.errors = []
# Tracks packages which we have called into via load_tests, to
# avoid infinite re-entrancy.
self._loading_packages = set()
def loadTestsFromTestCase(self, testCaseClass):
"""Return a suite of all tests cases contained in testCaseClass"""
@ -229,9 +232,13 @@ class TestLoader(object):
If a test package name (directory with '__init__.py') matches the
pattern then the package will be checked for a 'load_tests' function. If
this exists then it will be called with loader, tests, pattern.
this exists then it will be called with (loader, tests, pattern) unless
the package has already had load_tests called from the same discovery
invocation, in which case the package module object is not scanned for
tests - this ensures that when a package uses discover to further
discover child tests that infinite recursion does not happen.
If load_tests exists then discovery does *not* recurse into the package,
If load_tests exists then discovery does *not* recurse into the package,
load_tests is responsible for loading all tests in the package.
The pattern is deliberately not stored as a loader attribute so that
@ -355,69 +362,110 @@ class TestLoader(object):
def _find_tests(self, start_dir, pattern, namespace=False):
"""Used by discovery. Yields test suites it loads."""
# Handle the __init__ in this package
name = self._get_name_from_path(start_dir)
# name is '.' when start_dir == top_level_dir (and top_level_dir is by
# definition not a package).
if name != '.' and name not in self._loading_packages:
# name is in self._loading_packages while we have called into
# loadTestsFromModule with name.
tests, should_recurse = self._find_test_path(
start_dir, pattern, namespace)
if tests is not None:
yield tests
if not should_recurse:
# Either an error occured, or load_tests was used by the
# package.
return
# Handle the contents.
paths = sorted(os.listdir(start_dir))
for path in paths:
full_path = os.path.join(start_dir, path)
if os.path.isfile(full_path):
if not VALID_MODULE_NAME.match(path):
# valid Python identifiers only
continue
if not self._match_path(path, full_path, pattern):
continue
# if the test file matches, load it
tests, should_recurse = self._find_test_path(
full_path, pattern, namespace)
if tests is not None:
yield tests
if should_recurse:
# we found a package that didn't use load_tests.
name = self._get_name_from_path(full_path)
self._loading_packages.add(name)
try:
module = self._get_module_from_name(name)
except case.SkipTest as e:
yield _make_skipped_test(name, e, self.suiteClass)
except:
error_case, error_message = \
_make_failed_import_test(name, self.suiteClass)
self.errors.append(error_message)
yield error_case
else:
mod_file = os.path.abspath(getattr(module, '__file__', full_path))
realpath = _jython_aware_splitext(os.path.realpath(mod_file))
fullpath_noext = _jython_aware_splitext(os.path.realpath(full_path))
if realpath.lower() != fullpath_noext.lower():
module_dir = os.path.dirname(realpath)
mod_name = _jython_aware_splitext(os.path.basename(full_path))
expected_dir = os.path.dirname(full_path)
msg = ("%r module incorrectly imported from %r. Expected %r. "
"Is this module globally installed?")
raise ImportError(msg % (mod_name, module_dir, expected_dir))
yield self.loadTestsFromModule(module, pattern=pattern)
elif os.path.isdir(full_path):
if (not namespace and
not os.path.isfile(os.path.join(full_path, '__init__.py'))):
continue
yield from self._find_tests(full_path, pattern, namespace)
finally:
self._loading_packages.discard(name)
load_tests = None
tests = None
name = self._get_name_from_path(full_path)
def _find_test_path(self, full_path, pattern, namespace=False):
"""Used by discovery.
Loads tests from a single file, or a directories' __init__.py when
passed the directory.
Returns a tuple (None_or_tests_from_file, should_recurse).
"""
basename = os.path.basename(full_path)
if os.path.isfile(full_path):
if not VALID_MODULE_NAME.match(basename):
# valid Python identifiers only
return None, False
if not self._match_path(basename, full_path, pattern):
return None, False
# if the test file matches, load it
name = self._get_name_from_path(full_path)
try:
module = self._get_module_from_name(name)
except case.SkipTest as e:
return _make_skipped_test(name, e, self.suiteClass), False
except:
error_case, error_message = \
_make_failed_import_test(name, self.suiteClass)
self.errors.append(error_message)
return error_case, False
else:
mod_file = os.path.abspath(
getattr(module, '__file__', full_path))
realpath = _jython_aware_splitext(
os.path.realpath(mod_file))
fullpath_noext = _jython_aware_splitext(
os.path.realpath(full_path))
if realpath.lower() != fullpath_noext.lower():
module_dir = os.path.dirname(realpath)
mod_name = _jython_aware_splitext(
os.path.basename(full_path))
expected_dir = os.path.dirname(full_path)
msg = ("%r module incorrectly imported from %r. Expected "
"%r. Is this module globally installed?")
raise ImportError(
msg % (mod_name, module_dir, expected_dir))
return self.loadTestsFromModule(module, pattern=pattern), False
elif os.path.isdir(full_path):
if (not namespace and
not os.path.isfile(os.path.join(full_path, '__init__.py'))):
return None, False
load_tests = None
tests = None
name = self._get_name_from_path(full_path)
try:
package = self._get_module_from_name(name)
except case.SkipTest as e:
return _make_skipped_test(name, e, self.suiteClass), False
except:
error_case, error_message = \
_make_failed_import_test(name, self.suiteClass)
self.errors.append(error_message)
return error_case, False
else:
load_tests = getattr(package, 'load_tests', None)
# Mark this package as being in load_tests (possibly ;))
self._loading_packages.add(name)
try:
package = self._get_module_from_name(name)
except case.SkipTest as e:
yield _make_skipped_test(name, e, self.suiteClass)
except:
error_case, error_message = \
_make_failed_import_test(name, self.suiteClass)
self.errors.append(error_message)
yield error_case
else:
load_tests = getattr(package, 'load_tests', None)
tests = self.loadTestsFromModule(package, pattern=pattern)
if tests is not None:
# tests loaded from package file
yield tests
if load_tests is not None:
# loadTestsFromModule(package) has load_tests for us.
continue
# recurse into the package
yield from self._find_tests(full_path, pattern,
namespace=namespace)
# loadTestsFromModule(package) has loaded tests for us.
return tests, False
return tests, True
finally:
self._loading_packages.discard(name)
defaultTestLoader = TestLoader()

View File

@ -368,6 +368,51 @@ class TestDiscovery(unittest.TestCase):
self.assertEqual(_find_tests_args, [(start_dir, 'pattern')])
self.assertIn(top_level_dir, sys.path)
def test_discover_start_dir_is_package_calls_package_load_tests(self):
# This test verifies that the package load_tests in a package is indeed
# invoked when the start_dir is a package (and not the top level).
# http://bugs.python.org/issue22457
# Test data: we expect the following:
# an isfile to verify the package, then importing and scanning
# as per _find_tests' normal behaviour.
# We expect to see our load_tests hook called once.
vfs = {abspath('/toplevel'): ['startdir'],
abspath('/toplevel/startdir'): ['__init__.py']}
def list_dir(path):
return list(vfs[path])
self.addCleanup(setattr, os, 'listdir', os.listdir)
os.listdir = list_dir
self.addCleanup(setattr, os.path, 'isfile', os.path.isfile)
os.path.isfile = lambda path: path.endswith('.py')
self.addCleanup(setattr, os.path, 'isdir', os.path.isdir)
os.path.isdir = lambda path: not path.endswith('.py')
self.addCleanup(sys.path.remove, abspath('/toplevel'))
class Module(object):
paths = []
load_tests_args = []
def __init__(self, path):
self.path = path
def load_tests(self, loader, tests, pattern):
return ['load_tests called ' + self.path]
def __eq__(self, other):
return self.path == other.path
loader = unittest.TestLoader()
loader._get_module_from_name = lambda name: Module(name)
loader.suiteClass = lambda thing: thing
suite = loader.discover('/toplevel/startdir', top_level_dir='/toplevel')
# We should have loaded tests from the package __init__.
# (normally this would be nested TestSuites.)
self.assertEqual(suite,
[['load_tests called startdir']])
def setup_import_issue_tests(self, fakefile):
listdir = os.listdir
os.listdir = lambda _: [fakefile]

View File

@ -841,7 +841,7 @@ class Test_TestLoader(unittest.TestCase):
loader = unittest.TestLoader()
suite = loader.loadTestsFromNames(
['unittest.loader.sdasfasfasdf', 'unittest'])
['unittest.loader.sdasfasfasdf', 'unittest.test.dummy'])
error, test = self.check_deferred_error(loader, list(suite)[0])
expected = "module 'unittest.loader' has no attribute 'sdasfasfasdf'"
self.assertIn(

View File

@ -212,6 +212,8 @@ Library
- Issue #22217: Implemented reprs of classes in the zipfile module.
- Issue #22457: Honour load_tests in the start_dir of discovery.
- Issue #18216: gettext now raises an error when a .mo file has an
unsupported major version number. Patch by Aaron Hill.