Issue #892902: Fixed pickling recursive objects.

This commit is contained in:
Serhiy Storchaka 2015-11-07 11:15:32 +02:00
parent 43415ba571
commit da87e45add
5 changed files with 145 additions and 30 deletions

View File

@ -402,7 +402,13 @@ class Pickler:
write(REDUCE)
if obj is not None:
self.memoize(obj)
# If the object is already in the memo, this means it is
# recursive. In this case, throw away everything we put on the
# stack, and fetch the object back from the memo.
if id(obj) in self.memo:
write(POP + self.get(self.memo[id(obj)][0]))
else:
self.memoize(obj)
# More new special cases (that work with older protocols as
# well): when __reduce__ returns a tuple with 4 or 5 items,

View File

@ -117,6 +117,18 @@ class E(C):
def __getinitargs__(self):
return ()
class H(object):
pass
# Hashable mutable key
class K(object):
def __init__(self, value):
self.value = value
def __reduce__(self):
# Shouldn't support the recursion itself
return K, (self.value,)
import __main__
__main__.C = C
C.__module__ = "__main__"
@ -124,6 +136,10 @@ __main__.D = D
D.__module__ = "__main__"
__main__.E = E
E.__module__ = "__main__"
__main__.H = H
H.__module__ = "__main__"
__main__.K = K
K.__module__ = "__main__"
class myint(int):
def __init__(self, x):
@ -676,18 +692,21 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols:
s = self.dumps(l, proto)
x = self.loads(s)
self.assertIsInstance(x, list)
self.assertEqual(len(x), 1)
self.assertTrue(x is x[0])
self.assertIs(x[0], x)
def test_recursive_tuple(self):
def test_recursive_tuple_and_list(self):
t = ([],)
t[0].append(t)
for proto in protocols:
s = self.dumps(t, proto)
x = self.loads(s)
self.assertIsInstance(x, tuple)
self.assertEqual(len(x), 1)
self.assertIsInstance(x[0], list)
self.assertEqual(len(x[0]), 1)
self.assertTrue(x is x[0][0])
self.assertIs(x[0][0], x)
def test_recursive_dict(self):
d = {}
@ -695,8 +714,50 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols:
s = self.dumps(d, proto)
x = self.loads(s)
self.assertIsInstance(x, dict)
self.assertEqual(x.keys(), [1])
self.assertTrue(x[1] is x)
self.assertIs(x[1], x)
def test_recursive_dict_key(self):
d = {}
k = K(d)
d[k] = 1
for proto in protocols:
s = self.dumps(d, proto)
x = self.loads(s)
self.assertIsInstance(x, dict)
self.assertEqual(len(x.keys()), 1)
self.assertIsInstance(x.keys()[0], K)
self.assertIs(x.keys()[0].value, x)
def test_recursive_list_subclass(self):
y = MyList()
y.append(y)
s = self.dumps(y, 2)
x = self.loads(s)
self.assertIsInstance(x, MyList)
self.assertEqual(len(x), 1)
self.assertIs(x[0], x)
def test_recursive_dict_subclass(self):
d = MyDict()
d[1] = d
s = self.dumps(d, 2)
x = self.loads(s)
self.assertIsInstance(x, MyDict)
self.assertEqual(x.keys(), [1])
self.assertIs(x[1], x)
def test_recursive_dict_subclass_key(self):
d = MyDict()
k = K(d)
d[k] = 1
s = self.dumps(d, 2)
x = self.loads(s)
self.assertIsInstance(x, MyDict)
self.assertEqual(len(x.keys()), 1)
self.assertIsInstance(x.keys()[0], K)
self.assertIs(x.keys()[0].value, x)
def test_recursive_inst(self):
i = C()
@ -721,6 +782,42 @@ class AbstractPickleTests(unittest.TestCase):
self.assertEqual(x[0].attr.keys(), [1])
self.assertTrue(x[0].attr[1] is x)
def check_recursive_collection_and_inst(self, factory):
h = H()
y = factory([h])
h.attr = y
for proto in protocols:
s = self.dumps(y, proto)
x = self.loads(s)
self.assertIsInstance(x, type(y))
self.assertEqual(len(x), 1)
self.assertIsInstance(list(x)[0], H)
self.assertIs(list(x)[0].attr, x)
def test_recursive_list_and_inst(self):
self.check_recursive_collection_and_inst(list)
def test_recursive_tuple_and_inst(self):
self.check_recursive_collection_and_inst(tuple)
def test_recursive_dict_and_inst(self):
self.check_recursive_collection_and_inst(dict.fromkeys)
def test_recursive_set_and_inst(self):
self.check_recursive_collection_and_inst(set)
def test_recursive_frozenset_and_inst(self):
self.check_recursive_collection_and_inst(frozenset)
def test_recursive_list_subclass_and_inst(self):
self.check_recursive_collection_and_inst(MyList)
def test_recursive_tuple_subclass_and_inst(self):
self.check_recursive_collection_and_inst(MyTuple)
def test_recursive_dict_subclass_and_inst(self):
self.check_recursive_collection_and_inst(MyDict.fromkeys)
if have_unicode:
def test_unicode(self):
endcases = [u'', u'<\\u>', u'<\\\u1234>', u'<\n>',

View File

@ -1,6 +1,7 @@
import cPickle
import cStringIO
import io
import functools
import unittest
from test.pickletester import (AbstractUnpickleTests,
AbstractPickleTests,
@ -151,31 +152,6 @@ class cPickleFastPicklerTests(AbstractPickleTests):
finally:
self.close(f)
def test_recursive_list(self):
self.assertRaises(ValueError,
AbstractPickleTests.test_recursive_list,
self)
def test_recursive_tuple(self):
self.assertRaises(ValueError,
AbstractPickleTests.test_recursive_tuple,
self)
def test_recursive_inst(self):
self.assertRaises(ValueError,
AbstractPickleTests.test_recursive_inst,
self)
def test_recursive_dict(self):
self.assertRaises(ValueError,
AbstractPickleTests.test_recursive_dict,
self)
def test_recursive_multi(self):
self.assertRaises(ValueError,
AbstractPickleTests.test_recursive_multi,
self)
def test_nonrecursive_deep(self):
# If it's not cyclic, it should pickle OK even if the nesting
# depth exceeds PY_CPICKLE_FAST_LIMIT. That happens to be
@ -187,6 +163,19 @@ class cPickleFastPicklerTests(AbstractPickleTests):
b = self.loads(self.dumps(a))
self.assertEqual(a, b)
for name in dir(AbstractPickleTests):
if name.startswith('test_recursive_'):
func = getattr(AbstractPickleTests, name)
if '_subclass' in name and '_and_inst' not in name:
assert_args = RuntimeError, 'maximum recursion depth exceeded'
else:
assert_args = ValueError, "can't pickle cyclic objects"
def wrapper(self, func=func, assert_args=assert_args):
with self.assertRaisesRegexp(*assert_args):
func(self)
functools.update_wrapper(wrapper, func)
setattr(cPickleFastPicklerTests, name, wrapper)
class cStringIOCPicklerFastTests(cStringIOMixin, cPickleFastPicklerTests):
pass

View File

@ -46,6 +46,8 @@ Core and Builtins
Library
-------
- Issue #892902: Fixed pickling recursive objects.
- Issue #18010: Fix the pydoc GUI's search function to handle exceptions
from importing packages.

View File

@ -2533,6 +2533,27 @@ save_reduce(Picklerobject *self, PyObject *args, PyObject *fn, PyObject *ob)
/* Memoize. */
/* XXX How can ob be NULL? */
if (ob != NULL) {
/* If the object is already in the memo, this means it is
recursive. In this case, throw away everything we put on the
stack, and fetch the object back from the memo. */
if (Py_REFCNT(ob) > 1 && !self->fast) {
PyObject *py_ob_id = PyLong_FromVoidPtr(ob);
if (!py_ob_id)
return -1;
if (PyDict_GetItem(self->memo, py_ob_id)) {
const char pop_op = POP;
if (self->write_func(self, &pop_op, 1) < 0 ||
get(self, py_ob_id) < 0) {
Py_DECREF(py_ob_id);
return -1;
}
Py_DECREF(py_ob_id);
return 0;
}
Py_DECREF(py_ob_id);
if (PyErr_Occurred())
return -1;
}
if (state && !PyDict_Check(state)) {
if (put2(self, ob) < 0)
return -1;