diff --git a/Lib/test/test_fstring.py b/Lib/test/test_fstring.py new file mode 100644 --- /dev/null +++ b/Lib/test/test_fstring.py @@ -0,0 +1,620 @@ +import unittest +import decimal +import ast + +a_global = 'global variable' + +class TestCase(unittest.TestCase): + def assertAllRaise(self, exception_type, error_strings): + for str in error_strings: + with self.subTest(str=str): + with self.assertRaises(exception_type): + eval(str) + + def test_ast(self): + # inspired by http://bugs.python.org/issue24975 + class X: + def __init__(self): + self.called = False + def __call__(self): + self.called = True + return 4 + x = X() + expr = """ +a = 10 +f'{a * x()}'""" + t = ast.parse(expr) + c = compile(t, '', 'exec') + + # make sure x was not called + self.assertFalse(x.called) + + # actually run the code + exec(c) + + # make sure x was called + self.assertTrue(x.called) + + def test_literal_eval(self): + # With no expressions, an f-string is okay + self.assertEqual(ast.literal_eval("f'x'"), 'x') + self.assertEqual(ast.literal_eval("f'x' 'y'"), 'xy') + + # But this should raise an error + with self.assertRaises(ValueError): + ast.literal_eval("f'x{3}'") + + # As should this, which uses a different ast node + with self.assertRaises(ValueError): + ast.literal_eval("f'{3}'") + + def test_ast_compile_time_concat(self): + x = [''] + + expr = """x[0] = 'foo' f'{3}'""" + t = ast.parse(expr) + c = compile(t, '', 'exec') + exec(c) + self.assertEqual(x[0], 'foo3') + + def test_literal(self): + self.assertEqual(f'', '') + self.assertEqual(f'a', 'a') + self.assertEqual(f' ', ' ') + self.assertEqual(f'\N{GREEK CAPITAL LETTER DELTA}', + '\N{GREEK CAPITAL LETTER DELTA}') + self.assertEqual(f'\N{GREEK CAPITAL LETTER DELTA}', + '\u0394') + self.assertEqual(f'\N{True}', '\u22a8') + self.assertEqual(rf'\N{True}', r'\NTrue') + + def test_escape_order(self): + # note that hex(ord('{')) == 0x7b, so this + # string becomes f'a{4*10}b' + self.assertEqual(f'a\u007b4*10}b', 'a40b') + self.assertEqual(f'a\x7b4*10}b', 'a40b') + self.assertEqual(f'a\x7b4*10\N{RIGHT CURLY BRACKET}b', 'a40b') + self.assertEqual(f'{"a"!\N{LATIN SMALL LETTER R}}', "'a'") + self.assertEqual(f'{10\x3a02X}', '0A') + self.assertEqual(f'{10:02\N{LATIN CAPITAL LETTER X}}', '0A') + + self.assertAllRaise(SyntaxError, + ["""rf'a\u007b4*10}b'}""", # mis-matched brackets + """f'{"a"\!r}'""", + ]) + + def test_double_braces(self): + self.assertEqual(f'{{', '{') + self.assertEqual(f'a{{', 'a{') + self.assertEqual(f'{{b', '{b') + self.assertEqual(f'a{{b', 'a{b') + self.assertEqual(f'}}', '}') + self.assertEqual(f'a}}', 'a}') + self.assertEqual(f'}}b', '}b') + self.assertEqual(f'a}}b', 'a}b') + + self.assertEqual(f'{{{10}', '{10') + self.assertEqual(f'}}{10}', '}10') + self.assertEqual(f'}}{{{10}', '}{10') + self.assertEqual(f'}}a{{{10}', '}a{10') + + self.assertEqual(f'{10}{{', '10{') + self.assertEqual(f'{10}}}', '10}') + self.assertEqual(f'{10}}}{{', '10}{') + self.assertEqual(f'{10}}}a{{' '}', '10}a{}') + + # Inside of strings, don't interpret doubled brackets. + self.assertEqual(f'{"{{}}"}', '{{}}') + + self.assertAllRaise(TypeError, + ["f'{ {{}} }'", # dict in a set + ]) + + def test_compile_time_concat(self): + x = 'def' + self.assertEqual('abc' f'{x}ghi', 'abcdefghi') + self.assertEqual('abc' f'{x}' 'ghi', 'abcdefghi') + self.assertEqual('abc' f'{x}' 'gh' f'i{x:4}', 'abcdefghidef ') + self.assertEqual('{x}' f'{x}', '{x}def') + self.assertEqual('{x' f'{x}', '{xdef') + self.assertEqual('{x}' f'{x}', '{x}def') + self.assertEqual('{{x}}' f'{x}', '{{x}}def') + self.assertEqual('{{x' f'{x}', '{{xdef') + self.assertEqual('x}}' f'{x}', 'x}}def') + self.assertEqual(f'{x}' 'x}}', 'defx}}') + self.assertEqual(f'{x}' '', 'def') + self.assertEqual('' f'{x}' '', 'def') + self.assertEqual('' f'{x}', 'def') + self.assertEqual(f'{x}' '2', 'def2') + self.assertEqual('1' f'{x}' '2', '1def2') + self.assertEqual('1' f'{x}', '1def') + self.assertEqual(f'{x}' f'-{x}', 'def-def') + self.assertEqual('' f'', '') + self.assertEqual('' f'' '', '') + self.assertEqual('' f'' '' f'', '') + self.assertEqual(f'', '') + self.assertEqual(f'' '', '') + self.assertEqual(f'' '' f'', '') + self.assertEqual(f'' '' f'' '', '') + + self.assertAllRaise(SyntaxError, + ["f'{3' f'}'", # can't concat to get a valid f-string + ]) + + def test_comments(self): + # These aren't comments, since they're in strings + d = {'#': 'hash'} + self.assertEqual(f'{"#"}', '#') + self.assertEqual(f'{d["#"]}', 'hash') + + self.assertAllRaise(SyntaxError, + ["f'{1#}'", # error because the expression becomes "(1#)" + "f'{3(#)}'}", + ]) + + def test_many_expressions(self): + # Create a string with many expressions in it. Note that + # because we have a space in here as a literal, we're actually + # going to use twice as many ast nodes: one for each literal + # plus one for each expression. + def build_fstr(n, extra=''): + return "f'" + ('{x} ' * n) + extra + "'" + + x = 'X' + width = 1 + + # Test around 256 + for i in range(250, 260): + self.assertEqual(eval(build_fstr(i)), (x+' ')*i) + + # Test concatenating 2 largs fstrings + self.assertEqual(eval(build_fstr(255)*256), (x+' ')*(255*256)) + + s = build_fstr(253, '{x:{width}} ') + self.assertEqual(eval(s), (x+' ')*254) + + # Test lots of expressions and constants, concatenated. + s = "f'{1}' 'x' 'y'" * 1024 + self.assertEqual(eval(s), '1xy' * 1024) + + def test_format_specifier_expressions(self): + width = 10 + precision = 4 + value = decimal.Decimal('12.34567') + self.assertEqual(f'result: {value:{width}.{precision}}', 'result: 12.35') + self.assertEqual(f'result: {value:{width!r}.{precision}}', 'result: 12.35') + self.assertEqual(f'result: {value:{width:0}.{precision:1}}', 'result: 12.35') + self.assertEqual(f'result: {value:{1}{0:0}.{precision:1}}', 'result: 12.35') + self.assertEqual(f'result: {value:{ 1}{ 0:0}.{ precision:1}}', 'result: 12.35') + self.assertEqual(f'{10:#{1}0x}', ' 0xa') + self.assertEqual(f'{10:{"#"}1{0}{"x"}}', ' 0xa') + self.assertEqual(f'{-10:-{"#"}1{0}x}', ' -0xa') + self.assertEqual(f'{-10:{"-"}#{1}0{"x"}}', ' -0xa') + + self.assertAllRaise(SyntaxError, + [# Can't nest format specifiers + "f'result: {value:{width:{0}}.{prevision:1}}'", + + # This looks like a nested format spec + "f'{10:#{3 != {4:5} and width}x'", + + # No expansion inside conversion or for + # the : or ! itself + """f'{"s"}!{"r"}'""" + """f'{"s"!r{":10"}}'""", + ]) + + def test_side_effect_order(self): + class X: + def __init__(self): + self.i = 0 + + def __format__(self, spec): + self.i += 1 + return str(self.i) + + x = X() + self.assertEqual(f'{x} {x}', '1 2') + + def test_missing_expression(self): + self.assertAllRaise(SyntaxError, + ["f'{}'", + "f'{ }'" + "f' {} '", + "f'{!r}'", + "f'{ !r}'", + "f'{10:{ }}'", + "f' { } '", + "f'{\n}'", + "f'{\n \n}'", + ]) + + def test_escaped_quotes(self): + d = {'"': 'a', + "'": 'b'} + + self.assertEqual(fr"{d['\"']}", 'a') + self.assertEqual(fr'{d["\'"]}', 'b') + self.assertEqual(fr"{'\"'}", '"') + self.assertEqual(fr'{"\'"}', "'") + self.assertEqual(f'{"\\"3"}', '"3') + + self.assertAllRaise(SyntaxError, + ['''f'{"""\\}' ''', # Backslash at end of expression + ]) + + def test_parens_in_expressions(self): + self.assertEqual(f'{3,}', '(3,)') + + # Add these because when an expression is evaluated, parens + # are added around it. But we shouldn't go from an invalid + # expression to a valid one. The added parens are just + # supposed to allow whitespace (including newlines). + self.assertAllRaise(SyntaxError, + ["f'{,}'", + "f'{3)+(4}'", + "f'{}'", + "f'{\n}'", + "f'{,}'", # this is (,), which is an error + ]) + + def test_newlines_in_expressions(self): + self.assertEqual(f'{0}', '0') + self.assertEqual(f'{0\n}', '0') + self.assertEqual(f'{\n0\n}', '0') + self.assertEqual(f'{\n0}', '0') + self.assertEqual(f'{3+\n4}', '7') + self.assertEqual(f'{3+\\\n4}', '7') + self.assertEqual(rf'''{3+ +4}''', '7') + self.assertEqual(f'''{3+\ +4}''', '7') + + self.assertAllRaise(SyntaxError, + ["f{\n}", + ]) + + def test_lambda(self): + x = 5 + self.assertEqual(f'{(lambda y:x*y)("8")!r}', "'88888'") + self.assertEqual(f'{(lambda y:x*y)("8")!r:10}', "'88888' ") + self.assertEqual(f'{(lambda y:x*y)("8"):10}', "88888 ") + + # lambda doesn't work without parens, because the colon + # makes the parser think it's a format_spec + self.assertAllRaise(SyntaxError, + ["f'{lambda x:x}", + ]) + + def test_yield(self): + # Not terribly useful, but make sure the yield turns + # a function into a generator + def fn(y): + f'y:{yield y*2}' + + g = fn(4) + self.assertEqual(next(g), 8) + + def test_yield_send(self): + def fn(x): + yield f'x:{yield (lambda i: x * i)}' + + g = fn(10) + the_lambda = next(g) + self.assertEqual(the_lambda(4), 40) + self.assertEqual(g.send('string'), 'x:string') + + def test_expressions_with_triple_quoted_strings(self): + self.assertEqual(f"{'''x'''}", 'x') + self.assertEqual(f"{'''eric's'''}", "eric's") + self.assertEqual(f'{"""eric\'s"""}', "eric's") + self.assertEqual(f"{'''eric\"s'''}", 'eric"s') + self.assertEqual(f'{"""eric"s"""}', 'eric"s') + + # Test concatenation within an expression + self.assertEqual(f'{"x" """eric"s""" "y"}', 'xeric"sy') + self.assertEqual(f'{"x" """eric"s"""}', 'xeric"s') + self.assertEqual(f'{"""eric"s""" "y"}', 'eric"sy') + self.assertEqual(f'{"""x""" """eric"s""" "y"}', 'xeric"sy') + self.assertEqual(f'{"""x""" """eric"s""" """y"""}', 'xeric"sy') + self.assertEqual(f'{r"""x""" """eric"s""" """y"""}', 'xeric"sy') + + def test_multiple_vars(self): + x = 98 + y = 'abc' + self.assertEqual(f'{x}{y}', '98abc') + + self.assertEqual(f'X{x}{y}', 'X98abc') + self.assertEqual(f'{x}X{y}', '98Xabc') + self.assertEqual(f'{x}{y}X', '98abcX') + + self.assertEqual(f'X{x}Y{y}', 'X98Yabc') + self.assertEqual(f'X{x}{y}Y', 'X98abcY') + self.assertEqual(f'{x}X{y}Y', '98XabcY') + + self.assertEqual(f'X{x}Y{y}Z', 'X98YabcZ') + + def test_closure(self): + def outer(x): + def inner(): + return f'x:{x}' + return inner + + self.assertEqual(outer('987')(), 'x:987') + self.assertEqual(outer(7)(), 'x:7') + + def test_arguments(self): + y = 2 + def f(x, width): + return f'x={x*y:{width}}' + + self.assertEqual(f('foo', 10), 'x=foofoo ') + x = 'bar' + self.assertEqual(f(10, 10), 'x= 20') + + def test_locals(self): + value = 123 + self.assertEqual(f'v:{value}', 'v:123') + + def test_missing_variable(self): + with self.assertRaises(NameError): + f'v:{value}' + + def test_missing_format_spec(self): + class O: + def __format__(self, spec): + if not spec: + return '*' + return spec + + self.assertEqual(f'{O():x}', 'x') + self.assertEqual(f'{O()}', '*') + self.assertEqual(f'{O():}', '*') + + def test_global(self): + self.assertEqual(f'g:{a_global}', 'g:global variable') + self.assertEqual(f'g:{a_global!r}', "g:'global variable'") + + a_local = 'local variable' + self.assertEqual(f'g:{a_global} l:{a_local}', + 'g:global variable l:local variable') + self.assertEqual(f'g:{a_global!r}', + "g:'global variable'") + self.assertEqual(f'g:{a_global} l:{a_local!r}', + "g:global variable l:'local variable'") + + self.assertIn("module 'unittest' from", f'{unittest}') + + def test_shadowed_global(self): + a_global = 'really a local' + self.assertEqual(f'g:{a_global}', 'g:really a local') + self.assertEqual(f'g:{a_global!r}', "g:'really a local'") + + a_local = 'local variable' + self.assertEqual(f'g:{a_global} l:{a_local}', + 'g:really a local l:local variable') + self.assertEqual(f'g:{a_global!r}', + "g:'really a local'") + self.assertEqual(f'g:{a_global} l:{a_local!r}', + "g:really a local l:'local variable'") + + def test_call(self): + def foo(x): + return 'x=' + str(x) + + self.assertEqual(f'{foo(10)}', 'x=10') + + def test_nested_fstrings(self): + y = 5 + self.assertEqual(f'{f"{0}"*3}', '000') + self.assertEqual(f'{f"{y}"*3}', '555') + self.assertEqual(f'{f"{\'x\'}"*3}', 'xxx') + + self.assertEqual(f"{r'x' f'{\"s\"}'}", 'xs') + self.assertEqual(f"{r'x'rf'{\"s\"}'}", 'xs') + + def test_invalid_string_prefixes(self): + self.assertAllRaise(SyntaxError, + ["fu''", + "uf''", + "Fu''", + "fU''", + "Uf''", + "uF''", + "ufr''", + "urf''", + "fur''", + "fru''", + "rfu''", + "ruf''", + "FUR''", + "Fur''", + ]) + + def test_leading_trailing_spaces(self): + self.assertEqual(f'{ 3}', '3') + self.assertEqual(f'{ 3}', '3') + self.assertEqual(f'{\t3}', '3') + self.assertEqual(f'{\t\t3}', '3') + self.assertEqual(f'{3 }', '3') + self.assertEqual(f'{3 }', '3') + self.assertEqual(f'{3\t}', '3') + self.assertEqual(f'{3\t\t}', '3') + + self.assertEqual(f'expr={ {x: y for x, y in [(1, 2), ]}}', + 'expr={1: 2}') + self.assertEqual(f'expr={ {x: y for x, y in [(1, 2), ]} }', + 'expr={1: 2}') + + def test_character_name(self): + self.assertEqual(f'{4}\N{GREEK CAPITAL LETTER DELTA}{3}', + '4\N{GREEK CAPITAL LETTER DELTA}3') + self.assertEqual(f'{{}}\N{GREEK CAPITAL LETTER DELTA}{3}', + '{}\N{GREEK CAPITAL LETTER DELTA}3') + + def test_not_equal(self): + # There's a special test for this because there's a special + # case in the f-string parser to look for != as not ending an + # expression. Normally it would, while looking for !s or !r. + + self.assertEqual(f'{3!=4}', 'True') + self.assertEqual(f'{3!=4:}', 'True') + self.assertEqual(f'{3!=4!s}', 'True') + self.assertEqual(f'{3!=4!s:.3}', 'Tru') + + def test_conversions(self): + self.assertEqual(f'{3.14:10.10}', ' 3.14') + self.assertEqual(f'{3.14!s:10.10}', '3.14 ') + self.assertEqual(f'{3.14!r:10.10}', '3.14 ') + self.assertEqual(f'{3.14!a:10.10}', '3.14 ') + + self.assertEqual(f'{"a"}', 'a') + self.assertEqual(f'{"a"!r}', "'a'") + self.assertEqual(f'{"a"!a}', "'a'") + + # not a conversion + self.assertEqual(f'{"a!r"}', "a!r") + + # not a conversion, but show that ! is allowed in a format spec + self.assertEqual(f'{3.14:!<10.10}', '3.14!!!!!!') + + self.assertEqual(f'{"\N{GREEK CAPITAL LETTER DELTA}"}', '\u0394') + self.assertEqual(f'{"\N{GREEK CAPITAL LETTER DELTA}"!r}', "'\u0394'") + self.assertEqual(f'{"\N{GREEK CAPITAL LETTER DELTA}"!a}', "'\\u0394'") + + self.assertAllRaise(SyntaxError, + ["f'{3!g}'", + "f'{3!A}'", + "f'{3!A}'", + "f'{3!A}'", + "f'{3!!}'", + "f'{3!:}'", + "f'{3!\N{GREEK CAPITAL LETTER DELTA}}'", + "f'{3!ss}'", + "f'{3!ss:s}'", + "f'{3! s}'", # no space before conversion char + "f'{x!s{y}}'", + ]) + + def test_assignment(self): + self.assertAllRaise(SyntaxError, + ["f'' = 3", + "f'{0}' = x", + "f'{x}' = x", + ]) + + def test_del(self): + self.assertAllRaise(SyntaxError, + ["del f''", + "del '' f''", + ]) + + def test_mismatch_braces(self): + self.assertAllRaise(SyntaxError, + ["f'{'", + "f'{{{'", + "f'{{}'", + "f'{{}}}'", + "f'{{}}{'", + "f'}'", + "f'x{'", + "f'x}'", + "f'x{x'", + "f'x}x'", + "f'{3'", + "f'{3:'", + "f'{3!'", + + # can't have { or } in a format spec + "f'{3:{>10}'", + "f'{3:{{>10}'", + "f'{3:\\{>10}'", + """f'{3:{"{"}>10}'""", + "f'{3:}>10}'", + "f'{3:}}>10}'", + "f'{3:\\}>10}'", + """f'{3:{"}"}>10}'""", + ]) + + # But these are just normal strings + self.assertEqual(f'{"{"}', '{') + self.assertEqual(f'{"}"}', '}') + + def test_if_conditional(self): + # there's special logic in compile.c to test if the + # conditional for an if (and while) are constants. + + def test_fstring(x, expected): + flag = 0 + if f'{x}': + flag = 1 + else: + flag = 2 + self.assertEqual(flag, expected) + + def test_concat_empty(x, expected): + flag = 0 + if '' f'{x}': + flag = 1 + else: + flag = 2 + self.assertEqual(flag, expected) + + def test_concat_non_empty(x, expected): + flag = 0 + if ' ' f'{x}': + flag = 1 + else: + flag = 2 + self.assertEqual(flag, expected) + + test_fstring('', 2) + test_fstring(' ', 1) + + test_concat_empty('', 2) + test_concat_empty(' ', 1) + + test_concat_non_empty('', 1) + test_concat_non_empty(' ', 1) + + def test_empty_format_specifier(self): + x = 'test' + self.assertEqual(f'{x}', 'test') + self.assertEqual(f'{x:}', 'test') + self.assertEqual(f'{x!s:}', 'test') + self.assertEqual(f'{x!r:}', "'test'") + + def test_str_format_differences(self): + d = {'a': 'string', + 0: 'integer', + } + a = 0 + self.assertEqual(f'{d[0]}', 'integer') + self.assertEqual(f'{d["a"]}', 'string') + self.assertEqual(f'{d[a]}', 'integer') + self.assertEqual('{d[a]}'.format(d=d), 'string') + self.assertEqual('{d[0]}'.format(d=d), 'integer') + + def test_loop(self): + for i in range(1000): + self.assertEqual(f'i:{i}', 'i:' + str(i)) + + def test_dict(self): + d = {'"': 'dquote', + "'": 'squote', + 'foo': 'bar', + } + self.assertEqual(f'{d["\'"]}', 'squote') + self.assertEqual(f"{d['\"']}", 'dquote') + + self.assertEqual(f'''{d["'"]}''', 'squote') + self.assertEqual(f"""{d['"']}""", 'dquote') + + self.assertEqual(f'{d["foo"]}', 'bar') + self.assertEqual(f"{d['foo']}", 'bar') + self.assertEqual(f'{d[\'foo\']}', 'bar') + self.assertEqual(f"{d[\"foo\"]}", 'bar') + + +if __name__ == '__main__': + unittest.main() diff --git a/Parser/Python.asdl b/Parser/Python.asdl --- a/Parser/Python.asdl +++ b/Parser/Python.asdl @@ -71,6 +71,8 @@ | Call(expr func, expr* args, keyword* keywords) | Num(object n) -- a number as a PyObject. | Str(string s) -- need to specify raw, unicode, etc? + | FormattedValue(expr value, int? conversion, expr? format_spec) + | JoinedStr(expr* values) | Bytes(bytes s) | NameConstant(singleton value) | Ellipsis diff --git a/Parser/tokenizer.c b/Parser/tokenizer.c --- a/Parser/tokenizer.c +++ b/Parser/tokenizer.c @@ -1477,17 +1477,19 @@ nonascii = 0; if (is_potential_identifier_start(c)) { /* Process b"", r"", u"", br"" and rb"" */ - int saw_b = 0, saw_r = 0, saw_u = 0; + int saw_b = 0, saw_r = 0, saw_u = 0, saw_f = 0; while (1) { - if (!(saw_b || saw_u) && (c == 'b' || c == 'B')) + if (!(saw_b || saw_u || saw_f) && (c == 'b' || c == 'B')) saw_b = 1; /* Since this is a backwards compatibility support literal we don't want to support it in arbitrary order like byte literals. */ - else if (!(saw_b || saw_u || saw_r) && (c == 'u' || c == 'U')) + else if (!(saw_b || saw_u || saw_r || saw_f) && (c == 'u' || c == 'U')) saw_u = 1; /* ur"" and ru"" are not supported */ else if (!(saw_r || saw_u) && (c == 'r' || c == 'R')) saw_r = 1; + else if (!(saw_f || saw_b || saw_u) && (c == 'f' || c == 'F')) + saw_f = 1; else break; c = tok_nextc(tok); diff --git a/Python/ast.c b/Python/ast.c --- a/Python/ast.c +++ b/Python/ast.c @@ -257,6 +257,14 @@ } return 1; } + case JoinedStr_kind: + return validate_exprs(exp->v.JoinedStr.values, Load, 0); + case FormattedValue_kind: + if (validate_expr(exp->v.FormattedValue.value, Load) == 0) + return 0; + if (exp->v.FormattedValue.format_spec) + return validate_expr(exp->v.FormattedValue.format_spec, Load); + return 1; case Bytes_kind: { PyObject *b = exp->v.Bytes.s; if (!PyBytes_CheckExact(b)) { @@ -535,9 +543,7 @@ static expr_ty ast_for_call(struct compiling *, const node *, expr_ty); static PyObject *parsenumber(struct compiling *, const char *); -static PyObject *parsestr(struct compiling *, const node *n, int *bytesmode); -static PyObject *parsestrplus(struct compiling *, const node *n, - int *bytesmode); +static expr_ty parsestrplus(struct compiling *, const node *n); #define COMP_GENEXP 0 #define COMP_LISTCOMP 1 @@ -986,6 +992,8 @@ case Num_kind: case Str_kind: case Bytes_kind: + case JoinedStr_kind: + case FormattedValue_kind: expr_name = "literal"; break; case NameConstant_kind: @@ -2001,7 +2009,6 @@ | '...' | 'None' | 'True' | 'False' */ node *ch = CHILD(n, 0); - int bytesmode = 0; switch (TYPE(ch)) { case NAME: { @@ -2023,7 +2030,7 @@ return Name(name, Load, LINENO(n), n->n_col_offset, c->c_arena); } case STRING: { - PyObject *str = parsestrplus(c, n, &bytesmode); + expr_ty str = parsestrplus(c, n); if (!str) { const char *errtype = NULL; if (PyErr_ExceptionMatches(PyExc_UnicodeError)) @@ -2050,14 +2057,7 @@ } return NULL; } - if (PyArena_AddPyObject(c->c_arena, str) < 0) { - Py_DECREF(str); - return NULL; - } - if (bytesmode) - return Bytes(str, LINENO(n), n->n_col_offset, c->c_arena); - else - return Str(str, LINENO(n), n->n_col_offset, c->c_arena); + return str; } case NUMBER: { PyObject *pynum = parsenumber(c, STR(ch)); @@ -4002,12 +4002,786 @@ return v; } -/* s is a Python string literal, including the bracketing quote characters, - * and r &/or b prefixes (if any), and embedded escape sequences (if any). - * parsestr parses it, and returns the decoded Python string object. - */ +static int +fstring_find_literal(PyObject *str, Py_ssize_t *ofs, Py_ssize_t end, + Py_ssize_t *literal_start, Py_ssize_t *literal_end, + struct compiling *c, const node *n) +{ + /* Get any leading literal string. It ends when we hit an + un-doubled opening brace, or the end of the string. */ + + /* Set *literal_start and *literal_end to point to the literal + inside str. */ + + /* Return -1 on error. Return 0 if we reached the end of the + literal. Return 1 if we haven't reached the end of the literal, + but we want the caller to process the literal up to this + point. Used for doubled braces. */ + + enum PyUnicode_Kind kind = PyUnicode_KIND(str); + void *data = PyUnicode_DATA(str); + + *literal_start = *ofs; + for (; *ofs < end; *ofs += 1) { + Py_UCS4 ch = PyUnicode_READ(kind, data, *ofs); + if (ch == '{' || ch == '}') { + /* An un-doubled left brace is the start of an expression. + An un-doubled right brace is an error. */ + if (*ofs + 1 < end) { + Py_UCS4 ch1 = PyUnicode_READ(kind, data, *ofs + 1); + if (ch == ch1) { + /* We're going to tell the caller that the literal + ends here, but that they should continue + scanning. But also skip over the second brace + when we resume scanning. */ + *literal_end = *ofs + 1; + *ofs += 2; + return 1; + } + } + if (ch == '{') + break; + + /* Un-doubled right brace. Error. */ + ast_error(c, n, "single '}' encountered in format string"); + return -1; + } + } + *literal_end = *ofs; + return 0; +} + +/* The string we're looking inside is specified by str, *ofs and end. + We know *ofs starts an expression. Returns the string expression, + the conversion character (if any), and the format specifier. Note + that we don't do a perfect job here: I don't make sure that a + closing brace doesn't match an opening paren, for example. It + doesn't need to error on all invalid expressions, just correctly + find the end of all valid ones. Any errors inside the expression + will be caught when we parse it later. */ +static int +fstring_find_expr(PyObject *str, Py_ssize_t *ofs, Py_ssize_t end, + Py_ssize_t *expr_start, Py_ssize_t *expr_end, Py_UCS4 *conversion, + Py_ssize_t *fmt_spec_start, Py_ssize_t *fmt_spec_end, + struct compiling *c, const node *n) +{ + /* Return -1 on error, else 0. */ + + enum PyUnicode_Kind kind = PyUnicode_KIND(str); + void *data = PyUnicode_DATA(str); + + /* 0 if we're not in a string, else the quote char we're trying to + match (single or double quote). */ + Py_UCS4 quote_char = 0; + + /* If we're inside a string, 1=normal, 3=triple-quoted. */ + int string_type = 0; + + /* Keep track of nesting level for braces/parens/brackets in + expressions. */ + Py_ssize_t nested_depth = 0; + + /* How the expression was terminated. */ + Py_UCS4 terminal_char = 0; + + /* The first char must be a left brace. Skip over it. */ + *ofs += 1; + + *expr_start = *ofs; + for (; *ofs < end; *ofs += 1) { + Py_UCS4 ch; + + /* Loop invariants. */ + assert(nested_depth >= 0); + assert(terminal_char == 0); + assert(*ofs >= *expr_start); + if (quote_char) + assert(string_type == 1 || string_type == 3); + else + assert(string_type == 0); + + ch = PyUnicode_READ(kind, data, *ofs); + if (quote_char) { + /* We're inside a string. See if we're at the end. */ + /* This code needs to implement the same non-error logic as + tok_get from tokenizer.c, at the letter_quote label. To + actually share that code would be a nightmare. But, it's + unlikely to change and is small, so duplicate it here. Note we + don't need to catch all of the errors, since they'll be caught + when parsing the expression. We just need to match the + non-error cases. This we can ignore \n in single-quoted + strings, for example. Or non-terminated strings. */ + if (ch == quote_char) { + /* Does this match the string_type? */ + if (string_type == 3) { + if (*ofs+2 < end && + PyUnicode_READ(kind, data, *ofs+1) == ch && + PyUnicode_READ(kind, data, *ofs+2) == ch) { + /* We're at the end of a triple quoted string. */ + *ofs += 2; + string_type = 0; + quote_char = 0; + continue; + } + } else { + /* We're at the end of a normal string. */ + quote_char = 0; + string_type = 0; + continue; + } + } + /* We're inside a string, and not finished with the string. If + this is a backslash, skip the next char (it might be an end + quote that needs skipping). Otherwise, just consume this + character normally. */ + if (ch == '\\' && *ofs+1 < end) { + /* Just skip the next char, whatever it is. */ + *ofs += 1; + } + } else if (ch == '\'' || ch == '"') { + /* Is this a triple quoted string? */ + if (*ofs+2 < end && + PyUnicode_READ(kind, data, *ofs+1) == ch && + PyUnicode_READ(kind, data, *ofs+2) == ch) { + string_type = 3; + *ofs += 2; + } else { + /* Start of a normal string. */ + string_type = 1; + } + /* Start looking for the end of the string. */ + quote_char = ch; + } else if (ch == '[' || ch == '{' || ch == '(') { + nested_depth++; + } else if (nested_depth != 0 && + (ch == ']' || ch == '}' || ch == ')')) { + nested_depth--; + } else if (ch == '#') { + /* Error: can't include a comment character, inside parens + or not. */ + ast_error(c, n, "f-string cannot include #"); + return -1; + } else if (nested_depth == 0 && + (ch == '!' || ch == ':' || ch == '}')) { + /* First, test for the special case of "!=". Since '=' is + not an allowed conversion character, nothing is lost in + this test. */ + if (ch == '!' && *ofs+1 < end && PyUnicode_READ(kind, data, *ofs+1) == '=') + /* This isn't a conversion character, just continue. */ + continue; + + /* Normal way out of this loop. */ + terminal_char = ch; + break; + } else { + /* Just consume this char and loop around. */ + } + } + /* If we leave this loop in a string or with mismatched parens, we + don't care. We'll get a syntax error when parsing the + expression. */ + + if (terminal_char == 0) { + /* XXX: no terminating right brace before the end of the string. */ + /* set the appropriate error */ + ast_error(c, n, "missing '}' in format string expression"); + return -1; + } + + *expr_end = *ofs; + + /* Skip the terminal char. */ + *ofs += 1; + + /* Check for a conversion char, if present. */ + if (terminal_char == '!') { + if (*ofs >= end) { + ast_error(c, n, "invalid conversion char at end of string"); + return -1; + } + *conversion = PyUnicode_READ(kind, data, *ofs); + + /* Make sure the next char is a : or } */ + *ofs += 1; + if (*ofs >= end) { + ast_error(c, n, "invalid conversion char at end of string"); + return -1; + } + terminal_char = PyUnicode_READ(kind, data, *ofs); + + /* If this isn't ':' or '}', it's an error. */ + if (!(terminal_char == ':' || terminal_char == '}')) { + ast_error(c, n, "invalid character following conversion " + "character"); + return -1; + } + *ofs += 1; + } + + /* Check for the format spec, if present. */ + if (terminal_char == ':') { + /* Find the end of the format spec. It must be a right + brace. */ + Py_ssize_t nested_depth = 0; /* Keep track of nested brackets. */ + + *fmt_spec_start = *ofs; + for (; *ofs < end; *ofs += 1) { + Py_UCS4 ch = PyUnicode_READ(kind, data, *ofs); + if (nested_depth == 0 && ch == '}') { + /* End of the format spec. */ + terminal_char = ch; + break; + } else if (ch == '}') { + nested_depth--; + } else if (ch == '{') { + if (nested_depth >= 1) { + ast_error(c, n, "nesting of '{' in format specifier " + "is not allowed"); + return -1; + } + nested_depth++; + } + /* Just consume this char and loop around. */ + } + if (terminal_char != '}') { + ast_error(c, n, "missing '}' in format specifier"); + return -1; + } + *fmt_spec_end = *ofs; + *ofs += 1; + } + return 0; +} + + +/* The string we're looking inside is str, and we'll look at ofs + through end chars of that strings. + + If there's a leading literal, set *literal_start and *literal_end + to point to it. + + If there's an expression, set *expr_start and *expr_end to point to + it. + + If there's a conversion, set *converion to it. + + If there's a format specifier, set *fmt_spec_start and + *fmt_spec_end to point to it. + + */ +static int +fstring_find_literal_and_expr(PyObject *str, Py_ssize_t *ofs, Py_ssize_t end, + Py_ssize_t *literal_start, Py_ssize_t *literal_end, + Py_ssize_t *expr_start, Py_ssize_t *expr_end, Py_UCS4 *conversion, + Py_ssize_t *fmt_spec_start, Py_ssize_t *fmt_spec_end, + struct compiling *c, const node *n) +{ + int result; + + /* Return -1 if error, 0 if there's a literal and possibly a + literal, and 1 if there's a literal, no expression, but this + isn't the end of the input. */ + /* If the return value is 0 and there's no expression, then we're + at the end of the string. */ + /* Return type of 1 is used for doubling braces: a literal that + just includes one brace is returned, but no expression is + returned. When called again, we'll return the rest of the + literal, but having skipped over the second brace. The literal + is all joined back together in parsestrplus. */ + + /* Get any literal string. */ + result = fstring_find_literal(str, ofs, end, literal_start, literal_end, + c, n); + if (result < 0) + return -1; + + if (result == 1) + return 1; + + if (*ofs >= end) + /* We're at the end of the string: no expression. */ + return 0; + + /* We must now be the start of an expression, on a '{'. */ + assert(*ofs < end && PyUnicode_ReadChar(str, *ofs) == '{'); + + return fstring_find_expr(str, ofs, end, expr_start, expr_end, + conversion, fmt_spec_start, + fmt_spec_end, c, n); +} + +/* Compile this expression in to an expr_ty. We know that we can + temporarily modify the character before the start of this string + (it's '{'), and we know we can temporarily modify the character + after this string (it is a '}'). Leverage this to create a string + with enough room for us to add parens around the expression. This + is to allow strings with embedded newlines, for example. */ +static expr_ty +fstring_expression_compile(PyObject *str, Py_ssize_t expr_start, + Py_ssize_t expr_end, PyArena *arena) +{ + PyCompilerFlags cf; + mod_ty mod; + char *utf_expr; + PyObject *sub = NULL; + Py_ssize_t i; + int all_whitespace; + + /* If the substring is all whitespace, it's an error. We need to + catch this here, and not when we call PyParser_ASTFromString, + because turning the expression '' in to '()' would go from + being invalid to valid. */ + /* Note that this code says an empty string is all + whitespace. That's important. There's a test for it: f'{}'. */ + all_whitespace = 1; + for (i = expr_start; i < expr_end; i++) { + if (!Py_UNICODE_ISSPACE(PyUnicode_READ_CHAR(str, i))) { + all_whitespace = 0; + break; + } + } + if (all_whitespace) { + PyErr_SetString(PyExc_SyntaxError, "f-string with empty expression"); + return NULL; + } + + /* If the substring will be the entire source string, we can't use + PyUnicode_Substring, since it will return another reference to + our original string. Because we're modifying the string in + place, that's a no-no. So, detect that case and just use our + string directly. */ + + if (expr_start-1 == 0 && expr_end+1 == PyUnicode_GET_LENGTH(str)) { + sub = str; + /* No need to actually remember these characters, because we + know they must be braces. */ + assert(PyUnicode_ReadChar(sub, 0) == '{'); + assert(PyUnicode_ReadChar(sub, expr_end-expr_start+1) == '}'); + } else { + /* Create a substring object. */ + sub = PyUnicode_Substring(str, expr_start-1, expr_end+1); + if (!sub) + goto error; + } + + if (PyUnicode_WriteChar(sub, 0, '(') < 0 || + PyUnicode_WriteChar(sub, expr_end-expr_start+1, ')') < 0) + goto error; + + cf.cf_flags = PyCF_ONLY_AST; + + utf_expr = PyUnicode_AsUTF8(sub); + if (!utf_expr) + goto error; + mod = PyParser_ASTFromString(utf_expr, "", + Py_eval_input, &cf, arena); + if (!mod) + goto error; + if (sub != str) + Py_CLEAR(sub); + else { + if (PyUnicode_WriteChar(sub, 0, '{') < 0 || + PyUnicode_WriteChar(sub, expr_end-expr_start+1, '}') < 0) + goto error; + } + return mod->v.Expression.body; + +error: + if (sub != str) + Py_DECREF(sub); + return NULL; +} + +/* Forward declaration because parsing is recursive. */ +static expr_ty +fstring_parse(PyObject *str, Py_ssize_t start, Py_ssize_t end, + struct compiling *c, const node *n); + +/* Return -1 on error. + + Return 0 if we have a literal (possible zero length) and an + expression (zero length if at the end of the string. + + Return 1 if we have a literal, but no expression, and we want the + caller to call us again. This is used to deal with doubled + braces. + + When called multiple times on the string 'a{{b{0}c', this function + will return: + + 1. the literal 'a{' with no expression, and a return value + of 1. Despite the fact that there's no expression, the return + value of 1 means we're not finished yet. + + 2. the literal 'b' and the expression '0', with a return value of + 0. The fact that there's an expression means we're not finished. + + 3. literal 'c' with no expression and a return value of 0. The + combination of the return value of 0 with no expression means + we're finished. +*/ +static int +fstring_enumerate(PyObject *str, Py_ssize_t *start, Py_ssize_t end, + PyObject **literal, expr_ty *expression, + struct compiling *c, const node *n) +{ + Py_ssize_t literal_start = -1, literal_end = -1; + Py_ssize_t expr_start = -1, expr_end = -1; + Py_UCS4 conversion = 0; + Py_ssize_t fmt_spec_start = -1, fmt_spec_end = -1; + expr_ty format_spec = NULL; + expr_ty value; + int result; + + *literal = NULL; + *expression = NULL; + + result = fstring_find_literal_and_expr(str, start, end, + &literal_start, &literal_end, + &expr_start, &expr_end, + &conversion, &fmt_spec_start, + &fmt_spec_end, c, n); + if (result < 0) + goto error; + + /* If there's a literal, create the object. */ + if (literal_start != literal_end) { + *literal = PyUnicode_Substring(str, literal_start, literal_end); + if (!*literal) + goto error; + } + + if (expr_start == -1) { + /* If no expression, return. We might be done or not (see the + comment above the function definition. */ + assert(result == 0 || result == 1); + return result; + } + + /* Parse the expression, and save the value. */ + value = fstring_expression_compile(str, expr_start, expr_end, + c->c_arena); + if (!value) + goto error; + + if (conversion) { + /* Validate the conversion character. + XXX: should this go in fstring_find_expr? */ + if (!(conversion == 's' || conversion == 'r' + || conversion == 'a')) { + ast_error(c, n, "invalid conversion character"); + return -1; + } + } + + if (fmt_spec_start == fmt_spec_end) { + /* No format_spec provided. */ + format_spec = NULL; + } else { + format_spec = fstring_parse(str, fmt_spec_start, fmt_spec_end, c, n); + if (!format_spec) + goto error; + } + *expression = FormattedValue(value, (int)conversion, format_spec, + LINENO(n), n->n_col_offset, c->c_arena); + return 0; + error: + return -1; +} + + +#define EXPRLIST_N_CACHED 64 + +typedef struct { + /* Incrementally build an array of expr_ty, so be used in an + asdl_seq. Cache some small but reasonably sized number of + expr_ty's, and then after that start dynamically allocating, + doubling the number allocated each time. Note that the f-string + f'{0}a{1}' contains 3 expr_ty's: 2 FormattedValue's, and one + Str for the literal 'a'. So you add expr_ty's about twice as + fast as you add exressions in an f-string. */ + + Py_ssize_t allocated; /* Number we've allocated. */ + Py_ssize_t size; /* Number we've used. */ + expr_ty *p; /* Pointer to the memory we're actually + using. Will point to 'data' until we + start dynamically allocating. */ + expr_ty data[EXPRLIST_N_CACHED]; +} ExprList; + +static void +ExprList_Init(ExprList *l) +{ + l->allocated = EXPRLIST_N_CACHED; + l->size = 0; + + /* Until we start allocating dynamically, p points to data. */ + l->p = l->data; +} + +static int +ExprList_Append(ExprList *l, expr_ty exp) +{ + if (l->size >= l->allocated) { + /* We need to alloc (or realloc) the memory. */ + Py_ssize_t new_size = l->allocated * 2; + + /* See if we've ever allocated anything dynamically. */ + if (l->p == l->data) { + Py_ssize_t i; + /* We're still using the cached data. Switch to + alloc-ing. */ + l->p = PyMem_RawMalloc(sizeof(expr_ty) * new_size); + if (!l->p) + return -1; + /* Copy the cached data into the new buffer. */ + for (i = 0; i < l->size; i++) + l->p[i] = l->data[i]; + } else { + /* Just realloc. */ + expr_ty *tmp = PyMem_RawRealloc(l->p, sizeof(expr_ty) * new_size); + if (!tmp) { + PyMem_RawFree(l->p); + l->p = NULL; + return -1; + } + l->p = tmp; + } + + l->allocated = new_size; + assert(l->allocated == 2 * l->size); + } + + l->p[l->size++] = exp; + return 0; +} + +static void +ExprList_Dealloc(ExprList *l) +{ + /* If there's been an error, or we've never dynamically allocated, + do nothing. */ + if (!l->p || l->p == l->data) { + /* Do nothing. */ + } else { + /* We have dynamically allocated. Free the memory. */ + PyMem_RawFree(l->p); + } + l->p = NULL; +} + +static asdl_seq * +ExprList_Finish(ExprList *l, PyArena *arena) +{ + asdl_seq *seq = _Py_asdl_seq_new(l->size, arena); + if (seq != NULL) { + Py_ssize_t i; + + for (i = 0; i < l->size; i++) + asdl_seq_SET(seq, i, l->p[i]); + } + ExprList_Dealloc(l); + return seq; +} + +/* The FstringParser is designed to add a mix of strings and + f-strings, and concat them together as needed. Ultimately, it + generates an expr_ty. */ +typedef struct { + PyObject *last_str; + ExprList expr_list; +} FstringParser; + +static void +FstringParser_Init(FstringParser *state) +{ + state->last_str = NULL; + ExprList_Init(&state->expr_list); +} + +static void +FstringParser_Dealloc(FstringParser *state) +{ + Py_XDECREF(state->last_str); + ExprList_Dealloc(&state->expr_list); +} + +/* Make a Str node, but decref the PyUnicode object being addd. */ +static expr_ty +make_str_node_and_del(PyObject **str, struct compiling *c, const node* n) +{ + PyObject *s = *str; + *str = NULL; + assert(PyUnicode_CheckExact(s)); + if (PyArena_AddPyObject(c->c_arena, s) < 0) { + Py_DECREF(s); + return NULL; + } + return Str(s, LINENO(n), n->n_col_offset, c->c_arena); +} + +/* Add a non-f-string. str is decref'd. */ +static int +FstringParser_ConcatAndDel(FstringParser *state, PyObject *str) +{ + if (PyUnicode_GET_LENGTH(str) == 0) { + Py_DECREF(str); + return 0; + } + + if (!state->last_str) { + /* We didn't have a string before, so just remember this one. */ + state->last_str = str; + } else { + /* Concatenate this with the previous string. */ + PyObject *temp = PyUnicode_Concat(state->last_str, str); + Py_DECREF(state->last_str); + Py_DECREF(str); + state->last_str = temp; + if (!temp) + return -1; + } + return 0; +} + +/* Parse an f-string. The f-string is in str, between [start, end), + with no 'f' or quotes. str is not decref'd, since we don't know if + it's used elsewhere. And if we're only looking at a part of a + string, then decref'ing is definitely not the right thing to do! +*/ +static int +FstringParser_ConcatFstring(FstringParser *state, PyObject *str, + Py_ssize_t start, Py_ssize_t end, + struct compiling *c, const node *n) +{ + /* Parse the f-string. */ + while (1) { + PyObject *literal; + expr_ty expression; + + /* If there's a zero length literal in front of the + expression, literal will be NULL. If we're at the end of + the f-string, expression will be NULL (unless result == 1, + see below). */ + int result = fstring_enumerate(str, &start, end, &literal, + &expression, c, n); + if (result < 0) + return -1; + + /* Add the literal, if any. */ + if (!literal) { + /* Do nothing. Just leave last_str alone (and possibly + NULL). */ + } else if (!state->last_str) { + state->last_str = literal; + } else { + /* We have a literal, concatenate it. */ + assert(PyUnicode_GET_LENGTH(literal) != 0); + if (FstringParser_ConcatAndDel(state, literal) < 0) + return -1; + } + assert(!state->last_str || + PyUnicode_GET_LENGTH(state->last_str) != 0); + + /* See if we should just loop around to get the next literal + and expression, while ignoring the expression this + time. This is used for un-doubling braces, as an + optimization. */ + if (result == 1) + continue; + + if (!expression) + /* We're done with this f-string. */ + break; + + /* We know we have an expression. Convert any existing string + to a Str node. */ + if (!state->last_str) { + /* Do nothing. No previous literal. */ + } else { + /* Convert the existing last_str literal to a Str node. */ + expr_ty str = make_str_node_and_del(&state->last_str, c, n); + if (!str || ExprList_Append(&state->expr_list, str) < 0) + return -1; + } + + if (ExprList_Append(&state->expr_list, expression) < 0) + return -1; + } + return 0; +} + +/* Convert the partial state reflected in last_str and expr_list to an + expr_ty. The expr_ty can be a Str, or a JoinedStr. */ +static expr_ty +FstringParser_Finish(FstringParser *state, struct compiling *c, + const node *n) +{ + asdl_seq *seq; + + /* If we're just a constant string with no expressions, return + that. */ + if(state->expr_list.size == 0) { + if (!state->last_str) { + /* Create a zero length string. */ + state->last_str = PyUnicode_FromStringAndSize(NULL, 0); + if (!state->last_str) + goto error; + } + return make_str_node_and_del(&state->last_str, c, n); + } + + /* Create a Str node out of last_str, if needed. It will be the + last node in our expression list. */ + if (state->last_str) { + expr_ty str = make_str_node_and_del(&state->last_str, c, n); + if (!str || ExprList_Append(&state->expr_list, str) < 0) + goto error; + } + /* This has already been freed. */ + assert(state->last_str == NULL); + + seq = ExprList_Finish(&state->expr_list, c->c_arena); + if (!seq) + goto error; + + /* If there's only one expression, return it. Otherwise, we need + to join them together. */ + if (seq->size == 1) + return seq->elements[0]; + + return JoinedStr(seq, LINENO(n), n->n_col_offset, c->c_arena); + +error: + FstringParser_Dealloc(state); + return NULL; +} + +/* Given an f-string (with no 'f' or quotes) that's in str in the + range [start, end), parse it into an expr_ty. Return NULL on + error. Does not decref str. +*/ +static expr_ty +fstring_parse(PyObject *str, Py_ssize_t start, Py_ssize_t end, + struct compiling *c, const node *n) +{ + FstringParser state; + + FstringParser_Init(&state); + if (FstringParser_ConcatFstring(&state, str, start, end, c, n) < 0) + return NULL; + + return FstringParser_Finish(&state, c, n); +} + +/* n is a Python string literal, including the bracketing quote + characters, and r, b, u, &/or f prefixes (if any), and embedded + escape sequences (if any). parsestr parses it, and returns the + decoded Python string object. If the string is an f-string, set + *fmode and return the unparsed string object. +*/ static PyObject * -parsestr(struct compiling *c, const node *n, int *bytesmode) +parsestr(struct compiling *c, const node *n, int *bytesmode, int *fmode) { size_t len; const char *s = STR(n); @@ -4027,15 +4801,24 @@ quote = *++s; rawmode = 1; } + else if (quote == 'f' || quote == 'F') { + quote = *++s; + *fmode = 1; + } else { break; } } } + if (*fmode && *bytesmode) { + PyErr_BadInternalCall(); + return NULL; + } if (quote != '\'' && quote != '\"') { PyErr_BadInternalCall(); return NULL; } + /* Skip the leading quote char. */ s++; len = strlen(s); if (len > INT_MAX) { @@ -4044,12 +4827,17 @@ return NULL; } if (s[--len] != quote) { + /* Last quote char must match the first. */ PyErr_BadInternalCall(); return NULL; } if (len >= 4 && s[0] == quote && s[1] == quote) { + /* A triple quoted string. We've already skipped one quote at + the start and one at the end of the string. Now skip the + two at the start. */ s += 2; len -= 2; + /* And check that the last two match. */ if (s[--len] != quote || s[--len] != quote) { PyErr_BadInternalCall(); return NULL; @@ -4088,51 +4876,84 @@ } } return PyBytes_DecodeEscape(s, len, NULL, 1, - need_encoding ? c->c_encoding : NULL); + need_encoding ? c->c_encoding : NULL); } -/* Build a Python string object out of a STRING+ atom. This takes care of - * compile-time literal catenation, calling parsestr() on each piece, and - * pasting the intermediate results together. - */ -static PyObject * -parsestrplus(struct compiling *c, const node *n, int *bytesmode) +/* Accepts any number of STRING+ atoms, and concatenates them. Run + through each atom, and process it as needed. For bytes, just + concatenate them together, and the result will be a Bytes node. For + normal strings and f-strings, concatenate them together. The result + will be a Str node if there were no f-strings; a FormattedValue + node if there's just an f-string (with no leading or trailing + literals), or a JoinedStr node if there are multiple f-strings or + any literals involved. */ +static expr_ty +parsestrplus(struct compiling *c, const node *n) { - PyObject *v; + int bytesmode = 0; + PyObject *bytes_str = NULL; int i; - REQ(CHILD(n, 0), STRING); - v = parsestr(c, CHILD(n, 0), bytesmode); - if (v != NULL) { - /* String literal concatenation */ - for (i = 1; i < NCH(n); i++) { - PyObject *s; - int subbm = 0; - s = parsestr(c, CHILD(n, i), &subbm); - if (s == NULL) - goto onError; - if (*bytesmode != subbm) { - ast_error(c, n, "cannot mix bytes and nonbytes literals"); - Py_DECREF(s); - goto onError; + + FstringParser state; + FstringParser_Init(&state); + + for (i = 0; i < NCH(n); i++) { + int this_bytesmode = 0; + int this_fmode = 0; + PyObject *s; + + REQ(CHILD(n, i), STRING); + s = parsestr(c, CHILD(n, i), &this_bytesmode, &this_fmode); + if (!s) + goto error; + + /* Check that we're not mixing bytes with unicode. */ + if (i != 0 && bytesmode != this_bytesmode) { + ast_error(c, n, "cannot mix bytes and nonbytes literals"); + Py_DECREF(s); + goto error; + } + bytesmode = this_bytesmode; + + assert(bytesmode ? PyBytes_CheckExact(s) : PyUnicode_CheckExact(s)); + + if (bytesmode) { + /* For bytes, concat as we go. */ + if (i == 0) { + /* First time, just remember this value. */ + bytes_str = s; + } else { + PyBytes_ConcatAndDel(&bytes_str, s); + if (!bytes_str) + goto error; } - if (PyBytes_Check(v) && PyBytes_Check(s)) { - PyBytes_ConcatAndDel(&v, s); - if (v == NULL) - goto onError; - } - else { - PyObject *temp = PyUnicode_Concat(v, s); - Py_DECREF(s); - Py_DECREF(v); - v = temp; - if (v == NULL) - goto onError; - } + } else if (this_fmode) { + /* This is an f-string. Concatenate it. */ + if (FstringParser_ConcatFstring(&state, s, 0, + PyUnicode_GET_LENGTH(s), + c, n) < 0) + goto error; + Py_DECREF(s); + } else { + /* This is a regular string. Concatenate it. */ + if (!FstringParser_ConcatAndDel(&state, s) < 0) + goto error; } } - return v; - - onError: - Py_XDECREF(v); + if (bytesmode) { + /* Just return the bytes object and we're done. */ + if (PyArena_AddPyObject(c->c_arena, bytes_str) < 0) + goto error; + return Bytes(bytes_str, LINENO(n), n->n_col_offset, c->c_arena); + } + + /* We're not a bytes string, bytes_str should never have been set. */ + assert(bytes_str == NULL); + + return FstringParser_Finish(&state, c, n); + +error: + Py_XDECREF(bytes_str); + FstringParser_Dealloc(&state); return NULL; } diff --git a/Python/compile.c b/Python/compile.c --- a/Python/compile.c +++ b/Python/compile.c @@ -731,6 +731,7 @@ return 1; } + /* Allocate a new block and return a pointer to it. Returns NULL on error. */ @@ -3209,6 +3210,120 @@ e->v.Call.keywords); } +static int +compiler_joined_str(struct compiler *c, expr_ty e) +{ + /* Concatenate parts of a string using ''.join(parts). There are + probably better ways of doing this. + + This is used for constructs like "'x=' f'{42}'", which have to + be evaluated at compile time. */ + + static PyObject *empty_string; + static PyObject *join_string; + + if (!empty_string) { + empty_string = PyUnicode_FromString(""); + if (!empty_string) + return 0; + } + if (!join_string) { + join_string = PyUnicode_FromString("join"); + if (!join_string) + return 0; + } + + ADDOP_O(c, LOAD_CONST, empty_string, consts); + ADDOP_NAME(c, LOAD_ATTR, join_string, names); + VISIT_SEQ(c, expr, e->v.JoinedStr.values); + ADDOP_I(c, BUILD_LIST, asdl_seq_LEN(e->v.JoinedStr.values)); + ADDOP_I(c, CALL_FUNCTION, 1); + return 1; +} + +static int +compiler_formatted_value(struct compiler *c, expr_ty e) +{ + static PyObject *empty_string; + static PyObject *format_string; + static PyObject *str_string; + static PyObject *repr_string; + static PyObject *ascii_string; + + if (!empty_string) { + empty_string = PyUnicode_InternFromString(""); + if (!empty_string) + return 0; + } + + if (!format_string) { + format_string = PyUnicode_InternFromString("__format__"); + if (!format_string) + return 0; + } + + if (!str_string) { + str_string = PyUnicode_InternFromString("str"); + if (!str_string) + return 0; + } + + if (!repr_string) { + repr_string = PyUnicode_InternFromString("repr"); + if (!repr_string) + return 0; + } + if (!ascii_string) { + ascii_string = PyUnicode_InternFromString("ascii"); + if (!ascii_string) + return 0; + } + + /* If needed, convert via str, repr, or ascii. */ + if (e->v.FormattedValue.conversion) { + PyObject *conv_name; + switch (e->v.FormattedValue.conversion) { + case 's': + conv_name = str_string; + break; + case 'r': + conv_name = repr_string; + break; + case 'a': + conv_name = ascii_string; + break; + default: + PyErr_SetString(PyExc_SystemError, + "Unrecognized conversion character"); + return 0; + } + ADDOP_NAME(c, LOAD_GLOBAL, conv_name, names); + } + + /* Push the value. */ + VISIT(c, expr, e->v.FormattedValue.value); + + /* If needed, convert via str, repr, or ascii. */ + if (e->v.FormattedValue.conversion) { + /* Call the function we previously pushed. */ + ADDOP_I(c, CALL_FUNCTION, 1); + } + + ADDOP_NAME(c, LOAD_ATTR, format_string, names); + + /* The format spec, if any. */ + if (e->v.FormattedValue.format_spec) { + VISIT(c, expr, e->v.FormattedValue.format_spec); + } else { + /* No format spec specified, use an empty string. */ + ADDOP_O(c, LOAD_CONST, empty_string, consts); + } + + ADDOP_I(c, CALL_FUNCTION, 1); + + return 1; +} + /* shared code between compiler_call and compiler_class */ static int compiler_call_helper(struct compiler *c, @@ -3878,6 +3993,10 @@ case Str_kind: ADDOP_O(c, LOAD_CONST, e->v.Str.s, consts); break; + case JoinedStr_kind: + return compiler_joined_str(c, e); + case FormattedValue_kind: + return compiler_formatted_value(c, e); case Bytes_kind: ADDOP_O(c, LOAD_CONST, e->v.Bytes.s, consts); break; @@ -4784,4 +4903,3 @@ { return PyAST_CompileEx(mod, filename, flags, -1, arena); } - diff --git a/Python/symtable.c b/Python/symtable.c --- a/Python/symtable.c +++ b/Python/symtable.c @@ -1439,6 +1439,14 @@ VISIT_SEQ(st, expr, e->v.Call.args); VISIT_SEQ_WITH_NULL(st, keyword, e->v.Call.keywords); break; + case FormattedValue_kind: + VISIT(st, expr, e->v.FormattedValue.value); + if (e->v.FormattedValue.format_spec) + VISIT(st, expr, e->v.FormattedValue.format_spec); + break; + case JoinedStr_kind: + VISIT_SEQ(st, expr, e->v.JoinedStr.values); + break; case Num_kind: case Str_kind: case Bytes_kind: