class _MergeNode: """ Binary tree invariants: - A node N is a leaf iff N.leaf is N - For each leaf node: - leaf.right is an iterator - leaf.left is the item most recently produced by leaf.right - leaf.key is key(leaf.left) - For each non-leaf node: - node.left and node.right are Nodes - node.left.parent is node is node.right.parent - node.leaf is one of node's descendant leaves - node.key is node.leaf.key - if "winner" is the higher priority of node.left and node.right based on their keys, then node.leaf is winner.leaf and node.key is winner.key. """ __slots__ = "key", "leaf", "parent", "left", "right" @classmethod def construct_leaf(cls, iterator, keyfunc): it = iter(iterator) try: item = next(it) except StopIteration: return None key = item if keyfunc is None else keyfunc(item) node = cls() node.key = key node.left = item node.right = it node.parent = None node.leaf = node return node @classmethod def construct_parent(cls, left, right, reverse): if reverse: winner = right if left.key < right.key else left else: winner = right if right.key < left.key else left node = cls() node.key = winner.key node.left = left node.right = right node.parent = None node.leaf = winner.leaf left.parent = right.parent = node return node def merge(*iterables, key=None, reverse=False): '''Merge multiple sorted inputs into a single sorted output. Similar to sorted(itertools.chain(*iterables)) but returns a generator, does not pull the data into memory all at once, and assumes that each of the input streams is already sorted (smallest to largest). >>> list(merge([1,3,5,7], [0,2,4,8], [5,10,15,20], [], [25])) [0, 1, 2, 3, 4, 5, 5, 7, 8, 10, 15, 20, 25] If *key* is not None, applies a key function to each element to determine its sort order. >>> list(merge(['dog', 'horse'], ['cat', 'fish', 'kangaroo'], key=len)) ['dog', 'cat', 'fish', 'horse', 'kangaroo'] ''' nodes = [] for it in iterables: leaf = _MergeNode.construct_leaf(it, key) if leaf is not None: nodes.append(leaf) n = len(nodes) if not nodes: return if n == 1: key = None # unite pairs of adjacent nodes with a common parent until all nodes # are united into one big tree. while n > 1: new_nodes = nodes[:n & 1] # leave unpaired nodes to the left for i in range(n & 1, n - 1, 2): left, right = nodes[i:i + 2] parent = _MergeNode.construct_parent(left, right, reverse) new_nodes.append(parent) nodes = new_nodes n = len(nodes) (root,) = nodes _StopIteration = StopIteration _next = next while True: # To find the value to yield, check which leaf # the root's key came from. node = root.leaf yield node.left try: node.left = _next(node.right) except _StopIteration: # When a leaf is exhausted, move its sibling up to where # its parent is now. parent = node.parent if parent is None: return left = parent.left right = parent.right sibling = left if right is node else right parent.left = sibling.left parent.right = sibling.right parent.key = sibling.key if sibling.leaf is sibling: # sibling was a leaf, so now parent becomes a leaf parent.leaf = parent # Fast out (don't compute the keys). if parent is root: key = None else: parent.leaf = sibling.leaf sibling.left.parent = sibling.right.parent = parent node = parent else: # Item successfully produced by iterator. if key is None: # use the value as the key. node.key = node.left else: node.key = key(node.left) if reverse: while node is not root: node = node.parent left, right = node.left, node.right winner = right if left.key < right.key else left node.leaf = winner.leaf node.key = winner.key else: while node is not root: node = node.parent left, right = node.left, node.right winner = right if right.key < left.key else left node.leaf = winner.leaf node.key = winner.key if __name__ == "__main__": from random import randrange as rr for i in range(500): key = [None, abs][i & 1] reverse = bool(i & 2) its = [sorted((rr(-5, 6) for k in range(rr(7))), key=key, reverse=reverse) for j in range(rr(7))] try: expected = sorted(sum(its, []), key=key, reverse=reverse) actual = list(merge(*its, key=key, reverse=reverse)) except: print(f"{its = }") raise if expected != actual: print(f"{its = }") print(f"{expected = }") print(f"{actual = }") break print(i)