diff --git a/Lib/collections/__init__.py b/Lib/collections/__init__.py index 1c69b1b5096..b4592f98392 100644 --- a/Lib/collections/__init__.py +++ b/Lib/collections/__init__.py @@ -1036,6 +1036,13 @@ class UserDict(_collections_abc.MutableMapping): # Now, add the methods in dicts but not in MutableMapping def __repr__(self): return repr(self.data) + def __copy__(self): + inst = self.__class__.__new__(self.__class__) + inst.__dict__.update(self.__dict__) + # Create a copy and avoid triggering descriptors + inst.__dict__["data"] = self.__dict__["data"].copy() + return inst + def copy(self): if self.__class__ is UserDict: return UserDict(self.data.copy()) @@ -1048,6 +1055,7 @@ class UserDict(_collections_abc.MutableMapping): self.data = data c.update(self) return c + @classmethod def fromkeys(cls, iterable, value=None): d = cls() @@ -1112,6 +1120,12 @@ class UserList(_collections_abc.MutableSequence): def __imul__(self, n): self.data *= n return self + def __copy__(self): + inst = self.__class__.__new__(self.__class__) + inst.__dict__.update(self.__dict__) + # Create a copy and avoid triggering descriptors + inst.__dict__["data"] = self.__dict__["data"][:] + return inst def append(self, item): self.data.append(item) def insert(self, i, item): self.data.insert(i, item) def pop(self, i=-1): return self.data.pop(i) diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py index 00b33ce2767..16735b815e5 100644 --- a/Lib/test/test_collections.py +++ b/Lib/test/test_collections.py @@ -38,6 +38,20 @@ class TestUserObjects(unittest.TestCase): b=b.__name__, ), ) + + def _copy_test(self, obj): + # Test internal copy + obj_copy = obj.copy() + self.assertIsNot(obj.data, obj_copy.data) + self.assertEqual(obj.data, obj_copy.data) + + # Test copy.copy + obj.test = [1234] # Make sure instance vars are also copied. + obj_copy = copy.copy(obj) + self.assertIsNot(obj.data, obj_copy.data) + self.assertEqual(obj.data, obj_copy.data) + self.assertIs(obj.test, obj_copy.test) + def test_str_protocol(self): self._superset_test(UserString, str) @@ -47,6 +61,16 @@ class TestUserObjects(unittest.TestCase): def test_dict_protocol(self): self._superset_test(UserDict, dict) + def test_list_copy(self): + obj = UserList() + obj.append(123) + self._copy_test(obj) + + def test_dict_copy(self): + obj = UserDict() + obj[123] = "abc" + self._copy_test(obj) + ################################################################################ ### ChainMap (helper class for configparser and the string module) diff --git a/Misc/NEWS.d/next/Library/2017-10-24-00-42-14.bpo-27141.zbAgSs.rst b/Misc/NEWS.d/next/Library/2017-10-24-00-42-14.bpo-27141.zbAgSs.rst new file mode 100644 index 00000000000..76c2abbf82d --- /dev/null +++ b/Misc/NEWS.d/next/Library/2017-10-24-00-42-14.bpo-27141.zbAgSs.rst @@ -0,0 +1,3 @@ +Added a ``__copy__()`` to ``collections.UserList`` and +``collections.UserDict`` in order to correctly implement shallow copying of +the objects. Patch by Bar Harel.