diff -r 168cd3d19fef Doc/library/concurrent.futures.rst --- a/Doc/library/concurrent.futures.rst Mon Jul 21 21:40:55 2014 +0200 +++ b/Doc/library/concurrent.futures.rst Wed Jul 23 11:17:45 2014 -0400 @@ -115,11 +115,14 @@ executor.submit(wait_on_future) -.. class:: ThreadPoolExecutor(max_workers) +.. class:: ThreadPoolExecutor(max_workers, initializer=None, initargs=()) An :class:`Executor` subclass that uses a pool of at most *max_workers* threads to execute calls asynchronously. + *initializer*: an optional callable used to initialize new worker threads. + *initargs*: a tuple of arguments passed to the initializer. + .. _threadpoolexecutor-example: diff -r 168cd3d19fef Lib/concurrent/futures/process.py --- a/Lib/concurrent/futures/process.py Mon Jul 21 21:40:55 2014 +0200 +++ b/Lib/concurrent/futures/process.py Wed Jul 23 11:17:45 2014 -0400 @@ -108,7 +108,7 @@ self.args = args self.kwargs = kwargs -def _process_worker(call_queue, result_queue): +def _process_worker(call_queue, result_queue, initializer, initargs): """Evaluates calls from call_queue and places the results in result_queue. This worker is run in a separate process. @@ -118,9 +118,16 @@ evaluated by the worker. result_queue: A multiprocessing.Queue of _ResultItems that will written to by the worker. - shutdown: A multiprocessing.Event that will be set as a signal to the - worker that it should exit when call_queue is empty. + initializer: a callable used to initialize new worker threads. + initargs: a tuple of arguments passed to the initializer. """ + if initializer is not None: + try: + initializer(*initargs) + except Exception: + result_queue.put(None) + raise + while True: call_item = call_queue.get(block=True) if call_item is None: @@ -321,13 +328,15 @@ class ProcessPoolExecutor(_base.Executor): - def __init__(self, max_workers=None): + def __init__(self, max_workers=None, initializer=None, initargs=()): """Initializes a new ProcessPoolExecutor instance. Args: max_workers: The maximum number of processes that can be used to execute the given calls. If None or not given then as many worker processes will be created as the machine has processors. + initializer: a callable used to initialize new worker processes. + initargs: a tuple of arguments passed to the initializer. """ _check_system_limits() @@ -338,6 +347,11 @@ raise ValueError("max_workers must be greater than 0") self._max_workers = max_workers + self._initializer = initializer + self._initargs = initargs + + if initializer is not None and not callable(initializer): + raise TypeError('initializer must be a callable') # Make the call queue slightly larger than the number of processes to # prevent the worker processes from idling. But don't make it too big @@ -386,7 +400,9 @@ p = multiprocessing.Process( target=_process_worker, args=(self._call_queue, - self._result_queue)) + self._result_queue, + self._initializer, + self._initargs)) p.start() self._processes[p.pid] = p diff -r 168cd3d19fef Lib/concurrent/futures/thread.py --- a/Lib/concurrent/futures/thread.py Mon Jul 21 21:40:55 2014 +0200 +++ b/Lib/concurrent/futures/thread.py Wed Jul 23 11:17:45 2014 -0400 @@ -57,7 +57,16 @@ else: self.future.set_result(result) -def _worker(executor_reference, work_queue): +def _worker(executor_reference, work_queue, initializer, initargs): + if initializer is not None: + try: + initializer(*initargs) + except Exception: + _base.LOGGER.critical('Exception in initializer:', exc_info=True) + executor = executor_reference() + with executor._initializer_lock: + executor._initializer_failed = True + return try: while True: work_item = work_queue.get(block=True) @@ -80,12 +89,14 @@ _base.LOGGER.critical('Exception in worker', exc_info=True) class ThreadPoolExecutor(_base.Executor): - def __init__(self, max_workers): + def __init__(self, max_workers, initializer=None, initargs=()): """Initializes a new ThreadPoolExecutor instance. Args: max_workers: The maximum number of threads that can be used to execute the given calls. + initializer: a callable used to initialize new worker threads. + initargs: a tuple of arguments passed to the initializer. """ if max_workers <= 0: raise ValueError("max_workers must be greater than 0") @@ -95,11 +106,22 @@ self._threads = set() self._shutdown = False self._shutdown_lock = threading.Lock() + self._initializer_failed = False + self._initializer_lock = threading.Lock() + self._initializer = initializer + self._initargs = initargs + + if initializer is not None and not callable(initializer): + raise TypeError('initializer must be a callable') def submit(self, fn, *args, **kwargs): with self._shutdown_lock: if self._shutdown: - raise RuntimeError('cannot schedule new futures after shutdown') + raise RuntimeError('Cannot schedule new futures after shutdown') + + with self._initializer_lock: + if self._initializer_False: + raise RuntimeError('Cannot schedule new futures after initializer fails.') f = _base.Future() w = _WorkItem(f, fn, args, kwargs) @@ -119,7 +141,9 @@ if len(self._threads) < self._max_workers: t = threading.Thread(target=_worker, args=(weakref.ref(self, weakref_cb), - self._work_queue)) + self._work_queue, + self._initializer, + self._initargs)) t.daemon = True t.start() self._threads.add(t) diff -r 168cd3d19fef Lib/test/test_concurrent_futures.py --- a/Lib/test/test_concurrent_futures.py Mon Jul 21 21:40:55 2014 +0200 +++ b/Lib/test/test_concurrent_futures.py Wed Jul 23 11:17:45 2014 -0400 @@ -38,6 +38,7 @@ EXCEPTION_FUTURE = create_future(state=FINISHED, exception=OSError()) SUCCESSFUL_FUTURE = create_future(state=FINISHED, result=42) +INITIALIZER_STATUS = 'uninitialized' def mul(x, y): return x * y @@ -52,6 +53,15 @@ print(msg) sys.stdout.flush() +def init(x): + global INITIALIZER_STATUS + INITIALIZER_STATUS = x + +def get_status(): + return INITIALIZER_STATUS + +def fail(): + raise ValueError('error in initializer') class MyObject(object): def my_method(self): @@ -94,6 +104,57 @@ executor_type = futures.ProcessPoolExecutor +class ExecutorInitializerMixin(ExecutorMixin): + def setUp(self): + self.t1 = time.time() + try: + self.executor = self.executor_type(max_workers=self.worker_count, + initializer=init, initargs=('initialized', )) + except NotImplementedError as e: + self.skipTest(str(e)) + self._prime_executor() + + def test_initializer(self): + futures = [self.executor.submit(get_status) + for _ in range(self.worker_count)] + + for f in futures: + self.assertEqual(f.result(), 'initialized') + + +class FailingInitializerMixin(ExecutorMixin): + def setUp(self): + self.t1 = time.time() + try: + self.executor = self.executor_type(max_workers=self.worker_count, + initializer=fail, initargs=()) + except NotImplementedError as e: + self.skipTest(str(e)) + + def test_initializer(self): + self.assertRaises((RuntimeError, BrokenProcessPool), self._prime_executor) + + +class ThreadPoolInitializerTest(ExecutorInitializerMixin, + ThreadPoolMixin, unittest.TestCase): + pass + + +class ProcessPoolInitializerTest(ExecutorInitializerMixin, + ProcessPoolMixin, unittest.TestCase): + pass + + +class ThreadPoolFailingInitializerTest(FailingInitializerMixin, + ThreadPoolMixin, unittest.TestCase): + pass + + +class ProcessPoolFailingInitializerTest(FailingInitializerMixin, + ProcessPoolMixin, unittest.TestCase): + pass + + class ExecutorShutdownTest: def test_run_after_shutdown(self): self.executor.shutdown()