diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -217,6 +217,14 @@ class TestPartialC(TestPartial, unittest.TestCase): ['{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr) for kwargs_repr in kwargs_reprs]) + def test_equality(self): + p = functools.partial(capture, 1, 2, a=10, b=20) + q = functools.partial(capture, 1, 2, a=10, b=20) + self.assertEqual(p, q) + + q = self.partial(capture, 1, 2, a=10) + self.assertNotEqual(p, q) + def test_pickle(self): f = self.partial(signature, ['asdf'], bar=[True]) f.attr = [] diff --git a/Modules/_functoolsmodule.c b/Modules/_functoolsmodule.c --- a/Modules/_functoolsmodule.c +++ b/Modules/_functoolsmodule.c @@ -20,7 +20,9 @@ typedef struct { PyObject *weakreflist; /* List of weak references */ } partialobject; -static PyTypeObject partial_type; +PyAPI_DATA(PyTypeObject) partial_type; + +#define PyPartial_Check(op) ((op)->ob_type == &partial_type) static PyObject * partial_new(PyTypeObject *type, PyObject *args, PyObject *kw) @@ -307,7 +309,41 @@ static PyMethodDef partial_methods[] = { {NULL, NULL} /* sentinel */ }; -static PyTypeObject partial_type = { +static PyObject * +partial_richcompare(PyObject *self, PyObject *other, int op) +{ + partialobject *a, *b; + PyObject *res; + int eq; + + if ((op != Py_EQ && op != Py_NE) || + !PyPartial_Check(self) || + !PyPartial_Check(other)) + { + Py_RETURN_NOTIMPLEMENTED; + } + + a = (partialobject *) self; + b = (partialobject *) other; + + if (op == Py_EQ) + eq = PyObject_RichCompareBool(a->fn, b->fn, op) && + PyObject_RichCompareBool(a->args, b->args, op) && + PyObject_RichCompareBool(a->kw, b->kw, op); + + else + eq = PyObject_RichCompareBool(a->fn, b->fn, op) || + PyObject_RichCompareBool(a->args, b->args, op) || + PyObject_RichCompareBool(a->kw, b->kw, op); + + res = eq ? Py_True : Py_False; + + Py_INCREF(res); + + return res; +} + +PyTypeObject partial_type = { PyVarObject_HEAD_INIT(NULL, 0) "functools.partial", /* tp_name */ sizeof(partialobject), /* tp_basicsize */ @@ -333,7 +369,7 @@ static PyTypeObject partial_type = { partial_doc, /* tp_doc */ (traverseproc)partial_traverse, /* tp_traverse */ 0, /* tp_clear */ - 0, /* tp_richcompare */ + partial_richcompare, /* tp_richcompare */ offsetof(partialobject, weakreflist), /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */