Skip to content

Commit 78e2b74

Browse files
authored
[mlir][sparse] fix bugs when generate sparse conv_3d kernels. (#74561)
1 parent 861600f commit 78e2b74

File tree

2 files changed

+66
-15
lines changed

2 files changed

+66
-15
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,28 +1454,19 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
14541454
void LoopEmitter::forwardsReducedSliceLevelTreeIt(OpBuilder &builder,
14551455
Location loc, TensorId tid,
14561456
Level rootLvl, Value fcnt) {
1457+
14571458
auto stt = getSparseTensorType(tensors[tid]);
14581459

14591460
// Finds a [Lvl, leafLvl) range, and all level in between are fully reduced
1460-
// level (but not resolved). Since we forward an iterator at higher level of
1461-
// the tree, the subtree need to be pruned.
1461+
// sparse levels (but not resolved). Since we forward an iterator at higher
1462+
// level of the tree, the subtree need to be pruned.
14621463
Level leafLvl = rootLvl + 1;
1463-
while (leafLvl < stt.getLvlRank() && !dependentLvlMap[tid][leafLvl].empty() &&
1464-
depFullyReduced(tid, leafLvl)) {
1464+
while (leafLvl < stt.getLvlRank() && depFullyReduced(tid, leafLvl) &&
1465+
!stt.isDenseLvl(leafLvl)) {
14651466
leafLvl++;
14661467
}
14671468

14681469
Level curLvl = rootLvl + 1;
1469-
// Prunes all denses subtree.
1470-
while (curLvl < leafLvl && isDenseLT(lvlTypes[tid][curLvl])) {
1471-
// One step forward in parent level results in forwarding `slice.size` step
1472-
// in child dense level.
1473-
auto [size, stride] = sliceMeta[tid][curLvl].back();
1474-
assert(stride == 1 && "Not yet implemented");
1475-
fcnt = MULI(size, fcnt);
1476-
curLvl++;
1477-
}
1478-
14791470
Value nxPosPtr = nullptr;
14801471
if (curLvl < leafLvl) {
14811472
assert(!isDenseLT(lvlTypes[tid][curLvl]));

mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_3d.mlir

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,14 @@
3838
map = (d0, d1, d2) -> (d0 : compressed, d1 : dense, d2 : compressed)
3939
}>
4040

41-
#DDC = #sparse_tensor.encoding<{
41+
#DCC = #sparse_tensor.encoding<{
4242
map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : compressed)
4343
}>
4444

45+
#DDC = #sparse_tensor.encoding<{
46+
map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 : compressed)
47+
}>
48+
4549
// Creates and returns 3-D buffer of size (%s1, %s2, %s3) filled with the value %f
4650
func.func @alloc_3d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %f : f32) -> tensor<?x?x?xf32> {
4751
%buf = tensor.empty(%s1, %s2, %s3) : tensor<?x?x?xf32>
@@ -74,6 +78,15 @@ func.func @conv_3d_CDC(%arg0: tensor<?x?x?xf32, #CDC>, %arg1: tensor<?x?x?xf32>)
7478
return %ret : tensor<?x?x?xf32, #CDC>
7579
}
7680

81+
func.func @conv_3d_DCC(%arg0: tensor<?x?x?xf32, #DCC>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32, #DCC> {
82+
%c6 = arith.constant 6 : index
83+
%s = tensor.empty(%c6, %c6, %c6) : tensor<?x?x?xf32, #DCC>
84+
%ret = linalg.conv_3d
85+
ins (%arg0, %arg1: tensor<?x?x?xf32, #DCC>, tensor<?x?x?xf32>)
86+
outs (%s: tensor<?x?x?xf32, #DCC>) -> tensor<?x?x?xf32, #DCC>
87+
return %ret : tensor<?x?x?xf32, #DCC>
88+
}
89+
7790
func.func @conv_3d_DDC(%arg0: tensor<?x?x?xf32, #DDC>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32, #DDC> {
7891
%c6 = arith.constant 6 : index
7992
%s = tensor.empty(%c6, %c6, %c6) : tensor<?x?x?xf32, #DDC>
@@ -102,12 +115,15 @@ func.func @entry() {
102115
: tensor<?x?x?xf32> to tensor<?x?x?xf32, #CCC>
103116
%in3D_CDC = sparse_tensor.convert %in3D
104117
: tensor<?x?x?xf32> to tensor<?x?x?xf32, #CDC>
118+
%in3D_DCC = sparse_tensor.convert %in3D
119+
: tensor<?x?x?xf32> to tensor<?x?x?xf32, #DCC>
105120
%in3D_DDC = sparse_tensor.convert %in3D
106121
: tensor<?x?x?xf32> to tensor<?x?x?xf32, #DDC>
107122

108123
%dense_ret = call @conv_3d(%in3D, %filter3D, %out3D) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>)
109124
%CCC_ret = call @conv_3d_CCC(%in3D_CCC, %filter3D) : (tensor<?x?x?xf32, #CCC>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32, #CCC>)
110125
%CDC_ret = call @conv_3d_CDC(%in3D_CDC, %filter3D) : (tensor<?x?x?xf32, #CDC>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32, #CDC>)
126+
%DCC_ret = call @conv_3d_DCC(%in3D_DCC, %filter3D) : (tensor<?x?x?xf32, #DCC>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32, #DCC>)
111127
%DDC_ret = call @conv_3d_DDC(%in3D_DDC, %filter3D) : (tensor<?x?x?xf32, #DDC>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32, #DDC>)
112128

113129
// CHECK:( ( ( 108, 108, 108, 108, 108, 108 ),
@@ -276,6 +292,48 @@ func.func @entry() {
276292
: tensor<?x?x?xf32>, vector<6x6x6xf32>
277293
vector.print %v2 : vector<6x6x6xf32>
278294

295+
// CHECK-NEXT:( ( ( 108, 108, 108, 108, 108, 108 ),
296+
// CHECK-SAME: ( 124, 108, 108, 108, 108, 108 ),
297+
// CHECK-SAME: ( 124, 108, 108, 108, 108, 108 ),
298+
// CHECK-SAME: ( 124, 108, 108, 108, 108, 108 ),
299+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
300+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ),
301+
// CHECK-SAME: ( ( 108, 108, 108, 108, 108, 108 ),
302+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
303+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
304+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
305+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
306+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ),
307+
// CHECK-SAME: ( ( 108, 108, 108, 108, 108, 108 ),
308+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
309+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
310+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
311+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
312+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ),
313+
// CHECK-SAME: ( ( 108, 108, 108, 108, 108, 108 ),
314+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
315+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
316+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
317+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
318+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ),
319+
// CHECK-SAME: ( ( 108, 108, 108, 108, 108, 108 ),
320+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
321+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
322+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
323+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
324+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ),
325+
// CHECK-SAME: ( ( 108, 108, 108, 108, 108, 108 ),
326+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
327+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
328+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
329+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
330+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ) )
331+
%4 = sparse_tensor.convert %DCC_ret
332+
: tensor<?x?x?xf32, #DCC> to tensor<?x?x?xf32>
333+
%v4 = vector.transfer_read %3[%c0, %c0, %c0], %zero
334+
: tensor<?x?x?xf32>, vector<6x6x6xf32>
335+
vector.print %v2 : vector<6x6x6xf32>
336+
279337
// Free the resources
280338
bufferization.dealloc_tensor %in3D : tensor<?x?x?xf32>
281339
bufferization.dealloc_tensor %filter3D : tensor<?x?x?xf32>
@@ -284,9 +342,11 @@ func.func @entry() {
284342
bufferization.dealloc_tensor %in3D_CDC : tensor<?x?x?xf32, #CDC>
285343
bufferization.dealloc_tensor %in3D_CCC : tensor<?x?x?xf32, #CCC>
286344
bufferization.dealloc_tensor %in3D_DDC : tensor<?x?x?xf32, #DDC>
345+
bufferization.dealloc_tensor %in3D_DCC : tensor<?x?x?xf32, #DCC>
287346

288347
bufferization.dealloc_tensor %CCC_ret : tensor<?x?x?xf32, #CCC>
289348
bufferization.dealloc_tensor %CDC_ret : tensor<?x?x?xf32, #CDC>
290349
bufferization.dealloc_tensor %DDC_ret : tensor<?x?x?xf32, #DDC>
350+
bufferization.dealloc_tensor %DCC_ret : tensor<?x?x?xf32, #DCC>
291351
return
292352
}

0 commit comments

Comments
 (0)