Skip to content

[mlir][utils] Update generate-test-checks.py (use SSA names) #136819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 25, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions mlir/utils/generate-test-checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,20 @@ def generate_in_parent_scope(self, n):
self.generate_in_parent_scope_left = n

# Generate a substitution name for the given ssa value name.
def generate_name(self, source_variable_name):
def generate_name(self, source_variable_name, use_ssa_name):

# Compute variable name
variable_name = self.variable_names.pop(0) if len(self.variable_names) > 0 else ''
if variable_name == '':
variable_name = "VAL_" + str(self.name_counter)
self.name_counter += 1
# If `use_ssa_name` is set, use the MLIR SSA value name to generate
# a FileCHeck substation string. As FileCheck requires these
# strings to start with a character, skip MLIR variables starting
# with a digit (e.g. `%0`).
if use_ssa_name and source_variable_name[0].isalpha():
variable_name = source_variable_name.upper()
else:
variable_name = "VAL_" + str(self.name_counter)
self.name_counter += 1

# Scope where variable name is saved
scope = len(self.scopes) - 1
Expand Down Expand Up @@ -158,7 +165,7 @@ def get_num_ssa_results(input_line):


# Process a line of input that has been split at each SSA identifier '%'.
def process_line(line_chunks, variable_namer, strict_name_re=False):
def process_line(line_chunks, variable_namer, use_ssa_name=False, strict_name_re=False):
output_line = ""

# Process the rest that contained an SSA value name.
Expand All @@ -178,7 +185,7 @@ def process_line(line_chunks, variable_namer, strict_name_re=False):
output_line += "%[[" + variable + "]]"
else:
# Otherwise, generate a new variable.
variable = variable_namer.generate_name(ssa_name)
variable = variable_namer.generate_name(ssa_name, use_ssa_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So for an SSA value like %0, will it create a match like %[[0:.*]]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, though it's worth noting that this block of code is only effectively used for function declarations (i.e., CHECK-SAME labels). So if a function argument is %0, then yes - the auto-generated LIT label would be %[[0:.*]].

However, in practice you'd typically use generate-test-checks.py like this:

$ mlir-opt --some-transformation file.mlir | generate-test-checks.py

Since mlir-opt renames function arguments to something like %arg0, you'd end up with LIT labels like %[[ARG0:.*]]. That already feels like an improvement over %[[VAL_0:.*]], but still not quite as readable as %[[FILTER:.*]] from my original example - which is what I’d ideally like to achieve.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining!

To be on the safe side I suggest making sure that the variable name is a valid FileCheck identifier, e.g. %[[0:.*]] is invalid.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, updated!

if strict_name_re:
# Use stricter regexp for the variable name, if requested.
# Greedy matching may cause issues with the generic '.*'
Expand Down Expand Up @@ -415,9 +422,11 @@ def main():
pad_depth = label_length if label_length < 21 else 4
output_line += " " * pad_depth

# Process the rest of the line.
# Process the rest of the line. Use the original SSA name to generate the LIT
# variable names.
use_ssa_names = True
output_line += process_line(
[argument], variable_namer, args.strict_name_re
[argument], variable_namer, use_ssa_names, args.strict_name_re
)

# Append the output line.
Expand Down
Loading