Issue 1979: Make Decimal comparisons (other than !=, ==) involving NaN

raise InvalidOperation (and return False if InvalidOperation is trapped).
This commit is contained in:
Mark Dickinson 2008-02-06 22:10:50 +00:00
parent 55b8c3e26f
commit 2fc9263df5
4 changed files with 145 additions and 40 deletions

View File

@ -1290,6 +1290,19 @@ A variant is :const:`sNaN` which signals rather than remaining quiet after every
operation. This is a useful return value when an invalid result needs to
interrupt a calculation for special handling.
The behavior of Python's comparison operators can be a little surprising where a
:const:`NaN` is involved. A test for equality where one of the operands is a
quiet or signaling :const:`NaN` always returns :const:`False` (even when doing
``Decimal('NaN')==Decimal('NaN')``), while a test for inequality always returns
:const:`True`. An attempt to compare two Decimals using any of the :const:'<',
:const:'<=', :const:'>' or :const:'>=' operators will raise the
:exc:`InvalidOperation` signal if either operand is a :const:`NaN`, and return
:const:`False` if this signal is trapped. Note that the General Decimal
Arithmetic specification does not specify the behavior of direct comparisons;
these rules for comparisons involving a :const:`NaN` were taken from the IEEE
754 standard. To ensure strict standards-compliance, use the :meth:`compare`
and :meth:`compare-signal` methods instead.
The signed zeros can result from calculations that underflow. They keep the sign
that would have resulted if the calculation had been carried out to greater
precision. Since their magnitude is zero, both positive and negative zeros are

View File

@ -717,6 +717,39 @@ class Decimal(object):
return other._fix_nan(context)
return 0
def _compare_check_nans(self, other, context):
"""Version of _check_nans used for the signaling comparisons
compare_signal, __le__, __lt__, __ge__, __gt__.
Signal InvalidOperation if either self or other is a (quiet
or signaling) NaN. Signaling NaNs take precedence over quiet
NaNs.
Return 0 if neither operand is a NaN.
"""
if context is None:
context = getcontext()
if self._is_special or other._is_special:
if self.is_snan():
return context._raise_error(InvalidOperation,
'comparison involving sNaN',
self)
elif other.is_snan():
return context._raise_error(InvalidOperation,
'comparison involving sNaN',
other)
elif self.is_qnan():
return context._raise_error(InvalidOperation,
'comparison involving NaN',
self)
elif other.is_qnan():
return context._raise_error(InvalidOperation,
'comparison involving NaN',
other)
return 0
def __nonzero__(self):
"""Return True if self is nonzero; otherwise return False.
@ -724,18 +757,13 @@ class Decimal(object):
"""
return self._is_special or self._int != '0'
def __cmp__(self, other):
other = _convert_other(other)
if other is NotImplemented:
# Never return NotImplemented
return 1
def _cmp(self, other):
"""Compare the two non-NaN decimal instances self and other.
Returns -1 if self < other, 0 if self == other and 1
if self > other. This routine is for internal use only."""
if self._is_special or other._is_special:
# check for nans, without raising on a signaling nan
if self._isnan() or other._isnan():
return 1 # Comparison involving NaN's always reports self > other
# INF = INF
return cmp(self._isinfinity(), other._isinfinity())
# check for zeros; note that cmp(0, -0) should return 0
@ -764,15 +792,71 @@ class Decimal(object):
else: # self_adjusted < other_adjusted
return -((-1)**self._sign)
# Note: The Decimal standard doesn't cover rich comparisons for
# Decimals. In particular, the specification is silent on the
# subject of what should happen for a comparison involving a NaN.
# We take the following approach:
#
# == comparisons involving a NaN always return False
# != comparisons involving a NaN always return True
# <, >, <= and >= comparisons involving a (quiet or signaling)
# NaN signal InvalidOperation, and return False if the
# InvalidOperation is trapped.
#
# This behavior is designed to conform as closely as possible to
# that specified by IEEE 754.
def __eq__(self, other):
if not isinstance(other, (Decimal, int, long)):
return NotImplemented
return self.__cmp__(other) == 0
other = _convert_other(other)
if other is NotImplemented:
return other
if self.is_nan() or other.is_nan():
return False
return self._cmp(other) == 0
def __ne__(self, other):
if not isinstance(other, (Decimal, int, long)):
return NotImplemented
return self.__cmp__(other) != 0
other = _convert_other(other)
if other is NotImplemented:
return other
if self.is_nan() or other.is_nan():
return True
return self._cmp(other) != 0
def __lt__(self, other, context=None):
other = _convert_other(other)
if other is NotImplemented:
return other
ans = self._compare_check_nans(other, context)
if ans:
return False
return self._cmp(other) < 0
def __le__(self, other, context=None):
other = _convert_other(other)
if other is NotImplemented:
return other
ans = self._compare_check_nans(other, context)
if ans:
return False
return self._cmp(other) <= 0
def __gt__(self, other, context=None):
other = _convert_other(other)
if other is NotImplemented:
return other
ans = self._compare_check_nans(other, context)
if ans:
return False
return self._cmp(other) > 0
def __ge__(self, other, context=None):
other = _convert_other(other)
if other is NotImplemented:
return other
ans = self._compare_check_nans(other, context)
if ans:
return False
return self._cmp(other) >= 0
def compare(self, other, context=None):
"""Compares one to another.
@ -791,7 +875,7 @@ class Decimal(object):
if ans:
return ans
return Decimal(self.__cmp__(other))
return Decimal(self._cmp(other))
def __hash__(self):
"""x.__hash__() <==> hash(x)"""
@ -2452,7 +2536,7 @@ class Decimal(object):
return other._fix_nan(context)
return self._check_nans(other, context)
c = self.__cmp__(other)
c = self._cmp(other)
if c == 0:
# If both operands are finite and equal in numerical value
# then an ordering is applied:
@ -2494,7 +2578,7 @@ class Decimal(object):
return other._fix_nan(context)
return self._check_nans(other, context)
c = self.__cmp__(other)
c = self._cmp(other)
if c == 0:
c = self.compare_total(other)
@ -2542,23 +2626,10 @@ class Decimal(object):
It's pretty much like compare(), but all NaNs signal, with signaling
NaNs taking precedence over quiet NaNs.
"""
if context is None:
context = getcontext()
self_is_nan = self._isnan()
other_is_nan = other._isnan()
if self_is_nan == 2:
return context._raise_error(InvalidOperation, 'sNaN',
self)
if other_is_nan == 2:
return context._raise_error(InvalidOperation, 'sNaN',
other)
if self_is_nan:
return context._raise_error(InvalidOperation, 'NaN in compare_signal',
self)
if other_is_nan:
return context._raise_error(InvalidOperation, 'NaN in compare_signal',
other)
other = _convert_other(other, raiseit = True)
ans = self._compare_check_nans(other, context)
if ans:
return ans
return self.compare(other, context=context)
def compare_total(self, other):
@ -3065,7 +3136,7 @@ class Decimal(object):
return other._fix_nan(context)
return self._check_nans(other, context)
c = self.copy_abs().__cmp__(other.copy_abs())
c = self.copy_abs()._cmp(other.copy_abs())
if c == 0:
c = self.compare_total(other)
@ -3095,7 +3166,7 @@ class Decimal(object):
return other._fix_nan(context)
return self._check_nans(other, context)
c = self.copy_abs().__cmp__(other.copy_abs())
c = self.copy_abs()._cmp(other.copy_abs())
if c == 0:
c = self.compare_total(other)
@ -3170,7 +3241,7 @@ class Decimal(object):
if ans:
return ans
comparison = self.__cmp__(other)
comparison = self._cmp(other)
if comparison == 0:
return self.copy_sign(other)

View File

@ -838,6 +838,19 @@ class DecimalArithmeticOperatorsTest(unittest.TestCase):
self.assertEqual(-Decimal(45), Decimal(-45)) # -
self.assertEqual(abs(Decimal(45)), abs(Decimal(-45))) # abs
def test_nan_comparisons(self):
n = Decimal('NaN')
s = Decimal('sNaN')
i = Decimal('Inf')
f = Decimal('2')
for x, y in [(n, n), (n, i), (i, n), (n, f), (f, n),
(s, n), (n, s), (s, i), (i, s), (s, f), (f, s), (s, s)]:
self.assert_(x != y)
self.assert_(not (x == y))
self.assert_(not (x < y))
self.assert_(not (x <= y))
self.assert_(not (x > y))
self.assert_(not (x >= y))
# The following are two functions used to test threading in the next class
@ -1147,7 +1160,12 @@ class DecimalUsabilityTest(unittest.TestCase):
checkSameDec("__add__", True)
checkSameDec("__div__", True)
checkSameDec("__divmod__", True)
checkSameDec("__cmp__", True)
checkSameDec("__eq__", True)
checkSameDec("__ne__", True)
checkSameDec("__le__", True)
checkSameDec("__lt__", True)
checkSameDec("__ge__", True)
checkSameDec("__gt__", True)
checkSameDec("__float__")
checkSameDec("__floordiv__", True)
checkSameDec("__hash__")

View File

@ -391,6 +391,9 @@ Core and builtins
Library
-------
- #1979: Add rich comparisons to Decimal, and make Decimal comparisons
involving a NaN follow the IEEE 754 standard.
- #2004: tarfile.py: Use mode 0700 for temporary directories and default
permissions for missing directories.