bpo-38870: Add docstring support to ast.unparse (GH-17760)
Allow ast.unparse to detect docstrings in functions, modules and classes and produce nicely formatted unparsed output for said docstrings. Co-Authored-By: Pablo Galindo <Pablogsal@gmail.com>
This commit is contained in:
parent
66b7973c1b
commit
89aa4694fc
55
Lib/ast.py
55
Lib/ast.py
|
@ -667,6 +667,22 @@ class _Unparser(NodeVisitor):
|
|||
for node in nodes:
|
||||
self._precedences[node] = precedence
|
||||
|
||||
def get_raw_docstring(self, node):
|
||||
"""If a docstring node is found in the body of the *node* parameter,
|
||||
return that docstring node, None otherwise.
|
||||
|
||||
Logic mirrored from ``_PyAST_GetDocString``."""
|
||||
if not isinstance(
|
||||
node, (AsyncFunctionDef, FunctionDef, ClassDef, Module)
|
||||
) or len(node.body) < 1:
|
||||
return None
|
||||
node = node.body[0]
|
||||
if not isinstance(node, Expr):
|
||||
return None
|
||||
node = node.value
|
||||
if isinstance(node, Constant) and isinstance(node.value, str):
|
||||
return node
|
||||
|
||||
def traverse(self, node):
|
||||
if isinstance(node, list):
|
||||
for item in node:
|
||||
|
@ -681,9 +697,15 @@ class _Unparser(NodeVisitor):
|
|||
self.traverse(node)
|
||||
return "".join(self._source)
|
||||
|
||||
def _write_docstring_and_traverse_body(self, node):
|
||||
if (docstring := self.get_raw_docstring(node)):
|
||||
self._write_docstring(docstring)
|
||||
self.traverse(node.body[1:])
|
||||
else:
|
||||
self.traverse(node.body)
|
||||
|
||||
def visit_Module(self, node):
|
||||
for subnode in node.body:
|
||||
self.traverse(subnode)
|
||||
self._write_docstring_and_traverse_body(node)
|
||||
|
||||
def visit_Expr(self, node):
|
||||
self.fill()
|
||||
|
@ -850,15 +872,15 @@ class _Unparser(NodeVisitor):
|
|||
self.traverse(e)
|
||||
|
||||
with self.block():
|
||||
self.traverse(node.body)
|
||||
self._write_docstring_and_traverse_body(node)
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
self.__FunctionDef_helper(node, "def")
|
||||
self._function_helper(node, "def")
|
||||
|
||||
def visit_AsyncFunctionDef(self, node):
|
||||
self.__FunctionDef_helper(node, "async def")
|
||||
self._function_helper(node, "async def")
|
||||
|
||||
def __FunctionDef_helper(self, node, fill_suffix):
|
||||
def _function_helper(self, node, fill_suffix):
|
||||
self.write("\n")
|
||||
for deco in node.decorator_list:
|
||||
self.fill("@")
|
||||
|
@ -871,15 +893,15 @@ class _Unparser(NodeVisitor):
|
|||
self.write(" -> ")
|
||||
self.traverse(node.returns)
|
||||
with self.block():
|
||||
self.traverse(node.body)
|
||||
self._write_docstring_and_traverse_body(node)
|
||||
|
||||
def visit_For(self, node):
|
||||
self.__For_helper("for ", node)
|
||||
self._for_helper("for ", node)
|
||||
|
||||
def visit_AsyncFor(self, node):
|
||||
self.__For_helper("async for ", node)
|
||||
self._for_helper("async for ", node)
|
||||
|
||||
def __For_helper(self, fill, node):
|
||||
def _for_helper(self, fill, node):
|
||||
self.fill(fill)
|
||||
self.traverse(node.target)
|
||||
self.write(" in ")
|
||||
|
@ -974,6 +996,19 @@ class _Unparser(NodeVisitor):
|
|||
def visit_Name(self, node):
|
||||
self.write(node.id)
|
||||
|
||||
def _write_docstring(self, node):
|
||||
self.fill()
|
||||
if node.kind == "u":
|
||||
self.write("u")
|
||||
|
||||
# Preserve quotes in the docstring by escaping them
|
||||
value = node.value.replace("\\", "\\\\")
|
||||
value = value.replace('"""', '""\"')
|
||||
if value[-1] == '"':
|
||||
value = value.replace('"', '\\"', -1)
|
||||
|
||||
self.write(f'"""{value}"""')
|
||||
|
||||
def _write_constant(self, value):
|
||||
if isinstance(value, (float, complex)):
|
||||
# Substitute overflowing decimal literal for AST infinities.
|
||||
|
|
|
@ -111,12 +111,18 @@ with f() as x, g() as y:
|
|||
suite1
|
||||
"""
|
||||
|
||||
docstring_prefixes = [
|
||||
"",
|
||||
"class foo():\n ",
|
||||
"def foo():\n ",
|
||||
"async def foo():\n ",
|
||||
]
|
||||
|
||||
class ASTTestCase(unittest.TestCase):
|
||||
def assertASTEqual(self, ast1, ast2):
|
||||
self.assertEqual(ast.dump(ast1), ast.dump(ast2))
|
||||
|
||||
def check_roundtrip(self, code1):
|
||||
def check_ast_roundtrip(self, code1):
|
||||
ast1 = ast.parse(code1)
|
||||
code2 = ast.unparse(ast1)
|
||||
ast2 = ast.parse(code2)
|
||||
|
@ -125,147 +131,154 @@ class ASTTestCase(unittest.TestCase):
|
|||
def check_invalid(self, node, raises=ValueError):
|
||||
self.assertRaises(raises, ast.unparse, node)
|
||||
|
||||
def check_src_roundtrip(self, code1, code2=None, strip=True):
|
||||
def get_source(self, code1, code2=None, strip=True):
|
||||
code2 = code2 or code1
|
||||
code1 = ast.unparse(ast.parse(code1))
|
||||
if strip:
|
||||
code1 = code1.strip()
|
||||
return code1, code2
|
||||
|
||||
def check_src_roundtrip(self, code1, code2=None, strip=True):
|
||||
code1, code2 = self.get_source(code1, code2, strip)
|
||||
self.assertEqual(code2, code1)
|
||||
|
||||
def check_src_dont_roundtrip(self, code1, code2=None, strip=True):
|
||||
code1, code2 = self.get_source(code1, code2, strip)
|
||||
self.assertNotEqual(code2, code1)
|
||||
|
||||
class UnparseTestCase(ASTTestCase):
|
||||
# Tests for specific bugs found in earlier versions of unparse
|
||||
|
||||
def test_fstrings(self):
|
||||
# See issue 25180
|
||||
self.check_roundtrip(r"""f'{f"{0}"*3}'""")
|
||||
self.check_roundtrip(r"""f'{f"{y}"*3}'""")
|
||||
self.check_ast_roundtrip(r"""f'{f"{0}"*3}'""")
|
||||
self.check_ast_roundtrip(r"""f'{f"{y}"*3}'""")
|
||||
|
||||
def test_strings(self):
|
||||
self.check_roundtrip("u'foo'")
|
||||
self.check_roundtrip("r'foo'")
|
||||
self.check_roundtrip("b'foo'")
|
||||
self.check_ast_roundtrip("u'foo'")
|
||||
self.check_ast_roundtrip("r'foo'")
|
||||
self.check_ast_roundtrip("b'foo'")
|
||||
|
||||
def test_del_statement(self):
|
||||
self.check_roundtrip("del x, y, z")
|
||||
self.check_ast_roundtrip("del x, y, z")
|
||||
|
||||
def test_shifts(self):
|
||||
self.check_roundtrip("45 << 2")
|
||||
self.check_roundtrip("13 >> 7")
|
||||
self.check_ast_roundtrip("45 << 2")
|
||||
self.check_ast_roundtrip("13 >> 7")
|
||||
|
||||
def test_for_else(self):
|
||||
self.check_roundtrip(for_else)
|
||||
self.check_ast_roundtrip(for_else)
|
||||
|
||||
def test_while_else(self):
|
||||
self.check_roundtrip(while_else)
|
||||
self.check_ast_roundtrip(while_else)
|
||||
|
||||
def test_unary_parens(self):
|
||||
self.check_roundtrip("(-1)**7")
|
||||
self.check_roundtrip("(-1.)**8")
|
||||
self.check_roundtrip("(-1j)**6")
|
||||
self.check_roundtrip("not True or False")
|
||||
self.check_roundtrip("True or not False")
|
||||
self.check_ast_roundtrip("(-1)**7")
|
||||
self.check_ast_roundtrip("(-1.)**8")
|
||||
self.check_ast_roundtrip("(-1j)**6")
|
||||
self.check_ast_roundtrip("not True or False")
|
||||
self.check_ast_roundtrip("True or not False")
|
||||
|
||||
def test_integer_parens(self):
|
||||
self.check_roundtrip("3 .__abs__()")
|
||||
self.check_ast_roundtrip("3 .__abs__()")
|
||||
|
||||
def test_huge_float(self):
|
||||
self.check_roundtrip("1e1000")
|
||||
self.check_roundtrip("-1e1000")
|
||||
self.check_roundtrip("1e1000j")
|
||||
self.check_roundtrip("-1e1000j")
|
||||
self.check_ast_roundtrip("1e1000")
|
||||
self.check_ast_roundtrip("-1e1000")
|
||||
self.check_ast_roundtrip("1e1000j")
|
||||
self.check_ast_roundtrip("-1e1000j")
|
||||
|
||||
def test_min_int(self):
|
||||
self.check_roundtrip(str(-(2 ** 31)))
|
||||
self.check_roundtrip(str(-(2 ** 63)))
|
||||
self.check_ast_roundtrip(str(-(2 ** 31)))
|
||||
self.check_ast_roundtrip(str(-(2 ** 63)))
|
||||
|
||||
def test_imaginary_literals(self):
|
||||
self.check_roundtrip("7j")
|
||||
self.check_roundtrip("-7j")
|
||||
self.check_roundtrip("0j")
|
||||
self.check_roundtrip("-0j")
|
||||
self.check_ast_roundtrip("7j")
|
||||
self.check_ast_roundtrip("-7j")
|
||||
self.check_ast_roundtrip("0j")
|
||||
self.check_ast_roundtrip("-0j")
|
||||
|
||||
def test_lambda_parentheses(self):
|
||||
self.check_roundtrip("(lambda: int)()")
|
||||
self.check_ast_roundtrip("(lambda: int)()")
|
||||
|
||||
def test_chained_comparisons(self):
|
||||
self.check_roundtrip("1 < 4 <= 5")
|
||||
self.check_roundtrip("a is b is c is not d")
|
||||
self.check_ast_roundtrip("1 < 4 <= 5")
|
||||
self.check_ast_roundtrip("a is b is c is not d")
|
||||
|
||||
def test_function_arguments(self):
|
||||
self.check_roundtrip("def f(): pass")
|
||||
self.check_roundtrip("def f(a): pass")
|
||||
self.check_roundtrip("def f(b = 2): pass")
|
||||
self.check_roundtrip("def f(a, b): pass")
|
||||
self.check_roundtrip("def f(a, b = 2): pass")
|
||||
self.check_roundtrip("def f(a = 5, b = 2): pass")
|
||||
self.check_roundtrip("def f(*, a = 1, b = 2): pass")
|
||||
self.check_roundtrip("def f(*, a = 1, b): pass")
|
||||
self.check_roundtrip("def f(*, a, b = 2): pass")
|
||||
self.check_roundtrip("def f(a, b = None, *, c, **kwds): pass")
|
||||
self.check_roundtrip("def f(a=2, *args, c=5, d, **kwds): pass")
|
||||
self.check_roundtrip("def f(*args, **kwargs): pass")
|
||||
self.check_ast_roundtrip("def f(): pass")
|
||||
self.check_ast_roundtrip("def f(a): pass")
|
||||
self.check_ast_roundtrip("def f(b = 2): pass")
|
||||
self.check_ast_roundtrip("def f(a, b): pass")
|
||||
self.check_ast_roundtrip("def f(a, b = 2): pass")
|
||||
self.check_ast_roundtrip("def f(a = 5, b = 2): pass")
|
||||
self.check_ast_roundtrip("def f(*, a = 1, b = 2): pass")
|
||||
self.check_ast_roundtrip("def f(*, a = 1, b): pass")
|
||||
self.check_ast_roundtrip("def f(*, a, b = 2): pass")
|
||||
self.check_ast_roundtrip("def f(a, b = None, *, c, **kwds): pass")
|
||||
self.check_ast_roundtrip("def f(a=2, *args, c=5, d, **kwds): pass")
|
||||
self.check_ast_roundtrip("def f(*args, **kwargs): pass")
|
||||
|
||||
def test_relative_import(self):
|
||||
self.check_roundtrip(relative_import)
|
||||
self.check_ast_roundtrip(relative_import)
|
||||
|
||||
def test_nonlocal(self):
|
||||
self.check_roundtrip(nonlocal_ex)
|
||||
self.check_ast_roundtrip(nonlocal_ex)
|
||||
|
||||
def test_raise_from(self):
|
||||
self.check_roundtrip(raise_from)
|
||||
self.check_ast_roundtrip(raise_from)
|
||||
|
||||
def test_bytes(self):
|
||||
self.check_roundtrip("b'123'")
|
||||
self.check_ast_roundtrip("b'123'")
|
||||
|
||||
def test_annotations(self):
|
||||
self.check_roundtrip("def f(a : int): pass")
|
||||
self.check_roundtrip("def f(a: int = 5): pass")
|
||||
self.check_roundtrip("def f(*args: [int]): pass")
|
||||
self.check_roundtrip("def f(**kwargs: dict): pass")
|
||||
self.check_roundtrip("def f() -> None: pass")
|
||||
self.check_ast_roundtrip("def f(a : int): pass")
|
||||
self.check_ast_roundtrip("def f(a: int = 5): pass")
|
||||
self.check_ast_roundtrip("def f(*args: [int]): pass")
|
||||
self.check_ast_roundtrip("def f(**kwargs: dict): pass")
|
||||
self.check_ast_roundtrip("def f() -> None: pass")
|
||||
|
||||
def test_set_literal(self):
|
||||
self.check_roundtrip("{'a', 'b', 'c'}")
|
||||
self.check_ast_roundtrip("{'a', 'b', 'c'}")
|
||||
|
||||
def test_set_comprehension(self):
|
||||
self.check_roundtrip("{x for x in range(5)}")
|
||||
self.check_ast_roundtrip("{x for x in range(5)}")
|
||||
|
||||
def test_dict_comprehension(self):
|
||||
self.check_roundtrip("{x: x*x for x in range(10)}")
|
||||
self.check_ast_roundtrip("{x: x*x for x in range(10)}")
|
||||
|
||||
def test_class_decorators(self):
|
||||
self.check_roundtrip(class_decorator)
|
||||
self.check_ast_roundtrip(class_decorator)
|
||||
|
||||
def test_class_definition(self):
|
||||
self.check_roundtrip("class A(metaclass=type, *[], **{}): pass")
|
||||
self.check_ast_roundtrip("class A(metaclass=type, *[], **{}): pass")
|
||||
|
||||
def test_elifs(self):
|
||||
self.check_roundtrip(elif1)
|
||||
self.check_roundtrip(elif2)
|
||||
self.check_ast_roundtrip(elif1)
|
||||
self.check_ast_roundtrip(elif2)
|
||||
|
||||
def test_try_except_finally(self):
|
||||
self.check_roundtrip(try_except_finally)
|
||||
self.check_ast_roundtrip(try_except_finally)
|
||||
|
||||
def test_starred_assignment(self):
|
||||
self.check_roundtrip("a, *b, c = seq")
|
||||
self.check_roundtrip("a, (*b, c) = seq")
|
||||
self.check_roundtrip("a, *b[0], c = seq")
|
||||
self.check_roundtrip("a, *(b, c) = seq")
|
||||
self.check_ast_roundtrip("a, *b, c = seq")
|
||||
self.check_ast_roundtrip("a, (*b, c) = seq")
|
||||
self.check_ast_roundtrip("a, *b[0], c = seq")
|
||||
self.check_ast_roundtrip("a, *(b, c) = seq")
|
||||
|
||||
def test_with_simple(self):
|
||||
self.check_roundtrip(with_simple)
|
||||
self.check_ast_roundtrip(with_simple)
|
||||
|
||||
def test_with_as(self):
|
||||
self.check_roundtrip(with_as)
|
||||
self.check_ast_roundtrip(with_as)
|
||||
|
||||
def test_with_two_items(self):
|
||||
self.check_roundtrip(with_two_items)
|
||||
self.check_ast_roundtrip(with_two_items)
|
||||
|
||||
def test_dict_unpacking_in_dict(self):
|
||||
# See issue 26489
|
||||
self.check_roundtrip(r"""{**{'y': 2}, 'x': 1}""")
|
||||
self.check_roundtrip(r"""{**{'y': 2}, **{'x': 1}}""")
|
||||
self.check_ast_roundtrip(r"""{**{'y': 2}, 'x': 1}""")
|
||||
self.check_ast_roundtrip(r"""{**{'y': 2}, **{'x': 1}}""")
|
||||
|
||||
def test_invalid_raise(self):
|
||||
self.check_invalid(ast.Raise(exc=None, cause=ast.Name(id="X")))
|
||||
|
@ -288,6 +301,16 @@ class UnparseTestCase(ASTTestCase):
|
|||
def test_invalid_yield_from(self):
|
||||
self.check_invalid(ast.YieldFrom(value=None))
|
||||
|
||||
def test_docstrings(self):
|
||||
docstrings = (
|
||||
'this ends with double quote"',
|
||||
'this includes a """triple quote"""'
|
||||
)
|
||||
for docstring in docstrings:
|
||||
# check as Module docstrings for easy testing
|
||||
self.check_ast_roundtrip(f"'{docstring}'")
|
||||
|
||||
|
||||
class CosmeticTestCase(ASTTestCase):
|
||||
"""Test if there are cosmetic issues caused by unnecesary additions"""
|
||||
|
||||
|
@ -321,6 +344,39 @@ class CosmeticTestCase(ASTTestCase):
|
|||
self.check_src_roundtrip("call((yield x))")
|
||||
self.check_src_roundtrip("return x + (yield x)")
|
||||
|
||||
def test_docstrings(self):
|
||||
docstrings = (
|
||||
'"""simple doc string"""',
|
||||
'''"""A more complex one
|
||||
with some newlines"""''',
|
||||
'''"""Foo bar baz
|
||||
|
||||
empty newline"""''',
|
||||
'"""With some \t"""',
|
||||
'"""Foo "bar" baz """',
|
||||
)
|
||||
|
||||
for prefix in docstring_prefixes:
|
||||
for docstring in docstrings:
|
||||
self.check_src_roundtrip(f"{prefix}{docstring}")
|
||||
|
||||
def test_docstrings_negative_cases(self):
|
||||
# Test some cases that involve strings in the children of the
|
||||
# first node but aren't docstrings to make sure we don't have
|
||||
# False positives.
|
||||
docstrings_negative = (
|
||||
'a = """false"""',
|
||||
'"""false""" + """unless its optimized"""',
|
||||
'1 + 1\n"""false"""',
|
||||
'f"""no, top level but f-fstring"""'
|
||||
)
|
||||
for prefix in docstring_prefixes:
|
||||
for negative in docstrings_negative:
|
||||
# this cases should be result with single quote
|
||||
# rather then triple quoted docstring
|
||||
src = f"{prefix}{negative}"
|
||||
self.check_ast_roundtrip(src)
|
||||
self.check_src_dont_roundtrip(src)
|
||||
|
||||
class DirectoryTestCase(ASTTestCase):
|
||||
"""Test roundtrip behaviour on all files in Lib and Lib/test."""
|
||||
|
@ -379,7 +435,7 @@ class DirectoryTestCase(ASTTestCase):
|
|||
|
||||
with self.subTest(filename=item):
|
||||
source = read_pyfile(item)
|
||||
self.check_roundtrip(source)
|
||||
self.check_ast_roundtrip(source)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue