From dd5e53a086fcabc84ee1ac96b98057437863973a Mon Sep 17 00:00:00 2001 From: Raymond Hettinger Date: Mon, 26 May 2014 00:09:04 -0700 Subject: [PATCH] Issue 8743: Improve interoperability between sets and the collections.Set abstract base class. --- Lib/_collections_abc.py | 23 ++++- Lib/test/test_collections.py | 160 ++++++++++++++++++++++++++++++++++- Misc/NEWS | 3 + 3 files changed, 180 insertions(+), 6 deletions(-) diff --git a/Lib/_collections_abc.py b/Lib/_collections_abc.py index faa1ff22ff4..656ae064fe1 100644 --- a/Lib/_collections_abc.py +++ b/Lib/_collections_abc.py @@ -207,12 +207,17 @@ class Set(Sized, Iterable, Container): def __gt__(self, other): if not isinstance(other, Set): return NotImplemented - return other.__lt__(self) + return len(self) > len(other) and self.__ge__(other) def __ge__(self, other): if not isinstance(other, Set): return NotImplemented - return other.__le__(self) + if len(self) < len(other): + return False + for elem in other: + if elem not in self: + return False + return True def __eq__(self, other): if not isinstance(other, Set): @@ -236,6 +241,8 @@ class Set(Sized, Iterable, Container): return NotImplemented return self._from_iterable(value for value in other if value in self) + __rand__ = __and__ + def isdisjoint(self, other): 'Return True if two sets have a null intersection.' for value in other: @@ -249,6 +256,8 @@ class Set(Sized, Iterable, Container): chain = (e for s in (self, other) for e in s) return self._from_iterable(chain) + __ror__ = __or__ + def __sub__(self, other): if not isinstance(other, Set): if not isinstance(other, Iterable): @@ -257,6 +266,14 @@ class Set(Sized, Iterable, Container): return self._from_iterable(value for value in self if value not in other) + def __rsub__(self, other): + if not isinstance(other, Set): + if not isinstance(other, Iterable): + return NotImplemented + other = self._from_iterable(other) + return self._from_iterable(value for value in other + if value not in self) + def __xor__(self, other): if not isinstance(other, Set): if not isinstance(other, Iterable): @@ -264,6 +281,8 @@ class Set(Sized, Iterable, Container): other = self._from_iterable(other) return (self - other) | (other - self) + __rxor__ = __xor__ + def _hash(self): """Compute the hash value of a set. diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py index ee28a6c0b35..6407b6f4f3d 100644 --- a/Lib/test/test_collections.py +++ b/Lib/test/test_collections.py @@ -720,14 +720,166 @@ class TestCollectionABCs(ABCTestCase): cs = MyComparableSet() ncs = MyNonComparableSet() + self.assertFalse(ncs < cs) + self.assertTrue(ncs <= cs) + self.assertFalse(ncs > cs) + self.assertTrue(ncs >= cs) + + def assertSameSet(self, s1, s2): + # coerce both to a real set then check equality + self.assertSetEqual(set(s1), set(s2)) + + def test_Set_interoperability_with_real_sets(self): + # Issue: 8743 + class ListSet(Set): + def __init__(self, elements=()): + self.data = [] + for elem in elements: + if elem not in self.data: + self.data.append(elem) + def __contains__(self, elem): + return elem in self.data + def __iter__(self): + return iter(self.data) + def __len__(self): + return len(self.data) + def __repr__(self): + return 'Set({!r})'.format(self.data) + + r1 = set('abc') + r2 = set('bcd') + r3 = set('abcde') + f1 = ListSet('abc') + f2 = ListSet('bcd') + f3 = ListSet('abcde') + l1 = list('abccba') + l2 = list('bcddcb') + l3 = list('abcdeedcba') + + target = r1 & r2 + self.assertSameSet(f1 & f2, target) + self.assertSameSet(f1 & r2, target) + self.assertSameSet(r2 & f1, target) + self.assertSameSet(f1 & l2, target) + + target = r1 | r2 + self.assertSameSet(f1 | f2, target) + self.assertSameSet(f1 | r2, target) + self.assertSameSet(r2 | f1, target) + self.assertSameSet(f1 | l2, target) + + fwd_target = r1 - r2 + rev_target = r2 - r1 + self.assertSameSet(f1 - f2, fwd_target) + self.assertSameSet(f2 - f1, rev_target) + self.assertSameSet(f1 - r2, fwd_target) + self.assertSameSet(f2 - r1, rev_target) + self.assertSameSet(r1 - f2, fwd_target) + self.assertSameSet(r2 - f1, rev_target) + self.assertSameSet(f1 - l2, fwd_target) + self.assertSameSet(f2 - l1, rev_target) + + target = r1 ^ r2 + self.assertSameSet(f1 ^ f2, target) + self.assertSameSet(f1 ^ r2, target) + self.assertSameSet(r2 ^ f1, target) + self.assertSameSet(f1 ^ l2, target) + + # Don't change the following to use assertLess or other + # "more specific" unittest assertions. The current + # assertTrue/assertFalse style makes the pattern of test + # case combinations clear and allows us to know for sure + # the exact operator being invoked. + + # proper subset + self.assertTrue(f1 < f3) + self.assertFalse(f1 < f1) + self.assertFalse(f1 < f2) + self.assertTrue(r1 < f3) + self.assertFalse(r1 < f1) + self.assertFalse(r1 < f2) + self.assertTrue(r1 < r3) + self.assertFalse(r1 < r1) + self.assertFalse(r1 < r2) with self.assertRaises(TypeError): - ncs < cs + f1 < l3 with self.assertRaises(TypeError): - ncs <= cs + f1 < l1 with self.assertRaises(TypeError): - cs > ncs + f1 < l2 + + # any subset + self.assertTrue(f1 <= f3) + self.assertTrue(f1 <= f1) + self.assertFalse(f1 <= f2) + self.assertTrue(r1 <= f3) + self.assertTrue(r1 <= f1) + self.assertFalse(r1 <= f2) + self.assertTrue(r1 <= r3) + self.assertTrue(r1 <= r1) + self.assertFalse(r1 <= r2) with self.assertRaises(TypeError): - cs >= ncs + f1 <= l3 + with self.assertRaises(TypeError): + f1 <= l1 + with self.assertRaises(TypeError): + f1 <= l2 + + # proper superset + self.assertTrue(f3 > f1) + self.assertFalse(f1 > f1) + self.assertFalse(f2 > f1) + self.assertTrue(r3 > r1) + self.assertFalse(f1 > r1) + self.assertFalse(f2 > r1) + self.assertTrue(r3 > r1) + self.assertFalse(r1 > r1) + self.assertFalse(r2 > r1) + with self.assertRaises(TypeError): + f1 > l3 + with self.assertRaises(TypeError): + f1 > l1 + with self.assertRaises(TypeError): + f1 > l2 + + # any superset + self.assertTrue(f3 >= f1) + self.assertTrue(f1 >= f1) + self.assertFalse(f2 >= f1) + self.assertTrue(r3 >= r1) + self.assertTrue(f1 >= r1) + self.assertFalse(f2 >= r1) + self.assertTrue(r3 >= r1) + self.assertTrue(r1 >= r1) + self.assertFalse(r2 >= r1) + with self.assertRaises(TypeError): + f1 >= l3 + with self.assertRaises(TypeError): + f1 >=l1 + with self.assertRaises(TypeError): + f1 >= l2 + + # equality + self.assertTrue(f1 == f1) + self.assertTrue(r1 == f1) + self.assertTrue(f1 == r1) + self.assertFalse(f1 == f3) + self.assertFalse(r1 == f3) + self.assertFalse(f1 == r3) + self.assertFalse(f1 == l3) + self.assertFalse(f1 == l1) + self.assertFalse(f1 == l2) + + # inequality + self.assertFalse(f1 != f1) + self.assertFalse(r1 != f1) + self.assertFalse(f1 != r1) + self.assertTrue(f1 != f3) + self.assertTrue(r1 != f3) + self.assertTrue(f1 != r3) + self.assertTrue(f1 != l3) + self.assertTrue(f1 != l1) + self.assertTrue(f1 != l2) def test_Mapping(self): for sample in [dict]: diff --git a/Misc/NEWS b/Misc/NEWS index 950040b47aa..cdce6f06ffa 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -24,6 +24,9 @@ Library - Issue #14710: pkgutil.find_loader() no longer raises an exception when a module doesn't exist. +- Issue #8743: Fix interoperability between set objects and the + collections.Set() abstract base class. + - Issue #13355: random.triangular() no longer fails with a ZeroDivisionError when low equals high.