38
38
map = (d0 , d1 ) -> (d1 : dense , d0 : compressed)
39
39
}>
40
40
41
+ #map = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 + d1 , d3 + d2 )>
42
+ #map1 = affine_map <(d0 , d1 , d2 , d3 ) -> (d1 , d2 )>
43
+ #map2 = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d3 )>
44
+
41
45
// An example of a 2D convolution with a sparse filter.
42
46
module {
43
47
@@ -50,6 +54,21 @@ module {
50
54
return %0 : tensor <6 x6 xi32 >
51
55
}
52
56
57
+ func.func @conv2d_CSR_dense_rotated (%arg0: tensor <8 x8 xi32 , #CSR >,
58
+ %arg1: tensor <3 x3 xi32 >) -> tensor <6 x6 xi32 > {
59
+ %s = tensor.empty () : tensor <6 x6 xi32 >
60
+ %0 = linalg.generic {index ing_maps = [#map , #map1 , #map2 ],
61
+ iterator_types = [" parallel" , " reduction" , " reduction" , " parallel" ]}
62
+ ins (%arg0 , %arg1 : tensor <8 x8 xi32 , #CSR >, tensor <3 x3 xi32 >)
63
+ outs (%s : tensor <6 x6 xi32 >) attrs = {sorted = true } {
64
+ ^bb0 (%in: i32 , %in_0: i32 , %out: i32 ):
65
+ %1 = arith.muli %in , %in_0 : i32
66
+ %2 = arith.addi %out , %1 : i32
67
+ linalg.yield %2 : i32
68
+ } -> tensor <6 x6 xi32 >
69
+ return %0 : tensor <6 x6 xi32 >
70
+ }
71
+
53
72
func.func @conv2d_sparse_out (%input: tensor <8 x8 xi32 >,
54
73
%filter: tensor <3 x3 xi32 >) -> tensor <6 x6 xi32 , #DCSR > {
55
74
%s = tensor.empty () : tensor <6 x6 xi32 , #DCSR >
@@ -146,7 +165,9 @@ module {
146
165
%5 = call @conv2d_all_sparse_CSC (%sparse_input_CSC , %filter )
147
166
: (tensor <8 x8 xi32 , #CSC >,
148
167
tensor <3 x3 xi32 >) -> tensor <6 x6 xi32 , #CSC >
149
-
168
+ %6 = call @conv2d_CSR_dense_rotated (%sparse_input_CSR , %filter )
169
+ : (tensor <8 x8 xi32 , #CSR >,
170
+ tensor <3 x3 xi32 >) -> tensor <6 x6 xi32 >
150
171
151
172
// Verify the output.
152
173
//
@@ -236,6 +257,20 @@ module {
236
257
: tensor <6 x6 xi32 >, vector <6 x6 xi32 >
237
258
vector.print %v5 : vector <6 x6 xi32 >
238
259
260
+ //
261
+ // Should be the same as dense output
262
+ // CHECK: ( ( 0, 0, -1, -6, -1, 6 ),
263
+ // CHECK-SAME: ( -1, 0, 1, 0, 1, 0 ),
264
+ // CHECK-SAME: ( 0, -1, 1, 0, 0, 0 ),
265
+ // CHECK-SAME: ( -1, 0, 0, 0, 0, 0 ),
266
+ // CHECK-SAME: ( 0, 0, 3, 6, -3, -6 ),
267
+ // CHECK-SAME: ( 2, -1, 3, 0, -3, 0 ) )
268
+ //
269
+ %v6 = vector.transfer_read %6 [%c0 , %c0 ], %i0
270
+ : tensor <6 x6 xi32 >, vector <6 x6 xi32 >
271
+ vector.print %v : vector <6 x6 xi32 >
272
+
273
+
239
274
// Release the resources.
240
275
bufferization.dealloc_tensor %sparse_input_DCSR : tensor <8 x8 xi32 , #DCSR >
241
276
bufferization.dealloc_tensor %sparse_input_CSR : tensor <8 x8 xi32 , #CSR >
0 commit comments