diff --git a/Lib/multiprocessing/pool.py b/Lib/multiprocessing/pool.py index 8555da9..82859ff 100644 --- a/Lib/multiprocessing/pool.py +++ b/Lib/multiprocessing/pool.py @@ -115,6 +115,7 @@ class Pool(object): self._worker_handler._state = RUN self._worker_handler.start() + self._putlock = threading.BoundedSemaphore(self._processes) self._task_handler = threading.Thread( target=Pool._handle_tasks, @@ -126,7 +127,7 @@ class Pool(object): self._result_handler = threading.Thread( target=Pool._handle_results, - args=(self._outqueue, self._quick_get, self._cache) + args=(self._outqueue, self._quick_get, self._cache, self._putlock) ) self._result_handler.daemon = True self._result_handler._state = RUN @@ -150,6 +151,10 @@ class Pool(object): if worker.exitcode is not None: # worker exited debug('cleaning up worker %d' % i) + try: + self._putlock.release() + except ValueError: + pass worker.join() cleaned = True del self._pool[i] @@ -234,12 +239,22 @@ class Pool(object): for i, x in enumerate(task_batches)), result._set_length)) return (item for chunk in result for item in chunk) - def apply_async(self, func, args=(), kwds={}, callback=None): + def apply_async(self, func, args=(), kwds={}, callback=None, + waitforslot=False): ''' Asynchronous equivalent of `apply()` builtin + + :keyword callback: called when the target return value + is ready. Callback must accept a single positional argument, + which is the return value of `func`. + :keyword waitforslot: If ``True``, this function will not return + until there is a worker available to process the job. + ''' assert self._state == RUN result = ApplyResult(self._cache, callback) + if waitforslot: + self._putlock.acquire() self._taskqueue.put(([(result._job, None, func, args, kwds)], None)) return result @@ -311,7 +326,7 @@ class Pool(object): debug('task handler exiting') @staticmethod - def _handle_results(outqueue, get, cache): + def _handle_results(outqueue, get, cache, putlock): thread = threading.current_thread() while 1: @@ -321,6 +336,11 @@ class Pool(object): debug('result handler got EOFError/IOError -- exiting') return + try: + putlock.release() + except ValueError: + pass + if thread._state: assert thread._state == TERMINATE debug('result handler found thread._state=TERMINATE') @@ -336,6 +356,13 @@ class Pool(object): except KeyError: pass + # Release the semaphore. + while True: + try: + putlock.release() + except ValueError: + break + while cache and thread._state != TERMINATE: try: task = get() diff --git a/Lib/test/test_multiprocessing.py b/Lib/test/test_multiprocessing.py index 59b3357..7d53d29 100644 --- a/Lib/test/test_multiprocessing.py +++ b/Lib/test/test_multiprocessing.py @@ -1071,9 +1071,25 @@ class _TestPool(BaseTestCase): join() self.assertTrue(join.elapsed < 0.2) -class _TestPoolWorkerLifetime(BaseTestCase): +class _TestPoolWorkerSemaphore(BaseTestCase): ALLOWED_TYPES = ('processes', ) + + def test_async_waitforslot(self, interval=0.4, processes=3): + p = multiprocessing.Pool(processes) + results = [] + for i in xrange(processes + 1): + w = TimingWrapper(p.apply_async) + res = w(time.sleep, (interval, ), waitforslot=True) + results.append((w, res)) + for wrapper, res in results[:-1]: + self.assertLess(wrapper.elapsed, 0.1) + last_wrapper, last_res = results[-1] + self.assertGreater(last_wrapper.elapsed, interval - 0.1) + +class _TestPoolWorkerLifetime(BaseTestCase): + ALLOWED_TYPES = ('processes', ) + def test_pool_worker_lifetime(self): p = multiprocessing.Pool(3, maxtasksperchild=10) self.assertEqual(3, len(p._pool))