diff --git a/Lib/multiprocessing/pool.py b/Lib/multiprocessing/pool.py index 8555da9..c792b63 100644 --- a/Lib/multiprocessing/pool.py +++ b/Lib/multiprocessing/pool.py @@ -38,6 +38,9 @@ job_counter = itertools.count() def mapstar(args): return map(*args) +class WorkerLostError(Exception): + """The worker processing a job has exited prematurely.""" + # # Code run by worker processes # @@ -70,6 +73,14 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None): result = (True, func(*args, **kwds)) except Exception, e: result = (False, e) + except BaseException, e: + # Job raised SystemExit or equivalent, so tell the result + # handler and exit the process so it can be replaced. + err = WorkerLostError( + "Worker has terminated by user request: %r" % (e, )) + put((job, i, (False, err))) + raise + put((job, i, result)) completed += 1 debug('worker exiting after %d tasks' % completed) @@ -126,7 +137,8 @@ class Pool(object): self._result_handler = threading.Thread( target=Pool._handle_results, - args=(self._outqueue, self._quick_get, self._cache) + args=(self._outqueue, self._quick_get, self._cache, + self._poll_result, self._workers_gone) ) self._result_handler.daemon = True self._result_handler._state = RUN @@ -140,6 +152,10 @@ class Pool(object): exitpriority=15 ) + def _workers_gone(self): + self._join_exited_workers() + return not self._pool + def _join_exited_workers(self): """Cleanup after any worker processes which have exited due to reaching their specified lifetime. Returns True if any workers were cleaned up. @@ -184,6 +200,13 @@ class Pool(object): self._quick_put = self._inqueue._writer.send self._quick_get = self._outqueue._reader.recv + def _poll_result(timeout): + if self._outqueue._reader.poll(timeout): + return True, self._quick_get() + return False, None + + self._poll_result = _poll_result + def apply(self, func, args=(), kwds={}): ''' Equivalent of `apply()` builtin @@ -311,7 +334,7 @@ class Pool(object): debug('task handler exiting') @staticmethod - def _handle_results(outqueue, get, cache): + def _handle_results(outqueue, get, cache, poll, workers_gone): thread = threading.current_thread() while 1: @@ -338,19 +361,30 @@ class Pool(object): while cache and thread._state != TERMINATE: try: - task = get() + ready, task = poll(0.2) except (IOError, EOFError): debug('result handler got EOFError/IOError -- exiting') return - if task is None: - debug('result handler ignoring extra sentinel') - continue - job, i, obj = task - try: - cache[job]._set(i, obj) - except KeyError: - pass + if ready: + if task is None: + debug('result handler ignoring extra sentinel') + continue + + job, i, obj = task + try: + cache[job]._set(i, obj) + except KeyError: + pass + else: + if workers_gone(): + debug("%s active job(s), but no active workers! " + "Terminating..." % (len(cache), )) + err = WorkerLostError( + "The worker processing this job has terminated.") + for job in cache.values(): + job._set(None, (False, err)) + break if hasattr(outqueue, '_reader'): debug('ensuring that outqueue is not full') @@ -385,8 +419,11 @@ class Pool(object): def close(self): debug('closing pool') if self._state == RUN: - self._state = CLOSE + # Worker handler can't run while the result handler + # does its second pass, so wait for it to finish. self._worker_handler._state = CLOSE + self._worker_handler.join() + self._state = CLOSE self._taskqueue.put(None) def terminate(self): @@ -653,6 +690,14 @@ class ThreadPool(Pool): self._quick_put = self._inqueue.put self._quick_get = self._outqueue.get + def _poll_result(timeout): + try: + return True, self._quick_get(timeout=timeout) + except Queue.Empty: + return False, None + + self._poll_result = _poll_result + @staticmethod def _help_stuff_finish(inqueue, task_handler, size): # put sentinels at head of inqueue to make workers finish diff --git a/Lib/test/test_multiprocessing.py b/Lib/test/test_multiprocessing.py index 59b3357..4ef2a6c 100644 --- a/Lib/test/test_multiprocessing.py +++ b/Lib/test/test_multiprocessing.py @@ -34,6 +34,8 @@ import multiprocessing.pool from multiprocessing import util +from multiprocessing.pool import WorkerLostError + # # # @@ -995,6 +997,7 @@ class _TestContainers(BaseTestCase): def sqr(x, wait=0.0): time.sleep(wait) return x*x + class _TestPool(BaseTestCase): def test_apply(self): @@ -1071,6 +1074,77 @@ class _TestPool(BaseTestCase): join() self.assertTrue(join.elapsed < 0.2) +def terminates_by_signal(signum): + os.kill(os.getpid(), signum) + +def terminates_by_SystemExit(): + raise SystemExit + +class _TestPoolSupervisor(BaseTestCase): + ALLOWED_TYPES = ('processes', ) + + def test_job_killed_by_signal(self): + p = multiprocessing.Pool(3) + results = [p.apply_async(terminates_by_signal, (signal.SIGKILL, )) + for i in xrange(20)] + + res = p.apply_async(sqr, (7, 0.0)) + self.assertEqual(res.get(), 49, + 'supervisor did not restart crashed workers') + + # That the jobs have crashed won't be noticed until + # the result handler is done, so join the pool. + p.close() + p.join() + + for result in results: + with self.assertRaises(WorkerLostError): + result.get() + + def test_job_raising_SystemExit(self): + p = multiprocessing.Pool(3) + results = [p.apply_async(terminates_by_SystemExit) + for i in xrange(20)] + for result in results: + with self.assertRaises(WorkerLostError): + result.get() + + res = p.apply_async(sqr, (7, 0.0)) + self.assertEqual(res.get(), 49, + 'supervisor did not restart crashed workers') + + p.close() + p.join() + + def test_unpickleable_across_processes(self, processes=3): + global unpickleable_across_processes + try: + del(unpickleable_across_processes) + except NameError: + pass + p = multiprocessing.Pool(processes) + def unpickleable_across_processes(x): + return x + + results = [] + for i in xrange(processes): + try: + results.append(p.apply_async( + unpickleable_across_processes, [1])) + except AttributeError: + pass + + res = p.apply_async(sqr, (7, 0.0)) + self.assertEqual(res.get(), 49, + 'supervisor did not restart crashed workers') + + p.close() + p.join() + + for result in results: + with self.assertRaises(WorkerLostError): + result.get() + class _TestPoolWorkerLifetime(BaseTestCase): ALLOWED_TYPES = ('processes', )