"merge() implemented with a binary-tree of iterators" class Node: # Invariants: # * top node: parent is None # * leaf node: left is None, leaf is self, right a Callable # * non-leaf nodes: left and right are Nodes, leaf is a leaf Node __slots__ = 'value', 'leaf', 'parent', 'left', 'right' def __init__(self, value, leaf=None, left=None, right=None): self.value = value self.leaf = self if leaf is None else leaf self.left = left self.right = right self.parent = None @classmethod def treeify(cls, data): n = len(data) if n == 1: value, it = data[0] return cls(value, right=it) n = (n + 1) // 2 c1 = cls.treeify(data[:n]) c2 = cls.treeify(data[n:]) selection = c2 if c2.value < c1.value else c1 node = cls(selection.value, selection.leaf, c1, c2) c1.parent = c2.parent = node return node def merge(*iterables): # get initial values and iterators sentinel = object() pairs = [] for iterable in iterables: it = iter(iterable) value = next(it, sentinel) if value is not sentinel: pairs.append((value, it.__next__)) if not pairs: return if len(pairs) == 1: value, nxt = pairs[0] yield value yield from nxt.__self__ return # build the binary tree node = Node.treeify(pairs) while True: # yield top value yield node.value # traverse back to the leaf in one step node = node.leaf # load the next value from the iterator try: node.value = node.right() except StopIteration: # when empty, move sibling node up a level parent = node.parent left, right = parent.left, parent.right sibling = left if right is node else right parent.value = sibling.value parent.left = sibling.left parent.right = sibling.right if sibling.left is None: # if sibling was a leaf, then parent becomes a leaf parent.leaf = parent # if the new leaf is just the root, finish with a fast case. if parent.parent is None: yield parent.value yield from parent.right.__self__ return else: parent.leaf = sibling.leaf # parent's new children should link to parent sibling.left.parent = sibling.right.parent = parent node = parent while (parent := node.parent) is not None: node = parent c1, c2 = node.left, node.right selection = c2 if c2.value < c1.value else c1 node.value = selection.value node.leaf = selection.leaf if __name__ == '__main__': from random import randrange as rr for i in range(10_000): its = [sorted(rr(10) for k in range(rr(10))) for j in range(rr(10))] expected = sorted(sum(its, [])) actual = list(merge(*its)) if expected != actual: print(f"{its=}") print(f"{expected=}") print(f"{actual=}") break print(i)