@@ -65,8 +65,9 @@ module {
65
65
//
66
66
// Kernel that uses index in the index notation (conjunction).
67
67
//
68
- func.func @sparse_index_1d_conj (%arga: tensor <8 xi64 , #SparseVector >,
69
- %out: tensor <8 xi64 >) -> tensor <8 xi64 > {
68
+ func.func @sparse_index_1d_conj (%arga: tensor <8 xi64 , #SparseVector >)
69
+ -> tensor <8 xi64 > {
70
+ %out = tensor.empty () : tensor <8 xi64 >
70
71
%r = linalg.generic #trait_1d
71
72
ins (%arga: tensor <8 xi64 , #SparseVector >)
72
73
outs (%out: tensor <8 xi64 >) {
@@ -82,8 +83,9 @@ module {
82
83
//
83
84
// Kernel that uses index in the index notation (disjunction).
84
85
//
85
- func.func @sparse_index_1d_disj (%arga: tensor <8 xi64 , #SparseVector >,
86
- %out: tensor <8 xi64 >) -> tensor <8 xi64 > {
86
+ func.func @sparse_index_1d_disj (%arga: tensor <8 xi64 , #SparseVector >)
87
+ -> tensor <8 xi64 > {
88
+ %out = tensor.empty () : tensor <8 xi64 >
87
89
%r = linalg.generic #trait_1d
88
90
ins (%arga: tensor <8 xi64 , #SparseVector >)
89
91
outs (%out: tensor <8 xi64 >) {
@@ -99,8 +101,9 @@ module {
99
101
//
100
102
// Kernel that uses indices in the index notation (conjunction).
101
103
//
102
- func.func @sparse_index_2d_conj (%arga: tensor <3 x4 xi64 , #SparseMatrix >,
103
- %out: tensor <3 x4 xi64 >) -> tensor <3 x4 xi64 > {
104
+ func.func @sparse_index_2d_conj (%arga: tensor <3 x4 xi64 , #SparseMatrix >)
105
+ -> tensor <3 x4 xi64 > {
106
+ %out = tensor.empty () : tensor <3 x4 xi64 >
104
107
%r = linalg.generic #trait_2d
105
108
ins (%arga: tensor <3 x4 xi64 , #SparseMatrix >)
106
109
outs (%out: tensor <3 x4 xi64 >) {
@@ -119,8 +122,9 @@ module {
119
122
//
120
123
// Kernel that uses indices in the index notation (disjunction).
121
124
//
122
- func.func @sparse_index_2d_disj (%arga: tensor <3 x4 xi64 , #SparseMatrix >,
123
- %out: tensor <3 x4 xi64 >) -> tensor <3 x4 xi64 > {
125
+ func.func @sparse_index_2d_disj (%arga: tensor <3 x4 xi64 , #SparseMatrix >)
126
+ -> tensor <3 x4 xi64 > {
127
+ %out = tensor.empty () : tensor <3 x4 xi64 >
124
128
%r = linalg.generic #trait_2d
125
129
ins (%arga: tensor <3 x4 xi64 , #SparseMatrix >)
126
130
outs (%out: tensor <3 x4 xi64 >) {
@@ -161,20 +165,15 @@ module {
161
165
[ 1 , 1 , 3 , 4 ] ]> : tensor <3 x4 xi64 >
162
166
%dm = sparse_tensor.convert %m2 : tensor <3 x4 xi64 > to tensor <3 x4 xi64 , #SparseMatrix >
163
167
164
- // Setup out tensors.
165
- // Note: Constants bufferize to read-only buffers.
166
- %init_8 = tensor.empty () : tensor <8 xi64 >
167
- %init_3_4 = tensor.empty () : tensor <3 x4 xi64 >
168
-
169
168
// Call the kernels.
170
- %0 = call @sparse_index_1d_conj (%sv , %init_8 ) : (tensor <8 xi64 , #SparseVector >, tensor < 8 x i64 >) -> tensor <8 xi64 >
171
- %1 = call @sparse_index_1d_disj (%sv , %init_8 ) : (tensor <8 xi64 , #SparseVector >, tensor < 8 x i64 >) -> tensor <8 xi64 >
172
- %2 = call @sparse_index_1d_conj (%dv , %init_8 ) : (tensor <8 xi64 , #SparseVector >, tensor < 8 x i64 >) -> tensor <8 xi64 >
173
- %3 = call @sparse_index_1d_disj (%dv , %init_8 ) : (tensor <8 xi64 , #SparseVector >, tensor < 8 x i64 >) -> tensor <8 xi64 >
174
- %4 = call @sparse_index_2d_conj (%sm , %init_3_4 ) : (tensor <3 x4 xi64 , #SparseMatrix >, tensor < 3 x 4 x i64 >) -> tensor <3 x4 xi64 >
175
- %5 = call @sparse_index_2d_disj (%sm , %init_3_4 ) : (tensor <3 x4 xi64 , #SparseMatrix >, tensor < 3 x 4 x i64 >) -> tensor <3 x4 xi64 >
176
- %6 = call @sparse_index_2d_conj (%dm , %init_3_4 ) : (tensor <3 x4 xi64 , #SparseMatrix >, tensor < 3 x 4 x i64 >) -> tensor <3 x4 xi64 >
177
- %7 = call @sparse_index_2d_disj (%dm , %init_3_4 ) : (tensor <3 x4 xi64 , #SparseMatrix >, tensor < 3 x 4 x i64 >) -> tensor <3 x4 xi64 >
169
+ %0 = call @sparse_index_1d_conj (%sv ) : (tensor <8 xi64 , #SparseVector >) -> tensor <8 xi64 >
170
+ %1 = call @sparse_index_1d_disj (%sv ) : (tensor <8 xi64 , #SparseVector >) -> tensor <8 xi64 >
171
+ %2 = call @sparse_index_1d_conj (%dv ) : (tensor <8 xi64 , #SparseVector >) -> tensor <8 xi64 >
172
+ %3 = call @sparse_index_1d_disj (%dv ) : (tensor <8 xi64 , #SparseVector >) -> tensor <8 xi64 >
173
+ %4 = call @sparse_index_2d_conj (%sm ) : (tensor <3 x4 xi64 , #SparseMatrix >) -> tensor <3 x4 xi64 >
174
+ %5 = call @sparse_index_2d_disj (%sm ) : (tensor <3 x4 xi64 , #SparseMatrix >) -> tensor <3 x4 xi64 >
175
+ %6 = call @sparse_index_2d_conj (%dm ) : (tensor <3 x4 xi64 , #SparseMatrix >) -> tensor <3 x4 xi64 >
176
+ %7 = call @sparse_index_2d_disj (%dm ) : (tensor <3 x4 xi64 , #SparseMatrix >) -> tensor <3 x4 xi64 >
178
177
179
178
//
180
179
// Verify result.
0 commit comments