diff --git a/Lib/functools.py b/Lib/functools.py --- a/Lib/functools.py +++ b/Lib/functools.py @@ -97,7 +97,7 @@ """Convert a cmp= function into a key= function""" class K(object): __slots__ = ['obj'] - def __init__(self, obj, *args): + def __init__(self, obj): self.obj = obj def __lt__(self, other): return mycmp(self.obj, other.obj) < 0 @@ -115,6 +115,11 @@ raise TypeError('hash not implemented') return K +try: + from _functools import cmp_to_key +except ImportError: + pass + _CacheInfo = namedtuple("CacheInfo", "hits misses maxsize currsize") def lru_cache(maxsize=100): diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -435,18 +435,84 @@ self.assertEqual(self.func(add, d), "".join(d.keys())) class TestCmpToKey(unittest.TestCase): + def test_cmp_to_key(self): + def cmp1(x, y): + return (x > y) - (x < y) + key = functools.cmp_to_key(cmp1) + self.assertEqual(key(3), key(3)) + self.assertGreater(key(3), key(1)) + def cmp2(x, y): + return int(x) - int(y) + key = functools.cmp_to_key(cmp2) + self.assertEqual(key(4.0), key('4')) + self.assertLess(key(2), key('35')) + + def test_cmp_to_key_arguments(self): + def cmp1(x, y): + return (x > y) - (x < y) + key = functools.cmp_to_key(mycmp=cmp1) + self.assertEqual(key(obj=3), key(obj=3)) + self.assertGreater(key(obj=3), key(obj=1)) + with self.assertRaises(TypeError): + key(3) > 1 # rhs is not a K object + with self.assertRaises(TypeError): + 1 < key(3) # lhs is not a K object + with self.assertRaises(TypeError): + key = functools.cmp_to_key() # too few args + with self.assertRaises(TypeError): + key = functools.cmp_to_key(cmp1, None) # too many args + key = functools.cmp_to_key(cmp1) + with self.assertRaises(TypeError): + key() # too few args + with self.assertRaises(TypeError): + key(None, None) # too many args + + def test_bad_cmp(self): + def cmp1(x, y): + raise ZeroDivisionError + key = functools.cmp_to_key(cmp1) + with self.assertRaises(ZeroDivisionError): + key(3) > key(1) + + class BadCmp: + def __lt__(self, other): + raise ZeroDivisionError + def __gt__(self, other): + raise ZeroDivsioinError + def cmp1(x, y): + return BadCmp() + with self.assertRaises(ZeroDivisionError): + key(3) > key(1) + + def test_obj_field(self): + def cmp1(x, y): + return (x > y) - (x < y) + key = functools.cmp_to_key(mycmp=cmp1) + self.assertEqual(key(50).obj, 50) + self.assertIs(key.obj, None) + + def test_sort_int(self): def mycmp(x, y): return y - x self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)), [4, 3, 2, 1, 0]) + def test_sort_int_str(self): + def mycmp(x, y): + x, y = int(x), int(y) + return (x > y) - (x < y) + values = [5, '3', 7, 2, '0', '1', 4, '10', 1] + values = sorted(values, key=functools.cmp_to_key(mycmp)) + self.assertEqual([int(value) for value in values], + [0, 1, 1, 2, 3, 4, 5, 7, 10]) + def test_hash(self): def mycmp(x, y): return y - x key = functools.cmp_to_key(mycmp) k = key(10) - self.assertRaises(TypeError, hash(k)) + self.assertRaises(TypeError, hash, k) class TestTotalOrdering(unittest.TestCase): @@ -655,6 +721,7 @@ def test_main(verbose=None): test_classes = ( + TestCmpToKey, TestPartial, TestPartialSubclass, TestPythonPartial, diff --git a/Misc/NEWS b/Misc/NEWS --- a/Misc/NEWS +++ b/Misc/NEWS @@ -97,6 +97,9 @@ - Issue #10791: Implement missing method GzipFile.read1(), allowing GzipFile to be wrapped in a TextIOWrapper. Patch by Nadeem Vawda. +- Issue #11707: Added a fast C version of functools.cmp_to_key(). + Patch by Filip GruszczyƄski. + - Issue #11688: Add sqlite3.Connection.set_trace_callback(). Patch by Torsten Landschoff. diff --git a/Modules/_functoolsmodule.c b/Modules/_functoolsmodule.c --- a/Modules/_functoolsmodule.c +++ b/Modules/_functoolsmodule.c @@ -330,6 +330,165 @@ }; +/* cmp_to_key ***************************************************************/ + +typedef struct { + PyObject_HEAD; + PyObject *cmp; + PyObject *object; +} keyobject; + +static void +keyobject_dealloc(keyobject *ko) +{ + Py_DECREF(ko->cmp); + Py_XDECREF(ko->object); + PyObject_FREE(ko); +} + +static int +keyobject_traverse(keyobject *ko, visitproc visit, void *arg) +{ + Py_VISIT(ko->cmp); + if (ko->object) + Py_VISIT(ko->object); + return 0; +} + +static PyMemberDef keyobject_members[] = { + {"obj", T_OBJECT, + offsetof(keyobject, object), 0, + PyDoc_STR("Value wrapped by a key function.")}, + {NULL} +}; + +static PyObject * +keyobject_call(keyobject *ko, PyObject *args, PyObject *kw); + +static PyObject * +keyobject_richcompare(PyObject *ko, PyObject *other, int op); + +static PyTypeObject keyobject_type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + "functools.KeyWrapper", /* tp_name */ + sizeof(keyobject), /* tp_basicsize */ + 0, /* tp_itemsize */ + /* methods */ + (destructor)keyobject_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + (ternaryfunc)keyobject_call, /* tp_call */ + 0, /* tp_str */ + PyObject_GenericGetAttr, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + 0, /* tp_doc */ + (traverseproc)keyobject_traverse, /* tp_traverse */ + 0, /* tp_clear */ + keyobject_richcompare, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + 0, /* tp_methods */ + keyobject_members, /* tp_members */ + 0, /* tp_getset */ +}; + +static PyObject * +keyobject_call(keyobject *ko, PyObject *args, PyObject *kwds) +{ + PyObject *object; + keyobject *result; + static char *kwargs[] = {"obj", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O:K", kwargs, &object)) + return NULL; + result = PyObject_New(keyobject, &keyobject_type); + if (!result) + return NULL; + Py_INCREF(ko->cmp); + result->cmp = ko->cmp; + Py_INCREF(object); + result->object = object; + return (PyObject *)result; +} + +static PyObject * +keyobject_richcompare(PyObject *ko, PyObject *other, int op) +{ + PyObject *res; + PyObject *args; + PyObject *x; + PyObject *y; + PyObject *compare; + PyObject *answer; + static PyObject *zero; + + if (zero == NULL) { + zero = PyLong_FromLong(0); + if (!zero) + return NULL; + } + + if (Py_TYPE(other) != &keyobject_type){ + PyErr_Format(PyExc_TypeError, "other argument must be K instance"); + return NULL; + } + compare = ((keyobject *) ko)->cmp; + assert(compare != NULL); + x = ((keyobject *) ko)->object; + y = ((keyobject *) other)->object; + if (!x || !y){ + PyErr_Format(PyExc_AttributeError, "object"); + return NULL; + } + + /* Call the user's comparison function and translate the 3-way + * result into true or false (or error). + */ + args = PyTuple_New(2); + if (args == NULL) + return NULL; + Py_INCREF(x); + Py_INCREF(y); + PyTuple_SET_ITEM(args, 0, x); + PyTuple_SET_ITEM(args, 1, y); + res = PyObject_Call(compare, args, NULL); + Py_DECREF(args); + if (res == NULL) + return NULL; + answer = PyObject_RichCompare(res, zero, op); + Py_DECREF(res); + return answer; +} + +static PyObject * +functools_cmp_to_key(PyObject *self, PyObject *args, PyObject *kwds){ + PyObject *cmp; + static char *kwargs[] = {"mycmp", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O:cmp_to_key", kwargs, &cmp)) + return NULL; + keyobject *object = PyObject_New(keyobject, &keyobject_type); + if (!object) + return NULL; + Py_INCREF(cmp); + object->cmp = cmp; + object->object = NULL; + return (PyObject *)object; +} + +PyDoc_STRVAR(functools_cmp_to_key_doc, +"Convert a cmp= function into a key= function."); + /* reduce (used to be a builtin) ********************************************/ static PyObject * @@ -413,6 +572,8 @@ static PyMethodDef module_methods[] = { {"reduce", functools_reduce, METH_VARARGS, functools_reduce_doc}, + {"cmp_to_key", functools_cmp_to_key, METH_VARARGS | METH_KEYWORDS, + functools_cmp_to_key_doc}, {NULL, NULL} /* sentinel */ };