Index: Python/optimize.c =================================================================== --- Python/optimize.c (revision 0) +++ Python/optimize.c (revision 0) @@ -0,0 +1,601 @@ +#include "Python.h" + +#include "Python-ast.h" +#include "node.h" +#include "opcode.h" +#include "optimize.h" + +typedef expr_ty (*expr_visitor)(void *, expr_ty); +typedef stmt_ty (*stmt_visitor)(void *, stmt_ty); +struct visitor_context +{ + void *context; + expr_visitor preorder_expr_visitor; + expr_visitor postorder_expr_visitor; + stmt_visitor preorder_stmt_visitor; + stmt_visitor postorder_stmt_visitor; +}; + + +static stmt_ty visit_stmt(struct visitor_context *c, stmt_ty s); +static expr_ty visit_expr(struct visitor_context *c, expr_ty e); +static keyword_ty visit_keyword(struct visitor_context *c, keyword_ty k); +static slice_ty visit_slice(struct visitor_context *c, slice_ty s, expr_context_ty ctx); +static comprehension_ty visit_comprehension(struct visitor_context *c, comprehension_ty k); +static excepthandler_ty visit_excepthandler(struct visitor_context *c, excepthandler_ty k); + + +/* GEN_VISIT and GEN_VISIT_SEQ takes an ASDL type as their second argument. They use + the ASDL name to synthesize the name of the C type and the visit function. +*/ + +#define GEN_VISIT(C, TYPE, V) { \ + TYPE ## _ty __tmp; \ + if (!(__tmp = visit_ ## TYPE((C), (V)))) \ + return 0; \ + (V) = __tmp; \ +} + +#define GEN_VISIT_SLICE(C, V, CTX) { \ + slice_ty __tmp; \ + if (!(__tmp = visit_slice((C), (V), (CTX)))) \ + return 0; \ + (V) = __tmp; \ +} + +#define GEN_VISIT_SEQ(C, TYPE, SEQ) { \ + int i; \ + asdl_seq *seq = (SEQ); /* avoid variable capture */ \ + for (i = 0; i < asdl_seq_LEN(seq); i++) { \ + TYPE ## _ty __tmp, elt = asdl_seq_GET(seq, i); \ + if (!(__tmp=visit_ ## TYPE((C), elt))) \ + return 0; \ + asdl_seq_SET(seq, i, __tmp); \ + } \ +} + +static stmt_ty +visit_stmt(struct visitor_context *c, stmt_ty s) +{ + if(c->preorder_stmt_visitor) { + if(!(s = c->preorder_stmt_visitor(c->context, s))) + return 0; + } + + switch (s->kind) { + case FunctionDef_kind: + GEN_VISIT_SEQ(c, stmt, s->v.FunctionDef.body); + break; + case ClassDef_kind: + GEN_VISIT_SEQ(c, stmt, s->v.ClassDef.body); + break; + case Return_kind: + if(s->v.Return.value) + GEN_VISIT(c, expr, s->v.Return.value); + break; + case Delete_kind: + GEN_VISIT_SEQ(c, expr, s->v.Delete.targets); + break; + case Assign_kind: + GEN_VISIT_SEQ(c, expr, s->v.Assign.targets); + GEN_VISIT(c, expr, s->v.Assign.value); + break; + case AugAssign_kind: + GEN_VISIT(c, expr, s->v.AugAssign.target); + GEN_VISIT(c, expr, s->v.AugAssign.value); + break; + case For_kind: + GEN_VISIT(c, expr, s->v.For.target); + GEN_VISIT(c, expr, s->v.For.iter); + GEN_VISIT_SEQ(c, stmt, s->v.For.body); + GEN_VISIT_SEQ(c, stmt, s->v.For.orelse); + break; + case While_kind: + GEN_VISIT(c, expr, s->v.While.test); + GEN_VISIT_SEQ(c, stmt, s->v.While.body); + GEN_VISIT_SEQ(c, stmt, s->v.While.orelse); + break; + case If_kind: + GEN_VISIT(c, expr, s->v.If.test); + GEN_VISIT_SEQ(c, stmt, s->v.If.body); + GEN_VISIT_SEQ(c, stmt, s->v.If.orelse); + break; + case With_kind: + GEN_VISIT(c, expr, s->v.With.context_expr); + if (s->v.With.optional_vars) + GEN_VISIT(c, expr, s->v.With.optional_vars); + GEN_VISIT_SEQ(c, stmt, s->v.With.body); + break; + case Raise_kind: + if (s->v.Raise.exc) + GEN_VISIT(c, expr, s->v.Raise.exc); + if (s->v.Raise.cause) + GEN_VISIT(c, expr, s->v.Raise.cause); + break; + case TryExcept_kind: + GEN_VISIT_SEQ(c, stmt, s->v.TryExcept.body); + GEN_VISIT_SEQ(c, excepthandler, s->v.TryExcept.handlers); + GEN_VISIT_SEQ(c, stmt, s->v.TryExcept.orelse); + break; + case TryFinally_kind: + GEN_VISIT_SEQ(c, stmt, s->v.TryFinally.body); + GEN_VISIT_SEQ(c, stmt, s->v.TryFinally.finalbody); + break; + case Assert_kind: + GEN_VISIT(c, expr, s->v.Assert.test); + if(s->v.Assert.msg) + GEN_VISIT(c, expr, s->v.Assert.msg); + break; + case Import_kind: + break; + case ImportFrom_kind: + break; + case Global_kind: + break; + case Nonlocal_kind: + break; + case Expr_kind: + GEN_VISIT(c, expr, s->v.Expr.value); + break; + case Pass_kind: + break; + case Break_kind: + break; + case Continue_kind: + break; + default: + fprintf(stderr, "Unknown statement kind %d\n", s->kind); + assert(0); + return 0; + } + if(c->postorder_stmt_visitor) { + if(!(s = c->postorder_stmt_visitor(c->context, s))) + return 0; + } + + return s; +} + + +static expr_ty +visit_expr(struct visitor_context *c, expr_ty e) +{ + if(c->preorder_expr_visitor) { + if(!(e = c->preorder_expr_visitor(c->context, e))) + return 0; + } + + switch (e->kind) { + case BoolOp_kind: + GEN_VISIT_SEQ(c, expr, e->v.BoolOp.values); + break; + case BinOp_kind: + GEN_VISIT(c, expr, e->v.BinOp.left); + GEN_VISIT(c, expr, e->v.BinOp.right); + break; + case UnaryOp_kind: + GEN_VISIT(c, expr, e->v.UnaryOp.operand); + break; + case Lambda_kind: + GEN_VISIT(c, expr, e->v.Lambda.body); + break; + case IfExp_kind: + GEN_VISIT(c, expr, e->v.IfExp.test); + GEN_VISIT(c, expr, e->v.IfExp.body); + GEN_VISIT(c, expr, e->v.IfExp.orelse); + break; + case Dict_kind: + GEN_VISIT_SEQ(c, expr, e->v.Dict.keys); + GEN_VISIT_SEQ(c, expr, e->v.Dict.values); + break; + case Set_kind: + GEN_VISIT_SEQ(c, expr, e->v.Set.elts); + break; + case ListComp_kind: + GEN_VISIT(c, expr, e->v.ListComp.elt); + GEN_VISIT_SEQ(c, comprehension, e->v.ListComp.generators); + break; + case SetComp_kind: + GEN_VISIT(c, expr, e->v.SetComp.elt); + GEN_VISIT_SEQ(c, comprehension, e->v.SetComp.generators); + break; + case DictComp_kind: + GEN_VISIT(c, expr, e->v.DictComp.key); + GEN_VISIT(c, expr, e->v.DictComp.value); + GEN_VISIT_SEQ(c, comprehension, e->v.DictComp.generators); + break; + case GeneratorExp_kind: + GEN_VISIT(c, expr, e->v.GeneratorExp.elt); + GEN_VISIT_SEQ(c, comprehension, e->v.GeneratorExp.generators); + break; + case Yield_kind: + if (e->v.Yield.value) + GEN_VISIT(c, expr, e->v.Yield.value); + break; + case Compare_kind: + GEN_VISIT(c, expr, e->v.Compare.left); + GEN_VISIT_SEQ(c, expr, e->v.Compare.comparators); + break; + case Call_kind: + GEN_VISIT(c, expr, e->v.Call.func); + GEN_VISIT_SEQ(c, expr, e->v.Call.args); + if(e->v.Call.keywords) + GEN_VISIT_SEQ(c, keyword, e->v.Call.keywords); + if (e->v.Call.starargs) + GEN_VISIT(c, expr, e->v.Call.starargs); + if (e->v.Call.kwargs) + GEN_VISIT(c, expr, e->v.Call.kwargs); + break; + case Num_kind: + break; + case Str_kind: + break; + case Bytes_kind: + break; + case Ellipsis_kind: + break; + case Attribute_kind: + if (e->v.Attribute.ctx != AugStore) + GEN_VISIT(c, expr, e->v.Attribute.value); + break; + case Subscript_kind: + switch (e->v.Subscript.ctx) { + case AugLoad: + GEN_VISIT(c, expr, e->v.Subscript.value); + GEN_VISIT_SLICE(c, e->v.Subscript.slice, AugLoad); + break; + case Load: + GEN_VISIT(c, expr, e->v.Subscript.value); + GEN_VISIT_SLICE(c, e->v.Subscript.slice, Load); + break; + case AugStore: + GEN_VISIT_SLICE(c, e->v.Subscript.slice, AugStore); + break; + case Store: + GEN_VISIT(c, expr, e->v.Subscript.value); + GEN_VISIT_SLICE(c, e->v.Subscript.slice, Store); + break; + case Del: + GEN_VISIT(c, expr, e->v.Subscript.value); + GEN_VISIT_SLICE(c, e->v.Subscript.slice, Del); + break; + case Param: + assert(0); + break; + } + break; + case Starred_kind: + GEN_VISIT(c, expr, e->v.Starred.value); + break; + case Name_kind: + break; + case List_kind: + GEN_VISIT_SEQ(c, expr, e->v.List.elts); + break; + case Tuple_kind: + GEN_VISIT_SEQ(c, expr, e->v.Tuple.elts); + break; + default: + fprintf(stderr, "Unknown expression kind %d\n", e->kind); + assert(0); + return 0; + } + + if(c->postorder_expr_visitor) { + if(!(e = c->postorder_expr_visitor(c->context, e))) + return 0; + } + return e; +} + + + +static keyword_ty +visit_keyword(struct visitor_context *c, keyword_ty k) +{ + GEN_VISIT(c, expr, k->value); + return k; +} + +static comprehension_ty +visit_comprehension(struct visitor_context *c, comprehension_ty k) +{ + GEN_VISIT(c, expr, k->target); + GEN_VISIT(c, expr, k->iter); + GEN_VISIT_SEQ(c, expr, k->ifs); + return k; +} + + +static excepthandler_ty +visit_excepthandler(struct visitor_context *c, excepthandler_ty h) +{ + if(h->v.ExceptHandler.type) + GEN_VISIT(c, expr, h->v.ExceptHandler.type); + GEN_VISIT_SEQ(c, stmt, h->v.ExceptHandler.body); + return h; +} + + +static slice_ty +visit_nested_slice(struct visitor_context *c, slice_ty s, expr_context_ty ctx) +{ + switch (s->kind) { + case Slice_kind: + if(s->v.Slice.lower) + GEN_VISIT(c, expr, s->v.Slice.lower); + if(s->v.Slice.upper) + GEN_VISIT(c, expr, s->v.Slice.upper); + if(s->v.Slice.step) + GEN_VISIT(c, expr, s->v.Slice.step); + break; + case Index_kind: + GEN_VISIT(c, expr, s->v.Index.value); + break; + case ExtSlice_kind: + break; + } + return s; +} + + +static slice_ty +visit_slice(struct visitor_context *c, slice_ty s, expr_context_ty ctx) +{ + switch (s->kind) { + case Slice_kind: + if (s->v.Slice.lower) { + GEN_VISIT(c, expr, s->v.Slice.lower); + } + if (s->v.Slice.upper) { + GEN_VISIT(c, expr, s->v.Slice.upper); + } + if (s->v.Slice.step) { + GEN_VISIT(c, expr, s->v.Slice.step); + } + break; + case ExtSlice_kind: { + int i, n = asdl_seq_LEN(s->v.ExtSlice.dims); + for (i = 0; i < n; i++) { + slice_ty sub = asdl_seq_GET(s->v.ExtSlice.dims, i); + if (!(asdl_seq_SET(s->v.ExtSlice.dims, i, + visit_nested_slice(c, sub, ctx)))) + return 0; + } + break; + } + case Index_kind: + GEN_VISIT(c, expr, s->v.Index.value); + break; + } + return s; +} + + +static mod_ty +visit_mod(struct visitor_context *c, mod_ty mod) +{ + switch (mod->kind) { + case Module_kind: + GEN_VISIT_SEQ(c, stmt, mod->v.Module.body); + break; + case Interactive_kind: + GEN_VISIT_SEQ(c, stmt, mod->v.Interactive.body); + break; + case Expression_kind: + GEN_VISIT(c, expr, mod->v.Expression.body); + break; + case Suite_kind: + GEN_VISIT_SEQ(c, stmt, mod->v.Suite.body); + break; + default: + assert(0); + } + return mod; +} + + + +static PyObject *get_constant_object(expr_ty e) +{ + switch (e->kind) { + case Num_kind: + return e->v.Num.n; + case Str_kind: + return e->v.Str.s; + default: + return 0; + } +} + +static expr_ty create_ast_from_constant_object(PyObject *obj, int lineno, int col_offset, PyArena *arena) +{ + if (PyNumber_Check(obj)) + return Num(obj, lineno, col_offset, arena); + if (PyBytes_Check(obj) || PyUnicode_Check(obj)) + return Str(obj, lineno, col_offset, arena); + return 0; +} + +static expr_ty constant_fold_expr_visitor(void *context, expr_ty e) +{ + PyObject *v, *w, *newconst = 0; + expr_ty newast = 0; + int i, n; + int err; + int lineno = e->lineno; + int col_offset = e->col_offset; + PyArena *arena = (PyArena*)context; + switch(e->kind) { + case BinOp_kind: + if ((v = get_constant_object(e->v.BinOp.left)) && + (w = get_constant_object(e->v.BinOp.right))) { + switch(e->v.BinOp.op) { + case Add: + newconst = PyNumber_Add(v, w); + break; + case Sub: + newconst = PyNumber_Subtract(v, w); + break; + case Mult: + newconst = PyNumber_Multiply(v, w); + break; + case Div: + /* don't touch, since division behaviour can be changed + during run-time */ + break; + case Mod: + newconst = PyNumber_Remainder(v, w); + break; + case Pow: + newconst = PyNumber_Power(v, w, Py_None); + break; + case LShift: + newconst = PyNumber_Lshift(v, w); + break; + case RShift: + newconst = PyNumber_Rshift(v, w); + break; + case BitOr: + newconst = PyNumber_Or(v, w); + break; + case BitXor: + newconst = PyNumber_Xor(v, w); + break; + case BitAnd: + newconst = PyNumber_And(v, w); + break; + case FloorDiv: + newconst = PyNumber_FloorDivide(v, w); + break; + default: + break; + } + + if (newconst == NULL) { + /* + an exception may have been raised during the + calculation, but we want the exception to occur during + run-time instead. Therefore, clear the error and leave + the node untouched so the error will reoccur at + run-time. + */ + PyErr_Clear(); + break; + } + + if(newconst) { + int size = PyObject_Size(newconst); + if (size == -1) + PyErr_Clear(); + else if (size > 20) { + Py_DECREF(newconst); + break; + } + newast = create_ast_from_constant_object(newconst, lineno, col_offset, arena); + if(newast) { + return newast; + } + else { + Py_DECREF(newconst); + } + } + } + break; + + + case UnaryOp_kind: + if((v = get_constant_object(e->v.UnaryOp.operand))) { + switch(e->v.UnaryOp.op) { + case UAdd: + newconst = PyNumber_Positive(v); + break; + case USub: + /* Preserve the sign of -0.0 */ + if (PyObject_IsTrue(v) == 1) + newconst = PyNumber_Negative(v); + break; + case Invert: + newconst = PyNumber_Invert(v); + break; + default: + /* TODO: do something about Not as well? */ + break; + } + + if (newconst == NULL) { + PyErr_Clear(); + break; + } + + if (newconst) { + return Num(newconst, lineno, col_offset, arena); + } + } + break; + + case BoolOp_kind: + if(asdl_seq_LEN(e->v.BoolOp.values)) { + n = asdl_seq_LEN(e->v.BoolOp.values); + for(i = 0; i < n-1; ++i) { + v = get_constant_object(asdl_seq_GET(e->v.BoolOp.values, i)); + if(!v) goto partial_result; + err = PyObject_IsTrue(v); + if(err == -1) goto partial_result; + if((err != 0) ^ (e->v.BoolOp.op == And)) { + goto shortcircuit; + } + } + v = get_constant_object(asdl_seq_GET(e->v.BoolOp.values, n-1)); + if(!v) goto partial_result; + + shortcircuit: + newast = create_ast_from_constant_object(v, lineno, col_offset, arena); + if(newast) { + Py_INCREF(v); + return newast; + } + + partial_result: + if(i == 0) + return e; + if(i == n-1) { + expr_ty tmp = asdl_seq_GET(e->v.BoolOp.values, i); + asdl_seq_SET(e->v.BoolOp.values, i, 0); + return tmp; + } + else { + /* partial result: the values 0..i-1 are constant and + can be skipped, the rest must be evaluated */ + asdl_seq *new_seq = asdl_seq_new(n-i, arena); + int j; + for(j = i; j < n; ++j) { + asdl_seq_SET(new_seq, j-i, asdl_seq_GET(e->v.BoolOp.values, j)); + } + e->v.BoolOp.values = new_seq; + } + return e; + } + break; + + case Compare_kind: + /* + This operation could be constant folded, however, it is hard to + generate new references to True and False after the symbol table + pass has been performed. Postponed for now. + */ + break; + + default: + break; + } + return e; +} + +mod_ty ast_optimize(mod_ty m, PyArena *arena) +{ + struct visitor_context constant_fold_context = {0}; + constant_fold_context.context = arena; + constant_fold_context.postorder_expr_visitor = constant_fold_expr_visitor; + return visit_mod(&constant_fold_context, m); +} Index: Python/compile.c =================================================================== --- Python/compile.c (revision 86219) +++ Python/compile.c (working copy) @@ -31,6 +31,7 @@ #include "compile.h" #include "symtable.h" #include "opcode.h" +#include "optimize.h" int Py_OptimizeFlag = 0; @@ -1169,11 +1170,17 @@ PyCodeObject *co; int addNone = 1; static PyObject *module; + mod_ty tmp; if (!module) { module = PyUnicode_InternFromString(""); if (!module) return NULL; } + + tmp = ast_optimize(mod, c->c_arena); + if(tmp) + mod = tmp; + /* Use 0 for firstlineno initially, will fixup in assemble(). */ if (!compiler_enter_scope(c, module, mod, 0)) return NULL; Index: Include/optimize.h =================================================================== --- Include/optimize.h (revision 0) +++ Include/optimize.h (revision 0) @@ -0,0 +1,16 @@ + +#ifndef Py_OPTIMIZE_H +#define Py_OPTIMIZE_H +#ifdef __cplusplus +extern "C" { +#endif + +struct _mod; /* Declare the existence of this type */ +struct _mod *ast_optimize(struct _mod *, PyArena *); + + + +#ifdef __cplusplus +} +#endif +#endif /* !Py_OPTIMIZE_H */ Index: Lib/test/test_optimize.py =================================================================== --- Lib/test/test_optimize.py (revision 0) +++ Lib/test/test_optimize.py (revision 0) @@ -0,0 +1,101 @@ +import dis +import re +import sys +from io import StringIO +import unittest + +def disassemble(func): + f = StringIO() + tmp = sys.stdout + sys.stdout = f + dis.dis(func) + sys.stdout = tmp + result = f.getvalue() + f.close() + return result + +def dis_single(line): + return disassemble(compile(line, '', 'single')) + +class TestConstantFolding(unittest.TestCase): + def test_elim_build_tuple(self): + # Tuples containing constant expressions should get folded into + # constant tuples: + asm = dis_single('(1 * 1, 2 * 2, 3 * 3, 4 * 4)') + self.assertNotIn('BUILD_TUPLE', asm) + self.assertIn('((1, 4, 9, 16))', asm) + + def test_elim_binary_op(self): + asm = dis_single('3 + 4 + x') + self.assertIn('(7)', asm) + + def test_elim_binary_power(self): + asm = dis_single('2 ** 2 ** 2') + self.assertNotIn('BINARY_POWER', asm) + self.assertIn('(16)', asm) + + def test_elim_boolean_op_1(self): + asm = dis_single('4 and 5 and x and 6') + self.assertNotIn('(4)', asm) + self.assertNotIn('(5)', asm) + self.assertIn('(6)', asm) + + def test_elim_boolean_op_2(self): + asm = dis_single('4 or 5 or x') + self.assertIn('(4)', asm) + self.assertNotIn('(5)', asm) + self.assertNotIn('(x)', asm) + + def test_elim_boolean_op_3(self): + asm = dis_single('4 and 5 and ~6') + self.assertNotIn('(4)', asm) + self.assertNotIn('(5)', asm) + self.assertNotIn('(6)', asm) + self.assertIn('(-7)', asm) + + def test_elim_dead_code(self): + # Elimination of code guarded by expression that can be folded to + # constant 0: + def f(): + if 2 ** 2 ** 2 - 16: + expensive_computation() + asm = disassemble(f) + self.assertNotIn('CALL_FUNCTION', asm) + self.assertNotIn('expensive_computation', asm) + + def test_large_constants(self): + # Ensure that the optimizers limits itself and doesn't generate large + # constants: + asm = dis_single("'-' * 100") + self.assertIn("('-')", asm) + self.assertIn('(100)', asm) + self.assertIn('BINARY_MULTIPLY', asm) + + asm = dis_single('(None,)*2000') + self.assertIn('(None,)', asm) + self.assertIn('(2000)', asm) + self.assertIn('BINARY_MULTIPLY', asm) + + def test_zero(self): + # -0.0 != 0.0 + asm = dis_single('-0.0') + self.assertIn('UNARY_NEGATIVE', asm) + +def test_main(verbose=None): + import sys + from test import support + test_classes = (TestConstantFolding,) + support.run_unittest(*test_classes) + + # verify reference counting + if verbose and hasattr(sys, "gettotalrefcount"): + import gc + counts = [None] * 5 + for i in range(len(counts)): + support.run_unittest(*test_classes) + gc.collect() + counts[i] = sys.gettotalrefcount() + print(counts) + +if __name__ == "__main__": + test_main(verbose=True) Index: Makefile.pre.in =================================================================== --- Makefile.pre.in (revision 86219) +++ Makefile.pre.in (working copy) @@ -310,6 +310,7 @@ Python/modsupport.o \ Python/mystrtoul.o \ Python/mysnprintf.o \ + Python/optimize.o \ Python/peephole.o \ Python/pyarena.o \ Python/pyctype.o \