Copyright (C) 2010 Ksplice, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. diff --git a/Lib/multiprocessing/pool.py b/Lib/multiprocessing/pool.py index 8555da9..0148639 100644 --- a/Lib/multiprocessing/pool.py +++ b/Lib/multiprocessing/pool.py @@ -16,6 +16,8 @@ import threading import Queue import itertools import collections +import os +import sys import time from multiprocessing import Process, cpu_count, TimeoutError @@ -29,6 +31,9 @@ RUN = 0 CLOSE = 1 TERMINATE = 2 +ACK = 'ACK' +NACK = 'NACK' + # # Miscellaneous # @@ -42,7 +47,24 @@ def mapstar(args): # Code run by worker processes # -def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None): +def worker(inqueue, outqueue, report_task, get_ack, initializer=None, initargs=(), maxtasks=None, pid=None): + # Worker protocol: + # - Task handler puts task into outqueue + # - Worker removes task from queue, puts report into report_task_queue + # - Task handler reads pending reports and ACKs the first worker to claim + # the current task, NACKing all other worker claims. + # - If the task handler waits a second without receiving any reports, + # it reschedules the current task. + # - Once the worker receives an ACK, it proceeds. If it receives a NACK + # it drops the task it's holding and begins the cycle anew. + # + # If the worker exits with a nonzero exit status, then all tasks + # currently assigned to the worker are treated as if they had run + # a RuntimeError. + # + # This protocol ensures that each task is begun exactly once, and + # only after the task handler has recorded which worker has the + # task. assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0) put = outqueue.put get = inqueue.get @@ -52,6 +74,8 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None): if initializer is not None: initializer(*initargs) + if pid is None: + pid = os.getpid() completed = 0 while maxtasks is None or (maxtasks and completed < maxtasks): @@ -59,21 +83,43 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None): task = get() except (EOFError, IOError): debug('worker got EOFError or IOError -- exiting') - break + sys.exit(1) if task is None: debug('worker got sentinel -- exiting') break job, i, func, args, kwds = task + report_task((pid, job, i)) + try: + acked_job, acked_i, response = get_ack() + except Exception, e: + debug('could not retrieve ACK from parent -- %s' % e) + sys.exit(1) + + assert job == acked_job + assert i == acked_i + + if response == NACK: + continue + + assert response == ACK try: result = (True, func(*args, **kwds)) except Exception, e: result = (False, e) - put((job, i, result)) + put((pid, job, i, result)) completed += 1 debug('worker exiting after %d tasks' % completed) + +class LockableDict(dict): + ''' + A dictionary class with its own lock. + ''' + def __init__(self): + self.lock = threading.Lock() + # # Class representing a process pool # @@ -103,6 +149,11 @@ class Pool(object): if initializer is not None and not hasattr(initializer, '__call__'): raise TypeError('initializer must be a callable') + self._putmsg_by_pid = LockableDict() + self._task_assignments = LockableDict() + + self._report_task_queue = self._queue_factory() + self._processes = processes self._pool = [] self._repopulate_pool() @@ -115,10 +166,22 @@ class Pool(object): self._worker_handler._state = RUN self._worker_handler.start() + def get_reported_tasks(target_job, target_i): + """Page through the queue of reported tasks until one with + the desired job and i value are found.""" + reported_tasks = [] + while self._report_task_queue._reader.poll(1): + report = self._report_task_queue.get() + reported_tasks.append(report) + pid, job, i = report + if job == target_job and i == target_i: + break + return reported_tasks self._task_handler = threading.Thread( target=Pool._handle_tasks, - args=(self._taskqueue, self._quick_put, self._outqueue, self._pool) + args=(self._taskqueue, self._quick_put, self._outqueue, self._pool, + get_reported_tasks, self._putmsg_by_pid, self._task_assignments) ) self._task_handler.daemon = True self._task_handler._state = RUN @@ -126,7 +189,7 @@ class Pool(object): self._result_handler = threading.Thread( target=Pool._handle_results, - args=(self._outqueue, self._quick_get, self._cache) + args=(self._outqueue, self._quick_get, self._cache, self._task_assignments) ) self._result_handler.daemon = True self._result_handler._state = RUN @@ -153,6 +216,25 @@ class Pool(object): worker.join() cleaned = True del self._pool[i] + with self._putmsg_by_pid.lock: + del self._putmsg_by_pid[worker.pid] + if not worker.exitcode: + continue + + if worker.exitcode > 0: + msg = 'worker %d exited with code %d' % \ + (worker.pid, worker.exitcode) + else: + msg = 'worker %d killed with signal %d' % \ + (worker.pid, -worker.exitcode) + with self._task_assignments.lock: + # If it has no assigned task, then there's nothing to worry about. + tasks = self._task_assignments.get(worker.pid) + for (job, i) in tasks: + debug("inserting error as result for worker %s's work on job %s, index %s" % + (worker.pid, job, i)) + result = (False, RuntimeError(msg)) + self._outqueue.put((worker.pid, job, i, result)) return cleaned def _repopulate_pool(self): @@ -160,15 +242,23 @@ class Pool(object): for use after reaping workers which have exited. """ for i in range(self._processes - len(self._pool)): + q = self._queue_factory() w = self.Process(target=worker, args=(self._inqueue, self._outqueue, + self._report_task_queue.put, q.get, self._initializer, self._initargs, self._maxtasksperchild) ) self._pool.append(w) w.name = w.name.replace('Process', 'PoolWorker') w.daemon = True - w.start() + # We don't know the pid before starting w, but once we + # start w another thread might have to look it up in + # self._putmsg_by_pid. So we take out its lock. + with self._putmsg_by_pid.lock: + w.start() + assert w.pid not in self._putmsg_by_pid + self._putmsg_by_pid[w.pid] = q.put debug('added worker') def _maintain_pool(self): @@ -183,6 +273,7 @@ class Pool(object): self._outqueue = SimpleQueue() self._quick_put = self._inqueue._writer.send self._quick_get = self._outqueue._reader.recv + self._queue_factory = SimpleQueue def apply(self, func, args=(), kwds={}): ''' @@ -272,19 +363,70 @@ class Pool(object): debug('worker handler exiting') @staticmethod - def _handle_tasks(taskqueue, put, outqueue, pool): + def _send_task(put, get_reported_tasks, putmsg_by_pid, task_assignments, task): thread = threading.current_thread() + while True: + if thread._state: + debug('task handler found thread._state != RUN') + return False + + try: + put(task) + except IOError: + debug('could not put task on queue') + return False + + if task is None: + return False + + job, i, func, args, kwds = task + assigned = False + try: + for pid, acked_job, acked_i in get_reported_tasks(job, i): + # Hold the putmsg_by_pid lock throughout to body + # of the loop, just in case the worker dies while + # we are in the middle. + with putmsg_by_pid.lock: + try: + putmsg = putmsg_by_pid[pid] + except KeyError: + # If pid not in putmsg_by_pid, then the worker + # with the given pid has died. Move along. + debug("Couldn't find putmsg method for worker with pid %d" % pid) + continue + + if job == acked_job and i == acked_i and not assigned: + msg = ACK + else: + msg = NACK + + with task_assignments.lock: + try: + putmsg((acked_job, acked_i, msg)) + except IOError: + # Couldn't send an ACK? That's fine. + debug("Couldn't send method for worker with pid %d" % pid) + continue + else: + if msg == ACK: + assigned = True + task_assignments.setdefault(pid, set()).add((acked_job, acked_i)) + except Exception, e: + debug('could not retrieve ack from queue: %s' % e) + return False + else: + if assigned: + return True + else: + continue + + @staticmethod + def _handle_tasks(taskqueue, put, outqueue, pool, get_reported_tasks, putmsg_by_pid, task_assignments): for taskseq, set_length in iter(taskqueue.get, None): i = -1 for i, task in enumerate(taskseq): - if thread._state: - debug('task handler found thread._state != RUN') - break - try: - put(task) - except IOError: - debug('could not put task on queue') + if not Pool._send_task(put, get_reported_tasks, putmsg_by_pid, task_assignments, task): break else: if set_length: @@ -311,7 +453,7 @@ class Pool(object): debug('task handler exiting') @staticmethod - def _handle_results(outqueue, get, cache): + def _handle_results(outqueue, get, cache, task_assignments): thread = threading.current_thread() while 1: @@ -330,7 +472,19 @@ class Pool(object): debug('result handler got sentinel') break - job, i, obj = task + pid, job, i, obj = task + with task_assignments.lock: + # If the process died right after trying to send a + # result, the worker handler may have inserted a + # result for it, yielding 2 results for the same task. + try: + assignments = task_assignments[pid] + assignments.remove((job, i)) + except KeyError: + continue + if not assignments: + del task_assignments[pid] + try: cache[job]._set(i, obj) except KeyError: @@ -346,7 +500,17 @@ class Pool(object): if task is None: debug('result handler ignoring extra sentinel') continue - job, i, obj = task + + pid, job, i, obj = task + with task_assignments.lock: + assignments = task_assignments[pid] + try: + assignments.remove((job, i)) + except KeyError: + continue + if not assignments: + del task_assignments[pid] + try: cache[job]._set(i, obj) except KeyError: @@ -645,6 +809,7 @@ class ThreadPool(Pool): from .dummy import Process def __init__(self, processes=None, initializer=None, initargs=()): + raise NotImplementedError('Have not yet added support for threaded pool') Pool.__init__(self, processes, initializer, initargs) def _setup_queues(self): diff --git a/Lib/test/test_multiprocessing.py b/Lib/test/test_multiprocessing.py index 59b3357..1aa7976 100644 --- a/Lib/test/test_multiprocessing.py +++ b/Lib/test/test_multiprocessing.py @@ -995,6 +995,14 @@ class _TestContainers(BaseTestCase): def sqr(x, wait=0.0): time.sleep(wait) return x*x + +def do_raise(e): + raise e + +def kill_self(sig): + os.kill(os.getpid(), sig) + + class _TestPool(BaseTestCase): def test_apply(self): @@ -1071,6 +1079,27 @@ class _TestPool(BaseTestCase): join() self.assertTrue(join.elapsed < 0.2) + def test_exceptions(self): + global un_unpickleable_across_processes + try: + del un_unpickleable_across_processes + except NameError: + pass + p = multiprocessing.Pool(3) + # This is defined after forking occurs, so the worker + # processes can't unpickle it, although worker threads can. + def un_unpickleable_across_processes(x): + return x + self.assertRaises(RuntimeError, p.map, kill_self, [signal.SIGKILL]) + self.assertRaises(RuntimeError, p.map, do_raise, [SystemExit(1)]) + try: + p.map(un_unpickleable_across_processes, [1]) + except Exception, e: + self.assertIsInstance(e, AttributeError) + p.close() + p.join() + + class _TestPoolWorkerLifetime(BaseTestCase): ALLOWED_TYPES = ('processes', ) @@ -2022,14 +2051,14 @@ def test_main(run=None): multiprocessing.get_logger().setLevel(LOG_LEVEL) ProcessesMixin.pool = multiprocessing.Pool(4) - ThreadsMixin.pool = multiprocessing.dummy.Pool(4) + # ThreadsMixin.pool = multiprocessing.dummy.Pool(4) ManagerMixin.manager.__init__() ManagerMixin.manager.start() ManagerMixin.pool = ManagerMixin.manager.Pool(4) testcases = ( sorted(testcases_processes.values(), key=lambda tc:tc.__name__) + - sorted(testcases_threads.values(), key=lambda tc:tc.__name__) + + # sorted(testcases_threads.values(), key=lambda tc:tc.__name__) + sorted(testcases_manager.values(), key=lambda tc:tc.__name__) + testcases_other ) @@ -2046,12 +2075,13 @@ def test_main(run=None): quiet=True): run(suite) - ThreadsMixin.pool.terminate() + # ThreadsMixin.pool.terminate() ProcessesMixin.pool.terminate() ManagerMixin.pool.terminate() ManagerMixin.manager.shutdown() - del ProcessesMixin.pool, ThreadsMixin.pool, ManagerMixin.pool + # del ProcessesMixin.pool, ThreadsMixin.pool, ManagerMixin.pool + del ProcessesMixin.pool, ManagerMixin.pool def main(): test_main(unittest.TextTestRunner(verbosity=2).run)