diff --git a/Lib/multiprocessing/pool.py b/Lib/multiprocessing/pool.py index 8555da9..2be1ad2 100644 --- a/Lib/multiprocessing/pool.py +++ b/Lib/multiprocessing/pool.py @@ -42,6 +42,23 @@ def mapstar(args): # Code run by worker processes # +class MaybeEncodingError(Exception): + """Wraps possible unpickleable errors, so they can be + safely sent through the socket.""" + + def __init__(self, exc, value): + self.exc = str(exc) + self.value = repr(value) + super(MaybeEncodingError, self).__init__(self.exc, self.value) + + def __str__(self): + return "Error sending result: '%s'. Reason: '%s'" % (self.value, + self.exc) + + def __repr__(self): + return "" % str(self) + + def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None): assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0) put = outqueue.put @@ -70,7 +87,13 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None): result = (True, func(*args, **kwds)) except Exception, e: result = (False, e) - put((job, i, result)) + try: + put((job, i, result)) + except Exception, e: + wrapped = MaybeEncodingError(e, result[1]) + debug("Possible encoding error while sending result: %s" % ( + wrapped)) + put((job, i, (False, wrapped))) completed += 1 debug('worker exiting after %d tasks' % completed) @@ -234,12 +257,13 @@ class Pool(object): for i, x in enumerate(task_batches)), result._set_length)) return (item for chunk in result for item in chunk) - def apply_async(self, func, args=(), kwds={}, callback=None): + def apply_async(self, func, args=(), kwds={}, callback=None, + error_callback=None): ''' Asynchronous equivalent of `apply()` builtin ''' assert self._state == RUN - result = ApplyResult(self._cache, callback) + result = ApplyResult(self._cache, callback, error_callback) self._taskqueue.put(([(result._job, None, func, args, kwds)], None)) return result @@ -458,12 +482,13 @@ class Pool(object): class ApplyResult(object): - def __init__(self, cache, callback): + def __init__(self, cache, callback, error_callback=None): self._cond = threading.Condition(threading.Lock()) self._job = job_counter.next() self._cache = cache self._ready = False self._callback = callback + self._errback = error_callback cache[self._job] = self def ready(self): @@ -494,6 +519,8 @@ class ApplyResult(object): self._success, self._value = obj if self._callback and self._success: self._callback(self._value) + if self._errback and not self._success: + self._errback(self._value) self._cond.acquire() try: self._ready = True diff --git a/Lib/test/test_multiprocessing.py b/Lib/test/test_multiprocessing.py index 59b3357..41aeda8 100644 --- a/Lib/test/test_multiprocessing.py +++ b/Lib/test/test_multiprocessing.py @@ -995,6 +995,7 @@ class _TestContainers(BaseTestCase): def sqr(x, wait=0.0): time.sleep(wait) return x*x + class _TestPool(BaseTestCase): def test_apply(self): @@ -1020,6 +1021,7 @@ class _TestPool(BaseTestCase): self.assertEqual(get(), 49) self.assertTimingAlmostEqual(get.elapsed, TIMEOUT1) + def test_async_timeout(self): res = self.pool.apply_async(sqr, (6, TIMEOUT2 + 0.2)) get = TimingWrapper(res.get) @@ -1071,6 +1073,54 @@ class _TestPool(BaseTestCase): join() self.assertTrue(join.elapsed < 0.2) +def raising(): + raise KeyError("key") + +def unpickleable_result(): + return lambda: 42 + +class _TestPoolWorkerErrors(BaseTestCase): + ALLOWED_TYPES = ('processes', ) + + def test_async_error_callback(self): + p = multiprocessing.Pool(2) + + scratchpad = [None] + def errback(exc): + scratchpad[0] = exc + + res = p.apply_async(raising, error_callback=errback) + self.assertRaises(KeyError, res.get) + self.assertTrue(scratchpad[0]) + self.assertIsInstance(scratchpad[0], KeyError) + + p.close() + p.join() + + def test_unpickleable_result(self): + from multiprocessing.pool import MaybeEncodingError + p = multiprocessing.Pool(2) + + # Make sure we don't lose pool processes because of encoding errors. + for iteration in xrange(20): + + scratchpad = [None] + def errback(exc): + scratchpad[0] = exc + + res = p.apply_async(unpickleable_result, error_callback=errback) + self.assertRaises(MaybeEncodingError, res.get) + wrapped = scratchpad[0] + self.assertTrue(wrapped) + self.assertIsInstance(scratchpad[0], MaybeEncodingError) + self.assertIsNotNone(wrapped.exc) + self.assertIsNotNone(wrapped.value) + + p.close() + p.join() + + + class _TestPoolWorkerLifetime(BaseTestCase): ALLOWED_TYPES = ('processes', )