@@ -130,7 +130,43 @@ class Foo: pass
130
130
131
131
class ASTTestCase (unittest .TestCase ):
132
132
def assertASTEqual (self , ast1 , ast2 ):
133
- self .assertEqual (ast .dump (ast1 ), ast .dump (ast2 ))
133
+ # Ensure the comparisons start at an AST node
134
+ self .assertIsInstance (ast1 , ast .AST )
135
+ self .assertIsInstance (ast2 , ast .AST )
136
+
137
+ # An AST comparison routine modeled after ast.dump(), but
138
+ # instead of string building, it traverses the two trees
139
+ # in lock-step.
140
+ def traverse_compare (a , b , missing = object ()):
141
+ if type (a ) is not type (b ):
142
+ self .fail (f"{ type (a )!r} is not { type (b )!r} " )
143
+ if isinstance (a , ast .AST ):
144
+ for field in a ._fields :
145
+ value1 = getattr (a , field , missing )
146
+ value2 = getattr (b , field , missing )
147
+ # Singletons are equal by definition, so further
148
+ # testing can be skipped.
149
+ if value1 is not value2 :
150
+ traverse_compare (value1 , value2 )
151
+ elif isinstance (a , list ):
152
+ try :
153
+ for node1 , node2 in zip (a , b , strict = True ):
154
+ traverse_compare (node1 , node2 )
155
+ except ValueError :
156
+ # Attempt a "pretty" error ala assertSequenceEqual()
157
+ len1 = len (a )
158
+ len2 = len (b )
159
+ if len1 > len2 :
160
+ what = "First"
161
+ diff = len1 - len2
162
+ else :
163
+ what = "Second"
164
+ diff = len2 - len1
165
+ msg = f"{ what } list contains { diff } additional elements."
166
+ raise self .failureException (msg ) from None
167
+ elif a != b :
168
+ self .fail (f"{ a !r} != { b !r} " )
169
+ traverse_compare (ast1 , ast2 )
134
170
135
171
def check_ast_roundtrip (self , code1 , ** kwargs ):
136
172
with self .subTest (code1 = code1 , ast_parse_kwargs = kwargs ):
0 commit comments