diff --git a/Lib/heapq.py b/Lib/heapq.py --- a/Lib/heapq.py +++ b/Lib/heapq.py @@ -311,7 +311,7 @@ except ImportError: pass -def merge(*iterables): +def merge(*iterables, key=None): '''Merge multiple sorted inputs into a single sorted output. Similar to sorted(itertools.chain(*iterables)) but returns a generator, @@ -321,31 +321,63 @@ >>> list(merge([1,3,5,7], [0,2,4,8], [5,10,15,20], [], [25])) [0, 1, 2, 3, 4, 5, 5, 7, 8, 10, 15, 20, 25] + If *key* is not None, applies a key function to each element to determine + its sort order. + + >>> list(merge(['dog', 'horse'], ['cat', 'fish', 'kangaroo'], key=len)) + ['dog', 'cat', 'fish', 'horse', 'kangaroo'] + ''' h = [] h_append = h.append + _heapreplace = heapreplace + + if key is None: + for order, it in enumerate(map(iter, iterables)): + try: + next = it.__next__ + h_append([next(), order, next]) + except StopIteration: + pass + heapify(h) + while len(h) > 1: + try: + while True: + value, order, next = s = h[0] + yield value + s[0] = next() # raises StopIteration when exhausted + _heapreplace(h, s) # restore heap condition + except StopIteration: + heappop(h) # remove empty iterator + if h: + # fast case when only a single iterator remains + value, order, next = h[0] + yield value + yield from next.__self__ + return + for order, it in enumerate(map(iter, iterables)): try: next = it.__next__ - h_append([next(), order, next]) + value = next() + h_append([key(value), order, value, next]) except StopIteration: pass heapify(h) - - _heapreplace = heapreplace while len(h) > 1: try: while True: - value, order, next = s = h[0] + key_value, order, value, next = s = h[0] yield value - s[0] = next() # raises StopIteration when exhausted - _heapreplace(h, s) # restore heap condition + value = next() + s[0] = key(value) + s[2] = value + _heapreplace(h, s) except StopIteration: - heappop(h) # remove empty iterator + heappop(h) if h: - # fast case when only a single iterator remains - value, order, next = h[0] + key_value, order, value, next = h[0] yield value yield from next.__self__ diff --git a/Lib/test/test_heapq.py b/Lib/test/test_heapq.py --- a/Lib/test/test_heapq.py +++ b/Lib/test/test_heapq.py @@ -158,6 +158,17 @@ self.assertEqual(sorted(chain(*inputs)), list(self.module.merge(*inputs))) self.assertEqual(list(self.module.merge()), []) + def test_merge_with_key_function(self): + inputs = [] + for i in range(random.randrange(5)): + row = list(random.randrange(-1000, 1000) + for j in range(random.randrange(10))) + row.sort(key=abs) + inputs.append(row) + self.assertEqual(sorted(chain(*inputs), key=abs), + list(self.module.merge(*inputs, key=abs))) + self.assertEqual(list(self.module.merge()), []) + def test_merge_does_not_suppress_index_error(self): # Issue 19018: Heapq.merge suppresses IndexError from user generator def iterable():