27
27
import sys
28
28
from _ast import *
29
29
from contextlib import contextmanager , nullcontext
30
+ from enum import IntEnum , auto
30
31
31
32
32
33
def parse (source , filename = '<unknown>' , mode = 'exec' , * ,
@@ -560,6 +561,35 @@ def __new__(cls, *args, **kwargs):
560
561
# We unparse those infinities to INFSTR.
561
562
_INFSTR = "1e" + repr (sys .float_info .max_10_exp + 1 )
562
563
564
+ class _Precedence (IntEnum ):
565
+ """Precedence table that originated from python grammar."""
566
+
567
+ TUPLE = auto ()
568
+ YIELD = auto () # 'yield', 'yield from'
569
+ TEST = auto () # 'if'-'else', 'lambda'
570
+ OR = auto () # 'or'
571
+ AND = auto () # 'and'
572
+ NOT = auto () # 'not'
573
+ CMP = auto () # '<', '>', '==', '>=', '<=', '!=',
574
+ # 'in', 'not in', 'is', 'is not'
575
+ EXPR = auto ()
576
+ BOR = EXPR # '|'
577
+ BXOR = auto () # '^'
578
+ BAND = auto () # '&'
579
+ SHIFT = auto () # '<<', '>>'
580
+ ARITH = auto () # '+', '-'
581
+ TERM = auto () # '*', '@', '/', '%', '//'
582
+ FACTOR = auto () # unary '+', '-', '~'
583
+ POWER = auto () # '**'
584
+ AWAIT = auto () # 'await'
585
+ ATOM = auto ()
586
+
587
+ def next (self ):
588
+ try :
589
+ return self .__class__ (self + 1 )
590
+ except ValueError :
591
+ return self
592
+
563
593
class _Unparser (NodeVisitor ):
564
594
"""Methods in this class recursively traverse an AST and
565
595
output source code for the abstract syntax; original formatting
@@ -568,6 +598,7 @@ class _Unparser(NodeVisitor):
568
598
def __init__ (self ):
569
599
self ._source = []
570
600
self ._buffer = []
601
+ self ._precedences = {}
571
602
self ._indent = 0
572
603
573
604
def interleave (self , inter , f , seq ):
@@ -625,6 +656,17 @@ def delimit_if(self, start, end, condition):
625
656
else :
626
657
return nullcontext ()
627
658
659
+ def require_parens (self , precedence , node ):
660
+ """Shortcut to adding precedence related parens"""
661
+ return self .delimit_if ("(" , ")" , self .get_precedence (node ) > precedence )
662
+
663
+ def get_precedence (self , node ):
664
+ return self ._precedences .get (node , _Precedence .TEST )
665
+
666
+ def set_precedence (self , precedence , * nodes ):
667
+ for node in nodes :
668
+ self ._precedences [node ] = precedence
669
+
628
670
def traverse (self , node ):
629
671
if isinstance (node , list ):
630
672
for item in node :
@@ -645,10 +687,12 @@ def visit_Module(self, node):
645
687
646
688
def visit_Expr (self , node ):
647
689
self .fill ()
690
+ self .set_precedence (_Precedence .YIELD , node .value )
648
691
self .traverse (node .value )
649
692
650
693
def visit_NamedExpr (self , node ):
651
- with self .delimit ("(" , ")" ):
694
+ with self .require_parens (_Precedence .TUPLE , node ):
695
+ self .set_precedence (_Precedence .ATOM , node .target , node .value )
652
696
self .traverse (node .target )
653
697
self .write (" := " )
654
698
self .traverse (node .value )
@@ -723,24 +767,27 @@ def visit_Nonlocal(self, node):
723
767
self .interleave (lambda : self .write (", " ), self .write , node .names )
724
768
725
769
def visit_Await (self , node ):
726
- with self .delimit ( "(" , ")" ):
770
+ with self .require_parens ( _Precedence . AWAIT , node ):
727
771
self .write ("await" )
728
772
if node .value :
729
773
self .write (" " )
774
+ self .set_precedence (_Precedence .ATOM , node .value )
730
775
self .traverse (node .value )
731
776
732
777
def visit_Yield (self , node ):
733
- with self .delimit ( "(" , ")" ):
778
+ with self .require_parens ( _Precedence . YIELD , node ):
734
779
self .write ("yield" )
735
780
if node .value :
736
781
self .write (" " )
782
+ self .set_precedence (_Precedence .ATOM , node .value )
737
783
self .traverse (node .value )
738
784
739
785
def visit_YieldFrom (self , node ):
740
- with self .delimit ( "(" , ")" ):
786
+ with self .require_parens ( _Precedence . YIELD , node ):
741
787
self .write ("yield from " )
742
788
if not node .value :
743
789
raise ValueError ("Node can't be used without a value attribute." )
790
+ self .set_precedence (_Precedence .ATOM , node .value )
744
791
self .traverse (node .value )
745
792
746
793
def visit_Raise (self , node ):
@@ -907,7 +954,9 @@ def _fstring_Constant(self, node, write):
907
954
908
955
def _fstring_FormattedValue (self , node , write ):
909
956
write ("{" )
910
- expr = type (self )().visit (node .value ).rstrip ("\n " )
957
+ unparser = type (self )()
958
+ unparser .set_precedence (_Precedence .TEST .next (), node .value )
959
+ expr = unparser .visit (node .value ).rstrip ("\n " )
911
960
if expr .startswith ("{" ):
912
961
write (" " ) # Separate pair of opening brackets as "{ {"
913
962
write (expr )
@@ -983,19 +1032,23 @@ def visit_comprehension(self, node):
983
1032
self .write (" async for " )
984
1033
else :
985
1034
self .write (" for " )
1035
+ self .set_precedence (_Precedence .TUPLE , node .target )
986
1036
self .traverse (node .target )
987
1037
self .write (" in " )
1038
+ self .set_precedence (_Precedence .TEST .next (), node .iter , * node .ifs )
988
1039
self .traverse (node .iter )
989
1040
for if_clause in node .ifs :
990
1041
self .write (" if " )
991
1042
self .traverse (if_clause )
992
1043
993
1044
def visit_IfExp (self , node ):
994
- with self .delimit ("(" , ")" ):
1045
+ with self .require_parens (_Precedence .TEST , node ):
1046
+ self .set_precedence (_Precedence .TEST .next (), node .body , node .test )
995
1047
self .traverse (node .body )
996
1048
self .write (" if " )
997
1049
self .traverse (node .test )
998
1050
self .write (" else " )
1051
+ self .set_precedence (_Precedence .TEST , node .orelse )
999
1052
self .traverse (node .orelse )
1000
1053
1001
1054
def visit_Set (self , node ):
@@ -1016,6 +1069,7 @@ def write_item(item):
1016
1069
# for dictionary unpacking operator in dicts {**{'y': 2}}
1017
1070
# see PEP 448 for details
1018
1071
self .write ("**" )
1072
+ self .set_precedence (_Precedence .EXPR , v )
1019
1073
self .traverse (v )
1020
1074
else :
1021
1075
write_key_value_pair (k , v )
@@ -1035,11 +1089,20 @@ def visit_Tuple(self, node):
1035
1089
self .interleave (lambda : self .write (", " ), self .traverse , node .elts )
1036
1090
1037
1091
unop = {"Invert" : "~" , "Not" : "not" , "UAdd" : "+" , "USub" : "-" }
1092
+ unop_precedence = {
1093
+ "~" : _Precedence .FACTOR ,
1094
+ "not" : _Precedence .NOT ,
1095
+ "+" : _Precedence .FACTOR ,
1096
+ "-" : _Precedence .FACTOR
1097
+ }
1038
1098
1039
1099
def visit_UnaryOp (self , node ):
1040
- with self .delimit ("(" , ")" ):
1041
- self .write (self .unop [node .op .__class__ .__name__ ])
1100
+ operator = self .unop [node .op .__class__ .__name__ ]
1101
+ operator_precedence = self .unop_precedence [operator ]
1102
+ with self .require_parens (operator_precedence , node ):
1103
+ self .write (operator )
1042
1104
self .write (" " )
1105
+ self .set_precedence (operator_precedence , node .operand )
1043
1106
self .traverse (node .operand )
1044
1107
1045
1108
binop = {
@@ -1058,10 +1121,38 @@ def visit_UnaryOp(self, node):
1058
1121
"Pow" : "**" ,
1059
1122
}
1060
1123
1124
+ binop_precedence = {
1125
+ "+" : _Precedence .ARITH ,
1126
+ "-" : _Precedence .ARITH ,
1127
+ "*" : _Precedence .TERM ,
1128
+ "@" : _Precedence .TERM ,
1129
+ "/" : _Precedence .TERM ,
1130
+ "%" : _Precedence .TERM ,
1131
+ "<<" : _Precedence .SHIFT ,
1132
+ ">>" : _Precedence .SHIFT ,
1133
+ "|" : _Precedence .BOR ,
1134
+ "^" : _Precedence .BXOR ,
1135
+ "&" : _Precedence .BAND ,
1136
+ "//" : _Precedence .TERM ,
1137
+ "**" : _Precedence .POWER ,
1138
+ }
1139
+
1140
+ binop_rassoc = frozenset (("**" ,))
1061
1141
def visit_BinOp (self , node ):
1062
- with self .delimit ("(" , ")" ):
1142
+ operator = self .binop [node .op .__class__ .__name__ ]
1143
+ operator_precedence = self .binop_precedence [operator ]
1144
+ with self .require_parens (operator_precedence , node ):
1145
+ if operator in self .binop_rassoc :
1146
+ left_precedence = operator_precedence .next ()
1147
+ right_precedence = operator_precedence
1148
+ else :
1149
+ left_precedence = operator_precedence
1150
+ right_precedence = operator_precedence .next ()
1151
+
1152
+ self .set_precedence (left_precedence , node .left )
1063
1153
self .traverse (node .left )
1064
- self .write (" " + self .binop [node .op .__class__ .__name__ ] + " " )
1154
+ self .write (f" { operator } " )
1155
+ self .set_precedence (right_precedence , node .right )
1065
1156
self .traverse (node .right )
1066
1157
1067
1158
cmpops = {
@@ -1078,20 +1169,32 @@ def visit_BinOp(self, node):
1078
1169
}
1079
1170
1080
1171
def visit_Compare (self , node ):
1081
- with self .delimit ("(" , ")" ):
1172
+ with self .require_parens (_Precedence .CMP , node ):
1173
+ self .set_precedence (_Precedence .CMP .next (), node .left , * node .comparators )
1082
1174
self .traverse (node .left )
1083
1175
for o , e in zip (node .ops , node .comparators ):
1084
1176
self .write (" " + self .cmpops [o .__class__ .__name__ ] + " " )
1085
1177
self .traverse (e )
1086
1178
1087
1179
boolops = {"And" : "and" , "Or" : "or" }
1180
+ boolop_precedence = {"and" : _Precedence .AND , "or" : _Precedence .OR }
1088
1181
1089
1182
def visit_BoolOp (self , node ):
1090
- with self .delimit ("(" , ")" ):
1091
- s = " %s " % self .boolops [node .op .__class__ .__name__ ]
1092
- self .interleave (lambda : self .write (s ), self .traverse , node .values )
1183
+ operator = self .boolops [node .op .__class__ .__name__ ]
1184
+ operator_precedence = self .boolop_precedence [operator ]
1185
+
1186
+ def increasing_level_traverse (node ):
1187
+ nonlocal operator_precedence
1188
+ operator_precedence = operator_precedence .next ()
1189
+ self .set_precedence (operator_precedence , node )
1190
+ self .traverse (node )
1191
+
1192
+ with self .require_parens (operator_precedence , node ):
1193
+ s = f" { operator } "
1194
+ self .interleave (lambda : self .write (s ), increasing_level_traverse , node .values )
1093
1195
1094
1196
def visit_Attribute (self , node ):
1197
+ self .set_precedence (_Precedence .ATOM , node .value )
1095
1198
self .traverse (node .value )
1096
1199
# Special case: 3.__abs__() is a syntax error, so if node.value
1097
1200
# is an integer literal then we need to either parenthesize
@@ -1102,6 +1205,7 @@ def visit_Attribute(self, node):
1102
1205
self .write (node .attr )
1103
1206
1104
1207
def visit_Call (self , node ):
1208
+ self .set_precedence (_Precedence .ATOM , node .func )
1105
1209
self .traverse (node .func )
1106
1210
with self .delimit ("(" , ")" ):
1107
1211
comma = False
@@ -1119,18 +1223,21 @@ def visit_Call(self, node):
1119
1223
self .traverse (e )
1120
1224
1121
1225
def visit_Subscript (self , node ):
1226
+ self .set_precedence (_Precedence .ATOM , node .value )
1122
1227
self .traverse (node .value )
1123
1228
with self .delimit ("[" , "]" ):
1124
1229
self .traverse (node .slice )
1125
1230
1126
1231
def visit_Starred (self , node ):
1127
1232
self .write ("*" )
1233
+ self .set_precedence (_Precedence .EXPR , node .value )
1128
1234
self .traverse (node .value )
1129
1235
1130
1236
def visit_Ellipsis (self , node ):
1131
1237
self .write ("..." )
1132
1238
1133
1239
def visit_Index (self , node ):
1240
+ self .set_precedence (_Precedence .TUPLE , node .value )
1134
1241
self .traverse (node .value )
1135
1242
1136
1243
def visit_Slice (self , node ):
@@ -1212,10 +1319,11 @@ def visit_keyword(self, node):
1212
1319
self .traverse (node .value )
1213
1320
1214
1321
def visit_Lambda (self , node ):
1215
- with self .delimit ( "(" , ")" ):
1322
+ with self .require_parens ( _Precedence . TEST , node ):
1216
1323
self .write ("lambda " )
1217
1324
self .traverse (node .args )
1218
1325
self .write (": " )
1326
+ self .set_precedence (_Precedence .TEST , node .body )
1219
1327
self .traverse (node .body )
1220
1328
1221
1329
def visit_alias (self , node ):
0 commit comments