From c06e3acc735a6e9cf28d0f511493bcfd8829117d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 7 Feb 2003 17:30:18 +0000 Subject: [PATCH] Add support for copy_reg.dispatch_table. Rewrote copy() and deepcopy() without avoidable try/except statements; getattr(x, name, None) or dict.get() are much faster than try/except. --- Lib/copy.py | 98 ++++++++++++++++++++++--------------------- Lib/test/test_copy.py | 27 ++++++++++++ 2 files changed, 77 insertions(+), 48 deletions(-) diff --git a/Lib/copy.py b/Lib/copy.py index 4133a1f1dc3..c1c0ec0cbdc 100644 --- a/Lib/copy.py +++ b/Lib/copy.py @@ -48,10 +48,8 @@ __getstate__() and __setstate__(). See the documentation for module "pickle" for information on these methods. """ -# XXX need to support copy_reg here too... - import types -from copy_reg import _better_reduce +from copy_reg import _better_reduce, dispatch_table class Error(Exception): pass @@ -70,25 +68,25 @@ def copy(x): See the module's __doc__ string for more info. """ - try: - copierfunction = _copy_dispatch[type(x)] - except KeyError: - try: - copier = x.__copy__ - except AttributeError: - try: - reductor = x.__class__.__reduce__ - if reductor == object.__reduce__: - reductor = _better_reduce - except AttributeError: - raise Error("un(shallow)copyable object of type %s" % type(x)) - else: - y = _reconstruct(x, reductor(x), 0) - else: - y = copier() - else: - y = copierfunction(x) - return y + cls = type(x) + + copier = _copy_dispatch.get(cls) + if copier: + return copier(x) + + copier = getattr(cls, "__copy__", None) + if copier: + return copier(x) + + reductor = dispatch_table.get(cls) + if not reductor: + reductor = getattr(cls, "__reduce__", None) + if reductor == object.__reduce__: + reductor = _better_reduce + elif not reductor: + raise Error("un(shallow)copyable object of type %s" % cls) + + return _reconstruct(x, reductor(x), 0) _copy_dispatch = d = {} @@ -153,7 +151,7 @@ d[types.InstanceType] = _copy_inst del d -def deepcopy(x, memo = None): +def deepcopy(x, memo=None, _nil=[]): """Deep copy operation on arbitrary Python objects. See the module's __doc__ string for more info. @@ -161,35 +159,39 @@ def deepcopy(x, memo = None): if memo is None: memo = {} + d = id(x) - if d in memo: - return memo[d] - try: - copierfunction = _deepcopy_dispatch[type(x)] - except KeyError: + y = memo.get(d, _nil) + if y is not _nil: + return y + + cls = type(x) + + copier = _deepcopy_dispatch.get(cls) + if copier: + y = copier(x, memo) + else: try: - issc = issubclass(type(x), type) - except TypeError: + issc = issubclass(cls, type) + except TypeError: # cls is not a class (old Boost; see SF #502085) issc = 0 if issc: - y = _deepcopy_dispatch[type](x, memo) + copier = _deepcopy_atomic else: - try: - copier = x.__deepcopy__ - except AttributeError: - try: - reductor = x.__class__.__reduce__ - if reductor == object.__reduce__: - reductor = _better_reduce - except AttributeError: - raise Error("un(shallow)copyable object of type %s" % - type(x)) - else: - y = _reconstruct(x, reductor(x), 1, memo) - else: - y = copier(memo) - else: - y = copierfunction(x, memo) + copier = getattr(cls, "__deepcopy__", None) + + if copier: + y = copier(x, memo) + else: + reductor = dispatch_table.get(cls) + if not reductor: + reductor = getattr(cls, "__reduce__", None) + if reductor == object.__reduce__: + reductor = _better_reduce + elif not reductor: + raise Error("un(deep)copyable object of type %s" % cls) + y = _reconstruct(x, reductor(x), 1, memo) + memo[d] = y _keep_alive(x, memo) # Make sure x lives at least as long as d return y @@ -380,7 +382,7 @@ def _test(): def __setstate__(self, state): for key, value in state.iteritems(): setattr(self, key, value) - def __deepcopy__(self, memo = None): + def __deepcopy__(self, memo=None): new = self.__class__(deepcopy(self.arg, memo)) new.a = self.a return new diff --git a/Lib/test/test_copy.py b/Lib/test/test_copy.py index c97d54d7491..35ce46a5232 100644 --- a/Lib/test/test_copy.py +++ b/Lib/test/test_copy.py @@ -2,6 +2,7 @@ import sys import copy +import copy_reg import unittest from test import test_support @@ -32,6 +33,19 @@ class TestCopy(unittest.TestCase): self.assertEqual(y.__class__, x.__class__) self.assertEqual(y.foo, x.foo) + def test_copy_registry(self): + class C(object): + def __new__(cls, foo): + obj = object.__new__(cls) + obj.foo = foo + return obj + def pickle_C(obj): + return (C, (obj.foo,)) + x = C(42) + self.assertRaises(TypeError, copy.copy, x) + copy_reg.pickle(C, pickle_C, C) + y = copy.copy(x) + def test_copy_reduce(self): class C(object): def __reduce__(self): @@ -182,6 +196,19 @@ class TestCopy(unittest.TestCase): self.assertEqual(y.__class__, x.__class__) self.assertEqual(y.foo, x.foo) + def test_deepcopy_registry(self): + class C(object): + def __new__(cls, foo): + obj = object.__new__(cls) + obj.foo = foo + return obj + def pickle_C(obj): + return (C, (obj.foo,)) + x = C(42) + self.assertRaises(TypeError, copy.deepcopy, x) + copy_reg.pickle(C, pickle_C, C) + y = copy.deepcopy(x) + def test_deepcopy_reduce(self): class C(object): def __reduce__(self):