diff --git a/Lib/heapq.py b/Lib/heapq.py --- a/Lib/heapq.py +++ b/Lib/heapq.py @@ -176,6 +176,16 @@ for i in reversed(range(n//2)): _siftup(x, i) +def _heappop_max(heap): + """Maxheap version of a heappop.""" + lastelt = heap.pop() # raises appropriate IndexError if heap is empty + if heap: + returnitem = heap[0] + heap[0] = lastelt + _siftup_max(heap, 0) + return returnitem + return lastelt + def _heapreplace_max(heap, item): """Maxheap version of a heappop followed by a heappush.""" returnitem = heap[0] # raises appropriate IndexError if heap is empty @@ -311,7 +321,7 @@ except ImportError: pass -def merge(*iterables): +def merge(*iterables, key=None, reverse=False): '''Merge multiple sorted inputs into a single sorted output. Similar to sorted(itertools.chain(*iterables)) but returns a generator, @@ -321,31 +331,73 @@ >>> list(merge([1,3,5,7], [0,2,4,8], [5,10,15,20], [], [25])) [0, 1, 2, 3, 4, 5, 5, 7, 8, 10, 15, 20, 25] + If *key* is not None, applies a key function to each element to determine + its sort order. + + >>> list(merge(['dog', 'horse'], ['cat', 'fish', 'kangaroo'], key=len)) + ['dog', 'cat', 'fish', 'horse', 'kangaroo'] + ''' h = [] h_append = h.append + + if reverse: + _heapify = _heapify_max + _heappop = _heappop_max + _heapreplace = _heapreplace_max + direction = -1 + else: + _heapify = heapify + _heappop = heappop + _heapreplace = heapreplace + direction = 1 + + if key is None: + for order, it in enumerate(map(iter, iterables)): + try: + next = it.__next__ + h_append([next(), order * direction, next]) + except StopIteration: + pass + _heapify(h) + while len(h) > 1: + try: + while True: + value, order, next = s = h[0] + yield value + s[0] = next() # raises StopIteration when exhausted + _heapreplace(h, s) # restore heap condition + except StopIteration: + _heappop(h) # remove empty iterator + if h: + # fast case when only a single iterator remains + value, order, next = h[0] + yield value + yield from next.__self__ + return + for order, it in enumerate(map(iter, iterables)): try: next = it.__next__ - h_append([next(), order, next]) + value = next() + h_append([key(value), order * direction, value, next]) except StopIteration: pass - heapify(h) - - _heapreplace = heapreplace + _heapify(h) while len(h) > 1: try: while True: - value, order, next = s = h[0] + key_value, order, value, next = s = h[0] yield value - s[0] = next() # raises StopIteration when exhausted - _heapreplace(h, s) # restore heap condition + value = next() + s[0] = key(value) + s[2] = value + _heapreplace(h, s) except StopIteration: - heappop(h) # remove empty iterator + _heappop(h) if h: - # fast case when only a single iterator remains - value, order, next = h[0] + key_value, order, value, next = h[0] yield value yield from next.__self__ diff --git a/Lib/test/test_heapq.py b/Lib/test/test_heapq.py --- a/Lib/test/test_heapq.py +++ b/Lib/test/test_heapq.py @@ -6,6 +6,7 @@ from test import support from unittest import TestCase, skipUnless +from operator import itemgetter py_heapq = support.import_fresh_module('heapq', blocked=['_heapq']) c_heapq = support.import_fresh_module('heapq', fresh=['_heapq']) @@ -152,11 +153,20 @@ def test_merge(self): inputs = [] - for i in range(random.randrange(5)): - row = sorted(random.randrange(1000) for j in range(random.randrange(10))) + for i in range(random.randrange(20)): + row = [] + for j in range(random.randrange(100)): + tup = random.choice('ABC'), random.randrange(-100, 100) + row.append(tup) inputs.append(row) - self.assertEqual(sorted(chain(*inputs)), list(self.module.merge(*inputs))) - self.assertEqual(list(self.module.merge()), []) + + for key in [None, itemgetter(0), itemgetter(1), itemgetter(1, 0)]: + for reverse in [False, True]: + for seq in inputs: + seq.sort(key=key, reverse=reverse) + self.assertEqual(sorted(chain(*inputs), key=key, reverse=reverse), + list(self.module.merge(*inputs, key=key, reverse=reverse))) + self.assertEqual(list(self.module.merge()), []) def test_merge_does_not_suppress_index_error(self): # Issue 19018: Heapq.merge suppresses IndexError from user generator