From a993e901ebe60c38d46ecb31f771d0b4a206828c Mon Sep 17 00:00:00 2001 From: Shantanu <12621235+hauntsaninja@users.noreply.github.com> Date: Fri, 20 Nov 2020 13:16:42 -0800 Subject: [PATCH] bpo-28002: Roundtrip f-strings with ast.unparse better (#19612) By attempting to avoid backslashes in f-string expressions. We also now proactively raise errors for some backslashes we can't avoid while unparsing FormattedValues Co-authored-by: hauntsaninja <> Co-authored-by: Shantanu Co-authored-by: Batuhan Taskaya --- Lib/ast.py | 110 ++++++++++++++++++++++++++++++--------- Lib/test/test_unparse.py | 42 ++++++++++----- 2 files changed, 115 insertions(+), 37 deletions(-) diff --git a/Lib/ast.py b/Lib/ast.py index d8bd3373701..7275fe28ba8 100644 --- a/Lib/ast.py +++ b/Lib/ast.py @@ -662,17 +662,23 @@ class _Precedence(IntEnum): except ValueError: return self + +_SINGLE_QUOTES = ("'", '"') +_MULTI_QUOTES = ('"""', "'''") +_ALL_QUOTES = (*_SINGLE_QUOTES, *_MULTI_QUOTES) + class _Unparser(NodeVisitor): """Methods in this class recursively traverse an AST and output source code for the abstract syntax; original formatting is disregarded.""" - def __init__(self): + def __init__(self, *, _avoid_backslashes=False): self._source = [] self._buffer = [] self._precedences = {} self._type_ignores = {} self._indent = 0 + self._avoid_backslashes = _avoid_backslashes def interleave(self, inter, f, seq): """Call f on each item in seq, calling inter() in between.""" @@ -1067,15 +1073,85 @@ class _Unparser(NodeVisitor): with self.block(extra=self.get_type_comment(node)): self.traverse(node.body) + def _str_literal_helper( + self, string, *, quote_types=_ALL_QUOTES, escape_special_whitespace=False + ): + """Helper for writing string literals, minimizing escapes. + Returns the tuple (string literal to write, possible quote types). + """ + def escape_char(c): + # \n and \t are non-printable, but we only escape them if + # escape_special_whitespace is True + if not escape_special_whitespace and c in "\n\t": + return c + # Always escape backslashes and other non-printable characters + if c == "\\" or not c.isprintable(): + return c.encode("unicode_escape").decode("ascii") + return c + + escaped_string = "".join(map(escape_char, string)) + possible_quotes = quote_types + if "\n" in escaped_string: + possible_quotes = [q for q in possible_quotes if q in _MULTI_QUOTES] + possible_quotes = [q for q in possible_quotes if q not in escaped_string] + if not possible_quotes: + # If there aren't any possible_quotes, fallback to using repr + # on the original string. Try to use a quote from quote_types, + # e.g., so that we use triple quotes for docstrings. + string = repr(string) + quote = next((q for q in quote_types if string[0] in q), string[0]) + return string[1:-1], [quote] + if escaped_string: + # Sort so that we prefer '''"''' over """\"""" + possible_quotes.sort(key=lambda q: q[0] == escaped_string[-1]) + # If we're using triple quotes and we'd need to escape a final + # quote, escape it + if possible_quotes[0][0] == escaped_string[-1]: + assert len(possible_quotes[0]) == 3 + escaped_string = escaped_string[:-1] + "\\" + escaped_string[-1] + return escaped_string, possible_quotes + + def _write_str_avoiding_backslashes(self, string, *, quote_types=_ALL_QUOTES): + """Write string literal value with a best effort attempt to avoid backslashes.""" + string, quote_types = self._str_literal_helper(string, quote_types=quote_types) + quote_type = quote_types[0] + self.write(f"{quote_type}{string}{quote_type}") + def visit_JoinedStr(self, node): self.write("f") - self._fstring_JoinedStr(node, self.buffer_writer) - self.write(repr(self.buffer)) + if self._avoid_backslashes: + self._fstring_JoinedStr(node, self.buffer_writer) + self._write_str_avoiding_backslashes(self.buffer) + return + + # If we don't need to avoid backslashes globally (i.e., we only need + # to avoid them inside FormattedValues), it's cosmetically preferred + # to use escaped whitespace. That is, it's preferred to use backslashes + # for cases like: f"{x}\n". To accomplish this, we keep track of what + # in our buffer corresponds to FormattedValues and what corresponds to + # Constant parts of the f-string, and allow escapes accordingly. + buffer = [] + for value in node.values: + meth = getattr(self, "_fstring_" + type(value).__name__) + meth(value, self.buffer_writer) + buffer.append((self.buffer, isinstance(value, Constant))) + new_buffer = [] + quote_types = _ALL_QUOTES + for value, is_constant in buffer: + # Repeatedly narrow down the list of possible quote_types + value, quote_types = self._str_literal_helper( + value, quote_types=quote_types, + escape_special_whitespace=is_constant + ) + new_buffer.append(value) + value = "".join(new_buffer) + quote_type = quote_types[0] + self.write(f"{quote_type}{value}{quote_type}") def visit_FormattedValue(self, node): self.write("f") self._fstring_FormattedValue(node, self.buffer_writer) - self.write(repr(self.buffer)) + self._write_str_avoiding_backslashes(self.buffer) def _fstring_JoinedStr(self, node, write): for value in node.values: @@ -1090,11 +1166,13 @@ class _Unparser(NodeVisitor): def _fstring_FormattedValue(self, node, write): write("{") - unparser = type(self)() + unparser = type(self)(_avoid_backslashes=True) unparser.set_precedence(_Precedence.TEST.next(), node.value) expr = unparser.visit(node.value) if expr.startswith("{"): write(" ") # Separate pair of opening brackets as "{ {" + if "\\" in expr: + raise ValueError("Unable to avoid backslash in f-string expression part") write(expr) if node.conversion != -1: conversion = chr(node.conversion) @@ -1111,33 +1189,17 @@ class _Unparser(NodeVisitor): self.write(node.id) def _write_docstring(self, node): - def esc_char(c): - if c in ("\n", "\t"): - # In the AST form, we don't know the author's intentation - # about how this should be displayed. We'll only escape - # \n and \t, because they are more likely to be unescaped - # in the source - return c - return c.encode('unicode_escape').decode('ascii') - self.fill() if node.kind == "u": self.write("u") - - value = node.value - if value: - # Preserve quotes in the docstring by escaping them - value = "".join(map(esc_char, value)) - if value[-1] == '"': - value = value.replace('"', '\\"', -1) - value = value.replace('"""', '""\\"') - - self.write(f'"""{value}"""') + self._write_str_avoiding_backslashes(node.value, quote_types=_MULTI_QUOTES) def _write_constant(self, value): if isinstance(value, (float, complex)): # Substitute overflowing decimal literal for AST infinities. self.write(repr(value).replace("inf", _INFSTR)) + elif self._avoid_backslashes and isinstance(value, str): + self._write_str_avoiding_backslashes(value) else: self.write(repr(value)) diff --git a/Lib/test/test_unparse.py b/Lib/test/test_unparse.py index 532aa3a6390..c7c8613ea27 100644 --- a/Lib/test/test_unparse.py +++ b/Lib/test/test_unparse.py @@ -152,6 +152,18 @@ class UnparseTestCase(ASTTestCase): # See issue 25180 self.check_ast_roundtrip(r"""f'{f"{0}"*3}'""") self.check_ast_roundtrip(r"""f'{f"{y}"*3}'""") + self.check_ast_roundtrip("""f''""") + self.check_ast_roundtrip('''f"""'end' "quote\\""""''') + + def test_fstrings_complicated(self): + # See issue 28002 + self.check_ast_roundtrip("""f'''{"'"}'''""") + self.check_ast_roundtrip('''f\'\'\'-{f"""*{f"+{f'.{x}.'}+"}*"""}-\'\'\'''') + self.check_ast_roundtrip('''f\'\'\'-{f"""*{f"+{f'.{x}.'}+"}*"""}-'single quote\\'\'\'\'''') + self.check_ast_roundtrip('f"""{\'\'\'\n\'\'\'}"""') + self.check_ast_roundtrip('f"""{g(\'\'\'\n\'\'\')}"""') + self.check_ast_roundtrip('''f"a\\r\\nb"''') + self.check_ast_roundtrip('''f"\\u2028{'x'}"''') def test_strings(self): self.check_ast_roundtrip("u'foo'") @@ -311,6 +323,9 @@ class UnparseTestCase(ASTTestCase): ) ) + def test_invalid_fstring_backslash(self): + self.check_invalid(ast.FormattedValue(value=ast.Constant(value="\\\\"))) + def test_invalid_set(self): self.check_invalid(ast.Set(elts=[])) @@ -330,8 +345,8 @@ class UnparseTestCase(ASTTestCase): '\r\\r\t\\t\n\\n', '""">>> content = \"\"\"blabla\"\"\" <<<"""', r'foo\n\x00', - '🐍⛎𩸽üéş^\N{LONG RIGHTWARDS SQUIGGLE ARROW}' - + "' \\'\\'\\'\"\"\" \"\"\\'\\' \\'", + '🐍⛎𩸽üéş^\\\\X\\\\BB\N{LONG RIGHTWARDS SQUIGGLE ARROW}' ) for docstring in docstrings: # check as Module docstrings for easy testing @@ -416,7 +431,6 @@ class CosmeticTestCase(ASTTestCase): self.check_src_roundtrip("call((yield x))") self.check_src_roundtrip("return x + (yield x)") - def test_class_bases_and_keywords(self): self.check_src_roundtrip("class X:\n pass") self.check_src_roundtrip("class X(A):\n pass") @@ -429,6 +443,13 @@ class CosmeticTestCase(ASTTestCase): self.check_src_roundtrip("class X(*args):\n pass") self.check_src_roundtrip("class X(*args, **kwargs):\n pass") + def test_fstrings(self): + self.check_src_roundtrip('''f\'\'\'-{f"""*{f"+{f'.{x}.'}+"}*"""}-\'\'\'''') + self.check_src_roundtrip('''f"\\u2028{'x'}"''') + self.check_src_roundtrip(r"f'{x}\n'") + self.check_src_roundtrip('''f''\'{"""\n"""}\\n''\'''') + self.check_src_roundtrip('''f''\'{f"""{x}\n"""}\\n''\'''') + def test_docstrings(self): docstrings = ( '"""simple doc string"""', @@ -443,6 +464,10 @@ class CosmeticTestCase(ASTTestCase): '""""""', '"""\'\'\'"""', '"""\'\'\'\'\'\'"""', + '"""🐍⛎𩸽üéş^\\\\X\\\\BB⟿"""', + '"""end in single \'quote\'"""', + "'''end in double \"quote\"'''", + '"""almost end in double "quote"."""', ) for prefix in docstring_prefixes: @@ -483,9 +508,8 @@ class DirectoryTestCase(ASTTestCase): lib_dir = pathlib.Path(__file__).parent / ".." test_directories = (lib_dir, lib_dir / "test") - skip_files = {"test_fstring.py"} run_always_files = {"test_grammar.py", "test_syntax.py", "test_compile.py", - "test_ast.py", "test_asdl_parser.py"} + "test_ast.py", "test_asdl_parser.py", "test_fstring.py"} _files_to_test = None @@ -525,14 +549,6 @@ class DirectoryTestCase(ASTTestCase): if test.support.verbose: print(f"Testing {item.absolute()}") - # Some f-strings are not correctly round-tripped by - # Tools/parser/unparse.py. See issue 28002 for details. - # We need to skip files that contain such f-strings. - if item.name in self.skip_files: - if test.support.verbose: - print(f"Skipping {item.absolute()}: see issue 28002") - continue - with self.subTest(filename=item): source = read_pyfile(item) self.check_ast_roundtrip(source)