if 1: _NODE_OUT = -1 _NODE_DONE = -2 class _NodeInfo: __slots__ = 'node', 'npred', 'succs' def __init__(self, node): self.node = node # number of predecessors, generally >= 0 # when this falls to 0, and is passed out, set to _NODE_OUT # when the node is marked done, set to _NODE_DONE self.npred = 0 # list of successor nodes # it doesn't matter if some are duplicated (so long as # they're all reflected in the successor .npred counts) self.succs = [] class CycleError(ValueError): pass class TS: def __init__(self): self.node2info = {} self.readytodo = None self.npassedout = 0 self.nfinished = 0 def _get_nodeinfo(self, node): if (result := self.node2info.get(node)) is None: self.node2info[node] = result = _NodeInfo(node) return result def add(self, node, *dependson): if self.readytodo is not None: raise ValueError("can't add nodes after prepare()") nodeinfo = self._get_nodeinfo(node) nodeinfo.npred += len(dependson) for pred in dependson: predinfo = self._get_nodeinfo(pred) predinfo.succs.append(node) def prepare(self): if self.readytodo is not None: raise ValueError("can't prepare() more than once") self.readytodo = [i.node for i in self.node2info.values() if i.npred == 0] # readytodo is set before we look for cycles on purpose: # if the user wants to catch the CycleError, that's fine, # they can continue using the instance to grab as many # nodes as possible before cycles block more progress cycle = self._find_cycle() if cycle: raise CycleError(f"nodes are in a cycle", cycle) def getreadytodo(self): if self.readytodo is None: raise ValueError("must call prepare() first") result = tuple(self.readytodo) self.readytodo.clear() n2i = self.node2info for x in result: n2i[x].npred = _NODE_OUT self.npassedout += len(result) return result def isactive(self): if self.readytodo is None: raise ValueError("must call prepare() first") assert self.nfinished <= self.npassedout return (self.nfinished < self.npassedout or bool(self.readytodo)) def done(self, node): if self.readytodo is None: raise ValueError("must call prepare() first") n2i = self.node2info nodeinfo = n2i.get(node) if nodeinfo is None: raise ValueError(f"don't know about node {node!r}") stat = nodeinfo.npred if stat != _NODE_OUT: if stat >= 0: raise ValueError(f"node {node!r} wasn't passed out") elif stat == _NODE_DONE: raise ValueError(f"node {node!r} already marked done") else: raise ValueError(f"node {node!r} unknown status {stat}") # TODO? move the body into a remove_edge function, which # could also be used to break cycles nodeinfo.npred = _NODE_DONE for x in nodeinfo.succs: xinfo = n2i[x] assert xinfo.npred > 0 xinfo.npred -= 1 if xinfo.npred == 0: self.readytodo.append(x) self.nfinished += 1 def _find_cycle(self): n2i = self.node2info stack = [] itstack = [] seen = set() node2stacki = {} for x in n2i: if x in seen: continue while True: if x in seen: if x in node2stacki: return stack[node2stacki[x] :] + [x] # else get next successor else: seen.add(x) itstack.append(iter(n2i[x].succs).__next__) node2stacki[x] = len(stack) stack.append(x) # backtrack to the topmost stack entry with # another successor while stack: try: x = itstack[-1]() break except StopIteration: del node2stacki[stack.pop()] itstack.pop() else: break assert not stack assert not itstack assert not node2stacki return [] import queue if 0: import threading workermaker = threading.Thread queuemaker = queue.Queue else: import multiprocessing workermaker = multiprocessing.Process queuemaker = multiprocessing.Queue def worker(inq, outq): while (x := inq.get()) is not None: outq.put(x) inq.put(None) if __name__ == "__main__": inq = queuemaker() outq = queuemaker() NWORKERS = 5 ths = [workermaker(target=worker, args=(inq, outq)) for i in range(NWORKERS)] for th in ths: th.start() t = TS() N = 100000 for i in range(0, N, 2): t.add(i, i+1, i+2) t.prepare() allofem = [] if 1: while t.isactive(): for node in t.getreadytodo(): inq.put(node) node = outq.get() t.done(node) allofem.append(node) else: pending = [] outstanding = 0 while t.isactive(): pending.extend(t.getreadytodo()) while pending and outstanding < 200: inq.put(pending.pop()) outstanding += 1 assert outstanding > 0 block = True try: while outstanding: node = outq.get(block=block) outstanding -= 1 allofem.append(node) t.done(node) if pending: inq.put(pending.pop()) outstanding += 1 block = False except queue.Empty: pass assert not outstanding inq.put(None) for th in ths: th.join() node = inq.get() assert node is None assert sorted(allofem) == list(range(N+1))