diff -r 20b8f0ee3d64 Lib/contextlib.py --- a/Lib/contextlib.py Wed May 30 17:57:50 2012 +0300 +++ b/Lib/contextlib.py Thu May 31 12:21:45 2012 +0300 @@ -225,32 +225,17 @@ return self def __exit__(self, *exc_details): - if not self._exit_callbacks: - return - # This looks complicated, but it is really just - # setting up a chain of try-expect statements to ensure - # that outer callbacks still get invoked even if an - # inner one throws an exception - def _invoke_next_callback(exc_details): - # Callbacks are removed from the list in FIFO order - # but the recursion means they're invoked in LIFO order - cb = self._exit_callbacks.popleft() - if not self._exit_callbacks: - # Innermost callback is invoked directly - return cb(*exc_details) - # More callbacks left, so descend another level in the stack + while self._exit_callbacks: + cb = self._exit_callbacks.pop() try: - suppress_exc = _invoke_next_callback(exc_details) + if cb(*exc_details): + exc_details = (None, None, None) except: - suppress_exc = cb(*sys.exc_info()) - # Check if this cb suppressed the inner exception - if not suppress_exc: + new_exc_details = sys.exc_info() + if exc_details != (None, None, None): + # simulate the stack of exceptions by setting the context + new_exc_details[1].__context__ = exc_details[1] + if not self._exit_callbacks: raise - else: - # Check if inner cb suppressed the original exception - if suppress_exc: - exc_details = (None, None, None) - suppress_exc = cb(*exc_details) or suppress_exc - return suppress_exc - # Kick off the recursive chain - return _invoke_next_callback(exc_details) + exc_details = new_exc_details + return exc_details == (None, None, None) diff -r 20b8f0ee3d64 Lib/test/test_contextlib.py --- a/Lib/test/test_contextlib.py Wed May 30 17:57:50 2012 +0300 +++ b/Lib/test/test_contextlib.py Thu May 31 12:21:45 2012 +0300 @@ -483,6 +483,43 @@ new_stack.close() self.assertEqual(result, [1, 2, 3]) + def test_exit_raise(self): + with self.assertRaises(ZeroDivisionError): + with ExitStack() as stack: + stack.push(lambda *exc: False) + 1/0 + + def test_exit_suppress(self): + with ExitStack() as stack: + stack.push(lambda *exc: True) + 1/0 + + def test_exit_change_exception(self): + def raise_exc(exception): + def func(*exc): raise exception + return func + try: + with ExitStack() as stack: + stack.push(raise_exc(IndexError())) + stack.push(raise_exc(KeyError())) + stack.push(raise_exc(AttributeError())) + stack.push(lambda *exc: True) + stack.push(raise_exc(ValueError())) + 1 / 0 + except IndexError as e: + self.assertIsInstance(e.__context__, KeyError) + self.assertIsInstance(e.__context__.__context__, AttributeError) + # the ValueError was skipped and the original exception is found + self.assertIsInstance(e.__context__.__context__.__context__, ZeroDivisionError) + else: + self.fail() + + def test_exit_change_exception_suppress(self): + with ExitStack() as stack: + stack.push(lambda *exc: True) + stack.push(lambda *exc: 1/0) + stack.push(lambda *exc: {}[1]) + def test_instance_bypass(self): class Example(object): pass cm = Example()