--- heapq.py 2008-01-22 08:59:02.000000000 -0200 +++ heapq_key.py 2008-01-22 09:04:43.000000000 -0200 @@ -133,23 +133,23 @@ from operator import itemgetter, neg import bisect -def heappush(heap, item): +def heappush(heap, item, key=None): """Push item onto heap, maintaining the heap invariant.""" heap.append(item) - _siftdown(heap, 0, len(heap)-1) + _siftdown(heap, 0, len(heap)-1, key) -def heappop(heap): +def heappop(heap, key=None): """Pop the smallest item off the heap, maintaining the heap invariant.""" lastelt = heap.pop() # raises appropriate IndexError if heap is empty if heap: returnitem = heap[0] heap[0] = lastelt - _siftup(heap, 0) + _siftup(heap, 0, key) else: returnitem = lastelt return returnitem -def heapreplace(heap, item): +def heapreplace(heap, item, key=None): """Pop and return the current smallest value, and add the new item. This is more efficient than heappop() followed by heappush(), and can be @@ -162,10 +162,10 @@ """ returnitem = heap[0] # raises appropriate IndexError if heap is empty heap[0] = item - _siftup(heap, 0) + _siftup(heap, 0, key) return returnitem -def heapify(x): +def heapify(x, key=None): """Transform list into a heap, in-place, in O(len(heap)) time.""" n = len(x) # Transform bottom-up. The largest index there's any point to looking at @@ -174,7 +174,7 @@ # j-1 is the largest, which is n//2 - 1. If n is odd = 2*j+1, this is # (2*j+1-1)/2 = j so j-1 is the largest, and that's again n//2-1. for i in reversed(range(n//2)): - _siftup(x, i) + _siftup(x, i, key) def nlargest(n, iterable): """Find the n largest elements in a dataset. @@ -230,15 +230,18 @@ # 'heap' is a heap at all indices >= startpos, except possibly for pos. pos # is the index of a leaf with a possibly out-of-order value. Restore the # heap invariant. -def _siftdown(heap, startpos, pos): +def _siftdown(heap, startpos, pos, key): newitem = heap[pos] # Follow the path to the root, moving parents down until finding a place # newitem fits. while pos > startpos: parentpos = (pos - 1) >> 1 parent = heap[parentpos] - if parent <= newitem: + if key and key(parent) < key(newitem): break + elif not key and parent <= newitem: + break + heap[pos] = parent pos = parentpos heap[pos] = newitem @@ -282,7 +285,7 @@ # heappop() compares): list.sort() is (unsurprisingly!) more efficient # for sorting. -def _siftup(heap, pos): +def _siftup(heap, pos, key): endpos = len(heap) startpos = pos newitem = heap[pos] @@ -291,8 +294,12 @@ while childpos < endpos: # Set childpos to index of smaller child. rightpos = childpos + 1 - if rightpos < endpos and heap[rightpos] <= heap[childpos]: - childpos = rightpos + if rightpos < endpos: + if key and key(heap[rightpos]) <= key(heap[childpos]): + childpos = rightpos + elif not key and heap[rightpos] < heap[childpos]: + childpos = rightpos + # Move the smaller child up. heap[pos] = heap[childpos] pos = childpos @@ -300,7 +307,7 @@ # The leaf at pos is empty now. Put newitem there, and bubble it up # to its final resting place (by sifting its parents down). heap[pos] = newitem - _siftdown(heap, startpos, pos) + _siftdown(heap, startpos, pos, key) # If available, use C implementation try: