import asynq def sync_wait_for(future_or_coroutine): return future_or_coroutine.value() @asynq.asynq(pure=True) def gather(*awaitables): return (yield tuple(awaitables)) class Batch(asynq.BatchBase): def __init_subclass__(cls, **kwargs): cls.instance = cls() @staticmethod def resolve_futures(batch): raise NotImplementedError @classmethod def _try_switch_active_batch(cls): cls.instance = cls() def _flush(self): keys = [item.key for item in self.items] results = self.resolve_futures(keys) for future, value in zip(self.items, results): future.set_value(value) @classmethod def gen(cls, key): return BatchItem(cls.instance, key) @classmethod def genv(cls, keys): return gather(*[cls.gen(key) for key in keys]) class BatchItem(asynq.BatchItemBase): def __init__(self, batch, key): super().__init__(batch) self.key = key class DoubleBatch(Batch): @staticmethod def resolve_futures(batch): return [x+x for x in batch] class SquareBatch(Batch): @staticmethod def resolve_futures(batch): return [x*x for x in batch] @asynq.asynq(pure=True) def double_square(x): double = yield DoubleBatch.gen(x) square = yield SquareBatch.gen(double) return square @asynq.asynq(pure=True) def square_double(x): square = yield SquareBatch.gen(x) double = yield DoubleBatch.gen(square) return double @asynq.asynq(pure=True) def triple_double(x): d1 = yield DoubleBatch.gen(x) d2 = yield DoubleBatch.gen(d1) d3 = yield DoubleBatch.gen(d2) return d3 @asynq.asynq(pure=True) def double_square_square_double(x): ds = yield double_square(x) sd = yield square_double(ds) return sd @asynq.asynq(pure=True) def big_batch(x): return (yield ( gather( square_double(x + 0), square_double(x + 1), square_double(x + 2), ), gather( double_square(x * 1), double_square(x * 2), double_square(x * 3), ), DoubleBatch.genv(range(x, x + 5, +1)), SquareBatch.genv(range(x, x - 5, -1)), gather( triple_double(x - 1), triple_double(x - 2), triple_double(x - 3), ), gather( double_square_square_double(x * 5), double_square_square_double(x * 7), double_square_square_double(x * 9), ), )) @asynq.asynq(pure=True) def root(): return (yield [ big_batch(x) for x in range(0, 5000) ]) if __name__ == '__main__': sync_wait_for(root())