Skip to content

Commit 8206b75

Browse files
authored
[mlir][sparse] fix crash when generate rotated convolution kernels. (#74146)
1 parent 3d89f2a commit 8206b75

File tree

2 files changed

+40
-5
lines changed

2 files changed

+40
-5
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1329,11 +1329,11 @@ void LoopEmitter::enterTensorsAtDenseLvls(
13291329
// Update the slice information as we enter the new loop.
13301330
info.minCrd = info.offset = MULI(iv, C_IDX(stride));
13311331
info.isNonEmpty = constantI1(builder, loc, true);
1332-
levelReducedDep[tid][lvl]++;
13331332
} else {
13341333
posits[tid][lvl] =
13351334
genAddress(builder, loc, tid, lvl, ADDI(info.offset, iv));
13361335
}
1336+
levelReducedDep[tid][lvl]++;
13371337
} else {
13381338
// Skips the synthetic tensor
13391339
if (isSynTensor(tid))
@@ -1369,11 +1369,11 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
13691369
// moves forward to the next slice.
13701370
invalidateSliceIterIdx(rewriter, loc, tid, lvl);
13711371
info.minCrd = info.offset = info.isNonEmpty = Value();
1372-
levelReducedDep[tid][lvl]--;
13731372
} else {
13741373
forwardsReducedSliceLevelTreeIt(rewriter, loc, tid, lvl,
13751374
constantIndex(rewriter, loc, 1));
13761375
}
1376+
levelReducedDep[tid][lvl]--;
13771377
}
13781378
if (auto forOp = llvm::dyn_cast<scf::ForOp>(loopInfo.loop)) {
13791379
if (!reduc.empty()) {
@@ -1460,8 +1460,8 @@ void LoopEmitter::forwardsReducedSliceLevelTreeIt(OpBuilder &builder,
14601460
// level (but not resolved). Since we forward an iterator at higher level of
14611461
// the tree, the subtree need to be pruned.
14621462
Level leafLvl = rootLvl + 1;
1463-
while (leafLvl < stt.getLvlRank() && !dependentLvlMap[tid][leafLvl].empty()) {
1464-
assert(depFullyReduced(tid, leafLvl));
1463+
while (leafLvl < stt.getLvlRank() && !dependentLvlMap[tid][leafLvl].empty() &&
1464+
depFullyReduced(tid, leafLvl)) {
14651465
leafLvl++;
14661466
}
14671467

mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_2d.mlir

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@
3838
map = (d0, d1) -> (d1 : dense, d0 : compressed)
3939
}>
4040

41+
#map = affine_map<(d0, d1, d2, d3) -> (d0 + d1, d3 + d2)>
42+
#map1 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
43+
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d3)>
44+
4145
// An example of a 2D convolution with a sparse filter.
4246
module {
4347

@@ -50,6 +54,21 @@ module {
5054
return %0 : tensor<6x6xi32>
5155
}
5256

57+
func.func @conv2d_CSR_dense_rotated(%arg0: tensor<8x8xi32, #CSR>,
58+
%arg1: tensor<3x3xi32>) -> tensor<6x6xi32> {
59+
%s = tensor.empty() : tensor<6x6xi32>
60+
%0 = linalg.generic {indexing_maps = [#map, #map1, #map2],
61+
iterator_types = ["parallel", "reduction", "reduction", "parallel"]}
62+
ins(%arg0, %arg1 : tensor<8x8xi32, #CSR>, tensor<3x3xi32>)
63+
outs(%s : tensor<6x6xi32>) attrs = {sorted = true} {
64+
^bb0(%in: i32, %in_0: i32, %out: i32):
65+
%1 = arith.muli %in, %in_0 : i32
66+
%2 = arith.addi %out, %1 : i32
67+
linalg.yield %2 : i32
68+
} -> tensor<6x6xi32>
69+
return %0 : tensor<6x6xi32>
70+
}
71+
5372
func.func @conv2d_sparse_out(%input: tensor<8x8xi32>,
5473
%filter: tensor<3x3xi32>) -> tensor<6x6xi32, #DCSR> {
5574
%s = tensor.empty() : tensor<6x6xi32, #DCSR>
@@ -146,7 +165,9 @@ module {
146165
%5 = call @conv2d_all_sparse_CSC(%sparse_input_CSC, %filter)
147166
: (tensor<8x8xi32, #CSC>,
148167
tensor<3x3xi32>) -> tensor<6x6xi32, #CSC>
149-
168+
%6 = call @conv2d_CSR_dense_rotated(%sparse_input_CSR, %filter)
169+
: (tensor<8x8xi32, #CSR>,
170+
tensor<3x3xi32>) -> tensor<6x6xi32>
150171

151172
// Verify the output.
152173
//
@@ -236,6 +257,20 @@ module {
236257
: tensor<6x6xi32>, vector<6x6xi32>
237258
vector.print %v5 : vector<6x6xi32>
238259

260+
//
261+
// Should be the same as dense output
262+
// CHECK: ( ( 0, 0, -1, -6, -1, 6 ),
263+
// CHECK-SAME: ( -1, 0, 1, 0, 1, 0 ),
264+
// CHECK-SAME: ( 0, -1, 1, 0, 0, 0 ),
265+
// CHECK-SAME: ( -1, 0, 0, 0, 0, 0 ),
266+
// CHECK-SAME: ( 0, 0, 3, 6, -3, -6 ),
267+
// CHECK-SAME: ( 2, -1, 3, 0, -3, 0 ) )
268+
//
269+
%v6 = vector.transfer_read %6[%c0, %c0], %i0
270+
: tensor<6x6xi32>, vector<6x6xi32>
271+
vector.print %v : vector<6x6xi32>
272+
273+
239274
// Release the resources.
240275
bufferization.dealloc_tensor %sparse_input_DCSR : tensor<8x8xi32, #DCSR>
241276
bufferization.dealloc_tensor %sparse_input_CSR : tensor<8x8xi32, #CSR>

0 commit comments

Comments
 (0)