Index: Lib/decimal.py =================================================================== --- Lib/decimal.py (revision 60445) +++ Lib/decimal.py (working copy) @@ -718,6 +718,39 @@ return other._fix_nan(context) return 0 + def _compare_check_nans(self, other, context): + """Version of _check_nans used for the signaling comparisons + compare_signal, __le__, __lt__, __ge__, __gt__. + + Signal InvalidOperation if either self or other is a (quiet + or signaling) NaN. Signaling NaNs take precedence over quiet + NaNs. + + Return 0 if neither operand is a NaN. + + """ + if context is None: + context = getcontext() + + if self._is_special or other._is_special: + if self.is_snan(): + return context._raise_error(InvalidOperation, + 'comparison involving sNaN', + self) + elif other.is_snan(): + return context._raise_error(InvalidOperation, + 'comparison involving sNaN', + other) + elif self.is_qnan(): + return context._raise_error(InvalidOperation, + 'comparison involving NaN', + self) + elif other.is_qnan(): + return context._raise_error(InvalidOperation, + 'comparison involving NaN', + other) + return 0 + def __bool__(self): """Return True if self is nonzero; otherwise return False. @@ -725,18 +758,13 @@ """ return self._is_special or self._int != '0' - def __cmp__(self, other): - other = _convert_other(other) - if other is NotImplemented: - # Never return NotImplemented - return 1 + def _cmp(self, other): + """Compare the two non-NaN decimal instances self and other. + Returns -1 if self < other, 0 if self == other and 1 + if self > other. This routine is for internal use only.""" + if self._is_special or other._is_special: - # check for nans, without raising on a signaling nan - if self._isnan() or other._isnan(): - return 1 # Comparison involving NaN's always reports self > other - - # INF = INF return cmp(self._isinfinity(), other._isinfinity()) # check for zeros; note that cmp(0, -0) should return 0 @@ -766,43 +794,68 @@ return -((-1)**self._sign) def __eq__(self, other): - if not isinstance(other, (Decimal, int)): - return NotImplemented - return self.__cmp__(other) == 0 + other = _convert_other(other) + if other is NotImplemented: + return other + if self.is_nan() or other.is_nan(): + return False + return self._cmp(other) == 0 def __ne__(self, other): - if not isinstance(other, (Decimal, int)): - return NotImplemented - return self.__cmp__(other) != 0 + other = _convert_other(other) + if other is NotImplemented: + return other + if self.is_nan() or other.is_nan(): + return True + return self._cmp(other) != 0 - def __lt__(self, other): - if not isinstance(other, (Decimal, int)): - return NotImplemented - return self.__cmp__(other) < 0 + def __lt__(self, other, context=None): + other = _convert_other(other) + if other is NotImplemented: + return other + ans = self._compare_check_nans(other, context) + if ans: + return ans + return self._cmp(other) < 0 - def __le__(self, other): - if not isinstance(other, (Decimal, int)): - return NotImplemented - return self.__cmp__(other) <= 0 + def __le__(self, other, context=None): + other = _convert_other(other) + if other is NotImplemented: + return other + ans = self._compare_check_nans(other, context) + if ans: + return ans + return self._cmp(other) <= 0 - def __gt__(self, other): - if not isinstance(other, (Decimal, int)): - return NotImplemented - return self.__cmp__(other) > 0 + def __gt__(self, other, context=None): + other = _convert_other(other) + if other is NotImplemented: + return other + ans = self._compare_check_nans(other, context) + if ans: + return ans + return self._cmp(other) > 0 - def __ge__(self, other): - if not isinstance(other, (Decimal, int)): - return NotImplemented - return self.__cmp__(other) >= 0 + def __ge__(self, other, context=None): + other = _convert_other(other) + if other is NotImplemented: + return other + ans = self._compare_check_nans(other, context) + if ans: + return ans + return self._cmp(other) >= 0 def compare(self, other, context=None): - """Compares one to another. + """Compare the numerical values of self and other. - -1 => a < b - 0 => a = b - 1 => a > b - NaN => one is NaN - Like __cmp__, but returns Decimal instances. + The return value is a Decimal instance. Its value is + determined as follows: + + -1 if self < other + 0 if self == other + 1 if self > other + NaN if either self or other is a NaN. + """ other = _convert_other(other, raiseit=True) @@ -812,7 +865,7 @@ if ans: return ans - return Decimal(self.__cmp__(other)) + return Decimal(self._cmp(other)) def __hash__(self): """x.__hash__() <==> hash(x)""" @@ -2463,7 +2516,7 @@ return other._fix_nan(context) return self._check_nans(other, context) - c = self.__cmp__(other) + c = self._cmp(other) if c == 0: # If both operands are finite and equal in numerical value # then an ordering is applied: @@ -2505,7 +2558,7 @@ return other._fix_nan(context) return self._check_nans(other, context) - c = self.__cmp__(other) + c = self._cmp(other) if c == 0: c = self.compare_total(other) @@ -2553,23 +2606,10 @@ It's pretty much like compare(), but all NaNs signal, with signaling NaNs taking precedence over quiet NaNs. """ - if context is None: - context = getcontext() - - self_is_nan = self._isnan() - other_is_nan = other._isnan() - if self_is_nan == 2: - return context._raise_error(InvalidOperation, 'sNaN', - self) - if other_is_nan == 2: - return context._raise_error(InvalidOperation, 'sNaN', - other) - if self_is_nan: - return context._raise_error(InvalidOperation, 'NaN in compare_signal', - self) - if other_is_nan: - return context._raise_error(InvalidOperation, 'NaN in compare_signal', - other) + other = _convert_other(other, raiseit = True) + ans = self._compare_check_nans(other, context) + if ans: + return ans return self.compare(other, context=context) def compare_total(self, other): @@ -3076,7 +3116,7 @@ return other._fix_nan(context) return self._check_nans(other, context) - c = self.copy_abs().__cmp__(other.copy_abs()) + c = self.copy_abs()._cmp(other.copy_abs()) if c == 0: c = self.compare_total(other) @@ -3106,7 +3146,7 @@ return other._fix_nan(context) return self._check_nans(other, context) - c = self.copy_abs().__cmp__(other.copy_abs()) + c = self.copy_abs()._cmp(other.copy_abs()) if c == 0: c = self.compare_total(other) @@ -3181,7 +3221,7 @@ if ans: return ans - comparison = self.__cmp__(other) + comparison = self._cmp(other) if comparison == 0: return self.copy_sign(other) Index: Lib/test/test_decimal.py =================================================================== --- Lib/test/test_decimal.py (revision 60445) +++ Lib/test/test_decimal.py (working copy) @@ -832,6 +832,19 @@ self.assertEqual(-Decimal(45), Decimal(-45)) # - self.assertEqual(abs(Decimal(45)), abs(Decimal(-45))) # abs + def test_nan_comparisons(self): + n = Decimal('NaN') + s = Decimal('sNaN') + i = Decimal('Inf') + f = Decimal('2') + for x, y in [(n, n), (n, i), (i, n), (n, f), (f, n), + (s, n), (n, s), (s, i), (i, s), (s, f), (f, s), (s, s)]: + self.assertEqual(str(x < y), 'NaN') + self.assertEqual(str(x <= y), 'NaN') + self.assertEqual(str(x > y), 'NaN') + self.assertEqual(str(x >= y), 'NaN') + self.assert_(x != y) + self.assert_(not (x == y)) # The following are two functions used to test threading in the next class @@ -1136,14 +1149,19 @@ checkSameDec("__abs__") checkSameDec("__add__", True) checkSameDec("__divmod__", True) - checkSameDec("__cmp__", True) + checkSameDec("__eq__", True) checkSameDec("__float__") checkSameDec("__floordiv__", True) + checkSameDec("__ge__", True) + checkSameDec("__gt__", True) checkSameDec("__hash__") checkSameDec("__int__") checkSameDec("__trunc__") + checkSameDec("__le__", True) + checkSameDec("__lt__", True) checkSameDec("__mod__", True) checkSameDec("__mul__", True) + checkSameDec("__ne__", True) checkSameDec("__neg__") checkSameDec("__bool__") checkSameDec("__pos__")