From 4ad1d6f81a1fec3b9822e104e5df3a61a5cc328c Mon Sep 17 00:00:00 2001 From: Robert Schuppenies Date: Sun, 17 May 2009 17:32:20 +0000 Subject: [PATCH] Issue 5964: Fixed WeakSet __eq__ comparison to handle non-WeakSet objects. --- Lib/_weakrefset.py | 2 ++ Lib/test/test_weakset.py | 23 +++++++++++++---------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/Lib/_weakrefset.py b/Lib/_weakrefset.py index 0046133925d..addc7afeef9 100644 --- a/Lib/_weakrefset.py +++ b/Lib/_weakrefset.py @@ -118,6 +118,8 @@ class WeakSet: return self.data >= set(ref(item) for item in other) def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented return self.data == set(ref(item) for item in other) def symmetric_difference(self, other): diff --git a/Lib/test/test_weakset.py b/Lib/test/test_weakset.py index cff20455566..4d54576d15b 100644 --- a/Lib/test/test_weakset.py +++ b/Lib/test/test_weakset.py @@ -134,13 +134,11 @@ class TestWeakSet(unittest.TestCase): def test_gc(self): # Create a nest of cycles to exercise overall ref count check - class A: - pass - s = set(A() for i in range(1000)) + s = WeakSet(Foo() for i in range(1000)) for elem in s: elem.cycle = s elem.sub = elem - elem.set = set([elem]) + elem.set = WeakSet([elem]) def test_subclass_with_custom_hash(self): # Bug #1257731 @@ -169,17 +167,12 @@ class TestWeakSet(unittest.TestCase): t = WeakSet(s) self.assertNotEqual(id(s), id(t)) - def test_set_literal(self): - s = set([1,2,3]) - t = {1,2,3} - self.assertEqual(s, t) - def test_hash(self): self.assertRaises(TypeError, hash, self.s) def test_clear(self): self.s.clear() - self.assertEqual(self.s, set()) + self.assertEqual(self.s, WeakSet([])) self.assertEqual(len(self.s), 0) def test_copy(self): @@ -304,6 +297,16 @@ class TestWeakSet(unittest.TestCase): t ^= t self.assertEqual(t, WeakSet()) + def test_eq(self): + # issue 5964 + self.assertTrue(self.s == self.s) + self.assertTrue(self.s == WeakSet(self.items)) + self.assertFalse(self.s == set(self.items)) + self.assertFalse(self.s == list(self.items)) + self.assertFalse(self.s == tuple(self.items)) + self.assertFalse(self.s == WeakSet([Foo])) + self.assertFalse(self.s == 1) + def test_main(verbose=None): support.run_unittest(TestWeakSet)