Skip to content

bpo-43417: Better buffer handling for ast.unparse #24772

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 59 additions & 57 deletions Lib/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,6 @@ class _Unparser(NodeVisitor):

def __init__(self, *, _avoid_backslashes=False):
self._source = []
self._buffer = []
self._precedences = {}
self._type_ignores = {}
self._indent = 0
Expand Down Expand Up @@ -720,14 +719,15 @@ def write(self, text):
"""Append a piece of text"""
self._source.append(text)

def buffer_writer(self, text):
self._buffer.append(text)
@contextmanager
def buffered(self, buffer = None):
if buffer is None:
buffer = []

@property
def buffer(self):
value = "".join(self._buffer)
self._buffer.clear()
return value
original_source = self._source
self._source = buffer
yield buffer
self._source = original_source

@contextmanager
def block(self, *, extra = None):
Expand Down Expand Up @@ -1123,70 +1123,72 @@ def _write_str_avoiding_backslashes(self, string, *, quote_types=_ALL_QUOTES):
def visit_JoinedStr(self, node):
self.write("f")
if self._avoid_backslashes:
self._fstring_JoinedStr(node, self.buffer_writer)
self._write_str_avoiding_backslashes(self.buffer)
return
with self.buffered() as buffer:
self._write_fstring_inner(node)
return self._write_str_avoiding_backslashes("".join(buffer))

# If we don't need to avoid backslashes globally (i.e., we only need
# to avoid them inside FormattedValues), it's cosmetically preferred
# to use escaped whitespace. That is, it's preferred to use backslashes
# for cases like: f"{x}\n". To accomplish this, we keep track of what
# in our buffer corresponds to FormattedValues and what corresponds to
# Constant parts of the f-string, and allow escapes accordingly.
buffer = []
fstring_parts = []
for value in node.values:
meth = getattr(self, "_fstring_" + type(value).__name__)
meth(value, self.buffer_writer)
buffer.append((self.buffer, isinstance(value, Constant)))
new_buffer = []
quote_types = _ALL_QUOTES
for value, is_constant in buffer:
# Repeatedly narrow down the list of possible quote_types
with self.buffered() as buffer:
self._write_fstring_inner(value)
fstring_parts.append(
("".join(buffer), isinstance(value, Constant))
)

new_fstring_parts = []
quote_types = list(_ALL_QUOTES)
for value, is_constant in fstring_parts:
value, quote_types = self._str_literal_helper(
value, quote_types=quote_types,
escape_special_whitespace=is_constant
value,
quote_types=quote_types,
escape_special_whitespace=is_constant,
)
new_buffer.append(value)
value = "".join(new_buffer)
new_fstring_parts.append(value)

value = "".join(new_fstring_parts)
quote_type = quote_types[0]
self.write(f"{quote_type}{value}{quote_type}")

def _write_fstring_inner(self, node):
if isinstance(node, JoinedStr):
# for both the f-string itself, and format_spec
for value in node.values:
self._write_fstring_inner(value)
elif isinstance(node, Constant) and isinstance(node.value, str):
value = node.value.replace("{", "{{").replace("}", "}}")
self.write(value)
elif isinstance(node, FormattedValue):
self.visit_FormattedValue(node)
else:
raise ValueError(f"Unexpected node inside JoinedStr, {node!r}")

def visit_FormattedValue(self, node):
self.write("f")
self._fstring_FormattedValue(node, self.buffer_writer)
self._write_str_avoiding_backslashes(self.buffer)
def unparse_inner(inner):
unparser = type(self)(_avoid_backslashes=True)
unparser.set_precedence(_Precedence.TEST.next(), inner)
return unparser.visit(inner)

def _fstring_JoinedStr(self, node, write):
for value in node.values:
meth = getattr(self, "_fstring_" + type(value).__name__)
meth(value, write)

def _fstring_Constant(self, node, write):
if not isinstance(node.value, str):
raise ValueError("Constants inside JoinedStr should be a string.")
value = node.value.replace("{", "{{").replace("}", "}}")
write(value)

def _fstring_FormattedValue(self, node, write):
write("{")
unparser = type(self)(_avoid_backslashes=True)
unparser.set_precedence(_Precedence.TEST.next(), node.value)
expr = unparser.visit(node.value)
if expr.startswith("{"):
write(" ") # Separate pair of opening brackets as "{ {"
if "\\" in expr:
raise ValueError("Unable to avoid backslash in f-string expression part")
write(expr)
if node.conversion != -1:
conversion = chr(node.conversion)
if conversion not in "sra":
raise ValueError("Unknown f-string conversion.")
Comment on lines -1182 to -1183
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not a big fan of handling invalid cases in the ast.unparse if they are not going to break the rest of the code, which they won't in this case.

write(f"!{conversion}")
if node.format_spec:
write(":")
meth = getattr(self, "_fstring_" + type(node.format_spec).__name__)
meth(node.format_spec, write)
write("}")
with self.delimit("{", "}"):
expr = unparse_inner(node.value)
if "\\" in expr:
raise ValueError(
"Unable to avoid backslash in f-string expression part"
)
if expr.startswith("{"):
# Separate pair of opening brackets as "{ {"
self.write(" ")
self.write(expr)
if node.conversion != -1:
self.write(f"!{chr(node.conversion)}")
if node.format_spec:
self.write(":")
self._write_fstring_inner(node.format_spec)

def visit_Name(self, node):
self.write(node.id)
Expand Down
35 changes: 27 additions & 8 deletions Lib/test/test_unparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,27 @@ class UnparseTestCase(ASTTestCase):
# Tests for specific bugs found in earlier versions of unparse

def test_fstrings(self):
self.check_ast_roundtrip("f'a'")
self.check_ast_roundtrip("f'{{}}'")
self.check_ast_roundtrip("f'{{5}}'")
self.check_ast_roundtrip("f'{{5}}5'")
self.check_ast_roundtrip("f'X{{}}X'")
self.check_ast_roundtrip("f'{a}'")
self.check_ast_roundtrip("f'{ {1:2}}'")
self.check_ast_roundtrip("f'a{a}a'")
self.check_ast_roundtrip("f'a{a}{a}a'")
self.check_ast_roundtrip("f'a{a}a{a}a'")
self.check_ast_roundtrip("f'{a!r}x{a!s}12{{}}{a!a}'")
self.check_ast_roundtrip("f'{a:10}'")
self.check_ast_roundtrip("f'{a:100_000{10}}'")
self.check_ast_roundtrip("f'{a!r:10}'")
self.check_ast_roundtrip("f'{a:a{b}10}'")
self.check_ast_roundtrip(
"f'a{b}{c!s}{d!r}{e!a}{f:a}{g:a{b}}{h!s:a}"
"{j!s:{a}b}{k!s:a{b}c}{l!a:{b}c{d}}{x+y=}'"
)

def test_fstrings_special_chars(self):
# See issue 25180
self.check_ast_roundtrip(r"""f'{f"{0}"*3}'""")
self.check_ast_roundtrip(r"""f'{f"{y}"*3}'""")
Expand Down Expand Up @@ -311,15 +332,13 @@ def test_slices(self):
def test_invalid_raise(self):
self.check_invalid(ast.Raise(exc=None, cause=ast.Name(id="X")))

def test_invalid_fstring_constant(self):
self.check_invalid(ast.JoinedStr(values=[ast.Constant(value=100)]))

def test_invalid_fstring_conversion(self):
def test_invalid_fstring_value(self):
self.check_invalid(
ast.FormattedValue(
value=ast.Constant(value="a", kind=None),
conversion=ord("Y"), # random character
format_spec=None,
ast.JoinedStr(
values=[
ast.Name(id="test"),
ast.Constant(value="test")
]
)
)

Expand Down