bpo-42809: Improve pickle tests for recursive data. (GH-24060)

(cherry picked from commit a25011be8c)

Co-authored-by: Serhiy Storchaka <storchaka@gmail.com>
This commit is contained in:
Miss Islington (bot) 2021-01-02 09:50:28 -08:00 committed by GitHub
parent 6dffa67b98
commit 2e8b1c9e9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 252 additions and 86 deletions

View File

@ -65,6 +65,10 @@ def count_opcode(code, pickle):
return n return n
def identity(x):
return x
class UnseekableIO(io.BytesIO): class UnseekableIO(io.BytesIO):
def peek(self, *args): def peek(self, *args):
raise NotImplementedError raise NotImplementedError
@ -134,11 +138,12 @@ class E(C):
def __getinitargs__(self): def __getinitargs__(self):
return () return ()
class H(object): # Simple mutable object.
class Object:
pass pass
# Hashable mutable key # Hashable immutable key object containing unheshable mutable data.
class K(object): class K:
def __init__(self, value): def __init__(self, value):
self.value = value self.value = value
@ -153,10 +158,6 @@ __main__.D = D
D.__module__ = "__main__" D.__module__ = "__main__"
__main__.E = E __main__.E = E
E.__module__ = "__main__" E.__module__ = "__main__"
__main__.H = H
H.__module__ = "__main__"
__main__.K = K
K.__module__ = "__main__"
class myint(int): class myint(int):
def __init__(self, x): def __init__(self, x):
@ -1490,54 +1491,182 @@ class AbstractPickleTests(unittest.TestCase):
got = filelike.getvalue() got = filelike.getvalue()
self.assertEqual(expected, got) self.assertEqual(expected, got)
def test_recursive_list(self): def _test_recursive_list(self, cls, aslist=identity, minprotocol=0):
l = [] # List containing itself.
l = cls()
l.append(l) l.append(l)
for proto in protocols: for proto in range(minprotocol, pickle.HIGHEST_PROTOCOL + 1):
s = self.dumps(l, proto) s = self.dumps(l, proto)
x = self.loads(s) x = self.loads(s)
self.assertIsInstance(x, list) self.assertIsInstance(x, cls)
self.assertEqual(len(x), 1) y = aslist(x)
self.assertIs(x[0], x) self.assertEqual(len(y), 1)
self.assertIs(y[0], x)
def test_recursive_tuple_and_list(self): def test_recursive_list(self):
t = ([],) self._test_recursive_list(list)
def test_recursive_list_subclass(self):
self._test_recursive_list(MyList, minprotocol=2)
def test_recursive_list_like(self):
self._test_recursive_list(REX_six, aslist=lambda x: x.items)
def _test_recursive_tuple_and_list(self, cls, aslist=identity, minprotocol=0):
# Tuple containing a list containing the original tuple.
t = (cls(),)
t[0].append(t) t[0].append(t)
for proto in protocols: for proto in range(minprotocol, pickle.HIGHEST_PROTOCOL + 1):
s = self.dumps(t, proto) s = self.dumps(t, proto)
x = self.loads(s) x = self.loads(s)
self.assertIsInstance(x, tuple) self.assertIsInstance(x, tuple)
self.assertEqual(len(x), 1) self.assertEqual(len(x), 1)
self.assertIsInstance(x[0], list) self.assertIsInstance(x[0], cls)
self.assertEqual(len(x[0]), 1) y = aslist(x[0])
self.assertIs(x[0][0], x) self.assertEqual(len(y), 1)
self.assertIs(y[0], x)
# List containing a tuple containing the original list.
t, = t
for proto in range(minprotocol, pickle.HIGHEST_PROTOCOL + 1):
s = self.dumps(t, proto)
x = self.loads(s)
self.assertIsInstance(x, cls)
y = aslist(x)
self.assertEqual(len(y), 1)
self.assertIsInstance(y[0], tuple)
self.assertEqual(len(y[0]), 1)
self.assertIs(y[0][0], x)
def test_recursive_tuple_and_list(self):
self._test_recursive_tuple_and_list(list)
def test_recursive_tuple_and_list_subclass(self):
self._test_recursive_tuple_and_list(MyList, minprotocol=2)
def test_recursive_tuple_and_list_like(self):
self._test_recursive_tuple_and_list(REX_six, aslist=lambda x: x.items)
def _test_recursive_dict(self, cls, asdict=identity, minprotocol=0):
# Dict containing itself.
d = cls()
d[1] = d
for proto in range(minprotocol, pickle.HIGHEST_PROTOCOL + 1):
s = self.dumps(d, proto)
x = self.loads(s)
self.assertIsInstance(x, cls)
y = asdict(x)
self.assertEqual(list(y.keys()), [1])
self.assertIs(y[1], x)
def test_recursive_dict(self): def test_recursive_dict(self):
d = {} self._test_recursive_dict(dict)
d[1] = d
for proto in protocols: def test_recursive_dict_subclass(self):
self._test_recursive_dict(MyDict, minprotocol=2)
def test_recursive_dict_like(self):
self._test_recursive_dict(REX_seven, asdict=lambda x: x.table)
def _test_recursive_tuple_and_dict(self, cls, asdict=identity, minprotocol=0):
# Tuple containing a dict containing the original tuple.
t = (cls(),)
t[0][1] = t
for proto in range(minprotocol, pickle.HIGHEST_PROTOCOL + 1):
s = self.dumps(t, proto)
x = self.loads(s)
self.assertIsInstance(x, tuple)
self.assertEqual(len(x), 1)
self.assertIsInstance(x[0], cls)
y = asdict(x[0])
self.assertEqual(list(y), [1])
self.assertIs(y[1], x)
# Dict containing a tuple containing the original dict.
t, = t
for proto in range(minprotocol, pickle.HIGHEST_PROTOCOL + 1):
s = self.dumps(t, proto)
x = self.loads(s)
self.assertIsInstance(x, cls)
y = asdict(x)
self.assertEqual(list(y), [1])
self.assertIsInstance(y[1], tuple)
self.assertEqual(len(y[1]), 1)
self.assertIs(y[1][0], x)
def test_recursive_tuple_and_dict(self):
self._test_recursive_tuple_and_dict(dict)
def test_recursive_tuple_and_dict_subclass(self):
self._test_recursive_tuple_and_dict(MyDict, minprotocol=2)
def test_recursive_tuple_and_dict_like(self):
self._test_recursive_tuple_and_dict(REX_seven, asdict=lambda x: x.table)
def _test_recursive_dict_key(self, cls, asdict=identity, minprotocol=0):
# Dict containing an immutable object (as key) containing the original
# dict.
d = cls()
d[K(d)] = 1
for proto in range(minprotocol, pickle.HIGHEST_PROTOCOL + 1):
s = self.dumps(d, proto) s = self.dumps(d, proto)
x = self.loads(s) x = self.loads(s)
self.assertIsInstance(x, dict) self.assertIsInstance(x, cls)
self.assertEqual(list(x.keys()), [1]) y = asdict(x)
self.assertIs(x[1], x) self.assertEqual(len(y.keys()), 1)
self.assertIsInstance(list(y.keys())[0], K)
self.assertIs(list(y.keys())[0].value, x)
def test_recursive_dict_key(self): def test_recursive_dict_key(self):
d = {} self._test_recursive_dict_key(dict)
k = K(d)
d[k] = 1 def test_recursive_dict_subclass_key(self):
for proto in protocols: self._test_recursive_dict_key(MyDict, minprotocol=2)
s = self.dumps(d, proto)
def test_recursive_dict_like_key(self):
self._test_recursive_dict_key(REX_seven, asdict=lambda x: x.table)
def _test_recursive_tuple_and_dict_key(self, cls, asdict=identity, minprotocol=0):
# Tuple containing a dict containing an immutable object (as key)
# containing the original tuple.
t = (cls(),)
t[0][K(t)] = 1
for proto in range(minprotocol, pickle.HIGHEST_PROTOCOL + 1):
s = self.dumps(t, proto)
x = self.loads(s) x = self.loads(s)
self.assertIsInstance(x, dict) self.assertIsInstance(x, tuple)
self.assertEqual(len(x.keys()), 1) self.assertEqual(len(x), 1)
self.assertIsInstance(list(x.keys())[0], K) self.assertIsInstance(x[0], cls)
self.assertIs(list(x.keys())[0].value, x) y = asdict(x[0])
self.assertEqual(len(y), 1)
self.assertIsInstance(list(y.keys())[0], K)
self.assertIs(list(y.keys())[0].value, x)
# Dict containing an immutable object (as key) containing a tuple
# containing the original dict.
t, = t
for proto in range(minprotocol, pickle.HIGHEST_PROTOCOL + 1):
s = self.dumps(t, proto)
x = self.loads(s)
self.assertIsInstance(x, cls)
y = asdict(x)
self.assertEqual(len(y), 1)
self.assertIsInstance(list(y.keys())[0], K)
self.assertIs(list(y.keys())[0].value[0], x)
def test_recursive_tuple_and_dict_key(self):
self._test_recursive_tuple_and_dict_key(dict)
def test_recursive_tuple_and_dict_subclass_key(self):
self._test_recursive_tuple_and_dict_key(MyDict, minprotocol=2)
def test_recursive_tuple_and_dict_like_key(self):
self._test_recursive_tuple_and_dict_key(REX_seven, asdict=lambda x: x.table)
def test_recursive_set(self): def test_recursive_set(self):
# Set containing an immutable object containing the original set.
y = set() y = set()
k = K(y) y.add(K(y))
y.add(k)
for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
s = self.dumps(y, proto) s = self.dumps(y, proto)
x = self.loads(s) x = self.loads(s)
@ -1546,52 +1675,31 @@ class AbstractPickleTests(unittest.TestCase):
self.assertIsInstance(list(x)[0], K) self.assertIsInstance(list(x)[0], K)
self.assertIs(list(x)[0].value, x) self.assertIs(list(x)[0].value, x)
def test_recursive_list_subclass(self): # Immutable object containing a set containing the original object.
y = MyList() y, = y
y.append(y) for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
s = self.dumps(y, proto) s = self.dumps(y, proto)
x = self.loads(s) x = self.loads(s)
self.assertIsInstance(x, MyList) self.assertIsInstance(x, K)
self.assertEqual(len(x), 1) self.assertIsInstance(x.value, set)
self.assertIs(x[0], x) self.assertEqual(len(x.value), 1)
self.assertIs(list(x.value)[0], x)
def test_recursive_dict_subclass(self):
d = MyDict()
d[1] = d
for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
s = self.dumps(d, proto)
x = self.loads(s)
self.assertIsInstance(x, MyDict)
self.assertEqual(list(x.keys()), [1])
self.assertIs(x[1], x)
def test_recursive_dict_subclass_key(self):
d = MyDict()
k = K(d)
d[k] = 1
for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
s = self.dumps(d, proto)
x = self.loads(s)
self.assertIsInstance(x, MyDict)
self.assertEqual(len(list(x.keys())), 1)
self.assertIsInstance(list(x.keys())[0], K)
self.assertIs(list(x.keys())[0].value, x)
def test_recursive_inst(self): def test_recursive_inst(self):
i = C() # Mutable object containing itself.
i = Object()
i.attr = i i.attr = i
for proto in protocols: for proto in protocols:
s = self.dumps(i, proto) s = self.dumps(i, proto)
x = self.loads(s) x = self.loads(s)
self.assertIsInstance(x, C) self.assertIsInstance(x, Object)
self.assertEqual(dir(x), dir(i)) self.assertEqual(dir(x), dir(i))
self.assertIs(x.attr, x) self.assertIs(x.attr, x)
def test_recursive_multi(self): def test_recursive_multi(self):
l = [] l = []
d = {1:l} d = {1:l}
i = C() i = Object()
i.attr = d i.attr = d
l.append(i) l.append(i)
for proto in protocols: for proto in protocols:
@ -1601,49 +1709,94 @@ class AbstractPickleTests(unittest.TestCase):
self.assertEqual(len(x), 1) self.assertEqual(len(x), 1)
self.assertEqual(dir(x[0]), dir(i)) self.assertEqual(dir(x[0]), dir(i))
self.assertEqual(list(x[0].attr.keys()), [1]) self.assertEqual(list(x[0].attr.keys()), [1])
self.assertTrue(x[0].attr[1] is x) self.assertIs(x[0].attr[1], x)
def check_recursive_collection_and_inst(self, factory): def _test_recursive_collection_and_inst(self, factory):
h = H() # Mutable object containing a collection containing the original
y = factory([h]) # object.
h.attr = y o = Object()
o.attr = factory([o])
t = type(o.attr)
for proto in protocols: for proto in protocols:
s = self.dumps(y, proto) s = self.dumps(o, proto)
x = self.loads(s) x = self.loads(s)
self.assertIsInstance(x, type(y)) self.assertIsInstance(x.attr, t)
self.assertEqual(len(x.attr), 1)
self.assertIsInstance(list(x.attr)[0], Object)
self.assertIs(list(x.attr)[0], x)
# Collection containing a mutable object containing the original
# collection.
o = o.attr
for proto in protocols:
s = self.dumps(o, proto)
x = self.loads(s)
self.assertIsInstance(x, t)
self.assertEqual(len(x), 1) self.assertEqual(len(x), 1)
self.assertIsInstance(list(x)[0], H) self.assertIsInstance(list(x)[0], Object)
self.assertIs(list(x)[0].attr, x) self.assertIs(list(x)[0].attr, x)
def test_recursive_list_and_inst(self): def test_recursive_list_and_inst(self):
self.check_recursive_collection_and_inst(list) self._test_recursive_collection_and_inst(list)
def test_recursive_tuple_and_inst(self): def test_recursive_tuple_and_inst(self):
self.check_recursive_collection_and_inst(tuple) self._test_recursive_collection_and_inst(tuple)
def test_recursive_dict_and_inst(self): def test_recursive_dict_and_inst(self):
self.check_recursive_collection_and_inst(dict.fromkeys) self._test_recursive_collection_and_inst(dict.fromkeys)
def test_recursive_set_and_inst(self): def test_recursive_set_and_inst(self):
self.check_recursive_collection_and_inst(set) self._test_recursive_collection_and_inst(set)
def test_recursive_frozenset_and_inst(self): def test_recursive_frozenset_and_inst(self):
self.check_recursive_collection_and_inst(frozenset) self._test_recursive_collection_and_inst(frozenset)
def test_recursive_list_subclass_and_inst(self): def test_recursive_list_subclass_and_inst(self):
self.check_recursive_collection_and_inst(MyList) self._test_recursive_collection_and_inst(MyList)
def test_recursive_tuple_subclass_and_inst(self): def test_recursive_tuple_subclass_and_inst(self):
self.check_recursive_collection_and_inst(MyTuple) self._test_recursive_collection_and_inst(MyTuple)
def test_recursive_dict_subclass_and_inst(self): def test_recursive_dict_subclass_and_inst(self):
self.check_recursive_collection_and_inst(MyDict.fromkeys) self._test_recursive_collection_and_inst(MyDict.fromkeys)
def test_recursive_set_subclass_and_inst(self): def test_recursive_set_subclass_and_inst(self):
self.check_recursive_collection_and_inst(MySet) self._test_recursive_collection_and_inst(MySet)
def test_recursive_frozenset_subclass_and_inst(self): def test_recursive_frozenset_subclass_and_inst(self):
self.check_recursive_collection_and_inst(MyFrozenSet) self._test_recursive_collection_and_inst(MyFrozenSet)
def test_recursive_inst_state(self):
# Mutable object containing itself.
y = REX_state()
y.state = y
for proto in protocols:
s = self.dumps(y, proto)
x = self.loads(s)
self.assertIsInstance(x, REX_state)
self.assertIs(x.state, x)
def test_recursive_tuple_and_inst_state(self):
# Tuple containing a mutable object containing the original tuple.
t = (REX_state(),)
t[0].state = t
for proto in protocols:
s = self.dumps(t, proto)
x = self.loads(s)
self.assertIsInstance(x, tuple)
self.assertEqual(len(x), 1)
self.assertIsInstance(x[0], REX_state)
self.assertIs(x[0].state, x)
# Mutable object containing a tuple containing the object.
t, = t
for proto in protocols:
s = self.dumps(t, proto)
x = self.loads(s)
self.assertIsInstance(x, REX_state)
self.assertIsInstance(x.state, tuple)
self.assertEqual(len(x.state), 1)
self.assertIs(x.state[0], x)
def test_unicode(self): def test_unicode(self):
endcases = ['', '<\\u>', '<\\\u1234>', '<\n>', endcases = ['', '<\\u>', '<\\\u1234>', '<\n>',
@ -3045,6 +3198,19 @@ class REX_seven(object):
def __reduce__(self): def __reduce__(self):
return type(self), (), None, None, iter(self.table.items()) return type(self), (), None, None, iter(self.table.items())
class REX_state(object):
"""This class is used to check the 3th argument (state) of
the reduce protocol.
"""
def __init__(self, state=None):
self.state = state
def __eq__(self, other):
return type(self) is type(other) and self.state == other.state
def __setstate__(self, state):
self.state = state
def __reduce__(self):
return type(self), (), self.state
# Test classes for newobj # Test classes for newobj