From dc61cdf6c0f433cba6a51b05346acb5b538a7617 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 22 Aug 2002 17:23:33 +0000 Subject: [PATCH] Change the binary operators |, &, ^, - to return NotImplemented rather than raising TypeError when the other argument is not a BaseSet. This made it necessary to separate the implementation of e.g. __or__ from the union method; the latter should not return NotImplemented but raise TypeError. This is accomplished by making union(self, other) return self|other, etc.; Python's binary operator machinery will raise TypeError. The idea behind this change is to allow other set implementations with an incompatible internal structure; these can provide union (etc.) with standard sets by implementing __ror__ etc. I wish I could do this for comparisons too, but the default comparison implementation allows comparing anything to anything else (returning false); we don't want that (at least the test suite makes sure e.g. Set()==42 raises TypeError). That's probably fine; otherwise other set implementations would be constrained to implementing a hash that's compatible with ours. --- Lib/sets.py | 60 ++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 43 insertions(+), 17 deletions(-) diff --git a/Lib/sets.py b/Lib/sets.py index eeef0e8087f..fee06d76be0 100644 --- a/Lib/sets.py +++ b/Lib/sets.py @@ -53,7 +53,7 @@ what's tested is actually `z in y'. # and cleaned up the docstrings. # # - Raymond Hettinger added a number of speedups and other -# bugs^H^H^H^Himprovements. +# improvements. __all__ = ['BaseSet', 'Set', 'ImmutableSet'] @@ -155,26 +155,35 @@ class BaseSet(object): data[deepcopy(elt, memo)] = value return result - # Standard set operations: union, intersection, both differences + # Standard set operations: union, intersection, both differences. + # Each has an operator version (e.g. __or__, invoked with |) and a + # method version (e.g. union). + + def __or__(self, other): + """Return the union of two sets as a new set. + + (I.e. all elements that are in either set.) + """ + if not isinstance(other, BaseSet): + return NotImplemented + result = self.__class__(self._data) + result._data.update(other._data) + return result def union(self, other): """Return the union of two sets as a new set. (I.e. all elements that are in either set.) """ - self._binary_sanity_check(other) - result = self.__class__(self._data) - result._data.update(other._data) - return result + return self | other - __or__ = union - - def intersection(self, other): + def __and__(self, other): """Return the intersection of two sets as a new set. (I.e. all elements that are in both sets.) """ - self._binary_sanity_check(other) + if not isinstance(other, BaseSet): + return NotImplemented if len(self) <= len(other): little, big = self, other else: @@ -187,14 +196,20 @@ class BaseSet(object): data[elt] = value return result - __and__ = intersection + def intersection(self, other): + """Return the intersection of two sets as a new set. - def symmetric_difference(self, other): + (I.e. all elements that are in both sets.) + """ + return self & other + + def __xor__(self, other): """Return the symmetric difference of two sets as a new set. (I.e. all elements that are in exactly one of the sets.) """ - self._binary_sanity_check(other) + if not isinstance(other, BaseSet): + return NotImplemented result = self.__class__([]) data = result._data value = True @@ -206,14 +221,20 @@ class BaseSet(object): data[elt] = value return result - __xor__ = symmetric_difference + def symmetric_difference(self, other): + """Return the symmetric difference of two sets as a new set. - def difference(self, other): + (I.e. all elements that are in exactly one of the sets.) + """ + return self ^ other + + def __sub__(self, other): """Return the difference of two sets as a new Set. (I.e. all elements that are in this set and not in the other.) """ - self._binary_sanity_check(other) + if not isinstance(other, BaseSet): + return NotImplemented result = self.__class__([]) data = result._data value = True @@ -222,7 +243,12 @@ class BaseSet(object): data[elt] = value return result - __sub__ = difference + def difference(self, other): + """Return the difference of two sets as a new Set. + + (I.e. all elements that are in this set and not in the other.) + """ + return self - other # Membership test