Close #10042: functools.total_ordering now handles NotImplemented
(Patch by Katie Miller)
This commit is contained in:
parent
e6f4631f08
commit
f05d981f58
|
@ -134,15 +134,34 @@ The :mod:`functools` module defines the following functions:
|
||||||
|
|
||||||
@total_ordering
|
@total_ordering
|
||||||
class Student:
|
class Student:
|
||||||
|
def _is_valid_operand(self, other):
|
||||||
|
return (hasattr(other, "lastname") and
|
||||||
|
hasattr(other, "firstname"))
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
|
if not self._is_valid_operand(other):
|
||||||
|
return NotImplemented
|
||||||
return ((self.lastname.lower(), self.firstname.lower()) ==
|
return ((self.lastname.lower(), self.firstname.lower()) ==
|
||||||
(other.lastname.lower(), other.firstname.lower()))
|
(other.lastname.lower(), other.firstname.lower()))
|
||||||
def __lt__(self, other):
|
def __lt__(self, other):
|
||||||
|
if not self._is_valid_operand(other):
|
||||||
|
return NotImplemented
|
||||||
return ((self.lastname.lower(), self.firstname.lower()) <
|
return ((self.lastname.lower(), self.firstname.lower()) <
|
||||||
(other.lastname.lower(), other.firstname.lower()))
|
(other.lastname.lower(), other.firstname.lower()))
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
While this decorator makes it easy to create well behaved totally
|
||||||
|
ordered types, it *does* come at the cost of slower execution and
|
||||||
|
more complex stack traces for the derived comparison methods. If
|
||||||
|
performance benchmarking indicates this is a bottleneck for a given
|
||||||
|
application, implementing all six rich comparison methods instead is
|
||||||
|
likely to provide an easy speed boost.
|
||||||
|
|
||||||
.. versionadded:: 3.2
|
.. versionadded:: 3.2
|
||||||
|
|
||||||
|
.. versionchanged:: 3.4
|
||||||
|
Returning NotImplemented from the underlying comparison function for
|
||||||
|
unrecognised types is now supported.
|
||||||
|
|
||||||
.. function:: partial(func, *args, **keywords)
|
.. function:: partial(func, *args, **keywords)
|
||||||
|
|
||||||
|
|
|
@ -89,21 +89,91 @@ def wraps(wrapped,
|
||||||
### total_ordering class decorator
|
### total_ordering class decorator
|
||||||
################################################################################
|
################################################################################
|
||||||
|
|
||||||
|
# The correct way to indicate that a comparison operation doesn't
|
||||||
|
# recognise the other type is to return NotImplemented and let the
|
||||||
|
# interpreter handle raising TypeError if both operands return
|
||||||
|
# NotImplemented from their respective comparison methods
|
||||||
|
#
|
||||||
|
# This makes the implementation of total_ordering more complicated, since
|
||||||
|
# we need to be careful not to trigger infinite recursion when two
|
||||||
|
# different types that both use this decorator encounter each other.
|
||||||
|
#
|
||||||
|
# For example, if a type implements __lt__, it's natural to define
|
||||||
|
# __gt__ as something like:
|
||||||
|
#
|
||||||
|
# lambda self, other: not self < other and not self == other
|
||||||
|
#
|
||||||
|
# However, using the operator syntax like that ends up invoking the full
|
||||||
|
# type checking machinery again and means we can end up bouncing back and
|
||||||
|
# forth between the two operands until we run out of stack space.
|
||||||
|
#
|
||||||
|
# The solution is to define helper functions that invoke the appropriate
|
||||||
|
# magic methods directly, ensuring we only try each operand once, and
|
||||||
|
# return NotImplemented immediately if it is returned from the
|
||||||
|
# underlying user provided method. Using this scheme, the __gt__ derived
|
||||||
|
# from a user provided __lt__ becomes:
|
||||||
|
#
|
||||||
|
# lambda self, other: _not_op_and_not_eq(self.__lt__, self, other))
|
||||||
|
|
||||||
|
def _not_op(op, other):
|
||||||
|
# "not a < b" handles "a >= b"
|
||||||
|
# "not a <= b" handles "a > b"
|
||||||
|
# "not a >= b" handles "a < b"
|
||||||
|
# "not a > b" handles "a <= b"
|
||||||
|
op_result = op(other)
|
||||||
|
if op_result is NotImplemented:
|
||||||
|
return NotImplemented
|
||||||
|
return not op_result
|
||||||
|
|
||||||
|
def _op_or_eq(op, self, other):
|
||||||
|
# "a < b or a == b" handles "a <= b"
|
||||||
|
# "a > b or a == b" handles "a >= b"
|
||||||
|
op_result = op(other)
|
||||||
|
if op_result is NotImplemented:
|
||||||
|
return NotImplemented
|
||||||
|
return op_result or self == other
|
||||||
|
|
||||||
|
def _not_op_and_not_eq(op, self, other):
|
||||||
|
# "not (a < b or a == b)" handles "a > b"
|
||||||
|
# "not a < b and a != b" is equivalent
|
||||||
|
# "not (a > b or a == b)" handles "a < b"
|
||||||
|
# "not a > b and a != b" is equivalent
|
||||||
|
op_result = op(other)
|
||||||
|
if op_result is NotImplemented:
|
||||||
|
return NotImplemented
|
||||||
|
return not op_result and self != other
|
||||||
|
|
||||||
|
def _not_op_or_eq(op, self, other):
|
||||||
|
# "not a <= b or a == b" handles "a >= b"
|
||||||
|
# "not a >= b or a == b" handles "a <= b"
|
||||||
|
op_result = op(other)
|
||||||
|
if op_result is NotImplemented:
|
||||||
|
return NotImplemented
|
||||||
|
return not op_result or self == other
|
||||||
|
|
||||||
|
def _op_and_not_eq(op, self, other):
|
||||||
|
# "a <= b and not a == b" handles "a < b"
|
||||||
|
# "a >= b and not a == b" handles "a > b"
|
||||||
|
op_result = op(other)
|
||||||
|
if op_result is NotImplemented:
|
||||||
|
return NotImplemented
|
||||||
|
return op_result and self != other
|
||||||
|
|
||||||
def total_ordering(cls):
|
def total_ordering(cls):
|
||||||
"""Class decorator that fills in missing ordering methods"""
|
"""Class decorator that fills in missing ordering methods"""
|
||||||
convert = {
|
convert = {
|
||||||
'__lt__': [('__gt__', lambda self, other: not (self < other or self == other)),
|
'__lt__': [('__gt__', lambda self, other: _not_op_and_not_eq(self.__lt__, self, other)),
|
||||||
('__le__', lambda self, other: self < other or self == other),
|
('__le__', lambda self, other: _op_or_eq(self.__lt__, self, other)),
|
||||||
('__ge__', lambda self, other: not self < other)],
|
('__ge__', lambda self, other: _not_op(self.__lt__, other))],
|
||||||
'__le__': [('__ge__', lambda self, other: not self <= other or self == other),
|
'__le__': [('__ge__', lambda self, other: _not_op_or_eq(self.__le__, self, other)),
|
||||||
('__lt__', lambda self, other: self <= other and not self == other),
|
('__lt__', lambda self, other: _op_and_not_eq(self.__le__, self, other)),
|
||||||
('__gt__', lambda self, other: not self <= other)],
|
('__gt__', lambda self, other: _not_op(self.__le__, other))],
|
||||||
'__gt__': [('__lt__', lambda self, other: not (self > other or self == other)),
|
'__gt__': [('__lt__', lambda self, other: _not_op_and_not_eq(self.__gt__, self, other)),
|
||||||
('__ge__', lambda self, other: self > other or self == other),
|
('__ge__', lambda self, other: _op_or_eq(self.__gt__, self, other)),
|
||||||
('__le__', lambda self, other: not self > other)],
|
('__le__', lambda self, other: _not_op(self.__gt__, other))],
|
||||||
'__ge__': [('__le__', lambda self, other: (not self >= other) or self == other),
|
'__ge__': [('__le__', lambda self, other: _not_op_or_eq(self.__ge__, self, other)),
|
||||||
('__gt__', lambda self, other: self >= other and not self == other),
|
('__gt__', lambda self, other: _op_and_not_eq(self.__ge__, self, other)),
|
||||||
('__lt__', lambda self, other: not self >= other)]
|
('__lt__', lambda self, other: _not_op(self.__ge__, other))]
|
||||||
}
|
}
|
||||||
# Find user-defined comparisons (not those inherited from object).
|
# Find user-defined comparisons (not those inherited from object).
|
||||||
roots = [op for op in convert if getattr(cls, op, None) is not getattr(object, op, None)]
|
roots = [op for op in convert if getattr(cls, op, None) is not getattr(object, op, None)]
|
||||||
|
|
|
@ -584,6 +584,7 @@ class TestTotalOrdering(unittest.TestCase):
|
||||||
self.assertTrue(A(2) >= A(1))
|
self.assertTrue(A(2) >= A(1))
|
||||||
self.assertTrue(A(2) <= A(2))
|
self.assertTrue(A(2) <= A(2))
|
||||||
self.assertTrue(A(2) >= A(2))
|
self.assertTrue(A(2) >= A(2))
|
||||||
|
self.assertFalse(A(1) > A(2))
|
||||||
|
|
||||||
def test_total_ordering_le(self):
|
def test_total_ordering_le(self):
|
||||||
@functools.total_ordering
|
@functools.total_ordering
|
||||||
|
@ -600,6 +601,7 @@ class TestTotalOrdering(unittest.TestCase):
|
||||||
self.assertTrue(A(2) >= A(1))
|
self.assertTrue(A(2) >= A(1))
|
||||||
self.assertTrue(A(2) <= A(2))
|
self.assertTrue(A(2) <= A(2))
|
||||||
self.assertTrue(A(2) >= A(2))
|
self.assertTrue(A(2) >= A(2))
|
||||||
|
self.assertFalse(A(1) >= A(2))
|
||||||
|
|
||||||
def test_total_ordering_gt(self):
|
def test_total_ordering_gt(self):
|
||||||
@functools.total_ordering
|
@functools.total_ordering
|
||||||
|
@ -616,6 +618,7 @@ class TestTotalOrdering(unittest.TestCase):
|
||||||
self.assertTrue(A(2) >= A(1))
|
self.assertTrue(A(2) >= A(1))
|
||||||
self.assertTrue(A(2) <= A(2))
|
self.assertTrue(A(2) <= A(2))
|
||||||
self.assertTrue(A(2) >= A(2))
|
self.assertTrue(A(2) >= A(2))
|
||||||
|
self.assertFalse(A(2) < A(1))
|
||||||
|
|
||||||
def test_total_ordering_ge(self):
|
def test_total_ordering_ge(self):
|
||||||
@functools.total_ordering
|
@functools.total_ordering
|
||||||
|
@ -632,6 +635,7 @@ class TestTotalOrdering(unittest.TestCase):
|
||||||
self.assertTrue(A(2) >= A(1))
|
self.assertTrue(A(2) >= A(1))
|
||||||
self.assertTrue(A(2) <= A(2))
|
self.assertTrue(A(2) <= A(2))
|
||||||
self.assertTrue(A(2) >= A(2))
|
self.assertTrue(A(2) >= A(2))
|
||||||
|
self.assertFalse(A(2) <= A(1))
|
||||||
|
|
||||||
def test_total_ordering_no_overwrite(self):
|
def test_total_ordering_no_overwrite(self):
|
||||||
# new methods should not overwrite existing
|
# new methods should not overwrite existing
|
||||||
|
@ -651,22 +655,112 @@ class TestTotalOrdering(unittest.TestCase):
|
||||||
class A:
|
class A:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_bug_10042(self):
|
def test_type_error_when_not_implemented(self):
|
||||||
|
# bug 10042; ensure stack overflow does not occur
|
||||||
|
# when decorated types return NotImplemented
|
||||||
@functools.total_ordering
|
@functools.total_ordering
|
||||||
class TestTO:
|
class ImplementsLessThan:
|
||||||
def __init__(self, value):
|
def __init__(self, value):
|
||||||
self.value = value
|
self.value = value
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if isinstance(other, TestTO):
|
if isinstance(other, ImplementsLessThan):
|
||||||
return self.value == other.value
|
return self.value == other.value
|
||||||
return False
|
return False
|
||||||
def __lt__(self, other):
|
def __lt__(self, other):
|
||||||
if isinstance(other, TestTO):
|
if isinstance(other, ImplementsLessThan):
|
||||||
return self.value < other.value
|
return self.value < other.value
|
||||||
raise TypeError
|
return NotImplemented
|
||||||
with self.assertRaises(TypeError):
|
|
||||||
TestTO(8) <= ()
|
|
||||||
|
|
||||||
|
@functools.total_ordering
|
||||||
|
class ImplementsGreaterThan:
|
||||||
|
def __init__(self, value):
|
||||||
|
self.value = value
|
||||||
|
def __eq__(self, other):
|
||||||
|
if isinstance(other, ImplementsGreaterThan):
|
||||||
|
return self.value == other.value
|
||||||
|
return False
|
||||||
|
def __gt__(self, other):
|
||||||
|
if isinstance(other, ImplementsGreaterThan):
|
||||||
|
return self.value > other.value
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
@functools.total_ordering
|
||||||
|
class ImplementsLessThanEqualTo:
|
||||||
|
def __init__(self, value):
|
||||||
|
self.value = value
|
||||||
|
def __eq__(self, other):
|
||||||
|
if isinstance(other, ImplementsLessThanEqualTo):
|
||||||
|
return self.value == other.value
|
||||||
|
return False
|
||||||
|
def __le__(self, other):
|
||||||
|
if isinstance(other, ImplementsLessThanEqualTo):
|
||||||
|
return self.value <= other.value
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
@functools.total_ordering
|
||||||
|
class ImplementsGreaterThanEqualTo:
|
||||||
|
def __init__(self, value):
|
||||||
|
self.value = value
|
||||||
|
def __eq__(self, other):
|
||||||
|
if isinstance(other, ImplementsGreaterThanEqualTo):
|
||||||
|
return self.value == other.value
|
||||||
|
return False
|
||||||
|
def __ge__(self, other):
|
||||||
|
if isinstance(other, ImplementsGreaterThanEqualTo):
|
||||||
|
return self.value >= other.value
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
@functools.total_ordering
|
||||||
|
class ComparatorNotImplemented:
|
||||||
|
def __init__(self, value):
|
||||||
|
self.value = value
|
||||||
|
def __eq__(self, other):
|
||||||
|
if isinstance(other, ComparatorNotImplemented):
|
||||||
|
return self.value == other.value
|
||||||
|
return False
|
||||||
|
def __lt__(self, other):
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
with self.subTest("LT < 1"), self.assertRaises(TypeError):
|
||||||
|
ImplementsLessThan(-1) < 1
|
||||||
|
|
||||||
|
with self.subTest("LT < LE"), self.assertRaises(TypeError):
|
||||||
|
ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
|
||||||
|
|
||||||
|
with self.subTest("LT < GT"), self.assertRaises(TypeError):
|
||||||
|
ImplementsLessThan(1) < ImplementsGreaterThan(1)
|
||||||
|
|
||||||
|
with self.subTest("LE <= LT"), self.assertRaises(TypeError):
|
||||||
|
ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
|
||||||
|
|
||||||
|
with self.subTest("LE <= GE"), self.assertRaises(TypeError):
|
||||||
|
ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
|
||||||
|
|
||||||
|
with self.subTest("GT > GE"), self.assertRaises(TypeError):
|
||||||
|
ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
|
||||||
|
|
||||||
|
with self.subTest("GT > LT"), self.assertRaises(TypeError):
|
||||||
|
ImplementsGreaterThan(5) > ImplementsLessThan(5)
|
||||||
|
|
||||||
|
with self.subTest("GE >= GT"), self.assertRaises(TypeError):
|
||||||
|
ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
|
||||||
|
|
||||||
|
with self.subTest("GE >= LE"), self.assertRaises(TypeError):
|
||||||
|
ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
|
||||||
|
|
||||||
|
with self.subTest("GE when equal"):
|
||||||
|
a = ComparatorNotImplemented(8)
|
||||||
|
b = ComparatorNotImplemented(8)
|
||||||
|
self.assertEqual(a, b)
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
a >= b
|
||||||
|
|
||||||
|
with self.subTest("LE when equal"):
|
||||||
|
a = ComparatorNotImplemented(9)
|
||||||
|
b = ComparatorNotImplemented(9)
|
||||||
|
self.assertEqual(a, b)
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
a <= b
|
||||||
|
|
||||||
class TestLRU(unittest.TestCase):
|
class TestLRU(unittest.TestCase):
|
||||||
|
|
||||||
|
|
|
@ -862,6 +862,7 @@ Chad Miller
|
||||||
Damien Miller
|
Damien Miller
|
||||||
Jason V. Miller
|
Jason V. Miller
|
||||||
Jay T. Miller
|
Jay T. Miller
|
||||||
|
Katie Miller
|
||||||
Roman Milner
|
Roman Milner
|
||||||
Julien Miotte
|
Julien Miotte
|
||||||
Andrii V. Mishkovskyi
|
Andrii V. Mishkovskyi
|
||||||
|
|
|
@ -13,6 +13,10 @@ Core and Builtins
|
||||||
Library
|
Library
|
||||||
-------
|
-------
|
||||||
|
|
||||||
|
- Issue #10042: functools.total_ordering now correctly handles
|
||||||
|
NotImplemented being returned by the underlying comparison function (Patch
|
||||||
|
by Katie Miller)
|
||||||
|
|
||||||
- Issue #19092: contextlib.ExitStack now correctly reraises exceptions
|
- Issue #19092: contextlib.ExitStack now correctly reraises exceptions
|
||||||
from the __exit__ callbacks of inner context managers (Patch by Hrvoje
|
from the __exit__ callbacks of inner context managers (Patch by Hrvoje
|
||||||
Nikšić)
|
Nikšić)
|
||||||
|
|
Loading…
Reference in New Issue