32
32
//
33
33
// Traits for tensor operations.
34
34
//
35
- #trait_vec_scale = {
35
+ #trait_vec = {
36
36
indexing_maps = [
37
37
affine_map <(i ) -> (i )>, // a (in)
38
38
affine_map <(i ) -> (i )> // x (out)
39
39
],
40
40
iterator_types = [" parallel" ]
41
41
}
42
- #trait_mat_scale = {
42
+ #trait_mat = {
43
43
indexing_maps = [
44
44
affine_map <(i ,j ) -> (i ,j )>, // A (in)
45
45
affine_map <(i ,j ) -> (i ,j )> // X (out)
49
49
50
50
module {
51
51
// Invert the structure of a sparse vector. Present values become missing.
52
- // Missing values are filled with 1 (i32).
53
- func.func @vector_complement (%arga: tensor <?xf64 , #SparseVector >) -> tensor <?xi32 , #SparseVector > {
52
+ // Missing values are filled with 1 (i32). Output is sparse.
53
+ func.func @vector_complement_sparse (%arga: tensor <?xf64 , #SparseVector >) -> tensor <?xi32 , #SparseVector > {
54
54
%c = arith.constant 0 : index
55
55
%ci1 = arith.constant 1 : i32
56
56
%d = tensor.dim %arga , %c : tensor <?xf64 , #SparseVector >
57
57
%xv = bufferization.alloc_tensor (%d ) : tensor <?xi32 , #SparseVector >
58
- %0 = linalg.generic #trait_vec_scale
58
+ %0 = linalg.generic #trait_vec
59
59
ins (%arga: tensor <?xf64 , #SparseVector >)
60
60
outs (%xv: tensor <?xi32 , #SparseVector >) {
61
61
^bb (%a: f64 , %x: i32 ):
@@ -69,13 +69,35 @@ module {
69
69
return %0 : tensor <?xi32 , #SparseVector >
70
70
}
71
71
72
+ // Invert the structure of a sparse vector, where missing values are
73
+ // filled with 1. For a dense output, the sparse compiler initializes
74
+ // the buffer to all zero at all other places.
75
+ func.func @vector_complement_dense (%arga: tensor <?xf64 , #SparseVector >) -> tensor <?xi32 > {
76
+ %c = arith.constant 0 : index
77
+ %d = tensor.dim %arga , %c : tensor <?xf64 , #SparseVector >
78
+ %xv = bufferization.alloc_tensor (%d ) : tensor <?xi32 >
79
+ %0 = linalg.generic #trait_vec
80
+ ins (%arga: tensor <?xf64 , #SparseVector >)
81
+ outs (%xv: tensor <?xi32 >) {
82
+ ^bb (%a: f64 , %x: i32 ):
83
+ %1 = sparse_tensor.unary %a : f64 to i32
84
+ present ={}
85
+ absent ={
86
+ %ci1 = arith.constant 1 : i32
87
+ sparse_tensor.yield %ci1 : i32
88
+ }
89
+ linalg.yield %1 : i32
90
+ } -> tensor <?xi32 >
91
+ return %0 : tensor <?xi32 >
92
+ }
93
+
72
94
// Negate existing values. Fill missing ones with +1.
73
95
func.func @vector_negation (%arga: tensor <?xf64 , #SparseVector >) -> tensor <?xf64 , #SparseVector > {
74
96
%c = arith.constant 0 : index
75
97
%cf1 = arith.constant 1.0 : f64
76
98
%d = tensor.dim %arga , %c : tensor <?xf64 , #SparseVector >
77
99
%xv = bufferization.alloc_tensor (%d ) : tensor <?xf64 , #SparseVector >
78
- %0 = linalg.generic #trait_vec_scale
100
+ %0 = linalg.generic #trait_vec
79
101
ins (%arga: tensor <?xf64 , #SparseVector >)
80
102
outs (%xv: tensor <?xf64 , #SparseVector >) {
81
103
^bb (%a: f64 , %x: f64 ):
@@ -98,7 +120,7 @@ module {
98
120
%c = arith.constant 0 : index
99
121
%d = tensor.dim %arga , %c : tensor <?xf64 , #SparseVector >
100
122
%xv = bufferization.alloc_tensor (%d ) : tensor <?xf64 , #SparseVector >
101
- %0 = linalg.generic #trait_vec_scale
123
+ %0 = linalg.generic #trait_vec
102
124
ins (%arga: tensor <?xf64 , #SparseVector >)
103
125
outs (%xv: tensor <?xf64 , #SparseVector >) {
104
126
^bb (%a: f64 , %x: f64 ):
@@ -126,7 +148,7 @@ module {
126
148
%d0 = tensor.dim %argx , %c0 : tensor <?x?xf64 , #DCSR >
127
149
%d1 = tensor.dim %argx , %c1 : tensor <?x?xf64 , #DCSR >
128
150
%xv = bufferization.alloc_tensor (%d0 , %d1 ) : tensor <?x?xf64 , #DCSR >
129
- %0 = linalg.generic #trait_mat_scale
151
+ %0 = linalg.generic #trait_mat
130
152
ins (%argx: tensor <?x?xf64 , #DCSR >)
131
153
outs (%xv: tensor <?x?xf64 , #DCSR >) {
132
154
^bb (%a: f64 , %x: f64 ):
@@ -153,7 +175,7 @@ module {
153
175
%d0 = tensor.dim %argx , %c0 : tensor <?x?xf64 , #DCSR >
154
176
%d1 = tensor.dim %argx , %c1 : tensor <?x?xf64 , #DCSR >
155
177
%xv = bufferization.alloc_tensor (%d0 , %d1 ) : tensor <?x?xf64 , #DCSR >
156
- %0 = linalg.generic #trait_mat_scale
178
+ %0 = linalg.generic #trait_mat
157
179
ins (%argx: tensor <?x?xf64 , #DCSR >)
158
180
outs (%xv: tensor <?x?xf64 , #DCSR >) {
159
181
^bb (%a: f64 , %x: f64 ):
@@ -223,6 +245,7 @@ module {
223
245
224
246
// Driver method to call and verify vector kernels.
225
247
func.func @entry () {
248
+ %cmu = arith.constant -99 : i32
226
249
%c0 = arith.constant 0 : index
227
250
228
251
// Setup sparse vectors.
@@ -240,7 +263,7 @@ module {
240
263
%sm1 = sparse_tensor.convert %m1 : tensor <4 x8 xf64 > to tensor <?x?xf64 , #DCSR >
241
264
242
265
// Call sparse vector kernels.
243
- %0 = call @vector_complement (%sv1 )
266
+ %0 = call @vector_complement_sparse (%sv1 )
244
267
: (tensor <?xf64 , #SparseVector >) -> tensor <?xi32 , #SparseVector >
245
268
%1 = call @vector_negation (%sv1 )
246
269
: (tensor <?xf64 , #SparseVector >) -> tensor <?xf64 , #SparseVector >
@@ -253,6 +276,9 @@ module {
253
276
%4 = call @matrix_slice (%sm1 )
254
277
: (tensor <?x?xf64 , #DCSR >) -> tensor <?x?xf64 , #DCSR >
255
278
279
+ // Call kernel with dense output.
280
+ %5 = call @vector_complement_dense (%sv1 ) : (tensor <?xf64 , #SparseVector >) -> tensor <?xi32 >
281
+
256
282
//
257
283
// Verify the results.
258
284
//
@@ -268,13 +294,16 @@ module {
268
294
// CHECK-NEXT: ( ( 3, 3, 0, 0, 0, 0, 0, 0 ), ( 0, 0, 0, 0, 0, 0, 0, 3 ), ( 0, 0, 4, 0, 5, 0, 0, 6 ), ( 7, 0, 7, 7, 0, 0, 0, 0 ) )
269
295
// CHECK-NEXT: ( 99, 99, 99, 99, 5, 6, 99, 99, 99, 0, 0, 0, 0, 0, 0, 0 )
270
296
// CHECK-NEXT: ( ( 99, 99, 0, 0, 0, 0, 0, 0 ), ( 0, 0, 0, 0, 0, 0, 0, 99 ), ( 0, 0, 99, 0, 5, 0, 0, 6 ), ( 99, 0, 99, 99, 0, 0, 0, 0 ) )
297
+ // CHECK-NEXT: ( 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0 )
271
298
//
272
299
call @dump_vec_f64 (%sv1 ) : (tensor <?xf64 , #SparseVector >) -> ()
273
300
call @dump_vec_i32 (%0 ) : (tensor <?xi32 , #SparseVector >) -> ()
274
301
call @dump_vec_f64 (%1 ) : (tensor <?xf64 , #SparseVector >) -> ()
275
302
call @dump_vec_f64 (%2 ) : (tensor <?xf64 , #SparseVector >) -> ()
276
303
call @dump_mat (%3 ) : (tensor <?x?xf64 , #DCSR >) -> ()
277
304
call @dump_mat (%4 ) : (tensor <?x?xf64 , #DCSR >) -> ()
305
+ %v = vector.transfer_read %5 [%c0 ], %cmu: tensor <?xi32 >, vector <32 xi32 >
306
+ vector.print %v : vector <32 xi32 >
278
307
279
308
// Release the resources.
280
309
bufferization.dealloc_tensor %sv1 : tensor <?xf64 , #SparseVector >
@@ -284,6 +313,7 @@ module {
284
313
bufferization.dealloc_tensor %2 : tensor <?xf64 , #SparseVector >
285
314
bufferization.dealloc_tensor %3 : tensor <?x?xf64 , #DCSR >
286
315
bufferization.dealloc_tensor %4 : tensor <?x?xf64 , #DCSR >
316
+ bufferization.dealloc_tensor %5 : tensor <?xi32 >
287
317
return
288
318
}
289
319
}
0 commit comments