bpo-38870: Simplify sequence interleaves in ast.unparse (GH-17892)

This commit is contained in:
Batuhan Taşkaya 2020-03-09 23:27:03 +03:00 committed by GitHub
parent 111e4ee52a
commit e7cab7f780
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 37 deletions

View File

@ -613,6 +613,16 @@ class _Unparser(NodeVisitor):
inter() inter()
f(x) f(x)
def items_view(self, traverser, items):
"""Traverse and separate the given *items* with a comma and append it to
the buffer. If *items* is a single item sequence, a trailing comma
will be added."""
if len(items) == 1:
traverser(items[0])
self.write(",")
else:
self.interleave(lambda: self.write(", "), traverser, items)
def fill(self, text=""): def fill(self, text=""):
"""Indent a piece of text and append it, according to the current """Indent a piece of text and append it, according to the current
indentation level""" indentation level"""
@ -1020,11 +1030,7 @@ class _Unparser(NodeVisitor):
value = node.value value = node.value
if isinstance(value, tuple): if isinstance(value, tuple):
with self.delimit("(", ")"): with self.delimit("(", ")"):
if len(value) == 1: self.items_view(self._write_constant, value)
self._write_constant(value[0])
self.write(",")
else:
self.interleave(lambda: self.write(", "), self._write_constant, value)
elif value is ...: elif value is ...:
self.write("...") self.write("...")
else: else:
@ -1116,12 +1122,7 @@ class _Unparser(NodeVisitor):
def visit_Tuple(self, node): def visit_Tuple(self, node):
with self.delimit("(", ")"): with self.delimit("(", ")"):
if len(node.elts) == 1: self.items_view(self.traverse, node.elts)
elt = node.elts[0]
self.traverse(elt)
self.write(",")
else:
self.interleave(lambda: self.write(", "), self.traverse, node.elts)
unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"} unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"}
unop_precedence = { unop_precedence = {
@ -1264,12 +1265,7 @@ class _Unparser(NodeVisitor):
if (isinstance(node.slice, Index) if (isinstance(node.slice, Index)
and isinstance(node.slice.value, Tuple) and isinstance(node.slice.value, Tuple)
and node.slice.value.elts): and node.slice.value.elts):
if len(node.slice.value.elts) == 1: self.items_view(self.traverse, node.slice.value.elts)
elt = node.slice.value.elts[0]
self.traverse(elt)
self.write(",")
else:
self.interleave(lambda: self.write(", "), self.traverse, node.slice.value.elts)
else: else:
self.traverse(node.slice) self.traverse(node.slice)
@ -1296,12 +1292,7 @@ class _Unparser(NodeVisitor):
self.traverse(node.step) self.traverse(node.step)
def visit_ExtSlice(self, node): def visit_ExtSlice(self, node):
if len(node.dims) == 1: self.items_view(self.traverse, node.dims)
elt = node.dims[0]
self.traverse(elt)
self.write(",")
else:
self.interleave(lambda: self.write(", "), self.traverse, node.dims)
def visit_arg(self, node): def visit_arg(self, node):
self.write(node.arg) self.write(node.arg)

View File

@ -280,6 +280,20 @@ class UnparseTestCase(ASTTestCase):
self.check_ast_roundtrip(r"""{**{'y': 2}, 'x': 1}""") self.check_ast_roundtrip(r"""{**{'y': 2}, 'x': 1}""")
self.check_ast_roundtrip(r"""{**{'y': 2}, **{'x': 1}}""") self.check_ast_roundtrip(r"""{**{'y': 2}, **{'x': 1}}""")
def test_ext_slices(self):
self.check_ast_roundtrip("a[i]")
self.check_ast_roundtrip("a[i,]")
self.check_ast_roundtrip("a[i, j]")
self.check_ast_roundtrip("a[()]")
self.check_ast_roundtrip("a[i:j]")
self.check_ast_roundtrip("a[:j]")
self.check_ast_roundtrip("a[i:]")
self.check_ast_roundtrip("a[i:j:k]")
self.check_ast_roundtrip("a[:j:k]")
self.check_ast_roundtrip("a[i::k]")
self.check_ast_roundtrip("a[i:j,]")
self.check_ast_roundtrip("a[i:j, k]")
def test_invalid_raise(self): def test_invalid_raise(self):
self.check_invalid(ast.Raise(exc=None, cause=ast.Name(id="X"))) self.check_invalid(ast.Raise(exc=None, cause=ast.Name(id="X")))
@ -310,6 +324,12 @@ class UnparseTestCase(ASTTestCase):
# check as Module docstrings for easy testing # check as Module docstrings for easy testing
self.check_ast_roundtrip(f"'{docstring}'") self.check_ast_roundtrip(f"'{docstring}'")
def test_constant_tuples(self):
self.check_src_roundtrip(ast.Constant(value=(1,), kind=None), "(1,)")
self.check_src_roundtrip(
ast.Constant(value=(1, 2, 3), kind=None), "(1, 2, 3)"
)
class CosmeticTestCase(ASTTestCase): class CosmeticTestCase(ASTTestCase):
"""Test if there are cosmetic issues caused by unnecesary additions""" """Test if there are cosmetic issues caused by unnecesary additions"""
@ -344,20 +364,6 @@ class CosmeticTestCase(ASTTestCase):
self.check_src_roundtrip("call((yield x))") self.check_src_roundtrip("call((yield x))")
self.check_src_roundtrip("return x + (yield x)") self.check_src_roundtrip("return x + (yield x)")
def test_subscript(self):
self.check_src_roundtrip("a[i]")
self.check_src_roundtrip("a[i,]")
self.check_src_roundtrip("a[i, j]")
self.check_src_roundtrip("a[()]")
self.check_src_roundtrip("a[i:j]")
self.check_src_roundtrip("a[:j]")
self.check_src_roundtrip("a[i:]")
self.check_src_roundtrip("a[i:j:k]")
self.check_src_roundtrip("a[:j:k]")
self.check_src_roundtrip("a[i::k]")
self.check_src_roundtrip("a[i:j,]")
self.check_src_roundtrip("a[i:j, k]")
def test_docstrings(self): def test_docstrings(self):
docstrings = ( docstrings = (
'"""simple doc string"""', '"""simple doc string"""',