from pprint import pprint from collections import defaultdict from itertools import * def topological_sort(iterable): ''' topological_sort(iterable_of_seqs) --> ? iterator output * Input can be list of pairs or a list of longer sequences to be merged. * Sort is stable -- ties are resolved by first seen # Test that result is subsequence of chained inputs deduped remembering order * Work in setup phase: O(sum(seqs_lens)) * Work in visit phase O(E) + O(N) XXX convert input from iterable-of-seq to iterable-of-iterables XXX convert visit() from recursion to iteration XXX convert output from list to generator XXX add logic to handle cycles ??? possibly put this in the itertools module ??? alternate input as a graph (dict of lists) ''' ## Setup phase: Build predecessor dict and a ordered list of nodes without predecessors preds = defaultdict(list) # map node to list of predecesor nodes, order encountered orderednodes = [] # list of all node in order encountered all_preds = set() # set of all nodes that are precessors seen = set() for seq in iterable: for p, s in izip(seq, seq[1:]): preds[s].append(p) all_preds.add(p) for n in seq: if n not in seen: seen.add(n) orderednodes.append(n) startpoints = [n for n in orderednodes if n not in all_preds] del seen, all_preds # Structs from here forward: preds and startpoints ## # Recursive version of Visit Phase ## result = [] ## visited = set() ## def visit(n): ## if n not in visited: # every node visited exactly once ## for p in preds[n]: # every edge followed exactly once ## visit(p) ## visited.add(n) ## result.append(n) ## ## ## Visit phase: Starting with non-predessor nodes, visit every node ## for n in startpoints: ## visit(n) ## ## return result # Iterative version of Visit Phase result = [] visited = set() ds = [(n, iter(preds[n])) for n in reversed(startpoints)] while ds: n, it = ds[-1] if n in visited: del ds[-1] else: p = next(it, None) if p is None: del ds[-1] visited.add(n) result.append(n) else: ds.append((p, iter(preds[p]))) return result if __name__ == '__main__': print topological_sort(['ABDGI', 'BEG', 'CEH', 'KCFHJ']) print ['A', 'B', 'D', 'K', 'C', 'E', 'G', 'I', 'F', 'H', 'J'], '<-- target'