Skip to content

bpo-38870: Implement Simple Preceding to AST Unparser #17377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 1, 2020
Merged
138 changes: 123 additions & 15 deletions Lib/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import sys
from _ast import *
from contextlib import contextmanager, nullcontext
from enum import IntEnum, auto


def parse(source, filename='<unknown>', mode='exec', *,
Expand Down Expand Up @@ -560,6 +561,35 @@ def __new__(cls, *args, **kwargs):
# We unparse those infinities to INFSTR.
_INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1)

class _Precedence(IntEnum):
"""Precedence table that originated from python grammar."""

TUPLE = auto()
YIELD = auto() # 'yield', 'yield from'
TEST = auto() # 'if'-'else', 'lambda'
OR = auto() # 'or'
AND = auto() # 'and'
NOT = auto() # 'not'
CMP = auto() # '<', '>', '==', '>=', '<=', '!=',
# 'in', 'not in', 'is', 'is not'
EXPR = auto()
BOR = EXPR # '|'
BXOR = auto() # '^'
BAND = auto() # '&'
SHIFT = auto() # '<<', '>>'
ARITH = auto() # '+', '-'
TERM = auto() # '*', '@', '/', '%', '//'
FACTOR = auto() # unary '+', '-', '~'
POWER = auto() # '**'
AWAIT = auto() # 'await'
ATOM = auto()

def next(self):
try:
return self.__class__(self + 1)
except ValueError:
return self

class _Unparser(NodeVisitor):
"""Methods in this class recursively traverse an AST and
output source code for the abstract syntax; original formatting
Expand All @@ -568,6 +598,7 @@ class _Unparser(NodeVisitor):
def __init__(self):
self._source = []
self._buffer = []
self._precedences = {}
self._indent = 0

def interleave(self, inter, f, seq):
Expand Down Expand Up @@ -625,6 +656,17 @@ def delimit_if(self, start, end, condition):
else:
return nullcontext()

def require_parens(self, precedence, node):
"""Shortcut to adding precedence related parens"""
return self.delimit_if("(", ")", self.get_precedence(node) > precedence)

def get_precedence(self, node):
return self._precedences.get(node, _Precedence.TEST)

def set_precedence(self, precedence, *nodes):
for node in nodes:
self._precedences[node] = precedence

def traverse(self, node):
if isinstance(node, list):
for item in node:
Expand All @@ -645,10 +687,12 @@ def visit_Module(self, node):

def visit_Expr(self, node):
self.fill()
self.set_precedence(_Precedence.YIELD, node.value)
self.traverse(node.value)

def visit_NamedExpr(self, node):
with self.delimit("(", ")"):
with self.require_parens(_Precedence.TUPLE, node):
self.set_precedence(_Precedence.ATOM, node.target, node.value)
self.traverse(node.target)
self.write(" := ")
self.traverse(node.value)
Expand Down Expand Up @@ -723,24 +767,27 @@ def visit_Nonlocal(self, node):
self.interleave(lambda: self.write(", "), self.write, node.names)

def visit_Await(self, node):
with self.delimit("(", ")"):
with self.require_parens(_Precedence.AWAIT, node):
self.write("await")
if node.value:
self.write(" ")
self.set_precedence(_Precedence.ATOM, node.value)
self.traverse(node.value)

def visit_Yield(self, node):
with self.delimit("(", ")"):
with self.require_parens(_Precedence.YIELD, node):
self.write("yield")
if node.value:
self.write(" ")
self.set_precedence(_Precedence.ATOM, node.value)
self.traverse(node.value)

def visit_YieldFrom(self, node):
with self.delimit("(", ")"):
with self.require_parens(_Precedence.YIELD, node):
self.write("yield from ")
if not node.value:
raise ValueError("Node can't be used without a value attribute.")
self.set_precedence(_Precedence.ATOM, node.value)
self.traverse(node.value)

def visit_Raise(self, node):
Expand Down Expand Up @@ -907,7 +954,9 @@ def _fstring_Constant(self, node, write):

def _fstring_FormattedValue(self, node, write):
write("{")
expr = type(self)().visit(node.value).rstrip("\n")
unparser = type(self)()
unparser.set_precedence(_Precedence.TEST.next(), node.value)
expr = unparser.visit(node.value).rstrip("\n")
if expr.startswith("{"):
write(" ") # Separate pair of opening brackets as "{ {"
write(expr)
Expand Down Expand Up @@ -983,19 +1032,23 @@ def visit_comprehension(self, node):
self.write(" async for ")
else:
self.write(" for ")
self.set_precedence(_Precedence.TUPLE, node.target)
self.traverse(node.target)
self.write(" in ")
self.set_precedence(_Precedence.TEST.next(), node.iter, *node.ifs)
self.traverse(node.iter)
for if_clause in node.ifs:
self.write(" if ")
self.traverse(if_clause)

def visit_IfExp(self, node):
with self.delimit("(", ")"):
with self.require_parens(_Precedence.TEST, node):
self.set_precedence(_Precedence.TEST.next(), node.body, node.test)
self.traverse(node.body)
self.write(" if ")
self.traverse(node.test)
self.write(" else ")
self.set_precedence(_Precedence.TEST, node.orelse)
self.traverse(node.orelse)

def visit_Set(self, node):
Expand All @@ -1016,6 +1069,7 @@ def write_item(item):
# for dictionary unpacking operator in dicts {**{'y': 2}}
# see PEP 448 for details
self.write("**")
self.set_precedence(_Precedence.EXPR, v)
self.traverse(v)
else:
write_key_value_pair(k, v)
Expand All @@ -1035,11 +1089,20 @@ def visit_Tuple(self, node):
self.interleave(lambda: self.write(", "), self.traverse, node.elts)

unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"}
unop_precedence = {
"~": _Precedence.FACTOR,
"not": _Precedence.NOT,
"+": _Precedence.FACTOR,
"-": _Precedence.FACTOR
}

def visit_UnaryOp(self, node):
with self.delimit("(", ")"):
self.write(self.unop[node.op.__class__.__name__])
operator = self.unop[node.op.__class__.__name__]
operator_precedence = self.unop_precedence[operator]
with self.require_parens(operator_precedence, node):
self.write(operator)
self.write(" ")
self.set_precedence(operator_precedence, node.operand)
self.traverse(node.operand)

binop = {
Expand All @@ -1058,10 +1121,38 @@ def visit_UnaryOp(self, node):
"Pow": "**",
}

binop_precedence = {
"+": _Precedence.ARITH,
"-": _Precedence.ARITH,
"*": _Precedence.TERM,
"@": _Precedence.TERM,
"/": _Precedence.TERM,
"%": _Precedence.TERM,
"<<": _Precedence.SHIFT,
">>": _Precedence.SHIFT,
"|": _Precedence.BOR,
"^": _Precedence.BXOR,
"&": _Precedence.BAND,
"//": _Precedence.TERM,
"**": _Precedence.POWER,
}

binop_rassoc = frozenset(("**",))
def visit_BinOp(self, node):
with self.delimit("(", ")"):
operator = self.binop[node.op.__class__.__name__]
operator_precedence = self.binop_precedence[operator]
with self.require_parens(operator_precedence, node):
if operator in self.binop_rassoc:
left_precedence = operator_precedence.next()
right_precedence = operator_precedence
else:
left_precedence = operator_precedence
right_precedence = operator_precedence.next()

self.set_precedence(left_precedence, node.left)
self.traverse(node.left)
self.write(" " + self.binop[node.op.__class__.__name__] + " ")
self.write(f" {operator} ")
self.set_precedence(right_precedence, node.right)
self.traverse(node.right)

cmpops = {
Expand All @@ -1078,20 +1169,32 @@ def visit_BinOp(self, node):
}

def visit_Compare(self, node):
with self.delimit("(", ")"):
with self.require_parens(_Precedence.CMP, node):
self.set_precedence(_Precedence.CMP.next(), node.left, *node.comparators)
self.traverse(node.left)
for o, e in zip(node.ops, node.comparators):
self.write(" " + self.cmpops[o.__class__.__name__] + " ")
self.traverse(e)

boolops = {"And": "and", "Or": "or"}
boolop_precedence = {"and": _Precedence.AND, "or": _Precedence.OR}

def visit_BoolOp(self, node):
with self.delimit("(", ")"):
s = " %s " % self.boolops[node.op.__class__.__name__]
self.interleave(lambda: self.write(s), self.traverse, node.values)
operator = self.boolops[node.op.__class__.__name__]
operator_precedence = self.boolop_precedence[operator]

def increasing_level_traverse(node):
nonlocal operator_precedence
operator_precedence = operator_precedence.next()
self.set_precedence(operator_precedence, node)
self.traverse(node)

with self.require_parens(operator_precedence, node):
s = f" {operator} "
self.interleave(lambda: self.write(s), increasing_level_traverse, node.values)

def visit_Attribute(self, node):
self.set_precedence(_Precedence.ATOM, node.value)
self.traverse(node.value)
# Special case: 3.__abs__() is a syntax error, so if node.value
# is an integer literal then we need to either parenthesize
Expand All @@ -1102,6 +1205,7 @@ def visit_Attribute(self, node):
self.write(node.attr)

def visit_Call(self, node):
self.set_precedence(_Precedence.ATOM, node.func)
self.traverse(node.func)
with self.delimit("(", ")"):
comma = False
Expand All @@ -1119,18 +1223,21 @@ def visit_Call(self, node):
self.traverse(e)

def visit_Subscript(self, node):
self.set_precedence(_Precedence.ATOM, node.value)
self.traverse(node.value)
with self.delimit("[", "]"):
self.traverse(node.slice)

def visit_Starred(self, node):
self.write("*")
self.set_precedence(_Precedence.EXPR, node.value)
self.traverse(node.value)

def visit_Ellipsis(self, node):
self.write("...")

def visit_Index(self, node):
self.set_precedence(_Precedence.TUPLE, node.value)
self.traverse(node.value)

def visit_Slice(self, node):
Expand Down Expand Up @@ -1212,10 +1319,11 @@ def visit_keyword(self, node):
self.traverse(node.value)

def visit_Lambda(self, node):
with self.delimit("(", ")"):
with self.require_parens(_Precedence.TEST, node):
self.write("lambda ")
self.traverse(node.args)
self.write(": ")
self.set_precedence(_Precedence.TEST, node.body)
self.traverse(node.body)

def visit_alias(self, node):
Expand Down
9 changes: 8 additions & 1 deletion Lib/test/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,13 @@ def to_tuple(t):

class AST_Tests(unittest.TestCase):

def _is_ast_node(self, name, node):
if not isinstance(node, type):
return False
if "ast" not in node.__module__:
return False
return name != 'AST' and name[0].isupper()

def _assertTrueorder(self, ast_node, parent_pos):
if not isinstance(ast_node, ast.AST) or ast_node._fields is None:
return
Expand Down Expand Up @@ -331,7 +338,7 @@ def test_base_classes(self):

def test_field_attr_existence(self):
for name, item in ast.__dict__.items():
if isinstance(item, type) and name != 'AST' and name[0].isupper():
if self._is_ast_node(name, item):
x = item()
if isinstance(x, ast.AST):
self.assertEqual(type(x._fields), tuple)
Expand Down
41 changes: 41 additions & 0 deletions Lib/test/test_unparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,13 @@ def check_roundtrip(self, code1):
def check_invalid(self, node, raises=ValueError):
self.assertRaises(raises, ast.unparse, node)

def check_src_roundtrip(self, code1, code2=None, strip=True):
code2 = code2 or code1
code1 = ast.unparse(ast.parse(code1))
if strip:
code1 = code1.strip()
self.assertEqual(code2, code1)


class UnparseTestCase(ASTTestCase):
# Tests for specific bugs found in earlier versions of unparse
Expand Down Expand Up @@ -281,6 +288,40 @@ def test_invalid_set(self):
def test_invalid_yield_from(self):
self.check_invalid(ast.YieldFrom(value=None))

class CosmeticTestCase(ASTTestCase):
"""Test if there are cosmetic issues caused by unnecesary additions"""

def test_simple_expressions_parens(self):
self.check_src_roundtrip("(a := b)")
self.check_src_roundtrip("await x")
self.check_src_roundtrip("x if x else y")
self.check_src_roundtrip("lambda x: x")
self.check_src_roundtrip("1 + 1")
self.check_src_roundtrip("1 + 2 / 3")
self.check_src_roundtrip("(1 + 2) / 3")
self.check_src_roundtrip("(1 + 2) * 3 + 4 * (5 + 2)")
self.check_src_roundtrip("(1 + 2) * 3 + 4 * (5 + 2) ** 2")
self.check_src_roundtrip("~ x")
self.check_src_roundtrip("x and y")
self.check_src_roundtrip("x and y and z")
self.check_src_roundtrip("x and (y and x)")
self.check_src_roundtrip("(x and y) and z")
self.check_src_roundtrip("(x ** y) ** z ** q")
self.check_src_roundtrip("x >> y")
self.check_src_roundtrip("x << y")
self.check_src_roundtrip("x >> y and x >> z")
self.check_src_roundtrip("x + y - z * q ^ t ** k")
self.check_src_roundtrip("P * V if P and V else n * R * T")
self.check_src_roundtrip("lambda P, V, n: P * V == n * R * T")
self.check_src_roundtrip("flag & (other | foo)")
self.check_src_roundtrip("not x == y")
self.check_src_roundtrip("x == (not y)")
self.check_src_roundtrip("yield x")
self.check_src_roundtrip("yield from x")
self.check_src_roundtrip("call((yield x))")
self.check_src_roundtrip("return x + (yield x)")


class DirectoryTestCase(ASTTestCase):
"""Test roundtrip behaviour on all files in Lib and Lib/test."""

Expand Down