Issue 8351. Suppress large diffs in unittest.TestCase.assertSequenceEqual.

This commit is contained in:
Michael Foord 2010-06-05 11:23:51 +00:00
parent 9ef5d33084
commit 0100702b9a
2 changed files with 29 additions and 3 deletions

View File

@ -13,7 +13,7 @@ from .util import (
) )
__unittest = True __unittest = True
TRUNCATED_DIFF = '\n[diff truncated...]'
class SkipTest(Exception): class SkipTest(Exception):
""" """
@ -589,7 +589,8 @@ class TestCase(object):
failUnlessRaises = _deprecate(assertRaises) failUnlessRaises = _deprecate(assertRaises)
failIf = _deprecate(assertFalse) failIf = _deprecate(assertFalse)
def assertSequenceEqual(self, seq1, seq2, msg=None, seq_type=None): def assertSequenceEqual(self, seq1, seq2, msg=None, seq_type=None,
max_diff=80*8):
"""An equality assertion for ordered sequences (like lists and tuples). """An equality assertion for ordered sequences (like lists and tuples).
For the purposes of this function, a valid ordered sequence type is one For the purposes of this function, a valid ordered sequence type is one
@ -602,6 +603,7 @@ class TestCase(object):
datatype should be enforced. datatype should be enforced.
msg: Optional message to use on failure instead of a list of msg: Optional message to use on failure instead of a list of
differences. differences.
max_diff: Maximum size off the diff, larger diffs are not shown
""" """
if seq_type is not None: if seq_type is not None:
seq_type_name = seq_type.__name__ seq_type_name = seq_type.__name__
@ -684,9 +686,14 @@ class TestCase(object):
except (TypeError, IndexError, NotImplementedError): except (TypeError, IndexError, NotImplementedError):
differing += ('Unable to index element %d ' differing += ('Unable to index element %d '
'of second %s\n' % (len1, seq_type_name)) 'of second %s\n' % (len1, seq_type_name))
standardMsg = differing + '\n' + '\n'.join( standardMsg = differing
diffMsg = '\n' + '\n'.join(
difflib.ndiff(pprint.pformat(seq1).splitlines(), difflib.ndiff(pprint.pformat(seq1).splitlines(),
pprint.pformat(seq2).splitlines())) pprint.pformat(seq2).splitlines()))
if max_diff is None or len(diffMsg) <= max_diff:
standardMsg += diffMsg
else:
standardMsg += diffMsg[:max_diff] + TRUNCATED_DIFF
msg = self._formatMessage(msg, standardMsg) msg = self._formatMessage(msg, standardMsg)
self.fail(msg) self.fail(msg)

View File

@ -1,3 +1,5 @@
import difflib
import pprint
import re import re
import sys import sys
@ -588,6 +590,23 @@ class Test_TestCase(unittest.TestCase, TestEquality, TestHashing):
self.assertRaises(self.failureException, self.assertDictEqual, [], d) self.assertRaises(self.failureException, self.assertDictEqual, [], d)
self.assertRaises(self.failureException, self.assertDictEqual, 1, 1) self.assertRaises(self.failureException, self.assertDictEqual, 1, 1)
def testAssertSequenceEqualMaxDiff(self):
seq1 = 'a' + 'x' * 80**2
seq2 = 'b' + 'x' * 80**2
diff = '\n'.join(difflib.ndiff(pprint.pformat(seq1).splitlines(),
pprint.pformat(seq2).splitlines()))
try:
self.assertSequenceEqual(seq1, seq2, max_diff=len(diff)/2)
except AssertionError as e:
msg = e.args[0]
self.assertTrue(len(msg) < len(diff))
try:
self.assertSequenceEqual(seq1, seq2, max_diff=len(diff)*2)
except AssertionError as e:
msg = e.args[0]
self.assertTrue(len(msg) > len(diff))
def testAssertItemsEqual(self): def testAssertItemsEqual(self):
a = object() a = object()
self.assertItemsEqual([1, 2, 3], [3, 2, 1]) self.assertItemsEqual([1, 2, 3], [3, 2, 1])