Skip to content

Commit 2124fa4

Browse files
committed
[mlir][utils] Add script to verify canonicalizations agains Alive2
This script takes IR before and after canonicalization, translates it into llvm IR and converts to format suitablle for Alive2 https://alive2.llvm.org/ce/ This is primarily for arith canonicalizations verification, but technically it can be adapted for any dialect translatable to llvm. Usage `python verify_canon.py canonicalize.mlir func1 func2 ...` Example output: https://alive2.llvm.org/ce/z/KhQs4J Initial discussion: llvm#91646 (review)
1 parent 7c1b289 commit 2124fa4

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
# Run canonicalization, convert IR to LLVM and convert to format suitable to
6+
# verification against Alive2 https://alive2.llvm.org/ce/.
7+
8+
import subprocess
9+
import tempfile
10+
import sys
11+
from pathlib import Path
12+
13+
14+
def filter_funcs(ir, funcs):
15+
if not funcs:
16+
return ir
17+
18+
funcs_str = ",".join(funcs)
19+
return subprocess.check_output(
20+
["mlir-opt", f"--symbol-privatize=exclude={funcs_str}", "--symbol-dce"],
21+
input=ir,
22+
)
23+
24+
25+
def add_func_prefix(src, prefix):
26+
return src.replace("@", "@" + prefix)
27+
28+
29+
def merge_ir(chunks):
30+
files = []
31+
for chunk in chunks:
32+
tmp = tempfile.NamedTemporaryFile(suffix=".ll")
33+
tmp.write(chunk)
34+
tmp.flush()
35+
files.append(tmp)
36+
37+
return subprocess.check_output(["llvm-link", "-S"] + [f.name for f in files])
38+
39+
40+
if __name__ == "__main__":
41+
argv = sys.argv
42+
if len(argv) < 2:
43+
print(f"usage: {argv[0]} canonicalize.mlir [func1] [func2] ...")
44+
exit(0)
45+
46+
file = argv[1]
47+
funcs = argv[2:]
48+
49+
orig_ir = Path(file).read_bytes()
50+
orig_ir = filter_funcs(orig_ir, funcs)
51+
52+
to_llvm_args = [
53+
"--convert-arith-to-llvm",
54+
"--convert-func-to-llvm",
55+
"--convert-ub-to-llvm",
56+
"--convert-vector-to-llvm",
57+
]
58+
orig_args = ["mlir-opt"] + to_llvm_args
59+
canon_args = ["mlir-opt", "-canonicalize"] + to_llvm_args
60+
translate_args = ["mlir-translate", "-mlir-to-llvmir"]
61+
62+
orig = subprocess.check_output(orig_args, input=orig_ir)
63+
canonicalized = subprocess.check_output(canon_args, input=orig_ir)
64+
65+
orig = subprocess.check_output(translate_args, input=orig)
66+
canonicalized = subprocess.check_output(translate_args, input=canonicalized)
67+
68+
enc = "utf-8"
69+
orig = bytes(add_func_prefix(orig.decode(enc), "src_"), enc)
70+
canonicalized = bytes(add_func_prefix(canonicalized.decode(enc), "tgt_"), enc)
71+
72+
res = merge_ir([orig, canonicalized])
73+
74+
print(res.decode(enc))

0 commit comments

Comments
 (0)