Change the binary operators |, &, ^, - to return NotImplemented rather

than raising TypeError when the other argument is not a BaseSet.  This
made it necessary to separate the implementation of e.g. __or__ from
the union method; the latter should not return NotImplemented but
raise TypeError.  This is accomplished by making union(self, other)
return self|other, etc.; Python's binary operator machinery will raise
TypeError.

The idea behind this change is to allow other set implementations with
an incompatible internal structure; these can provide union (etc.) with
standard sets by implementing __ror__ etc.

I wish I could do this for comparisons too, but the default comparison
implementation allows comparing anything to anything else (returning
false); we don't want that (at least the test suite makes sure
e.g. Set()==42 raises TypeError).  That's probably fine; otherwise
other set implementations would be constrained to implementing a hash
that's compatible with ours.
This commit is contained in:
Guido van Rossum 2002-08-22 17:23:33 +00:00
parent 13090e1025
commit dc61cdf6c0
1 changed files with 43 additions and 17 deletions

View File

@ -53,7 +53,7 @@ what's tested is actually `z in y'.
# and cleaned up the docstrings.
#
# - Raymond Hettinger added a number of speedups and other
# bugs^H^H^H^Himprovements.
# improvements.
__all__ = ['BaseSet', 'Set', 'ImmutableSet']
@ -155,26 +155,35 @@ class BaseSet(object):
data[deepcopy(elt, memo)] = value
return result
# Standard set operations: union, intersection, both differences
# Standard set operations: union, intersection, both differences.
# Each has an operator version (e.g. __or__, invoked with |) and a
# method version (e.g. union).
def __or__(self, other):
"""Return the union of two sets as a new set.
(I.e. all elements that are in either set.)
"""
if not isinstance(other, BaseSet):
return NotImplemented
result = self.__class__(self._data)
result._data.update(other._data)
return result
def union(self, other):
"""Return the union of two sets as a new set.
(I.e. all elements that are in either set.)
"""
self._binary_sanity_check(other)
result = self.__class__(self._data)
result._data.update(other._data)
return result
return self | other
__or__ = union
def intersection(self, other):
def __and__(self, other):
"""Return the intersection of two sets as a new set.
(I.e. all elements that are in both sets.)
"""
self._binary_sanity_check(other)
if not isinstance(other, BaseSet):
return NotImplemented
if len(self) <= len(other):
little, big = self, other
else:
@ -187,14 +196,20 @@ class BaseSet(object):
data[elt] = value
return result
__and__ = intersection
def intersection(self, other):
"""Return the intersection of two sets as a new set.
def symmetric_difference(self, other):
(I.e. all elements that are in both sets.)
"""
return self & other
def __xor__(self, other):
"""Return the symmetric difference of two sets as a new set.
(I.e. all elements that are in exactly one of the sets.)
"""
self._binary_sanity_check(other)
if not isinstance(other, BaseSet):
return NotImplemented
result = self.__class__([])
data = result._data
value = True
@ -206,14 +221,20 @@ class BaseSet(object):
data[elt] = value
return result
__xor__ = symmetric_difference
def symmetric_difference(self, other):
"""Return the symmetric difference of two sets as a new set.
def difference(self, other):
(I.e. all elements that are in exactly one of the sets.)
"""
return self ^ other
def __sub__(self, other):
"""Return the difference of two sets as a new Set.
(I.e. all elements that are in this set and not in the other.)
"""
self._binary_sanity_check(other)
if not isinstance(other, BaseSet):
return NotImplemented
result = self.__class__([])
data = result._data
value = True
@ -222,7 +243,12 @@ class BaseSet(object):
data[elt] = value
return result
__sub__ = difference
def difference(self, other):
"""Return the difference of two sets as a new Set.
(I.e. all elements that are in this set and not in the other.)
"""
return self - other
# Membership test