Index: Doc/library/contextlib.rst =================================================================== --- Doc/library/contextlib.rst (revision 82389) +++ Doc/library/contextlib.rst (working copy) @@ -50,7 +50,12 @@ generator context manager will indicate to the :keyword:`with` statement that the exception has been handled, and execution will resume with the statement immediately following the :keyword:`with` statement. + + contextmanager uses :class:`ContextDecorator` so the context managers it + creates can be used as decorators as well as in :keyword:`with` statements. + .. versionchanged:: 3.2 + Use of :class:`ContextDecorator`. .. function:: closing(thing) @@ -79,6 +84,53 @@ ``page.close()`` will be called when the :keyword:`with` block is exited. +.. class:: ContextDecorator() + + A base class that enables a context manager to also be used as a decorator. + + Context managers inheriting from ``ContextDecorator`` have to implement + ``__enter__`` and ``__exit__`` as normal. ``__exit__`` retains its optional + exception handling even when used as a decorator. + + Example:: + + from contextlib import ContextDecorator + + class mycontext(ContextDecorator): + def __enter__(self): + print('Starting') + return self + + def __exit__(self, *exc): + print('Finishing') + return False + + @mycontext() + def function(): + print('The bit in the middle') + + >>> with mycontext(): + ... print('The bit in the middle') + ... + Starting + The bit in the middle + Finishing + + Existing context managers that already have a base class can be extended by + using ``ContextDecorator`` as a mixin class:: + + from contextlib import ContextDecorator + + class mycontext(ContextBaseClass, ContextDecorator): + def __enter__(self): + return self + + def __exit__(self, *exc): + return False + + .. versionadded:: 3.2 + + .. seealso:: :pep:`0343` - The "with" statement Index: Lib/contextlib.py =================================================================== --- Lib/contextlib.py (revision 82389) +++ Lib/contextlib.py (working copy) @@ -4,9 +4,20 @@ from functools import wraps from warnings import warn -__all__ = ["contextmanager", "closing"] +__all__ = ["contextmanager", "closing", "ContextDecorator"] -class GeneratorContextManager(object): + +class ContextDecorator(object): + "A base class or mixin that enables context managers to work as decorators." + def __call__(self, func): + @wraps(func) + def inner(*args, **kwds): + with self: + return func(*args, **kwds) + return inner + + +class GeneratorContextManager(ContextDecorator): """Helper for @contextmanager decorator.""" def __init__(self, gen): Index: Lib/test/test_contextlib.py =================================================================== --- Lib/test/test_contextlib.py (revision 82389) +++ Lib/test/test_contextlib.py (working copy) @@ -202,6 +202,169 @@ return True self.boilerPlate(lock, locked) + +class mycontext(ContextDecorator): + started = False + exc = None + catch = False + + def __enter__(self): + self.started = True + return self + + def __exit__(self, *exc): + self.exc = exc + return self.catch + + +class TestContextDecorator(unittest.TestCase): + + def test_contextdecorator(self): + context = mycontext() + with context as result: + self.assertIs(result, context) + self.assertTrue(context.started) + + self.assertEqual(context.exc, (None, None, None)) + + + def test_contextdecorator_with_exception(self): + context = mycontext() + + with self.assertRaisesRegexp(NameError, 'foo'): + with context: + raise NameError('foo') + self.assertIsNotNone(context.exc) + self.assertIs(context.exc[0], NameError) + + context = mycontext() + context.catch = True + with context: + raise NameError('foo') + self.assertIsNotNone(context.exc) + self.assertIs(context.exc[0], NameError) + + + def test_decorator(self): + context = mycontext() + + @context + def test(): + self.assertIsNone(context.exc) + self.assertTrue(context.started) + test() + self.assertEqual(context.exc, (None, None, None)) + + + def test_decorator_with_exception(self): + context = mycontext() + + @context + def test(): + self.assertIsNone(context.exc) + self.assertTrue(context.started) + raise NameError('foo') + + with self.assertRaisesRegexp(NameError, 'foo'): + test() + self.assertIsNotNone(context.exc) + self.assertIs(context.exc[0], NameError) + + + def test_decorating_method(self): + context = mycontext() + + class Test(object): + + @context + def method(self, a, b, c=None): + self.a = a + self.b = b + self.c = c + + # these tests are for argument passing when used as a decorator + test = Test() + test.method(1, 2) + self.assertEqual(test.a, 1) + self.assertEqual(test.b, 2) + self.assertEqual(test.c, None) + + test = Test() + test.method('a', 'b', 'c') + self.assertEqual(test.a, 'a') + self.assertEqual(test.b, 'b') + self.assertEqual(test.c, 'c') + + test = Test() + test.method(a=1, b=2) + self.assertEqual(test.a, 1) + self.assertEqual(test.b, 2) + + + def test_typo_enter(self): + class mycontext(ContextDecorator): + def __unter__(self): + pass + def __exit__(self, *exc): + pass + + with self.assertRaises(AttributeError): + with mycontext(): + pass + + + def test_typo_exit(self): + class mycontext(ContextDecorator): + def __enter__(self): + pass + def __uxit__(self, *exc): + pass + + with self.assertRaises(AttributeError): + with mycontext(): + pass + + + def test_contextdecorator_as_mixin(self): + class somecontext(object): + started = False + exc = None + + def __enter__(self): + self.started = True + return self + + def __exit__(self, *exc): + self.exc = exc + + class mycontext(somecontext, ContextDecorator): + pass + + context = mycontext() + @context + def test(): + self.assertIsNone(context.exc) + self.assertTrue(context.started) + test() + self.assertEqual(context.exc, (None, None, None)) + + + def test_contextmanager_as_decorator(self): + state = [] + @contextmanager + def woohoo(y): + state.append(y) + yield + state.append(999) + + @woohoo(1) + def test(x): + self.assertEqual(state, [1]) + state.append(x) + test('something') + self.assertEqual(state, [1, 'something', 999]) + + # This is needed to make the test actually run under regrtest.py! def test_main(): support.run_unittest(__name__)