Index: Lib/fractions.py =================================================================== --- Lib/fractions.py (revision 74022) +++ Lib/fractions.py (working copy) @@ -501,54 +501,56 @@ if isinstance(b, numbers.Complex) and b.imag == 0: b = b.real if isinstance(b, float): - return a == a.from_float(b) + if math.isnan(b) or math.isinf(b): + # comparisons with an infinity or nan should behave in + # the same way for any finite a, so treat a as zero. + return 0.0 == b + else: + return a == a.from_float(b) else: - # XXX: If b.__eq__ is implemented like this method, it may - # give the wrong answer after float(a) changes a's - # value. Better ways of doing this are welcome. - return float(a) == b + # Since a doesn't know how to compare with b, let's give b + # a chance to compare itself with a. + return NotImplemented - def _subtractAndCompareToZero(a, b, op): - """Helper function for comparison operators. + def _richcmp(self, other, op): + """Helper for comparison operators, for internal use only. - Subtracts b from a, exactly if possible, and compares the - result with 0 using op, in such a way that the comparison - won't recurse. If the difference raises a TypeError, returns - NotImplemented instead. + Implement comparison between a Rational instance `self`, and + either another Rational instance or a float `other`. If + `other` is not a Rational instance or a float, return + NotImplemented. `op` should be one of the six standard + comparison operators. """ - if isinstance(b, numbers.Complex) and b.imag == 0: - b = b.real - if isinstance(b, float): - b = a.from_float(b) - try: - # XXX: If b <: Real but not <: Rational, this is likely - # to fall back to a float. If the actual values differ by - # less than MIN_FLOAT, this could falsely call them equal, - # which would make <= inconsistent with ==. Better ways of - # doing this are welcome. - diff = a - b - except TypeError: + # convert other to a Rational instance where reasonable. + if isinstance(other, numbers.Rational): + return op(self._numerator * other.denominator, + self._denominator * other.numerator) + if isinstance(other, numbers.Complex) and other.imag == 0: + other = other.real + if isinstance(other, float): + if math.isnan(other) or math.isinf(other): + return op(0.0, other) + else: + return op(self, self.from_float(other)) + else: return NotImplemented - if isinstance(diff, numbers.Rational): - return op(diff.numerator, 0) - return op(diff, 0) def __lt__(a, b): """a < b""" - return a._subtractAndCompareToZero(b, operator.lt) + return a._richcmp(b, operator.lt) def __gt__(a, b): """a > b""" - return a._subtractAndCompareToZero(b, operator.gt) + return a._richcmp(b, operator.gt) def __le__(a, b): """a <= b""" - return a._subtractAndCompareToZero(b, operator.le) + return a._richcmp(b, operator.le) def __ge__(a, b): """a >= b""" - return a._subtractAndCompareToZero(b, operator.ge) + return a._richcmp(b, operator.ge) def __bool__(a): """a != 0""" Index: Lib/test/test_fractions.py =================================================================== --- Lib/test/test_fractions.py (revision 74022) +++ Lib/test/test_fractions.py (working copy) @@ -3,6 +3,7 @@ from decimal import Decimal from test.support import run_unittest import math +import numbers import operator import fractions import unittest @@ -11,7 +12,70 @@ F = fractions.Fraction gcd = fractions.gcd +class DummyFloat(object): + """Dummy float class for testing comparisons with Fractions""" + def __init__(self, value): + if not isinstance(value, float): + raise TypeError("DummyFloat can only be initialized from float") + self.value = value + + def _richcmp(self, other, op): + if isinstance(other, numbers.Rational): + return op(F.from_float(self.value), other) + elif isinstance(other, DummyFloat): + return op(self.value, other.value) + else: + return NotImplemented + + def __eq__(self, other): return self._richcmp(other, operator.eq) + def __le__(self, other): return self._richcmp(other, operator.le) + def __lt__(self, other): return self._richcmp(other, operator.lt) + def __ge__(self, other): return self._richcmp(other, operator.ge) + def __gt__(self, other): return self._richcmp(other, operator.gt) + + # shouldn't be calling __float__ at all when doing comparisons + def __float__(self): + assert False, "__float__ should not be invoked for comparisons" + + # same goes for subtraction + def __sub__(self, other): + assert False, "__sub__ should not be invoked for comparisons" + __rsub__ = __sub__ + + +class DummyRational(object): + """Test comparison of Fraction with a naive rational implementation.""" + + def __init__(self, num, den): + g = gcd(num, den) + self.num = num // g + self.den = den // g + + def __eq__(self, other): + if isinstance(other, fractions.Fraction): + return (self.num == other._numerator and + self.den == other._denominator) + else: + return NotImplemented + + def __lt__(self, other): + return(self.num * other._denominator < self.den * other._numerator) + + def __gt__(self, other): + return(self.num * other._denominator > self.den * other._numerator) + + def __le__(self, other): + return(self.num * other._denominator <= self.den * other._numerator) + + def __ge__(self, other): + return(self.num * other._denominator >= self.den * other._numerator) + + # this class is for testing comparisons; conversion to float + # should never be used for a comparison, since it loses accuracy + def __float__(self): + assert False, "__float__ should not be invoked" + class GcdTest(unittest.TestCase): def testMisc(self): @@ -324,6 +388,50 @@ self.assertFalse(F(1, 2) != F(1, 2)) self.assertTrue(F(1, 2) != F(1, 3)) + def testComparisonsDummyRational(self): + self.assertTrue(F(1, 2) == DummyRational(1, 2)) + self.assertTrue(DummyRational(1, 2) == F(1, 2)) + self.assertFalse(F(1, 2) == DummyRational(3, 4)) + self.assertFalse(DummyRational(3, 4) == F(1, 2)) + + self.assertTrue(F(1, 2) < DummyRational(3, 4)) + self.assertFalse(F(1, 2) < DummyRational(1, 2)) + self.assertFalse(F(1, 2) < DummyRational(1, 7)) + self.assertFalse(F(1, 2) > DummyRational(3, 4)) + self.assertFalse(F(1, 2) > DummyRational(1, 2)) + self.assertTrue(F(1, 2) > DummyRational(1, 7)) + self.assertTrue(F(1, 2) <= DummyRational(3, 4)) + self.assertTrue(F(1, 2) <= DummyRational(1, 2)) + self.assertFalse(F(1, 2) <= DummyRational(1, 7)) + self.assertFalse(F(1, 2) >= DummyRational(3, 4)) + self.assertTrue(F(1, 2) >= DummyRational(1, 2)) + self.assertTrue(F(1, 2) >= DummyRational(1, 7)) + + self.assertTrue(DummyRational(1, 2) < F(3, 4)) + self.assertFalse(DummyRational(1, 2) < F(1, 2)) + self.assertFalse(DummyRational(1, 2) < F(1, 7)) + self.assertFalse(DummyRational(1, 2) > F(3, 4)) + self.assertFalse(DummyRational(1, 2) > F(1, 2)) + self.assertTrue(DummyRational(1, 2) > F(1, 7)) + self.assertTrue(DummyRational(1, 2) <= F(3, 4)) + self.assertTrue(DummyRational(1, 2) <= F(1, 2)) + self.assertFalse(DummyRational(1, 2) <= F(1, 7)) + self.assertFalse(DummyRational(1, 2) >= F(3, 4)) + self.assertTrue(DummyRational(1, 2) >= F(1, 2)) + self.assertTrue(DummyRational(1, 2) >= F(1, 7)) + + def testComparisonsDummyFloat(self): + x = DummyFloat(1./3.) + y = F(1, 3) + self.assertTrue(x != y) + self.assertTrue(x < y or x > y) + self.assertFalse(x == y) + self.assertFalse(x <= y and x >= y) + self.assertTrue(y != x) + self.assertTrue(y < x or y > x) + self.assertFalse(y == x) + self.assertFalse(y <= x and y >= x) + def testMixedLess(self): self.assertTrue(2 < F(5, 2)) self.assertFalse(2 < F(4, 2)) @@ -335,6 +443,13 @@ self.assertTrue(0.4 < F(1, 2)) self.assertFalse(0.5 < F(1, 2)) + self.assertFalse(float('inf') < F(1, 2)) + self.assertTrue(float('-inf') < F(0, 10)) + self.assertFalse(float('nan') < F(-3, 7)) + self.assertTrue(F(1, 2) < float('inf')) + self.assertFalse(F(17, 12) < float('-inf')) + self.assertFalse(F(144, -89) < float('nan')) + def testMixedLessEqual(self): self.assertTrue(0.5 <= F(1, 2)) self.assertFalse(0.6 <= F(1, 2)) @@ -345,6 +460,13 @@ self.assertTrue(F(4, 2) <= 2) self.assertFalse(F(5, 2) <= 2) + self.assertFalse(float('inf') <= F(1, 2)) + self.assertTrue(float('-inf') <= F(0, 10)) + self.assertFalse(float('nan') <= F(-3, 7)) + self.assertTrue(F(1, 2) <= float('inf')) + self.assertFalse(F(17, 12) <= float('-inf')) + self.assertFalse(F(144, -89) <= float('nan')) + def testBigFloatComparisons(self): # Because 10**23 can't be represented exactly as a float: self.assertFalse(F(10**23) == float(10**23)) @@ -369,6 +491,10 @@ self.assertFalse(2 == F(3, 2)) self.assertTrue(F(4, 2) == 2) self.assertFalse(F(5, 2) == 2) + self.assertFalse(F(5, 2) == float('nan')) + self.assertFalse(float('nan') == F(3, 7)) + self.assertFalse(F(5, 2) == float('inf')) + self.assertFalse(float('-inf') == F(2, 5)) def testStringification(self): self.assertEquals("Fraction(7, 3)", repr(F(7, 3)))