Support all the new stuff supported by the new pickle code:

- subclasses of list or dict
- __reduce__ returning a 4-tuple or 5-tuple
- slots
This commit is contained in:
Guido van Rossum 2003-02-06 19:53:22 +00:00
parent 0189266456
commit c755758906
2 changed files with 109 additions and 13 deletions

View File

@ -7,7 +7,7 @@ Interface summary:
x = copy.copy(y) # make a shallow copy of y
x = copy.deepcopy(y) # make a deep copy of y
For module specific errors, copy.error is raised.
For module specific errors, copy.Error is raised.
The difference between shallow and deep copying is only relevant for
compound objects (objects that contain other objects, like lists or
@ -51,6 +51,7 @@ __getstate__() and __setstate__(). See the documentation for module
# XXX need to support copy_reg here too...
import types
from pickle import _slotnames
class Error(Exception):
pass
@ -61,7 +62,7 @@ try:
except ImportError:
PyStringMap = None
__all__ = ["Error", "error", "copy", "deepcopy"]
__all__ = ["Error", "copy", "deepcopy"]
def copy(x):
"""Shallow copy operation on arbitrary Python objects.
@ -76,18 +77,60 @@ def copy(x):
copier = x.__copy__
except AttributeError:
try:
reductor = x.__reduce__
reductor = x.__class__.__reduce__
if reductor == object.__reduce__:
reductor = _better_reduce
except AttributeError:
raise error, \
"un(shallow)copyable object of type %s" % type(x)
raise Error("un(shallow)copyable object of type %s" % type(x))
else:
y = _reconstruct(x, reductor(), 0)
y = _reconstruct(x, reductor(x), 0)
else:
y = copier()
else:
y = copierfunction(x)
return y
def __newobj__(cls, *args):
return cls.__new__(cls, *args)
def _better_reduce(obj):
cls = obj.__class__
getnewargs = getattr(obj, "__getnewargs__", None)
if getnewargs:
args = getnewargs()
else:
args = ()
getstate = getattr(obj, "__getstate__", None)
if getstate:
try:
state = getstate()
except TypeError, err:
# XXX Catch generic exception caused by __slots__
if str(err) != ("a class that defines __slots__ "
"without defining __getstate__ "
"cannot be pickled"):
raise # Not that specific exception
getstate = None
if not getstate:
state = getattr(obj, "__dict__", None)
names = _slotnames(cls)
if names:
slots = {}
nil = []
for name in names:
value = getattr(obj, name, nil)
if value is not nil:
slots[name] = value
if slots:
state = (state, slots)
listitems = dictitems = None
if isinstance(obj, list):
listitems = iter(obj)
elif isinstance(obj, dict):
dictitems = obj.iteritems()
return __newobj__, (cls, args), state, listitems, dictitems
_copy_dispatch = d = {}
def _copy_atomic(x):
@ -175,12 +218,14 @@ def deepcopy(x, memo = None):
copier = x.__deepcopy__
except AttributeError:
try:
reductor = x.__reduce__
reductor = x.__class__.__reduce__
if reductor == object.__reduce__:
reductor = _better_reduce
except AttributeError:
raise error, \
"un-deep-copyable object of type %s" % type(x)
raise Error("un(shallow)copyable object of type %s" %
type(x))
else:
y = _reconstruct(x, reductor(), 1, memo)
y = _reconstruct(x, reductor(x), 1, memo)
else:
y = copier(memo)
else:
@ -331,7 +376,15 @@ def _reconstruct(x, info, deep, memo=None):
if hasattr(y, '__setstate__'):
y.__setstate__(state)
else:
y.__dict__.update(state)
if isinstance(state, tuple) and len(state) == 2:
state, slotstate = state
else:
slotstate = None
if state is not None:
y.__dict__.update(state)
if slotstate is not None:
for key, value in slotstate.iteritems():
setattr(y, key, value)
return y
del d

View File

@ -41,11 +41,13 @@ class TestCopy(unittest.TestCase):
self.assert_(y is x)
def test_copy_cant(self):
class C(object):
class Meta(type):
def __getattribute__(self, name):
if name == "__reduce__":
raise AttributeError, name
return object.__getattribute__(self, name)
class C:
__metaclass__ = Meta
x = C()
self.assertRaises(copy.Error, copy.copy, x)
@ -189,11 +191,13 @@ class TestCopy(unittest.TestCase):
self.assert_(y is x)
def test_deepcopy_cant(self):
class C(object):
class Meta(type):
def __getattribute__(self, name):
if name == "__reduce__":
raise AttributeError, name
return object.__getattribute__(self, name)
class C:
__metaclass__ = Meta
x = C()
self.assertRaises(copy.Error, copy.deepcopy, x)
@ -411,6 +415,45 @@ class TestCopy(unittest.TestCase):
self.assert_(x is not y)
self.assert_(x["foo"] is not y["foo"])
def test_copy_slots(self):
class C(object):
__slots__ = ["foo"]
x = C()
x.foo = [42]
y = copy.copy(x)
self.assert_(x.foo is y.foo)
def test_deepcopy_slots(self):
class C(object):
__slots__ = ["foo"]
x = C()
x.foo = [42]
y = copy.deepcopy(x)
self.assertEqual(x.foo, y.foo)
self.assert_(x.foo is not y.foo)
def test_copy_list_subclass(self):
class C(list):
pass
x = C([[1, 2], 3])
x.foo = [4, 5]
y = copy.copy(x)
self.assertEqual(list(x), list(y))
self.assertEqual(x.foo, y.foo)
self.assert_(x[0] is y[0])
self.assert_(x.foo is y.foo)
def test_deepcopy_list_subclass(self):
class C(list):
pass
x = C([[1, 2], 3])
x.foo = [4, 5]
y = copy.deepcopy(x)
self.assertEqual(list(x), list(y))
self.assertEqual(x.foo, y.foo)
self.assert_(x[0] is not y[0])
self.assert_(x.foo is not y.foo)
def test_main():
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(TestCopy))