From fa2f1cdcbb6cf2d31066605b0f71d8fe3ed337aa Mon Sep 17 00:00:00 2001 From: Michael Foord Date: Fri, 26 Mar 2010 03:18:31 +0000 Subject: [PATCH] Addition of -c command line option to unittest, to handle ctrl-c during a test run more elegantly --- Lib/unittest/__init__.py | 4 +- Lib/unittest/main.py | 61 ++++++-- Lib/unittest/runner.py | 2 + Lib/unittest/signals.py | 38 +++++ Lib/unittest/test/test_break.py | 225 ++++++++++++++++++++++++++++ Lib/unittest/test/test_discovery.py | 5 +- 6 files changed, 319 insertions(+), 16 deletions(-) create mode 100644 Lib/unittest/signals.py create mode 100644 Lib/unittest/test/test_break.py diff --git a/Lib/unittest/__init__.py b/Lib/unittest/__init__.py index 7153802eb7f..e84299e993a 100644 --- a/Lib/unittest/__init__.py +++ b/Lib/unittest/__init__.py @@ -47,7 +47,8 @@ SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. __all__ = ['TestResult', 'TestCase', 'TestSuite', 'TextTestRunner', 'TestLoader', 'FunctionTestCase', 'main', 'defaultTestLoader', 'SkipTest', 'skip', 'skipIf', 'skipUnless', - 'expectedFailure', 'TextTestResult'] + 'expectedFailure', 'TextTestResult', 'installHandler', + 'registerResult', 'removeResult'] # Expose obsolete functions for backwards compatibility __all__.extend(['getTestCaseNames', 'makeSuite', 'findTestCases']) @@ -62,6 +63,7 @@ from .loader import (TestLoader, defaultTestLoader, makeSuite, getTestCaseNames, findTestCases) from .main import TestProgram, main from .runner import TextTestRunner, TextTestResult +from .signals import installHandler, registerResult, removeResult # deprecated _TextTestResult = TextTestResult diff --git a/Lib/unittest/main.py b/Lib/unittest/main.py index d0da7c070a7..e2703ce7270 100644 --- a/Lib/unittest/main.py +++ b/Lib/unittest/main.py @@ -5,10 +5,14 @@ import os import types from . import loader, runner +from .signals import installHandler __unittest = True +FAILFAST = " -f, --failfast Stop on first failure\n" +CATCHBREAK = " -c, --catch Catch control-C and display results\n" + USAGE_AS_MAIN = """\ Usage: %(progName)s [options] [tests] @@ -16,8 +20,7 @@ Options: -h, --help Show this message -v, --verbose Verbose output -q, --quiet Minimal output - -f, --failfast Stop on first failure - +%(failfast)s%(catchbreak)s Examples: %(progName)s test_module - run tests from test_module %(progName)s test_module.TestClass - run tests from @@ -31,8 +34,7 @@ Alternative Usage: %(progName)s discover [options] Options: -v, --verbose Verbose output - -f, --failfast Stop on first failure - -s directory Directory to start discovery ('.' default) +%(failfast)s%(catchbreak)s -s directory Directory to start discovery ('.' default) -p pattern Pattern to match test files ('test*.py' default) -t directory Top level directory of project (default to start directory) @@ -48,8 +50,7 @@ Options: -h, --help Show this message -v, --verbose Verbose output -q, --quiet Minimal output - -f, --failfast Stop on first failure - +%(failfast)s%(catchbreak)s Examples: %(progName)s - run default set of tests %(progName)s MyTestSuite - run suite 'MyTestSuite' @@ -58,15 +59,21 @@ Examples: in MyTestCase """ + + class TestProgram(object): """A command-line program that runs a set of tests; this is primarily for making test modules conveniently executable. """ USAGE = USAGE_FROM_MODULE + + # defaults for testing + failfast = catchbreak = None + def __init__(self, module='__main__', defaultTest=None, argv=None, testRunner=None, testLoader=loader.defaultTestLoader, exit=True, - verbosity=1, failfast=False): + verbosity=1, failfast=None, catchbreak=None): if isinstance(module, basestring): self.module = __import__(module) for part in module.split('.')[1:]: @@ -78,6 +85,7 @@ class TestProgram(object): self.exit = exit self.failfast = failfast + self.catchbreak = catchbreak self.verbosity = verbosity self.defaultTest = defaultTest self.testRunner = testRunner @@ -89,7 +97,12 @@ class TestProgram(object): def usageExit(self, msg=None): if msg: print msg - print self.USAGE % self.__dict__ + usage = {'progName': self.progName, 'catchbreak': '', 'failfast': ''} + if self.failfast != False: + usage['failfast'] = FAILFAST + if self.catchbreak != False: + usage['catchbreak'] = CATCHBREAK + print self.USAGE % usage sys.exit(2) def parseArgs(self, argv): @@ -98,9 +111,9 @@ class TestProgram(object): return import getopt - long_opts = ['help', 'verbose', 'quiet', 'failfast'] + long_opts = ['help', 'verbose', 'quiet', 'failfast', 'catch'] try: - options, args = getopt.getopt(argv[1:], 'hHvqf', long_opts) + options, args = getopt.getopt(argv[1:], 'hHvqfc', long_opts) for opt, value in options: if opt in ('-h','-H','--help'): self.usageExit() @@ -109,7 +122,13 @@ class TestProgram(object): if opt in ('-v','--verbose'): self.verbosity = 2 if opt in ('-f','--failfast'): - self.failfast = True + if self.failfast is None: + self.failfast = True + # Should this raise an exception if -f is not valid? + if opt in ('-c','--catch'): + if self.catchbreak is None: + self.catchbreak = True + # Should this raise an exception if -c is not valid? if len(args) == 0 and self.defaultTest is None: # createTests will load tests from self.module self.testNames = None @@ -137,8 +156,14 @@ class TestProgram(object): parser = optparse.OptionParser() parser.add_option('-v', '--verbose', dest='verbose', default=False, help='Verbose output', action='store_true') - parser.add_option('-f', '--failfast', dest='failfast', default=False, - help='Stop on first fail or error', action='store_true') + if self.failfast != False: + parser.add_option('-f', '--failfast', dest='failfast', default=False, + help='Stop on first fail or error', + action='store_true') + if self.catchbreak != False: + parser.add_option('-c', '--catch', dest='catchbreak', default=False, + help='Catch ctrl-C and display results so far', + action='store_true') parser.add_option('-s', '--start-directory', dest='start', default='.', help="Directory to start discovery ('.' default)") parser.add_option('-p', '--pattern', dest='pattern', default='test*.py', @@ -153,7 +178,13 @@ class TestProgram(object): for name, value in zip(('start', 'pattern', 'top'), args): setattr(options, name, value) - self.failfast = options.failfast + # only set options from the parsing here + # if they weren't set explicitly in the constructor + if self.failfast is None: + self.failfast = options.failfast + if self.catchbreak is None: + self.catchbreak = options.catchbreak + if options.verbose: self.verbosity = 2 @@ -165,6 +196,8 @@ class TestProgram(object): self.test = loader.discover(start_dir, pattern, top_level_dir) def runTests(self): + if self.catchbreak: + installHandler() if self.testRunner is None: self.testRunner = runner.TextTestRunner if isinstance(self.testRunner, (type, types.ClassType)): diff --git a/Lib/unittest/runner.py b/Lib/unittest/runner.py index fd56a303499..5169d20ece0 100644 --- a/Lib/unittest/runner.py +++ b/Lib/unittest/runner.py @@ -4,6 +4,7 @@ import sys import time from . import result +from .signals import registerResult __unittest = True @@ -138,6 +139,7 @@ class TextTestRunner(object): def run(self, test): "Run the given test case or test suite." result = self._makeResult() + registerResult(result) result.failfast = self.failfast startTime = time.time() startTestRun = getattr(result, 'startTestRun', None) diff --git a/Lib/unittest/signals.py b/Lib/unittest/signals.py new file mode 100644 index 00000000000..0651cf2edea --- /dev/null +++ b/Lib/unittest/signals.py @@ -0,0 +1,38 @@ +import signal +import weakref + +__unittest = True + + +class _InterruptHandler(object): + def __init__(self, default_handler): + self.called = False + self.default_handler = default_handler + + def __call__(self, signum, frame): + installed_handler = signal.getsignal(signal.SIGINT) + if installed_handler is not self: + # if we aren't the installed handler, then delegate immediately + # to the default handler + self.default_handler(signum, frame) + + if self.called: + self.default_handler(signum, frame) + self.called = True + for result in _results.keys(): + result.stop() + +_results = weakref.WeakKeyDictionary() +def registerResult(result): + _results[result] = 1 + +def removeResult(result): + return bool(_results.pop(result, None)) + +_interrupt_handler = None +def installHandler(): + global _interrupt_handler + if _interrupt_handler is None: + default_handler = signal.getsignal(signal.SIGINT) + _interrupt_handler = _InterruptHandler(default_handler) + signal.signal(signal.SIGINT, _interrupt_handler) diff --git a/Lib/unittest/test/test_break.py b/Lib/unittest/test/test_break.py new file mode 100644 index 00000000000..0de31cd8643 --- /dev/null +++ b/Lib/unittest/test/test_break.py @@ -0,0 +1,225 @@ +import gc +import os +import signal +import weakref + +from cStringIO import StringIO + + +import unittest + + +@unittest.skipUnless(hasattr(os, 'kill'), "Test requires os.kill") +class TestBreak(unittest.TestCase): + + def setUp(self): + self._default_handler = signal.getsignal(signal.SIGINT) + + def tearDown(self): + signal.signal(signal.SIGINT, self._default_handler) + unittest.signals._results = weakref.WeakKeyDictionary() + unittest.signals._interrupt_handler = None + + + def testInstallHandler(self): + default_handler = signal.getsignal(signal.SIGINT) + unittest.installHandler() + self.assertNotEqual(signal.getsignal(signal.SIGINT), default_handler) + + try: + pid = os.getpid() + os.kill(pid, signal.SIGINT) + except KeyboardInterrupt: + self.fail("KeyboardInterrupt not handled") + + self.assertTrue(unittest.signals._interrupt_handler.called) + + def testRegisterResult(self): + result = unittest.TestResult() + unittest.registerResult(result) + + for ref in unittest.signals._results: + if ref is result: + break + elif ref is not result: + self.fail("odd object in result set") + else: + self.fail("result not found") + + + def testInterruptCaught(self): + default_handler = signal.getsignal(signal.SIGINT) + + result = unittest.TestResult() + unittest.installHandler() + unittest.registerResult(result) + + self.assertNotEqual(signal.getsignal(signal.SIGINT), default_handler) + + def test(result): + pid = os.getpid() + os.kill(pid, signal.SIGINT) + result.breakCaught = True + self.assertTrue(result.shouldStop) + + try: + test(result) + except KeyboardInterrupt: + self.fail("KeyboardInterrupt not handled") + self.assertTrue(result.breakCaught) + + + def testSecondInterrupt(self): + result = unittest.TestResult() + unittest.installHandler() + unittest.registerResult(result) + + def test(result): + pid = os.getpid() + os.kill(pid, signal.SIGINT) + result.breakCaught = True + self.assertTrue(result.shouldStop) + os.kill(pid, signal.SIGINT) + self.fail("Second KeyboardInterrupt not raised") + + try: + test(result) + except KeyboardInterrupt: + pass + else: + self.fail("Second KeyboardInterrupt not raised") + self.assertTrue(result.breakCaught) + + + def testTwoResults(self): + unittest.installHandler() + + result = unittest.TestResult() + unittest.registerResult(result) + new_handler = signal.getsignal(signal.SIGINT) + + result2 = unittest.TestResult() + unittest.registerResult(result2) + self.assertEqual(signal.getsignal(signal.SIGINT), new_handler) + + result3 = unittest.TestResult() + + def test(result): + pid = os.getpid() + os.kill(pid, signal.SIGINT) + + try: + test(result) + except KeyboardInterrupt: + self.fail("KeyboardInterrupt not handled") + + self.assertTrue(result.shouldStop) + self.assertTrue(result2.shouldStop) + self.assertFalse(result3.shouldStop) + + + def testHandlerReplacedButCalled(self): + # If our handler has been replaced (is no longer installed) but is + # called by the *new* handler, then it isn't safe to delay the + # SIGINT and we should immediately delegate to the default handler + unittest.installHandler() + + handler = signal.getsignal(signal.SIGINT) + def new_handler(frame, signum): + handler(frame, signum) + signal.signal(signal.SIGINT, new_handler) + + try: + pid = os.getpid() + os.kill(pid, signal.SIGINT) + except KeyboardInterrupt: + pass + else: + self.fail("replaced but delegated handler doesn't raise interrupt") + + def testRunner(self): + # Creating a TextTestRunner with the appropriate argument should + # register the TextTestResult it creates + runner = unittest.TextTestRunner(stream=StringIO()) + + result = runner.run(unittest.TestSuite()) + self.assertIn(result, unittest.signals._results) + + def testWeakReferences(self): + # Calling registerResult on a result should not keep it alive + result = unittest.TestResult() + unittest.registerResult(result) + + ref = weakref.ref(result) + del result + + # For non-reference counting implementations + gc.collect();gc.collect() + self.assertIsNone(ref()) + + + def testRemoveResult(self): + result = unittest.TestResult() + unittest.registerResult(result) + + unittest.installHandler() + self.assertTrue(unittest.removeResult(result)) + + # Should this raise an error instead? + self.assertFalse(unittest.removeResult(unittest.TestResult())) + + try: + pid = os.getpid() + os.kill(pid, signal.SIGINT) + except KeyboardInterrupt: + pass + + self.assertFalse(result.shouldStop) + + def testMainInstallsHandler(self): + failfast = object() + test = object() + verbosity = object() + result = object() + default_handler = signal.getsignal(signal.SIGINT) + + class FakeRunner(object): + initArgs = [] + runArgs = [] + def __init__(self, *args, **kwargs): + self.initArgs.append((args, kwargs)) + def run(self, test): + self.runArgs.append(test) + return result + + class Program(unittest.TestProgram): + def __init__(self, catchbreak): + self.exit = False + self.verbosity = verbosity + self.failfast = failfast + self.catchbreak = catchbreak + self.testRunner = FakeRunner + self.test = test + self.result = None + + p = Program(False) + p.runTests() + + self.assertEqual(FakeRunner.initArgs, [((), {'verbosity': verbosity, + 'failfast': failfast})]) + self.assertEqual(FakeRunner.runArgs, [test]) + self.assertEqual(p.result, result) + + self.assertEqual(signal.getsignal(signal.SIGINT), default_handler) + + FakeRunner.initArgs = [] + FakeRunner.runArgs = [] + p = Program(True) + p.runTests() + + self.assertEqual(FakeRunner.initArgs, [((), {'verbosity': verbosity, + 'failfast': failfast})]) + self.assertEqual(FakeRunner.runArgs, [test]) + self.assertEqual(p.result, result) + + self.assertNotEqual(signal.getsignal(signal.SIGINT), default_handler) diff --git a/Lib/unittest/test/test_discovery.py b/Lib/unittest/test/test_discovery.py index 2fa78095c79..0221fc2945b 100644 --- a/Lib/unittest/test/test_discovery.py +++ b/Lib/unittest/test/test_discovery.py @@ -279,14 +279,17 @@ class TestDiscovery(unittest.TestCase): self.assertEqual(program.test, 'tests') self.assertEqual(Loader.args, [('.', 'fish', None)]) self.assertFalse(program.failfast) + self.assertFalse(program.catchbreak) Loader.args = [] program = object.__new__(unittest.TestProgram) - program._do_discovery(['-p', 'eggs', '-s', 'fish', '-v', '-f'], Loader=Loader) + program._do_discovery(['-p', 'eggs', '-s', 'fish', '-v', '-f', '-c'], + Loader=Loader) self.assertEqual(program.test, 'tests') self.assertEqual(Loader.args, [('fish', 'eggs', None)]) self.assertEqual(program.verbosity, 2) self.assertTrue(program.failfast) + self.assertTrue(program.catchbreak) if __name__ == '__main__':