diff --git a/Lib/sets.py b/Lib/sets.py index a05c66f73c2..5dac370a644 100644 --- a/Lib/sets.py +++ b/Lib/sets.py @@ -366,6 +366,11 @@ class ImmutableSet(BaseSet): self._hashcode = self._compute_hash() return self._hashcode + def __getstate__(self): + return self._data, self._hashcode + + def __setstate__(self, state): + self._data, self._hashcode = state class Set(BaseSet): """ Mutable set class.""" @@ -380,6 +385,13 @@ class Set(BaseSet): if iterable is not None: self._update(iterable) + def __getstate__(self): + # getstate's results are ignored if it is not + return self._data, + + def __setstate__(self, data): + self._data, = data + def __hash__(self): """A Set cannot be hashed.""" # We inherit object.__hash__, so we must deny this explicitly diff --git a/Lib/test/test_sets.py b/Lib/test/test_sets.py index 4521335c516..cf0cd59dc99 100644 --- a/Lib/test/test_sets.py +++ b/Lib/test/test_sets.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -import unittest, operator, copy +import unittest, operator, copy, pickle from sets import Set, ImmutableSet from test import test_support @@ -74,6 +74,14 @@ class TestBasicOps(unittest.TestCase): for v in self.set: self.assert_(v in self.values) + def test_pickling(self): + p = pickle.dumps(self.set) + print repr(p) + copy = pickle.loads(p) + repr(copy) + self.assertEqual(self.set, copy, + "%s != %s" % (self.set, copy)) + #------------------------------------------------------------------------------ class TestBasicOpsEmpty(TestBasicOps):