Skip to content

Commit 8c8aecd

Browse files
author
Peiming Liu
committed
[mlir][sparse] Supporting (non)uniqueness in SparseTensorStorage::lexDiff.
Fix copied from https://reviews.llvm.org/D156946 but with a legit test case that triggers the bug. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D158578
1 parent cd7af14 commit 8c8aecd

File tree

2 files changed

+42
-13
lines changed

2 files changed

+42
-13
lines changed

mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -704,12 +704,17 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
704704
/// in the argument differ from those in the current cursor.
705705
uint64_t lexDiff(const uint64_t *lvlCoords) const {
706706
const uint64_t lvlRank = getLvlRank();
707-
for (uint64_t l = 0; l < lvlRank; ++l)
708-
if (lvlCoords[l] > lvlCursor[l])
707+
for (uint64_t l = 0; l < lvlRank; ++l) {
708+
const auto crd = lvlCoords[l];
709+
const auto cur = lvlCursor[l];
710+
if (crd > cur || (crd == cur && !isUniqueLvl(l)))
709711
return l;
710-
else
711-
assert(lvlCoords[l] == lvlCursor[l] && "non-lexicographic insertion");
712-
assert(0 && "duplicate insertion");
712+
if (crd < cur) {
713+
assert(false && "non-lexicographic insertion");
714+
return -1u;
715+
}
716+
}
717+
assert(false && "duplicate insertion");
713718
return -1u;
714719
}
715720

mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_coo_test.mlir

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,22 @@ module {
8787
return %0 : tensor<8x8xf32>
8888
}
8989

90+
func.func @add_coo_coo_out_coo(%arga: tensor<8x8xf32, #SortedCOO>,
91+
%argb: tensor<8x8xf32, #SortedCOO>)
92+
-> tensor<8x8xf32, #SortedCOO> {
93+
%init = tensor.empty() : tensor<8x8xf32, #SortedCOO>
94+
%0 = linalg.generic #trait
95+
ins(%arga, %argb: tensor<8x8xf32, #SortedCOO>,
96+
tensor<8x8xf32, #SortedCOO>)
97+
outs(%init: tensor<8x8xf32, #SortedCOO>) {
98+
^bb(%a: f32, %b: f32, %x: f32):
99+
%0 = arith.addf %a, %b : f32
100+
linalg.yield %0 : f32
101+
} -> tensor<8x8xf32, #SortedCOO>
102+
return %0 : tensor<8x8xf32, #SortedCOO>
103+
}
104+
105+
90106
func.func @add_coo_dense(%arga: tensor<8x8xf32>,
91107
%argb: tensor<8x8xf32, #SortedCOO>)
92108
-> tensor<8x8xf32> {
@@ -149,17 +165,21 @@ module {
149165
%C3 = call @add_coo_coo(%COO_A, %COO_B) : (tensor<8x8xf32, #SortedCOO>,
150166
tensor<8x8xf32, #SortedCOO>)
151167
-> tensor<8x8xf32>
168+
%COO_RET = call @add_coo_coo_out_coo(%COO_A, %COO_B) : (tensor<8x8xf32, #SortedCOO>,
169+
tensor<8x8xf32, #SortedCOO>)
170+
-> tensor<8x8xf32, #SortedCOO>
171+
%C4 = sparse_tensor.convert %COO_RET : tensor<8x8xf32, #SortedCOO> to tensor<8x8xf32>
152172
//
153173
// Verify computed matrix C.
154174
//
155-
// CHECK-COUNT-3: ( 8.8, 4.8, 6.8, 4.8, 8.8, 6.1, 14.8, 16.8 )
156-
// CHECK-NEXT-COUNT-3: ( 4.4, 4.4, 4.4, 8.4, 8.4, 12.4, 16.4, 16.4 )
157-
// CHECK-NEXT-COUNT-3: ( 8.8, 4.8, 6.8, 8.8, 8.8, 12.8, 14.8, 15.8 )
158-
// CHECK-NEXT-COUNT-3: ( 4.3, 5.3, 6.3, 8.3, 8.3, 12.3, 14.3, 16.3 )
159-
// CHECK-NEXT-COUNT-3: ( 4.5, 4.5, 6.5, 8.5, 8.5, 12.5, 14.5, 16.5 )
160-
// CHECK-NEXT-COUNT-3: ( 9.9, 4.9, 6.9, 8.9, 8.9, 12.9, 15.9, 16.9 )
161-
// CHECK-NEXT-COUNT-3: ( 12.1, 6.1, 5.1, 9.1, 9.1, 13.1, 15.1, 17.1 )
162-
// CHECK-NEXT-COUNT-3: ( 15.4, 5.4, 7.4, 5.4, 11.4, 10.4, 11.4, 9.4 )
175+
// CHECK-COUNT-4: ( 8.8, 4.8, 6.8, 4.8, 8.8, 6.1, 14.8, 16.8 )
176+
// CHECK-NEXT-COUNT-4: ( 4.4, 4.4, 4.4, 8.4, 8.4, 12.4, 16.4, 16.4 )
177+
// CHECK-NEXT-COUNT-4: ( 8.8, 4.8, 6.8, 8.8, 8.8, 12.8, 14.8, 15.8 )
178+
// CHECK-NEXT-COUNT-4: ( 4.3, 5.3, 6.3, 8.3, 8.3, 12.3, 14.3, 16.3 )
179+
// CHECK-NEXT-COUNT-4: ( 4.5, 4.5, 6.5, 8.5, 8.5, 12.5, 14.5, 16.5 )
180+
// CHECK-NEXT-COUNT-4: ( 9.9, 4.9, 6.9, 8.9, 8.9, 12.9, 15.9, 16.9 )
181+
// CHECK-NEXT-COUNT-4: ( 12.1, 6.1, 5.1, 9.1, 9.1, 13.1, 15.1, 17.1 )
182+
// CHECK-NEXT-COUNT-4: ( 15.4, 5.4, 7.4, 5.4, 11.4, 10.4, 11.4, 9.4 )
163183
//
164184
%f0 = arith.constant 0.0 : f32
165185
scf.for %i = %c0 to %c8 step %c1 {
@@ -169,9 +189,12 @@ module {
169189
: tensor<8x8xf32>, vector<8xf32>
170190
%v3 = vector.transfer_read %C3[%i, %c0], %f0
171191
: tensor<8x8xf32>, vector<8xf32>
192+
%v4 = vector.transfer_read %C4[%i, %c0], %f0
193+
: tensor<8x8xf32>, vector<8xf32>
172194
vector.print %v1 : vector<8xf32>
173195
vector.print %v2 : vector<8xf32>
174196
vector.print %v3 : vector<8xf32>
197+
vector.print %v4 : vector<8xf32>
175198
}
176199

177200
// Release resources.
@@ -181,6 +204,7 @@ module {
181204
bufferization.dealloc_tensor %CSR_A : tensor<8x8xf32, #CSR>
182205
bufferization.dealloc_tensor %COO_A : tensor<8x8xf32, #SortedCOO>
183206
bufferization.dealloc_tensor %COO_B : tensor<8x8xf32, #SortedCOO>
207+
bufferization.dealloc_tensor %COO_RET : tensor<8x8xf32, #SortedCOO>
184208

185209

186210
return

0 commit comments

Comments
 (0)