diff -r 42bf74b90626 Lib/ast.py --- a/Lib/ast.py Wed Mar 20 14:26:33 2013 -0700 +++ b/Lib/ast.py Wed Mar 20 18:39:12 2013 -0700 @@ -35,53 +35,130 @@ return compile(source, filename, mode, PyCF_ONLY_AST) -def literal_eval(node_or_string): +class CannotConvertNode(Exception): + pass + +def convert_node(node, converters, *, nested=False): + """ + Converts an AST node into an appropriate Python object. + + *converters* is an iterable of converter functions. + + Converter functions accept an AST node, a conversion function (for + recursion) and a flag indicating whether or not this is a recursive + call for a nested AST node and return a converted value. + + Converters raise ast.CannotConvertNode to indicate that no conversion + occurred and the node should be passed to the next registered converter. + """ + def _convert(node): + return convert_node(node, converters, nested=True) + for converter in converters: + try: + return converter(node, _convert, nested) + except CannotConvertNode: + pass + if nested: + raise CannotConvertNode + raise ValueError('malformed node: ' + repr(node)) + +_LITERAL_CONVERTERS = [] + +def _add_converter(f): + _LITERAL_CONVERTERS.append(f) + return f + +_LITERAL_NODE_MAP = { + Str: (lambda node, convert: node.s), + Bytes: (lambda node, convert: node.s), + Num: (lambda node, convert: node.n), + Tuple: (lambda node, convert: tuple(map(convert, node.elts))), + List: (lambda node, convert: list(map(convert, node.elts))), + Set: (lambda node, convert: set(map(convert, node.elts))), + Dict: (lambda node, convert: dict((convert(k), convert(v)) for k, v + in zip(node.keys, node.values))), + NameConstant: (lambda node, convert: node.value), +} + +@_add_converter +def _convert_node_by_type(node, convert, nested): + try: + converter = _LITERAL_NODE_MAP[type(node)] + except KeyError: + raise CannotConvertNode + return converter(node, convert) + +@_add_converter +def _convert_unary_plus_or_minus(node, convert, nested): + if isinstance(node, UnaryOp) and \ + isinstance(node.op, (UAdd, USub)) and \ + isinstance(node.operand, (Num, UnaryOp, BinOp)): + operand = convert(node.operand) + if isinstance(node.op, UAdd): + return + operand + else: + return - operand + raise CannotConvertNode + +@_add_converter +def _convert_binary_plus_or_minus(node, convert, nested): + # This should really only be one level deep to support complex numbers + # but has been released with support for arbitrarily deep nesting of + # addition and subtraction. We retain the support, but don't document + # it, and handle the case where the left or right operand fails to + # convert relatively gracefully + if isinstance(node, BinOp) and \ + isinstance(node.op, (Add, Sub)) and \ + isinstance(node.right, (Num, UnaryOp, BinOp)) and \ + isinstance(node.left, (Num, UnaryOp, BinOp)): + left = convert(node.left) + right = convert(node.right) + if isinstance(node.op, Add): + return left + right + else: + return left - right + raise CannotConvertNode + + +def literal_eval(node_or_string, allow=()): """ Safely evaluate an expression node or a string containing a Python expression. The string or node provided may only consist of the following - Python literal structures: strings, bytes, numbers, tuples, lists, dicts, - sets, booleans, and None. + Python literal structures: strings, bytes, numbers (including complex + numbers), tuples, lists, dicts, sets, booleans, and None. + + Additional AST nodes may be supported by passing *allow*, an optional + iterable of custom converter functions. + + Converter functions accept an AST node, a conversion function (for + recursion) and a flag indicating whether or not this is a recursive + call for a nested AST node and return a converted value. + + Converters raise ast.CannotConvertNode to indicate that no conversion + occurred and the node should be passed to the next registered converter. """ if isinstance(node_or_string, str): - node_or_string = parse(node_or_string, mode='eval') - if isinstance(node_or_string, Expression): - node_or_string = node_or_string.body - def _convert(node): - if isinstance(node, (Str, Bytes)): - return node.s - elif isinstance(node, Num): - return node.n - elif isinstance(node, Tuple): - return tuple(map(_convert, node.elts)) - elif isinstance(node, List): - return list(map(_convert, node.elts)) - elif isinstance(node, Set): - return set(map(_convert, node.elts)) - elif isinstance(node, Dict): - return dict((_convert(k), _convert(v)) for k, v - in zip(node.keys, node.values)) - elif isinstance(node, NameConstant): - return node.value - elif isinstance(node, UnaryOp) and \ - isinstance(node.op, (UAdd, USub)) and \ - isinstance(node.operand, (Num, UnaryOp, BinOp)): - operand = _convert(node.operand) - if isinstance(node.op, UAdd): - return + operand - else: - return - operand - elif isinstance(node, BinOp) and \ - isinstance(node.op, (Add, Sub)) and \ - isinstance(node.right, (Num, UnaryOp, BinOp)) and \ - isinstance(node.left, (Num, UnaryOp, BinOp)): - left = _convert(node.left) - right = _convert(node.right) - if isinstance(node.op, Add): - return left + right - else: - return left - right - raise ValueError('malformed node or string: ' + repr(node)) - return _convert(node_or_string) + # We use exec mode so custom converter functions can process + # additional node types, like import statements, function and class + # headers, etc + node = parse(node_or_string) + assert isinstance(node, Module) + if len(node.body) == 1: + node = node.body[0] + else: + node = node_or_string + # Unwraps the top level Expression node, *not* the Expr stmt node + if isinstance(node, Expression): + node = node.body + # Also unwrap the Expr stmt node + if isinstance(node, Expr): + node = node.value + try: + return convert_node(node, _LITERAL_CONVERTERS + list(allow)) + except ValueError: + raise ValueError('malformed node or string: ' + + repr(node_or_string)) from None + def dump(node, annotate_fields=True, include_attributes=False): diff -r 42bf74b90626 Lib/test/test_ast.py --- a/Lib/test/test_ast.py Wed Mar 20 14:26:33 2013 -0700 +++ b/Lib/test/test_ast.py Wed Mar 20 18:39:12 2013 -0700 @@ -523,16 +523,68 @@ self.assertEqual(ast.get_docstring(node.body[0]), 'line one\nline two') + _literal_eval_examples = [ + ('[1, 2, 3]', [1, 2, 3]), + ('{"foo": 42}', {"foo": 42}), + ('(True, False, None)', (True, False, None)), + ('{1, 2, 3}', {1, 2, 3}), + ('b"hi"', b"hi"), + ('-6', -6), + ('-6j+3', 3-6j), + ('6j--3', 3+6j), + ('(2j+4j)+(1+2)', 3+6j), + ('(2j+4j)-(1+2)', -3+6j), + ('3.25', 3.25), + ] + def test_literal_eval(self): - self.assertEqual(ast.literal_eval('[1, 2, 3]'), [1, 2, 3]) - self.assertEqual(ast.literal_eval('{"foo": 42}'), {"foo": 42}) - self.assertEqual(ast.literal_eval('(True, False, None)'), (True, False, None)) - self.assertEqual(ast.literal_eval('{1, 2, 3}'), {1, 2, 3}) - self.assertEqual(ast.literal_eval('b"hi"'), b"hi") - self.assertRaises(ValueError, ast.literal_eval, 'foo()') - self.assertEqual(ast.literal_eval('-6'), -6) - self.assertEqual(ast.literal_eval('-6j+3'), 3-6j) - self.assertEqual(ast.literal_eval('3.25'), 3.25) + for expr, expected in self._literal_eval_examples: + with self.subTest(expr=expr, expected=expected): + self.assertEqual(ast.literal_eval(expr), expected) + + _literal_eval_errors = [ + '', + 'foo()', + 'foo(),', + '[foo()]', + '{foo()}', + '{1:foo()}', + '{foo():1}', + '1 * 2', + '1 ** 2', + '(1 * 2) + 3', + '1 + (2 * 3)', + '"" + 1', + '1 + ""', + ] + + def test_literal_eval_raises(self): + for expr in self._literal_eval_errors: + with self.subTest(expr=expr): + self.assertRaises(ValueError, ast.literal_eval, expr) + + def test_literal_eval_node(self): + expr = ast.parse('1', mode='eval') + self.assertEqual(ast.literal_eval(expr), 1) + self.assertEqual(ast.literal_eval(expr.body), 1) + + def test_literal_eval_custom(self): + def convert_any(node, convert, nested): + if nested: + raise ast.CannotConvertNode + return node + def eval_any(source): + return ast.literal_eval(source, [convert_any]) + self.assertIsInstance(eval_any('class A: pass'), ast.ClassDef) + self.assertIsInstance(eval_any('def f(): pass'), ast.FunctionDef) + for expr in self._literal_eval_errors: + with self.subTest(expr=expr): + node = ast.parse(expr) + if len(node.body) == 1: + node = node.body[0].value + self.assertIs(eval_any(node), node) + + #TODO: Explicit tests for convert_node def test_literal_eval_issue4907(self): self.assertEqual(ast.literal_eval('2j'), 2j)