From 3d98ececda1335c7ed2a6c6a2b0d3bb46f2d3c55 Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Sun, 9 May 2021 02:32:04 +0300 Subject: [PATCH] bpo-43417: Better buffer handling for ast.unparse (GH-24772) --- Lib/ast.py | 118 ++++++++++++++++++++------------------- Lib/test/test_unparse.py | 35 +++++++++--- 2 files changed, 87 insertions(+), 66 deletions(-) diff --git a/Lib/ast.py b/Lib/ast.py index 66bcee8a252..18163d6b7bd 100644 --- a/Lib/ast.py +++ b/Lib/ast.py @@ -678,7 +678,6 @@ class _Unparser(NodeVisitor): def __init__(self, *, _avoid_backslashes=False): self._source = [] - self._buffer = [] self._precedences = {} self._type_ignores = {} self._indent = 0 @@ -721,14 +720,15 @@ class _Unparser(NodeVisitor): """Append a piece of text""" self._source.append(text) - def buffer_writer(self, text): - self._buffer.append(text) + @contextmanager + def buffered(self, buffer = None): + if buffer is None: + buffer = [] - @property - def buffer(self): - value = "".join(self._buffer) - self._buffer.clear() - return value + original_source = self._source + self._source = buffer + yield buffer + self._source = original_source @contextmanager def block(self, *, extra = None): @@ -1127,9 +1127,9 @@ class _Unparser(NodeVisitor): def visit_JoinedStr(self, node): self.write("f") if self._avoid_backslashes: - self._fstring_JoinedStr(node, self.buffer_writer) - self._write_str_avoiding_backslashes(self.buffer) - return + with self.buffered() as buffer: + self._write_fstring_inner(node) + return self._write_str_avoiding_backslashes("".join(buffer)) # If we don't need to avoid backslashes globally (i.e., we only need # to avoid them inside FormattedValues), it's cosmetically preferred @@ -1137,60 +1137,62 @@ class _Unparser(NodeVisitor): # for cases like: f"{x}\n". To accomplish this, we keep track of what # in our buffer corresponds to FormattedValues and what corresponds to # Constant parts of the f-string, and allow escapes accordingly. - buffer = [] + fstring_parts = [] for value in node.values: - meth = getattr(self, "_fstring_" + type(value).__name__) - meth(value, self.buffer_writer) - buffer.append((self.buffer, isinstance(value, Constant))) - new_buffer = [] - quote_types = _ALL_QUOTES - for value, is_constant in buffer: - # Repeatedly narrow down the list of possible quote_types - value, quote_types = self._str_literal_helper( - value, quote_types=quote_types, - escape_special_whitespace=is_constant + with self.buffered() as buffer: + self._write_fstring_inner(value) + fstring_parts.append( + ("".join(buffer), isinstance(value, Constant)) ) - new_buffer.append(value) - value = "".join(new_buffer) + + new_fstring_parts = [] + quote_types = list(_ALL_QUOTES) + for value, is_constant in fstring_parts: + value, quote_types = self._str_literal_helper( + value, + quote_types=quote_types, + escape_special_whitespace=is_constant, + ) + new_fstring_parts.append(value) + + value = "".join(new_fstring_parts) quote_type = quote_types[0] self.write(f"{quote_type}{value}{quote_type}") + def _write_fstring_inner(self, node): + if isinstance(node, JoinedStr): + # for both the f-string itself, and format_spec + for value in node.values: + self._write_fstring_inner(value) + elif isinstance(node, Constant) and isinstance(node.value, str): + value = node.value.replace("{", "{{").replace("}", "}}") + self.write(value) + elif isinstance(node, FormattedValue): + self.visit_FormattedValue(node) + else: + raise ValueError(f"Unexpected node inside JoinedStr, {node!r}") + def visit_FormattedValue(self, node): - self.write("f") - self._fstring_FormattedValue(node, self.buffer_writer) - self._write_str_avoiding_backslashes(self.buffer) + def unparse_inner(inner): + unparser = type(self)(_avoid_backslashes=True) + unparser.set_precedence(_Precedence.TEST.next(), inner) + return unparser.visit(inner) - def _fstring_JoinedStr(self, node, write): - for value in node.values: - meth = getattr(self, "_fstring_" + type(value).__name__) - meth(value, write) - - def _fstring_Constant(self, node, write): - if not isinstance(node.value, str): - raise ValueError("Constants inside JoinedStr should be a string.") - value = node.value.replace("{", "{{").replace("}", "}}") - write(value) - - def _fstring_FormattedValue(self, node, write): - write("{") - unparser = type(self)(_avoid_backslashes=True) - unparser.set_precedence(_Precedence.TEST.next(), node.value) - expr = unparser.visit(node.value) - if expr.startswith("{"): - write(" ") # Separate pair of opening brackets as "{ {" - if "\\" in expr: - raise ValueError("Unable to avoid backslash in f-string expression part") - write(expr) - if node.conversion != -1: - conversion = chr(node.conversion) - if conversion not in "sra": - raise ValueError("Unknown f-string conversion.") - write(f"!{conversion}") - if node.format_spec: - write(":") - meth = getattr(self, "_fstring_" + type(node.format_spec).__name__) - meth(node.format_spec, write) - write("}") + with self.delimit("{", "}"): + expr = unparse_inner(node.value) + if "\\" in expr: + raise ValueError( + "Unable to avoid backslash in f-string expression part" + ) + if expr.startswith("{"): + # Separate pair of opening brackets as "{ {" + self.write(" ") + self.write(expr) + if node.conversion != -1: + self.write(f"!{chr(node.conversion)}") + if node.format_spec: + self.write(":") + self._write_fstring_inner(node.format_spec) def visit_Name(self, node): self.write(node.id) diff --git a/Lib/test/test_unparse.py b/Lib/test/test_unparse.py index 9f67b49f3a6..534431bc969 100644 --- a/Lib/test/test_unparse.py +++ b/Lib/test/test_unparse.py @@ -149,6 +149,27 @@ class UnparseTestCase(ASTTestCase): # Tests for specific bugs found in earlier versions of unparse def test_fstrings(self): + self.check_ast_roundtrip("f'a'") + self.check_ast_roundtrip("f'{{}}'") + self.check_ast_roundtrip("f'{{5}}'") + self.check_ast_roundtrip("f'{{5}}5'") + self.check_ast_roundtrip("f'X{{}}X'") + self.check_ast_roundtrip("f'{a}'") + self.check_ast_roundtrip("f'{ {1:2}}'") + self.check_ast_roundtrip("f'a{a}a'") + self.check_ast_roundtrip("f'a{a}{a}a'") + self.check_ast_roundtrip("f'a{a}a{a}a'") + self.check_ast_roundtrip("f'{a!r}x{a!s}12{{}}{a!a}'") + self.check_ast_roundtrip("f'{a:10}'") + self.check_ast_roundtrip("f'{a:100_000{10}}'") + self.check_ast_roundtrip("f'{a!r:10}'") + self.check_ast_roundtrip("f'{a:a{b}10}'") + self.check_ast_roundtrip( + "f'a{b}{c!s}{d!r}{e!a}{f:a}{g:a{b}}{h!s:a}" + "{j!s:{a}b}{k!s:a{b}c}{l!a:{b}c{d}}{x+y=}'" + ) + + def test_fstrings_special_chars(self): # See issue 25180 self.check_ast_roundtrip(r"""f'{f"{0}"*3}'""") self.check_ast_roundtrip(r"""f'{f"{y}"*3}'""") @@ -323,15 +344,13 @@ class UnparseTestCase(ASTTestCase): def test_invalid_raise(self): self.check_invalid(ast.Raise(exc=None, cause=ast.Name(id="X"))) - def test_invalid_fstring_constant(self): - self.check_invalid(ast.JoinedStr(values=[ast.Constant(value=100)])) - - def test_invalid_fstring_conversion(self): + def test_invalid_fstring_value(self): self.check_invalid( - ast.FormattedValue( - value=ast.Constant(value="a", kind=None), - conversion=ord("Y"), # random character - format_spec=None, + ast.JoinedStr( + values=[ + ast.Name(id="test"), + ast.Constant(value="test") + ] ) )