diff --git a/Lib/multiprocessing/pool.py b/Lib/multiprocessing/pool.py index 8555da9..f40cfc8 100644 --- a/Lib/multiprocessing/pool.py +++ b/Lib/multiprocessing/pool.py @@ -42,6 +42,23 @@ def mapstar(args): # Code run by worker processes # +class MaybeEncodingError(Exception): + """Wraps possible unpickleable errors, so they can be + safely sent through the socket.""" + + def __init__(self, exc, value): + self.exc = str(exc) + self.value = repr(value) + super(MaybeEncodingError, self).__init__(self.exc, self.value) + + def __str__(self): + return "Error sending result: '%s'. Reason: '%s'" % (self.value, + self.exc) + + def __repr__(self): + return "" % str(self) + + def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None): assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0) put = outqueue.put @@ -70,7 +87,13 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None): result = (True, func(*args, **kwds)) except Exception, e: result = (False, e) - put((job, i, result)) + try: + put((job, i, result)) + except Exception, e: + wrapped = MaybeEncodingError(e, result[1]) + debug("Possible encoding error while sending result: %s" % ( + wrapped)) + put((job, i, (False, wrapped))) completed += 1 debug('worker exiting after %d tasks' % completed) @@ -234,16 +257,18 @@ class Pool(object): for i, x in enumerate(task_batches)), result._set_length)) return (item for chunk in result for item in chunk) - def apply_async(self, func, args=(), kwds={}, callback=None): + def apply_async(self, func, args=(), kwds={}, callback=None, + error_callback=None): ''' Asynchronous equivalent of `apply()` builtin ''' assert self._state == RUN - result = ApplyResult(self._cache, callback) + result = ApplyResult(self._cache, callback, error_callback) self._taskqueue.put(([(result._job, None, func, args, kwds)], None)) return result - def map_async(self, func, iterable, chunksize=None, callback=None): + def map_async(self, func, iterable, chunksize=None, callback=None, + error_callback=None): ''' Asynchronous equivalent of `map()` builtin ''' @@ -259,7 +284,8 @@ class Pool(object): chunksize = 0 task_batches = Pool._get_tasks(func, iterable, chunksize) - result = MapResult(self._cache, chunksize, len(iterable), callback) + result = MapResult(self._cache, chunksize, len(iterable), callback, + error_callback=error_callback) self._taskqueue.put((((result._job, i, mapstar, (x,), {}) for i, x in enumerate(task_batches)), None)) return result @@ -458,12 +484,13 @@ class Pool(object): class ApplyResult(object): - def __init__(self, cache, callback): + def __init__(self, cache, callback, error_callback=None): self._cond = threading.Condition(threading.Lock()) self._job = job_counter.next() self._cache = cache self._ready = False self._callback = callback + self._error_callback = error_callback cache[self._job] = self def ready(self): @@ -494,6 +521,8 @@ class ApplyResult(object): self._success, self._value = obj if self._callback and self._success: self._callback(self._value) + if self._error_callback and not self._success: + self._error_callback(self._value) self._cond.acquire() try: self._ready = True @@ -508,8 +537,9 @@ class ApplyResult(object): class MapResult(ApplyResult): - def __init__(self, cache, chunksize, length, callback): - ApplyResult.__init__(self, cache, callback) + def __init__(self, cache, chunksize, length, callback, error_callback): + ApplyResult.__init__(self, cache, callback, + error_callback=error_callback) self._success = True self._value = [None] * length self._chunksize = chunksize @@ -534,10 +564,11 @@ class MapResult(ApplyResult): self._cond.notify() finally: self._cond.release() - else: self._success = False self._value = result + if self._error_callback: + self._error_callback(self._value) del self._cache[self._job] self._cond.acquire() try: diff --git a/Lib/test/test_multiprocessing.py b/Lib/test/test_multiprocessing.py index 59b3357..80711d1 100644 --- a/Lib/test/test_multiprocessing.py +++ b/Lib/test/test_multiprocessing.py @@ -995,6 +995,7 @@ class _TestContainers(BaseTestCase): def sqr(x, wait=0.0): time.sleep(wait) return x*x + class _TestPool(BaseTestCase): def test_apply(self): @@ -1071,9 +1072,55 @@ class _TestPool(BaseTestCase): join() self.assertTrue(join.elapsed < 0.2) -class _TestPoolWorkerLifetime(BaseTestCase): +def raising(): + raise KeyError("key") + +def unpickleable_result(): + return lambda: 42 + +class _TestPoolWorkerErrors(BaseTestCase): + ALLOWED_TYPES = ('processes', ) + + def test_async_error_callback(self): + p = multiprocessing.Pool(2) + + scratchpad = [None] + def errback(exc): + scratchpad[0] = exc + + res = p.apply_async(raising, error_callback=errback) + self.assertRaises(KeyError, res.get) + self.assertTrue(scratchpad[0]) + self.assertIsInstance(scratchpad[0], KeyError) + + p.close() + p.join() + + def test_unpickleable_result(self): + from multiprocessing.pool import MaybeEncodingError + p = multiprocessing.Pool(2) + + # Make sure we don't lose pool processes because of encoding errors. + for iteration in xrange(20): + + scratchpad = [None] + def errback(exc): + scratchpad[0] = exc + + res = p.apply_async(unpickleable_result, error_callback=errback) + self.assertRaises(MaybeEncodingError, res.get) + wrapped = scratchpad[0] + self.assertTrue(wrapped) + self.assertIsInstance(scratchpad[0], MaybeEncodingError) + self.assertIsNotNone(wrapped.exc) + self.assertIsNotNone(wrapped.value) + p.close() + p.join() + +class _TestPoolWorkerLifetime(BaseTestCase): ALLOWED_TYPES = ('processes', ) + def test_pool_worker_lifetime(self): p = multiprocessing.Pool(3, maxtasksperchild=10) self.assertEqual(3, len(p._pool))