cpython/Lib/compiler/pycodegen.py

823 lines
24 KiB
Python

import os
import marshal
import stat
import struct
import types
from cStringIO import StringIO
from compiler import ast, parse, walk
from compiler import pyassem, misc
from compiler.pyassem import CO_VARARGS, CO_VARKEYWORDS, CO_NEWLOCALS, TupleArg
callfunc_opcode_info = {
# (Have *args, Have **args) : opcode
(0,0) : "CALL_FUNCTION",
(1,0) : "CALL_FUNCTION_VAR",
(0,1) : "CALL_FUNCTION_KW",
(1,1) : "CALL_FUNCTION_VAR_KW",
}
def compile(filename):
f = open(filename)
buf = f.read()
f.close()
mod = Module(buf, filename)
mod.compile()
f = open(filename + "c", "wb")
mod.dump(f)
f.close()
class Module:
def __init__(self, source, filename):
self.filename = filename
self.source = source
self.code = None
def compile(self):
ast = parse(self.source)
root, filename = os.path.split(self.filename)
gen = ModuleCodeGenerator(filename)
walk(ast, gen, 1)
self.code = gen.getCode()
def dump(self, f):
f.write(self.getPycHeader())
marshal.dump(self.code, f)
MAGIC = (50811 | (ord('\r')<<16) | (ord('\n')<<24))
def getPycHeader(self):
# compile.c uses marshal to write a long directly, with
# calling the interface that would also generate a 1-byte code
# to indicate the type of the value. simplest way to get the
# same effect is to call marshal and then skip the code.
magic = marshal.dumps(self.MAGIC)[1:]
mtime = os.stat(self.filename)[stat.ST_MTIME]
mtime = struct.pack('i', mtime)
return magic + mtime
class CodeGenerator:
optimized = 0 # is namespace access optimized?
def __init__(self, filename):
## Subclasses must define a constructor that intializes self.graph
## before calling this init function
## self.graph = pyassem.PyFlowGraph()
self.filename = filename
self.locals = misc.Stack()
self.loops = misc.Stack()
self.curStack = 0
self.maxStack = 0
self._setupGraphDelegation()
def _setupGraphDelegation(self):
self.emit = self.graph.emit
self.newBlock = self.graph.newBlock
self.startBlock = self.graph.startBlock
self.nextBlock = self.graph.nextBlock
self.setDocstring = self.graph.setDocstring
def getCode(self):
"""Return a code object"""
return self.graph.getCode()
# Next five methods handle name access
def isLocalName(self, name):
return self.locals.top().has_elt(name)
def storeName(self, name):
self._nameOp('STORE', name)
def loadName(self, name):
self._nameOp('LOAD', name)
def delName(self, name):
self._nameOp('DELETE', name)
def _nameOp(self, prefix, name):
if not self.optimized:
self.emit(prefix + '_NAME', name)
return
if self.isLocalName(name):
self.emit(prefix + '_FAST', name)
else:
self.emit(prefix + '_GLOBAL', name)
def set_lineno(self, node):
"""Emit SET_LINENO if node has lineno attribute
Returns true if SET_LINENO was emitted.
There are no rules for when an AST node should have a lineno
attribute. The transformer and AST code need to be reviewed
and a consistent policy implemented and documented. Until
then, this method works around missing line numbers.
"""
lineno = getattr(node, 'lineno', None)
if lineno is not None:
self.emit('SET_LINENO', lineno)
return 1
return 0
# The first few visitor methods handle nodes that generator new
# code objects
def visitModule(self, node):
lnf = walk(node.node, LocalNameFinder(), 0)
self.locals.push(lnf.getLocals())
self.setDocstring(node.doc)
self.visit(node.node)
self.emit('LOAD_CONST', None)
self.emit('RETURN_VALUE')
def visitFunction(self, node):
self._visitFuncOrLambda(node, isLambda=0)
self.storeName(node.name)
def visitLambda(self, node):
self._visitFuncOrLambda(node, isLambda=1)
## self.storeName("<lambda>")
def _visitFuncOrLambda(self, node, isLambda):
gen = FunctionCodeGenerator(node, self.filename, isLambda)
walk(node.code, gen)
gen.finish()
self.set_lineno(node)
for default in node.defaults:
self.visit(default)
self.emit('LOAD_CONST', gen.getCode())
self.emit('MAKE_FUNCTION', len(node.defaults))
def visitClass(self, node):
gen = ClassCodeGenerator(node, self.filename)
walk(node.code, gen)
gen.finish()
self.set_lineno(node)
self.emit('LOAD_CONST', node.name)
for base in node.bases:
self.visit(base)
self.emit('BUILD_TUPLE', len(node.bases))
self.emit('LOAD_CONST', gen.getCode())
self.emit('MAKE_FUNCTION', 0)
self.emit('CALL_FUNCTION', 0)
self.emit('BUILD_CLASS')
self.storeName(node.name)
# The rest are standard visitor methods
# The next few implement control-flow statements
def visitIf(self, node):
end = self.newBlock()
numtests = len(node.tests)
for i in range(numtests):
test, suite = node.tests[i]
self.set_lineno(test)
self.visit(test)
## if i == numtests - 1 and not node.else_:
## nextTest = end
## else:
## nextTest = self.newBlock()
nextTest = self.newBlock()
self.emit('JUMP_IF_FALSE', nextTest)
self.nextBlock()
self.emit('POP_TOP')
self.visit(suite)
self.emit('JUMP_FORWARD', end)
self.nextBlock(nextTest)
self.emit('POP_TOP')
if node.else_:
self.visit(node.else_)
self.nextBlock(end)
def visitWhile(self, node):
self.set_lineno(node)
loop = self.newBlock()
else_ = self.newBlock()
after = self.newBlock()
self.emit('SETUP_LOOP', after)
self.nextBlock(loop)
self.loops.push(loop)
self.set_lineno(node)
self.visit(node.test)
self.emit('JUMP_IF_FALSE', else_ or after)
self.nextBlock()
self.emit('POP_TOP')
self.visit(node.body)
self.emit('JUMP_ABSOLUTE', loop)
self.startBlock(else_) # or just the POPs if not else clause
self.emit('POP_TOP')
self.emit('POP_BLOCK')
if node.else_:
self.visit(node.else_)
self.loops.pop()
self.nextBlock(after)
def visitFor(self, node):
start = self.newBlock()
anchor = self.newBlock()
after = self.newBlock()
self.loops.push(start)
self.set_lineno(node)
self.emit('SETUP_LOOP', after)
self.visit(node.list)
self.visit(ast.Const(0))
self.nextBlock(start)
self.set_lineno(node)
self.emit('FOR_LOOP', anchor)
self.visit(node.assign)
self.visit(node.body)
self.emit('JUMP_ABSOLUTE', start)
self.nextBlock(anchor)
self.emit('POP_BLOCK')
if node.else_:
self.visit(node.else_)
self.loops.pop()
self.nextBlock(after)
def visitBreak(self, node):
if not self.loops:
raise SyntaxError, "'break' outside loop (%s, %d)" % \
(self.filename, node.lineno)
self.set_lineno(node)
self.emit('BREAK_LOOP')
def visitContinue(self, node):
if not self.loops:
raise SyntaxError, "'continue' outside loop (%s, %d)" % \
(self.filename, node.lineno)
l = self.loops.top()
self.set_lineno(node)
self.emit('JUMP_ABSOLUTE', l)
self.nextBlock()
def visitTest(self, node, jump):
end = self.newBlock()
for child in node.nodes[:-1]:
self.visit(child)
self.emit(jump, end)
self.nextBlock()
self.emit('POP_TOP')
self.visit(node.nodes[-1])
self.nextBlock(end)
def visitAnd(self, node):
self.visitTest(node, 'JUMP_IF_FALSE')
def visitOr(self, node):
self.visitTest(node, 'JUMP_IF_TRUE')
def visitCompare(self, node):
self.visit(node.expr)
cleanup = self.newBlock()
for op, code in node.ops[:-1]:
self.visit(code)
self.emit('DUP_TOP')
self.emit('ROT_THREE')
self.emit('COMPARE_OP', op)
self.emit('JUMP_IF_FALSE', cleanup)
self.nextBlock()
self.emit('POP_TOP')
# now do the last comparison
if node.ops:
op, code = node.ops[-1]
self.visit(code)
self.emit('COMPARE_OP', op)
if len(node.ops) > 1:
end = self.newBlock()
self.emit('JUMP_FORWARD', end)
self.nextBlock(cleanup)
self.emit('ROT_TWO')
self.emit('POP_TOP')
self.nextBlock(end)
# exception related
def visitAssert(self, node):
# XXX would be interesting to implement this via a
# transformation of the AST before this stage
end = self.newBlock()
self.set_lineno(node)
# XXX __debug__ and AssertionError appear to be special cases
# -- they are always loaded as globals even if there are local
# names. I guess this is a sort of renaming op.
self.emit('LOAD_GLOBAL', '__debug__')
self.emit('JUMP_IF_FALSE', end)
self.nextBlock()
self.emit('POP_TOP')
self.visit(node.test)
self.emit('JUMP_IF_TRUE', end)
self.nextBlock()
self.emit('LOAD_GLOBAL', 'AssertionError')
self.visit(node.fail)
self.emit('RAISE_VARARGS', 2)
self.nextBlock(end)
self.emit('POP_TOP')
def visitRaise(self, node):
self.set_lineno(node)
n = 0
if node.expr1:
self.visit(node.expr1)
n = n + 1
if node.expr2:
self.visit(node.expr2)
n = n + 1
if node.expr3:
self.visit(node.expr3)
n = n + 1
self.emit('RAISE_VARARGS', n)
def visitTryExcept(self, node):
handlers = self.newBlock()
end = self.newBlock()
if node.else_:
lElse = self.newBlock()
else:
lElse = end
self.set_lineno(node)
self.emit('SETUP_EXCEPT', handlers)
self.visit(node.body)
self.emit('POP_BLOCK')
self.emit('JUMP_FORWARD', lElse)
self.nextBlock(handlers)
last = len(node.handlers) - 1
for i in range(len(node.handlers)):
expr, target, body = node.handlers[i]
self.set_lineno(expr)
if expr:
self.emit('DUP_TOP')
self.visit(expr)
self.emit('COMPARE_OP', 'exception match')
next = self.newBlock()
self.emit('JUMP_IF_FALSE', next)
self.nextBlock()
self.emit('POP_TOP')
self.emit('POP_TOP')
if target:
self.visit(target)
else:
self.emit('POP_TOP')
self.emit('POP_TOP')
self.visit(body)
self.emit('JUMP_FORWARD', end)
if expr:
self.nextBlock(next)
self.emit('POP_TOP')
self.emit('END_FINALLY')
if node.else_:
self.nextBlock(lElse)
self.visit(node.else_)
self.nextBlock(end)
def visitTryFinally(self, node):
final = self.newBlock()
self.set_lineno(node)
self.emit('SETUP_FINALLY', final)
self.visit(node.body)
self.emit('POP_BLOCK')
self.emit('LOAD_CONST', None)
self.nextBlock(final)
self.visit(node.final)
self.emit('END_FINALLY')
# misc
## def visitStmt(self, node):
## # nothing to do except walk the children
## pass
def visitDiscard(self, node):
self.visit(node.expr)
self.emit('POP_TOP')
def visitConst(self, node):
self.emit('LOAD_CONST', node.value)
def visitKeyword(self, node):
self.emit('LOAD_CONST', node.name)
self.visit(node.expr)
def visitGlobal(self, node):
# no code to generate
pass
def visitName(self, node):
self.loadName(node.name)
def visitPass(self, node):
self.set_lineno(node)
def visitImport(self, node):
self.set_lineno(node)
for name in node.names:
self.emit('IMPORT_NAME', name)
self.storeName(name)
def visitFrom(self, node):
self.set_lineno(node)
self.emit('IMPORT_NAME', node.modname)
for name in node.names:
if name == '*':
self.namespace = 0
self.emit('IMPORT_FROM', name)
self.emit('POP_TOP')
def visitGetattr(self, node):
self.visit(node.expr)
self.emit('LOAD_ATTR', node.attrname)
# next five implement assignments
def visitAssign(self, node):
self.set_lineno(node)
self.visit(node.expr)
dups = len(node.nodes) - 1
for i in range(len(node.nodes)):
elt = node.nodes[i]
if i < dups:
self.emit('DUP_TOP')
if isinstance(elt, ast.Node):
self.visit(elt)
def visitAssName(self, node):
if node.flags == 'OP_ASSIGN':
self.storeName(node.name)
elif node.flags == 'OP_DELETE':
self.delName(node.name)
else:
print "oops", node.flags
def visitAssAttr(self, node):
self.visit(node.expr)
if node.flags == 'OP_ASSIGN':
self.emit('STORE_ATTR', node.attrname)
elif node.flags == 'OP_DELETE':
self.emit('DELETE_ATTR', node.attrname)
else:
print "warning: unexpected flags:", node.flags
print node
def visitAssTuple(self, node):
if findOp(node) != 'OP_DELETE':
self.emit('UNPACK_SEQUENCE', len(node.nodes))
for child in node.nodes:
self.visit(child)
visitAssList = visitAssTuple
def visitExec(self, node):
self.visit(node.expr)
if node.locals is None:
self.emit('LOAD_CONST', None)
else:
self.visit(node.locals)
if node.globals is None:
self.emit('DUP_TOP')
else:
self.visit(node.globals)
self.emit('EXEC_STMT')
def visitCallFunc(self, node):
pos = 0
kw = 0
self.set_lineno(node)
self.visit(node.node)
for arg in node.args:
self.visit(arg)
if isinstance(arg, ast.Keyword):
kw = kw + 1
else:
pos = pos + 1
if node.star_args is not None:
self.visit(node.star_args)
if node.dstar_args is not None:
self.visit(node.dstar_args)
have_star = node.star_args is not None
have_dstar = node.dstar_args is not None
opcode = callfunc_opcode_info[have_star, have_dstar]
self.emit(opcode, kw << 8 | pos)
def visitPrint(self, node):
self.set_lineno(node)
for child in node.nodes:
self.visit(child)
self.emit('PRINT_ITEM')
def visitPrintnl(self, node):
self.visitPrint(node)
self.emit('PRINT_NEWLINE')
def visitReturn(self, node):
self.set_lineno(node)
self.visit(node.value)
self.emit('RETURN_VALUE')
# slice and subscript stuff
def visitSlice(self, node):
self.visit(node.expr)
slice = 0
if node.lower:
self.visit(node.lower)
slice = slice | 1
if node.upper:
self.visit(node.upper)
slice = slice | 2
if node.flags == 'OP_APPLY':
self.emit('SLICE+%d' % slice)
elif node.flags == 'OP_ASSIGN':
self.emit('STORE_SLICE+%d' % slice)
elif node.flags == 'OP_DELETE':
self.emit('DELETE_SLICE+%d' % slice)
else:
print "weird slice", node.flags
raise
def visitSubscript(self, node):
self.visit(node.expr)
for sub in node.subs:
self.visit(sub)
if len(node.subs) > 1:
self.emit('BUILD_TUPLE', len(node.subs))
if node.flags == 'OP_APPLY':
self.emit('BINARY_SUBSCR')
elif node.flags == 'OP_ASSIGN':
self.emit('STORE_SUBSCR')
elif node.flags == 'OP_DELETE':
self.emit('DELETE_SUBSCR')
# binary ops
def binaryOp(self, node, op):
self.visit(node.left)
self.visit(node.right)
self.emit(op)
def visitAdd(self, node):
return self.binaryOp(node, 'BINARY_ADD')
def visitSub(self, node):
return self.binaryOp(node, 'BINARY_SUBTRACT')
def visitMul(self, node):
return self.binaryOp(node, 'BINARY_MULTIPLY')
def visitDiv(self, node):
return self.binaryOp(node, 'BINARY_DIVIDE')
def visitMod(self, node):
return self.binaryOp(node, 'BINARY_MODULO')
def visitPower(self, node):
return self.binaryOp(node, 'BINARY_POWER')
def visitLeftShift(self, node):
return self.binaryOp(node, 'BINARY_LSHIFT')
def visitRightShift(self, node):
return self.binaryOp(node, 'BINARY_RSHIFT')
# unary ops
def unaryOp(self, node, op):
self.visit(node.expr)
self.emit(op)
def visitInvert(self, node):
return self.unaryOp(node, 'UNARY_INVERT')
def visitUnarySub(self, node):
return self.unaryOp(node, 'UNARY_NEGATIVE')
def visitUnaryAdd(self, node):
return self.unaryOp(node, 'UNARY_POSITIVE')
def visitUnaryInvert(self, node):
return self.unaryOp(node, 'UNARY_INVERT')
def visitNot(self, node):
return self.unaryOp(node, 'UNARY_NOT')
def visitBackquote(self, node):
return self.unaryOp(node, 'UNARY_CONVERT')
# bit ops
def bitOp(self, nodes, op):
self.visit(nodes[0])
for node in nodes[1:]:
self.visit(node)
self.emit(op)
def visitBitand(self, node):
return self.bitOp(node.nodes, 'BINARY_AND')
def visitBitor(self, node):
return self.bitOp(node.nodes, 'BINARY_OR')
def visitBitxor(self, node):
return self.bitOp(node.nodes, 'BINARY_XOR')
# object constructors
def visitEllipsis(self, node):
self.emit('LOAD_CONST', Ellipsis)
def visitTuple(self, node):
for elt in node.nodes:
self.visit(elt)
self.emit('BUILD_TUPLE', len(node.nodes))
def visitList(self, node):
for elt in node.nodes:
self.visit(elt)
self.emit('BUILD_LIST', len(node.nodes))
def visitSliceobj(self, node):
for child in node.nodes:
self.visit(child)
self.emit('BUILD_SLICE', len(node.nodes))
def visitDict(self, node):
lineno = getattr(node, 'lineno', None)
if lineno:
set.emit('SET_LINENO', lineno)
self.emit('BUILD_MAP', 0)
for k, v in node.items:
lineno2 = getattr(node, 'lineno', None)
if lineno2 is not None and lineno != lineno2:
self.emit('SET_LINENO', lineno2)
lineno = lineno2
self.emit('DUP_TOP')
self.visit(v)
self.emit('ROT_TWO')
self.visit(k)
self.emit('STORE_SUBSCR')
class ModuleCodeGenerator(CodeGenerator):
super_init = CodeGenerator.__init__
def __init__(self, filename):
# XXX <module> is ? in compile.c
self.graph = pyassem.PyFlowGraph("<module>", filename)
self.super_init(filename)
class FunctionCodeGenerator(CodeGenerator):
super_init = CodeGenerator.__init__
optimized = 1
lambdaCount = 0
def __init__(self, func, filename, isLambda=0):
if isLambda:
klass = FunctionCodeGenerator
name = "<lambda.%d>" % klass.lambdaCount
klass.lambdaCount = klass.lambdaCount + 1
else:
name = func.name
args, hasTupleArg = generateArgList(func.argnames)
self.graph = pyassem.PyFlowGraph(name, filename, args,
optimized=1)
self.isLambda = isLambda
self.super_init(filename)
lnf = walk(func.code, LocalNameFinder(args), 0)
self.locals.push(lnf.getLocals())
if func.varargs:
self.graph.setFlag(CO_VARARGS)
if func.kwargs:
self.graph.setFlag(CO_VARKEYWORDS)
self.set_lineno(func)
if hasTupleArg:
self.generateArgUnpack(func.argnames)
def finish(self):
self.graph.startExitBlock()
if not self.isLambda:
self.emit('LOAD_CONST', None)
self.emit('RETURN_VALUE')
def generateArgUnpack(self, args):
count = 0
for arg in args:
if type(arg) == types.TupleType:
self.emit('LOAD_FAST', '.nested%d' % count)
count = count + 1
self.unpackSequence(arg)
def unpackSequence(self, tup):
self.emit('UNPACK_SEQUENCE', len(tup))
for elt in tup:
if type(elt) == types.TupleType:
self.unpackSequence(elt)
else:
self.emit('STORE_FAST', elt)
unpackTuple = unpackSequence
class ClassCodeGenerator(CodeGenerator):
super_init = CodeGenerator.__init__
def __init__(self, klass, filename):
self.graph = pyassem.PyFlowGraph(klass.name, filename,
optimized=0)
self.super_init(filename)
lnf = walk(klass.code, LocalNameFinder(), 0)
self.locals.push(lnf.getLocals())
self.graph.setFlag(CO_NEWLOCALS)
def finish(self):
self.graph.startExitBlock()
self.emit('LOAD_LOCALS')
self.emit('RETURN_VALUE')
def generateArgList(arglist):
"""Generate an arg list marking TupleArgs"""
args = []
extra = []
count = 0
for elt in arglist:
if type(elt) == types.StringType:
args.append(elt)
elif type(elt) == types.TupleType:
args.append(TupleArg(count, elt))
count = count + 1
extra.extend(misc.flatten(elt))
else:
raise ValueError, "unexpect argument type:", elt
return args + extra, count
class LocalNameFinder:
"""Find local names in scope"""
def __init__(self, names=()):
self.names = misc.Set()
self.globals = misc.Set()
for name in names:
self.names.add(name)
def getLocals(self):
for elt in self.globals.elements():
if self.names.has_elt(elt):
self.names.remove(elt)
return self.names
def visitDict(self, node):
pass
def visitGlobal(self, node):
for name in node.names:
self.globals.add(name)
def visitFunction(self, node):
self.names.add(node.name)
def visitLambda(self, node):
pass
def visitImport(self, node):
for name in node.names:
self.names.add(name)
def visitFrom(self, node):
for name in node.names:
self.names.add(name)
def visitClass(self, node):
self.names.add(node.name)
def visitAssName(self, node):
self.names.add(node.name)
def findOp(node):
"""Find the op (DELETE, LOAD, STORE) in an AssTuple tree"""
v = OpFinder()
walk(node, v, 0)
return v.op
class OpFinder:
def __init__(self):
self.op = None
def visitAssName(self, node):
if self.op is None:
self.op = node.flags
elif self.op != node.flags:
raise ValueError, "mixed ops in stmt"
if __name__ == "__main__":
import sys
for file in sys.argv[1:]:
compile(file)