diff --git a/Include/symtable.h b/Include/symtable.h index 86ae3c28e8..375edeb5b5 100644 --- a/Include/symtable.h +++ b/Include/symtable.h @@ -47,6 +47,7 @@ typedef struct _symtable_entry { unsigned ste_free : 1; /* true if block has free variables */ unsigned ste_child_free : 1; /* true if a child block has free vars, including free refs to globals */ + unsigned ste_yields : 1; /* true if block has yields */ unsigned ste_generator : 1; /* true if namespace is a generator */ unsigned ste_coroutine : 1; /* true if namespace is a coroutine */ unsigned ste_varargs : 1; /* true if block has varargs */ diff --git a/Lib/test/test_grammar.py b/Lib/test/test_grammar.py index 65e26bfd38..b89c1ded7d 100644 --- a/Lib/test/test_grammar.py +++ b/Lib/test/test_grammar.py @@ -841,6 +841,43 @@ class GrammarTests(unittest.TestCase): # Check annotation refleak on SyntaxError check_syntax_error(self, "def g(a:(yield)): pass") + def test_yield_in_comprehensions(self): + # Check yield in comprehensions + def g(): [(yield x) for x in range(3)] + def g(): {(yield x) for x in range(3)} + def g(): {(yield x): (yield x) for x in range(3)} + def g(): [(yield from x) for x in ('a', 'bc')] + def g(): {(yield from x) for x in ('a', 'bc')} + def g(): {(yield from x): (yield from x) for x in ('a', 'bc')} + def g(): [x for x in range(3) if not (yield x)] + def g(): {x for x in range(3) if not (yield x)} + def g(): {x: x for x in range(3) if not (yield x)} + def g(): [y for x in range(3) for y in [(yield x)]] + def g(): {y for x in range(3) for y in [(yield x)]} + def g(): {y: y for x in range(3) for y in [(yield x)]} + # Not allowed at top level + check_syntax_error(self, "[(yield x) for x in range(3)]") + check_syntax_error(self, "{(yield x) for x in range(3)}") + check_syntax_error(self, "{(yield x): (yield x) for x in range(3)}") + check_syntax_error(self, "[x for x in range(3) if not (yield x)]") + check_syntax_error(self, "{x for x in range(3) if not (yield x)}") + check_syntax_error(self, "{x: x for x in range(3) if not (yield x)}") + check_syntax_error(self, "[y for x in range(3) for y in [(yield x)]]") + check_syntax_error(self, "{y for x in range(3) for y in [(yield x)]}") + check_syntax_error(self, "{y: y for x in range(3) for y in [(yield x)]}") + # Not allowed at class scope + check_syntax_error(self, "class C:[(yield x) for x in range(3)]") + check_syntax_error(self, "class C:{(yield x) for x in range(3)}") + check_syntax_error(self, "class C:{(yield x): (yield x) for x in range(3)}") + check_syntax_error(self, "class C:[x for x in range(3) if not (yield x)]") + check_syntax_error(self, "class C:{x for x in range(3) if not (yield x)}") + check_syntax_error(self, "class C:{x: x for x in range(3) if not (yield x)}") + check_syntax_error(self, "class C:[y for x in range(3) for y in [(yield x)]]") + check_syntax_error(self, "class C:{y for x in range(3) for y in [(yield x)]}") + check_syntax_error(self, "class C:{y: y for x in range(3) for y in [(yield x)]}") + # Not allowed in generators + check_syntax_error(self, "def g(): ((yield x) for x in range(3))") + def test_raise(self): # 'raise' test [',' test] try: raise RuntimeError('just testing') diff --git a/Lib/test/test_yield_from.py b/Lib/test/test_yield_from.py index ce21c3df81..7b33d91587 100644 --- a/Lib/test/test_yield_from.py +++ b/Lib/test/test_yield_from.py @@ -1044,6 +1044,122 @@ class TestPEP380Operation(unittest.TestCase): g.send((1, 2, 3, 4)) self.assertEqual(v, (1, 2, 3, 4)) + def test_yield_in_comprehensions(self): + # Check yield in comprehensions + def g(): return [(yield x) for x in range(2)] + self.assertEqual(list(g()), [0, 1]) + it = g() + self.assertEqual(next(it), 0) + self.assertEqual(it.send('a'), 1) + try: + it.send('b') + except StopIteration as exc: + self.assertEqual(exc.value, ['a', 'b']) + + def g(): return {(yield x) for x in range(2)} + self.assertEqual(list(g()), [0, 1]) + it = g() + self.assertEqual(next(it), 0) + self.assertEqual(it.send('a'), 1) + try: + it.send('b') + except StopIteration as exc: + self.assertEqual(exc.value, {'a', 'b'}) + + def g(): return {(yield x): (yield str(x)) for x in range(2)} + self.assertEqual(list(g()), ['0', 0, '1', 1]) + it = g() + self.assertEqual(next(it), '0') + self.assertEqual(it.send('a'), 0) + self.assertEqual(it.send('b'), '1') + self.assertEqual(it.send('c'), 1) + try: + it.send('d') + except StopIteration as exc: + self.assertEqual(exc.value, {'b': 'a', 'd': 'c'}) + + def g2(): + nonlocal res + res = (yield from g()) + + def g(): return [(yield from x) for x in ('a', 'bc')] + self.assertEqual(list(g()), ['a', 'b', 'c']) + res = None + list(g2()) + self.assertEqual(res, [None, None]) + + def g(): return {(yield from x) for x in ('a', 'bc')} + self.assertEqual(list(g()), ['a', 'b', 'c']) + res = None + list(g2()) + self.assertEqual(res, {None}) + + def g(): return {(yield from x): (yield from x) for x in ('a', 'bc')} + self.assertEqual(list(g()), ['a', 'a', 'b', 'c', 'b', 'c']) + res = None + list(g2()) + self.assertEqual(res, {None: None}) + + def g(): return [x for x in range(2) if not (yield x)] + self.assertEqual(list(g()), [0, 1]) + it = g() + self.assertEqual(next(it), 0) + self.assertEqual(it.send(True), 1) + try: + it.send(False) + except StopIteration as exc: + self.assertEqual(exc.value, [1]) + + def g(): return {x for x in range(2) if not (yield x)} + self.assertEqual(list(g()), [0, 1]) + it = g() + self.assertEqual(next(it), 0) + self.assertEqual(it.send(True), 1) + try: + it.send(False) + except StopIteration as exc: + self.assertEqual(exc.value, {1}) + + def g(): return {x: x for x in range(2) if not (yield x)} + self.assertEqual(list(g()), [0, 1]) + it = g() + self.assertEqual(next(it), 0) + self.assertEqual(it.send(True), 1) + try: + it.send(False) + except StopIteration as exc: + self.assertEqual(exc.value, {1: 1}) + + def g(): return [y for x in range(2) for y in [(yield x)]] + self.assertEqual(list(g()), [0, 1]) + it = g() + self.assertEqual(next(it), 0) + self.assertEqual(it.send('a'), 1) + try: + it.send('b') + except StopIteration as exc: + self.assertEqual(exc.value, ['a', 'b']) + + def g(): return {y for x in range(2) for y in [(yield x)]} + self.assertEqual(list(g()), [0, 1]) + it = g() + self.assertEqual(next(it), 0) + self.assertEqual(it.send('a'), 1) + try: + it.send('b') + except StopIteration as exc: + self.assertEqual(exc.value, {'a', 'b'}) + + def g(): return {y: x for x in range(2) for y in [(yield x)]} + self.assertEqual(list(g()), [0, 1]) + it = g() + self.assertEqual(next(it), 0) + self.assertEqual(it.send('a'), 1) + try: + it.send('b') + except StopIteration as exc: + self.assertEqual(exc.value, {'a': 0, 'b': 1}) + if __name__ == '__main__': unittest.main() diff --git a/Python/compile.c b/Python/compile.c index a3ea60d07c..a54d30fc95 100644 --- a/Python/compile.c +++ b/Python/compile.c @@ -3961,8 +3961,10 @@ compiler_comprehension(struct compiler *c, expr_ty e, int type, PyCodeObject *co = NULL; comprehension_ty outermost; PyObject *qualname = NULL; + int is_function = (c->u->u_ste->ste_type == FunctionBlock); int is_async_function = c->u->u_ste->ste_coroutine; int is_async_generator = 0; + int has_yields = 0; outermost = (comprehension_ty) asdl_seq_GET(generators, 0); @@ -3973,8 +3975,25 @@ compiler_comprehension(struct compiler *c, expr_ty e, int type, } is_async_generator = c->u->u_ste->ste_coroutine; + has_yields = c->u->u_ste->ste_yields; - if (is_async_generator && !is_async_function && type != COMP_GENEXP) { + if (has_yields && type == COMP_GENEXP) { + if (e->lineno > c->u->u_lineno) { + c->u->u_lineno = e->lineno; + c->u->u_lineno_set = 0; + } + compiler_error(c, "'yield' inside generator expression"); + goto error_in_scope; + } + else if (has_yields && !is_function && type != COMP_GENEXP) { + if (e->lineno > c->u->u_lineno) { + c->u->u_lineno = e->lineno; + c->u->u_lineno_set = 0; + } + compiler_error(c, "'yield' outside function"); + goto error_in_scope; + } + else if (is_async_generator && !is_async_function && type != COMP_GENEXP) { if (e->lineno > c->u->u_lineno) { c->u->u_lineno = e->lineno; c->u->u_lineno_set = 0; @@ -4035,8 +4054,13 @@ compiler_comprehension(struct compiler *c, expr_ty e, int type, ADDOP_I(c, CALL_FUNCTION, 1); - if (is_async_generator && type != COMP_GENEXP) { - ADDOP(c, GET_AWAITABLE); + if ((is_async_generator || has_yields) && type != COMP_GENEXP) { + if (is_async_generator) { + ADDOP(c, GET_AWAITABLE); + } + else { + ADDOP(c, GET_YIELD_FROM_ITER); + } ADDOP_O(c, LOAD_CONST, Py_None, consts); ADDOP(c, YIELD_FROM); } diff --git a/Python/symtable.c b/Python/symtable.c index 55815c91cc..c5cc69919a 100644 --- a/Python/symtable.c +++ b/Python/symtable.c @@ -78,6 +78,7 @@ ste_new(struct symtable *st, identifier name, _Py_block_ty block, st->st_cur->ste_type == FunctionBlock)) ste->ste_nested = 1; ste->ste_child_free = 0; + ste->ste_yields = 0; ste->ste_generator = 0; ste->ste_coroutine = 0; ste->ste_returns_value = 0; @@ -1452,10 +1453,12 @@ symtable_visit_expr(struct symtable *st, expr_ty e) case Yield_kind: if (e->v.Yield.value) VISIT(st, expr, e->v.Yield.value); + st->st_cur->ste_yields = 1; st->st_cur->ste_generator = 1; break; case YieldFrom_kind: VISIT(st, expr, e->v.YieldFrom.value); + st->st_cur->ste_yields = 1; st->st_cur->ste_generator = 1; break; case Await_kind: @@ -1724,6 +1727,7 @@ symtable_handle_comprehension(struct symtable *st, expr_ty e, { int is_generator = (e->kind == GeneratorExp_kind); int needs_tmp = !is_generator; + PySTEntryObject *prev = st->st_cur; comprehension_ty outermost = ((comprehension_ty) asdl_seq_GET(generators, 0)); /* Outermost iterator is evaluated in current scope */ @@ -1754,6 +1758,9 @@ symtable_handle_comprehension(struct symtable *st, expr_ty e, if (value) VISIT(st, expr, value); VISIT(st, expr, elt); + if (st->st_cur->ste_yields) { + prev->ste_generator = 1; + } return symtable_exit_block(st, (void *)e); }