Close #10042: functools.total_ordering now handles NotImplemented

(Patch by Katie Miller)
This commit is contained in:
Nick Coghlan 2013-10-02 00:02:03 +10:00
parent e6f4631f08
commit f05d981f58
5 changed files with 207 additions and 19 deletions

View File

@ -134,15 +134,34 @@ The :mod:`functools` module defines the following functions:
@total_ordering @total_ordering
class Student: class Student:
def _is_valid_operand(self, other):
return (hasattr(other, "lastname") and
hasattr(other, "firstname"))
def __eq__(self, other): def __eq__(self, other):
if not self._is_valid_operand(other):
return NotImplemented
return ((self.lastname.lower(), self.firstname.lower()) == return ((self.lastname.lower(), self.firstname.lower()) ==
(other.lastname.lower(), other.firstname.lower())) (other.lastname.lower(), other.firstname.lower()))
def __lt__(self, other): def __lt__(self, other):
if not self._is_valid_operand(other):
return NotImplemented
return ((self.lastname.lower(), self.firstname.lower()) < return ((self.lastname.lower(), self.firstname.lower()) <
(other.lastname.lower(), other.firstname.lower())) (other.lastname.lower(), other.firstname.lower()))
.. note::
While this decorator makes it easy to create well behaved totally
ordered types, it *does* come at the cost of slower execution and
more complex stack traces for the derived comparison methods. If
performance benchmarking indicates this is a bottleneck for a given
application, implementing all six rich comparison methods instead is
likely to provide an easy speed boost.
.. versionadded:: 3.2 .. versionadded:: 3.2
.. versionchanged:: 3.4
Returning NotImplemented from the underlying comparison function for
unrecognised types is now supported.
.. function:: partial(func, *args, **keywords) .. function:: partial(func, *args, **keywords)

View File

@ -89,21 +89,91 @@ def wraps(wrapped,
### total_ordering class decorator ### total_ordering class decorator
################################################################################ ################################################################################
# The correct way to indicate that a comparison operation doesn't
# recognise the other type is to return NotImplemented and let the
# interpreter handle raising TypeError if both operands return
# NotImplemented from their respective comparison methods
#
# This makes the implementation of total_ordering more complicated, since
# we need to be careful not to trigger infinite recursion when two
# different types that both use this decorator encounter each other.
#
# For example, if a type implements __lt__, it's natural to define
# __gt__ as something like:
#
# lambda self, other: not self < other and not self == other
#
# However, using the operator syntax like that ends up invoking the full
# type checking machinery again and means we can end up bouncing back and
# forth between the two operands until we run out of stack space.
#
# The solution is to define helper functions that invoke the appropriate
# magic methods directly, ensuring we only try each operand once, and
# return NotImplemented immediately if it is returned from the
# underlying user provided method. Using this scheme, the __gt__ derived
# from a user provided __lt__ becomes:
#
# lambda self, other: _not_op_and_not_eq(self.__lt__, self, other))
def _not_op(op, other):
# "not a < b" handles "a >= b"
# "not a <= b" handles "a > b"
# "not a >= b" handles "a < b"
# "not a > b" handles "a <= b"
op_result = op(other)
if op_result is NotImplemented:
return NotImplemented
return not op_result
def _op_or_eq(op, self, other):
# "a < b or a == b" handles "a <= b"
# "a > b or a == b" handles "a >= b"
op_result = op(other)
if op_result is NotImplemented:
return NotImplemented
return op_result or self == other
def _not_op_and_not_eq(op, self, other):
# "not (a < b or a == b)" handles "a > b"
# "not a < b and a != b" is equivalent
# "not (a > b or a == b)" handles "a < b"
# "not a > b and a != b" is equivalent
op_result = op(other)
if op_result is NotImplemented:
return NotImplemented
return not op_result and self != other
def _not_op_or_eq(op, self, other):
# "not a <= b or a == b" handles "a >= b"
# "not a >= b or a == b" handles "a <= b"
op_result = op(other)
if op_result is NotImplemented:
return NotImplemented
return not op_result or self == other
def _op_and_not_eq(op, self, other):
# "a <= b and not a == b" handles "a < b"
# "a >= b and not a == b" handles "a > b"
op_result = op(other)
if op_result is NotImplemented:
return NotImplemented
return op_result and self != other
def total_ordering(cls): def total_ordering(cls):
"""Class decorator that fills in missing ordering methods""" """Class decorator that fills in missing ordering methods"""
convert = { convert = {
'__lt__': [('__gt__', lambda self, other: not (self < other or self == other)), '__lt__': [('__gt__', lambda self, other: _not_op_and_not_eq(self.__lt__, self, other)),
('__le__', lambda self, other: self < other or self == other), ('__le__', lambda self, other: _op_or_eq(self.__lt__, self, other)),
('__ge__', lambda self, other: not self < other)], ('__ge__', lambda self, other: _not_op(self.__lt__, other))],
'__le__': [('__ge__', lambda self, other: not self <= other or self == other), '__le__': [('__ge__', lambda self, other: _not_op_or_eq(self.__le__, self, other)),
('__lt__', lambda self, other: self <= other and not self == other), ('__lt__', lambda self, other: _op_and_not_eq(self.__le__, self, other)),
('__gt__', lambda self, other: not self <= other)], ('__gt__', lambda self, other: _not_op(self.__le__, other))],
'__gt__': [('__lt__', lambda self, other: not (self > other or self == other)), '__gt__': [('__lt__', lambda self, other: _not_op_and_not_eq(self.__gt__, self, other)),
('__ge__', lambda self, other: self > other or self == other), ('__ge__', lambda self, other: _op_or_eq(self.__gt__, self, other)),
('__le__', lambda self, other: not self > other)], ('__le__', lambda self, other: _not_op(self.__gt__, other))],
'__ge__': [('__le__', lambda self, other: (not self >= other) or self == other), '__ge__': [('__le__', lambda self, other: _not_op_or_eq(self.__ge__, self, other)),
('__gt__', lambda self, other: self >= other and not self == other), ('__gt__', lambda self, other: _op_and_not_eq(self.__ge__, self, other)),
('__lt__', lambda self, other: not self >= other)] ('__lt__', lambda self, other: _not_op(self.__ge__, other))]
} }
# Find user-defined comparisons (not those inherited from object). # Find user-defined comparisons (not those inherited from object).
roots = [op for op in convert if getattr(cls, op, None) is not getattr(object, op, None)] roots = [op for op in convert if getattr(cls, op, None) is not getattr(object, op, None)]

View File

@ -584,6 +584,7 @@ class TestTotalOrdering(unittest.TestCase):
self.assertTrue(A(2) >= A(1)) self.assertTrue(A(2) >= A(1))
self.assertTrue(A(2) <= A(2)) self.assertTrue(A(2) <= A(2))
self.assertTrue(A(2) >= A(2)) self.assertTrue(A(2) >= A(2))
self.assertFalse(A(1) > A(2))
def test_total_ordering_le(self): def test_total_ordering_le(self):
@functools.total_ordering @functools.total_ordering
@ -600,6 +601,7 @@ class TestTotalOrdering(unittest.TestCase):
self.assertTrue(A(2) >= A(1)) self.assertTrue(A(2) >= A(1))
self.assertTrue(A(2) <= A(2)) self.assertTrue(A(2) <= A(2))
self.assertTrue(A(2) >= A(2)) self.assertTrue(A(2) >= A(2))
self.assertFalse(A(1) >= A(2))
def test_total_ordering_gt(self): def test_total_ordering_gt(self):
@functools.total_ordering @functools.total_ordering
@ -616,6 +618,7 @@ class TestTotalOrdering(unittest.TestCase):
self.assertTrue(A(2) >= A(1)) self.assertTrue(A(2) >= A(1))
self.assertTrue(A(2) <= A(2)) self.assertTrue(A(2) <= A(2))
self.assertTrue(A(2) >= A(2)) self.assertTrue(A(2) >= A(2))
self.assertFalse(A(2) < A(1))
def test_total_ordering_ge(self): def test_total_ordering_ge(self):
@functools.total_ordering @functools.total_ordering
@ -632,6 +635,7 @@ class TestTotalOrdering(unittest.TestCase):
self.assertTrue(A(2) >= A(1)) self.assertTrue(A(2) >= A(1))
self.assertTrue(A(2) <= A(2)) self.assertTrue(A(2) <= A(2))
self.assertTrue(A(2) >= A(2)) self.assertTrue(A(2) >= A(2))
self.assertFalse(A(2) <= A(1))
def test_total_ordering_no_overwrite(self): def test_total_ordering_no_overwrite(self):
# new methods should not overwrite existing # new methods should not overwrite existing
@ -651,22 +655,112 @@ class TestTotalOrdering(unittest.TestCase):
class A: class A:
pass pass
def test_bug_10042(self): def test_type_error_when_not_implemented(self):
# bug 10042; ensure stack overflow does not occur
# when decorated types return NotImplemented
@functools.total_ordering @functools.total_ordering
class TestTO: class ImplementsLessThan:
def __init__(self, value): def __init__(self, value):
self.value = value self.value = value
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, TestTO): if isinstance(other, ImplementsLessThan):
return self.value == other.value return self.value == other.value
return False return False
def __lt__(self, other): def __lt__(self, other):
if isinstance(other, TestTO): if isinstance(other, ImplementsLessThan):
return self.value < other.value return self.value < other.value
raise TypeError return NotImplemented
with self.assertRaises(TypeError):
TestTO(8) <= ()
@functools.total_ordering
class ImplementsGreaterThan:
def __init__(self, value):
self.value = value
def __eq__(self, other):
if isinstance(other, ImplementsGreaterThan):
return self.value == other.value
return False
def __gt__(self, other):
if isinstance(other, ImplementsGreaterThan):
return self.value > other.value
return NotImplemented
@functools.total_ordering
class ImplementsLessThanEqualTo:
def __init__(self, value):
self.value = value
def __eq__(self, other):
if isinstance(other, ImplementsLessThanEqualTo):
return self.value == other.value
return False
def __le__(self, other):
if isinstance(other, ImplementsLessThanEqualTo):
return self.value <= other.value
return NotImplemented
@functools.total_ordering
class ImplementsGreaterThanEqualTo:
def __init__(self, value):
self.value = value
def __eq__(self, other):
if isinstance(other, ImplementsGreaterThanEqualTo):
return self.value == other.value
return False
def __ge__(self, other):
if isinstance(other, ImplementsGreaterThanEqualTo):
return self.value >= other.value
return NotImplemented
@functools.total_ordering
class ComparatorNotImplemented:
def __init__(self, value):
self.value = value
def __eq__(self, other):
if isinstance(other, ComparatorNotImplemented):
return self.value == other.value
return False
def __lt__(self, other):
return NotImplemented
with self.subTest("LT < 1"), self.assertRaises(TypeError):
ImplementsLessThan(-1) < 1
with self.subTest("LT < LE"), self.assertRaises(TypeError):
ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
with self.subTest("LT < GT"), self.assertRaises(TypeError):
ImplementsLessThan(1) < ImplementsGreaterThan(1)
with self.subTest("LE <= LT"), self.assertRaises(TypeError):
ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
with self.subTest("LE <= GE"), self.assertRaises(TypeError):
ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
with self.subTest("GT > GE"), self.assertRaises(TypeError):
ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
with self.subTest("GT > LT"), self.assertRaises(TypeError):
ImplementsGreaterThan(5) > ImplementsLessThan(5)
with self.subTest("GE >= GT"), self.assertRaises(TypeError):
ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
with self.subTest("GE >= LE"), self.assertRaises(TypeError):
ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
with self.subTest("GE when equal"):
a = ComparatorNotImplemented(8)
b = ComparatorNotImplemented(8)
self.assertEqual(a, b)
with self.assertRaises(TypeError):
a >= b
with self.subTest("LE when equal"):
a = ComparatorNotImplemented(9)
b = ComparatorNotImplemented(9)
self.assertEqual(a, b)
with self.assertRaises(TypeError):
a <= b
class TestLRU(unittest.TestCase): class TestLRU(unittest.TestCase):

View File

@ -862,6 +862,7 @@ Chad Miller
Damien Miller Damien Miller
Jason V. Miller Jason V. Miller
Jay T. Miller Jay T. Miller
Katie Miller
Roman Milner Roman Milner
Julien Miotte Julien Miotte
Andrii V. Mishkovskyi Andrii V. Mishkovskyi

View File

@ -13,6 +13,10 @@ Core and Builtins
Library Library
------- -------
- Issue #10042: functools.total_ordering now correctly handles
NotImplemented being returned by the underlying comparison function (Patch
by Katie Miller)
- Issue #19092: contextlib.ExitStack now correctly reraises exceptions - Issue #19092: contextlib.ExitStack now correctly reraises exceptions
from the __exit__ callbacks of inner context managers (Patch by Hrvoje from the __exit__ callbacks of inner context managers (Patch by Hrvoje
Nikšić) Nikšić)