diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py index 2ef48e64d74..6d971aa273b 100644 --- a/Lib/test/pickletester.py +++ b/Lib/test/pickletester.py @@ -103,6 +103,15 @@ class E(C): class H(object): pass +# Hashable mutable key +class K(object): + def __init__(self, value): + self.value = value + + def __reduce__(self): + # Shouldn't support the recursion itself + return K, (self.value,) + import __main__ __main__.C = C C.__module__ = "__main__" @@ -112,6 +121,8 @@ __main__.E = E E.__module__ = "__main__" __main__.H = H H.__module__ = "__main__" +__main__.K = K +K.__module__ = "__main__" class myint(int): def __init__(self, x): @@ -1041,9 +1052,9 @@ class AbstractPickleTests(unittest.TestCase): x = self.loads(s) self.assertIsInstance(x, list) self.assertEqual(len(x), 1) - self.assertTrue(x is x[0]) + self.assertIs(x[0], x) - def test_recursive_tuple(self): + def test_recursive_tuple_and_list(self): t = ([],) t[0].append(t) for proto in protocols: @@ -1051,8 +1062,9 @@ class AbstractPickleTests(unittest.TestCase): x = self.loads(s) self.assertIsInstance(x, tuple) self.assertEqual(len(x), 1) + self.assertIsInstance(x[0], list) self.assertEqual(len(x[0]), 1) - self.assertTrue(x is x[0][0]) + self.assertIs(x[0][0], x) def test_recursive_dict(self): d = {} @@ -1062,29 +1074,63 @@ class AbstractPickleTests(unittest.TestCase): x = self.loads(s) self.assertIsInstance(x, dict) self.assertEqual(list(x.keys()), [1]) - self.assertTrue(x[1] is x) + self.assertIs(x[1], x) + + def test_recursive_dict_key(self): + d = {} + k = K(d) + d[k] = 1 + for proto in protocols: + s = self.dumps(d, proto) + x = self.loads(s) + self.assertIsInstance(x, dict) + self.assertEqual(len(x.keys()), 1) + self.assertIsInstance(list(x.keys())[0], K) + self.assertIs(list(x.keys())[0].value, x) def test_recursive_set(self): - h = H() - y = set({h}) - h.attr = y - for proto in protocols: + y = set() + k = K(y) + y.add(k) + for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): s = self.dumps(y, proto) x = self.loads(s) self.assertIsInstance(x, set) - self.assertIs(list(x)[0].attr, x) self.assertEqual(len(x), 1) + self.assertIsInstance(list(x)[0], K) + self.assertIs(list(x)[0].value, x) - def test_recursive_frozenset(self): - h = H() - y = frozenset({h}) - h.attr = y - for proto in protocols: + def test_recursive_list_subclass(self): + y = MyList() + y.append(y) + for proto in range(2, pickle.HIGHEST_PROTOCOL + 1): s = self.dumps(y, proto) x = self.loads(s) - self.assertIsInstance(x, frozenset) - self.assertIs(list(x)[0].attr, x) + self.assertIsInstance(x, MyList) self.assertEqual(len(x), 1) + self.assertIs(x[0], x) + + def test_recursive_dict_subclass(self): + d = MyDict() + d[1] = d + for proto in range(2, pickle.HIGHEST_PROTOCOL + 1): + s = self.dumps(d, proto) + x = self.loads(s) + self.assertIsInstance(x, MyDict) + self.assertEqual(list(x.keys()), [1]) + self.assertIs(x[1], x) + + def test_recursive_dict_subclass_key(self): + d = MyDict() + k = K(d) + d[k] = 1 + for proto in range(2, pickle.HIGHEST_PROTOCOL + 1): + s = self.dumps(d, proto) + x = self.loads(s) + self.assertIsInstance(x, MyDict) + self.assertEqual(len(list(x.keys())), 1) + self.assertIsInstance(list(x.keys())[0], K) + self.assertIs(list(x.keys())[0].value, x) def test_recursive_inst(self): i = C() @@ -1111,6 +1157,48 @@ class AbstractPickleTests(unittest.TestCase): self.assertEqual(list(x[0].attr.keys()), [1]) self.assertTrue(x[0].attr[1] is x) + def check_recursive_collection_and_inst(self, factory): + h = H() + y = factory([h]) + h.attr = y + for proto in protocols: + s = self.dumps(y, proto) + x = self.loads(s) + self.assertIsInstance(x, type(y)) + self.assertEqual(len(x), 1) + self.assertIsInstance(list(x)[0], H) + self.assertIs(list(x)[0].attr, x) + + def test_recursive_list_and_inst(self): + self.check_recursive_collection_and_inst(list) + + def test_recursive_tuple_and_inst(self): + self.check_recursive_collection_and_inst(tuple) + + def test_recursive_dict_and_inst(self): + self.check_recursive_collection_and_inst(dict.fromkeys) + + def test_recursive_set_and_inst(self): + self.check_recursive_collection_and_inst(set) + + def test_recursive_frozenset_and_inst(self): + self.check_recursive_collection_and_inst(frozenset) + + def test_recursive_list_subclass_and_inst(self): + self.check_recursive_collection_and_inst(MyList) + + def test_recursive_tuple_subclass_and_inst(self): + self.check_recursive_collection_and_inst(MyTuple) + + def test_recursive_dict_subclass_and_inst(self): + self.check_recursive_collection_and_inst(MyDict.fromkeys) + + def test_recursive_set_subclass_and_inst(self): + self.check_recursive_collection_and_inst(MySet) + + def test_recursive_frozenset_subclass_and_inst(self): + self.check_recursive_collection_and_inst(MyFrozenSet) + def test_unicode(self): endcases = ['', '<\\u>', '<\\\u1234>', '<\n>', '<\\>', '<\\\U00012345>',