3
3
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
4
4
// RUN: FileCheck %s
5
5
6
+ #SparseVector = #sparse_tensor.encoding <{
7
+ dimLevelType = [" compressed" ]
8
+ }>
9
+
6
10
#SparseMatrix = #sparse_tensor.encoding <{
7
11
dimLevelType = [" compressed" , " compressed" ]
8
12
}>
9
13
10
- #trait = {
14
+ #trait_1d = {
15
+ indexing_maps = [
16
+ affine_map <(i ) -> (i )>, // a
17
+ affine_map <(i ) -> (i )> // x (out)
18
+ ],
19
+ iterator_types = [" parallel" ],
20
+ doc = " X(i) = a(i) op i"
21
+ }
22
+
23
+ #trait_2d = {
11
24
indexing_maps = [
12
25
affine_map <(i ,j ) -> (i ,j )>, // A
13
26
affine_map <(i ,j ) -> (i ,j )> // X (out)
14
27
],
15
28
iterator_types = [" parallel" , " parallel" ],
16
- doc = " X(i,j) = A(i,j) * i * j"
29
+ doc = " X(i,j) = A(i,j) op i op j"
17
30
}
18
31
32
+ //
33
+ // Test with indices. Note that a lot of results are actually
34
+ // dense, but this is done to stress test all the operations.
35
+ //
19
36
module {
20
37
21
38
//
22
- // Kernel that uses indices in the index notation.
39
+ // Kernel that uses index in the index notation (conjunction).
40
+ //
41
+ func @sparse_index_1d_conj (%arga: tensor <8 xi64 , #SparseVector >)
42
+ -> tensor <8 xi64 , #SparseVector > {
43
+ %d0 = arith.constant 8 : index
44
+ %init = sparse_tensor.init [%d0 ] : tensor <8 xi64 , #SparseVector >
45
+ %r = linalg.generic #trait_1d
46
+ ins (%arga: tensor <8 xi64 , #SparseVector >)
47
+ outs (%init: tensor <8 xi64 , #SparseVector >) {
48
+ ^bb (%a: i64 , %x: i64 ):
49
+ %i = linalg.index 0 : index
50
+ %ii = arith.index_cast %i : index to i64
51
+ %m1 = arith.muli %a , %ii : i64
52
+ linalg.yield %m1 : i64
53
+ } -> tensor <8 xi64 , #SparseVector >
54
+ return %r : tensor <8 xi64 , #SparseVector >
55
+ }
56
+
57
+ //
58
+ // Kernel that uses index in the index notation (disjunction).
59
+ //
60
+ func @sparse_index_1d_disj (%arga: tensor <8 xi64 , #SparseVector >)
61
+ -> tensor <8 xi64 , #SparseVector > {
62
+ %d0 = arith.constant 8 : index
63
+ %init = sparse_tensor.init [%d0 ] : tensor <8 xi64 , #SparseVector >
64
+ %r = linalg.generic #trait_1d
65
+ ins (%arga: tensor <8 xi64 , #SparseVector >)
66
+ outs (%init: tensor <8 xi64 , #SparseVector >) {
67
+ ^bb (%a: i64 , %x: i64 ):
68
+ %i = linalg.index 0 : index
69
+ %ii = arith.index_cast %i : index to i64
70
+ %m1 = arith.addi %a , %ii : i64
71
+ linalg.yield %m1 : i64
72
+ } -> tensor <8 xi64 , #SparseVector >
73
+ return %r : tensor <8 xi64 , #SparseVector >
74
+ }
75
+
76
+ //
77
+ // Kernel that uses indices in the index notation (conjunction).
23
78
//
24
- func @sparse_index (%arga: tensor <3 x4 xi64 , #SparseMatrix >)
25
- -> tensor <3 x4 xi64 , #SparseMatrix > {
79
+ func @sparse_index_2d_conj (%arga: tensor <3 x4 xi64 , #SparseMatrix >)
80
+ -> tensor <3 x4 xi64 , #SparseMatrix > {
26
81
%d0 = arith.constant 3 : index
27
82
%d1 = arith.constant 4 : index
28
83
%init = sparse_tensor.init [%d0 , %d1 ] : tensor <3 x4 xi64 , #SparseMatrix >
29
- %r = linalg.generic #trait
84
+ %r = linalg.generic #trait_2d
30
85
ins (%arga: tensor <3 x4 xi64 , #SparseMatrix >)
31
86
outs (%init: tensor <3 x4 xi64 , #SparseMatrix >) {
32
87
^bb (%a: i64 , %x: i64 ):
@@ -41,40 +96,122 @@ module {
41
96
return %r : tensor <3 x4 xi64 , #SparseMatrix >
42
97
}
43
98
99
+ //
100
+ // Kernel that uses indices in the index notation (disjunction).
101
+ //
102
+ func @sparse_index_2d_disj (%arga: tensor <3 x4 xi64 , #SparseMatrix >)
103
+ -> tensor <3 x4 xi64 , #SparseMatrix > {
104
+ %d0 = arith.constant 3 : index
105
+ %d1 = arith.constant 4 : index
106
+ %init = sparse_tensor.init [%d0 , %d1 ] : tensor <3 x4 xi64 , #SparseMatrix >
107
+ %r = linalg.generic #trait_2d
108
+ ins (%arga: tensor <3 x4 xi64 , #SparseMatrix >)
109
+ outs (%init: tensor <3 x4 xi64 , #SparseMatrix >) {
110
+ ^bb (%a: i64 , %x: i64 ):
111
+ %i = linalg.index 0 : index
112
+ %j = linalg.index 1 : index
113
+ %ii = arith.index_cast %i : index to i64
114
+ %jj = arith.index_cast %j : index to i64
115
+ %m1 = arith.addi %ii , %a : i64
116
+ %m2 = arith.addi %jj , %m1 : i64
117
+ linalg.yield %m2 : i64
118
+ } -> tensor <3 x4 xi64 , #SparseMatrix >
119
+ return %r : tensor <3 x4 xi64 , #SparseMatrix >
120
+ }
121
+
44
122
//
45
123
// Main driver.
46
124
//
47
125
func @entry () {
48
126
%c0 = arith.constant 0 : index
49
- %c1 = arith.constant 1 : index
50
- %c4 = arith.constant 4 : index
51
127
%du = arith.constant -1 : i64
52
128
129
+ // Setup input sparse vector.
130
+ %v1 = arith.constant sparse <[[2 ], [4 ]], [ 10 , 20 ]> : tensor <8 xi64 >
131
+ %sv = sparse_tensor.convert %v1 : tensor <8 xi64 > to tensor <8 xi64 , #SparseVector >
132
+
133
+ // Setup input "sparse" vector.
134
+ %v2 = arith.constant dense <[ 1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 ]> : tensor <8 xi64 >
135
+ %dv = sparse_tensor.convert %v2 : tensor <8 xi64 > to tensor <8 xi64 , #SparseVector >
136
+
137
+ // Setup input sparse matrix.
138
+ %m1 = arith.constant sparse <[[1 ,1 ], [2 ,3 ]], [10 , 20 ]> : tensor <3 x4 xi64 >
139
+ %sm = sparse_tensor.convert %m1 : tensor <3 x4 xi64 > to tensor <3 x4 xi64 , #SparseMatrix >
140
+
53
141
// Setup input "sparse" matrix.
54
- %d = arith.constant dense <[
55
- [ 1 , 1 , 1 , 1 ],
56
- [ 1 , 1 , 1 , 1 ],
57
- [ 1 , 1 , 1 , 1 ]
58
- ]> : tensor <3 x4 xi64 >
59
- %a = sparse_tensor.convert %d : tensor <3 x4 xi64 > to tensor <3 x4 xi64 , #SparseMatrix >
142
+ %m2 = arith.constant dense <[ [ 1 , 1 , 1 , 1 ],
143
+ [ 1 , 2 , 1 , 1 ],
144
+ [ 1 , 1 , 3 , 4 ] ]> : tensor <3 x4 xi64 >
145
+ %dm = sparse_tensor.convert %m2 : tensor <3 x4 xi64 > to tensor <3 x4 xi64 , #SparseMatrix >
60
146
61
- // Call the kernel.
62
- %0 = call @sparse_index (%a ) : (tensor <3 x4 xi64 , #SparseMatrix >) -> tensor <3 x4 xi64 , #SparseMatrix >
147
+ // Call the kernels.
148
+ %0 = call @sparse_index_1d_conj (%sv ) : (tensor <8 xi64 , #SparseVector >)
149
+ -> tensor <8 xi64 , #SparseVector >
150
+ %1 = call @sparse_index_1d_disj (%sv ) : (tensor <8 xi64 , #SparseVector >)
151
+ -> tensor <8 xi64 , #SparseVector >
152
+ %2 = call @sparse_index_1d_conj (%dv ) : (tensor <8 xi64 , #SparseVector >)
153
+ -> tensor <8 xi64 , #SparseVector >
154
+ %3 = call @sparse_index_1d_disj (%dv ) : (tensor <8 xi64 , #SparseVector >)
155
+ -> tensor <8 xi64 , #SparseVector >
156
+ %4 = call @sparse_index_2d_conj (%sm ) : (tensor <3 x4 xi64 , #SparseMatrix >)
157
+ -> tensor <3 x4 xi64 , #SparseMatrix >
158
+ %5 = call @sparse_index_2d_disj (%sm ) : (tensor <3 x4 xi64 , #SparseMatrix >)
159
+ -> tensor <3 x4 xi64 , #SparseMatrix >
160
+ %6 = call @sparse_index_2d_conj (%dm ) : (tensor <3 x4 xi64 , #SparseMatrix >)
161
+ -> tensor <3 x4 xi64 , #SparseMatrix >
162
+ %7 = call @sparse_index_2d_disj (%dm ) : (tensor <3 x4 xi64 , #SparseMatrix >)
163
+ -> tensor <3 x4 xi64 , #SparseMatrix >
63
164
64
165
//
65
166
// Verify result.
66
167
//
67
- // CHECK: ( ( 0, 0, 0, 0 ), ( 0, 1, 2, 3 ), ( 0, 2, 4, 6 ) )
168
+ // CHECK: ( 20, 80, -1, -1, -1, -1, -1, -1 )
169
+ // CHECK-NEXT: ( 0, 1, 12, 3, 24, 5, 6, 7 )
170
+ // CHECK-NEXT: ( 0, 2, 8, 24, 64, 160, 384, 896 )
171
+ // CHECK-NEXT: ( 1, 3, 6, 11, 20, 37, 70, 135 )
172
+ // CHECK-NEXT: ( 10, 120, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
173
+ // CHECK-NEXT: ( 0, 1, 2, 3, 1, 12, 3, 4, 2, 3, 4, 25 )
174
+ // CHECK-NEXT: ( 0, 0, 0, 0, 0, 2, 2, 3, 0, 2, 12, 24 )
175
+ // CHECK-NEXT: ( 1, 2, 3, 4, 2, 4, 4, 5, 3, 4, 7, 9 )
68
176
//
69
- %x = sparse_tensor.convert %0 : tensor <3 x4 xi64 , #SparseMatrix > to tensor <3 x4 xi64 >
70
- %m = bufferization.to_memref %x : memref <3 x4 xi64 >
71
- %v = vector.transfer_read %m [%c0 , %c0 ], %du: memref <3 x4 xi64 >, vector <3 x4 xi64 >
72
- vector.print %v : vector <3 x4 xi64 >
177
+ %8 = sparse_tensor.values %0 : tensor <8 xi64 , #SparseVector > to memref <?xi64 >
178
+ %9 = sparse_tensor.values %1 : tensor <8 xi64 , #SparseVector > to memref <?xi64 >
179
+ %10 = sparse_tensor.values %2 : tensor <8 xi64 , #SparseVector > to memref <?xi64 >
180
+ %11 = sparse_tensor.values %3 : tensor <8 xi64 , #SparseVector > to memref <?xi64 >
181
+ %12 = sparse_tensor.values %4 : tensor <3 x4 xi64 , #SparseMatrix > to memref <?xi64 >
182
+ %13 = sparse_tensor.values %5 : tensor <3 x4 xi64 , #SparseMatrix > to memref <?xi64 >
183
+ %14 = sparse_tensor.values %6 : tensor <3 x4 xi64 , #SparseMatrix > to memref <?xi64 >
184
+ %15 = sparse_tensor.values %7 : tensor <3 x4 xi64 , #SparseMatrix > to memref <?xi64 >
185
+ %16 = vector.transfer_read %8 [%c0 ], %du: memref <?xi64 >, vector <8 xi64 >
186
+ %17 = vector.transfer_read %9 [%c0 ], %du: memref <?xi64 >, vector <8 xi64 >
187
+ %18 = vector.transfer_read %10 [%c0 ], %du: memref <?xi64 >, vector <8 xi64 >
188
+ %19 = vector.transfer_read %11 [%c0 ], %du: memref <?xi64 >, vector <8 xi64 >
189
+ %20 = vector.transfer_read %12 [%c0 ], %du: memref <?xi64 >, vector <12 xi64 >
190
+ %21 = vector.transfer_read %13 [%c0 ], %du: memref <?xi64 >, vector <12 xi64 >
191
+ %22 = vector.transfer_read %14 [%c0 ], %du: memref <?xi64 >, vector <12 xi64 >
192
+ %23 = vector.transfer_read %15 [%c0 ], %du: memref <?xi64 >, vector <12 xi64 >
193
+ vector.print %16 : vector <8 xi64 >
194
+ vector.print %17 : vector <8 xi64 >
195
+ vector.print %18 : vector <8 xi64 >
196
+ vector.print %19 : vector <8 xi64 >
197
+ vector.print %20 : vector <12 xi64 >
198
+ vector.print %21 : vector <12 xi64 >
199
+ vector.print %22 : vector <12 xi64 >
200
+ vector.print %23 : vector <12 xi64 >
73
201
74
202
// Release resources.
75
- sparse_tensor.release %a : tensor <3 x4 xi64 , #SparseMatrix >
76
- sparse_tensor.release %0 : tensor <3 x4 xi64 , #SparseMatrix >
77
- memref.dealloc %m : memref <3 x4 xi64 >
203
+ sparse_tensor.release %sv : tensor <8 xi64 , #SparseVector >
204
+ sparse_tensor.release %dv : tensor <8 xi64 , #SparseVector >
205
+ sparse_tensor.release %0 : tensor <8 xi64 , #SparseVector >
206
+ sparse_tensor.release %1 : tensor <8 xi64 , #SparseVector >
207
+ sparse_tensor.release %2 : tensor <8 xi64 , #SparseVector >
208
+ sparse_tensor.release %3 : tensor <8 xi64 , #SparseVector >
209
+ sparse_tensor.release %sm : tensor <3 x4 xi64 , #SparseMatrix >
210
+ sparse_tensor.release %dm : tensor <3 x4 xi64 , #SparseMatrix >
211
+ sparse_tensor.release %4 : tensor <3 x4 xi64 , #SparseMatrix >
212
+ sparse_tensor.release %5 : tensor <3 x4 xi64 , #SparseMatrix >
213
+ sparse_tensor.release %6 : tensor <3 x4 xi64 , #SparseMatrix >
214
+ sparse_tensor.release %7 : tensor <3 x4 xi64 , #SparseMatrix >
78
215
79
216
return
80
217
}
0 commit comments