@@ -172,18 +172,19 @@ module {
172
172
%COO_RET = call @add_coo_coo_out_coo (%COO_A , %COO_B ) : (tensor <8 x8 xf32 , #SortedCOO >,
173
173
tensor <8 x8 xf32 , #SortedCOOSoA >)
174
174
-> tensor <8 x8 xf32 , #SortedCOOSoA >
175
+ %C4 = sparse_tensor.convert %COO_RET : tensor <8 x8 xf32 , #SortedCOOSoA > to tensor <8 x8 xf32 >
175
176
176
177
//
177
178
// Verify computed matrix C.
178
179
//
179
- // CHECK-COUNT-3 : ( 8.8, 4.8, 6.8, 4.8, 8.8, 6.1, 14.8, 16.8 )
180
- // CHECK-NEXT-COUNT-3 : ( 4.4, 4.4, 4.4, 8.4, 8.4, 12.4, 16.4, 16.4 )
181
- // CHECK-NEXT-COUNT-3 : ( 8.8, 4.8, 6.8, 8.8, 8.8, 12.8, 14.8, 15.8 )
182
- // CHECK-NEXT-COUNT-3 : ( 4.3, 5.3, 6.3, 8.3, 8.3, 12.3, 14.3, 16.3 )
183
- // CHECK-NEXT-COUNT-3 : ( 4.5, 4.5, 6.5, 8.5, 8.5, 12.5, 14.5, 16.5 )
184
- // CHECK-NEXT-COUNT-3 : ( 9.9, 4.9, 6.9, 8.9, 8.9, 12.9, 15.9, 16.9 )
185
- // CHECK-NEXT-COUNT-3 : ( 12.1, 6.1, 5.1, 9.1, 9.1, 13.1, 15.1, 17.1 )
186
- // CHECK-NEXT-COUNT-3 : ( 15.4, 5.4, 7.4, 5.4, 11.4, 10.4, 11.4, 9.4 )
180
+ // CHECK-COUNT-4 : ( 8.8, 4.8, 6.8, 4.8, 8.8, 6.1, 14.8, 16.8 )
181
+ // CHECK-NEXT-COUNT-4 : ( 4.4, 4.4, 4.4, 8.4, 8.4, 12.4, 16.4, 16.4 )
182
+ // CHECK-NEXT-COUNT-4 : ( 8.8, 4.8, 6.8, 8.8, 8.8, 12.8, 14.8, 15.8 )
183
+ // CHECK-NEXT-COUNT-4 : ( 4.3, 5.3, 6.3, 8.3, 8.3, 12.3, 14.3, 16.3 )
184
+ // CHECK-NEXT-COUNT-4 : ( 4.5, 4.5, 6.5, 8.5, 8.5, 12.5, 14.5, 16.5 )
185
+ // CHECK-NEXT-COUNT-4 : ( 9.9, 4.9, 6.9, 8.9, 8.9, 12.9, 15.9, 16.9 )
186
+ // CHECK-NEXT-COUNT-4 : ( 12.1, 6.1, 5.1, 9.1, 9.1, 13.1, 15.1, 17.1 )
187
+ // CHECK-NEXT-COUNT-4 : ( 15.4, 5.4, 7.4, 5.4, 11.4, 10.4, 11.4, 9.4 )
187
188
//
188
189
%f0 = arith.constant 0.0 : f32
189
190
scf.for %i = %c0 to %c8 step %c1 {
@@ -193,9 +194,12 @@ module {
193
194
: tensor <8 x8 xf32 >, vector <8 xf32 >
194
195
%v3 = vector.transfer_read %C3 [%i , %c0 ], %f0
195
196
: tensor <8 x8 xf32 >, vector <8 xf32 >
197
+ %v4 = vector.transfer_read %C4 [%i , %c0 ], %f0
198
+ : tensor <8 x8 xf32 >, vector <8 xf32 >
196
199
vector.print %v1 : vector <8 xf32 >
197
200
vector.print %v2 : vector <8 xf32 >
198
201
vector.print %v3 : vector <8 xf32 >
202
+ vector.print %v4 : vector <8 xf32 >
199
203
}
200
204
201
205
//
@@ -228,6 +232,7 @@ module {
228
232
bufferization.dealloc_tensor %C1 : tensor <8 x8 xf32 >
229
233
bufferization.dealloc_tensor %C2 : tensor <8 x8 xf32 >
230
234
bufferization.dealloc_tensor %C3 : tensor <8 x8 xf32 >
235
+ bufferization.dealloc_tensor %C4 : tensor <8 x8 xf32 >
231
236
bufferization.dealloc_tensor %CSR_A : tensor <8 x8 xf32 , #CSR >
232
237
bufferization.dealloc_tensor %COO_A : tensor <8 x8 xf32 , #SortedCOO >
233
238
bufferization.dealloc_tensor %COO_B : tensor <8 x8 xf32 , #SortedCOOSoA >
0 commit comments