Index: Lib/contextlib.py =================================================================== --- Lib/contextlib.py (revision 82335) +++ Lib/contextlib.py (working copy) @@ -4,9 +4,20 @@ from functools import wraps from warnings import warn -__all__ = ["contextmanager", "closing"] +__all__ = ["contextmanager", "closing", "ContextDecorator"] -class GeneratorContextManager(object): + +class ContextDecorator(object): + "A base class or mixin that enables context managers to work as decorators." + def __call__(self, func): + @wraps(func) + def inner(*args, **kwds): + with self: + return func(*args, **kwds) + return inner + + +class GeneratorContextManager(ContextDecorator): """Helper for @contextmanager decorator.""" def __init__(self, gen): Index: Lib/test/test_contextlib.py =================================================================== --- Lib/test/test_contextlib.py (revision 82335) +++ Lib/test/test_contextlib.py (working copy) @@ -202,6 +202,150 @@ return True self.boilerPlate(lock, locked) + +class mycontext(ContextDecorator): + started = False + exc = None + catch = False + + def __enter__(self): + self.started = True + return self + + def __exit__(self, *exc): + self.exc = exc + return self.catch + + +class TestContextDecorator(unittest.TestCase): + + def test_contextdecorator(self): + context = mycontext() + with context as result: + self.assertIs(result, context) + self.assertTrue(context.started) + + self.assertEqual(context.exc, (None, None, None)) + + def test_contextdecorator_with_exception(self): + context = mycontext() + + with self.assertRaisesRegexp(NameError, 'foo'): + with context: + raise NameError('foo') + + context = mycontext() + context.catch = True + with context: + raise NameError('foo') + self.assertIsNotNone(context.exc) + + def test_decorator(self): + context = mycontext() + + @context + def test(): + self.assertIsNone(context.exc) + self.assertTrue(context.started) + test() + self.assertEqual(context.exc, (None, None, None)) + + def test_decorator_with_exception(self): + context = mycontext() + + @context + def test(): + self.assertIsNone(context.exc) + self.assertTrue(context.started) + raise NameError('foo') + + with self.assertRaisesRegexp(NameError, 'foo'): + test() + self.assertNotEqual(context.exc, (None, None, None)) + + def test_decorating_method(self): + context = mycontext() + + class Test(object): + + @context + def method(self, a, b, c=None): + self.a = a + self.b = b + self.c = c + + # these tests + test = Test() + test.method(1, 2) + self.assertEqual(test.a, 1) + self.assertEqual(test.b, 2) + self.assertEqual(test.c, None) + + test = Test() + test.method('a', 'b', 'c') + self.assertEqual(test.a, 'a') + self.assertEqual(test.b, 'b') + self.assertEqual(test.c, 'c') + + test = Test() + test.method(a=1, b=2) + self.assertEqual(test.a, 1) + self.assertEqual(test.b, 2) + + + def test_typo(self): + class mycontext(ContextDecorator): + def __enter__(self): + pass + def __uxit__(self, *exc): + pass + + with self.assertRaises(AttributeError): + with mycontext(): + pass + + + def test_decorator_as_mixin(self): + class somecontext(object): + started = False + exc = None + + def __enter__(self): + self.started = True + return self + + def __exit__(self, *exc): + self.exc = exc + + class mycontext(somecontext, ContextDecorator): + pass + + context = mycontext() + + @context + def test(): + self.assertIsNone(context.exc) + self.assertTrue(context.started) + test() + self.assertEqual(context.exc, (None, None, None)) + + + def test_contextmanager_as_decorator(self): + state = [] + @contextmanager + def woohoo(y): + state.append(y) + yield + state.append(999) + + @woohoo(1) + def test(x): + self.assertEqual(state, [1]) + state.append(x) + test('something') + self.assertEqual(state, [1, 'something', 999]) + + # This is needed to make the test actually run under regrtest.py! def test_main(): support.run_unittest(__name__)