diff -r 7356f71fb0a4 Lib/test/test_types.py --- a/Lib/test/test_types.py Fri May 29 09:06:05 2015 -0400 +++ b/Lib/test/test_types.py Fri May 29 22:01:30 2015 +0200 @@ -1206,28 +1206,43 @@ @types.coroutine def foo(): pass - @types.coroutine - def gen(): - def _gen(): yield - return _gen() - - for sample in (foo, gen): - with self.assertRaisesRegex(TypeError, - 'callable wrapped .* non-coroutine'): - sample() + with self.assertRaisesRegex(TypeError, + 'callable wrapped .* non-coroutine'): + foo() def test_duck_coro(self): class CoroLike: def send(self): pass def throw(self): pass def close(self): pass - def __await__(self): pass + def __await__(self): return self coro = CoroLike() @types.coroutine def foo(): return coro - self.assertIs(coro, foo()) + self.assertIs(foo().__await__(), coro) + + def test_duck_gen(self): + class GenLike: + def send(self): pass + def throw(self): pass + def close(self): pass + def __iter__(self): return self + def __next__(self): pass + + gen = GenLike() + @types.coroutine + def foo(): + return gen + self.assertIs(foo().__await__(), gen) + + def test_gen(self): + def gen(): yield + gen = gen() + @types.coroutine + def foo(): return gen + self.assertIs(foo().__await__(), gen) def test_genfunc(self): def gen(): diff -r 7356f71fb0a4 Lib/types.py --- a/Lib/types.py Fri May 29 09:06:05 2015 -0400 +++ b/Lib/types.py Fri May 29 22:01:30 2015 +0200 @@ -166,31 +166,55 @@ # We don't want to import 'dis' or 'inspect' just for # these constants. - _CO_GENERATOR = 0x20 - _CO_ITERABLE_COROUTINE = 0x100 + CO_GENERATOR = 0x20 + CO_ITERABLE_COROUTINE = 0x100 if not callable(func): raise TypeError('types.coroutine() expects a callable') if (isinstance(func, FunctionType) and isinstance(getattr(func, '__code__', None), CodeType) and - (func.__code__.co_flags & _CO_GENERATOR)): + (func.__code__.co_flags & CO_GENERATOR)): # TODO: Implement this in C. co = func.__code__ func.__code__ = CodeType( co.co_argcount, co.co_kwonlyargcount, co.co_nlocals, co.co_stacksize, - co.co_flags | _CO_ITERABLE_COROUTINE, + co.co_flags | CO_ITERABLE_COROUTINE, co.co_code, co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name, co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars) return func + # The following code is primarily to support functions that + # return generator-like objects (for instance generators + # compiled with Cython). + + class AwaitableGeneratorWrapper: + def __init__(self, gen): + self.__wrapped__ = gen + self.send = gen.send + self.throw = gen.throw + self.close = gen.close + self.__name__ = getattr(gen, '__name__', None) + self.__qualname__ = getattr(gen, '__qualname__', None) + try: + self.gi_code = gen.gi_code + except AttributeError: + pass + + def __iter__(self): return self.__wrapped__ + def __next__(self): return next(self.__wrapped__) + __await__ = __iter__ + @_functools.wraps(func) def wrapped(*args, **kwargs): coro = func(*args, **kwargs) + if (coro.__class__ is GeneratorType or + isinstance(coro, _collections_abc.Generator)): + return AwaitableGeneratorWrapper(coro) if not isinstance(coro, _collections_abc.Coroutine): raise TypeError( 'callable wrapped with types.coroutine() returned '