import threading import multiprocessing import queue import traceback import os class ThreadPool: def __init__(self): self.thrs = [] self.min_size = multiprocessing.cpu_count() // 2 self.work_queue = queue.Queue(256) for _ in range(self.min_size): thr = threading.Thread(target=self.run) self.thrs.append(thr) self._lock = threading.Lock() self._started = False self._stopped = False self.daemon = True def start(self): with self._lock: if self._started: return self._started = True for thr in self.thrs: thr.daemon = self.daemon thr.start() print("ThreadPool started.") def tear_down(self): with self._lock: if self._stopped: return self._stopped = True for thr in self.thrs: self.work_queue.put(None) if not self.daemon: print("Wait for threads to stop.") for thr in self.thrs: thr.join() print("ThreadPool stopped.") def enqueue_jobs(self, jobs, block=True): for job in jobs: self.work_queue.put(job, block) def run(self): work_queue = self.work_queue while 1: print(f"Executing to get job for {threading.current_thread().getName()}") job = work_queue.get() if job is None: break if self._stopped: work_queue.put(None) # print("Going to exec job") try: job() except Exception: print(traceback.format_exc()) # print("Done with exec job") qsize = work_queue.qsize() print(f"Thread work_queue_size={qsize}") print(f"Worker thread {threading.current_thread().getName()} stopped.") class ProcessPool: def __init__(self): self.size = multiprocessing.cpu_count() // 2 self._pool = multiprocessing.Pool(processes=self.size, maxtasksperchild=3) self._stopped = False def tear_down(self): if self._stopped: return self._stopped = True print("ProcessPool teardown started...") self._pool.close() print("ProcessPool closed.") self._pool.join() print("ProcessPool stopped.") def apply(self, func, args=(), kwargs={}): if self._stopped: return None return self._pool.apply(func, args, kwargs) class Manager: def __init__(self): self.wakeup_queue = queue.Queue() self.io_pool = ThreadPool() self._started = False self._stopped = False def run(self): if self._started: return self._started = True self.cpu_pool = ProcessPool() self.scheduler = Scheduler(self) self.io_pool.start() print("DataLoader started.") io_pool = self.io_pool wakeup_q = self.wakeup_queue while 1: try: (sleep_time, jobs) = self.scheduler.get_ready_jobs() except Exception: print("Failed to get jobs, reason=%s", traceback.format_exc()) jobs = () sleep_time = 1 print(f"putting {len(jobs)} jobs in queue") io_pool.enqueue_jobs(jobs) try: go_exit = wakeup_q.get(timeout=sleep_time) except queue.Empty: pass else: if go_exit: self._stopped = True break io_pool.tear_down() self.cpu_pool.tear_down() print("DataLoader stopped.") def tear_down(self): self.wakeup_queue.put(True) def stopped(self): return self._stopped def run_io_jobs(self, jobs, block=True): self.io_pool.enqueue_jobs(jobs, block) def run_computing_job(self, func, args=(), kwargs={}): print(f"executing cpu_pool apply {threading.current_thread().getName()}") res = self.cpu_pool.apply(func, args, kwargs) print(f"completed cpu_pool apply {threading.current_thread().getName()}") return res def processing_func(): print("processing something") class Job: def __init__(self, manager): self.lock = threading.Lock() self.manager = manager def __call__(self): with self.lock: print(f"executing dummy_func {threading.current_thread().getName()}") self.manager.run_computing_job(processing_func) print(f"completed dummy_func {threading.current_thread().getName()}") class Scheduler: def __init__(self, manager): self.count = 0 self.manager = manager self.lock = threading.Lock() def get_ready_jobs(self): with self.lock: if self.count > 50: # raise Exception("no jobs available") return (1, []) else: self.count += 1 return (1, [Job(self.manager) for _ in range(100)]) if self.count % 5 == 0 else (1, []) def setup_signal_handler(manager): def _handle_exit(signum, frame): print("Starting teardown...") manager.tear_down() def handle_tear_down_signals(callback): import signal signal.signal(signal.SIGTERM, callback) signal.signal(signal.SIGINT, callback) if os.name == "nt": signal.signal(signal.SIGBREAK, callback) handle_tear_down_signals(_handle_exit) if __name__ == "__main__": manager = Manager() setup_signal_handler(manager) manager.run()