diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -217,6 +217,26 @@ ['{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr) for kwargs_repr in kwargs_reprs]) + def test_equality(self): + p = functools.partial(capture, 1, 2, a=10, b=20) + q = functools.partial(capture, 1, 2, a=10, b=20) + self.assertTrue(p == q) + self.assertFalse(p != q) + self.assertTrue(p.__eq__(q)) + self.assertFalse(p.__ne__(q)) + + q = self.partial(capture, 1, 2, a=10) + self.assertFalse(p == q) + self.assertTrue(p != q) + + self.assertNotEqual(p, capture) + self.assertNotEqual(q, capture) + + a = self.partial(capture) + b = self.partial(signature) + self.assertFalse(a == b) + self.assertTrue(a != b) + def test_pickle(self): f = self.partial(signature, ['asdf'], bar=[True]) f.attr = [] diff --git a/Modules/_functoolsmodule.c b/Modules/_functoolsmodule.c --- a/Modules/_functoolsmodule.c +++ b/Modules/_functoolsmodule.c @@ -22,6 +22,8 @@ static PyTypeObject partial_type; +#define PyPartial_CheckExact(op) ((op)->ob_type == &partial_type) + static PyObject * partial_new(PyTypeObject *type, PyObject *args, PyObject *kw) { @@ -307,6 +309,45 @@ {NULL, NULL} /* sentinel */ }; +static PyObject * +partial_richcompare(PyObject *self, PyObject *other, int op) +{ + partialobject *a, *b; + PyObject *res; + int eq; + + if ((op != Py_EQ && op != Py_NE) || + !PyPartial_CheckExact(self) || + !PyPartial_CheckExact(other)) + { + Py_RETURN_NOTIMPLEMENTED; + } + + a = (partialobject *) self; + b = (partialobject *) other; + + eq = PyObject_RichCompareBool(a->fn, b->fn, Py_EQ); + if (eq == 1) { + eq = PyObject_RichCompareBool(a->args, b->args, Py_EQ); + if (eq == 1) { + eq = PyObject_RichCompareBool(a->kw, b->kw, Py_EQ); + } + } + + if (eq < 0) { + return NULL; + } + + if (op == Py_EQ) + res = eq ? Py_True : Py_False; + else + res = eq ? Py_False : Py_True; + + Py_INCREF(res); + + return res; +} + static PyTypeObject partial_type = { PyVarObject_HEAD_INIT(NULL, 0) "functools.partial", /* tp_name */ @@ -333,7 +374,7 @@ partial_doc, /* tp_doc */ (traverseproc)partial_traverse, /* tp_traverse */ 0, /* tp_clear */ - 0, /* tp_richcompare */ + partial_richcompare, /* tp_richcompare */ offsetof(partialobject, weakreflist), /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */