From 2fc9263df56d28a6c1def6c8a0517bbddfa899af Mon Sep 17 00:00:00 2001 From: Mark Dickinson Date: Wed, 6 Feb 2008 22:10:50 +0000 Subject: [PATCH] Issue 1979: Make Decimal comparisons (other than !=, ==) involving NaN raise InvalidOperation (and return False if InvalidOperation is trapped). --- Doc/library/decimal.rst | 13 ++++ Lib/decimal.py | 149 +++++++++++++++++++++++++++++---------- Lib/test/test_decimal.py | 20 +++++- Misc/NEWS | 3 + 4 files changed, 145 insertions(+), 40 deletions(-) diff --git a/Doc/library/decimal.rst b/Doc/library/decimal.rst index 7b858068552..ebb18bb2be5 100644 --- a/Doc/library/decimal.rst +++ b/Doc/library/decimal.rst @@ -1290,6 +1290,19 @@ A variant is :const:`sNaN` which signals rather than remaining quiet after every operation. This is a useful return value when an invalid result needs to interrupt a calculation for special handling. +The behavior of Python's comparison operators can be a little surprising where a +:const:`NaN` is involved. A test for equality where one of the operands is a +quiet or signaling :const:`NaN` always returns :const:`False` (even when doing +``Decimal('NaN')==Decimal('NaN')``), while a test for inequality always returns +:const:`True`. An attempt to compare two Decimals using any of the :const:'<', +:const:'<=', :const:'>' or :const:'>=' operators will raise the +:exc:`InvalidOperation` signal if either operand is a :const:`NaN`, and return +:const:`False` if this signal is trapped. Note that the General Decimal +Arithmetic specification does not specify the behavior of direct comparisons; +these rules for comparisons involving a :const:`NaN` were taken from the IEEE +754 standard. To ensure strict standards-compliance, use the :meth:`compare` +and :meth:`compare-signal` methods instead. + The signed zeros can result from calculations that underflow. They keep the sign that would have resulted if the calculation had been carried out to greater precision. Since their magnitude is zero, both positive and negative zeros are diff --git a/Lib/decimal.py b/Lib/decimal.py index eea9448f431..80340396658 100644 --- a/Lib/decimal.py +++ b/Lib/decimal.py @@ -717,6 +717,39 @@ class Decimal(object): return other._fix_nan(context) return 0 + def _compare_check_nans(self, other, context): + """Version of _check_nans used for the signaling comparisons + compare_signal, __le__, __lt__, __ge__, __gt__. + + Signal InvalidOperation if either self or other is a (quiet + or signaling) NaN. Signaling NaNs take precedence over quiet + NaNs. + + Return 0 if neither operand is a NaN. + + """ + if context is None: + context = getcontext() + + if self._is_special or other._is_special: + if self.is_snan(): + return context._raise_error(InvalidOperation, + 'comparison involving sNaN', + self) + elif other.is_snan(): + return context._raise_error(InvalidOperation, + 'comparison involving sNaN', + other) + elif self.is_qnan(): + return context._raise_error(InvalidOperation, + 'comparison involving NaN', + self) + elif other.is_qnan(): + return context._raise_error(InvalidOperation, + 'comparison involving NaN', + other) + return 0 + def __nonzero__(self): """Return True if self is nonzero; otherwise return False. @@ -724,18 +757,13 @@ class Decimal(object): """ return self._is_special or self._int != '0' - def __cmp__(self, other): - other = _convert_other(other) - if other is NotImplemented: - # Never return NotImplemented - return 1 + def _cmp(self, other): + """Compare the two non-NaN decimal instances self and other. + + Returns -1 if self < other, 0 if self == other and 1 + if self > other. This routine is for internal use only.""" if self._is_special or other._is_special: - # check for nans, without raising on a signaling nan - if self._isnan() or other._isnan(): - return 1 # Comparison involving NaN's always reports self > other - - # INF = INF return cmp(self._isinfinity(), other._isinfinity()) # check for zeros; note that cmp(0, -0) should return 0 @@ -764,15 +792,71 @@ class Decimal(object): else: # self_adjusted < other_adjusted return -((-1)**self._sign) + # Note: The Decimal standard doesn't cover rich comparisons for + # Decimals. In particular, the specification is silent on the + # subject of what should happen for a comparison involving a NaN. + # We take the following approach: + # + # == comparisons involving a NaN always return False + # != comparisons involving a NaN always return True + # <, >, <= and >= comparisons involving a (quiet or signaling) + # NaN signal InvalidOperation, and return False if the + # InvalidOperation is trapped. + # + # This behavior is designed to conform as closely as possible to + # that specified by IEEE 754. + def __eq__(self, other): - if not isinstance(other, (Decimal, int, long)): - return NotImplemented - return self.__cmp__(other) == 0 + other = _convert_other(other) + if other is NotImplemented: + return other + if self.is_nan() or other.is_nan(): + return False + return self._cmp(other) == 0 def __ne__(self, other): - if not isinstance(other, (Decimal, int, long)): - return NotImplemented - return self.__cmp__(other) != 0 + other = _convert_other(other) + if other is NotImplemented: + return other + if self.is_nan() or other.is_nan(): + return True + return self._cmp(other) != 0 + + def __lt__(self, other, context=None): + other = _convert_other(other) + if other is NotImplemented: + return other + ans = self._compare_check_nans(other, context) + if ans: + return False + return self._cmp(other) < 0 + + def __le__(self, other, context=None): + other = _convert_other(other) + if other is NotImplemented: + return other + ans = self._compare_check_nans(other, context) + if ans: + return False + return self._cmp(other) <= 0 + + def __gt__(self, other, context=None): + other = _convert_other(other) + if other is NotImplemented: + return other + ans = self._compare_check_nans(other, context) + if ans: + return False + return self._cmp(other) > 0 + + def __ge__(self, other, context=None): + other = _convert_other(other) + if other is NotImplemented: + return other + ans = self._compare_check_nans(other, context) + if ans: + return False + return self._cmp(other) >= 0 def compare(self, other, context=None): """Compares one to another. @@ -791,7 +875,7 @@ class Decimal(object): if ans: return ans - return Decimal(self.__cmp__(other)) + return Decimal(self._cmp(other)) def __hash__(self): """x.__hash__() <==> hash(x)""" @@ -2452,7 +2536,7 @@ class Decimal(object): return other._fix_nan(context) return self._check_nans(other, context) - c = self.__cmp__(other) + c = self._cmp(other) if c == 0: # If both operands are finite and equal in numerical value # then an ordering is applied: @@ -2494,7 +2578,7 @@ class Decimal(object): return other._fix_nan(context) return self._check_nans(other, context) - c = self.__cmp__(other) + c = self._cmp(other) if c == 0: c = self.compare_total(other) @@ -2542,23 +2626,10 @@ class Decimal(object): It's pretty much like compare(), but all NaNs signal, with signaling NaNs taking precedence over quiet NaNs. """ - if context is None: - context = getcontext() - - self_is_nan = self._isnan() - other_is_nan = other._isnan() - if self_is_nan == 2: - return context._raise_error(InvalidOperation, 'sNaN', - self) - if other_is_nan == 2: - return context._raise_error(InvalidOperation, 'sNaN', - other) - if self_is_nan: - return context._raise_error(InvalidOperation, 'NaN in compare_signal', - self) - if other_is_nan: - return context._raise_error(InvalidOperation, 'NaN in compare_signal', - other) + other = _convert_other(other, raiseit = True) + ans = self._compare_check_nans(other, context) + if ans: + return ans return self.compare(other, context=context) def compare_total(self, other): @@ -3065,7 +3136,7 @@ class Decimal(object): return other._fix_nan(context) return self._check_nans(other, context) - c = self.copy_abs().__cmp__(other.copy_abs()) + c = self.copy_abs()._cmp(other.copy_abs()) if c == 0: c = self.compare_total(other) @@ -3095,7 +3166,7 @@ class Decimal(object): return other._fix_nan(context) return self._check_nans(other, context) - c = self.copy_abs().__cmp__(other.copy_abs()) + c = self.copy_abs()._cmp(other.copy_abs()) if c == 0: c = self.compare_total(other) @@ -3170,7 +3241,7 @@ class Decimal(object): if ans: return ans - comparison = self.__cmp__(other) + comparison = self._cmp(other) if comparison == 0: return self.copy_sign(other) diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py index 0c8852c4001..3b15df76c6f 100644 --- a/Lib/test/test_decimal.py +++ b/Lib/test/test_decimal.py @@ -838,6 +838,19 @@ class DecimalArithmeticOperatorsTest(unittest.TestCase): self.assertEqual(-Decimal(45), Decimal(-45)) # - self.assertEqual(abs(Decimal(45)), abs(Decimal(-45))) # abs + def test_nan_comparisons(self): + n = Decimal('NaN') + s = Decimal('sNaN') + i = Decimal('Inf') + f = Decimal('2') + for x, y in [(n, n), (n, i), (i, n), (n, f), (f, n), + (s, n), (n, s), (s, i), (i, s), (s, f), (f, s), (s, s)]: + self.assert_(x != y) + self.assert_(not (x == y)) + self.assert_(not (x < y)) + self.assert_(not (x <= y)) + self.assert_(not (x > y)) + self.assert_(not (x >= y)) # The following are two functions used to test threading in the next class @@ -1147,7 +1160,12 @@ class DecimalUsabilityTest(unittest.TestCase): checkSameDec("__add__", True) checkSameDec("__div__", True) checkSameDec("__divmod__", True) - checkSameDec("__cmp__", True) + checkSameDec("__eq__", True) + checkSameDec("__ne__", True) + checkSameDec("__le__", True) + checkSameDec("__lt__", True) + checkSameDec("__ge__", True) + checkSameDec("__gt__", True) checkSameDec("__float__") checkSameDec("__floordiv__", True) checkSameDec("__hash__") diff --git a/Misc/NEWS b/Misc/NEWS index 0c5cdf2523e..172678e4b91 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -391,6 +391,9 @@ Core and builtins Library ------- +- #1979: Add rich comparisons to Decimal, and make Decimal comparisons + involving a NaN follow the IEEE 754 standard. + - #2004: tarfile.py: Use mode 0700 for temporary directories and default permissions for missing directories.