Skip to content

Commit 397b96f

Browse files
authored
bpo-38870: Implement a precedence algorithm in ast.unparse (pythonGH-17377)
Implement a simple precedence algorithm for ast.unparse in order to avoid redundant parenthesis for nested structures in the final output.
1 parent 185903d commit 397b96f

File tree

3 files changed

+172
-16
lines changed

3 files changed

+172
-16
lines changed

Lib/ast.py

Lines changed: 123 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import sys
2828
from _ast import *
2929
from contextlib import contextmanager, nullcontext
30+
from enum import IntEnum, auto
3031

3132

3233
def parse(source, filename='<unknown>', mode='exec', *,
@@ -560,6 +561,35 @@ def __new__(cls, *args, **kwargs):
560561
# We unparse those infinities to INFSTR.
561562
_INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1)
562563

564+
class _Precedence(IntEnum):
565+
"""Precedence table that originated from python grammar."""
566+
567+
TUPLE = auto()
568+
YIELD = auto() # 'yield', 'yield from'
569+
TEST = auto() # 'if'-'else', 'lambda'
570+
OR = auto() # 'or'
571+
AND = auto() # 'and'
572+
NOT = auto() # 'not'
573+
CMP = auto() # '<', '>', '==', '>=', '<=', '!=',
574+
# 'in', 'not in', 'is', 'is not'
575+
EXPR = auto()
576+
BOR = EXPR # '|'
577+
BXOR = auto() # '^'
578+
BAND = auto() # '&'
579+
SHIFT = auto() # '<<', '>>'
580+
ARITH = auto() # '+', '-'
581+
TERM = auto() # '*', '@', '/', '%', '//'
582+
FACTOR = auto() # unary '+', '-', '~'
583+
POWER = auto() # '**'
584+
AWAIT = auto() # 'await'
585+
ATOM = auto()
586+
587+
def next(self):
588+
try:
589+
return self.__class__(self + 1)
590+
except ValueError:
591+
return self
592+
563593
class _Unparser(NodeVisitor):
564594
"""Methods in this class recursively traverse an AST and
565595
output source code for the abstract syntax; original formatting
@@ -568,6 +598,7 @@ class _Unparser(NodeVisitor):
568598
def __init__(self):
569599
self._source = []
570600
self._buffer = []
601+
self._precedences = {}
571602
self._indent = 0
572603

573604
def interleave(self, inter, f, seq):
@@ -625,6 +656,17 @@ def delimit_if(self, start, end, condition):
625656
else:
626657
return nullcontext()
627658

659+
def require_parens(self, precedence, node):
660+
"""Shortcut to adding precedence related parens"""
661+
return self.delimit_if("(", ")", self.get_precedence(node) > precedence)
662+
663+
def get_precedence(self, node):
664+
return self._precedences.get(node, _Precedence.TEST)
665+
666+
def set_precedence(self, precedence, *nodes):
667+
for node in nodes:
668+
self._precedences[node] = precedence
669+
628670
def traverse(self, node):
629671
if isinstance(node, list):
630672
for item in node:
@@ -645,10 +687,12 @@ def visit_Module(self, node):
645687

646688
def visit_Expr(self, node):
647689
self.fill()
690+
self.set_precedence(_Precedence.YIELD, node.value)
648691
self.traverse(node.value)
649692

650693
def visit_NamedExpr(self, node):
651-
with self.delimit("(", ")"):
694+
with self.require_parens(_Precedence.TUPLE, node):
695+
self.set_precedence(_Precedence.ATOM, node.target, node.value)
652696
self.traverse(node.target)
653697
self.write(" := ")
654698
self.traverse(node.value)
@@ -723,24 +767,27 @@ def visit_Nonlocal(self, node):
723767
self.interleave(lambda: self.write(", "), self.write, node.names)
724768

725769
def visit_Await(self, node):
726-
with self.delimit("(", ")"):
770+
with self.require_parens(_Precedence.AWAIT, node):
727771
self.write("await")
728772
if node.value:
729773
self.write(" ")
774+
self.set_precedence(_Precedence.ATOM, node.value)
730775
self.traverse(node.value)
731776

732777
def visit_Yield(self, node):
733-
with self.delimit("(", ")"):
778+
with self.require_parens(_Precedence.YIELD, node):
734779
self.write("yield")
735780
if node.value:
736781
self.write(" ")
782+
self.set_precedence(_Precedence.ATOM, node.value)
737783
self.traverse(node.value)
738784

739785
def visit_YieldFrom(self, node):
740-
with self.delimit("(", ")"):
786+
with self.require_parens(_Precedence.YIELD, node):
741787
self.write("yield from ")
742788
if not node.value:
743789
raise ValueError("Node can't be used without a value attribute.")
790+
self.set_precedence(_Precedence.ATOM, node.value)
744791
self.traverse(node.value)
745792

746793
def visit_Raise(self, node):
@@ -907,7 +954,9 @@ def _fstring_Constant(self, node, write):
907954

908955
def _fstring_FormattedValue(self, node, write):
909956
write("{")
910-
expr = type(self)().visit(node.value).rstrip("\n")
957+
unparser = type(self)()
958+
unparser.set_precedence(_Precedence.TEST.next(), node.value)
959+
expr = unparser.visit(node.value).rstrip("\n")
911960
if expr.startswith("{"):
912961
write(" ") # Separate pair of opening brackets as "{ {"
913962
write(expr)
@@ -983,19 +1032,23 @@ def visit_comprehension(self, node):
9831032
self.write(" async for ")
9841033
else:
9851034
self.write(" for ")
1035+
self.set_precedence(_Precedence.TUPLE, node.target)
9861036
self.traverse(node.target)
9871037
self.write(" in ")
1038+
self.set_precedence(_Precedence.TEST.next(), node.iter, *node.ifs)
9881039
self.traverse(node.iter)
9891040
for if_clause in node.ifs:
9901041
self.write(" if ")
9911042
self.traverse(if_clause)
9921043

9931044
def visit_IfExp(self, node):
994-
with self.delimit("(", ")"):
1045+
with self.require_parens(_Precedence.TEST, node):
1046+
self.set_precedence(_Precedence.TEST.next(), node.body, node.test)
9951047
self.traverse(node.body)
9961048
self.write(" if ")
9971049
self.traverse(node.test)
9981050
self.write(" else ")
1051+
self.set_precedence(_Precedence.TEST, node.orelse)
9991052
self.traverse(node.orelse)
10001053

10011054
def visit_Set(self, node):
@@ -1016,6 +1069,7 @@ def write_item(item):
10161069
# for dictionary unpacking operator in dicts {**{'y': 2}}
10171070
# see PEP 448 for details
10181071
self.write("**")
1072+
self.set_precedence(_Precedence.EXPR, v)
10191073
self.traverse(v)
10201074
else:
10211075
write_key_value_pair(k, v)
@@ -1035,11 +1089,20 @@ def visit_Tuple(self, node):
10351089
self.interleave(lambda: self.write(", "), self.traverse, node.elts)
10361090

10371091
unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"}
1092+
unop_precedence = {
1093+
"~": _Precedence.FACTOR,
1094+
"not": _Precedence.NOT,
1095+
"+": _Precedence.FACTOR,
1096+
"-": _Precedence.FACTOR
1097+
}
10381098

10391099
def visit_UnaryOp(self, node):
1040-
with self.delimit("(", ")"):
1041-
self.write(self.unop[node.op.__class__.__name__])
1100+
operator = self.unop[node.op.__class__.__name__]
1101+
operator_precedence = self.unop_precedence[operator]
1102+
with self.require_parens(operator_precedence, node):
1103+
self.write(operator)
10421104
self.write(" ")
1105+
self.set_precedence(operator_precedence, node.operand)
10431106
self.traverse(node.operand)
10441107

10451108
binop = {
@@ -1058,10 +1121,38 @@ def visit_UnaryOp(self, node):
10581121
"Pow": "**",
10591122
}
10601123

1124+
binop_precedence = {
1125+
"+": _Precedence.ARITH,
1126+
"-": _Precedence.ARITH,
1127+
"*": _Precedence.TERM,
1128+
"@": _Precedence.TERM,
1129+
"/": _Precedence.TERM,
1130+
"%": _Precedence.TERM,
1131+
"<<": _Precedence.SHIFT,
1132+
">>": _Precedence.SHIFT,
1133+
"|": _Precedence.BOR,
1134+
"^": _Precedence.BXOR,
1135+
"&": _Precedence.BAND,
1136+
"//": _Precedence.TERM,
1137+
"**": _Precedence.POWER,
1138+
}
1139+
1140+
binop_rassoc = frozenset(("**",))
10611141
def visit_BinOp(self, node):
1062-
with self.delimit("(", ")"):
1142+
operator = self.binop[node.op.__class__.__name__]
1143+
operator_precedence = self.binop_precedence[operator]
1144+
with self.require_parens(operator_precedence, node):
1145+
if operator in self.binop_rassoc:
1146+
left_precedence = operator_precedence.next()
1147+
right_precedence = operator_precedence
1148+
else:
1149+
left_precedence = operator_precedence
1150+
right_precedence = operator_precedence.next()
1151+
1152+
self.set_precedence(left_precedence, node.left)
10631153
self.traverse(node.left)
1064-
self.write(" " + self.binop[node.op.__class__.__name__] + " ")
1154+
self.write(f" {operator} ")
1155+
self.set_precedence(right_precedence, node.right)
10651156
self.traverse(node.right)
10661157

10671158
cmpops = {
@@ -1078,20 +1169,32 @@ def visit_BinOp(self, node):
10781169
}
10791170

10801171
def visit_Compare(self, node):
1081-
with self.delimit("(", ")"):
1172+
with self.require_parens(_Precedence.CMP, node):
1173+
self.set_precedence(_Precedence.CMP.next(), node.left, *node.comparators)
10821174
self.traverse(node.left)
10831175
for o, e in zip(node.ops, node.comparators):
10841176
self.write(" " + self.cmpops[o.__class__.__name__] + " ")
10851177
self.traverse(e)
10861178

10871179
boolops = {"And": "and", "Or": "or"}
1180+
boolop_precedence = {"and": _Precedence.AND, "or": _Precedence.OR}
10881181

10891182
def visit_BoolOp(self, node):
1090-
with self.delimit("(", ")"):
1091-
s = " %s " % self.boolops[node.op.__class__.__name__]
1092-
self.interleave(lambda: self.write(s), self.traverse, node.values)
1183+
operator = self.boolops[node.op.__class__.__name__]
1184+
operator_precedence = self.boolop_precedence[operator]
1185+
1186+
def increasing_level_traverse(node):
1187+
nonlocal operator_precedence
1188+
operator_precedence = operator_precedence.next()
1189+
self.set_precedence(operator_precedence, node)
1190+
self.traverse(node)
1191+
1192+
with self.require_parens(operator_precedence, node):
1193+
s = f" {operator} "
1194+
self.interleave(lambda: self.write(s), increasing_level_traverse, node.values)
10931195

10941196
def visit_Attribute(self, node):
1197+
self.set_precedence(_Precedence.ATOM, node.value)
10951198
self.traverse(node.value)
10961199
# Special case: 3.__abs__() is a syntax error, so if node.value
10971200
# is an integer literal then we need to either parenthesize
@@ -1102,6 +1205,7 @@ def visit_Attribute(self, node):
11021205
self.write(node.attr)
11031206

11041207
def visit_Call(self, node):
1208+
self.set_precedence(_Precedence.ATOM, node.func)
11051209
self.traverse(node.func)
11061210
with self.delimit("(", ")"):
11071211
comma = False
@@ -1119,18 +1223,21 @@ def visit_Call(self, node):
11191223
self.traverse(e)
11201224

11211225
def visit_Subscript(self, node):
1226+
self.set_precedence(_Precedence.ATOM, node.value)
11221227
self.traverse(node.value)
11231228
with self.delimit("[", "]"):
11241229
self.traverse(node.slice)
11251230

11261231
def visit_Starred(self, node):
11271232
self.write("*")
1233+
self.set_precedence(_Precedence.EXPR, node.value)
11281234
self.traverse(node.value)
11291235

11301236
def visit_Ellipsis(self, node):
11311237
self.write("...")
11321238

11331239
def visit_Index(self, node):
1240+
self.set_precedence(_Precedence.TUPLE, node.value)
11341241
self.traverse(node.value)
11351242

11361243
def visit_Slice(self, node):
@@ -1212,10 +1319,11 @@ def visit_keyword(self, node):
12121319
self.traverse(node.value)
12131320

12141321
def visit_Lambda(self, node):
1215-
with self.delimit("(", ")"):
1322+
with self.require_parens(_Precedence.TEST, node):
12161323
self.write("lambda ")
12171324
self.traverse(node.args)
12181325
self.write(": ")
1326+
self.set_precedence(_Precedence.TEST, node.body)
12191327
self.traverse(node.body)
12201328

12211329
def visit_alias(self, node):

Lib/test/test_ast.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,13 @@ def to_tuple(t):
247247

248248
class AST_Tests(unittest.TestCase):
249249

250+
def _is_ast_node(self, name, node):
251+
if not isinstance(node, type):
252+
return False
253+
if "ast" not in node.__module__:
254+
return False
255+
return name != 'AST' and name[0].isupper()
256+
250257
def _assertTrueorder(self, ast_node, parent_pos):
251258
if not isinstance(ast_node, ast.AST) or ast_node._fields is None:
252259
return
@@ -335,7 +342,7 @@ def test_base_classes(self):
335342

336343
def test_field_attr_existence(self):
337344
for name, item in ast.__dict__.items():
338-
if isinstance(item, type) and name != 'AST' and name[0].isupper():
345+
if self._is_ast_node(name, item):
339346
x = item()
340347
if isinstance(x, ast.AST):
341348
self.assertEqual(type(x._fields), tuple)

Lib/test/test_unparse.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,13 @@ def check_roundtrip(self, code1):
125125
def check_invalid(self, node, raises=ValueError):
126126
self.assertRaises(raises, ast.unparse, node)
127127

128+
def check_src_roundtrip(self, code1, code2=None, strip=True):
129+
code2 = code2 or code1
130+
code1 = ast.unparse(ast.parse(code1))
131+
if strip:
132+
code1 = code1.strip()
133+
self.assertEqual(code2, code1)
134+
128135

129136
class UnparseTestCase(ASTTestCase):
130137
# Tests for specific bugs found in earlier versions of unparse
@@ -281,6 +288,40 @@ def test_invalid_set(self):
281288
def test_invalid_yield_from(self):
282289
self.check_invalid(ast.YieldFrom(value=None))
283290

291+
class CosmeticTestCase(ASTTestCase):
292+
"""Test if there are cosmetic issues caused by unnecesary additions"""
293+
294+
def test_simple_expressions_parens(self):
295+
self.check_src_roundtrip("(a := b)")
296+
self.check_src_roundtrip("await x")
297+
self.check_src_roundtrip("x if x else y")
298+
self.check_src_roundtrip("lambda x: x")
299+
self.check_src_roundtrip("1 + 1")
300+
self.check_src_roundtrip("1 + 2 / 3")
301+
self.check_src_roundtrip("(1 + 2) / 3")
302+
self.check_src_roundtrip("(1 + 2) * 3 + 4 * (5 + 2)")
303+
self.check_src_roundtrip("(1 + 2) * 3 + 4 * (5 + 2) ** 2")
304+
self.check_src_roundtrip("~ x")
305+
self.check_src_roundtrip("x and y")
306+
self.check_src_roundtrip("x and y and z")
307+
self.check_src_roundtrip("x and (y and x)")
308+
self.check_src_roundtrip("(x and y) and z")
309+
self.check_src_roundtrip("(x ** y) ** z ** q")
310+
self.check_src_roundtrip("x >> y")
311+
self.check_src_roundtrip("x << y")
312+
self.check_src_roundtrip("x >> y and x >> z")
313+
self.check_src_roundtrip("x + y - z * q ^ t ** k")
314+
self.check_src_roundtrip("P * V if P and V else n * R * T")
315+
self.check_src_roundtrip("lambda P, V, n: P * V == n * R * T")
316+
self.check_src_roundtrip("flag & (other | foo)")
317+
self.check_src_roundtrip("not x == y")
318+
self.check_src_roundtrip("x == (not y)")
319+
self.check_src_roundtrip("yield x")
320+
self.check_src_roundtrip("yield from x")
321+
self.check_src_roundtrip("call((yield x))")
322+
self.check_src_roundtrip("return x + (yield x)")
323+
324+
284325
class DirectoryTestCase(ASTTestCase):
285326
"""Test roundtrip behaviour on all files in Lib and Lib/test."""
286327

0 commit comments

Comments
 (0)