diff --git a/Lib/multiprocessing/pool.py b/Lib/multiprocessing/pool.py index 8555da9..4d10c9c 100644 --- a/Lib/multiprocessing/pool.py +++ b/Lib/multiprocessing/pool.py @@ -12,6 +12,7 @@ __all__ = ['Pool'] # Imports # +import os import threading import Queue import itertools @@ -30,6 +31,21 @@ CLOSE = 1 TERMINATE = 2 # +# Constants representing the state of a job +# + +ACK = 0 +READY = 1 + +# +# Exceptions +# + +class WorkerLostError(Exception): + """The worker processing a job has exited prematurely.""" + pass + +# # Miscellaneous # @@ -44,6 +60,7 @@ def mapstar(args): def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None): assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0) + pid = os.getpid() put = outqueue.put get = inqueue.get if hasattr(inqueue, '_writer'): @@ -66,11 +83,12 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None): break job, i, func, args, kwds = task + put((ACK, (job, i, pid))) try: result = (True, func(*args, **kwds)) except Exception, e: result = (False, e) - put((job, i, result)) + put((READY, (job, i, result))) completed += 1 debug('worker exiting after %d tasks' % completed) @@ -126,7 +144,8 @@ class Pool(object): self._result_handler = threading.Thread( target=Pool._handle_results, - args=(self._outqueue, self._quick_get, self._cache) + args=(self._outqueue, self._quick_get, self._cache, + self._poll_result, self._join_exited_workers) ) self._result_handler.daemon = True self._result_handler._state = RUN @@ -144,22 +163,33 @@ class Pool(object): """Cleanup after any worker processes which have exited due to reaching their specified lifetime. Returns True if any workers were cleaned up. """ - cleaned = False + cleaned = [] for i in reversed(range(len(self._pool))): worker = self._pool[i] if worker.exitcode is not None: # worker exited debug('cleaning up worker %d' % i) worker.join() - cleaned = True + cleaned.append(worker.pid) del self._pool[i] - return cleaned + if cleaned: + for job in self._cache.values(): + for worker_pid in job.worker_pids(): + if worker_pid in cleaned: + err = WorkerLostError("Worker exited prematurely.") + job._set(None, (False, err)) + continue + + return True + return False def _repopulate_pool(self): """Bring the number of pool processes up to the specified number, for use after reaping workers which have exited. """ for i in range(self._processes - len(self._pool)): + if self._state != RUN: + return w = self.Process(target=worker, args=(self._inqueue, self._outqueue, self._initializer, @@ -184,6 +214,12 @@ class Pool(object): self._quick_put = self._inqueue._writer.send self._quick_get = self._outqueue._reader.recv + def _poll_result(timeout): + if self._outqueue._reader.poll(timeout): + return True, self._quick_get() + return False, None + self._poll_result = _poll_result + def apply(self, func, args=(), kwds={}): ''' Equivalent of `apply()` builtin @@ -311,9 +347,30 @@ class Pool(object): debug('task handler exiting') @staticmethod - def _handle_results(outqueue, get, cache): + def _handle_results(outqueue, get, cache, poll, join_exited_workers): thread = threading.current_thread() + def on_ack(job, i, pid): + try: + cache[job]._ack(i, pid) + except (KeyError, AttributeError): + # Object gone or doesn't support _ack (e.g. IMAPIterator) + pass + + def on_ready(job, i, obj): + try: + cache[job]._set(i, obj) + except KeyError: + pass + + state_handlers = {ACK: on_ack, READY: on_ready} + + def on_state_change(state, args): + try: + state_handlers[state](*args) + except KeyError: + debug("Unknown job state: %s (%s)" % (state, meta)) + while 1: try: task = get() @@ -330,27 +387,25 @@ class Pool(object): debug('result handler got sentinel') break - job, i, obj = task - try: - cache[job]._set(i, obj) - except KeyError: - pass + state, args = task + on_state_change(state, args) while cache and thread._state != TERMINATE: try: - task = get() + ready, task = poll(0.2) except (IOError, EOFError): debug('result handler got EOFError/IOError -- exiting') return - if task is None: - debug('result handler ignoring extra sentinel') - continue - job, i, obj = task - try: - cache[job]._set(i, obj) - except KeyError: - pass + if ready: + if task is None: + debug('result handler ignoring extra sentinel') + continue + + state, args = task + on_state_change(state, args) + + join_exited_workers() if hasattr(outqueue, '_reader'): debug('ensuring that outqueue is not full') @@ -387,6 +442,7 @@ class Pool(object): if self._state == RUN: self._state = CLOSE self._worker_handler._state = CLOSE + self._worker_handler.join() self._taskqueue.put(None) def terminate(self): @@ -464,6 +520,8 @@ class ApplyResult(object): self._cache = cache self._ready = False self._callback = callback + self._accepted = False + self._worker_pid = None cache[self._job] = self def ready(self): @@ -473,6 +531,12 @@ class ApplyResult(object): assert self._ready return self._success + def accepted(self): + return self._accepted + + def worker_pids(self): + return filter(None, [self._worker_pid]) + def wait(self, timeout=None): self._cond.acquire() try: @@ -500,7 +564,15 @@ class ApplyResult(object): self._cond.notify() finally: self._cond.release() - del self._cache[self._job] + if self._accepted: + self._cache.pop(self._job, None) + + def _ack(self, i, pid): + self._accepted = True + self._worker_pid = pid + if self._ready: + self._cache.pop(self._job, None) + # # Class whose instances are returned by `Pool.map_async()` @@ -511,7 +583,10 @@ class MapResult(ApplyResult): def __init__(self, cache, chunksize, length, callback): ApplyResult.__init__(self, cache, callback) self._success = True + self._length = length self._value = [None] * length + self._accepted = [False] * length + self._worker_pid = [None] * length self._chunksize = chunksize if chunksize <= 0: self._number_left = 0 @@ -527,18 +602,19 @@ class MapResult(ApplyResult): if self._number_left == 0: if self._callback: self._callback(self._value) - del self._cache[self._job] + if self.accepted: + self._cache.pop(self._job, None) self._cond.acquire() try: self._ready = True self._cond.notify() finally: self._cond.release() - else: self._success = False self._value = result - del self._cache[self._job] + if self.accepted: + self._cache.pop(self._job, None) self._cond.acquire() try: self._ready = True @@ -546,6 +622,21 @@ class MapResult(ApplyResult): finally: self._cond.release() + def _ack(self, i, pid): + start = i * self._chunksize + stop = (i + 1) * self._chunksize + for j in range(start, stop): + self._accepted[j] = True + self._worker_pid[j] = pid + if self._ready: + self._cache.pop(self._job, None) + + def accepted(self): + return all(self._accepted) + + def worker_pids(self): + return filter(None, self._worker_pid) + # # Class whose instances are returned by `Pool.imap()` # @@ -653,6 +744,13 @@ class ThreadPool(Pool): self._quick_put = self._inqueue.put self._quick_get = self._outqueue.get + def _poll_result(timeout): + try: + return True, self._quick_get(timeout=timeout) + except Queue.Empty: + return False, None + self._poll_result = _poll_result + @staticmethod def _help_stuff_finish(inqueue, task_handler, size): # put sentinels at head of inqueue to make workers finish diff --git a/Lib/test/test_multiprocessing.py b/Lib/test/test_multiprocessing.py index 59b3357..e24877f 100644 --- a/Lib/test/test_multiprocessing.py +++ b/Lib/test/test_multiprocessing.py @@ -34,6 +34,8 @@ import multiprocessing.pool from multiprocessing import util +from multiprocessing.pool import WorkerLostError + # # # @@ -995,6 +997,7 @@ class _TestContainers(BaseTestCase): def sqr(x, wait=0.0): time.sleep(wait) return x*x + class _TestPool(BaseTestCase): def test_apply(self): @@ -1071,9 +1074,63 @@ class _TestPool(BaseTestCase): join() self.assertTrue(join.elapsed < 0.2) -class _TestPoolWorkerLifetime(BaseTestCase): +def terminates_by_signal(signum): + os.kill(os.getpid(), signum) + +def terminates_by_SystemExit(): + raise SystemExit + +def sends_SIGKILL_sometimes(i): + if not i % 5: + os.kill(os.getpid(), signal.SIGKILL) + +class _TestPoolSupervisor(BaseTestCase): ALLOWED_TYPES = ('processes', ) + + def test_job_killed_by_signal(self): + # TODO this won't pass under Windows. + p = multiprocessing.Pool(3) + results = [p.apply_async(terminates_by_signal, (signal.SIGKILL, )) + for i in xrange(20)] + + res = p.apply_async(sqr, (7, 0.0)) + self.assertEqual(res.get(), 49, + 'supervisor did restart crashed workers') + + for result in results: + with self.assertRaises(WorkerLostError): + result.get() + + p.close() + p.join() + + def test_map_killed_by_signal(self): + p = multiprocessing.Pool(3) + res = p.map_async(sends_SIGKILL_sometimes, xrange(12)) + with self.assertRaises(WorkerLostError): + res.get() + p.close() + p.join() + + def test_job_raising_SystemExit(self): + p = multiprocessing.Pool(3) + results = [p.apply_async(terminates_by_SystemExit) + for i in xrange(20)] + for result in results: + with self.assertRaises(WorkerLostError): + result.get() + + res = p.apply_async(sqr, (7, 0.0)) + self.assertEqual(res.get(), 49, + 'supervisor did restart crashed workers') + + p.close() + p.join() + +class _TestPoolWorkerLifetime(BaseTestCase): + ALLOWED_TYPES = ('processes', ) + def test_pool_worker_lifetime(self): p = multiprocessing.Pool(3, maxtasksperchild=10) self.assertEqual(3, len(p._pool))