diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py index dcaae93aed5..ca545f97efd 100644 --- a/Lib/unittest/case.py +++ b/Lib/unittest/case.py @@ -129,6 +129,27 @@ class _AssertRaisesContext(object): return True +class _TypeEqualityDict(object): + + def __init__(self, testcase): + self.testcase = testcase + self._store = {} + + def __setitem__(self, key, value): + self._store[key] = value + + def __getitem__(self, key): + value = self._store[key] + if isinstance(value, basestring): + return getattr(self.testcase, value) + return value + + def get(self, key, default=None): + if key in self._store: + return self[key] + return default + + class TestCase(object): """A class whose instances are single test cases. @@ -195,13 +216,13 @@ class TestCase(object): # Map types to custom assertEqual functions that will compare # instances of said type in more detail to generate a more useful # error message. - self._type_equality_funcs = {} - self.addTypeEqualityFunc(dict, self.assertDictEqual) - self.addTypeEqualityFunc(list, self.assertListEqual) - self.addTypeEqualityFunc(tuple, self.assertTupleEqual) - self.addTypeEqualityFunc(set, self.assertSetEqual) - self.addTypeEqualityFunc(frozenset, self.assertSetEqual) - self.addTypeEqualityFunc(unicode, self.assertMultiLineEqual) + self._type_equality_funcs = _TypeEqualityDict(self) + self.addTypeEqualityFunc(dict, 'assertDictEqual') + self.addTypeEqualityFunc(list, 'assertListEqual') + self.addTypeEqualityFunc(tuple, 'assertTupleEqual') + self.addTypeEqualityFunc(set, 'assertSetEqual') + self.addTypeEqualityFunc(frozenset, 'assertSetEqual') + self.addTypeEqualityFunc(unicode, 'assertMultiLineEqual') def addTypeEqualityFunc(self, typeobj, function): """Add a type specific assertEqual style function to compare a type. diff --git a/Lib/unittest/test/test_case.py b/Lib/unittest/test/test_case.py index 2e449c5c2ae..e92b0191674 100644 --- a/Lib/unittest/test/test_case.py +++ b/Lib/unittest/test/test_case.py @@ -1,5 +1,6 @@ import difflib import pprint +import pickle import re import sys @@ -1104,6 +1105,20 @@ test case self.assertEqual(len(result.errors), 1) self.assertEqual(result.testsRun, 1) + def testPickle(self): + # Issue 10326 + + # Can't use TestCase classes defined in Test class as + # pickle does not work with inner classes + test = unittest.TestCase('run') + for protocol in range(pickle.HIGHEST_PROTOCOL + 1): + + # blew up prior to fix + pickled_test = pickle.dumps(test, protocol=protocol) + + unpickled_test = pickle.loads(pickled_test) + self.assertEqual(test, unpickled_test) + if __name__ == '__main__': unittest.main()