diff --git a/Lib/ast.py b/Lib/ast.py index ee3f74358ee..76e0cac838b 100644 --- a/Lib/ast.py +++ b/Lib/ast.py @@ -26,6 +26,7 @@ """ import sys from _ast import * +from contextlib import contextmanager, nullcontext def parse(source, filename='', mode='exec', *, @@ -613,6 +614,21 @@ class _Unparser(NodeVisitor): def block(self): return self._Block(self) + @contextmanager + def delimit(self, start, end): + """A context manager for preparing the source for expressions. It adds + *start* to the buffer and enters, after exit it adds *end*.""" + + self.write(start) + yield + self.write(end) + + def delimit_if(self, start, end, condition): + if condition: + return self.delimit(start, end) + else: + return nullcontext() + def traverse(self, node): if isinstance(node, list): for item in node: @@ -636,11 +652,10 @@ class _Unparser(NodeVisitor): self.traverse(node.value) def visit_NamedExpr(self, node): - self.write("(") - self.traverse(node.target) - self.write(" := ") - self.traverse(node.value) - self.write(")") + with self.delimit("(", ")"): + self.traverse(node.target) + self.write(" := ") + self.traverse(node.value) def visit_Import(self, node): self.fill("import ") @@ -669,11 +684,8 @@ class _Unparser(NodeVisitor): def visit_AnnAssign(self, node): self.fill() - if not node.simple and isinstance(node.target, Name): - self.write("(") - self.traverse(node.target) - if not node.simple and isinstance(node.target, Name): - self.write(")") + with self.delimit_if("(", ")", not node.simple and isinstance(node.target, Name)): + self.traverse(node.target) self.write(": ") self.traverse(node.annotation) if node.value: @@ -715,28 +727,25 @@ class _Unparser(NodeVisitor): self.interleave(lambda: self.write(", "), self.write, node.names) def visit_Await(self, node): - self.write("(") - self.write("await") - if node.value: - self.write(" ") - self.traverse(node.value) - self.write(")") + with self.delimit("(", ")"): + self.write("await") + if node.value: + self.write(" ") + self.traverse(node.value) def visit_Yield(self, node): - self.write("(") - self.write("yield") - if node.value: - self.write(" ") - self.traverse(node.value) - self.write(")") + with self.delimit("(", ")"): + self.write("yield") + if node.value: + self.write(" ") + self.traverse(node.value) def visit_YieldFrom(self, node): - self.write("(") - self.write("yield from") - if node.value: - self.write(" ") - self.traverse(node.value) - self.write(")") + with self.delimit("(", ")"): + self.write("yield from") + if node.value: + self.write(" ") + self.traverse(node.value) def visit_Raise(self, node): self.fill("raise") @@ -782,21 +791,20 @@ class _Unparser(NodeVisitor): self.fill("@") self.traverse(deco) self.fill("class " + node.name) - self.write("(") - comma = False - for e in node.bases: - if comma: - self.write(", ") - else: - comma = True - self.traverse(e) - for e in node.keywords: - if comma: - self.write(", ") - else: - comma = True - self.traverse(e) - self.write(")") + with self.delimit("(", ")"): + comma = False + for e in node.bases: + if comma: + self.write(", ") + else: + comma = True + self.traverse(e) + for e in node.keywords: + if comma: + self.write(", ") + else: + comma = True + self.traverse(e) with self.block(): self.traverse(node.body) @@ -812,10 +820,10 @@ class _Unparser(NodeVisitor): for deco in node.decorator_list: self.fill("@") self.traverse(deco) - def_str = fill_suffix + " " + node.name + "(" + def_str = fill_suffix + " " + node.name self.fill(def_str) - self.traverse(node.args) - self.write(")") + with self.delimit("(", ")"): + self.traverse(node.args) if node.returns: self.write(" -> ") self.traverse(node.returns) @@ -931,13 +939,12 @@ class _Unparser(NodeVisitor): def visit_Constant(self, node): value = node.value if isinstance(value, tuple): - self.write("(") - if len(value) == 1: - self._write_constant(value[0]) - self.write(",") - else: - self.interleave(lambda: self.write(", "), self._write_constant, value) - self.write(")") + with self.delimit("(", ")"): + if len(value) == 1: + self._write_constant(value[0]) + self.write(",") + else: + self.interleave(lambda: self.write(", "), self._write_constant, value) elif value is ...: self.write("...") else: @@ -946,39 +953,34 @@ class _Unparser(NodeVisitor): self._write_constant(node.value) def visit_List(self, node): - self.write("[") - self.interleave(lambda: self.write(", "), self.traverse, node.elts) - self.write("]") + with self.delimit("[", "]"): + self.interleave(lambda: self.write(", "), self.traverse, node.elts) def visit_ListComp(self, node): - self.write("[") - self.traverse(node.elt) - for gen in node.generators: - self.traverse(gen) - self.write("]") + with self.delimit("[", "]"): + self.traverse(node.elt) + for gen in node.generators: + self.traverse(gen) def visit_GeneratorExp(self, node): - self.write("(") - self.traverse(node.elt) - for gen in node.generators: - self.traverse(gen) - self.write(")") + with self.delimit("(", ")"): + self.traverse(node.elt) + for gen in node.generators: + self.traverse(gen) def visit_SetComp(self, node): - self.write("{") - self.traverse(node.elt) - for gen in node.generators: - self.traverse(gen) - self.write("}") + with self.delimit("{", "}"): + self.traverse(node.elt) + for gen in node.generators: + self.traverse(gen) def visit_DictComp(self, node): - self.write("{") - self.traverse(node.key) - self.write(": ") - self.traverse(node.value) - for gen in node.generators: - self.traverse(gen) - self.write("}") + with self.delimit("{", "}"): + self.traverse(node.key) + self.write(": ") + self.traverse(node.value) + for gen in node.generators: + self.traverse(gen) def visit_comprehension(self, node): if node.is_async: @@ -993,24 +995,20 @@ class _Unparser(NodeVisitor): self.traverse(if_clause) def visit_IfExp(self, node): - self.write("(") - self.traverse(node.body) - self.write(" if ") - self.traverse(node.test) - self.write(" else ") - self.traverse(node.orelse) - self.write(")") + with self.delimit("(", ")"): + self.traverse(node.body) + self.write(" if ") + self.traverse(node.test) + self.write(" else ") + self.traverse(node.orelse) def visit_Set(self, node): if not node.elts: raise ValueError("Set node should has at least one item") - self.write("{") - self.interleave(lambda: self.write(", "), self.traverse, node.elts) - self.write("}") + with self.delimit("{", "}"): + self.interleave(lambda: self.write(", "), self.traverse, node.elts) def visit_Dict(self, node): - self.write("{") - def write_key_value_pair(k, v): self.traverse(k) self.write(": ") @@ -1026,29 +1024,27 @@ class _Unparser(NodeVisitor): else: write_key_value_pair(k, v) - self.interleave( - lambda: self.write(", "), write_item, zip(node.keys, node.values) - ) - self.write("}") + with self.delimit("{", "}"): + self.interleave( + lambda: self.write(", "), write_item, zip(node.keys, node.values) + ) def visit_Tuple(self, node): - self.write("(") - if len(node.elts) == 1: - elt = node.elts[0] - self.traverse(elt) - self.write(",") - else: - self.interleave(lambda: self.write(", "), self.traverse, node.elts) - self.write(")") + with self.delimit("(", ")"): + if len(node.elts) == 1: + elt = node.elts[0] + self.traverse(elt) + self.write(",") + else: + self.interleave(lambda: self.write(", "), self.traverse, node.elts) unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"} def visit_UnaryOp(self, node): - self.write("(") - self.write(self.unop[node.op.__class__.__name__]) - self.write(" ") - self.traverse(node.operand) - self.write(")") + with self.delimit("(", ")"): + self.write(self.unop[node.op.__class__.__name__]) + self.write(" ") + self.traverse(node.operand) binop = { "Add": "+", @@ -1067,11 +1063,10 @@ class _Unparser(NodeVisitor): } def visit_BinOp(self, node): - self.write("(") - self.traverse(node.left) - self.write(" " + self.binop[node.op.__class__.__name__] + " ") - self.traverse(node.right) - self.write(")") + with self.delimit("(", ")"): + self.traverse(node.left) + self.write(" " + self.binop[node.op.__class__.__name__] + " ") + self.traverse(node.right) cmpops = { "Eq": "==", @@ -1087,20 +1082,18 @@ class _Unparser(NodeVisitor): } def visit_Compare(self, node): - self.write("(") - self.traverse(node.left) - for o, e in zip(node.ops, node.comparators): - self.write(" " + self.cmpops[o.__class__.__name__] + " ") - self.traverse(e) - self.write(")") + with self.delimit("(", ")"): + self.traverse(node.left) + for o, e in zip(node.ops, node.comparators): + self.write(" " + self.cmpops[o.__class__.__name__] + " ") + self.traverse(e) - boolops = {And: "and", Or: "or"} + boolops = {"And": "and", "Or": "or"} def visit_BoolOp(self, node): - self.write("(") - s = " %s " % self.boolops[node.op.__class__] - self.interleave(lambda: self.write(s), self.traverse, node.values) - self.write(")") + with self.delimit("(", ")"): + s = " %s " % self.boolops[node.op.__class__.__name__] + self.interleave(lambda: self.write(s), self.traverse, node.values) def visit_Attribute(self, node): self.traverse(node.value) @@ -1114,27 +1107,25 @@ class _Unparser(NodeVisitor): def visit_Call(self, node): self.traverse(node.func) - self.write("(") - comma = False - for e in node.args: - if comma: - self.write(", ") - else: - comma = True - self.traverse(e) - for e in node.keywords: - if comma: - self.write(", ") - else: - comma = True - self.traverse(e) - self.write(")") + with self.delimit("(", ")"): + comma = False + for e in node.args: + if comma: + self.write(", ") + else: + comma = True + self.traverse(e) + for e in node.keywords: + if comma: + self.write(", ") + else: + comma = True + self.traverse(e) def visit_Subscript(self, node): self.traverse(node.value) - self.write("[") - self.traverse(node.slice) - self.write("]") + with self.delimit("[", "]"): + self.traverse(node.slice) def visit_Starred(self, node): self.write("*") @@ -1225,12 +1216,11 @@ class _Unparser(NodeVisitor): self.traverse(node.value) def visit_Lambda(self, node): - self.write("(") - self.write("lambda ") - self.traverse(node.args) - self.write(": ") - self.traverse(node.body) - self.write(")") + with self.delimit("(", ")"): + self.write("lambda ") + self.traverse(node.args) + self.write(": ") + self.traverse(node.body) def visit_alias(self, node): self.write(node.name)