Skip to content

Commit 1f3c482

Browse files
committed
[mlir][sparse] more test cases for linalg.index
Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D121660
1 parent c62746a commit 1f3c482

File tree

1 file changed

+161
-24
lines changed

1 file changed

+161
-24
lines changed

mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_index.mlir

Lines changed: 161 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,30 +3,85 @@
33
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
44
// RUN: FileCheck %s
55

6+
#SparseVector = #sparse_tensor.encoding<{
7+
dimLevelType = ["compressed"]
8+
}>
9+
610
#SparseMatrix = #sparse_tensor.encoding<{
711
dimLevelType = ["compressed", "compressed"]
812
}>
913

10-
#trait = {
14+
#trait_1d = {
15+
indexing_maps = [
16+
affine_map<(i) -> (i)>, // a
17+
affine_map<(i) -> (i)> // x (out)
18+
],
19+
iterator_types = ["parallel"],
20+
doc = "X(i) = a(i) op i"
21+
}
22+
23+
#trait_2d = {
1124
indexing_maps = [
1225
affine_map<(i,j) -> (i,j)>, // A
1326
affine_map<(i,j) -> (i,j)> // X (out)
1427
],
1528
iterator_types = ["parallel", "parallel"],
16-
doc = "X(i,j) = A(i,j) * i * j"
29+
doc = "X(i,j) = A(i,j) op i op j"
1730
}
1831

32+
//
33+
// Test with indices. Note that a lot of results are actually
34+
// dense, but this is done to stress test all the operations.
35+
//
1936
module {
2037

2138
//
22-
// Kernel that uses indices in the index notation.
39+
// Kernel that uses index in the index notation (conjunction).
40+
//
41+
func @sparse_index_1d_conj(%arga: tensor<8xi64, #SparseVector>)
42+
-> tensor<8xi64, #SparseVector> {
43+
%d0 = arith.constant 8 : index
44+
%init = sparse_tensor.init [%d0] : tensor<8xi64, #SparseVector>
45+
%r = linalg.generic #trait_1d
46+
ins(%arga: tensor<8xi64, #SparseVector>)
47+
outs(%init: tensor<8xi64, #SparseVector>) {
48+
^bb(%a: i64, %x: i64):
49+
%i = linalg.index 0 : index
50+
%ii = arith.index_cast %i : index to i64
51+
%m1 = arith.muli %a, %ii : i64
52+
linalg.yield %m1 : i64
53+
} -> tensor<8xi64, #SparseVector>
54+
return %r : tensor<8xi64, #SparseVector>
55+
}
56+
57+
//
58+
// Kernel that uses index in the index notation (disjunction).
59+
//
60+
func @sparse_index_1d_disj(%arga: tensor<8xi64, #SparseVector>)
61+
-> tensor<8xi64, #SparseVector> {
62+
%d0 = arith.constant 8 : index
63+
%init = sparse_tensor.init [%d0] : tensor<8xi64, #SparseVector>
64+
%r = linalg.generic #trait_1d
65+
ins(%arga: tensor<8xi64, #SparseVector>)
66+
outs(%init: tensor<8xi64, #SparseVector>) {
67+
^bb(%a: i64, %x: i64):
68+
%i = linalg.index 0 : index
69+
%ii = arith.index_cast %i : index to i64
70+
%m1 = arith.addi %a, %ii : i64
71+
linalg.yield %m1 : i64
72+
} -> tensor<8xi64, #SparseVector>
73+
return %r : tensor<8xi64, #SparseVector>
74+
}
75+
76+
//
77+
// Kernel that uses indices in the index notation (conjunction).
2378
//
24-
func @sparse_index(%arga: tensor<3x4xi64, #SparseMatrix>)
25-
-> tensor<3x4xi64, #SparseMatrix> {
79+
func @sparse_index_2d_conj(%arga: tensor<3x4xi64, #SparseMatrix>)
80+
-> tensor<3x4xi64, #SparseMatrix> {
2681
%d0 = arith.constant 3 : index
2782
%d1 = arith.constant 4 : index
2883
%init = sparse_tensor.init [%d0, %d1] : tensor<3x4xi64, #SparseMatrix>
29-
%r = linalg.generic #trait
84+
%r = linalg.generic #trait_2d
3085
ins(%arga: tensor<3x4xi64, #SparseMatrix>)
3186
outs(%init: tensor<3x4xi64, #SparseMatrix>) {
3287
^bb(%a: i64, %x: i64):
@@ -41,40 +96,122 @@ module {
4196
return %r : tensor<3x4xi64, #SparseMatrix>
4297
}
4398

99+
//
100+
// Kernel that uses indices in the index notation (disjunction).
101+
//
102+
func @sparse_index_2d_disj(%arga: tensor<3x4xi64, #SparseMatrix>)
103+
-> tensor<3x4xi64, #SparseMatrix> {
104+
%d0 = arith.constant 3 : index
105+
%d1 = arith.constant 4 : index
106+
%init = sparse_tensor.init [%d0, %d1] : tensor<3x4xi64, #SparseMatrix>
107+
%r = linalg.generic #trait_2d
108+
ins(%arga: tensor<3x4xi64, #SparseMatrix>)
109+
outs(%init: tensor<3x4xi64, #SparseMatrix>) {
110+
^bb(%a: i64, %x: i64):
111+
%i = linalg.index 0 : index
112+
%j = linalg.index 1 : index
113+
%ii = arith.index_cast %i : index to i64
114+
%jj = arith.index_cast %j : index to i64
115+
%m1 = arith.addi %ii, %a : i64
116+
%m2 = arith.addi %jj, %m1 : i64
117+
linalg.yield %m2 : i64
118+
} -> tensor<3x4xi64, #SparseMatrix>
119+
return %r : tensor<3x4xi64, #SparseMatrix>
120+
}
121+
44122
//
45123
// Main driver.
46124
//
47125
func @entry() {
48126
%c0 = arith.constant 0 : index
49-
%c1 = arith.constant 1 : index
50-
%c4 = arith.constant 4 : index
51127
%du = arith.constant -1 : i64
52128

129+
// Setup input sparse vector.
130+
%v1 = arith.constant sparse<[[2], [4]], [ 10, 20]> : tensor<8xi64>
131+
%sv = sparse_tensor.convert %v1 : tensor<8xi64> to tensor<8xi64, #SparseVector>
132+
133+
// Setup input "sparse" vector.
134+
%v2 = arith.constant dense<[ 1, 2, 4, 8, 16, 32, 64, 128 ]> : tensor<8xi64>
135+
%dv = sparse_tensor.convert %v2 : tensor<8xi64> to tensor<8xi64, #SparseVector>
136+
137+
// Setup input sparse matrix.
138+
%m1 = arith.constant sparse<[[1,1], [2,3]], [10, 20]> : tensor<3x4xi64>
139+
%sm = sparse_tensor.convert %m1 : tensor<3x4xi64> to tensor<3x4xi64, #SparseMatrix>
140+
53141
// Setup input "sparse" matrix.
54-
%d = arith.constant dense <[
55-
[ 1, 1, 1, 1 ],
56-
[ 1, 1, 1, 1 ],
57-
[ 1, 1, 1, 1 ]
58-
]> : tensor<3x4xi64>
59-
%a = sparse_tensor.convert %d : tensor<3x4xi64> to tensor<3x4xi64, #SparseMatrix>
142+
%m2 = arith.constant dense <[ [ 1, 1, 1, 1 ],
143+
[ 1, 2, 1, 1 ],
144+
[ 1, 1, 3, 4 ] ]> : tensor<3x4xi64>
145+
%dm = sparse_tensor.convert %m2 : tensor<3x4xi64> to tensor<3x4xi64, #SparseMatrix>
60146

61-
// Call the kernel.
62-
%0 = call @sparse_index(%a) : (tensor<3x4xi64, #SparseMatrix>) -> tensor<3x4xi64, #SparseMatrix>
147+
// Call the kernels.
148+
%0 = call @sparse_index_1d_conj(%sv) : (tensor<8xi64, #SparseVector>)
149+
-> tensor<8xi64, #SparseVector>
150+
%1 = call @sparse_index_1d_disj(%sv) : (tensor<8xi64, #SparseVector>)
151+
-> tensor<8xi64, #SparseVector>
152+
%2 = call @sparse_index_1d_conj(%dv) : (tensor<8xi64, #SparseVector>)
153+
-> tensor<8xi64, #SparseVector>
154+
%3 = call @sparse_index_1d_disj(%dv) : (tensor<8xi64, #SparseVector>)
155+
-> tensor<8xi64, #SparseVector>
156+
%4 = call @sparse_index_2d_conj(%sm) : (tensor<3x4xi64, #SparseMatrix>)
157+
-> tensor<3x4xi64, #SparseMatrix>
158+
%5 = call @sparse_index_2d_disj(%sm) : (tensor<3x4xi64, #SparseMatrix>)
159+
-> tensor<3x4xi64, #SparseMatrix>
160+
%6 = call @sparse_index_2d_conj(%dm) : (tensor<3x4xi64, #SparseMatrix>)
161+
-> tensor<3x4xi64, #SparseMatrix>
162+
%7 = call @sparse_index_2d_disj(%dm) : (tensor<3x4xi64, #SparseMatrix>)
163+
-> tensor<3x4xi64, #SparseMatrix>
63164

64165
//
65166
// Verify result.
66167
//
67-
// CHECK: ( ( 0, 0, 0, 0 ), ( 0, 1, 2, 3 ), ( 0, 2, 4, 6 ) )
168+
// CHECK: ( 20, 80, -1, -1, -1, -1, -1, -1 )
169+
// CHECK-NEXT: ( 0, 1, 12, 3, 24, 5, 6, 7 )
170+
// CHECK-NEXT: ( 0, 2, 8, 24, 64, 160, 384, 896 )
171+
// CHECK-NEXT: ( 1, 3, 6, 11, 20, 37, 70, 135 )
172+
// CHECK-NEXT: ( 10, 120, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
173+
// CHECK-NEXT: ( 0, 1, 2, 3, 1, 12, 3, 4, 2, 3, 4, 25 )
174+
// CHECK-NEXT: ( 0, 0, 0, 0, 0, 2, 2, 3, 0, 2, 12, 24 )
175+
// CHECK-NEXT: ( 1, 2, 3, 4, 2, 4, 4, 5, 3, 4, 7, 9 )
68176
//
69-
%x = sparse_tensor.convert %0 : tensor<3x4xi64, #SparseMatrix> to tensor<3x4xi64>
70-
%m = bufferization.to_memref %x : memref<3x4xi64>
71-
%v = vector.transfer_read %m[%c0, %c0], %du: memref<3x4xi64>, vector<3x4xi64>
72-
vector.print %v : vector<3x4xi64>
177+
%8 = sparse_tensor.values %0 : tensor<8xi64, #SparseVector> to memref<?xi64>
178+
%9 = sparse_tensor.values %1 : tensor<8xi64, #SparseVector> to memref<?xi64>
179+
%10 = sparse_tensor.values %2 : tensor<8xi64, #SparseVector> to memref<?xi64>
180+
%11 = sparse_tensor.values %3 : tensor<8xi64, #SparseVector> to memref<?xi64>
181+
%12 = sparse_tensor.values %4 : tensor<3x4xi64, #SparseMatrix> to memref<?xi64>
182+
%13 = sparse_tensor.values %5 : tensor<3x4xi64, #SparseMatrix> to memref<?xi64>
183+
%14 = sparse_tensor.values %6 : tensor<3x4xi64, #SparseMatrix> to memref<?xi64>
184+
%15 = sparse_tensor.values %7 : tensor<3x4xi64, #SparseMatrix> to memref<?xi64>
185+
%16 = vector.transfer_read %8[%c0], %du: memref<?xi64>, vector<8xi64>
186+
%17 = vector.transfer_read %9[%c0], %du: memref<?xi64>, vector<8xi64>
187+
%18 = vector.transfer_read %10[%c0], %du: memref<?xi64>, vector<8xi64>
188+
%19 = vector.transfer_read %11[%c0], %du: memref<?xi64>, vector<8xi64>
189+
%20 = vector.transfer_read %12[%c0], %du: memref<?xi64>, vector<12xi64>
190+
%21 = vector.transfer_read %13[%c0], %du: memref<?xi64>, vector<12xi64>
191+
%22 = vector.transfer_read %14[%c0], %du: memref<?xi64>, vector<12xi64>
192+
%23 = vector.transfer_read %15[%c0], %du: memref<?xi64>, vector<12xi64>
193+
vector.print %16 : vector<8xi64>
194+
vector.print %17 : vector<8xi64>
195+
vector.print %18 : vector<8xi64>
196+
vector.print %19 : vector<8xi64>
197+
vector.print %20 : vector<12xi64>
198+
vector.print %21 : vector<12xi64>
199+
vector.print %22 : vector<12xi64>
200+
vector.print %23 : vector<12xi64>
73201

74202
// Release resources.
75-
sparse_tensor.release %a : tensor<3x4xi64, #SparseMatrix>
76-
sparse_tensor.release %0 : tensor<3x4xi64, #SparseMatrix>
77-
memref.dealloc %m : memref<3x4xi64>
203+
sparse_tensor.release %sv : tensor<8xi64, #SparseVector>
204+
sparse_tensor.release %dv : tensor<8xi64, #SparseVector>
205+
sparse_tensor.release %0 : tensor<8xi64, #SparseVector>
206+
sparse_tensor.release %1 : tensor<8xi64, #SparseVector>
207+
sparse_tensor.release %2 : tensor<8xi64, #SparseVector>
208+
sparse_tensor.release %3 : tensor<8xi64, #SparseVector>
209+
sparse_tensor.release %sm : tensor<3x4xi64, #SparseMatrix>
210+
sparse_tensor.release %dm : tensor<3x4xi64, #SparseMatrix>
211+
sparse_tensor.release %4 : tensor<3x4xi64, #SparseMatrix>
212+
sparse_tensor.release %5 : tensor<3x4xi64, #SparseMatrix>
213+
sparse_tensor.release %6 : tensor<3x4xi64, #SparseMatrix>
214+
sparse_tensor.release %7 : tensor<3x4xi64, #SparseMatrix>
78215

79216
return
80217
}

0 commit comments

Comments
 (0)