Skip to content

Commit a89c15a

Browse files
authored
[mlir][sparse] enable Python BSR test (#72325)
1 parent 8aed916 commit a89c15a

File tree

1 file changed

+49
-12
lines changed

1 file changed

+49
-12
lines changed

mlir/test/Integration/Dialect/SparseTensor/python/test_output.py

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,28 @@ def boilerplate(attr: st.EncodingAttr):
3232
def expected(id_map):
3333
"""Returns expected contents of output.
3434
35+
+-----+-----+-----+-----+-----+
36+
| 1 0 | . . | . . | . . | 0 3 |
37+
| 0 2 | . . | . . | . . | 0 0 |
38+
+-----+-----+-----+-----+-----+
39+
| . . | . . | . . | . . | . . |
40+
| . . | . . | . . | . . | . . |
41+
+-----+-----+-----+-----+-----+
42+
| . . | . . | 5 0 | . . | . . |
43+
| . . | . . | 0 0 | . . | . . |
44+
+-----+-----+-----+-----+-----+
45+
| . . | . . | . . | . . | . . |
46+
| . . | . . | . . | . . | . . |
47+
+-----+-----+-----+-----+-----+
48+
| 0 0 | . . | . . | . . | . . |
49+
| 4 0 | . . | . . | . . | . . |
50+
+-----+-----+-----+-----+-----+
51+
3552
Output appears as dimension coordinates but lexicographically
36-
sorted by level coordinates.
53+
sorted by level coordinates. For BSR, the blocks are filled.
3754
"""
38-
return (
39-
f"""# extended FROSTT format
55+
if id_map is 0:
56+
return f"""# extended FROSTT format
4057
2 5
4158
10 10
4259
1 1 1
@@ -45,8 +62,8 @@ def expected(id_map):
4562
5 5 5
4663
10 1 4
4764
"""
48-
if id_map
49-
else f"""# extended FROSTT format
65+
if id_map is 1:
66+
return f"""# extended FROSTT format
5067
2 5
5168
10 10
5269
1 1 1
@@ -55,7 +72,28 @@ def expected(id_map):
5572
5 5 5
5673
1 10 3
5774
"""
58-
)
75+
if id_map is 2:
76+
return f"""# extended FROSTT format
77+
2 16
78+
10 10
79+
1 1 1
80+
1 2 0
81+
2 1 0
82+
2 2 2
83+
1 9 0
84+
1 10 3
85+
2 9 0
86+
2 10 0
87+
5 5 5
88+
5 6 0
89+
6 5 0
90+
6 6 0
91+
9 1 0
92+
9 2 0
93+
10 1 4
94+
10 2 0
95+
"""
96+
raise AssertionError("unexpected id_map")
5997

6098

6199
def build_compile_and_run_output(attr: st.EncodingAttr, compiler, expected):
@@ -93,10 +131,10 @@ def main():
93131
[st.DimLevelType.compressed, st.DimLevelType.compressed],
94132
]
95133
orderings = [
96-
(ir.AffineMap.get_permutation([0, 1]), True),
97-
(ir.AffineMap.get_permutation([1, 0]), False),
134+
(ir.AffineMap.get_permutation([0, 1]), 0),
135+
(ir.AffineMap.get_permutation([1, 0]), 1),
98136
]
99-
bitwidths = [8, 16, 32, 64]
137+
bitwidths = [8, 64]
100138
compiler = sparse_compiler.SparseCompiler(
101139
options="", opt_level=2, shared_libs=[support_lib]
102140
)
@@ -135,11 +173,10 @@ def main():
135173
l3 = ir.AffineDimExpr.get(3)
136174
lvl2dim = ir.AffineMap.get(4, 0, [2 * l0 + l2, 2 * l1 + l3])
137175
attr = st.EncodingAttr.get(level, dim2lvl, lvl2dim, 0, 0)
138-
# TODO: enable this one CONVERSION on BSR is working
139-
# build_compile_and_run_output(attr, compiler, block_expected())
176+
build_compile_and_run_output(attr, compiler, expected(2))
140177
count = count + 1
141178

142-
# CHECK: Passed 33 tests
179+
# CHECK: Passed 17 tests
143180
print("Passed", count, "tests")
144181

145182

0 commit comments

Comments
 (0)