diff --git a/Lib/multiprocessing/pool.py b/Lib/multiprocessing/pool.py index 8555da9..00a641a 100644 --- a/Lib/multiprocessing/pool.py +++ b/Lib/multiprocessing/pool.py @@ -12,6 +12,7 @@ __all__ = ['Pool'] # Imports # +import os import threading import Queue import itertools @@ -30,6 +31,14 @@ CLOSE = 1 TERMINATE = 2 # +# Exceptions +# + +class WorkerLostError(Exception): + """The worker processing a job has exited prematurely.""" + pass + +# # Miscellaneous # @@ -42,10 +51,13 @@ def mapstar(args): # Code run by worker processes # -def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None): +def worker(inqueue, outqueue, ackqueue, initializer=None, initargs=(), + maxtasks=None): assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0) + pid = os.getpid() put = outqueue.put get = inqueue.get + ack = ackqueue.put if hasattr(inqueue, '_writer'): inqueue._writer.close() outqueue._reader.close() @@ -66,6 +78,7 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None): break job, i, func, args, kwds = task + ack((job, i, time.time(), pid)) try: result = (True, func(*args, **kwds)) except Exception, e: @@ -118,15 +131,25 @@ class Pool(object): self._task_handler = threading.Thread( target=Pool._handle_tasks, - args=(self._taskqueue, self._quick_put, self._outqueue, self._pool) + args=(self._taskqueue, self._quick_put, self._outqueue, + self._pool, self._ackqueue) ) self._task_handler.daemon = True self._task_handler._state = RUN self._task_handler.start() + self._ack_handler = threading.Thread( + target=Pool._handle_acks, + args=(self._ackqueue, self._quick_get_ack, self._cache) + ) + self._ack_handler.daemon = True + self._ack_handler._state = RUN + self._ack_handler.start() + self._result_handler = threading.Thread( target=Pool._handle_results, - args=(self._outqueue, self._quick_get, self._cache) + args=(self._outqueue, self._quick_get, self._cache, + self._poll_result, self._join_exited_workers) ) self._result_handler.daemon = True self._result_handler._state = RUN @@ -134,7 +157,8 @@ class Pool(object): self._terminate = Finalize( self, self._terminate_pool, - args=(self._taskqueue, self._inqueue, self._outqueue, self._pool, + args=(self._taskqueue, self._inqueue, self._outqueue, + self._ackqueue, self._pool, self._ack_handler, self._worker_handler, self._task_handler, self._result_handler, self._cache), exitpriority=15 @@ -144,24 +168,37 @@ class Pool(object): """Cleanup after any worker processes which have exited due to reaching their specified lifetime. Returns True if any workers were cleaned up. """ - cleaned = False + cleaned = [] for i in reversed(range(len(self._pool))): worker = self._pool[i] if worker.exitcode is not None: # worker exited debug('cleaning up worker %d' % i) worker.join() - cleaned = True + cleaned.append(worker.pid) del self._pool[i] - return cleaned + if cleaned: + for job in self._cache.values(): + for worker_pid in job.worker_pids(): + if worker_pid in cleaned: + err = WorkerLostError("Worker exited prematurely.") + job._set(None, (False, err)) + continue + + return True + return False def _repopulate_pool(self): """Bring the number of pool processes up to the specified number, for use after reaping workers which have exited. """ for i in range(self._processes - len(self._pool)): + if self._state != RUN: + return w = self.Process(target=worker, - args=(self._inqueue, self._outqueue, + args=(self._inqueue, + self._outqueue, + self._ackqueue, self._initializer, self._initargs, self._maxtasksperchild) ) @@ -181,8 +218,16 @@ class Pool(object): from .queues import SimpleQueue self._inqueue = SimpleQueue() self._outqueue = SimpleQueue() + self._ackqueue = SimpleQueue() self._quick_put = self._inqueue._writer.send self._quick_get = self._outqueue._reader.recv + self._quick_get_ack = self._ackqueue._reader.recv + + def _poll_result(timeout): + if self._outqueue._reader.poll(timeout): + return True, self._quick_get() + return False, None + self._poll_result = _poll_result def apply(self, func, args=(), kwds={}): ''' @@ -234,12 +279,25 @@ class Pool(object): for i, x in enumerate(task_batches)), result._set_length)) return (item for chunk in result for item in chunk) - def apply_async(self, func, args=(), kwds={}, callback=None): + def apply_async(self, func, args=(), kwds={}, callback=None, + accept_callback=None): ''' Asynchronous equivalent of `apply()` builtin + + Callback is called when the functions return value is ready. + The accept callback is called when the job is reserved by a worker + process. + + Simplified the flow is like this: + + >>> if accept_callback: + ... accept_callback() + >>> retval = func(*args, **kwds) + >>> if callback: + ... callback(retval) ''' assert self._state == RUN - result = ApplyResult(self._cache, callback) + result = ApplyResult(self._cache, callback, accept_callback) self._taskqueue.put(([(result._job, None, func, args, kwds)], None)) return result @@ -272,7 +330,7 @@ class Pool(object): debug('worker handler exiting') @staticmethod - def _handle_tasks(taskqueue, put, outqueue, pool): + def _handle_tasks(taskqueue, put, outqueue, pool, ackqueue): thread = threading.current_thread() for taskseq, set_length in iter(taskqueue.get, None): @@ -301,6 +359,9 @@ class Pool(object): debug('task handler sending sentinel to result handler') outqueue.put(None) + debug('task handler sending sentinel to ack handler') + ackqueue.put(None) + # tell workers there is no more work debug('task handler sending sentinel to workers') for p in pool: @@ -311,47 +372,98 @@ class Pool(object): debug('task handler exiting') @staticmethod - def _handle_results(outqueue, get, cache): + def _handle_acks(ackqueue, get, cache): + debug('ack handler starting') thread = threading.current_thread() while 1: try: task = get() - except (IOError, EOFError): - debug('result handler got EOFError/IOError -- exiting') + except (IOError, EOFError), e: + debug('ack handler got %r -- exiting' % (e, )) return if thread._state: assert thread._state == TERMINATE - debug('result handler found thread._state=TERMINATE') + debug('ack handler found thread._state=TERMINATE') break if task is None: - debug('result handler got sentinel') + debug('ack handler received sentinel') break - job, i, obj = task + job, i, time_accepted, pid = task try: - cache[job]._set(i, obj) - except KeyError: + cache[job]._ack(i, time_accepted, pid) + except (KeyError, AttributeError): + # Object gone, or doesn't support _ack (e.g. IMAPIterator) pass while cache and thread._state != TERMINATE: try: task = get() + except (EOFError, IOError), e: + debug('ack handler got %r -- exiting' % (e, )) + return + + if task is None: + debug('ack handler ignoring extra sentinel') + continue + + job, i, time_accepted, pid = task + try: + cache[job]._ack(i, time_accepted, pid) + except (KeyError, AttributeError): + pass + + debug('ack handler exiting: not accepted=%s thread._state=%s' % ( + sum(1 for job in cache.values() if not job.accepted), + thread._state)) + + @staticmethod + def _handle_results(outqueue, get, cache, poll, join_exited_workers): + thread = threading.current_thread() + + while 1: + try: + task = get() except (IOError, EOFError): debug('result handler got EOFError/IOError -- exiting') return + if thread._state: + assert thread._state == TERMINATE + debug('result handler found thread._state=TERMINATE') + break + if task is None: - debug('result handler ignoring extra sentinel') - continue + debug('result handler got sentinel') + break + job, i, obj = task try: cache[job]._set(i, obj) except KeyError: pass + while cache and thread._state != TERMINATE: + try: + ready, task = poll(0.2) + except (IOError, EOFError): + debug('result handler got EOFError/IOError -- exiting') + return + + if ready: + if task is None: + debug('result handler ignoring extra sentinel') + continue + job, i, obj = task + try: + cache[job]._set(i, obj) + except KeyError: + pass + join_exited_workers() + if hasattr(outqueue, '_reader'): debug('ensuring that outqueue is not full') # If we don't make room available in outqueue then @@ -387,6 +499,7 @@ class Pool(object): if self._state == RUN: self._state = CLOSE self._worker_handler._state = CLOSE + self._worker_handler.join() self._taskqueue.put(None) def terminate(self): @@ -398,10 +511,16 @@ class Pool(object): def join(self): debug('joining pool') assert self._state in (CLOSE, TERMINATE) + debug('joining worker handler') self._worker_handler.join() + debug('joining task handler') self._task_handler.join() + debug('joining result handler') self._result_handler.join() + debug('joining ack handler') + self._ack_handler.join() for p in self._pool: + debug('joining worker %r' % (p, )) p.join() @staticmethod @@ -414,8 +533,9 @@ class Pool(object): time.sleep(0) @classmethod - def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool, - worker_handler, task_handler, result_handler, cache): + def _terminate_pool(cls, taskqueue, inqueue, outqueue, ackqueue, pool, + ack_handler, worker_handler, task_handler, + result_handler, cache): # this is guaranteed to only be called once debug('finalizing pool') @@ -431,6 +551,9 @@ class Pool(object): result_handler._state = TERMINATE outqueue.put(None) # sentinel + ack_handler._state = TERMINATE + ackqueue.put(None) + # Terminate workers which haven't already finished. if pool and hasattr(pool[0], 'terminate'): debug('terminating workers') @@ -444,6 +567,9 @@ class Pool(object): debug('joining result handler') result_handler.join(1e100) + debug('joining ack handler') + ack_handler.join(1e100) + if pool and hasattr(pool[0], 'terminate'): debug('joining pool workers') for p in pool: @@ -458,12 +584,17 @@ class Pool(object): class ApplyResult(object): - def __init__(self, cache, callback): + def __init__(self, cache, callback, accept_callback): self._cond = threading.Condition(threading.Lock()) self._job = job_counter.next() self._cache = cache self._ready = False self._callback = callback + self._accept_callback = accept_callback + + self._accepted = False + self._worker_pid = None + self._time_accepted = None cache[self._job] = self def ready(self): @@ -473,6 +604,12 @@ class ApplyResult(object): assert self._ready return self._success + def accepted(self): + return self._accepted + + def worker_pids(self): + return filter(None, [self._worker_pid]) + def wait(self, timeout=None): self._cond.acquire() try: @@ -500,7 +637,18 @@ class ApplyResult(object): self._cond.notify() finally: self._cond.release() - del self._cache[self._job] + if self._accepted: + self._cache.pop(self._job, None) + + def _ack(self, i, time_accepted, pid): + self._accepted = True + self._time_accepted = time_accepted + self._worker_pid = pid + if self._accept_callback: + self._accept_callback() + if self._ready: + self._cache.pop(self._job, None) + # # Class whose instances are returned by `Pool.map_async()` @@ -508,10 +656,15 @@ class ApplyResult(object): class MapResult(ApplyResult): - def __init__(self, cache, chunksize, length, callback): - ApplyResult.__init__(self, cache, callback) + def __init__(self, cache, chunksize, length, callback, + accept_callback=None): + ApplyResult.__init__(self, cache, callback, accept_callback) self._success = True + self._length = length self._value = [None] * length + self._accepted = [False] * length + self._worker_pid = [None] * length + self._time_accepted = [None] * length self._chunksize = chunksize if chunksize <= 0: self._number_left = 0 @@ -527,18 +680,19 @@ class MapResult(ApplyResult): if self._number_left == 0: if self._callback: self._callback(self._value) - del self._cache[self._job] + if self.accepted: + self._cache.pop(self._job, None) self._cond.acquire() try: self._ready = True self._cond.notify() finally: self._cond.release() - else: self._success = False self._value = result - del self._cache[self._job] + if self.accepted: + self._cache.pop(self._job, None) self._cond.acquire() try: self._ready = True @@ -546,6 +700,22 @@ class MapResult(ApplyResult): finally: self._cond.release() + def _ack(self, i, time_accepted, pid): + start = i * self._chunksize + stop = (i + 1) * self._chunksize + for j in range(start, stop): + self._accepted[j] = True + self._worker_pid[j] = pid + self._time_accepted[j] = time_accepted + if self._ready: + self._cache.pop(self._job, None) + + def accepted(self): + return all(self._accepted) + + def worker_pids(self): + return filter(None, self._worker_pid) + # # Class whose instances are returned by `Pool.imap()` # @@ -650,8 +820,17 @@ class ThreadPool(Pool): def _setup_queues(self): self._inqueue = Queue.Queue() self._outqueue = Queue.Queue() + self._ackqueue = Queue.Queue() self._quick_put = self._inqueue.put self._quick_get = self._outqueue.get + self._quick_get_ack = self._ackqueue.get + + def _poll_result(timeout): + try: + return True, self._quick_get(timeout=timeout) + except Queue.Empty: + return False, None + self._poll_result = _poll_result @staticmethod def _help_stuff_finish(inqueue, task_handler, size): diff --git a/Lib/test/test_multiprocessing.py b/Lib/test/test_multiprocessing.py index 59b3357..dd07800 100644 --- a/Lib/test/test_multiprocessing.py +++ b/Lib/test/test_multiprocessing.py @@ -34,6 +34,8 @@ import multiprocessing.pool from multiprocessing import util +from multiprocessing.pool import WorkerLostError + # # # @@ -995,6 +997,7 @@ class _TestContainers(BaseTestCase): def sqr(x, wait=0.0): time.sleep(wait) return x*x + class _TestPool(BaseTestCase): def test_apply(self): @@ -1020,6 +1023,24 @@ class _TestPool(BaseTestCase): self.assertEqual(get(), 49) self.assertTimingAlmostEqual(get.elapsed, TIMEOUT1) + def test_async_accept_callback(self): + if self.TYPE == 'manager': + return + scratchpad = [False] + def accept_callback(): + scratchpad[0] = True + + res = self.pool.apply_async(sqr, (7, TIMEOUT1), + accept_callback=accept_callback) + get = TimingWrapper(res.get) + self.assertEqual(get(), 49) + self.assertTimingAlmostEqual(get.elapsed, TIMEOUT1) + self.assertTrue(scratchpad[0]) + self.assertTrue(res._worker_pid) + self.assertTrue(res._time_accepted) + self.assertTrue(res.accepted()) + self.assertTrue(res.worker_pids()) + def test_async_timeout(self): res = self.pool.apply_async(sqr, (6, TIMEOUT2 + 0.2)) get = TimingWrapper(res.get) @@ -1071,9 +1092,63 @@ class _TestPool(BaseTestCase): join() self.assertTrue(join.elapsed < 0.2) -class _TestPoolWorkerLifetime(BaseTestCase): +def terminates_by_signal(signum): + os.kill(os.getpid(), signum) + +def terminates_by_SystemExit(): + raise SystemExit + +def sends_SIGKILL_sometimes(i): + if not i % 5: + os.kill(os.getpid(), signal.SIGKILL) + +class _TestPoolSupervisor(BaseTestCase): + ALLOWED_TYPES = ('processes', ) + + def test_job_killed_by_signal(self): + p = multiprocessing.Pool(3) + results = [p.apply_async(terminates_by_signal, (signal.SIGKILL, )) + for i in xrange(20)] + + res = p.apply_async(sqr, (7, 0.0)) + self.assertEqual(res.get(), 49, + 'supervisor did restart crashed workers') + + for result in results: + with self.assertRaises(WorkerLostError): + result.get() + + p.close() + p.join() + + def test_map_killed_by_signal(self): + p = multiprocessing.Pool(3) + res = p.map_async(sends_SIGKILL_sometimes, xrange(12)) + with self.assertRaises(WorkerLostError): + res.get() + print("CLOSING") + p.close() + p.join() + + def test_job_raising_SystemExit(self): + p = multiprocessing.Pool(3) + results = [p.apply_async(terminates_by_SystemExit) + for i in xrange(20)] + for result in results: + with self.assertRaises(WorkerLostError): + result.get() + + res = p.apply_async(sqr, (7, 0.0)) + self.assertEqual(res.get(), 49, + 'supervisor did restart crashed workers') + + p.close() + p.join() + +class _TestPoolWorkerLifetime(BaseTestCase): ALLOWED_TYPES = ('processes', ) + def test_pool_worker_lifetime(self): p = multiprocessing.Pool(3, maxtasksperchild=10) self.assertEqual(3, len(p._pool))