@@ -87,6 +87,22 @@ module {
87
87
return %0 : tensor <8 x8 xf32 >
88
88
}
89
89
90
+ func.func @add_coo_coo_out_coo (%arga: tensor <8 x8 xf32 , #SortedCOO >,
91
+ %argb: tensor <8 x8 xf32 , #SortedCOO >)
92
+ -> tensor <8 x8 xf32 , #SortedCOO > {
93
+ %init = tensor.empty () : tensor <8 x8 xf32 , #SortedCOO >
94
+ %0 = linalg.generic #trait
95
+ ins (%arga , %argb: tensor <8 x8 xf32 , #SortedCOO >,
96
+ tensor <8 x8 xf32 , #SortedCOO >)
97
+ outs (%init: tensor <8 x8 xf32 , #SortedCOO >) {
98
+ ^bb (%a: f32 , %b: f32 , %x: f32 ):
99
+ %0 = arith.addf %a , %b : f32
100
+ linalg.yield %0 : f32
101
+ } -> tensor <8 x8 xf32 , #SortedCOO >
102
+ return %0 : tensor <8 x8 xf32 , #SortedCOO >
103
+ }
104
+
105
+
90
106
func.func @add_coo_dense (%arga: tensor <8 x8 xf32 >,
91
107
%argb: tensor <8 x8 xf32 , #SortedCOO >)
92
108
-> tensor <8 x8 xf32 > {
@@ -149,17 +165,21 @@ module {
149
165
%C3 = call @add_coo_coo (%COO_A , %COO_B ) : (tensor <8 x8 xf32 , #SortedCOO >,
150
166
tensor <8 x8 xf32 , #SortedCOO >)
151
167
-> tensor <8 x8 xf32 >
168
+ %COO_RET = call @add_coo_coo_out_coo (%COO_A , %COO_B ) : (tensor <8 x8 xf32 , #SortedCOO >,
169
+ tensor <8 x8 xf32 , #SortedCOO >)
170
+ -> tensor <8 x8 xf32 , #SortedCOO >
171
+ %C4 = sparse_tensor.convert %COO_RET : tensor <8 x8 xf32 , #SortedCOO > to tensor <8 x8 xf32 >
152
172
//
153
173
// Verify computed matrix C.
154
174
//
155
- // CHECK-COUNT-3 : ( 8.8, 4.8, 6.8, 4.8, 8.8, 6.1, 14.8, 16.8 )
156
- // CHECK-NEXT-COUNT-3 : ( 4.4, 4.4, 4.4, 8.4, 8.4, 12.4, 16.4, 16.4 )
157
- // CHECK-NEXT-COUNT-3 : ( 8.8, 4.8, 6.8, 8.8, 8.8, 12.8, 14.8, 15.8 )
158
- // CHECK-NEXT-COUNT-3 : ( 4.3, 5.3, 6.3, 8.3, 8.3, 12.3, 14.3, 16.3 )
159
- // CHECK-NEXT-COUNT-3 : ( 4.5, 4.5, 6.5, 8.5, 8.5, 12.5, 14.5, 16.5 )
160
- // CHECK-NEXT-COUNT-3 : ( 9.9, 4.9, 6.9, 8.9, 8.9, 12.9, 15.9, 16.9 )
161
- // CHECK-NEXT-COUNT-3 : ( 12.1, 6.1, 5.1, 9.1, 9.1, 13.1, 15.1, 17.1 )
162
- // CHECK-NEXT-COUNT-3 : ( 15.4, 5.4, 7.4, 5.4, 11.4, 10.4, 11.4, 9.4 )
175
+ // CHECK-COUNT-4 : ( 8.8, 4.8, 6.8, 4.8, 8.8, 6.1, 14.8, 16.8 )
176
+ // CHECK-NEXT-COUNT-4 : ( 4.4, 4.4, 4.4, 8.4, 8.4, 12.4, 16.4, 16.4 )
177
+ // CHECK-NEXT-COUNT-4 : ( 8.8, 4.8, 6.8, 8.8, 8.8, 12.8, 14.8, 15.8 )
178
+ // CHECK-NEXT-COUNT-4 : ( 4.3, 5.3, 6.3, 8.3, 8.3, 12.3, 14.3, 16.3 )
179
+ // CHECK-NEXT-COUNT-4 : ( 4.5, 4.5, 6.5, 8.5, 8.5, 12.5, 14.5, 16.5 )
180
+ // CHECK-NEXT-COUNT-4 : ( 9.9, 4.9, 6.9, 8.9, 8.9, 12.9, 15.9, 16.9 )
181
+ // CHECK-NEXT-COUNT-4 : ( 12.1, 6.1, 5.1, 9.1, 9.1, 13.1, 15.1, 17.1 )
182
+ // CHECK-NEXT-COUNT-4 : ( 15.4, 5.4, 7.4, 5.4, 11.4, 10.4, 11.4, 9.4 )
163
183
//
164
184
%f0 = arith.constant 0.0 : f32
165
185
scf.for %i = %c0 to %c8 step %c1 {
@@ -169,9 +189,12 @@ module {
169
189
: tensor <8 x8 xf32 >, vector <8 xf32 >
170
190
%v3 = vector.transfer_read %C3 [%i , %c0 ], %f0
171
191
: tensor <8 x8 xf32 >, vector <8 xf32 >
192
+ %v4 = vector.transfer_read %C4 [%i , %c0 ], %f0
193
+ : tensor <8 x8 xf32 >, vector <8 xf32 >
172
194
vector.print %v1 : vector <8 xf32 >
173
195
vector.print %v2 : vector <8 xf32 >
174
196
vector.print %v3 : vector <8 xf32 >
197
+ vector.print %v4 : vector <8 xf32 >
175
198
}
176
199
177
200
// Release resources.
@@ -181,6 +204,7 @@ module {
181
204
bufferization.dealloc_tensor %CSR_A : tensor <8 x8 xf32 , #CSR >
182
205
bufferization.dealloc_tensor %COO_A : tensor <8 x8 xf32 , #SortedCOO >
183
206
bufferization.dealloc_tensor %COO_B : tensor <8 x8 xf32 , #SortedCOO >
207
+ bufferization.dealloc_tensor %COO_RET : tensor <8 x8 xf32 , #SortedCOO >
184
208
185
209
186
210
return
0 commit comments