Copyright (C) 2010 Ksplice, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Index: Lib/multiprocessing/pool.py =================================================================== --- Lib/multiprocessing/pool.py (revision 82645) +++ Lib/multiprocessing/pool.py (working copy) @@ -57,9 +57,14 @@ while maxtasks is None or (maxtasks and completed < maxtasks): try: task = get() - except (EOFError, IOError): + except (EOFError, IOError), e: debug('worker got EOFError or IOError -- exiting') + put((None, None, e)) break + except BaseException, e: + debug('worker got exception %s' % e) + put((None, None, e)) + break if task is None: debug('worker got sentinel -- exiting') @@ -93,6 +98,7 @@ self._maxtasksperchild = maxtasksperchild self._initializer = initializer self._initargs = initargs + self._termination_requested = False if processes is None: try: @@ -142,18 +148,25 @@ def _join_exited_workers(self): """Cleanup after any worker processes which have exited due to reaching - their specified lifetime. Returns True if any workers were cleaned up. + their specified lifetime. Returns a pair (cleaned, abnormal) where: + - cleaned is True if any workers were cleaned up and False otherwise + - abnormal is a list of tuples (pid, exitcode) which belong to any terminated + processes which exited unexpectedly with nonzero exitcode. """ cleaned = False + abnormal = [] for i in reversed(range(len(self._pool))): worker = self._pool[i] - if worker.exitcode is not None: + exitcode = worker.exitcode + if exitcode is not None: + cleaned = True + if exitcode != 0 and not worker._termination_requested: + abnormal.append((worker.pid, exitcode)) # worker exited debug('cleaning up worker %d' % i) worker.join() - cleaned = True del self._pool[i] - return cleaned + return cleaned, abnormal def _repopulate_pool(self): """Bring the number of pool processes up to the specified number, @@ -171,10 +184,25 @@ w.start() debug('added worker') - def _maintain_pool(self): - """Clean up any exited workers and start replacements for them. + def _maintain_pool(self, repopulate=True): + """Clean up any exited workers. If repopulate, also start replacements for them. """ - if self._join_exited_workers(): + cleaned, abnormal = self._join_exited_workers() + if abnormal: + # Other threads may be altering self._cache, so don't + # use itervalues + for application in self._cache.values(): + for pid, exitcode in abnormal: + msg = [] + if exitcode > 0: + msg.append('worker %d exited with code %d' % + (pid, exitcode)) + else: + msg.append('worker %d killed with signal %d' % + (pid, -exitcode)) + self._termination_requested = True + application._set_error(RuntimeError('Something went wrong: %s' % ', '.join(msg))) + if cleaned and repopulate: self._repopulate_pool() def _setup_queues(self): @@ -182,8 +210,17 @@ self._inqueue = SimpleQueue() self._outqueue = SimpleQueue() self._quick_put = self._inqueue._writer.send - self._quick_get = self._outqueue._reader.recv + recv = self._outqueue._reader.recv + poll = self._outqueue._reader.poll + def _quick_get(): + while True: + if poll(1): + return recv() + if self._termination_requested: + raise EOFError + self._quick_get = _quick_get + def apply(self, func, args=(), kwds={}): ''' Equivalent of `apply()` builtin @@ -269,6 +306,10 @@ while pool._worker_handler._state == RUN and pool._state == RUN: pool._maintain_pool() time.sleep(0.1) + # Do one last run to ensure that _termination_requested is set + # if necessary. Don't repopulate because of a race condition + # in Pool.close. + pool._maintain_pool(repopulate=False) debug('worker handler exiting') @staticmethod @@ -313,6 +354,7 @@ @staticmethod def _handle_results(outqueue, get, cache): thread = threading.current_thread() + failed = False while 1: try: @@ -332,11 +374,18 @@ job, i, obj = task try: - cache[job]._set(i, obj) + if job is None: + debug('received a job of None with an exception %s' % obj) + assert isinstance(obj, BaseException) + failed = True + for application in cache.values(): + application._set_error(obj) + else: + cache[job]._set(i, obj) except KeyError: pass - while cache and thread._state != TERMINATE: + while cache and thread._state != TERMINATE and not failed: try: task = get() except (IOError, EOFError): @@ -348,7 +397,14 @@ continue job, i, obj = task try: - cache[job]._set(i, obj) + if job is None: + debug('received a job of None with an exception %s' % obj) + assert isinstance(obj, BaseException) + failed = True + for application in cache.values(): + application._set_error(obj) + else: + cache[job]._set(i, obj) except KeyError: pass @@ -400,15 +456,32 @@ assert self._state in (CLOSE, TERMINATE) self._worker_handler.join() self._task_handler.join() + self._outqueue.put(None) self._result_handler.join() for p in self._pool: - p.join() + # If termination has been requested, the pool is in a + # broken state and we really shouldn't block on anything. + if not self._termination_requested: + p.join() + elif hasattr(p, 'terminate'): + p.terminate() + else: + p.join(1) @staticmethod def _help_stuff_finish(inqueue, task_handler, size): - # task_handler may be blocked trying to put items on inqueue + # task_handler may be blocked trying to put items on inqueue. + # + # We should try to acquire the queue read lock so that reads + # from the pipe remain atomic. + # + # However, the lock may not be available, if for example a + # worker acquired the lock and died without releasing it. If + # we haven't obtained it after a second, we're probably not + # getting it at all. debug('removing tasks from inqueue until task handler finished') - inqueue._rlock.acquire() + if not inqueue._rlock.acquire(timeout=1): + debug('could not acquire inqueue read lock; carrying on anyway') while task_handler.is_alive() and inqueue._reader.poll(): inqueue._reader.recv() time.sleep(0) @@ -426,8 +499,6 @@ debug('helping task handler/workers to finish') cls._help_stuff_finish(inqueue, task_handler, len(pool)) - assert result_handler.is_alive() or len(cache) == 0 - result_handler._state = TERMINATE outqueue.put(None) # sentinel @@ -450,7 +521,7 @@ if p.is_alive(): # worker has not yet exited debug('cleaning up worker %d' % p.pid) - p.join() + p.join(1) # # Class whose instances are returned by `Pool.apply_async()` @@ -464,6 +535,7 @@ self._cache = cache self._ready = False self._callback = callback + self._error = None cache[self._job] = self def ready(self): @@ -476,8 +548,12 @@ def wait(self, timeout=None): self._cond.acquire() try: + if self._error: + raise self._error if not self._ready: self._cond.wait(timeout) + if self._error: + raise self._error finally: self._cond.release() @@ -502,6 +578,14 @@ self._cond.release() del self._cache[self._job] + def _set_error(self, exception): + self._cond.acquire() + try: + self._error = exception + self._cond.notify() + finally: + self._cond.release() + # # Class whose instances are returned by `Pool.map_async()` # @@ -651,7 +735,16 @@ self._inqueue = Queue.Queue() self._outqueue = Queue.Queue() self._quick_put = self._inqueue.put - self._quick_get = self._outqueue.get + get = self._outqueue.get + def _quick_get(): + while True: + try: + return get(timeout=1) + except Queue.Empty: + pass + if self._termination_requested: + raise EOFError + self._quick_get = _quick_get @staticmethod def _help_stuff_finish(inqueue, task_handler, size): Index: Lib/multiprocessing/process.py =================================================================== --- Lib/multiprocessing/process.py (revision 82645) +++ Lib/multiprocessing/process.py (working copy) @@ -79,6 +79,7 @@ self._kwargs = dict(kwargs) self._name = name or type(self).__name__ + '-' + \ ':'.join(str(i) for i in self._identity) + self._termination_requested = False def run(self): ''' Index: Lib/multiprocessing/forking.py =================================================================== --- Lib/multiprocessing/forking.py (revision 82645) +++ Lib/multiprocessing/forking.py (working copy) @@ -135,6 +135,7 @@ def terminate(self): if self.returncode is None: + self._termination_requested = True try: os.kill(self.pid, signal.SIGTERM) except OSError, e: Index: Lib/test/test_multiprocessing.py =================================================================== --- Lib/test/test_multiprocessing.py (revision 82645) +++ Lib/test/test_multiprocessing.py (working copy) @@ -995,6 +995,14 @@ def sqr(x, wait=0.0): time.sleep(wait) return x*x + +def do_raise(e): + raise e + +def kill_self(sig): + os.kill(os.getpid(), sig) + + class _TestPool(BaseTestCase): def test_apply(self): @@ -1071,6 +1079,27 @@ join() self.assertTrue(join.elapsed < 0.2) + def test_exceptions(self): + global un_unpickleable_across_processes + try: + del un_unpickleable_across_processes + except NameError: + pass + p = multiprocessing.Pool(3) + # This is defined after forking occurs, so the worker + # processes can't unpickle it, although worker threads can. + def un_unpickleable_across_processes(x): + return x + self.assertRaises(RuntimeError, p.map, kill_self, [signal.SIGKILL]) + self.assertRaises(RuntimeError, p.map, do_raise, [SystemExit(1)]) + try: + p.map(un_unpickleable_across_processes, [1]) + except Exception, e: + self.assertIsInstance(e, AttributeError) + p.close() + p.join() + + class _TestPoolWorkerLifetime(BaseTestCase): ALLOWED_TYPES = ('processes', )