diff --git a/Lib/heapq.py b/Lib/heapq.py --- a/Lib/heapq.py +++ b/Lib/heapq.py @@ -127,7 +127,7 @@ __all__ = ['heappush', 'heappop', 'heapify', 'heapreplace', 'merge', 'nlargest', 'nsmallest', 'heappushpop'] -from itertools import islice, count, tee, chain +from itertools import islice, count def heappush(heap, item): """Push item onto heap, maintaining the heap invariant.""" @@ -179,12 +179,12 @@ for i in reversed(range(n//2)): _siftup(x, i) -def _heappushpop_max(heap, item): - """Maxheap version of a heappush followed by a heappop.""" - if heap and item < heap[0]: - item, heap[0] = heap[0], item - _siftup_max(heap, 0) - return item +def _heapreplace_max(heap, item): + """Maxheap version of a heappop followed by a heappush.""" + returnitem = heap[0] # raises appropriate IndexError if heap is empty + heap[0] = item + _siftup_max(heap, 0) + return returnitem def _heapify_max(x): """Transform list into a maxheap, in-place, in O(len(x)) time.""" @@ -192,24 +192,6 @@ for i in reversed(range(n//2)): _siftup_max(x, i) -def nsmallest(n, iterable): - """Find the n smallest elements in a dataset. - - Equivalent to: sorted(iterable)[:n] - """ - if n <= 0: - return [] - it = iter(iterable) - result = list(islice(it, n)) - if not result: - return result - _heapify_max(result) - _heappushpop = _heappushpop_max - for elem in it: - _heappushpop(result, elem) - result.sort() - return result - # 'heap' is a heap at all indices >= startpos, except possibly for pos. pos # is the index of a leaf with a possibly out-of-order value. Restore the # heap invariant. @@ -327,6 +309,10 @@ from _heapq import * except ImportError: pass +try: + from _heapq import _heapreplace_max +except ImportError: + pass def merge(*iterables): '''Merge multiple sorted inputs into a single sorted output. @@ -367,22 +353,21 @@ yield v yield from next.__self__ -# Extend the implementations of nsmallest and nlargest to use a key= argument -_nsmallest = nsmallest def nsmallest(n, iterable, key=None): """Find the n smallest elements in a dataset. Equivalent to: sorted(iterable, key=key)[:n] """ + # Short-cut for n==1 is to use min() when len(iterable)>0 if n == 1: it = iter(iterable) - head = list(islice(it, 1)) - if not head: - return [] + sentinel = object() if key is None: - return [min(chain(head, it))] - return [min(chain(head, it), key=key)] + result = min(it, default=sentinel) + else: + result = min(it, default=sentinel, key=key) + return [] if result is sentinel else [result] # When n>=size, it's faster to use sorted() try: @@ -395,15 +380,39 @@ # When key is none, use simpler decoration if key is None: - it = zip(iterable, count()) # decorate - result = _nsmallest(n, it) - return [r[0] for r in result] # undecorate + it = iter(iterable) + result = list(islice(zip(it, count()), n)) + if not result: + return result + _heapify_max(result) + order = n + top = result[0][0] + _heapreplace = _heapreplace_max + for elem in it: + if elem < top: + _heapreplace(result, (elem, order)) + top = result[0][0] + order += 1 + result.sort() + return [r[0] for r in result] # General case, slowest method - in1, in2 = tee(iterable) - it = zip(map(key, in1), count(), in2) # decorate - result = _nsmallest(n, it) - return [r[2] for r in result] # undecorate + it = iter(iterable) + result = [(key(elem), i, elem) for i, elem in zip(range(n), it)] + if not result: + return result + _heapify_max(result) + order = n + top = result[0][0] + _heapreplace = _heapreplace_max + for elem in it: + k = key(elem) + if k < top: + _heapreplace(result, (k, order, elem)) + top = result[0][0] + order += 1 + result.sort() + return [r[2] for r in result] def nlargest(n, iterable, key=None): """Find the n largest elements in a dataset. @@ -442,9 +451,9 @@ _heapreplace = heapreplace for elem in it: if top < elem: - order -= 1 _heapreplace(result, (elem, order)) top = result[0][0] + order -= 1 result.sort(reverse=True) return [r[0] for r in result] @@ -460,9 +469,9 @@ for elem in it: k = key(elem) if top < k: - order -= 1 _heapreplace(result, (k, order, elem)) top = result[0][0] + order -= 1 result.sort(reverse=True) return [r[2] for r in result] diff --git a/Lib/test/test_heapq.py b/Lib/test/test_heapq.py --- a/Lib/test/test_heapq.py +++ b/Lib/test/test_heapq.py @@ -13,7 +13,7 @@ # _heapq.nlargest/nsmallest are saved in heapq._nlargest/_smallest when # _heapq is imported, so check them there func_names = ['heapify', 'heappop', 'heappush', 'heappushpop', - 'heapreplace', '_nsmallest'] + 'heapreplace', '_heapreplace_max'] class TestModules(TestCase): def test_py_functions(self): diff --git a/Modules/_heapqmodule.c b/Modules/_heapqmodule.c --- a/Modules/_heapqmodule.c +++ b/Modules/_heapqmodule.c @@ -354,88 +354,34 @@ } static PyObject * -nsmallest(PyObject *self, PyObject *args) +_heapreplace_max(PyObject *self, PyObject *args) { - PyObject *heap=NULL, *elem, *iterable, *los, *it, *oldelem; - Py_ssize_t i, n; - int cmp; + PyObject *heap, *item, *returnitem; - if (!PyArg_ParseTuple(args, "nO:nsmallest", &n, &iterable)) + if (!PyArg_UnpackTuple(args, "_heapreplace_max", 2, 2, &heap, &item)) return NULL; - it = PyObject_GetIter(iterable); - if (it == NULL) + if (!PyList_Check(heap)) { + PyErr_SetString(PyExc_TypeError, "heap argument must be a list"); return NULL; - - heap = PyList_New(0); - if (heap == NULL) - goto fail; - - for (i=0 ; i=0 ; i--) - if(_siftupmax((PyListObject *)heap, i) == -1) - goto fail; - - los = PyList_GET_ITEM(heap, 0); - while (1) { - elem = PyIter_Next(it); - if (elem == NULL) { - if (PyErr_Occurred()) - goto fail; - else - goto sortit; - } - cmp = PyObject_RichCompareBool(elem, los, Py_LT); - if (cmp == -1) { - Py_DECREF(elem); - goto fail; - } - if (cmp == 0) { - Py_DECREF(elem); - continue; - } - - oldelem = PyList_GET_ITEM(heap, 0); - PyList_SET_ITEM(heap, 0, elem); - Py_DECREF(oldelem); - if (_siftupmax((PyListObject *)heap, 0) == -1) - goto fail; - los = PyList_GET_ITEM(heap, 0); } -sortit: - if (PyList_Sort(heap) == -1) - goto fail; - Py_DECREF(it); - return heap; + if (PyList_GET_SIZE(heap) < 1) { + PyErr_SetString(PyExc_IndexError, "index out of range"); + return NULL; + } -fail: - Py_DECREF(it); - Py_XDECREF(heap); - return NULL; + returnitem = PyList_GET_ITEM(heap, 0); + Py_INCREF(item); + PyList_SET_ITEM(heap, 0, item); + if (_siftupmax((PyListObject *)heap, 0) == -1) { + Py_DECREF(returnitem); + return NULL; + } + return returnitem; } -PyDoc_STRVAR(nsmallest_doc, -"Find the n smallest elements in a dataset.\n\ -\n\ -Equivalent to: sorted(iterable)[:n]\n"); +PyDoc_STRVAR(heapreplace_max_doc, "Maxheap variant of heapreplace"); static PyMethodDef heapq_methods[] = { {"heappush", (PyCFunction)heappush, @@ -448,8 +394,8 @@ METH_VARARGS, heapreplace_doc}, {"heapify", (PyCFunction)heapify, METH_O, heapify_doc}, - {"nsmallest", (PyCFunction)nsmallest, - METH_VARARGS, nsmallest_doc}, + {"_heapreplace_max",(PyCFunction)_heapreplace_max, + METH_VARARGS, heapreplace_max_doc}, {NULL, NULL} /* sentinel */ };