diff --git a/Lib/multiprocessing/pool.py b/Lib/multiprocessing/pool.py index 8555da9..ff5deca 100644 --- a/Lib/multiprocessing/pool.py +++ b/Lib/multiprocessing/pool.py @@ -12,6 +12,7 @@ __all__ = ['Pool'] # Imports # +import os import threading import Queue import itertools @@ -30,6 +31,21 @@ CLOSE = 1 TERMINATE = 2 # +# Constants representing the state of a job +# + +ACK = 0 +READY = 1 + +# +# Exceptions +# + +class WorkerLostError(Exception): + """The worker processing a job has exited prematurely.""" + pass + +# # Miscellaneous # @@ -42,8 +58,10 @@ def mapstar(args): # Code run by worker processes # -def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None): +def worker(inqueue, outqueue, initializer=None, initargs=(), + maxtasks=None): assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0) + pid = os.getpid() put = outqueue.put get = inqueue.get if hasattr(inqueue, '_writer'): @@ -66,11 +84,12 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None): break job, i, func, args, kwds = task + put((ACK, (job, i, time.time(), pid))) try: result = (True, func(*args, **kwds)) except Exception, e: result = (False, e) - put((job, i, result)) + put((READY, (job, i, result))) completed += 1 debug('worker exiting after %d tasks' % completed) @@ -118,7 +137,8 @@ class Pool(object): self._task_handler = threading.Thread( target=Pool._handle_tasks, - args=(self._taskqueue, self._quick_put, self._outqueue, self._pool) + args=(self._taskqueue, self._quick_put, self._outqueue, + self._pool) ) self._task_handler.daemon = True self._task_handler._state = RUN @@ -126,7 +146,8 @@ class Pool(object): self._result_handler = threading.Thread( target=Pool._handle_results, - args=(self._outqueue, self._quick_get, self._cache) + args=(self._outqueue, self._quick_get, self._cache, + self._poll_result, self._join_exited_workers) ) self._result_handler.daemon = True self._result_handler._state = RUN @@ -134,8 +155,8 @@ class Pool(object): self._terminate = Finalize( self, self._terminate_pool, - args=(self._taskqueue, self._inqueue, self._outqueue, self._pool, - self._worker_handler, self._task_handler, + args=(self._taskqueue, self._inqueue, self._outqueue, + self._pool, self._worker_handler, self._task_handler, self._result_handler, self._cache), exitpriority=15 ) @@ -144,24 +165,36 @@ class Pool(object): """Cleanup after any worker processes which have exited due to reaching their specified lifetime. Returns True if any workers were cleaned up. """ - cleaned = False + cleaned = [] for i in reversed(range(len(self._pool))): worker = self._pool[i] if worker.exitcode is not None: # worker exited debug('cleaning up worker %d' % i) worker.join() - cleaned = True + cleaned.append(worker.pid) del self._pool[i] - return cleaned + if cleaned: + for job in self._cache.values(): + for worker_pid in job.worker_pids(): + if worker_pid in cleaned: + err = WorkerLostError("Worker exited prematurely.") + job._set(None, (False, err)) + continue + + return True + return False def _repopulate_pool(self): """Bring the number of pool processes up to the specified number, for use after reaping workers which have exited. """ for i in range(self._processes - len(self._pool)): + if self._state != RUN: + return w = self.Process(target=worker, - args=(self._inqueue, self._outqueue, + args=(self._inqueue, + self._outqueue, self._initializer, self._initargs, self._maxtasksperchild) ) @@ -184,6 +217,12 @@ class Pool(object): self._quick_put = self._inqueue._writer.send self._quick_get = self._outqueue._reader.recv + def _poll_result(timeout): + if self._outqueue._reader.poll(timeout): + return True, self._quick_get() + return False, None + self._poll_result = _poll_result + def apply(self, func, args=(), kwds={}): ''' Equivalent of `apply()` builtin @@ -234,12 +273,25 @@ class Pool(object): for i, x in enumerate(task_batches)), result._set_length)) return (item for chunk in result for item in chunk) - def apply_async(self, func, args=(), kwds={}, callback=None): + def apply_async(self, func, args=(), kwds={}, callback=None, + accept_callback=None): ''' Asynchronous equivalent of `apply()` builtin + + Callback is called when the functions return value is ready. + The accept callback is called when the job is reserved by a worker + process. + + Simplified the flow is like this: + + >>> if accept_callback: + ... accept_callback() + >>> retval = func(*args, **kwds) + >>> if callback: + ... callback(retval) ''' assert self._state == RUN - result = ApplyResult(self._cache, callback) + result = ApplyResult(self._cache, callback, accept_callback) self._taskqueue.put(([(result._job, None, func, args, kwds)], None)) return result @@ -311,9 +363,31 @@ class Pool(object): debug('task handler exiting') @staticmethod - def _handle_results(outqueue, get, cache): + def _handle_results(outqueue, get, cache, poll, join_exited_workers): thread = threading.current_thread() + def on_ack(job, i, time_accepted, pid): + try: + cache[job]._ack(i, time_accepted, pid) + except (KeyError, AttributeError): + # Object gone or doesn't support _ack (e.g. IMAPIterator) + pass + + def on_ready(job, i, obj): + try: + cache[job]._set(i, obj) + except KeyError: + pass + + state_handlers = {ACK: on_ack, READY: on_ready} + + def on_state_change(state, args): + try: + state_handlers[state](*args) + except KeyError: + debug("Unknown job state: %s (%s)" % (state, meta)) + + while 1: try: task = get() @@ -330,27 +404,26 @@ class Pool(object): debug('result handler got sentinel') break - job, i, obj = task - try: - cache[job]._set(i, obj) - except KeyError: - pass + state, args = task + on_state_change(state, args) + while cache and thread._state != TERMINATE: try: - task = get() + ready, task = poll(0.2) except (IOError, EOFError): debug('result handler got EOFError/IOError -- exiting') return - if task is None: - debug('result handler ignoring extra sentinel') - continue - job, i, obj = task - try: - cache[job]._set(i, obj) - except KeyError: - pass + if ready: + if task is None: + debug('result handler ignoring extra sentinel') + continue + + state, args = task + on_state_change(state, args) + + join_exited_workers() if hasattr(outqueue, '_reader'): debug('ensuring that outqueue is not full') @@ -387,6 +460,7 @@ class Pool(object): if self._state == RUN: self._state = CLOSE self._worker_handler._state = CLOSE + self._worker_handler.join() self._taskqueue.put(None) def terminate(self): @@ -398,10 +472,14 @@ class Pool(object): def join(self): debug('joining pool') assert self._state in (CLOSE, TERMINATE) + debug('joining worker handler') self._worker_handler.join() + debug('joining task handler') self._task_handler.join() + debug('joining result handler') self._result_handler.join() for p in self._pool: + debug('joining worker %r' % (p, )) p.join() @staticmethod @@ -415,7 +493,8 @@ class Pool(object): @classmethod def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool, - worker_handler, task_handler, result_handler, cache): + worker_handler, task_handler, + result_handler, cache): # this is guaranteed to only be called once debug('finalizing pool') @@ -458,12 +537,17 @@ class Pool(object): class ApplyResult(object): - def __init__(self, cache, callback): + def __init__(self, cache, callback, accept_callback): self._cond = threading.Condition(threading.Lock()) self._job = job_counter.next() self._cache = cache self._ready = False self._callback = callback + self._accept_callback = accept_callback + + self._accepted = False + self._worker_pid = None + self._time_accepted = None cache[self._job] = self def ready(self): @@ -473,6 +557,12 @@ class ApplyResult(object): assert self._ready return self._success + def accepted(self): + return self._accepted + + def worker_pids(self): + return filter(None, [self._worker_pid]) + def wait(self, timeout=None): self._cond.acquire() try: @@ -500,7 +590,18 @@ class ApplyResult(object): self._cond.notify() finally: self._cond.release() - del self._cache[self._job] + if self._accepted: + self._cache.pop(self._job, None) + + def _ack(self, i, time_accepted, pid): + self._accepted = True + self._time_accepted = time_accepted + self._worker_pid = pid + if self._accept_callback: + self._accept_callback() + if self._ready: + self._cache.pop(self._job, None) + # # Class whose instances are returned by `Pool.map_async()` @@ -508,10 +609,15 @@ class ApplyResult(object): class MapResult(ApplyResult): - def __init__(self, cache, chunksize, length, callback): - ApplyResult.__init__(self, cache, callback) + def __init__(self, cache, chunksize, length, callback, + accept_callback=None): + ApplyResult.__init__(self, cache, callback, accept_callback) self._success = True + self._length = length self._value = [None] * length + self._accepted = [False] * length + self._worker_pid = [None] * length + self._time_accepted = [None] * length self._chunksize = chunksize if chunksize <= 0: self._number_left = 0 @@ -527,18 +633,19 @@ class MapResult(ApplyResult): if self._number_left == 0: if self._callback: self._callback(self._value) - del self._cache[self._job] + if self.accepted: + self._cache.pop(self._job, None) self._cond.acquire() try: self._ready = True self._cond.notify() finally: self._cond.release() - else: self._success = False self._value = result - del self._cache[self._job] + if self.accepted: + self._cache.pop(self._job, None) self._cond.acquire() try: self._ready = True @@ -546,6 +653,22 @@ class MapResult(ApplyResult): finally: self._cond.release() + def _ack(self, i, time_accepted, pid): + start = i * self._chunksize + stop = (i + 1) * self._chunksize + for j in range(start, stop): + self._accepted[j] = True + self._worker_pid[j] = pid + self._time_accepted[j] = time_accepted + if self._ready: + self._cache.pop(self._job, None) + + def accepted(self): + return all(self._accepted) + + def worker_pids(self): + return filter(None, self._worker_pid) + # # Class whose instances are returned by `Pool.imap()` # @@ -653,6 +776,13 @@ class ThreadPool(Pool): self._quick_put = self._inqueue.put self._quick_get = self._outqueue.get + def _poll_result(timeout): + try: + return True, self._quick_get(timeout=timeout) + except Queue.Empty: + return False, None + self._poll_result = _poll_result + @staticmethod def _help_stuff_finish(inqueue, task_handler, size): # put sentinels at head of inqueue to make workers finish diff --git a/Lib/test/test_multiprocessing.py b/Lib/test/test_multiprocessing.py index 59b3357..64b33c1 100644 --- a/Lib/test/test_multiprocessing.py +++ b/Lib/test/test_multiprocessing.py @@ -34,6 +34,8 @@ import multiprocessing.pool from multiprocessing import util +from multiprocessing.pool import WorkerLostError + # # # @@ -995,6 +997,7 @@ class _TestContainers(BaseTestCase): def sqr(x, wait=0.0): time.sleep(wait) return x*x + class _TestPool(BaseTestCase): def test_apply(self): @@ -1020,6 +1023,24 @@ class _TestPool(BaseTestCase): self.assertEqual(get(), 49) self.assertTimingAlmostEqual(get.elapsed, TIMEOUT1) + def test_async_accept_callback(self): + if self.TYPE == 'manager': + return + scratchpad = [False] + def accept_callback(): + scratchpad[0] = True + + res = self.pool.apply_async(sqr, (7, TIMEOUT1), + accept_callback=accept_callback) + get = TimingWrapper(res.get) + self.assertEqual(get(), 49) + self.assertTimingAlmostEqual(get.elapsed, TIMEOUT1) + self.assertTrue(scratchpad[0]) + self.assertTrue(res._worker_pid) + self.assertTrue(res._time_accepted) + self.assertTrue(res.accepted()) + self.assertTrue(res.worker_pids()) + def test_async_timeout(self): res = self.pool.apply_async(sqr, (6, TIMEOUT2 + 0.2)) get = TimingWrapper(res.get) @@ -1071,9 +1092,62 @@ class _TestPool(BaseTestCase): join() self.assertTrue(join.elapsed < 0.2) -class _TestPoolWorkerLifetime(BaseTestCase): +def terminates_by_signal(signum): + os.kill(os.getpid(), signum) + +def terminates_by_SystemExit(): + raise SystemExit + +def sends_SIGKILL_sometimes(i): + if not i % 5: + os.kill(os.getpid(), signal.SIGKILL) + +class _TestPoolSupervisor(BaseTestCase): + ALLOWED_TYPES = ('processes', ) + + def test_job_killed_by_signal(self): + p = multiprocessing.Pool(3) + results = [p.apply_async(terminates_by_signal, (signal.SIGKILL, )) + for i in xrange(20)] + + res = p.apply_async(sqr, (7, 0.0)) + self.assertEqual(res.get(), 49, + 'supervisor did restart crashed workers') + + for result in results: + with self.assertRaises(WorkerLostError): + result.get() + + p.close() + p.join() + + def test_map_killed_by_signal(self): + p = multiprocessing.Pool(3) + res = p.map_async(sends_SIGKILL_sometimes, xrange(12)) + with self.assertRaises(WorkerLostError): + res.get() + p.close() + p.join() + + def test_job_raising_SystemExit(self): + p = multiprocessing.Pool(3) + results = [p.apply_async(terminates_by_SystemExit) + for i in xrange(20)] + for result in results: + with self.assertRaises(WorkerLostError): + result.get() + + res = p.apply_async(sqr, (7, 0.0)) + self.assertEqual(res.get(), 49, + 'supervisor did restart crashed workers') + + p.close() + p.join() + +class _TestPoolWorkerLifetime(BaseTestCase): ALLOWED_TYPES = ('processes', ) + def test_pool_worker_lifetime(self): p = multiprocessing.Pool(3, maxtasksperchild=10) self.assertEqual(3, len(p._pool))