Index: Objects/complexobject.c =================================================================== --- Objects/complexobject.c (revision 81388) +++ Objects/complexobject.c (working copy) @@ -783,40 +783,62 @@ static PyObject * complex_richcompare(PyObject *v, PyObject *w, int op) { - int c; - Py_complex i, j; PyObject *res; + Py_complex i; + int equal = 0; - c = PyNumber_CoerceEx(&v, &w); - if (c < 0) - return NULL; - if (c > 0) { + if (!(PyLong_Check(w) || PyInt_Check(w) + || PyFloat_Check(w) || PyComplex_Check(w))) { Py_INCREF(Py_NotImplemented); return Py_NotImplemented; } - /* Make sure both arguments are complex. */ - if (!(PyComplex_Check(v) && PyComplex_Check(w))) { - Py_DECREF(v); - Py_DECREF(w); - Py_INCREF(Py_NotImplemented); - return Py_NotImplemented; - } - i = ((PyComplexObject *)v)->cval; - j = ((PyComplexObject *)w)->cval; - Py_DECREF(v); - Py_DECREF(w); - if (op != Py_EQ && op != Py_NE) { PyErr_SetString(PyExc_TypeError, - "no ordering relation is defined for complex numbers"); + "no ordering relation is defined for complex numbers"); return NULL; } - if ((i.real == j.real && i.imag == j.imag) == (op == Py_EQ)) - res = Py_True; + assert(PyComplex_Check(v)); + TO_COMPLEX(v, i); + + if (PyLong_Check(w) || PyInt_Check(w)) { + PyObject *j = NULL; + PyObject *sub_res = NULL; + + j = PyFloat_FromDouble(i.real); + if (j == NULL) + return NULL; + + /* Check for 0.0 imaginary part first to avoid the rich + * comparison when possible. + */ + if (i.imag == 0.0) { + sub_res = PyObject_RichCompare(j, w, Py_EQ); + if (sub_res == NULL) + return NULL; + + equal = (sub_res == Py_True); + Py_DECREF(sub_res); + } + } else if (PyFloat_Check(w)) { + double j = PyFloat_AsDouble(w); + + equal = (i.real == j && i.imag == 0.0); + } else if (PyComplex_Check(w)) { + Py_complex j; + + TO_COMPLEX(w, j); + equal = (i.real == j.real && i.imag == j.imag); + } else { + /* unreachable */ + assert(0); + } + + if (equal == (op == Py_EQ)) + res = Py_True; else - res = Py_False; + res = Py_False; Py_INCREF(res); return res; Index: Lib/test/test_complex.py =================================================================== --- Lib/test/test_complex.py (revision 81388) +++ Lib/test/test_complex.py (working copy) @@ -116,7 +116,7 @@ self.assertRaises(OverflowError, complex.__coerce__, 1+1j, 1L<<10000) def test_richcompare(self): - self.assertRaises(OverflowError, complex.__eq__, 1+1j, 1L<<10000) + self.assertEqual(complex.__eq__(1+1j, 1L<<10000), False) self.assertEqual(complex.__lt__(1+1j, None), NotImplemented) self.assertIs(complex.__eq__(1+1j, 1+1j), True) self.assertIs(complex.__eq__(1+1j, 2+2j), False) @@ -127,6 +127,23 @@ self.assertRaises(TypeError, complex.__gt__, 1+1j, 2+2j) self.assertRaises(TypeError, complex.__ge__, 1+1j, 2+2j) + def test_richcompare_boundaries(self): + def check(n, deltas, is_equal, imag = 0.0): + for delta in deltas: + i = n + delta + z = complex(i, imag) + self.assertIs(complex.__eq__(z, i), is_equal(delta)) + self.assertIs(complex.__ne__(z, i), not is_equal(delta)) + # For IEEE-754 doubles the following should hold: + # x in [2 ** (52 + i), 2 ** (53 + i + 1)] -> x mod 2 ** i == 0 + # where the interval is representable, of course. + for i in range(1, 10): + pow = 52 + i + mult = 2 ** i + check(2 ** pow, range(1, 101), lambda delta: delta % mult == 0) + check(2 ** pow, range(1, 101), lambda delta: False, float(i)) + check(2 ** 53, range(-100, 0), lambda delta: True) + def test_mod(self): self.assertRaises(ZeroDivisionError, (1+1j).__mod__, 0+0j)