From e98839a1f48b2915f1cc747884e64f4d6e4c8e7a Mon Sep 17 00:00:00 2001 From: Raymond Hettinger Date: Mon, 9 Jun 2008 01:28:30 +0000 Subject: [PATCH] Issue3065: Fixed pickling of named tuples. Added tests. --- Doc/library/collections.rst | 3 +++ Lib/collections.py | 4 +++- Lib/test/test_collections.py | 21 ++++++++++++++++++++- 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/Doc/library/collections.rst b/Doc/library/collections.rst index 2b6f9b1078a..cbad297a4f6 100644 --- a/Doc/library/collections.rst +++ b/Doc/library/collections.rst @@ -539,6 +539,9 @@ Example: if kwds: raise ValueError('Got unexpected field names: %r' % kwds.keys()) return result + + def __getnewargs__(self): + return tuple(self) x = property(itemgetter(0)) y = property(itemgetter(1)) diff --git a/Lib/collections.py b/Lib/collections.py index f6233a7a448..24088183900 100644 --- a/Lib/collections.py +++ b/Lib/collections.py @@ -82,7 +82,9 @@ def namedtuple(typename, field_names, verbose=False): result = self._make(map(kwds.pop, %(field_names)r, self)) if kwds: raise ValueError('Got unexpected field names: %%r' %% kwds.keys()) - return result \n\n''' % locals() + return result \n + def __getnewargs__(self): + return tuple(self) \n\n''' % locals() for i, name in enumerate(field_names): template += ' %s = property(itemgetter(%d))\n' % (name, i) if verbose: diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py index a770155bdf5..4f823e393ce 100644 --- a/Lib/test/test_collections.py +++ b/Lib/test/test_collections.py @@ -1,12 +1,14 @@ import unittest, doctest from test import test_support from collections import namedtuple +import pickle, cPickle, copy from collections import Hashable, Iterable, Iterator from collections import Sized, Container, Callable from collections import Set, MutableSet from collections import Mapping, MutableMapping from collections import Sequence, MutableSequence +TestNT = namedtuple('TestNT', 'x y z') # type used for pickle tests class TestNamedTuple(unittest.TestCase): @@ -108,7 +110,7 @@ class TestNamedTuple(unittest.TestCase): self.assertEqual(Dot(1)._replace(d=999), (999,)) self.assertEqual(Dot(1)._fields, ('d',)) - n = 10000 + n = 5000 import string, random names = list(set(''.join([random.choice(string.ascii_letters) for j in range(10)]) for i in range(n))) @@ -130,6 +132,23 @@ class TestNamedTuple(unittest.TestCase): self.assertEqual(b2, tuple(b2_expected)) self.assertEqual(b._fields, tuple(names)) + def test_pickle(self): + p = TestNT(x=10, y=20, z=30) + for module in pickle, cPickle: + loads = getattr(module, 'loads') + dumps = getattr(module, 'dumps') + for protocol in -1, 0, 1, 2: + q = loads(dumps(p, protocol)) + self.assertEqual(p, q) + self.assertEqual(p._fields, q._fields) + + def test_copy(self): + p = TestNT(x=10, y=20, z=30) + for copier in copy.copy, copy.deepcopy: + q = copier(p) + self.assertEqual(p, q) + self.assertEqual(p._fields, q._fields) + class TestOneTrickPonyABCs(unittest.TestCase): def test_Hashable(self):