Merged revisions 66321 via svnmerge from

svn+ssh://pythondev@svn.python.org/python/trunk

........
  r66321 | brett.cannon | 2008-09-08 17:49:16 -0700 (Mon, 08 Sep 2008) | 7 lines

  warnings.catch_warnings() now returns a list or None instead of the custom
  WarningsRecorder object. This makes the API simpler to use as no special object
  must be learned.

  Closes issue 3781.
  Review by Benjamin Peterson.
........
This commit is contained in:
Brett Cannon 2008-09-09 01:52:27 +00:00
parent 4c19e6e02d
commit 1cd0247a4d
10 changed files with 193 additions and 137 deletions

View File

@ -160,6 +160,67 @@ ImportWarning can also be enabled explicitly in Python code using::
warnings.simplefilter('default', ImportWarning) warnings.simplefilter('default', ImportWarning)
.. _warning-suppress:
Temporarily Suppressing Warnings
--------------------------------
If you are using code that you know will raise a warning, such some deprecated
function, but do not want to see the warning, then suppress the warning using
the :class:`catch_warnings` context manager::
import warnings
def fxn():
warnings.warn("deprecated", DeprecationWarning)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
fxn()
While within the context manager all warnings will simply be ignored. This
allows you to use known-deprecated code without having to see the warning while
not suppressing the warning for other code that might not be aware of its use
of deprecated code.
.. _warning-testing:
Testing Warnings
----------------
To test warnings raised by code, use the :class:`catch_warnings` context
manager. With it you can temporarily mutate the warnings filter to facilitate
your testing. For instance, do the following to capture all raised warnings to
check::
import warnings
def fxn():
warnings.warn("deprecated", DeprecationWarning)
with warnings.catch_warnings(record=True) as w:
# Cause all warnings to always be triggered.
warnings.simplefilter("always")
# Trigger a warning.
fxn()
# Verify some things
assert len(w) == 1
assert isinstance(w[-1].category, DeprecationWarning)
assert "deprecated" in str(w[-1].message)
One can also cause all warnings to be exceptions by using ``error`` instead of
``always``. One thing to be aware of is that if a warning has already been
raised because of a ``once``/``default`` rule, then no matter what filters are
set the warning will not be seen again unless the warnings registry related to
the warning has been cleared.
Once the context manager exits, the warnings filter is restored to its state
when the context was entered. This prevents tests from changing the warnings
filter in unexpected ways between tests and leading to indeterminate test
results.
.. _warning-functions: .. _warning-functions:
Available Functions Available Functions
@ -248,31 +309,22 @@ Available Functions
and calls to :func:`simplefilter`. and calls to :func:`simplefilter`.
Available Classes Available Context Managers
----------------- --------------------------
.. class:: catch_warnings([\*, record=False, module=None]) .. class:: catch_warnings([\*, record=False, module=None])
A context manager that guards the warnings filter from being permanently A context manager that copies and, upon exit, restores the warnings filter.
mutated. The manager returns an instance of :class:`WarningsRecorder`. The If the *record* argument is False (the default) the context manager returns
*record* argument specifies whether warnings that would typically be :class:`None`. If *record* is true, a list is returned that is populated
handled by :func:`showwarning` should instead be recorded by the with objects as seen by a custom :func:`showwarning` function (which also
:class:`WarningsRecorder` instance. This argument is typically set when suppresses output to ``sys.stdout``). Each object has attributes with the
testing for expected warnings behavior. The *module* argument may be a same names as the arguments to :func:`showwarning`.
module object that is to be used instead of the :mod:`warnings` module.
This argument should only be set when testing the :mod:`warnings` module
or some similar use-case.
Typical usage of the context manager is like so:: The *module* argument takes a module that will be used instead of the
module returned when you import :mod:`warnings` whose filter will be
def fxn(): protected. This arguments exists primarily for testing the :mod:`warnings`
warn("fxn is deprecated", DeprecationWarning) module itself.
return "spam spam bacon spam"
# The function 'fxn' is known to raise a DeprecationWarning.
with catch_warnings() as w:
warnings.filterwarning('ignore', 'fxn is deprecated', DeprecationWarning)
fxn() # DeprecationWarning is temporarily suppressed.
.. versionadded:: 2.6 .. versionadded:: 2.6
@ -280,19 +332,3 @@ Available Classes
Constructor arguments turned into keyword-only arguments. Constructor arguments turned into keyword-only arguments.
.. class:: WarningsRecorder()
A subclass of :class:`list` that stores all warnings passed to
:func:`showwarning` when returned by a :class:`catch_warnings` context
manager created with its *record* argument set to ``True``. Each recorded
warning is represented by an object whose attributes correspond to the
arguments to :func:`showwarning`. As a convenience, a
:class:`WarningsRecorder` instance has the attributes of the last
recorded warning set on the :class:`WarningsRecorder` instance as well.
.. method:: reset()
Delete all recorded warnings.
.. versionadded:: 2.6

View File

@ -1,5 +1,5 @@
import unittest import unittest
from test.support import run_unittest, catch_warning from test.support import run_unittest
import sys import sys
import warnings import warnings
@ -8,7 +8,7 @@ class AllTest(unittest.TestCase):
def check_all(self, modname): def check_all(self, modname):
names = {} names = {}
with catch_warning(): with warnings.catch_warnings():
warnings.filterwarnings("ignore", ".* (module|package)", warnings.filterwarnings("ignore", ".* (module|package)",
DeprecationWarning) DeprecationWarning)
try: try:

View File

@ -211,7 +211,7 @@ class TestVectorsTestCase(unittest.TestCase):
def digest(self): def digest(self):
return self._x.digest() return self._x.digest()
with support.catch_warning(): with warnings.catch_warnings():
warnings.simplefilter('error', RuntimeWarning) warnings.simplefilter('error', RuntimeWarning)
try: try:
hmac.HMAC(b'a', b'b', digestmod=MockCrazyHash) hmac.HMAC(b'a', b'b', digestmod=MockCrazyHash)

View File

@ -6,7 +6,7 @@ import sys
import py_compile import py_compile
import warnings import warnings
import imp import imp
from test.support import unlink, TESTFN, unload, run_unittest, catch_warning from test.support import unlink, TESTFN, unload, run_unittest
def remove_files(name): def remove_files(name):
@ -153,7 +153,7 @@ class ImportTest(unittest.TestCase):
self.assert_(y is test.support, y.__name__) self.assert_(y is test.support, y.__name__)
def test_import_initless_directory_warning(self): def test_import_initless_directory_warning(self):
with catch_warning(): with warnings.catch_warnings():
# Just a random non-package directory we always expect to be # Just a random non-package directory we always expect to be
# somewhere in sys.path... # somewhere in sys.path...
warnings.simplefilter('error', ImportWarning) warnings.simplefilter('error', ImportWarning)

View File

@ -1,7 +1,7 @@
import sys import sys
sys.path = ['.'] + sys.path sys.path = ['.'] + sys.path
from test.support import verbose, run_unittest, catch_warning from test.support import verbose, run_unittest
import re import re
from re import Scanner from re import Scanner
import sys, os, traceback import sys, os, traceback

View File

@ -4,7 +4,7 @@ import struct
import warnings import warnings
from functools import wraps from functools import wraps
from test.support import TestFailed, verbose, run_unittest, catch_warning from test.support import TestFailed, verbose, run_unittest
import sys import sys
ISBIGENDIAN = sys.byteorder == "big" ISBIGENDIAN = sys.byteorder == "big"
@ -34,7 +34,7 @@ def bigendian_to_native(value):
def with_warning_restore(func): def with_warning_restore(func):
@wraps(func) @wraps(func)
def decorator(*args, **kw): def decorator(*args, **kw):
with catch_warning(): with warnings.catch_warnings():
# We need this function to warn every time, so stick an # We need this function to warn every time, so stick an
# unqualifed 'always' at the head of the filter list # unqualifed 'always' at the head of the filter list
warnings.simplefilter("always") warnings.simplefilter("always")

View File

@ -66,35 +66,35 @@ class ReadWriteTests(unittest.TestCase):
class TestWarnings(unittest.TestCase): class TestWarnings(unittest.TestCase):
def has_warned(self, w): def has_warned(self, w):
self.assertEqual(w.category, RuntimeWarning) self.assertEqual(w[-1].category, RuntimeWarning)
def test_byte_max(self): def test_byte_max(self):
with support.catch_warning() as w: with warnings.catch_warnings(record=True) as w:
ts.T_BYTE = CHAR_MAX+1 ts.T_BYTE = CHAR_MAX+1
self.has_warned(w) self.has_warned(w)
def test_byte_min(self): def test_byte_min(self):
with support.catch_warning() as w: with warnings.catch_warnings(record=True) as w:
ts.T_BYTE = CHAR_MIN-1 ts.T_BYTE = CHAR_MIN-1
self.has_warned(w) self.has_warned(w)
def test_ubyte_max(self): def test_ubyte_max(self):
with support.catch_warning() as w: with warnings.catch_warnings(record=True) as w:
ts.T_UBYTE = UCHAR_MAX+1 ts.T_UBYTE = UCHAR_MAX+1
self.has_warned(w) self.has_warned(w)
def test_short_max(self): def test_short_max(self):
with support.catch_warning() as w: with warnings.catch_warnings(record=True) as w:
ts.T_SHORT = SHRT_MAX+1 ts.T_SHORT = SHRT_MAX+1
self.has_warned(w) self.has_warned(w)
def test_short_min(self): def test_short_min(self):
with support.catch_warning() as w: with warnings.catch_warnings(record=True) as w:
ts.T_SHORT = SHRT_MIN-1 ts.T_SHORT = SHRT_MIN-1
self.has_warned(w) self.has_warned(w)
def test_ushort_max(self): def test_ushort_max(self):
with support.catch_warning() as w: with warnings.catch_warnings(record=True) as w:
ts.T_USHORT = USHRT_MAX+1 ts.T_USHORT = USHRT_MAX+1
self.has_warned(w) self.has_warned(w)

View File

@ -7,7 +7,7 @@ import warnings
class TestUntestedModules(unittest.TestCase): class TestUntestedModules(unittest.TestCase):
def test_at_least_import_untested_modules(self): def test_at_least_import_untested_modules(self):
with support.catch_warning(): with warnings.catch_warnings(record=True):
import aifc import aifc
import bdb import bdb
import cgitb import cgitb

View File

@ -72,64 +72,69 @@ class FilterTests(object):
"""Testing the filtering functionality.""" """Testing the filtering functionality."""
def test_error(self): def test_error(self):
with support.catch_warning(self.module) as w: with original_warnings.catch_warnings(module=self.module) as w:
self.module.resetwarnings() self.module.resetwarnings()
self.module.filterwarnings("error", category=UserWarning) self.module.filterwarnings("error", category=UserWarning)
self.assertRaises(UserWarning, self.module.warn, self.assertRaises(UserWarning, self.module.warn,
"FilterTests.test_error") "FilterTests.test_error")
def test_ignore(self): def test_ignore(self):
with support.catch_warning(module=self.module) as w: with original_warnings.catch_warnings(record=True,
module=self.module) as w:
self.module.resetwarnings() self.module.resetwarnings()
self.module.filterwarnings("ignore", category=UserWarning) self.module.filterwarnings("ignore", category=UserWarning)
self.module.warn("FilterTests.test_ignore", UserWarning) self.module.warn("FilterTests.test_ignore", UserWarning)
self.assertEquals(len(w), 0) self.assertEquals(len(w), 0)
def test_always(self): def test_always(self):
with support.catch_warning(module=self.module) as w: with original_warnings.catch_warnings(record=True,
module=self.module) as w:
self.module.resetwarnings() self.module.resetwarnings()
self.module.filterwarnings("always", category=UserWarning) self.module.filterwarnings("always", category=UserWarning)
message = "FilterTests.test_always" message = "FilterTests.test_always"
self.module.warn(message, UserWarning) self.module.warn(message, UserWarning)
self.assert_(message, w.message) self.assert_(message, w[-1].message)
self.module.warn(message, UserWarning) self.module.warn(message, UserWarning)
self.assert_(w.message, message) self.assert_(w[-1].message, message)
def test_default(self): def test_default(self):
with support.catch_warning(self.module) as w: with original_warnings.catch_warnings(record=True,
module=self.module) as w:
self.module.resetwarnings() self.module.resetwarnings()
self.module.filterwarnings("default", category=UserWarning) self.module.filterwarnings("default", category=UserWarning)
message = UserWarning("FilterTests.test_default") message = UserWarning("FilterTests.test_default")
for x in range(2): for x in range(2):
self.module.warn(message, UserWarning) self.module.warn(message, UserWarning)
if x == 0: if x == 0:
self.assertEquals(w.message, message) self.assertEquals(w[-1].message, message)
w.reset() del w[:]
elif x == 1: elif x == 1:
self.assert_(not len(w), "unexpected warning: " + str(w)) self.assertEquals(len(w), 0)
else: else:
raise ValueError("loop variant unhandled") raise ValueError("loop variant unhandled")
def test_module(self): def test_module(self):
with support.catch_warning(self.module) as w: with original_warnings.catch_warnings(record=True,
module=self.module) as w:
self.module.resetwarnings() self.module.resetwarnings()
self.module.filterwarnings("module", category=UserWarning) self.module.filterwarnings("module", category=UserWarning)
message = UserWarning("FilterTests.test_module") message = UserWarning("FilterTests.test_module")
self.module.warn(message, UserWarning) self.module.warn(message, UserWarning)
self.assertEquals(w.message, message) self.assertEquals(w[-1].message, message)
w.reset() del w[:]
self.module.warn(message, UserWarning) self.module.warn(message, UserWarning)
self.assert_(not len(w), "unexpected message: " + str(w)) self.assertEquals(len(w), 0)
def test_once(self): def test_once(self):
with support.catch_warning(self.module) as w: with original_warnings.catch_warnings(record=True,
module=self.module) as w:
self.module.resetwarnings() self.module.resetwarnings()
self.module.filterwarnings("once", category=UserWarning) self.module.filterwarnings("once", category=UserWarning)
message = UserWarning("FilterTests.test_once") message = UserWarning("FilterTests.test_once")
self.module.warn_explicit(message, UserWarning, "test_warnings.py", self.module.warn_explicit(message, UserWarning, "test_warnings.py",
42) 42)
self.assertEquals(w.message, message) self.assertEquals(w[-1].message, message)
w.reset() del w[:]
self.module.warn_explicit(message, UserWarning, "test_warnings.py", self.module.warn_explicit(message, UserWarning, "test_warnings.py",
13) 13)
self.assertEquals(len(w), 0) self.assertEquals(len(w), 0)
@ -138,19 +143,20 @@ class FilterTests(object):
self.assertEquals(len(w), 0) self.assertEquals(len(w), 0)
def test_inheritance(self): def test_inheritance(self):
with support.catch_warning(self.module) as w: with original_warnings.catch_warnings(module=self.module) as w:
self.module.resetwarnings() self.module.resetwarnings()
self.module.filterwarnings("error", category=Warning) self.module.filterwarnings("error", category=Warning)
self.assertRaises(UserWarning, self.module.warn, self.assertRaises(UserWarning, self.module.warn,
"FilterTests.test_inheritance", UserWarning) "FilterTests.test_inheritance", UserWarning)
def test_ordering(self): def test_ordering(self):
with support.catch_warning(self.module) as w: with original_warnings.catch_warnings(record=True,
module=self.module) as w:
self.module.resetwarnings() self.module.resetwarnings()
self.module.filterwarnings("ignore", category=UserWarning) self.module.filterwarnings("ignore", category=UserWarning)
self.module.filterwarnings("error", category=UserWarning, self.module.filterwarnings("error", category=UserWarning,
append=True) append=True)
w.reset() del w[:]
try: try:
self.module.warn("FilterTests.test_ordering", UserWarning) self.module.warn("FilterTests.test_ordering", UserWarning)
except UserWarning: except UserWarning:
@ -160,28 +166,29 @@ class FilterTests(object):
def test_filterwarnings(self): def test_filterwarnings(self):
# Test filterwarnings(). # Test filterwarnings().
# Implicitly also tests resetwarnings(). # Implicitly also tests resetwarnings().
with support.catch_warning(self.module) as w: with original_warnings.catch_warnings(record=True,
module=self.module) as w:
self.module.filterwarnings("error", "", Warning, "", 0) self.module.filterwarnings("error", "", Warning, "", 0)
self.assertRaises(UserWarning, self.module.warn, 'convert to error') self.assertRaises(UserWarning, self.module.warn, 'convert to error')
self.module.resetwarnings() self.module.resetwarnings()
text = 'handle normally' text = 'handle normally'
self.module.warn(text) self.module.warn(text)
self.assertEqual(str(w.message), text) self.assertEqual(str(w[-1].message), text)
self.assert_(w.category is UserWarning) self.assert_(w[-1].category is UserWarning)
self.module.filterwarnings("ignore", "", Warning, "", 0) self.module.filterwarnings("ignore", "", Warning, "", 0)
text = 'filtered out' text = 'filtered out'
self.module.warn(text) self.module.warn(text)
self.assertNotEqual(str(w.message), text) self.assertNotEqual(str(w[-1].message), text)
self.module.resetwarnings() self.module.resetwarnings()
self.module.filterwarnings("error", "hex*", Warning, "", 0) self.module.filterwarnings("error", "hex*", Warning, "", 0)
self.assertRaises(UserWarning, self.module.warn, 'hex/oct') self.assertRaises(UserWarning, self.module.warn, 'hex/oct')
text = 'nonmatching text' text = 'nonmatching text'
self.module.warn(text) self.module.warn(text)
self.assertEqual(str(w.message), text) self.assertEqual(str(w[-1].message), text)
self.assert_(w.category is UserWarning) self.assert_(w[-1].category is UserWarning)
class CFilterTests(BaseTest, FilterTests): class CFilterTests(BaseTest, FilterTests):
module = c_warnings module = c_warnings
@ -195,12 +202,13 @@ class WarnTests(unittest.TestCase):
"""Test warnings.warn() and warnings.warn_explicit().""" """Test warnings.warn() and warnings.warn_explicit()."""
def test_message(self): def test_message(self):
with support.catch_warning(self.module) as w: with original_warnings.catch_warnings(record=True,
module=self.module) as w:
for i in range(4): for i in range(4):
text = 'multi %d' %i # Different text on each call. text = 'multi %d' %i # Different text on each call.
self.module.warn(text) self.module.warn(text)
self.assertEqual(str(w.message), text) self.assertEqual(str(w[-1].message), text)
self.assert_(w.category is UserWarning) self.assert_(w[-1].category is UserWarning)
# Issue 3639 # Issue 3639
def test_warn_nonstandard_types(self): def test_warn_nonstandard_types(self):
@ -210,35 +218,45 @@ class WarnTests(unittest.TestCase):
self.module.warn(ob) self.module.warn(ob)
# Don't directly compare objects since # Don't directly compare objects since
# ``Warning() != Warning()``. # ``Warning() != Warning()``.
self.assertEquals(str(w.message), str(UserWarning(ob))) self.assertEquals(str(w[-1].message), str(UserWarning(ob)))
def test_filename(self): def test_filename(self):
with warnings_state(self.module): with warnings_state(self.module):
with support.catch_warning(self.module) as w: with original_warnings.catch_warnings(record=True,
module=self.module) as w:
warning_tests.inner("spam1") warning_tests.inner("spam1")
self.assertEqual(os.path.basename(w.filename), "warning_tests.py") self.assertEqual(os.path.basename(w[-1].filename),
"warning_tests.py")
warning_tests.outer("spam2") warning_tests.outer("spam2")
self.assertEqual(os.path.basename(w.filename), "warning_tests.py") self.assertEqual(os.path.basename(w[-1].filename),
"warning_tests.py")
def test_stacklevel(self): def test_stacklevel(self):
# Test stacklevel argument # Test stacklevel argument
# make sure all messages are different, so the warning won't be skipped # make sure all messages are different, so the warning won't be skipped
with warnings_state(self.module): with warnings_state(self.module):
with support.catch_warning(self.module) as w: with original_warnings.catch_warnings(record=True,
module=self.module) as w:
warning_tests.inner("spam3", stacklevel=1) warning_tests.inner("spam3", stacklevel=1)
self.assertEqual(os.path.basename(w.filename), "warning_tests.py") self.assertEqual(os.path.basename(w[-1].filename),
"warning_tests.py")
warning_tests.outer("spam4", stacklevel=1) warning_tests.outer("spam4", stacklevel=1)
self.assertEqual(os.path.basename(w.filename), "warning_tests.py") self.assertEqual(os.path.basename(w[-1].filename),
"warning_tests.py")
warning_tests.inner("spam5", stacklevel=2) warning_tests.inner("spam5", stacklevel=2)
self.assertEqual(os.path.basename(w.filename), "test_warnings.py") self.assertEqual(os.path.basename(w[-1].filename),
"test_warnings.py")
warning_tests.outer("spam6", stacklevel=2) warning_tests.outer("spam6", stacklevel=2)
self.assertEqual(os.path.basename(w.filename), "warning_tests.py") self.assertEqual(os.path.basename(w[-1].filename),
"warning_tests.py")
warning_tests.outer("spam6.5", stacklevel=3) warning_tests.outer("spam6.5", stacklevel=3)
self.assertEqual(os.path.basename(w.filename), "test_warnings.py") self.assertEqual(os.path.basename(w[-1].filename),
"test_warnings.py")
warning_tests.inner("spam7", stacklevel=9999) warning_tests.inner("spam7", stacklevel=9999)
self.assertEqual(os.path.basename(w.filename), "sys") self.assertEqual(os.path.basename(w[-1].filename),
"sys")
def test_missing_filename_not_main(self): def test_missing_filename_not_main(self):
# If __file__ is not specified and __main__ is not the module name, # If __file__ is not specified and __main__ is not the module name,
@ -247,9 +265,10 @@ class WarnTests(unittest.TestCase):
try: try:
del warning_tests.__file__ del warning_tests.__file__
with warnings_state(self.module): with warnings_state(self.module):
with support.catch_warning(self.module) as w: with original_warnings.catch_warnings(record=True,
module=self.module) as w:
warning_tests.inner("spam8", stacklevel=1) warning_tests.inner("spam8", stacklevel=1)
self.assertEqual(w.filename, warning_tests.__name__) self.assertEqual(w[-1].filename, warning_tests.__name__)
finally: finally:
warning_tests.__file__ = filename warning_tests.__file__ = filename
@ -264,9 +283,10 @@ class WarnTests(unittest.TestCase):
del warning_tests.__file__ del warning_tests.__file__
warning_tests.__name__ = '__main__' warning_tests.__name__ = '__main__'
with warnings_state(self.module): with warnings_state(self.module):
with support.catch_warning(self.module) as w: with original_warnings.catch_warnings(record=True,
module=self.module) as w:
warning_tests.inner('spam9', stacklevel=1) warning_tests.inner('spam9', stacklevel=1)
self.assertEqual(w.filename, sys.argv[0]) self.assertEqual(w[-1].filename, sys.argv[0])
finally: finally:
warning_tests.__file__ = filename warning_tests.__file__ = filename
warning_tests.__name__ = module_name warning_tests.__name__ = module_name
@ -282,9 +302,10 @@ class WarnTests(unittest.TestCase):
warning_tests.__name__ = '__main__' warning_tests.__name__ = '__main__'
del sys.argv del sys.argv
with warnings_state(self.module): with warnings_state(self.module):
with support.catch_warning(self.module) as w: with original_warnings.catch_warnings(record=True,
module=self.module) as w:
warning_tests.inner('spam10', stacklevel=1) warning_tests.inner('spam10', stacklevel=1)
self.assertEqual(w.filename, '__main__') self.assertEqual(w[-1].filename, '__main__')
finally: finally:
warning_tests.__file__ = filename warning_tests.__file__ = filename
warning_tests.__name__ = module_name warning_tests.__name__ = module_name
@ -302,9 +323,10 @@ class WarnTests(unittest.TestCase):
warning_tests.__name__ = '__main__' warning_tests.__name__ = '__main__'
sys.argv = [''] sys.argv = ['']
with warnings_state(self.module): with warnings_state(self.module):
with support.catch_warning(self.module) as w: with original_warnings.catch_warnings(record=True,
module=self.module) as w:
warning_tests.inner('spam11', stacklevel=1) warning_tests.inner('spam11', stacklevel=1)
self.assertEqual(w.filename, '__main__') self.assertEqual(w[-1].filename, '__main__')
finally: finally:
warning_tests.__file__ = file_name warning_tests.__file__ = file_name
warning_tests.__name__ = module_name warning_tests.__name__ = module_name
@ -337,7 +359,7 @@ class WCmdLineTests(unittest.TestCase):
def test_improper_input(self): def test_improper_input(self):
# Uses the private _setoption() function to test the parsing # Uses the private _setoption() function to test the parsing
# of command-line warning arguments # of command-line warning arguments
with support.catch_warning(self.module): with original_warnings.catch_warnings(module=self.module):
self.assertRaises(self.module._OptionError, self.assertRaises(self.module._OptionError,
self.module._setoption, '1:2:3:4:5:6') self.module._setoption, '1:2:3:4:5:6')
self.assertRaises(self.module._OptionError, self.assertRaises(self.module._OptionError,
@ -362,7 +384,7 @@ class _WarningsTests(BaseTest):
def test_filter(self): def test_filter(self):
# Everything should function even if 'filters' is not in warnings. # Everything should function even if 'filters' is not in warnings.
with support.catch_warning(self.module) as w: with original_warnings.catch_warnings(module=self.module) as w:
self.module.filterwarnings("error", "", Warning, "", 0) self.module.filterwarnings("error", "", Warning, "", 0)
self.assertRaises(UserWarning, self.module.warn, self.assertRaises(UserWarning, self.module.warn,
'convert to error') 'convert to error')
@ -377,21 +399,22 @@ class _WarningsTests(BaseTest):
try: try:
original_registry = self.module.onceregistry original_registry = self.module.onceregistry
__warningregistry__ = {} __warningregistry__ = {}
with support.catch_warning(self.module) as w: with original_warnings.catch_warnings(record=True,
module=self.module) as w:
self.module.resetwarnings() self.module.resetwarnings()
self.module.filterwarnings("once", category=UserWarning) self.module.filterwarnings("once", category=UserWarning)
self.module.warn_explicit(message, UserWarning, "file", 42) self.module.warn_explicit(message, UserWarning, "file", 42)
self.failUnlessEqual(w.message, message) self.failUnlessEqual(w[-1].message, message)
w.reset() del w[:]
self.module.warn_explicit(message, UserWarning, "file", 42) self.module.warn_explicit(message, UserWarning, "file", 42)
self.assertEquals(len(w), 0) self.assertEquals(len(w), 0)
# Test the resetting of onceregistry. # Test the resetting of onceregistry.
self.module.onceregistry = {} self.module.onceregistry = {}
__warningregistry__ = {} __warningregistry__ = {}
self.module.warn('onceregistry test') self.module.warn('onceregistry test')
self.failUnlessEqual(w.message.args, message.args) self.failUnlessEqual(w[-1].message.args, message.args)
# Removal of onceregistry is okay. # Removal of onceregistry is okay.
w.reset() del w[:]
del self.module.onceregistry del self.module.onceregistry
__warningregistry__ = {} __warningregistry__ = {}
self.module.warn_explicit(message, UserWarning, "file", 42) self.module.warn_explicit(message, UserWarning, "file", 42)
@ -402,7 +425,7 @@ class _WarningsTests(BaseTest):
def test_showwarning_missing(self): def test_showwarning_missing(self):
# Test that showwarning() missing is okay. # Test that showwarning() missing is okay.
text = 'del showwarning test' text = 'del showwarning test'
with support.catch_warning(self.module): with original_warnings.catch_warnings(module=self.module):
self.module.filterwarnings("always", category=UserWarning) self.module.filterwarnings("always", category=UserWarning)
del self.module.showwarning del self.module.showwarning
with support.captured_output('stderr') as stream: with support.captured_output('stderr') as stream:
@ -423,7 +446,7 @@ class _WarningsTests(BaseTest):
def test_show_warning_output(self): def test_show_warning_output(self):
# With showarning() missing, make sure that output is okay. # With showarning() missing, make sure that output is okay.
text = 'test show_warning' text = 'test show_warning'
with support.catch_warning(self.module): with original_warnings.catch_warnings(module=self.module):
self.module.filterwarnings("always", category=UserWarning) self.module.filterwarnings("always", category=UserWarning)
del self.module.showwarning del self.module.showwarning
with support.captured_output('stderr') as stream: with support.captured_output('stderr') as stream:
@ -494,6 +517,7 @@ class CWarningsDisplayTests(BaseTest, WarningsDisplayTests):
class PyWarningsDisplayTests(BaseTest, WarningsDisplayTests): class PyWarningsDisplayTests(BaseTest, WarningsDisplayTests):
module = py_warnings module = py_warnings
class CatchWarningTests(BaseTest): class CatchWarningTests(BaseTest):
"""Test catch_warnings().""" """Test catch_warnings()."""
@ -517,12 +541,12 @@ class CatchWarningTests(BaseTest):
self.assertEqual(w, []) self.assertEqual(w, [])
wmod.simplefilter("always") wmod.simplefilter("always")
wmod.warn("foo") wmod.warn("foo")
self.assertEqual(str(w.message), "foo") self.assertEqual(str(w[-1].message), "foo")
wmod.warn("bar") wmod.warn("bar")
self.assertEqual(str(w.message), "bar") self.assertEqual(str(w[-1].message), "bar")
self.assertEqual(str(w[0].message), "foo") self.assertEqual(str(w[0].message), "foo")
self.assertEqual(str(w[1].message), "bar") self.assertEqual(str(w[1].message), "bar")
w.reset() del w[:]
self.assertEqual(w, []) self.assertEqual(w, [])
orig_showwarning = wmod.showwarning orig_showwarning = wmod.showwarning
with support.catch_warning(module=wmod, record=False) as w: with support.catch_warning(module=wmod, record=False) as w:

View File

@ -7,7 +7,7 @@ import linecache
import sys import sys
__all__ = ["warn", "showwarning", "formatwarning", "filterwarnings", __all__ = ["warn", "showwarning", "formatwarning", "filterwarnings",
"resetwarnings"] "resetwarnings", "catch_warnings"]
def showwarning(message, category, filename, lineno, file=None, line=None): def showwarning(message, category, filename, lineno, file=None, line=None):
@ -274,28 +274,20 @@ class WarningMessage(object):
self.filename, self.lineno, self.line)) self.filename, self.lineno, self.line))
class WarningsRecorder(list):
"""Record the result of various showwarning() calls."""
def showwarning(self, *args, **kwargs):
self.append(WarningMessage(*args, **kwargs))
def __getattr__(self, attr):
return getattr(self[-1], attr)
def reset(self):
del self[:]
class catch_warnings(object): class catch_warnings(object):
"""Guard the warnings filter from being permanently changed and optionally """A context manager that copies and restores the warnings filter upon
record the details of any warnings that are issued. exiting the context.
Context manager returns an instance of warnings.WarningRecorder which is a The 'record' argument specifies whether warnings should be captured by a
list of WarningMessage instances. Attributes on WarningRecorder are custom implementation of warnings.showwarning() and be appended to a list
redirected to the last created WarningMessage instance. returned by the context manager. Otherwise None is returned by the context
manager. The objects appended to the list are arguments whose attributes
mirror the arguments to showwarning().
The 'module' argument is to specify an alternative module to the module
named 'warnings' and imported under that name. This argument is only useful
when testing the warnings module itself.
""" """
@ -307,17 +299,21 @@ class catch_warnings(object):
keyword-only. keyword-only.
""" """
self._recorder = WarningsRecorder() if record else None self._record = record
self._module = sys.modules['warnings'] if module is None else module self._module = sys.modules['warnings'] if module is None else module
def __enter__(self): def __enter__(self):
self._filters = self._module.filters self._filters = self._module.filters
self._module.filters = self._filters[:] self._module.filters = self._filters[:]
self._showwarning = self._module.showwarning self._showwarning = self._module.showwarning
if self._recorder is not None: if self._record:
self._recorder.reset() # In case the instance is being reused. log = []
self._module.showwarning = self._recorder.showwarning def showwarning(*args, **kwargs):
return self._recorder log.append(WarningMessage(*args, **kwargs))
self._module.showwarning = showwarning
return log
else:
return None
def __exit__(self, *exc_info): def __exit__(self, *exc_info):
self._module.filters = self._filters self._module.filters = self._filters