Skip to content

Commit 3d98ece

Browse files
authored
bpo-43417: Better buffer handling for ast.unparse (GH-24772)
1 parent a0bd9e9 commit 3d98ece

File tree

2 files changed

+86
-65
lines changed

2 files changed

+86
-65
lines changed

Lib/ast.py

Lines changed: 59 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,6 @@ class _Unparser(NodeVisitor):
678678

679679
def __init__(self, *, _avoid_backslashes=False):
680680
self._source = []
681-
self._buffer = []
682681
self._precedences = {}
683682
self._type_ignores = {}
684683
self._indent = 0
@@ -721,14 +720,15 @@ def write(self, text):
721720
"""Append a piece of text"""
722721
self._source.append(text)
723722

724-
def buffer_writer(self, text):
725-
self._buffer.append(text)
723+
@contextmanager
724+
def buffered(self, buffer = None):
725+
if buffer is None:
726+
buffer = []
726727

727-
@property
728-
def buffer(self):
729-
value = "".join(self._buffer)
730-
self._buffer.clear()
731-
return value
728+
original_source = self._source
729+
self._source = buffer
730+
yield buffer
731+
self._source = original_source
732732

733733
@contextmanager
734734
def block(self, *, extra = None):
@@ -1127,70 +1127,72 @@ def _write_str_avoiding_backslashes(self, string, *, quote_types=_ALL_QUOTES):
11271127
def visit_JoinedStr(self, node):
11281128
self.write("f")
11291129
if self._avoid_backslashes:
1130-
self._fstring_JoinedStr(node, self.buffer_writer)
1131-
self._write_str_avoiding_backslashes(self.buffer)
1132-
return
1130+
with self.buffered() as buffer:
1131+
self._write_fstring_inner(node)
1132+
return self._write_str_avoiding_backslashes("".join(buffer))
11331133

11341134
# If we don't need to avoid backslashes globally (i.e., we only need
11351135
# to avoid them inside FormattedValues), it's cosmetically preferred
11361136
# to use escaped whitespace. That is, it's preferred to use backslashes
11371137
# for cases like: f"{x}\n". To accomplish this, we keep track of what
11381138
# in our buffer corresponds to FormattedValues and what corresponds to
11391139
# Constant parts of the f-string, and allow escapes accordingly.
1140-
buffer = []
1140+
fstring_parts = []
11411141
for value in node.values:
1142-
meth = getattr(self, "_fstring_" + type(value).__name__)
1143-
meth(value, self.buffer_writer)
1144-
buffer.append((self.buffer, isinstance(value, Constant)))
1145-
new_buffer = []
1146-
quote_types = _ALL_QUOTES
1147-
for value, is_constant in buffer:
1148-
# Repeatedly narrow down the list of possible quote_types
1142+
with self.buffered() as buffer:
1143+
self._write_fstring_inner(value)
1144+
fstring_parts.append(
1145+
("".join(buffer), isinstance(value, Constant))
1146+
)
1147+
1148+
new_fstring_parts = []
1149+
quote_types = list(_ALL_QUOTES)
1150+
for value, is_constant in fstring_parts:
11491151
value, quote_types = self._str_literal_helper(
1150-
value, quote_types=quote_types,
1151-
escape_special_whitespace=is_constant
1152+
value,
1153+
quote_types=quote_types,
1154+
escape_special_whitespace=is_constant,
11521155
)
1153-
new_buffer.append(value)
1154-
value = "".join(new_buffer)
1156+
new_fstring_parts.append(value)
1157+
1158+
value = "".join(new_fstring_parts)
11551159
quote_type = quote_types[0]
11561160
self.write(f"{quote_type}{value}{quote_type}")
11571161

1162+
def _write_fstring_inner(self, node):
1163+
if isinstance(node, JoinedStr):
1164+
# for both the f-string itself, and format_spec
1165+
for value in node.values:
1166+
self._write_fstring_inner(value)
1167+
elif isinstance(node, Constant) and isinstance(node.value, str):
1168+
value = node.value.replace("{", "{{").replace("}", "}}")
1169+
self.write(value)
1170+
elif isinstance(node, FormattedValue):
1171+
self.visit_FormattedValue(node)
1172+
else:
1173+
raise ValueError(f"Unexpected node inside JoinedStr, {node!r}")
1174+
11581175
def visit_FormattedValue(self, node):
1159-
self.write("f")
1160-
self._fstring_FormattedValue(node, self.buffer_writer)
1161-
self._write_str_avoiding_backslashes(self.buffer)
1176+
def unparse_inner(inner):
1177+
unparser = type(self)(_avoid_backslashes=True)
1178+
unparser.set_precedence(_Precedence.TEST.next(), inner)
1179+
return unparser.visit(inner)
11621180

1163-
def _fstring_JoinedStr(self, node, write):
1164-
for value in node.values:
1165-
meth = getattr(self, "_fstring_" + type(value).__name__)
1166-
meth(value, write)
1167-
1168-
def _fstring_Constant(self, node, write):
1169-
if not isinstance(node.value, str):
1170-
raise ValueError("Constants inside JoinedStr should be a string.")
1171-
value = node.value.replace("{", "{{").replace("}", "}}")
1172-
write(value)
1173-
1174-
def _fstring_FormattedValue(self, node, write):
1175-
write("{")
1176-
unparser = type(self)(_avoid_backslashes=True)
1177-
unparser.set_precedence(_Precedence.TEST.next(), node.value)
1178-
expr = unparser.visit(node.value)
1179-
if expr.startswith("{"):
1180-
write(" ") # Separate pair of opening brackets as "{ {"
1181-
if "\\" in expr:
1182-
raise ValueError("Unable to avoid backslash in f-string expression part")
1183-
write(expr)
1184-
if node.conversion != -1:
1185-
conversion = chr(node.conversion)
1186-
if conversion not in "sra":
1187-
raise ValueError("Unknown f-string conversion.")
1188-
write(f"!{conversion}")
1189-
if node.format_spec:
1190-
write(":")
1191-
meth = getattr(self, "_fstring_" + type(node.format_spec).__name__)
1192-
meth(node.format_spec, write)
1193-
write("}")
1181+
with self.delimit("{", "}"):
1182+
expr = unparse_inner(node.value)
1183+
if "\\" in expr:
1184+
raise ValueError(
1185+
"Unable to avoid backslash in f-string expression part"
1186+
)
1187+
if expr.startswith("{"):
1188+
# Separate pair of opening brackets as "{ {"
1189+
self.write(" ")
1190+
self.write(expr)
1191+
if node.conversion != -1:
1192+
self.write(f"!{chr(node.conversion)}")
1193+
if node.format_spec:
1194+
self.write(":")
1195+
self._write_fstring_inner(node.format_spec)
11941196

11951197
def visit_Name(self, node):
11961198
self.write(node.id)

Lib/test/test_unparse.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,27 @@ class UnparseTestCase(ASTTestCase):
149149
# Tests for specific bugs found in earlier versions of unparse
150150

151151
def test_fstrings(self):
152+
self.check_ast_roundtrip("f'a'")
153+
self.check_ast_roundtrip("f'{{}}'")
154+
self.check_ast_roundtrip("f'{{5}}'")
155+
self.check_ast_roundtrip("f'{{5}}5'")
156+
self.check_ast_roundtrip("f'X{{}}X'")
157+
self.check_ast_roundtrip("f'{a}'")
158+
self.check_ast_roundtrip("f'{ {1:2}}'")
159+
self.check_ast_roundtrip("f'a{a}a'")
160+
self.check_ast_roundtrip("f'a{a}{a}a'")
161+
self.check_ast_roundtrip("f'a{a}a{a}a'")
162+
self.check_ast_roundtrip("f'{a!r}x{a!s}12{{}}{a!a}'")
163+
self.check_ast_roundtrip("f'{a:10}'")
164+
self.check_ast_roundtrip("f'{a:100_000{10}}'")
165+
self.check_ast_roundtrip("f'{a!r:10}'")
166+
self.check_ast_roundtrip("f'{a:a{b}10}'")
167+
self.check_ast_roundtrip(
168+
"f'a{b}{c!s}{d!r}{e!a}{f:a}{g:a{b}}{h!s:a}"
169+
"{j!s:{a}b}{k!s:a{b}c}{l!a:{b}c{d}}{x+y=}'"
170+
)
171+
172+
def test_fstrings_special_chars(self):
152173
# See issue 25180
153174
self.check_ast_roundtrip(r"""f'{f"{0}"*3}'""")
154175
self.check_ast_roundtrip(r"""f'{f"{y}"*3}'""")
@@ -323,15 +344,13 @@ def test_slices(self):
323344
def test_invalid_raise(self):
324345
self.check_invalid(ast.Raise(exc=None, cause=ast.Name(id="X")))
325346

326-
def test_invalid_fstring_constant(self):
327-
self.check_invalid(ast.JoinedStr(values=[ast.Constant(value=100)]))
328-
329-
def test_invalid_fstring_conversion(self):
347+
def test_invalid_fstring_value(self):
330348
self.check_invalid(
331-
ast.FormattedValue(
332-
value=ast.Constant(value="a", kind=None),
333-
conversion=ord("Y"), # random character
334-
format_spec=None,
349+
ast.JoinedStr(
350+
values=[
351+
ast.Name(id="test"),
352+
ast.Constant(value="test")
353+
]
335354
)
336355
)
337356

0 commit comments

Comments
 (0)