diff --git a/Doc/library/ast.rst b/Doc/library/ast.rst index b468f4235df..a7e0729b902 100644 --- a/Doc/library/ast.rst +++ b/Doc/library/ast.rst @@ -161,6 +161,19 @@ and classes for traversing abstract syntax trees: Added ``type_comments``, ``mode='func_type'`` and ``feature_version``. +.. function:: unparse(ast_obj) + + Unparse an :class:`ast.AST` object and generate a string with code + that would produce an equivalent :class:`ast.AST` object if parsed + back with :func:`ast.parse`. + + .. warning:: + The produced code string will not necesarily be equal to the original + code that generated the :class:`ast.AST` object. + + .. versionadded:: 3.9 + + .. function:: literal_eval(node_or_string) Safely evaluate an expression node or a string containing a Python literal or diff --git a/Doc/whatsnew/3.9.rst b/Doc/whatsnew/3.9.rst index a3ad98d0206..9af5259de95 100644 --- a/Doc/whatsnew/3.9.rst +++ b/Doc/whatsnew/3.9.rst @@ -121,6 +121,11 @@ Added the *indent* option to :func:`~ast.dump` which allows it to produce a multiline indented output. (Contributed by Serhiy Storchaka in :issue:`37995`.) +Added the :func:`ast.unparse` as a function in the :mod:`ast` module that can +be used to unparse an :class:`ast.AST` object and produce a string with code +that would produce an equivalent :class:`ast.AST` object when parsed. +(Contributed by Pablo Galindo and Batuhan Taskaya in :issue:`38870`.) + asyncio ------- diff --git a/Lib/ast.py b/Lib/ast.py index 720dd48a761..97914ebc668 100644 --- a/Lib/ast.py +++ b/Lib/ast.py @@ -24,7 +24,9 @@ :copyright: Copyright 2008 by Armin Ronacher. :license: Python License. """ +import sys from _ast import * +from contextlib import contextmanager def parse(source, filename='', mode='exec', *, @@ -551,6 +553,697 @@ _const_node_type_names = { type(...): 'Ellipsis', } +# Large float and imaginary literals get turned into infinities in the AST. +# We unparse those infinities to INFSTR. +_INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1) + +class _Unparser(NodeVisitor): + """Methods in this class recursively traverse an AST and + output source code for the abstract syntax; original formatting + is disregarded.""" + + def __init__(self): + self._source = [] + self._buffer = [] + self._indent = 0 + + def interleave(self, inter, f, seq): + """Call f on each item in seq, calling inter() in between.""" + seq = iter(seq) + try: + f(next(seq)) + except StopIteration: + pass + else: + for x in seq: + inter() + f(x) + + def fill(self, text=""): + """Indent a piece of text and append it, according to the current + indentation level""" + self.write("\n" + " " * self._indent + text) + + def write(self, text): + """Append a piece of text""" + self._source.append(text) + + def buffer_writer(self, text): + self._buffer.append(text) + + @property + def buffer(self): + value = "".join(self._buffer) + self._buffer.clear() + return value + + @contextmanager + def block(self): + """A context manager for preparing the source for blocks. It adds + the character':', increases the indentation on enter and decreases + the indentation on exit.""" + self.write(":") + self._indent += 1 + yield + self._indent -= 1 + + def traverse(self, node): + if isinstance(node, list): + for item in node: + self.traverse(item) + else: + super().visit(node) + + def visit(self, node): + """Outputs a source code string that, if converted back to an ast + (using ast.parse) will generate an AST equivalent to *node*""" + self._source = [] + self.traverse(node) + return "".join(self._source) + + def visit_Module(self, node): + for subnode in node.body: + self.traverse(subnode) + + def visit_Expr(self, node): + self.fill() + self.traverse(node.value) + + def visit_NamedExpr(self, node): + self.write("(") + self.traverse(node.target) + self.write(" := ") + self.traverse(node.value) + self.write(")") + + def visit_Import(self, node): + self.fill("import ") + self.interleave(lambda: self.write(", "), self.traverse, node.names) + + def visit_ImportFrom(self, node): + self.fill("from ") + self.write("." * node.level) + if node.module: + self.write(node.module) + self.write(" import ") + self.interleave(lambda: self.write(", "), self.traverse, node.names) + + def visit_Assign(self, node): + self.fill() + for target in node.targets: + self.traverse(target) + self.write(" = ") + self.traverse(node.value) + + def visit_AugAssign(self, node): + self.fill() + self.traverse(node.target) + self.write(" " + self.binop[node.op.__class__.__name__] + "= ") + self.traverse(node.value) + + def visit_AnnAssign(self, node): + self.fill() + if not node.simple and isinstance(node.target, Name): + self.write("(") + self.traverse(node.target) + if not node.simple and isinstance(node.target, Name): + self.write(")") + self.write(": ") + self.traverse(node.annotation) + if node.value: + self.write(" = ") + self.traverse(node.value) + + def visit_Return(self, node): + self.fill("return") + if node.value: + self.write(" ") + self.traverse(node.value) + + def visit_Pass(self, node): + self.fill("pass") + + def visit_Break(self, node): + self.fill("break") + + def visit_Continue(self, node): + self.fill("continue") + + def visit_Delete(self, node): + self.fill("del ") + self.interleave(lambda: self.write(", "), self.traverse, node.targets) + + def visit_Assert(self, node): + self.fill("assert ") + self.traverse(node.test) + if node.msg: + self.write(", ") + self.traverse(node.msg) + + def visit_Global(self, node): + self.fill("global ") + self.interleave(lambda: self.write(", "), self.write, node.names) + + def visit_Nonlocal(self, node): + self.fill("nonlocal ") + self.interleave(lambda: self.write(", "), self.write, node.names) + + def visit_Await(self, node): + self.write("(") + self.write("await") + if node.value: + self.write(" ") + self.traverse(node.value) + self.write(")") + + def visit_Yield(self, node): + self.write("(") + self.write("yield") + if node.value: + self.write(" ") + self.traverse(node.value) + self.write(")") + + def visit_YieldFrom(self, node): + self.write("(") + self.write("yield from") + if node.value: + self.write(" ") + self.traverse(node.value) + self.write(")") + + def visit_Raise(self, node): + self.fill("raise") + if not node.exc: + if node.cause: + raise ValueError(f"Node can't use cause without an exception.") + return + self.write(" ") + self.traverse(node.exc) + if node.cause: + self.write(" from ") + self.traverse(node.cause) + + def visit_Try(self, node): + self.fill("try") + with self.block(): + self.traverse(node.body) + for ex in node.handlers: + self.traverse(ex) + if node.orelse: + self.fill("else") + with self.block(): + self.traverse(node.orelse) + if node.finalbody: + self.fill("finally") + with self.block(): + self.traverse(node.finalbody) + + def visit_ExceptHandler(self, node): + self.fill("except") + if node.type: + self.write(" ") + self.traverse(node.type) + if node.name: + self.write(" as ") + self.write(node.name) + with self.block(): + self.traverse(node.body) + + def visit_ClassDef(self, node): + self.write("\n") + for deco in node.decorator_list: + self.fill("@") + self.traverse(deco) + self.fill("class " + node.name) + self.write("(") + comma = False + for e in node.bases: + if comma: + self.write(", ") + else: + comma = True + self.traverse(e) + for e in node.keywords: + if comma: + self.write(", ") + else: + comma = True + self.traverse(e) + self.write(")") + + with self.block(): + self.traverse(node.body) + + def visit_FunctionDef(self, node): + self.__FunctionDef_helper(node, "def") + + def visit_AsyncFunctionDef(self, node): + self.__FunctionDef_helper(node, "async def") + + def __FunctionDef_helper(self, node, fill_suffix): + self.write("\n") + for deco in node.decorator_list: + self.fill("@") + self.traverse(deco) + def_str = fill_suffix + " " + node.name + "(" + self.fill(def_str) + self.traverse(node.args) + self.write(")") + if node.returns: + self.write(" -> ") + self.traverse(node.returns) + with self.block(): + self.traverse(node.body) + + def visit_For(self, node): + self.__For_helper("for ", node) + + def visit_AsyncFor(self, node): + self.__For_helper("async for ", node) + + def __For_helper(self, fill, node): + self.fill(fill) + self.traverse(node.target) + self.write(" in ") + self.traverse(node.iter) + with self.block(): + self.traverse(node.body) + if node.orelse: + self.fill("else") + with self.block(): + self.traverse(node.orelse) + + def visit_If(self, node): + self.fill("if ") + self.traverse(node.test) + with self.block(): + self.traverse(node.body) + # collapse nested ifs into equivalent elifs. + while node.orelse and len(node.orelse) == 1 and isinstance(node.orelse[0], If): + node = node.orelse[0] + self.fill("elif ") + self.traverse(node.test) + with self.block(): + self.traverse(node.body) + # final else + if node.orelse: + self.fill("else") + with self.block(): + self.traverse(node.orelse) + + def visit_While(self, node): + self.fill("while ") + self.traverse(node.test) + with self.block(): + self.traverse(node.body) + if node.orelse: + self.fill("else") + with self.block(): + self.traverse(node.orelse) + + def visit_With(self, node): + self.fill("with ") + self.interleave(lambda: self.write(", "), self.traverse, node.items) + with self.block(): + self.traverse(node.body) + + def visit_AsyncWith(self, node): + self.fill("async with ") + self.interleave(lambda: self.write(", "), self.traverse, node.items) + with self.block(): + self.traverse(node.body) + + def visit_JoinedStr(self, node): + self.write("f") + self._fstring_JoinedStr(node, self.buffer_writer) + self.write(repr(self.buffer)) + + def visit_FormattedValue(self, node): + self.write("f") + self._fstring_FormattedValue(node, self.buffer_writer) + self.write(repr(self.buffer)) + + def _fstring_JoinedStr(self, node, write): + for value in node.values: + meth = getattr(self, "_fstring_" + type(value).__name__) + meth(value, write) + + def _fstring_Constant(self, node, write): + if not isinstance(node.value, str): + raise ValueError("Constants inside JoinedStr should be a string.") + value = node.value.replace("{", "{{").replace("}", "}}") + write(value) + + def _fstring_FormattedValue(self, node, write): + write("{") + expr = type(self)().visit(node.value).rstrip("\n") + if expr.startswith("{"): + write(" ") # Separate pair of opening brackets as "{ {" + write(expr) + if node.conversion != -1: + conversion = chr(node.conversion) + if conversion not in "sra": + raise ValueError("Unknown f-string conversion.") + write(f"!{conversion}") + if node.format_spec: + write(":") + meth = getattr(self, "_fstring_" + type(node.format_spec).__name__) + meth(node.format_spec, write) + write("}") + + def visit_Name(self, node): + self.write(node.id) + + def _write_constant(self, value): + if isinstance(value, (float, complex)): + # Substitute overflowing decimal literal for AST infinities. + self.write(repr(value).replace("inf", _INFSTR)) + else: + self.write(repr(value)) + + def visit_Constant(self, node): + value = node.value + if isinstance(value, tuple): + self.write("(") + if len(value) == 1: + self._write_constant(value[0]) + self.write(",") + else: + self.interleave(lambda: self.write(", "), self._write_constant, value) + self.write(")") + elif value is ...: + self.write("...") + else: + if node.kind == "u": + self.write("u") + self._write_constant(node.value) + + def visit_List(self, node): + self.write("[") + self.interleave(lambda: self.write(", "), self.traverse, node.elts) + self.write("]") + + def visit_ListComp(self, node): + self.write("[") + self.traverse(node.elt) + for gen in node.generators: + self.traverse(gen) + self.write("]") + + def visit_GeneratorExp(self, node): + self.write("(") + self.traverse(node.elt) + for gen in node.generators: + self.traverse(gen) + self.write(")") + + def visit_SetComp(self, node): + self.write("{") + self.traverse(node.elt) + for gen in node.generators: + self.traverse(gen) + self.write("}") + + def visit_DictComp(self, node): + self.write("{") + self.traverse(node.key) + self.write(": ") + self.traverse(node.value) + for gen in node.generators: + self.traverse(gen) + self.write("}") + + def visit_comprehension(self, node): + if node.is_async: + self.write(" async for ") + else: + self.write(" for ") + self.traverse(node.target) + self.write(" in ") + self.traverse(node.iter) + for if_clause in node.ifs: + self.write(" if ") + self.traverse(if_clause) + + def visit_IfExp(self, node): + self.write("(") + self.traverse(node.body) + self.write(" if ") + self.traverse(node.test) + self.write(" else ") + self.traverse(node.orelse) + self.write(")") + + def visit_Set(self, node): + if not node.elts: + raise ValueError("Set node should has at least one item") + self.write("{") + self.interleave(lambda: self.write(", "), self.traverse, node.elts) + self.write("}") + + def visit_Dict(self, node): + self.write("{") + + def write_key_value_pair(k, v): + self.traverse(k) + self.write(": ") + self.traverse(v) + + def write_item(item): + k, v = item + if k is None: + # for dictionary unpacking operator in dicts {**{'y': 2}} + # see PEP 448 for details + self.write("**") + self.traverse(v) + else: + write_key_value_pair(k, v) + + self.interleave( + lambda: self.write(", "), write_item, zip(node.keys, node.values) + ) + self.write("}") + + def visit_Tuple(self, node): + self.write("(") + if len(node.elts) == 1: + elt = node.elts[0] + self.traverse(elt) + self.write(",") + else: + self.interleave(lambda: self.write(", "), self.traverse, node.elts) + self.write(")") + + unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"} + + def visit_UnaryOp(self, node): + self.write("(") + self.write(self.unop[node.op.__class__.__name__]) + self.write(" ") + self.traverse(node.operand) + self.write(")") + + binop = { + "Add": "+", + "Sub": "-", + "Mult": "*", + "MatMult": "@", + "Div": "/", + "Mod": "%", + "LShift": "<<", + "RShift": ">>", + "BitOr": "|", + "BitXor": "^", + "BitAnd": "&", + "FloorDiv": "//", + "Pow": "**", + } + + def visit_BinOp(self, node): + self.write("(") + self.traverse(node.left) + self.write(" " + self.binop[node.op.__class__.__name__] + " ") + self.traverse(node.right) + self.write(")") + + cmpops = { + "Eq": "==", + "NotEq": "!=", + "Lt": "<", + "LtE": "<=", + "Gt": ">", + "GtE": ">=", + "Is": "is", + "IsNot": "is not", + "In": "in", + "NotIn": "not in", + } + + def visit_Compare(self, node): + self.write("(") + self.traverse(node.left) + for o, e in zip(node.ops, node.comparators): + self.write(" " + self.cmpops[o.__class__.__name__] + " ") + self.traverse(e) + self.write(")") + + boolops = {And: "and", Or: "or"} + + def visit_BoolOp(self, node): + self.write("(") + s = " %s " % self.boolops[node.op.__class__] + self.interleave(lambda: self.write(s), self.traverse, node.values) + self.write(")") + + def visit_Attribute(self, node): + self.traverse(node.value) + # Special case: 3.__abs__() is a syntax error, so if node.value + # is an integer literal then we need to either parenthesize + # it or add an extra space to get 3 .__abs__(). + if isinstance(node.value, Constant) and isinstance(node.value.value, int): + self.write(" ") + self.write(".") + self.write(node.attr) + + def visit_Call(self, node): + self.traverse(node.func) + self.write("(") + comma = False + for e in node.args: + if comma: + self.write(", ") + else: + comma = True + self.traverse(e) + for e in node.keywords: + if comma: + self.write(", ") + else: + comma = True + self.traverse(e) + self.write(")") + + def visit_Subscript(self, node): + self.traverse(node.value) + self.write("[") + self.traverse(node.slice) + self.write("]") + + def visit_Starred(self, node): + self.write("*") + self.traverse(node.value) + + def visit_Ellipsis(self, node): + self.write("...") + + def visit_Index(self, node): + self.traverse(node.value) + + def visit_Slice(self, node): + if node.lower: + self.traverse(node.lower) + self.write(":") + if node.upper: + self.traverse(node.upper) + if node.step: + self.write(":") + self.traverse(node.step) + + def visit_ExtSlice(self, node): + self.interleave(lambda: self.write(", "), self.traverse, node.dims) + + def visit_arg(self, node): + self.write(node.arg) + if node.annotation: + self.write(": ") + self.traverse(node.annotation) + + def visit_arguments(self, node): + first = True + # normal arguments + all_args = node.posonlyargs + node.args + defaults = [None] * (len(all_args) - len(node.defaults)) + node.defaults + for index, elements in enumerate(zip(all_args, defaults), 1): + a, d = elements + if first: + first = False + else: + self.write(", ") + self.traverse(a) + if d: + self.write("=") + self.traverse(d) + if index == len(node.posonlyargs): + self.write(", /") + + # varargs, or bare '*' if no varargs but keyword-only arguments present + if node.vararg or node.kwonlyargs: + if first: + first = False + else: + self.write(", ") + self.write("*") + if node.vararg: + self.write(node.vararg.arg) + if node.vararg.annotation: + self.write(": ") + self.traverse(node.vararg.annotation) + + # keyword-only arguments + if node.kwonlyargs: + for a, d in zip(node.kwonlyargs, node.kw_defaults): + if first: + first = False + else: + self.write(", ") + self.traverse(a), + if d: + self.write("=") + self.traverse(d) + + # kwargs + if node.kwarg: + if first: + first = False + else: + self.write(", ") + self.write("**" + node.kwarg.arg) + if node.kwarg.annotation: + self.write(": ") + self.traverse(node.kwarg.annotation) + + def visit_keyword(self, node): + if node.arg is None: + self.write("**") + else: + self.write(node.arg) + self.write("=") + self.traverse(node.value) + + def visit_Lambda(self, node): + self.write("(") + self.write("lambda ") + self.traverse(node.args) + self.write(": ") + self.traverse(node.body) + self.write(")") + + def visit_alias(self, node): + self.write(node.name) + if node.asname: + self.write(" as " + node.asname) + + def visit_withitem(self, node): + self.traverse(node.context_expr) + if node.optional_vars: + self.write(" as ") + self.traverse(node.optional_vars) + +def unparse(ast_obj): + unparser = _Unparser() + return unparser.visit(ast_obj) + def main(): import argparse diff --git a/Lib/test/test_tools/test_unparse.py b/Lib/test/test_unparse.py similarity index 76% rename from Lib/test/test_tools/test_unparse.py rename to Lib/test/test_unparse.py index a958ebb51cc..9197c8a4a9e 100644 --- a/Lib/test/test_tools/test_unparse.py +++ b/Lib/test/test_unparse.py @@ -3,19 +3,12 @@ import unittest import test.support import io -import os +import pathlib import random import tokenize import ast +import functools -from test.test_tools import basepath, toolsdir, skip_if_missing - -skip_if_missing() - -parser_path = os.path.join(toolsdir, "parser") - -with test.support.DirsOnSysPath(parser_path): - import unparse def read_pyfile(filename): """Read and return the contents of a Python source file (as a @@ -26,6 +19,7 @@ def read_pyfile(filename): source = pyfile.read() return source + for_else = """\ def f(): for x in range(10): @@ -119,18 +113,21 @@ with f() as x, g() as y: suite1 """ + class ASTTestCase(unittest.TestCase): def assertASTEqual(self, ast1, ast2): self.assertEqual(ast.dump(ast1), ast.dump(ast2)) - def check_roundtrip(self, code1, filename="internal"): - ast1 = compile(code1, filename, "exec", ast.PyCF_ONLY_AST) - unparse_buffer = io.StringIO() - unparse.Unparser(ast1, unparse_buffer) - code2 = unparse_buffer.getvalue() - ast2 = compile(code2, filename, "exec", ast.PyCF_ONLY_AST) + def check_roundtrip(self, code1): + ast1 = ast.parse(code1) + code2 = ast.unparse(ast1) + ast2 = ast.parse(code2) self.assertASTEqual(ast1, ast2) + def check_invalid(self, node, raises=ValueError): + self.assertRaises(raises, ast.unparse, node) + + class UnparseTestCase(ASTTestCase): # Tests for specific bugs found in earlier versions of unparse @@ -174,8 +171,8 @@ class UnparseTestCase(ASTTestCase): self.check_roundtrip("-1e1000j") def test_min_int(self): - self.check_roundtrip(str(-2**31)) - self.check_roundtrip(str(-2**63)) + self.check_roundtrip(str(-(2 ** 31))) + self.check_roundtrip(str(-(2 ** 63))) def test_imaginary_literals(self): self.check_roundtrip("7j") @@ -265,54 +262,67 @@ class UnparseTestCase(ASTTestCase): self.check_roundtrip(r"""{**{'y': 2}, 'x': 1}""") self.check_roundtrip(r"""{**{'y': 2}, **{'x': 1}}""") + def test_invalid_raise(self): + self.check_invalid(ast.Raise(exc=None, cause=ast.Name(id="X"))) + + def test_invalid_fstring_constant(self): + self.check_invalid(ast.JoinedStr(values=[ast.Constant(value=100)])) + + def test_invalid_fstring_conversion(self): + self.check_invalid( + ast.FormattedValue( + value=ast.Constant(value="a", kind=None), + conversion=ord("Y"), # random character + format_spec=None, + ) + ) + + def test_invalid_set(self): + self.check_invalid(ast.Set(elts=[])) + class DirectoryTestCase(ASTTestCase): """Test roundtrip behaviour on all files in Lib and Lib/test.""" - NAMES = None - # test directories, relative to the root of the distribution - test_directories = 'Lib', os.path.join('Lib', 'test') + lib_dir = pathlib.Path(__file__).parent / ".." + test_directories = (lib_dir, lib_dir / "test") + skip_files = {"test_fstring.py"} - @classmethod - def get_names(cls): - if cls.NAMES is not None: - return cls.NAMES + @functools.cached_property + def files_to_test(self): + # bpo-31174: Use cached_property to store the names sample + # to always test the same files. It prevents false alarms + # when hunting reference leaks. - names = [] - for d in cls.test_directories: - test_dir = os.path.join(basepath, d) - for n in os.listdir(test_dir): - if n.endswith('.py') and not n.startswith('bad'): - names.append(os.path.join(test_dir, n)) + items = [ + item.resolve() + for directory in self.test_directories + for item in directory.glob("*.py") + if not item.name.startswith("bad") + ] # Test limited subset of files unless the 'cpu' resource is specified. if not test.support.is_resource_enabled("cpu"): - names = random.sample(names, 10) - # bpo-31174: Store the names sample to always test the same files. - # It prevents false alarms when hunting reference leaks. - cls.NAMES = names - return names + items = random.sample(items, 10) + return items def test_files(self): - # get names of files to test - names = self.get_names() - - for filename in names: + for item in self.files_to_test: if test.support.verbose: - print('Testing %s' % filename) + print(f"Testing {item.absolute()}") # Some f-strings are not correctly round-tripped by - # Tools/parser/unparse.py. See issue 28002 for details. - # We need to skip files that contain such f-strings. - if os.path.basename(filename) in ('test_fstring.py', ): + # Tools/parser/unparse.py. See issue 28002 for details. + # We need to skip files that contain such f-strings. + if item.name in self.skip_files: if test.support.verbose: - print(f'Skipping {filename}: see issue 28002') + print(f"Skipping {item.absolute()}: see issue 28002") continue - with self.subTest(filename=filename): - source = read_pyfile(filename) + with self.subTest(filename=item): + source = read_pyfile(item) self.check_roundtrip(source) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/Misc/NEWS.d/next/Library/2019-11-20-22-43-48.bpo-38870.rLVZEv.rst b/Misc/NEWS.d/next/Library/2019-11-20-22-43-48.bpo-38870.rLVZEv.rst new file mode 100644 index 00000000000..61af368ba55 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2019-11-20-22-43-48.bpo-38870.rLVZEv.rst @@ -0,0 +1,4 @@ +Expose :func:`ast.unparse` as a function of the :mod:`ast` module that can +be used to unparse an :class:`ast.AST` object and produce a string with code +that would produce an equivalent :class:`ast.AST` object when parsed. Patch +by Pablo Galindo and Batuhan Taskaya. diff --git a/Tools/parser/unparse.py b/Tools/parser/unparse.py deleted file mode 100644 index a5cc000676b..00000000000 --- a/Tools/parser/unparse.py +++ /dev/null @@ -1,704 +0,0 @@ -"Usage: unparse.py " -import sys -import ast -import tokenize -import io -import os - -# Large float and imaginary literals get turned into infinities in the AST. -# We unparse those infinities to INFSTR. -INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1) - -def interleave(inter, f, seq): - """Call f on each item in seq, calling inter() in between. - """ - seq = iter(seq) - try: - f(next(seq)) - except StopIteration: - pass - else: - for x in seq: - inter() - f(x) - -class Unparser: - """Methods in this class recursively traverse an AST and - output source code for the abstract syntax; original formatting - is disregarded. """ - - def __init__(self, tree, file = sys.stdout): - """Unparser(tree, file=sys.stdout) -> None. - Print the source for tree to file.""" - self.f = file - self._indent = 0 - self.dispatch(tree) - print("", file=self.f) - self.f.flush() - - def fill(self, text = ""): - "Indent a piece of text, according to the current indentation level" - self.f.write("\n"+" "*self._indent + text) - - def write(self, text): - "Append a piece of text to the current line." - self.f.write(text) - - def enter(self): - "Print ':', and increase the indentation." - self.write(":") - self._indent += 1 - - def leave(self): - "Decrease the indentation level." - self._indent -= 1 - - def dispatch(self, tree): - "Dispatcher function, dispatching tree type T to method _T." - if isinstance(tree, list): - for t in tree: - self.dispatch(t) - return - meth = getattr(self, "_"+tree.__class__.__name__) - meth(tree) - - - ############### Unparsing methods ###################### - # There should be one method per concrete grammar type # - # Constructors should be grouped by sum type. Ideally, # - # this would follow the order in the grammar, but # - # currently doesn't. # - ######################################################## - - def _Module(self, tree): - for stmt in tree.body: - self.dispatch(stmt) - - # stmt - def _Expr(self, tree): - self.fill() - self.dispatch(tree.value) - - def _NamedExpr(self, tree): - self.write("(") - self.dispatch(tree.target) - self.write(" := ") - self.dispatch(tree.value) - self.write(")") - - def _Import(self, t): - self.fill("import ") - interleave(lambda: self.write(", "), self.dispatch, t.names) - - def _ImportFrom(self, t): - self.fill("from ") - self.write("." * t.level) - if t.module: - self.write(t.module) - self.write(" import ") - interleave(lambda: self.write(", "), self.dispatch, t.names) - - def _Assign(self, t): - self.fill() - for target in t.targets: - self.dispatch(target) - self.write(" = ") - self.dispatch(t.value) - - def _AugAssign(self, t): - self.fill() - self.dispatch(t.target) - self.write(" "+self.binop[t.op.__class__.__name__]+"= ") - self.dispatch(t.value) - - def _AnnAssign(self, t): - self.fill() - if not t.simple and isinstance(t.target, ast.Name): - self.write('(') - self.dispatch(t.target) - if not t.simple and isinstance(t.target, ast.Name): - self.write(')') - self.write(": ") - self.dispatch(t.annotation) - if t.value: - self.write(" = ") - self.dispatch(t.value) - - def _Return(self, t): - self.fill("return") - if t.value: - self.write(" ") - self.dispatch(t.value) - - def _Pass(self, t): - self.fill("pass") - - def _Break(self, t): - self.fill("break") - - def _Continue(self, t): - self.fill("continue") - - def _Delete(self, t): - self.fill("del ") - interleave(lambda: self.write(", "), self.dispatch, t.targets) - - def _Assert(self, t): - self.fill("assert ") - self.dispatch(t.test) - if t.msg: - self.write(", ") - self.dispatch(t.msg) - - def _Global(self, t): - self.fill("global ") - interleave(lambda: self.write(", "), self.write, t.names) - - def _Nonlocal(self, t): - self.fill("nonlocal ") - interleave(lambda: self.write(", "), self.write, t.names) - - def _Await(self, t): - self.write("(") - self.write("await") - if t.value: - self.write(" ") - self.dispatch(t.value) - self.write(")") - - def _Yield(self, t): - self.write("(") - self.write("yield") - if t.value: - self.write(" ") - self.dispatch(t.value) - self.write(")") - - def _YieldFrom(self, t): - self.write("(") - self.write("yield from") - if t.value: - self.write(" ") - self.dispatch(t.value) - self.write(")") - - def _Raise(self, t): - self.fill("raise") - if not t.exc: - assert not t.cause - return - self.write(" ") - self.dispatch(t.exc) - if t.cause: - self.write(" from ") - self.dispatch(t.cause) - - def _Try(self, t): - self.fill("try") - self.enter() - self.dispatch(t.body) - self.leave() - for ex in t.handlers: - self.dispatch(ex) - if t.orelse: - self.fill("else") - self.enter() - self.dispatch(t.orelse) - self.leave() - if t.finalbody: - self.fill("finally") - self.enter() - self.dispatch(t.finalbody) - self.leave() - - def _ExceptHandler(self, t): - self.fill("except") - if t.type: - self.write(" ") - self.dispatch(t.type) - if t.name: - self.write(" as ") - self.write(t.name) - self.enter() - self.dispatch(t.body) - self.leave() - - def _ClassDef(self, t): - self.write("\n") - for deco in t.decorator_list: - self.fill("@") - self.dispatch(deco) - self.fill("class "+t.name) - self.write("(") - comma = False - for e in t.bases: - if comma: self.write(", ") - else: comma = True - self.dispatch(e) - for e in t.keywords: - if comma: self.write(", ") - else: comma = True - self.dispatch(e) - self.write(")") - - self.enter() - self.dispatch(t.body) - self.leave() - - def _FunctionDef(self, t): - self.__FunctionDef_helper(t, "def") - - def _AsyncFunctionDef(self, t): - self.__FunctionDef_helper(t, "async def") - - def __FunctionDef_helper(self, t, fill_suffix): - self.write("\n") - for deco in t.decorator_list: - self.fill("@") - self.dispatch(deco) - def_str = fill_suffix+" "+t.name + "(" - self.fill(def_str) - self.dispatch(t.args) - self.write(")") - if t.returns: - self.write(" -> ") - self.dispatch(t.returns) - self.enter() - self.dispatch(t.body) - self.leave() - - def _For(self, t): - self.__For_helper("for ", t) - - def _AsyncFor(self, t): - self.__For_helper("async for ", t) - - def __For_helper(self, fill, t): - self.fill(fill) - self.dispatch(t.target) - self.write(" in ") - self.dispatch(t.iter) - self.enter() - self.dispatch(t.body) - self.leave() - if t.orelse: - self.fill("else") - self.enter() - self.dispatch(t.orelse) - self.leave() - - def _If(self, t): - self.fill("if ") - self.dispatch(t.test) - self.enter() - self.dispatch(t.body) - self.leave() - # collapse nested ifs into equivalent elifs. - while (t.orelse and len(t.orelse) == 1 and - isinstance(t.orelse[0], ast.If)): - t = t.orelse[0] - self.fill("elif ") - self.dispatch(t.test) - self.enter() - self.dispatch(t.body) - self.leave() - # final else - if t.orelse: - self.fill("else") - self.enter() - self.dispatch(t.orelse) - self.leave() - - def _While(self, t): - self.fill("while ") - self.dispatch(t.test) - self.enter() - self.dispatch(t.body) - self.leave() - if t.orelse: - self.fill("else") - self.enter() - self.dispatch(t.orelse) - self.leave() - - def _With(self, t): - self.fill("with ") - interleave(lambda: self.write(", "), self.dispatch, t.items) - self.enter() - self.dispatch(t.body) - self.leave() - - def _AsyncWith(self, t): - self.fill("async with ") - interleave(lambda: self.write(", "), self.dispatch, t.items) - self.enter() - self.dispatch(t.body) - self.leave() - - # expr - def _JoinedStr(self, t): - self.write("f") - string = io.StringIO() - self._fstring_JoinedStr(t, string.write) - self.write(repr(string.getvalue())) - - def _FormattedValue(self, t): - self.write("f") - string = io.StringIO() - self._fstring_FormattedValue(t, string.write) - self.write(repr(string.getvalue())) - - def _fstring_JoinedStr(self, t, write): - for value in t.values: - meth = getattr(self, "_fstring_" + type(value).__name__) - meth(value, write) - - def _fstring_Constant(self, t, write): - assert isinstance(t.value, str) - value = t.value.replace("{", "{{").replace("}", "}}") - write(value) - - def _fstring_FormattedValue(self, t, write): - write("{") - expr = io.StringIO() - Unparser(t.value, expr) - expr = expr.getvalue().rstrip("\n") - if expr.startswith("{"): - write(" ") # Separate pair of opening brackets as "{ {" - write(expr) - if t.conversion != -1: - conversion = chr(t.conversion) - assert conversion in "sra" - write(f"!{conversion}") - if t.format_spec: - write(":") - meth = getattr(self, "_fstring_" + type(t.format_spec).__name__) - meth(t.format_spec, write) - write("}") - - def _Name(self, t): - self.write(t.id) - - def _write_constant(self, value): - if isinstance(value, (float, complex)): - # Substitute overflowing decimal literal for AST infinities. - self.write(repr(value).replace("inf", INFSTR)) - else: - self.write(repr(value)) - - def _Constant(self, t): - value = t.value - if isinstance(value, tuple): - self.write("(") - if len(value) == 1: - self._write_constant(value[0]) - self.write(",") - else: - interleave(lambda: self.write(", "), self._write_constant, value) - self.write(")") - elif value is ...: - self.write("...") - else: - if t.kind == "u": - self.write("u") - self._write_constant(t.value) - - def _List(self, t): - self.write("[") - interleave(lambda: self.write(", "), self.dispatch, t.elts) - self.write("]") - - def _ListComp(self, t): - self.write("[") - self.dispatch(t.elt) - for gen in t.generators: - self.dispatch(gen) - self.write("]") - - def _GeneratorExp(self, t): - self.write("(") - self.dispatch(t.elt) - for gen in t.generators: - self.dispatch(gen) - self.write(")") - - def _SetComp(self, t): - self.write("{") - self.dispatch(t.elt) - for gen in t.generators: - self.dispatch(gen) - self.write("}") - - def _DictComp(self, t): - self.write("{") - self.dispatch(t.key) - self.write(": ") - self.dispatch(t.value) - for gen in t.generators: - self.dispatch(gen) - self.write("}") - - def _comprehension(self, t): - if t.is_async: - self.write(" async for ") - else: - self.write(" for ") - self.dispatch(t.target) - self.write(" in ") - self.dispatch(t.iter) - for if_clause in t.ifs: - self.write(" if ") - self.dispatch(if_clause) - - def _IfExp(self, t): - self.write("(") - self.dispatch(t.body) - self.write(" if ") - self.dispatch(t.test) - self.write(" else ") - self.dispatch(t.orelse) - self.write(")") - - def _Set(self, t): - assert(t.elts) # should be at least one element - self.write("{") - interleave(lambda: self.write(", "), self.dispatch, t.elts) - self.write("}") - - def _Dict(self, t): - self.write("{") - def write_key_value_pair(k, v): - self.dispatch(k) - self.write(": ") - self.dispatch(v) - - def write_item(item): - k, v = item - if k is None: - # for dictionary unpacking operator in dicts {**{'y': 2}} - # see PEP 448 for details - self.write("**") - self.dispatch(v) - else: - write_key_value_pair(k, v) - interleave(lambda: self.write(", "), write_item, zip(t.keys, t.values)) - self.write("}") - - def _Tuple(self, t): - self.write("(") - if len(t.elts) == 1: - elt = t.elts[0] - self.dispatch(elt) - self.write(",") - else: - interleave(lambda: self.write(", "), self.dispatch, t.elts) - self.write(")") - - unop = {"Invert":"~", "Not": "not", "UAdd":"+", "USub":"-"} - def _UnaryOp(self, t): - self.write("(") - self.write(self.unop[t.op.__class__.__name__]) - self.write(" ") - self.dispatch(t.operand) - self.write(")") - - binop = { "Add":"+", "Sub":"-", "Mult":"*", "MatMult":"@", "Div":"/", "Mod":"%", - "LShift":"<<", "RShift":">>", "BitOr":"|", "BitXor":"^", "BitAnd":"&", - "FloorDiv":"//", "Pow": "**"} - def _BinOp(self, t): - self.write("(") - self.dispatch(t.left) - self.write(" " + self.binop[t.op.__class__.__name__] + " ") - self.dispatch(t.right) - self.write(")") - - cmpops = {"Eq":"==", "NotEq":"!=", "Lt":"<", "LtE":"<=", "Gt":">", "GtE":">=", - "Is":"is", "IsNot":"is not", "In":"in", "NotIn":"not in"} - def _Compare(self, t): - self.write("(") - self.dispatch(t.left) - for o, e in zip(t.ops, t.comparators): - self.write(" " + self.cmpops[o.__class__.__name__] + " ") - self.dispatch(e) - self.write(")") - - boolops = {ast.And: 'and', ast.Or: 'or'} - def _BoolOp(self, t): - self.write("(") - s = " %s " % self.boolops[t.op.__class__] - interleave(lambda: self.write(s), self.dispatch, t.values) - self.write(")") - - def _Attribute(self,t): - self.dispatch(t.value) - # Special case: 3.__abs__() is a syntax error, so if t.value - # is an integer literal then we need to either parenthesize - # it or add an extra space to get 3 .__abs__(). - if isinstance(t.value, ast.Constant) and isinstance(t.value.value, int): - self.write(" ") - self.write(".") - self.write(t.attr) - - def _Call(self, t): - self.dispatch(t.func) - self.write("(") - comma = False - for e in t.args: - if comma: self.write(", ") - else: comma = True - self.dispatch(e) - for e in t.keywords: - if comma: self.write(", ") - else: comma = True - self.dispatch(e) - self.write(")") - - def _Subscript(self, t): - self.dispatch(t.value) - self.write("[") - self.dispatch(t.slice) - self.write("]") - - def _Starred(self, t): - self.write("*") - self.dispatch(t.value) - - # slice - def _Ellipsis(self, t): - self.write("...") - - def _Index(self, t): - self.dispatch(t.value) - - def _Slice(self, t): - if t.lower: - self.dispatch(t.lower) - self.write(":") - if t.upper: - self.dispatch(t.upper) - if t.step: - self.write(":") - self.dispatch(t.step) - - def _ExtSlice(self, t): - interleave(lambda: self.write(', '), self.dispatch, t.dims) - - # argument - def _arg(self, t): - self.write(t.arg) - if t.annotation: - self.write(": ") - self.dispatch(t.annotation) - - # others - def _arguments(self, t): - first = True - # normal arguments - all_args = t.posonlyargs + t.args - defaults = [None] * (len(all_args) - len(t.defaults)) + t.defaults - for index, elements in enumerate(zip(all_args, defaults), 1): - a, d = elements - if first:first = False - else: self.write(", ") - self.dispatch(a) - if d: - self.write("=") - self.dispatch(d) - if index == len(t.posonlyargs): - self.write(", /") - - # varargs, or bare '*' if no varargs but keyword-only arguments present - if t.vararg or t.kwonlyargs: - if first:first = False - else: self.write(", ") - self.write("*") - if t.vararg: - self.write(t.vararg.arg) - if t.vararg.annotation: - self.write(": ") - self.dispatch(t.vararg.annotation) - - # keyword-only arguments - if t.kwonlyargs: - for a, d in zip(t.kwonlyargs, t.kw_defaults): - if first:first = False - else: self.write(", ") - self.dispatch(a), - if d: - self.write("=") - self.dispatch(d) - - # kwargs - if t.kwarg: - if first:first = False - else: self.write(", ") - self.write("**"+t.kwarg.arg) - if t.kwarg.annotation: - self.write(": ") - self.dispatch(t.kwarg.annotation) - - def _keyword(self, t): - if t.arg is None: - self.write("**") - else: - self.write(t.arg) - self.write("=") - self.dispatch(t.value) - - def _Lambda(self, t): - self.write("(") - self.write("lambda ") - self.dispatch(t.args) - self.write(": ") - self.dispatch(t.body) - self.write(")") - - def _alias(self, t): - self.write(t.name) - if t.asname: - self.write(" as "+t.asname) - - def _withitem(self, t): - self.dispatch(t.context_expr) - if t.optional_vars: - self.write(" as ") - self.dispatch(t.optional_vars) - -def roundtrip(filename, output=sys.stdout): - with open(filename, "rb") as pyfile: - encoding = tokenize.detect_encoding(pyfile.readline)[0] - with open(filename, "r", encoding=encoding) as pyfile: - source = pyfile.read() - tree = compile(source, filename, "exec", ast.PyCF_ONLY_AST) - Unparser(tree, output) - - - -def testdir(a): - try: - names = [n for n in os.listdir(a) if n.endswith('.py')] - except OSError: - print("Directory not readable: %s" % a, file=sys.stderr) - else: - for n in names: - fullname = os.path.join(a, n) - if os.path.isfile(fullname): - output = io.StringIO() - print('Testing %s' % fullname) - try: - roundtrip(fullname, output) - except Exception as e: - print(' Failed to compile, exception is %s' % repr(e)) - elif os.path.isdir(fullname): - testdir(fullname) - -def main(args): - if args[0] == '--testdir': - for a in args[1:]: - testdir(a) - else: - for a in args: - roundtrip(a) - -if __name__=='__main__': - main(sys.argv[1:])