From f05d981f5867dcb19c0724d88378bdd35d73f02d Mon Sep 17 00:00:00 2001 From: Nick Coghlan Date: Wed, 2 Oct 2013 00:02:03 +1000 Subject: [PATCH] Close #10042: functools.total_ordering now handles NotImplemented (Patch by Katie Miller) --- Doc/library/functools.rst | 19 +++++++ Lib/functools.py | 94 +++++++++++++++++++++++++++----- Lib/test/test_functools.py | 108 ++++++++++++++++++++++++++++++++++--- Misc/ACKS | 1 + Misc/NEWS | 4 ++ 5 files changed, 207 insertions(+), 19 deletions(-) diff --git a/Doc/library/functools.rst b/Doc/library/functools.rst index 2f6d9afe0f1..5eb86ec93c5 100644 --- a/Doc/library/functools.rst +++ b/Doc/library/functools.rst @@ -134,15 +134,34 @@ The :mod:`functools` module defines the following functions: @total_ordering class Student: + def _is_valid_operand(self, other): + return (hasattr(other, "lastname") and + hasattr(other, "firstname")) def __eq__(self, other): + if not self._is_valid_operand(other): + return NotImplemented return ((self.lastname.lower(), self.firstname.lower()) == (other.lastname.lower(), other.firstname.lower())) def __lt__(self, other): + if not self._is_valid_operand(other): + return NotImplemented return ((self.lastname.lower(), self.firstname.lower()) < (other.lastname.lower(), other.firstname.lower())) + .. note:: + + While this decorator makes it easy to create well behaved totally + ordered types, it *does* come at the cost of slower execution and + more complex stack traces for the derived comparison methods. If + performance benchmarking indicates this is a bottleneck for a given + application, implementing all six rich comparison methods instead is + likely to provide an easy speed boost. + .. versionadded:: 3.2 + .. versionchanged:: 3.4 + Returning NotImplemented from the underlying comparison function for + unrecognised types is now supported. .. function:: partial(func, *args, **keywords) diff --git a/Lib/functools.py b/Lib/functools.py index 19f88c7f021..6a6974fc5ed 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -89,21 +89,91 @@ def wraps(wrapped, ### total_ordering class decorator ################################################################################ +# The correct way to indicate that a comparison operation doesn't +# recognise the other type is to return NotImplemented and let the +# interpreter handle raising TypeError if both operands return +# NotImplemented from their respective comparison methods +# +# This makes the implementation of total_ordering more complicated, since +# we need to be careful not to trigger infinite recursion when two +# different types that both use this decorator encounter each other. +# +# For example, if a type implements __lt__, it's natural to define +# __gt__ as something like: +# +# lambda self, other: not self < other and not self == other +# +# However, using the operator syntax like that ends up invoking the full +# type checking machinery again and means we can end up bouncing back and +# forth between the two operands until we run out of stack space. +# +# The solution is to define helper functions that invoke the appropriate +# magic methods directly, ensuring we only try each operand once, and +# return NotImplemented immediately if it is returned from the +# underlying user provided method. Using this scheme, the __gt__ derived +# from a user provided __lt__ becomes: +# +# lambda self, other: _not_op_and_not_eq(self.__lt__, self, other)) + +def _not_op(op, other): + # "not a < b" handles "a >= b" + # "not a <= b" handles "a > b" + # "not a >= b" handles "a < b" + # "not a > b" handles "a <= b" + op_result = op(other) + if op_result is NotImplemented: + return NotImplemented + return not op_result + +def _op_or_eq(op, self, other): + # "a < b or a == b" handles "a <= b" + # "a > b or a == b" handles "a >= b" + op_result = op(other) + if op_result is NotImplemented: + return NotImplemented + return op_result or self == other + +def _not_op_and_not_eq(op, self, other): + # "not (a < b or a == b)" handles "a > b" + # "not a < b and a != b" is equivalent + # "not (a > b or a == b)" handles "a < b" + # "not a > b and a != b" is equivalent + op_result = op(other) + if op_result is NotImplemented: + return NotImplemented + return not op_result and self != other + +def _not_op_or_eq(op, self, other): + # "not a <= b or a == b" handles "a >= b" + # "not a >= b or a == b" handles "a <= b" + op_result = op(other) + if op_result is NotImplemented: + return NotImplemented + return not op_result or self == other + +def _op_and_not_eq(op, self, other): + # "a <= b and not a == b" handles "a < b" + # "a >= b and not a == b" handles "a > b" + op_result = op(other) + if op_result is NotImplemented: + return NotImplemented + return op_result and self != other + def total_ordering(cls): """Class decorator that fills in missing ordering methods""" convert = { - '__lt__': [('__gt__', lambda self, other: not (self < other or self == other)), - ('__le__', lambda self, other: self < other or self == other), - ('__ge__', lambda self, other: not self < other)], - '__le__': [('__ge__', lambda self, other: not self <= other or self == other), - ('__lt__', lambda self, other: self <= other and not self == other), - ('__gt__', lambda self, other: not self <= other)], - '__gt__': [('__lt__', lambda self, other: not (self > other or self == other)), - ('__ge__', lambda self, other: self > other or self == other), - ('__le__', lambda self, other: not self > other)], - '__ge__': [('__le__', lambda self, other: (not self >= other) or self == other), - ('__gt__', lambda self, other: self >= other and not self == other), - ('__lt__', lambda self, other: not self >= other)] + '__lt__': [('__gt__', lambda self, other: _not_op_and_not_eq(self.__lt__, self, other)), + ('__le__', lambda self, other: _op_or_eq(self.__lt__, self, other)), + ('__ge__', lambda self, other: _not_op(self.__lt__, other))], + '__le__': [('__ge__', lambda self, other: _not_op_or_eq(self.__le__, self, other)), + ('__lt__', lambda self, other: _op_and_not_eq(self.__le__, self, other)), + ('__gt__', lambda self, other: _not_op(self.__le__, other))], + '__gt__': [('__lt__', lambda self, other: _not_op_and_not_eq(self.__gt__, self, other)), + ('__ge__', lambda self, other: _op_or_eq(self.__gt__, self, other)), + ('__le__', lambda self, other: _not_op(self.__gt__, other))], + '__ge__': [('__le__', lambda self, other: _not_op_or_eq(self.__ge__, self, other)), + ('__gt__', lambda self, other: _op_and_not_eq(self.__ge__, self, other)), + ('__lt__', lambda self, other: _not_op(self.__ge__, other))] } # Find user-defined comparisons (not those inherited from object). roots = [op for op in convert if getattr(cls, op, None) is not getattr(object, op, None)] diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index ab76efbfefa..cb493bff52d 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -584,6 +584,7 @@ class TestTotalOrdering(unittest.TestCase): self.assertTrue(A(2) >= A(1)) self.assertTrue(A(2) <= A(2)) self.assertTrue(A(2) >= A(2)) + self.assertFalse(A(1) > A(2)) def test_total_ordering_le(self): @functools.total_ordering @@ -600,6 +601,7 @@ class TestTotalOrdering(unittest.TestCase): self.assertTrue(A(2) >= A(1)) self.assertTrue(A(2) <= A(2)) self.assertTrue(A(2) >= A(2)) + self.assertFalse(A(1) >= A(2)) def test_total_ordering_gt(self): @functools.total_ordering @@ -616,6 +618,7 @@ class TestTotalOrdering(unittest.TestCase): self.assertTrue(A(2) >= A(1)) self.assertTrue(A(2) <= A(2)) self.assertTrue(A(2) >= A(2)) + self.assertFalse(A(2) < A(1)) def test_total_ordering_ge(self): @functools.total_ordering @@ -632,6 +635,7 @@ class TestTotalOrdering(unittest.TestCase): self.assertTrue(A(2) >= A(1)) self.assertTrue(A(2) <= A(2)) self.assertTrue(A(2) >= A(2)) + self.assertFalse(A(2) <= A(1)) def test_total_ordering_no_overwrite(self): # new methods should not overwrite existing @@ -651,22 +655,112 @@ class TestTotalOrdering(unittest.TestCase): class A: pass - def test_bug_10042(self): + def test_type_error_when_not_implemented(self): + # bug 10042; ensure stack overflow does not occur + # when decorated types return NotImplemented @functools.total_ordering - class TestTO: + class ImplementsLessThan: def __init__(self, value): self.value = value def __eq__(self, other): - if isinstance(other, TestTO): + if isinstance(other, ImplementsLessThan): return self.value == other.value return False def __lt__(self, other): - if isinstance(other, TestTO): + if isinstance(other, ImplementsLessThan): return self.value < other.value - raise TypeError - with self.assertRaises(TypeError): - TestTO(8) <= () + return NotImplemented + @functools.total_ordering + class ImplementsGreaterThan: + def __init__(self, value): + self.value = value + def __eq__(self, other): + if isinstance(other, ImplementsGreaterThan): + return self.value == other.value + return False + def __gt__(self, other): + if isinstance(other, ImplementsGreaterThan): + return self.value > other.value + return NotImplemented + + @functools.total_ordering + class ImplementsLessThanEqualTo: + def __init__(self, value): + self.value = value + def __eq__(self, other): + if isinstance(other, ImplementsLessThanEqualTo): + return self.value == other.value + return False + def __le__(self, other): + if isinstance(other, ImplementsLessThanEqualTo): + return self.value <= other.value + return NotImplemented + + @functools.total_ordering + class ImplementsGreaterThanEqualTo: + def __init__(self, value): + self.value = value + def __eq__(self, other): + if isinstance(other, ImplementsGreaterThanEqualTo): + return self.value == other.value + return False + def __ge__(self, other): + if isinstance(other, ImplementsGreaterThanEqualTo): + return self.value >= other.value + return NotImplemented + + @functools.total_ordering + class ComparatorNotImplemented: + def __init__(self, value): + self.value = value + def __eq__(self, other): + if isinstance(other, ComparatorNotImplemented): + return self.value == other.value + return False + def __lt__(self, other): + return NotImplemented + + with self.subTest("LT < 1"), self.assertRaises(TypeError): + ImplementsLessThan(-1) < 1 + + with self.subTest("LT < LE"), self.assertRaises(TypeError): + ImplementsLessThan(0) < ImplementsLessThanEqualTo(0) + + with self.subTest("LT < GT"), self.assertRaises(TypeError): + ImplementsLessThan(1) < ImplementsGreaterThan(1) + + with self.subTest("LE <= LT"), self.assertRaises(TypeError): + ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2) + + with self.subTest("LE <= GE"), self.assertRaises(TypeError): + ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3) + + with self.subTest("GT > GE"), self.assertRaises(TypeError): + ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4) + + with self.subTest("GT > LT"), self.assertRaises(TypeError): + ImplementsGreaterThan(5) > ImplementsLessThan(5) + + with self.subTest("GE >= GT"), self.assertRaises(TypeError): + ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6) + + with self.subTest("GE >= LE"), self.assertRaises(TypeError): + ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7) + + with self.subTest("GE when equal"): + a = ComparatorNotImplemented(8) + b = ComparatorNotImplemented(8) + self.assertEqual(a, b) + with self.assertRaises(TypeError): + a >= b + + with self.subTest("LE when equal"): + a = ComparatorNotImplemented(9) + b = ComparatorNotImplemented(9) + self.assertEqual(a, b) + with self.assertRaises(TypeError): + a <= b class TestLRU(unittest.TestCase): diff --git a/Misc/ACKS b/Misc/ACKS index a5d577dc9f0..bad2c517807 100644 --- a/Misc/ACKS +++ b/Misc/ACKS @@ -862,6 +862,7 @@ Chad Miller Damien Miller Jason V. Miller Jay T. Miller +Katie Miller Roman Milner Julien Miotte Andrii V. Mishkovskyi diff --git a/Misc/NEWS b/Misc/NEWS index 444a0426d58..b2030a8ad20 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -13,6 +13,10 @@ Core and Builtins Library ------- +- Issue #10042: functools.total_ordering now correctly handles + NotImplemented being returned by the underlying comparison function (Patch + by Katie Miller) + - Issue #19092: contextlib.ExitStack now correctly reraises exceptions from the __exit__ callbacks of inner context managers (Patch by Hrvoje Nikšić)