Skip to content

Commit 5953f19

Browse files
[mlir][sparse_tensor] Fix memory leak in sparse_index_dense.mlir (#137454)
1 parent 1b5cd1d commit 5953f19

File tree

1 file changed

+20
-21
lines changed

1 file changed

+20
-21
lines changed

mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_index_dense.mlir

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,9 @@ module {
6565
//
6666
// Kernel that uses index in the index notation (conjunction).
6767
//
68-
func.func @sparse_index_1d_conj(%arga: tensor<8xi64, #SparseVector>,
69-
%out: tensor<8xi64>) -> tensor<8xi64> {
68+
func.func @sparse_index_1d_conj(%arga: tensor<8xi64, #SparseVector>)
69+
-> tensor<8xi64> {
70+
%out = tensor.empty() : tensor<8xi64>
7071
%r = linalg.generic #trait_1d
7172
ins(%arga: tensor<8xi64, #SparseVector>)
7273
outs(%out: tensor<8xi64>) {
@@ -82,8 +83,9 @@ module {
8283
//
8384
// Kernel that uses index in the index notation (disjunction).
8485
//
85-
func.func @sparse_index_1d_disj(%arga: tensor<8xi64, #SparseVector>,
86-
%out: tensor<8xi64>) -> tensor<8xi64> {
86+
func.func @sparse_index_1d_disj(%arga: tensor<8xi64, #SparseVector>)
87+
-> tensor<8xi64> {
88+
%out = tensor.empty() : tensor<8xi64>
8789
%r = linalg.generic #trait_1d
8890
ins(%arga: tensor<8xi64, #SparseVector>)
8991
outs(%out: tensor<8xi64>) {
@@ -99,8 +101,9 @@ module {
99101
//
100102
// Kernel that uses indices in the index notation (conjunction).
101103
//
102-
func.func @sparse_index_2d_conj(%arga: tensor<3x4xi64, #SparseMatrix>,
103-
%out: tensor<3x4xi64>) -> tensor<3x4xi64> {
104+
func.func @sparse_index_2d_conj(%arga: tensor<3x4xi64, #SparseMatrix>)
105+
-> tensor<3x4xi64> {
106+
%out = tensor.empty() : tensor<3x4xi64>
104107
%r = linalg.generic #trait_2d
105108
ins(%arga: tensor<3x4xi64, #SparseMatrix>)
106109
outs(%out: tensor<3x4xi64>) {
@@ -119,8 +122,9 @@ module {
119122
//
120123
// Kernel that uses indices in the index notation (disjunction).
121124
//
122-
func.func @sparse_index_2d_disj(%arga: tensor<3x4xi64, #SparseMatrix>,
123-
%out: tensor<3x4xi64>) -> tensor<3x4xi64> {
125+
func.func @sparse_index_2d_disj(%arga: tensor<3x4xi64, #SparseMatrix>)
126+
-> tensor<3x4xi64> {
127+
%out = tensor.empty() : tensor<3x4xi64>
124128
%r = linalg.generic #trait_2d
125129
ins(%arga: tensor<3x4xi64, #SparseMatrix>)
126130
outs(%out: tensor<3x4xi64>) {
@@ -161,20 +165,15 @@ module {
161165
[ 1, 1, 3, 4 ] ]> : tensor<3x4xi64>
162166
%dm = sparse_tensor.convert %m2 : tensor<3x4xi64> to tensor<3x4xi64, #SparseMatrix>
163167

164-
// Setup out tensors.
165-
// Note: Constants bufferize to read-only buffers.
166-
%init_8 = tensor.empty() : tensor<8xi64>
167-
%init_3_4 = tensor.empty() : tensor<3x4xi64>
168-
169168
// Call the kernels.
170-
%0 = call @sparse_index_1d_conj(%sv, %init_8) : (tensor<8xi64, #SparseVector>, tensor<8xi64>) -> tensor<8xi64>
171-
%1 = call @sparse_index_1d_disj(%sv, %init_8) : (tensor<8xi64, #SparseVector>, tensor<8xi64>) -> tensor<8xi64>
172-
%2 = call @sparse_index_1d_conj(%dv, %init_8) : (tensor<8xi64, #SparseVector>, tensor<8xi64>) -> tensor<8xi64>
173-
%3 = call @sparse_index_1d_disj(%dv, %init_8) : (tensor<8xi64, #SparseVector>, tensor<8xi64>) -> tensor<8xi64>
174-
%4 = call @sparse_index_2d_conj(%sm, %init_3_4) : (tensor<3x4xi64, #SparseMatrix>, tensor<3x4xi64>) -> tensor<3x4xi64>
175-
%5 = call @sparse_index_2d_disj(%sm, %init_3_4) : (tensor<3x4xi64, #SparseMatrix>, tensor<3x4xi64>) -> tensor<3x4xi64>
176-
%6 = call @sparse_index_2d_conj(%dm, %init_3_4) : (tensor<3x4xi64, #SparseMatrix>, tensor<3x4xi64>) -> tensor<3x4xi64>
177-
%7 = call @sparse_index_2d_disj(%dm, %init_3_4) : (tensor<3x4xi64, #SparseMatrix>, tensor<3x4xi64>) -> tensor<3x4xi64>
169+
%0 = call @sparse_index_1d_conj(%sv) : (tensor<8xi64, #SparseVector>) -> tensor<8xi64>
170+
%1 = call @sparse_index_1d_disj(%sv) : (tensor<8xi64, #SparseVector>) -> tensor<8xi64>
171+
%2 = call @sparse_index_1d_conj(%dv) : (tensor<8xi64, #SparseVector>) -> tensor<8xi64>
172+
%3 = call @sparse_index_1d_disj(%dv) : (tensor<8xi64, #SparseVector>) -> tensor<8xi64>
173+
%4 = call @sparse_index_2d_conj(%sm) : (tensor<3x4xi64, #SparseMatrix>) -> tensor<3x4xi64>
174+
%5 = call @sparse_index_2d_disj(%sm) : (tensor<3x4xi64, #SparseMatrix>) -> tensor<3x4xi64>
175+
%6 = call @sparse_index_2d_conj(%dm) : (tensor<3x4xi64, #SparseMatrix>) -> tensor<3x4xi64>
176+
%7 = call @sparse_index_2d_disj(%dm) : (tensor<3x4xi64, #SparseMatrix>) -> tensor<3x4xi64>
178177

179178
//
180179
// Verify result.

0 commit comments

Comments
 (0)