import sys import ast from pprint import pprint #print("__optimizer__.py imported!") # Taken from symtable.h: DEF_GLOBAL = 1 # global stmt DEF_LOCAL = 2 # assignment in code block DEF_PARAM = 2<<1 # formal parameter DEF_NONLOCAL = 2<<2 # nonlocal stmt USE = 2<<3 # name is used DEF_FREE = 2<<4 # name used but not defined in nested block DEF_FREE_CLASS = 2<<5 # free variable from class's method DEF_IMPORT = 2<<6 # assignment occurred via import DEF_BOUND = (DEF_LOCAL | DEF_PARAM | DEF_IMPORT) # GLOBAL_EXPLICIT and GLOBAL_IMPLICIT are used internally by the symbol # table. GLOBAL is returned from PyST_GetScope() for either of them. # It is stored in ste_symbols at bits 12-15. SCOPE_OFFSET = 11 SCOPE_MASK = (DEF_GLOBAL | DEF_LOCAL | DEF_PARAM | DEF_NONLOCAL) LOCAL = 1 GLOBAL_EXPLICIT = 2 GLOBAL_IMPLICIT = 3 FREE = 4 CELL = 5 # Suppress various local-manipulation transforms in the presence of usage # of these builtins: BUILTINS_THAT_READ_LOCALS = {'locals', 'vars', 'dir'} def log(msg): if 0: print(msg) # Analogous to PyST_GetScope: def get_scope(ste, name): v = ste.symbols.get(name, 0) if v: return (v>> SCOPE_OFFSET) & SCOPE_MASK else: return 0 def add_local(ste, name): # Add a new local var to an STE assert name not in ste.varnames assert name not in ste.symbols ste.symbols[name] = (LOCAL << SCOPE_OFFSET) ste.varnames.append(name) def to_dot(t): def _node_to_dot(node, indent): result = '' prefix = ' ' * indent if isinstance(node, ast.AST): if hasattr(node, 'ste'): result += prefix + 'subgraph cluster_%s {\n' % id(node) result += prefix + ' label = "%s"\n;' % node.ste.name result += prefix + ' node%i [label=<%s>];\n' % (id(node), node.__class__.__name__) for name, field in ast.iter_fields(node): if field is not None: result += prefix + ' node%i -> node%i [label="%s"];\n' % (id(node), id(field), name) result += _node_to_dot(field, indent + 2) if hasattr(node, 'ste'): result += prefix + '}\n' elif isinstance(node, list): result += prefix + 'node%i [label=<[]>];\n' % (id(node)) for i, item in enumerate(node): result += prefix + 'node%i -> node%i [label="[%i]"];\n' % (id(node), id(item), i) result += _node_to_dot(item, indent) elif node is None: pass else: result += prefix + 'node%i [label=<%s>];\n' % (id(node), repr(node)) return result result = 'digraph {\n' result += _node_to_dot(t, 1) result += '}' return result def dot_to_png(dot, filename): from subprocess import Popen, PIPE p = Popen(['/usr/bin/dot', '-T', 'png', '-o', filename], stdin=PIPE) p.communicate(dot.encode('utf-8')) p.wait() class NodePathEntry: __slots__ = ('node', # the ast.Node 'field', # the name of the field 'index', # the index within the field (for lists), or None ) def __init__(self, node, field, index): self.node = node self.field = field self.index = index def __str__(self): result = self.node.__class__.__name__ if hasattr(self.node, 'ste'): result += '<%s>' % repr(self.node.ste.name) result += '.%s' % self.field if self.index is not None: result += '[%i]' % self.index return result class NodePath: ''' A list of NodePathEntries ''' __slots__ = ('entries', ) def __init__(self, entries): self.entries = entries def __str__(self): return '/'.join([str(entry) for entry in self.entries]) def __repr__(self): return '/'.join([str(entry) for entry in self.entries]) def extend(self, node, field, index): return NodePath(self.entries + [NodePathEntry(node, field, index)]) def get_dotted_name(self, childnode=None): nsp = NamespacePath.from_node_path(self, childnode) return nsp.as_dotted_str() class NamespacePath: ''' A list of symbol table entries ''' __slots__ = ('_stes',) def __init__(self, stes): self._stes = stes @classmethod def from_node_path(cls, nodepath, childnode=None): result = [] for npe in nodepath.entries: if hasattr(npe.node, 'ste'): result.append(npe.node.ste) if childnode is not None: if hasattr(childnode, 'ste'): result.append(childnode.ste) return NamespacePath(result) def as_dotted_str(self): ''' Generate a dotted string representing the namespace e.g. "SomeClass.some_method" ''' # Start at 1: don't include the "top" STE: return '.'.join([ste.name for ste in self._stes[1:]]) def get_parent_path(self): return NamespacePath(self._stes[:-1]) def get_innermost_scope(self): return self._stes[-1] class PathTransformer: """ Similar to an ast.NodeTransformer, but passes in a path when visiting a node The path is passed in as a list of (node, field, index) triples """ def visit(self, node, path=None): """Visit a node.""" method = 'visit_' + node.__class__.__name__ visitor = getattr(self, method, self.generic_visit) if path is None: path = NodePath([]) return visitor(node, path) def generic_visit(self, node, path): for field, old_value in ast.iter_fields(node): old_value = getattr(node, field, None) if isinstance(old_value, list): new_values = [] for idx, value in enumerate(old_value): if isinstance(value, ast.AST): value = self.visit(value, path.extend(node, field, idx)) if value is None: continue elif not isinstance(value, ast.AST): new_values.extend(value) continue new_values.append(value) old_value[:] = new_values elif isinstance(old_value, ast.AST): new_node = self.visit(old_value, path.extend(node, field, None)) if new_node is None: delattr(node, field) else: setattr(node, field, new_node) return node def make_load_name(name, locnode): return ast.copy_location(ast.Name(id=name, ctx=ast.Load()), locnode) def make_assignment(name, expr, locnode): name_node = ast.copy_location(ast.Name(id=name, ctx=ast.Store()), locnode) return ast.copy_location(ast.Assign(targets=[name_node], value=expr), locnode) def ast_clone(node): #log('ast_clone', node) if isinstance(node, ast.AST): clone = node.__class__() clone = ast.copy_location(clone, node) for name, value in ast.iter_fields(node): if isinstance(value, ast.AST): cvalue = ast_clone(value) elif isinstance(value, list): cvalue = [ast_clone(el) for el in value] else: cvalue = value setattr(clone, name, cvalue) return clone elif isinstance(node, str): return node else: raise ValueError("Don't know how to clone %r" % node) class InlineBodyFixups(ast.NodeTransformer): """ Fix up the cloned body of a function, for inlining """ def __init__(self, varprefix, ste): self.varprefix = varprefix self.ste = ste def visit_Name(self, node): # Replace local names with prefixed versions: self.generic_visit(node) scope = get_scope(self.ste, node.id) if scope == LOCAL: node.id = self.varprefix + node.id return node def visit_Return(self, node): self.generic_visit(node) # replace (the final) return with "__returnval__ = expr": return make_assignment(self.varprefix + "__returnval__", node.value, node) class FunctionInliner(PathTransformer): def __init__(self, tree, def_dict): self.tree = tree self.def_dict = def_dict # dict from dottedname to ast.FunctionDef #self.funcdef = funcdef #self.dotted_name = dotted_name #assert hasattr(funcdef, 'ste') self.num_callsites = 0 self.log('inlining calls to %r' % def_dict) #self.log('ste for body: %r' % funcdef.ste) def log(self, msg): if 0: print('%s: %s' % (self.__class__.__name__, msg)) def guess_dotted_name_for_def_from_call(self, call, path): # Return the name of the stored global if this is inlinable, otherwise None assert isinstance(call, ast.Call) if isinstance(call.func, ast.Name): # Name must match: if call.func.id in self.def_dict: return call.func.id if isinstance(call.func, ast.Attribute): # Handle simple "self.METHOD_NAME" case: attr = call.func value = attr.value if isinstance(value, ast.Name) and isinstance(value.ctx, ast.Load): if value.id == 'self': #print('attr.attr:', attr.attr) #print('path:', path) #print(path.get_dotted_name()) parent_nsp = NamespacePath.from_node_path(path).get_parent_path() return parent_nsp.as_dotted_str() + '.' + attr.attr # FIXME: only makes sense to traverse within this class and within subclasses # FIXME: fake the MRO and pick an appropriate class # Don't try to inline where the function is a non-trivial # expression e.g. "f()()", or for other awkward cases return None def _setup_args(self, ste, varprefix, funcdef, call, instance): # Create assignment statements of the form: # __inline__x = expr for x # for each parameter # We will insert before the callsite assignments = [] formalparams = funcdef.args.args actualparams = call.args if instance: # Synthesize a "self" at the front of the params: # Note that this must be just a load of a local name, to avoid, # say, injecting a 2nd getattr: assert is_read_from_local(ste, instance) actualparams = [instance] + actualparams for formal, actual in zip(formalparams, actualparams): self.log(' formal: %s' % ast.dump(formal)) self.log(' actual: %s' % ast.dump(actual)) add_local(ste, varprefix+formal.arg) assign = make_assignment(varprefix+formal.arg, actual, call) assignments.append(assign) # FIXME: these seem to be being done using LOAD_NAME/STORE_NAME; why isn't it using _FAST? # aha: because they're in module scope, not within a function. return assignments def visit_Call(self, call, path): # Stop inlining beyond an arbitrary cutoff # (bm_simplecall was exploding): if self.num_callsites> 1000: return call # Visit children: self.generic_visit(call, path) dotted_name = self.guess_dotted_name_for_def_from_call(call, path) if dotted_name is None: return call self.log('Got inlinable callsite of:\n dotted_name: %r\n path: %s\n node:%s' % (dotted_name, path, ast.dump(call))) if not isinstance(call.func, (ast.Name, ast.Attribute)): # Don't try to inline where the function is a non-trivial # expression e.g. "f()()" print('foo!') return call if isinstance(call.func, ast.Attribute): # Emergency cutoff for method inlining: if 0: self.log('Not inlining attribute %s' % ast.dump(call.func)) return call self.log('Considering call to: %s' % ast.dump(call.func)) self.log('NodePath: %r' % path) nsp = NamespacePath.from_node_path(path) self.log('NamespacePath for callsite: %r' % nsp) if dotted_name not in self.def_dict: return call #if call.func.id != self.funcdef.name: # return call # Locate innermost scope at callsite: ste = nsp.get_innermost_scope() self.log('Inlining call to: %r within %r' % (dotted_name, ste.name)) self.num_callsites += 1 funcdef = self.def_dict[dotted_name] self.log(ast.dump(funcdef)) varprefix = '__internal__inline_%s%x__' % (dotted_name, id(call)) self.log('varprefix: %s' % varprefix) # Generate a body of specialized statements that can replace the call: if isinstance(call.func, ast.Attribute): attr = call.func value = attr.value assert isinstance(value, ast.Name) and isinstance(value.ctx, ast.Load) instance = value else: instance = None specialized = self._setup_args(ste, varprefix, funcdef, call, instance) # Introduce __returnval__; initialize it to None, equivalent to # implicit "return None" if there are no "Return" nodes: returnval = varprefix + "__returnval__" add_local(ste, returnval) # FIXME: this requires "None", how to do this in AST? assign = make_assignment(returnval, make_load_name('None', call), call) # FIXME: this leads to LOAD_NAME None, when it should be LOAD_CONST, surely? specialized.append(assign) # Make inline body, generating various statements # ending with: # __inline____returnval = expr inline_body = [] fixer = InlineBodyFixups(varprefix, funcdef.ste) for stmt in funcdef.body: assert isinstance(stmt, ast.AST) inline_body.append(fixer.visit(ast_clone(stmt))) #log('inline_body:', inline_body) specialized += inline_body #log('Parent: %s' % ast.dump(find_parent(self.tree, call))) # FIXME: need some smarts about the value of the "Specialize": # it's the final Expr within the body specialized_result = ast.copy_location(ast.Name(id=returnval, ctx=ast.Load()), call) self.log(' specialized:') for stmt in specialized: self.log(' %s' % ast.dump(stmt)) return ast.copy_location(ast.Specialize(name=call.func, expected_value='__internal__.saved.' + dotted_name, generalized=call, specialized_body=specialized, specialized_result=specialized_result), call) # Replace the call with a load from __inline____returnval__ return ast.copy_location(ast.Name(id=returnval, ctx=ast.Load()), call) class NotInlinable(Exception): pass class CheckInlinableVisitor(PathTransformer): # Walk an ast.FunctionDef subtree, determining if it's inlinable def __init__(self, funcdef): self.funcdef = funcdef self.returns = [] # Various nodes aren't handlable: def visit_FunctionDef(self, node, path): if node != self.funcdef: raise NotInlinable() self.generic_visit(node, path) return node def visit_ClassDef(self, node, path): raise NotInlinable() def visit_Yield(self, node, path): raise NotInlinable() def visit_Import(self, node, path): raise NotInlinable() def visit_ImportFrom(self, node, path): raise NotInlinable() def visit_Return(self, node, path): self.returns.append(path) return node def fn_is_inlinable(funcdef, mod): # Should we inline calls to the given FunctionDef ? assert(isinstance(funcdef, ast.FunctionDef)) # Only inline "simple" calling conventions for now: if len(funcdef.decorator_list)> 0: return False if (funcdef.args.vararg is not None or funcdef.args.kwarg is not None or funcdef.args.kwonlyargs != [] or funcdef.args.defaults != [] or funcdef.args.kw_defaults != []): return False # Don't try to inline generators and other awkward cases: v = CheckInlinableVisitor(funcdef) try: v.visit(funcdef) except NotInlinable: return False # TODO: restrict to just those functions with only a "return" at # the end (or implicit "return None"), no "return" in awkward places # (but could have other control flow) # Specifically: no returns in places that have successor code # for each return: # FIXME: for now, only inline functions which have a single, final # explicit "return" at the end, or no returns: log('returns of %s: %r' % (funcdef.name, v.returns)) if len(v.returns)>1: return False # Single "return"? Then it must be at the top level if len(v.returns) == 1: rpath = v.returns[0] # Must be at toplevel: if len(rpath.entries) != 1: return False # Must be at the end of that level if rpath.entries[0].index != len(funcdef.body)-1: return False # Don't inline functions that use free or cell vars # (just locals and globals): assert hasattr(funcdef, 'ste') ste = funcdef.ste for varname in ste.varnames: scope = get_scope(ste, varname) #log('%r: %r' % (varname, scope)) if scope not in {LOCAL, GLOBAL_EXPLICIT, GLOBAL_IMPLICIT}: return False # Don't inline functions that use the "locals" or "vars" builtins: if varname in BUILTINS_THAT_READ_LOCALS: return False # Don't inline functions with names that get rebound: for node in ast.walk(mod): if isinstance(node, ast.Assign): for target in node.targets: if isinstance(target, ast.Name): if target.id == funcdef.name: if isinstance(target.ctx, ast.Store): return False return True class InlinableFunctionFinder(PathTransformer): # Locate function definitions that look inlinable, recording, and adding globals: def __init__(self, mod): self.mod = mod self.funcdefs = {} def log(self, msg): if 0: print('%s: %s' % (self.__class__.__name__, msg)) def visit_FunctionDef(self, funcdef, path): self.log('got function def: %r %r' % (funcdef.name, path)) self.generic_visit(funcdef, path) if fn_is_inlinable(funcdef, self.mod): dotted_name = path.get_dotted_name(funcdef) self.log('using dotted name: %r for %s' % (dotted_name, path)) self.funcdefs[dotted_name] = funcdef storedname = '__internal__.saved.' + dotted_name global_ = ast.copy_location(ast.Global(names=[storedname]), funcdef) assign = ast.copy_location(ast.Assign(targets=[ast.Name(id=storedname, ctx=ast.Store())], value=ast.Name(id=funcdef.name, ctx=ast.Load())), funcdef) ast.fix_missing_locations(assign) return [funcdef, global_, assign] else: return funcdef def _inline_function_calls(t): v = InlinableFunctionFinder(t) v.visit(t) def_dict = v.funcdefs # print('def_dict:%r' % def_dict) # Locate call sites: inliner = FunctionInliner(t, def_dict) inliner.visit(t) return t def is_local_name(ste, expr): assert isinstance(expr, ast.AST) if isinstance(expr, ast.Name): if get_scope(ste, expr.id) == LOCAL: return True def is_write_to_local(ste, expr): assert isinstance(expr, ast.AST) if is_local_name(ste, expr): if isinstance(expr.ctx, ast.Store): return True def is_read_from_local(ste, expr): assert isinstance(expr, ast.AST) if is_local_name(ste, expr): if isinstance(expr.ctx, ast.Load): return True def is_constant(expr): assert isinstance(expr, ast.AST) if isinstance(expr, (ast.Num, ast.Str, ast.Bytes)): return True # The following optimizations currently can only cope with functions with a # single basic-block, to avoid the need to build a CFG and do real data flow # analysis. class MoreThanOneBasicBlock(Exception): pass class LocalAssignmentWalker(ast.NodeTransformer): def __init__(self, funcdef): self.funcdef = funcdef self.local_values = {} def log(self, msg): if 0: print(msg) def _is_local(self, varname): ste = self.funcdef.ste return get_scope(ste, varname) == LOCAL def _is_propagatable(self, expr): # Is this expression propagatable to successive store operations? # We can propagate reads of locals (until they are written to): if is_read_from_local(self.funcdef.ste, expr): return True if is_constant(expr): return True return False def visit_Assign(self, assign): self.generic_visit(assign) self.log(' got assign: %s' % ast.dump(assign)) # Keep track of assignments to locals that are directly of constants # or of other other locals: for target in assign.targets: if is_write_to_local(self.funcdef.ste, target): self.log(' write to %r <- %s' % (target.id, ast.dump(assign.value))) if len(assign.targets) == 1: target = assign.targets[0] if self._is_propagatable(assign.value): self.log(' recording value for %r: %s' % (target.id, ast.dump(assign.value))) self.local_values[target.id] = assign.value continue self.log(' %r is no longer known' % target.id) self.local_values[target.id] = None # Propagate earlier copies to this assignment: if len(assign.targets) == 1: target = assign.targets[0] if is_write_to_local(self.funcdef.ste, target): if target.id in self.local_values: value = self.local_values[target.id] if value is not None: pass #self.log(' copy-propagation target') return assign def visit_Name(self, name): self.log('visit_Name %r' % name) self.generic_visit(name) if is_read_from_local(self.funcdef.ste, name): self.log(' got read from local: %s' % ast.dump(name)) if name.id in self.local_values: value = self.local_values[name.id] if value is not None: self.log(' copy-propagating: %r <- %s' % (name.id, ast.dump(value))) return value # clone this? return name # Nodes implying branching and looping: def visit_For(self, node): raise MoreThanOneBasicBlock() def visit_While(self, node): raise MoreThanOneBasicBlock() def visit_If(self, node): raise MoreThanOneBasicBlock() def visit_With(self, node): raise MoreThanOneBasicBlock() def visit_TryExcept(self, node): raise MoreThanOneBasicBlock() def visit_TryFinally(self, node): raise MoreThanOneBasicBlock() def visit_IfExp(self, node): raise MoreThanOneBasicBlock() class CopyPropagation(ast.NodeTransformer): def visit_FunctionDef(self, funcdef): log('CopyPropagation: got function def: %r' % funcdef.name) self.generic_visit(funcdef) try: w = LocalAssignmentWalker(funcdef) w.visit(funcdef) except MoreThanOneBasicBlock: pass return funcdef def _copy_propagation(t): # Very simple copy propagation, which (for simplicity), requires that we # have a single basic block v = CopyPropagation() v.visit(t) return t class ReferenceToLocalFinder(PathTransformer): # Gather all reads/writes of locals within the given FunctionDef def __init__(self, funcdef): assert isinstance(funcdef, ast.FunctionDef) self.funcdef = funcdef # varnames of locals: self.local_reads = set() self.local_writes = set() self.globals = set() def log(self, msg): if 0: print(msg) def visit_Name(self, node, path): scope = get_scope(self.funcdef.ste, node.id) if scope == LOCAL: self.log(' found local: %r %r' % (ast.dump(node), path)) if isinstance(node.ctx, ast.Store): self.local_writes.add(node.id) elif isinstance(node.ctx, ast.Load): self.local_reads.add(node.id) else: assert isinstance(node.ctx, ast.Del) # FIXME: what about other cases? elif scope in {GLOBAL_EXPLICIT, GLOBAL_IMPLICIT}: self.globals.add(node.id) return node def visit_AugAssign(self, augassign, path): # VAR += EXPR references VAR, and we can't remove it since it could # have arbitrary sideeffects if isinstance(augassign.target, ast.Name): target = augassign.target scope = get_scope(self.funcdef.ste, target.id) if scope == LOCAL: self.log(' found local: %r %r' % (ast.dump(augassign), path)) # An augassign is both a read and a write: self.local_writes.add(target.id) self.local_reads.add(target.id) self.generic_visit(augassign, path) return augassign class RemoveAssignmentToUnusedLocals(ast.NodeTransformer): # Replace all Assign([local], expr) with Expr(expr) for the given locals def __init__(self, varnames): self.varnames = varnames def visit_Assign(self, node): if len(node.targets) == 1: if isinstance(node.targets[0], ast.Name): if node.targets[0].id in self.varnames: if isinstance(node.targets[0].ctx, ast.Store): # Eliminate the assignment return ast.copy_location(ast.Expr(node.value), node.value) # FIXME: also eliminate the variable from symtable return node class RedundantLocalRemover(PathTransformer): def log(self, msg): if 0: print('%s: %s' % (self.__class__.__name__, msg)) def visit_FunctionDef(self, funcdef, path): self.log('got function def: %r' % funcdef.name) v = ReferenceToLocalFinder(funcdef) v.visit(funcdef, path) self.generic_visit(funcdef, path) # Don't ellide locals if this function references the builtins # that access them: if v.globals.intersection(BUILTINS_THAT_READ_LOCALS): return funcdef unused_writes = v.local_writes - v.local_reads if unused_writes: self.log(' globals: %s' % v.globals) self.log(' loaded from: %s' % v.local_reads) self.log(' stored to: %s:' % v.local_writes) self.log(' unused: %s' % unused_writes) v = RemoveAssignmentToUnusedLocals(unused_writes) v.visit(funcdef) return funcdef def _remove_redundant_locals(t): v = RedundantLocalRemover() v.visit(t) return t # For now restrict ourselves to just a few places: def is_test_code(t, filename): if filename == 'optimizable.py': return True for n in ast.walk(t): if isinstance(n, ast.FunctionDef): if n.name == 'function_to_be_inlined': return True return False def dump_dot(t, filename): return False for n in ast.walk(t): if isinstance(n, ast.FunctionDef): if n.name == 'simple_method': return True return False #class OptimizationError(Exception): # def __init__(self timing = 0 if timing: try: import time except ImportError: # "time" doesn't exist during the build process timing = 0 from pprint import pprint def optimize_ast(t, filename, st_blocks): if 0: print("optimize_ast called: %s" % filename) if is_test_code(t, filename): if timing: t0 = time.time() dot_before = to_dot(t) try: # pprint(st_blocks) # log(t) # pprint(t) # log(ast.dump(t)) if isinstance(t, ast.Module): if dump_dot(t, filename): dot_to_png(to_dot(t), 'before.png') t = _inline_function_calls(t) #cfg = CFG.from_ast(t) #print(cfg.to_dot()) #dot_to_png(cfg.to_dot(), 'cfg.png') t = _copy_propagation(t) t = _remove_redundant_locals(t) if dump_dot(t, filename): dot_to_png(to_dot(t), 'after.png') except: print('Exception during optimization of %r' % filename) # dot_to_png(dot_before, 'before.png') raise if timing: t1 = time.time() print('Optimizing %r took %ss' % (filename, t1 - t0)) if 0: print('finished optimizing') if filename == 'optimizable.py': print(ast.dump(t)) return t