Skip to content

bpo-47131: Speedup AST comparisons in test_unparse by using node traversal #32132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 2, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion Lib/test/test_unparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,43 @@ class Foo: pass

class ASTTestCase(unittest.TestCase):
def assertASTEqual(self, ast1, ast2):
self.assertEqual(ast.dump(ast1), ast.dump(ast2))
# Ensure the comparisons start at an AST node
self.assertIsInstance(ast1, ast.AST)
self.assertIsInstance(ast2, ast.AST)

# An AST comparison routine modeled after ast.dump(), but
# instead of string building, it traverses the two trees
# in lock-step.
def traverse_compare(a, b, missing=object()):
Copy link
Member

@vstinner vstinner Mar 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you moved missing inside the function definition, now you may move the function definition at the class level (method) of the module level (function, you should pass the testcase as an argument in this case), to avoid the cost of defining a new function at each assertASTEqual() call.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you are concerned about performance, using an inner function is faster than calling a method. True, the setup cost is greater, but that time is more than recovered due to attribute lookup after a few hundred comparisons. Global (module level) functions suffer the same fate as well but require more comparisons before getting overtaken.

After instrumenting test_unparse to get the number of calls to assertASTEqual and compares, I whipped up a reproducible minimal test script to show the timings. The script is here for you to run yourself.

Using main on Win10:
100 -> inner: 29 usec method: 27.4 usec
200 -> inner: 56.8 usec method: 48 usec
300 -> inner: 76.3 usec method: 86.4 usec
CosmeticTestCase (14 compares)
inner: 88.9 usec
method: 77.9 usec
DirectoryTestCase (400 compares)
inner: 74.1 msec
method: 79.5 msec
UnparseTestCase (400 compares)
inner: 91 msec
method: 97.6 msec

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I wrote previously, if there is a good reason to define a "local" (nested?) function, please add a comment here for future readers. Otherwise, someone who doesn't like nested functions can move the function (during a random refactoring) without knowing that you ran benchmarks and the code is faster when written like that. It's very easy to lose information when code is read/modified 5, 10 or 20 years later. And yes, it happens to me often to read 20 years old Python code and I have to dig into Git history (with unpleasant issues about CVS, SVN and HG commit numbers) and very old bug tracker issues to attempt to rebuild the rationale for the existing code before feeling safe to change it.

The performance doesn't make sense (method/module level code) to me. It should be as fast or faster. But my only god are benchmarks: I only trust benchmarks numbers :-)

if type(a) is not type(b):
self.fail(f"{type(a)!r} is not {type(b)!r}")
if isinstance(a, ast.AST):
for field in a._fields:
value1 = getattr(a, field, missing)
value2 = getattr(b, field, missing)
# Singletons are equal by definition, so further
# testing can be skipped.
if value1 is not value2:
traverse_compare(value1, value2)
elif isinstance(a, list):
try:
for node1, node2 in zip(a, b, strict=True):
traverse_compare(node1, node2)
except ValueError:
# Attempt a "pretty" error ala assertSequenceEqual()
len1 = len(a)
len2 = len(b)
if len1 > len2:
what = "First"
diff = len1 - len2
else:
what = "Second"
diff = len2 - len1
msg = f"{what} list contains {diff} additional elements."
raise self.failureException(msg) from None
elif a != b:
self.fail(f"{a!r} != {b!r}")
traverse_compare(ast1, ast2)

def check_ast_roundtrip(self, code1, **kwargs):
with self.subTest(code1=code1, ast_parse_kwargs=kwargs):
Expand Down