#!/usr/bin/python from pprint import pprint from simplegeneric import generic from ply import lex, yacc keywords = ("print", "if", "else", "fi") tokens = ("NUMBER", "VARIABLE", "PLUS", "MINUS", "TIMES", "EQUAL", "COLON", "NEWLINE") + tuple(map(str.upper, keywords)) t_NUMBER = r'\d+' t_PLUS = r'\+' t_MINUS = r'-' t_TIMES = r'\*' t_EQUAL = r'=' t_COLON = r':' def t_VARIABLE(t): r"""[a-z]+""" if t.value in keywords: t.type = t.value.upper() return t def t_NEWLINE(t): r"""\n""" t.lexer.lineno += 1 return t def t_error(t): print "Illegal character '%s'" % t.value[0] t.lexer.skip(1) t_ignore = ' \t' class BinOp(object): def __init__(self, lhs, rhs): self.lhs = lhs self.rhs = rhs def __repr__(self): return "%s(%s, %s)" % (self.__class__.__name__, self.lhs, self.rhs) class Add(BinOp): pass class Sub(BinOp): pass class Mul(BinOp): pass class AssignStmt(object): def __init__(self, lhs, rhs): self.lhs = lhs self.rhs = rhs def __repr__(self): return "AssignStmt(%r, %r)" % (self.lhs, self.rhs) class PrintStmt(object): def __init__(self, expr): self.expr = expr def __repr__(self): return "PrintStmt(%r)" % self.expr class IfStmt(object): def __init__(self, bool_expr, then_stmts, else_stmts=None): self.bool_expr = bool_expr self.then_stmts = then_stmts if else_stmts: self.else_stmts = else_stmts else: self.else_stmts = [] def __repr__(self): return "IfStmt(%r, %r, %r)" % (self.bool_expr, self.then_stmts, self.else_stmts) def p_program(p): """program : statement""" p[0] = [p[1]] def p_program_cont(p): """program : program statement""" p[0] = p[1] + [p[2]] def p_statement(p): """statement : assignment NEWLINE | print NEWLINE | if NEWLINE""" p[0] = p[1] def p_assignment(p): """assignment : VARIABLE EQUAL expr""" p[0] = AssignStmt(p[1], p[3]) def p_print(p): """print : PRINT expr""" p[0] = PrintStmt(p[2]) def p_if(p): """if : IF expr COLON NEWLINE program FI""" p[0] = IfStmt(p[2], p[5]) def p_if_else(p): """if : IF expr COLON NEWLINE program ELSE COLON NEWLINE program FI""" p[0] = IfStmt(p[2], p[5], p[9]) def p_expr(p): """expr : expr PLUS term | expr MINUS term""" if p[2] == '+': p[0] = Add(p[1], p[3]) else: p[0] = Sub(p[1], p[3]) def p_expr_term(p): """expr : term""" p[0] = p[1] def p_term(p): """term : term TIMES factor""" p[0] = Mul(p[1], p[3]) def p_term_factor(p): """term : factor""" p[0] = p[1] def p_factor(p): """factor : NUMBER | VARIABLE""" p[0] = p[1] def p_error(p): print "Syntax error!" print p lex.lex() yacc.yacc() ### Printing code ### @generic def code(node): return str(node) @code.when_type(list) def code_result(result): return "\n".join(map(code, result)) @code.when_type(Add) def code_add(node): return "(%s + %s)" % (code(node.lhs), code(node.rhs)) @code.when_type(Sub) def code_sub(node): return "(%s - %s)" % (code(node.lhs), code(node.rhs)) @code.when_type(Mul) def code_mul(node): return "%s * %s" % (code(node.lhs), code(node.rhs)) @code.when_type(AssignStmt) def code_assignment(node): return "%s = %s" % (code(node.lhs), code(node.rhs)) @code.when_type(PrintStmt) def code_print(node): return "print %s" % node.stmt @code.when_type(IfStmt) def code_if(node): result = ["if %s:" % code(node.bool_expr)] for stmt in node.then_stmts: result.append(code(stmt)) if node.else_stmts: result.append("else:") for stmt in node.else_stmts: result.append(code(stmt)) result.append("fi") return "\n".join(result) ### Desugaring of if-statements ### @generic def desugar_if(node, expr=None): return node @desugar_if.when_type(list) def desugar_if_list(node, expr=None): node = map(lambda n: desugar_if(n, expr), node) result = [] for n in node: if isinstance(n, list): result.extend(n) else: result.append(n) return result @desugar_if.when_type(IfStmt) def desugar_if_if(node, expr=None): if expr: b = Mul(expr, node.bool_expr) else: b = node.bool_expr nb = Sub(1, b) node.then_stmts = map(lambda n: desugar_if(n, b), node.then_stmts) node.else_stmts = map(lambda n: desugar_if(n, nb), node.else_stmts) return node.then_stmts + node.else_stmts @desugar_if.when_type(AssignStmt) def desugar_if_assignment(node, expr=None): if expr: nexpr = Sub(1, expr) node.rhs = Add(Mul(expr, node.rhs), Mul(nexpr, node.lhs)) return node ### Conversion of parse tree into nested tuples ### @generic def tree_to_tuples(node): return node @tree_to_tuples.when_type(list) def tree_to_tuples_list(node): return map(tree_to_tuples, node) @tree_to_tuples.when_type(IfStmt) def tree_to_tuples_if(node): if node.else_stmts: return ('IfStmt', tree_to_tuples(node.bool_expr), tree_to_tuples(node.then_stmts), tree_to_tuples(node.else_stmts)) else: return ('IfStmt', tree_to_tuples(node.bool_expr), tree_to_tuples(node.then_stmts)) @tree_to_tuples.when_type(AssignStmt) def tree_to_tuples_assign(node): return ('AssignStmt', tree_to_tuples(node.lhs), tree_to_tuples(node.rhs)) if __name__ == "__main__": import sys fp = open(sys.argv[1], 'r') data = fp.read() print "Original program:" print data tree = yacc.parse(data) #, debug=2) print "Parse tree:" pprint(tree_to_tuples(tree)) print print "Raw Python code:" print code(tree) tree = desugar_if(tree) print "Desugared Python code:" print code(tree)