mirror of https://github.com/python/cpython
bpo-38870: Implement a precedence algorithm in ast.unparse (GH-17377)
Implement a simple precedence algorithm for ast.unparse in order to avoid redundant parenthesis for nested structures in the final output.
This commit is contained in:
parent
185903de12
commit
397b96f6d7
138
Lib/ast.py
138
Lib/ast.py
|
@ -27,6 +27,7 @@
|
|||
import sys
|
||||
from _ast import *
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from enum import IntEnum, auto
|
||||
|
||||
|
||||
def parse(source, filename='<unknown>', mode='exec', *,
|
||||
|
@ -560,6 +561,35 @@ _const_node_type_names = {
|
|||
# We unparse those infinities to INFSTR.
|
||||
_INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1)
|
||||
|
||||
class _Precedence(IntEnum):
|
||||
"""Precedence table that originated from python grammar."""
|
||||
|
||||
TUPLE = auto()
|
||||
YIELD = auto() # 'yield', 'yield from'
|
||||
TEST = auto() # 'if'-'else', 'lambda'
|
||||
OR = auto() # 'or'
|
||||
AND = auto() # 'and'
|
||||
NOT = auto() # 'not'
|
||||
CMP = auto() # '<', '>', '==', '>=', '<=', '!=',
|
||||
# 'in', 'not in', 'is', 'is not'
|
||||
EXPR = auto()
|
||||
BOR = EXPR # '|'
|
||||
BXOR = auto() # '^'
|
||||
BAND = auto() # '&'
|
||||
SHIFT = auto() # '<<', '>>'
|
||||
ARITH = auto() # '+', '-'
|
||||
TERM = auto() # '*', '@', '/', '%', '//'
|
||||
FACTOR = auto() # unary '+', '-', '~'
|
||||
POWER = auto() # '**'
|
||||
AWAIT = auto() # 'await'
|
||||
ATOM = auto()
|
||||
|
||||
def next(self):
|
||||
try:
|
||||
return self.__class__(self + 1)
|
||||
except ValueError:
|
||||
return self
|
||||
|
||||
class _Unparser(NodeVisitor):
|
||||
"""Methods in this class recursively traverse an AST and
|
||||
output source code for the abstract syntax; original formatting
|
||||
|
@ -568,6 +598,7 @@ class _Unparser(NodeVisitor):
|
|||
def __init__(self):
|
||||
self._source = []
|
||||
self._buffer = []
|
||||
self._precedences = {}
|
||||
self._indent = 0
|
||||
|
||||
def interleave(self, inter, f, seq):
|
||||
|
@ -625,6 +656,17 @@ class _Unparser(NodeVisitor):
|
|||
else:
|
||||
return nullcontext()
|
||||
|
||||
def require_parens(self, precedence, node):
|
||||
"""Shortcut to adding precedence related parens"""
|
||||
return self.delimit_if("(", ")", self.get_precedence(node) > precedence)
|
||||
|
||||
def get_precedence(self, node):
|
||||
return self._precedences.get(node, _Precedence.TEST)
|
||||
|
||||
def set_precedence(self, precedence, *nodes):
|
||||
for node in nodes:
|
||||
self._precedences[node] = precedence
|
||||
|
||||
def traverse(self, node):
|
||||
if isinstance(node, list):
|
||||
for item in node:
|
||||
|
@ -645,10 +687,12 @@ class _Unparser(NodeVisitor):
|
|||
|
||||
def visit_Expr(self, node):
|
||||
self.fill()
|
||||
self.set_precedence(_Precedence.YIELD, node.value)
|
||||
self.traverse(node.value)
|
||||
|
||||
def visit_NamedExpr(self, node):
|
||||
with self.delimit("(", ")"):
|
||||
with self.require_parens(_Precedence.TUPLE, node):
|
||||
self.set_precedence(_Precedence.ATOM, node.target, node.value)
|
||||
self.traverse(node.target)
|
||||
self.write(" := ")
|
||||
self.traverse(node.value)
|
||||
|
@ -723,24 +767,27 @@ class _Unparser(NodeVisitor):
|
|||
self.interleave(lambda: self.write(", "), self.write, node.names)
|
||||
|
||||
def visit_Await(self, node):
|
||||
with self.delimit("(", ")"):
|
||||
with self.require_parens(_Precedence.AWAIT, node):
|
||||
self.write("await")
|
||||
if node.value:
|
||||
self.write(" ")
|
||||
self.set_precedence(_Precedence.ATOM, node.value)
|
||||
self.traverse(node.value)
|
||||
|
||||
def visit_Yield(self, node):
|
||||
with self.delimit("(", ")"):
|
||||
with self.require_parens(_Precedence.YIELD, node):
|
||||
self.write("yield")
|
||||
if node.value:
|
||||
self.write(" ")
|
||||
self.set_precedence(_Precedence.ATOM, node.value)
|
||||
self.traverse(node.value)
|
||||
|
||||
def visit_YieldFrom(self, node):
|
||||
with self.delimit("(", ")"):
|
||||
with self.require_parens(_Precedence.YIELD, node):
|
||||
self.write("yield from ")
|
||||
if not node.value:
|
||||
raise ValueError("Node can't be used without a value attribute.")
|
||||
self.set_precedence(_Precedence.ATOM, node.value)
|
||||
self.traverse(node.value)
|
||||
|
||||
def visit_Raise(self, node):
|
||||
|
@ -907,7 +954,9 @@ class _Unparser(NodeVisitor):
|
|||
|
||||
def _fstring_FormattedValue(self, node, write):
|
||||
write("{")
|
||||
expr = type(self)().visit(node.value).rstrip("\n")
|
||||
unparser = type(self)()
|
||||
unparser.set_precedence(_Precedence.TEST.next(), node.value)
|
||||
expr = unparser.visit(node.value).rstrip("\n")
|
||||
if expr.startswith("{"):
|
||||
write(" ") # Separate pair of opening brackets as "{ {"
|
||||
write(expr)
|
||||
|
@ -983,19 +1032,23 @@ class _Unparser(NodeVisitor):
|
|||
self.write(" async for ")
|
||||
else:
|
||||
self.write(" for ")
|
||||
self.set_precedence(_Precedence.TUPLE, node.target)
|
||||
self.traverse(node.target)
|
||||
self.write(" in ")
|
||||
self.set_precedence(_Precedence.TEST.next(), node.iter, *node.ifs)
|
||||
self.traverse(node.iter)
|
||||
for if_clause in node.ifs:
|
||||
self.write(" if ")
|
||||
self.traverse(if_clause)
|
||||
|
||||
def visit_IfExp(self, node):
|
||||
with self.delimit("(", ")"):
|
||||
with self.require_parens(_Precedence.TEST, node):
|
||||
self.set_precedence(_Precedence.TEST.next(), node.body, node.test)
|
||||
self.traverse(node.body)
|
||||
self.write(" if ")
|
||||
self.traverse(node.test)
|
||||
self.write(" else ")
|
||||
self.set_precedence(_Precedence.TEST, node.orelse)
|
||||
self.traverse(node.orelse)
|
||||
|
||||
def visit_Set(self, node):
|
||||
|
@ -1016,6 +1069,7 @@ class _Unparser(NodeVisitor):
|
|||
# for dictionary unpacking operator in dicts {**{'y': 2}}
|
||||
# see PEP 448 for details
|
||||
self.write("**")
|
||||
self.set_precedence(_Precedence.EXPR, v)
|
||||
self.traverse(v)
|
||||
else:
|
||||
write_key_value_pair(k, v)
|
||||
|
@ -1035,11 +1089,20 @@ class _Unparser(NodeVisitor):
|
|||
self.interleave(lambda: self.write(", "), self.traverse, node.elts)
|
||||
|
||||
unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"}
|
||||
unop_precedence = {
|
||||
"~": _Precedence.FACTOR,
|
||||
"not": _Precedence.NOT,
|
||||
"+": _Precedence.FACTOR,
|
||||
"-": _Precedence.FACTOR
|
||||
}
|
||||
|
||||
def visit_UnaryOp(self, node):
|
||||
with self.delimit("(", ")"):
|
||||
self.write(self.unop[node.op.__class__.__name__])
|
||||
operator = self.unop[node.op.__class__.__name__]
|
||||
operator_precedence = self.unop_precedence[operator]
|
||||
with self.require_parens(operator_precedence, node):
|
||||
self.write(operator)
|
||||
self.write(" ")
|
||||
self.set_precedence(operator_precedence, node.operand)
|
||||
self.traverse(node.operand)
|
||||
|
||||
binop = {
|
||||
|
@ -1058,10 +1121,38 @@ class _Unparser(NodeVisitor):
|
|||
"Pow": "**",
|
||||
}
|
||||
|
||||
binop_precedence = {
|
||||
"+": _Precedence.ARITH,
|
||||
"-": _Precedence.ARITH,
|
||||
"*": _Precedence.TERM,
|
||||
"@": _Precedence.TERM,
|
||||
"/": _Precedence.TERM,
|
||||
"%": _Precedence.TERM,
|
||||
"<<": _Precedence.SHIFT,
|
||||
">>": _Precedence.SHIFT,
|
||||
"|": _Precedence.BOR,
|
||||
"^": _Precedence.BXOR,
|
||||
"&": _Precedence.BAND,
|
||||
"//": _Precedence.TERM,
|
||||
"**": _Precedence.POWER,
|
||||
}
|
||||
|
||||
binop_rassoc = frozenset(("**",))
|
||||
def visit_BinOp(self, node):
|
||||
with self.delimit("(", ")"):
|
||||
operator = self.binop[node.op.__class__.__name__]
|
||||
operator_precedence = self.binop_precedence[operator]
|
||||
with self.require_parens(operator_precedence, node):
|
||||
if operator in self.binop_rassoc:
|
||||
left_precedence = operator_precedence.next()
|
||||
right_precedence = operator_precedence
|
||||
else:
|
||||
left_precedence = operator_precedence
|
||||
right_precedence = operator_precedence.next()
|
||||
|
||||
self.set_precedence(left_precedence, node.left)
|
||||
self.traverse(node.left)
|
||||
self.write(" " + self.binop[node.op.__class__.__name__] + " ")
|
||||
self.write(f" {operator} ")
|
||||
self.set_precedence(right_precedence, node.right)
|
||||
self.traverse(node.right)
|
||||
|
||||
cmpops = {
|
||||
|
@ -1078,20 +1169,32 @@ class _Unparser(NodeVisitor):
|
|||
}
|
||||
|
||||
def visit_Compare(self, node):
|
||||
with self.delimit("(", ")"):
|
||||
with self.require_parens(_Precedence.CMP, node):
|
||||
self.set_precedence(_Precedence.CMP.next(), node.left, *node.comparators)
|
||||
self.traverse(node.left)
|
||||
for o, e in zip(node.ops, node.comparators):
|
||||
self.write(" " + self.cmpops[o.__class__.__name__] + " ")
|
||||
self.traverse(e)
|
||||
|
||||
boolops = {"And": "and", "Or": "or"}
|
||||
boolop_precedence = {"and": _Precedence.AND, "or": _Precedence.OR}
|
||||
|
||||
def visit_BoolOp(self, node):
|
||||
with self.delimit("(", ")"):
|
||||
s = " %s " % self.boolops[node.op.__class__.__name__]
|
||||
self.interleave(lambda: self.write(s), self.traverse, node.values)
|
||||
operator = self.boolops[node.op.__class__.__name__]
|
||||
operator_precedence = self.boolop_precedence[operator]
|
||||
|
||||
def increasing_level_traverse(node):
|
||||
nonlocal operator_precedence
|
||||
operator_precedence = operator_precedence.next()
|
||||
self.set_precedence(operator_precedence, node)
|
||||
self.traverse(node)
|
||||
|
||||
with self.require_parens(operator_precedence, node):
|
||||
s = f" {operator} "
|
||||
self.interleave(lambda: self.write(s), increasing_level_traverse, node.values)
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
self.set_precedence(_Precedence.ATOM, node.value)
|
||||
self.traverse(node.value)
|
||||
# Special case: 3.__abs__() is a syntax error, so if node.value
|
||||
# is an integer literal then we need to either parenthesize
|
||||
|
@ -1102,6 +1205,7 @@ class _Unparser(NodeVisitor):
|
|||
self.write(node.attr)
|
||||
|
||||
def visit_Call(self, node):
|
||||
self.set_precedence(_Precedence.ATOM, node.func)
|
||||
self.traverse(node.func)
|
||||
with self.delimit("(", ")"):
|
||||
comma = False
|
||||
|
@ -1119,18 +1223,21 @@ class _Unparser(NodeVisitor):
|
|||
self.traverse(e)
|
||||
|
||||
def visit_Subscript(self, node):
|
||||
self.set_precedence(_Precedence.ATOM, node.value)
|
||||
self.traverse(node.value)
|
||||
with self.delimit("[", "]"):
|
||||
self.traverse(node.slice)
|
||||
|
||||
def visit_Starred(self, node):
|
||||
self.write("*")
|
||||
self.set_precedence(_Precedence.EXPR, node.value)
|
||||
self.traverse(node.value)
|
||||
|
||||
def visit_Ellipsis(self, node):
|
||||
self.write("...")
|
||||
|
||||
def visit_Index(self, node):
|
||||
self.set_precedence(_Precedence.TUPLE, node.value)
|
||||
self.traverse(node.value)
|
||||
|
||||
def visit_Slice(self, node):
|
||||
|
@ -1212,10 +1319,11 @@ class _Unparser(NodeVisitor):
|
|||
self.traverse(node.value)
|
||||
|
||||
def visit_Lambda(self, node):
|
||||
with self.delimit("(", ")"):
|
||||
with self.require_parens(_Precedence.TEST, node):
|
||||
self.write("lambda ")
|
||||
self.traverse(node.args)
|
||||
self.write(": ")
|
||||
self.set_precedence(_Precedence.TEST, node.body)
|
||||
self.traverse(node.body)
|
||||
|
||||
def visit_alias(self, node):
|
||||
|
|
|
@ -247,6 +247,13 @@ eval_tests = [
|
|||
|
||||
class AST_Tests(unittest.TestCase):
|
||||
|
||||
def _is_ast_node(self, name, node):
|
||||
if not isinstance(node, type):
|
||||
return False
|
||||
if "ast" not in node.__module__:
|
||||
return False
|
||||
return name != 'AST' and name[0].isupper()
|
||||
|
||||
def _assertTrueorder(self, ast_node, parent_pos):
|
||||
if not isinstance(ast_node, ast.AST) or ast_node._fields is None:
|
||||
return
|
||||
|
@ -335,7 +342,7 @@ class AST_Tests(unittest.TestCase):
|
|||
|
||||
def test_field_attr_existence(self):
|
||||
for name, item in ast.__dict__.items():
|
||||
if isinstance(item, type) and name != 'AST' and name[0].isupper():
|
||||
if self._is_ast_node(name, item):
|
||||
x = item()
|
||||
if isinstance(x, ast.AST):
|
||||
self.assertEqual(type(x._fields), tuple)
|
||||
|
|
|
@ -125,6 +125,13 @@ class ASTTestCase(unittest.TestCase):
|
|||
def check_invalid(self, node, raises=ValueError):
|
||||
self.assertRaises(raises, ast.unparse, node)
|
||||
|
||||
def check_src_roundtrip(self, code1, code2=None, strip=True):
|
||||
code2 = code2 or code1
|
||||
code1 = ast.unparse(ast.parse(code1))
|
||||
if strip:
|
||||
code1 = code1.strip()
|
||||
self.assertEqual(code2, code1)
|
||||
|
||||
|
||||
class UnparseTestCase(ASTTestCase):
|
||||
# Tests for specific bugs found in earlier versions of unparse
|
||||
|
@ -281,6 +288,40 @@ class UnparseTestCase(ASTTestCase):
|
|||
def test_invalid_yield_from(self):
|
||||
self.check_invalid(ast.YieldFrom(value=None))
|
||||
|
||||
class CosmeticTestCase(ASTTestCase):
|
||||
"""Test if there are cosmetic issues caused by unnecesary additions"""
|
||||
|
||||
def test_simple_expressions_parens(self):
|
||||
self.check_src_roundtrip("(a := b)")
|
||||
self.check_src_roundtrip("await x")
|
||||
self.check_src_roundtrip("x if x else y")
|
||||
self.check_src_roundtrip("lambda x: x")
|
||||
self.check_src_roundtrip("1 + 1")
|
||||
self.check_src_roundtrip("1 + 2 / 3")
|
||||
self.check_src_roundtrip("(1 + 2) / 3")
|
||||
self.check_src_roundtrip("(1 + 2) * 3 + 4 * (5 + 2)")
|
||||
self.check_src_roundtrip("(1 + 2) * 3 + 4 * (5 + 2) ** 2")
|
||||
self.check_src_roundtrip("~ x")
|
||||
self.check_src_roundtrip("x and y")
|
||||
self.check_src_roundtrip("x and y and z")
|
||||
self.check_src_roundtrip("x and (y and x)")
|
||||
self.check_src_roundtrip("(x and y) and z")
|
||||
self.check_src_roundtrip("(x ** y) ** z ** q")
|
||||
self.check_src_roundtrip("x >> y")
|
||||
self.check_src_roundtrip("x << y")
|
||||
self.check_src_roundtrip("x >> y and x >> z")
|
||||
self.check_src_roundtrip("x + y - z * q ^ t ** k")
|
||||
self.check_src_roundtrip("P * V if P and V else n * R * T")
|
||||
self.check_src_roundtrip("lambda P, V, n: P * V == n * R * T")
|
||||
self.check_src_roundtrip("flag & (other | foo)")
|
||||
self.check_src_roundtrip("not x == y")
|
||||
self.check_src_roundtrip("x == (not y)")
|
||||
self.check_src_roundtrip("yield x")
|
||||
self.check_src_roundtrip("yield from x")
|
||||
self.check_src_roundtrip("call((yield x))")
|
||||
self.check_src_roundtrip("return x + (yield x)")
|
||||
|
||||
|
||||
class DirectoryTestCase(ASTTestCase):
|
||||
"""Test roundtrip behaviour on all files in Lib and Lib/test."""
|
||||
|
||||
|
|
Loading…
Reference in New Issue