Change the binary operators |, &, ^, - to return NotImplemented rather

than raising TypeError when the other argument is not a BaseSet.  This
made it necessary to separate the implementation of e.g. __or__ from
the union method; the latter should not return NotImplemented but
raise TypeError.  This is accomplished by making union(self, other)
return self|other, etc.; Python's binary operator machinery will raise
TypeError.

The idea behind this change is to allow other set implementations with
an incompatible internal structure; these can provide union (etc.) with
standard sets by implementing __ror__ etc.

I wish I could do this for comparisons too, but the default comparison
implementation allows comparing anything to anything else (returning
false); we don't want that (at least the test suite makes sure
e.g. Set()==42 raises TypeError).  That's probably fine; otherwise
other set implementations would be constrained to implementing a hash
that's compatible with ours.
This commit is contained in:
Guido van Rossum 2002-08-22 17:23:33 +00:00
parent 13090e1025
commit dc61cdf6c0
1 changed files with 43 additions and 17 deletions

View File

@ -53,7 +53,7 @@ what's tested is actually `z in y'.
# and cleaned up the docstrings. # and cleaned up the docstrings.
# #
# - Raymond Hettinger added a number of speedups and other # - Raymond Hettinger added a number of speedups and other
# bugs^H^H^H^Himprovements. # improvements.
__all__ = ['BaseSet', 'Set', 'ImmutableSet'] __all__ = ['BaseSet', 'Set', 'ImmutableSet']
@ -155,26 +155,35 @@ class BaseSet(object):
data[deepcopy(elt, memo)] = value data[deepcopy(elt, memo)] = value
return result return result
# Standard set operations: union, intersection, both differences # Standard set operations: union, intersection, both differences.
# Each has an operator version (e.g. __or__, invoked with |) and a
# method version (e.g. union).
def __or__(self, other):
"""Return the union of two sets as a new set.
(I.e. all elements that are in either set.)
"""
if not isinstance(other, BaseSet):
return NotImplemented
result = self.__class__(self._data)
result._data.update(other._data)
return result
def union(self, other): def union(self, other):
"""Return the union of two sets as a new set. """Return the union of two sets as a new set.
(I.e. all elements that are in either set.) (I.e. all elements that are in either set.)
""" """
self._binary_sanity_check(other) return self | other
result = self.__class__(self._data)
result._data.update(other._data)
return result
__or__ = union def __and__(self, other):
def intersection(self, other):
"""Return the intersection of two sets as a new set. """Return the intersection of two sets as a new set.
(I.e. all elements that are in both sets.) (I.e. all elements that are in both sets.)
""" """
self._binary_sanity_check(other) if not isinstance(other, BaseSet):
return NotImplemented
if len(self) <= len(other): if len(self) <= len(other):
little, big = self, other little, big = self, other
else: else:
@ -187,14 +196,20 @@ class BaseSet(object):
data[elt] = value data[elt] = value
return result return result
__and__ = intersection def intersection(self, other):
"""Return the intersection of two sets as a new set.
def symmetric_difference(self, other): (I.e. all elements that are in both sets.)
"""
return self & other
def __xor__(self, other):
"""Return the symmetric difference of two sets as a new set. """Return the symmetric difference of two sets as a new set.
(I.e. all elements that are in exactly one of the sets.) (I.e. all elements that are in exactly one of the sets.)
""" """
self._binary_sanity_check(other) if not isinstance(other, BaseSet):
return NotImplemented
result = self.__class__([]) result = self.__class__([])
data = result._data data = result._data
value = True value = True
@ -206,14 +221,20 @@ class BaseSet(object):
data[elt] = value data[elt] = value
return result return result
__xor__ = symmetric_difference def symmetric_difference(self, other):
"""Return the symmetric difference of two sets as a new set.
def difference(self, other): (I.e. all elements that are in exactly one of the sets.)
"""
return self ^ other
def __sub__(self, other):
"""Return the difference of two sets as a new Set. """Return the difference of two sets as a new Set.
(I.e. all elements that are in this set and not in the other.) (I.e. all elements that are in this set and not in the other.)
""" """
self._binary_sanity_check(other) if not isinstance(other, BaseSet):
return NotImplemented
result = self.__class__([]) result = self.__class__([])
data = result._data data = result._data
value = True value = True
@ -222,7 +243,12 @@ class BaseSet(object):
data[elt] = value data[elt] = value
return result return result
__sub__ = difference def difference(self, other):
"""Return the difference of two sets as a new Set.
(I.e. all elements that are in this set and not in the other.)
"""
return self - other
# Membership test # Membership test