diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -217,6 +217,25 @@ class TestPartialC(TestPartial, unittest.TestCase): ['{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr) for kwargs_repr in kwargs_reprs]) + def test_equality(self): + p = self.partial(capture, 1, 2, a=10, b=20) + q = self.partial(capture, 1, 2, a=10, b=20) + self.assertEqual(p, q) + + r = self.partial(capture, 1, 2, a=10) + self.assertNotEqual(p, r) + + self.assertNotEqual(p, capture) + self.assertNotEqual(q, capture) + + a = self.partial(capture) + b = self.partial(signature) + self.assertNotEqual(a, b) + + c = self.partial(capture, 1) + d = self.partial(capture, 2) + self.assertNotEqual(c, d) + def test_pickle(self): f = self.partial(signature, ['asdf'], bar=[True]) f.attr = [] diff --git a/Modules/_functoolsmodule.c b/Modules/_functoolsmodule.c --- a/Modules/_functoolsmodule.c +++ b/Modules/_functoolsmodule.c @@ -307,6 +307,54 @@ static PyMethodDef partial_methods[] = { {NULL, NULL} /* sentinel */ }; +static PyObject * +partial_richcompare(PyObject *self, PyObject *other, int op) +{ + partialobject *a, *b; + PyObject *res; + int part_self, part_other, fn_eq, args_eq, kw_eq, dict_eq; + + part_self = PyObject_IsInstance(self, (PyObject *) &partial_type); + if (part_self < 0) + return NULL; + part_other = PyObject_IsInstance(other, (PyObject *) &partial_type); + if (part_other < 0) + return NULL; + + if ((op != Py_EQ && op != Py_NE) || !part_self || !part_other) + Py_RETURN_NOTIMPLEMENTED; + + a = (partialobject *) self; + b = (partialobject *) other; + + fn_eq = PyObject_RichCompareBool(a->fn, b->fn, op); + if (fn_eq < 0) + goto error; + args_eq = PyObject_RichCompareBool(a->args, b->args, op); + if (args_eq < 0) + goto error; + kw_eq = PyObject_RichCompareBool(a->kw, b->kw, op); + if (kw_eq < 0) + goto error; + dict_eq = PyObject_RichCompareBool(a->dict, b->dict, op); + if (dict_eq < 0) + goto error; + + if (op == Py_EQ) + res = (fn_eq && args_eq && kw_eq && dict_eq) ? Py_True : Py_False; + else + res = (fn_eq || args_eq || kw_eq || dict_eq) ? Py_True : Py_False; + + Py_INCREF(res); + return res; + + error: + Py_DECREF(a); + Py_DECREF(b); + return NULL; + +} + static PyTypeObject partial_type = { PyVarObject_HEAD_INIT(NULL, 0) "functools.partial", /* tp_name */ @@ -333,7 +381,7 @@ static PyTypeObject partial_type = { partial_doc, /* tp_doc */ (traverseproc)partial_traverse, /* tp_traverse */ 0, /* tp_clear */ - 0, /* tp_richcompare */ + partial_richcompare, /* tp_richcompare */ offsetof(partialobject, weakreflist), /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */