Skip to content

Commit de5e4d7

Browse files
authored
[mlir][sparse] fix error when convolution stride is applied on a dens… (#79521)
…e level.
1 parent 59e9060 commit de5e4d7

File tree

4 files changed

+30
-6
lines changed

4 files changed

+30
-6
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,8 +313,8 @@ void LoopEmitter::initSubSectIterator(OpBuilder &builder, Location loc) {
313313
Value loopHi = loopHighs[loop];
314314
size = ADDI(size, MULI(loopHi, C_IDX(stride)));
315315
}
316-
it = makeNonEmptySubSectIterator(builder, loc, parent, std::move(lvlIt),
317-
size, curDep.second);
316+
it = makeNonEmptySubSectIterator(builder, loc, parent, loopHighs[loop],
317+
std::move(lvlIt), size, curDep.second);
318318
} else {
319319
Value size = loopHighs[loop];
320320
const SparseIterator &subSectIter = *iters[t][lvl].back();

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,17 +1271,20 @@ static const IterType *unwrapFilter(const SparseIterator *it) {
12711271
}
12721272

12731273
std::unique_ptr<SparseIterator> sparse_tensor::makeNonEmptySubSectIterator(
1274-
OpBuilder &b, Location l, const SparseIterator *parent,
1274+
OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound,
12751275
std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride) {
12761276

12771277
// Try unwrap the NonEmptySubSectIterator from a filter parent.
12781278
parent = tryUnwrapFilter<NonEmptySubSectIterator>(parent);
12791279
auto it = std::make_unique<NonEmptySubSectIterator>(
12801280
b, l, parent, std::move(delegate), size);
12811281

1282-
if (stride != 1)
1282+
if (stride != 1) {
1283+
// TODO: We can safely skip bound checking on sparse levels, but for dense
1284+
// iteration space, we need the bound to infer the dense loop range.
12831285
return std::make_unique<FilterIterator>(std::move(it), /*offset=*/C_IDX(0),
1284-
C_IDX(stride), /*size=*/C_IDX(-1));
1286+
C_IDX(stride), /*size=*/loopBound);
1287+
}
12851288
return it;
12861289
}
12871290

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit, Value offset,
246246
/// Helper function to create a SparseIterator object that iterate over the
247247
/// non-empty subsections set.
248248
std::unique_ptr<SparseIterator> makeNonEmptySubSectIterator(
249-
OpBuilder &b, Location l, const SparseIterator *parent,
249+
OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound,
250250
std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride);
251251

252252
/// Helper function to create a SparseIterator object that iterate over a

mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_strided_conv_2d_nhwc_hwcf.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,14 @@ func.func @conv_2d_nhwc_hwcf_CDCC(%arg0: tensor<?x?x?x?xf32, #CDCC>, %arg1: tens
6969
return %ret : tensor<?x?x?x?xf32>
7070
}
7171

72+
func.func @conv_2d_nhwc_hwcf_dual_CDCC(%arg0: tensor<?x?x?x?xf32, #CDCC>, %arg1: tensor<?x?x?x?xf32, #CDCC>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
73+
%ret = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
74+
strides = dense<2> : tensor<2xi64>}
75+
ins (%arg0, %arg1: tensor<?x?x?x?xf32, #CDCC>, tensor<?x?x?x?xf32, #CDCC>)
76+
outs (%arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
77+
return %ret : tensor<?x?x?x?xf32>
78+
}
79+
7280

7381
func.func @entry() {
7482
%c0 = arith.constant 0 : index
@@ -87,16 +95,28 @@ func.func @entry() {
8795

8896
%in2D_nhwc_CCCC = sparse_tensor.convert %in2D_nhwc
8997
: tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32, #CCCC>
98+
%filter2D_nhwc_CDCC = sparse_tensor.convert %filter2D_nhwc
99+
: tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32, #CDCC>
90100
%in2D_nhwc_CDCC = sparse_tensor.convert %in2D_nhwc
91101
: tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32, #CDCC>
92102

93103
%dense_ret = call @conv_2d_nhwc_hwcf(%in2D_nhwc, %filter2D_nhwc, %out2D_nhwc) : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> (tensor<?x?x?x?xf32>)
94104
%CCCC_ret = call @conv_2d_nhwc_hwcf_CCCC(%in2D_nhwc_CCCC, %filter2D_nhwc, %out2D_nhwc) : (tensor<?x?x?x?xf32, #CCCC>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> (tensor<?x?x?x?xf32>)
95105
%CDCC_ret = call @conv_2d_nhwc_hwcf_CDCC(%in2D_nhwc_CDCC, %filter2D_nhwc, %out2D_nhwc) : (tensor<?x?x?x?xf32, #CDCC>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> (tensor<?x?x?x?xf32>)
96106

107+
%dual_CDCC_ret = call @conv_2d_nhwc_hwcf_dual_CDCC(%in2D_nhwc_CDCC, %filter2D_nhwc_CDCC, %out2D_nhwc)
108+
: (tensor<?x?x?x?xf32, #CDCC>, tensor<?x?x?x?xf32, #CDCC>, tensor<?x?x?x?xf32>) -> (tensor<?x?x?x?xf32>)
109+
97110
// CHECK: ( ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 20 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
98111
// CHECK-SAME: ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
99112
// CHECK-SAME: ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ) )
113+
%v_dual = vector.transfer_read %dual_CDCC_ret[%c0, %c0, %c0, %c0], %zero
114+
: tensor<?x?x?x?xf32>, vector<3x3x3x1xf32>
115+
vector.print %v_dual : vector<3x3x3x1xf32>
116+
117+
// CHECK-NEXT: ( ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 20 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
118+
// CHECK-SAME: ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
119+
// CHECK-SAME: ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ) )
100120
%dense_v = vector.transfer_read %dense_ret[%c0, %c0, %c0, %c0], %zero
101121
: tensor<?x?x?x?xf32>, vector<3x3x3x1xf32>
102122
vector.print %dense_v : vector<3x3x3x1xf32>
@@ -120,6 +140,7 @@ func.func @entry() {
120140
bufferization.dealloc_tensor %filter2D_nhwc : tensor<?x?x?x?xf32>
121141
bufferization.dealloc_tensor %out2D_nhwc : tensor<?x?x?x?xf32>
122142

143+
bufferization.dealloc_tensor %filter2D_nhwc_CDCC : tensor<?x?x?x?xf32, #CDCC>
123144
bufferization.dealloc_tensor %in2D_nhwc_CCCC : tensor<?x?x?x?xf32, #CCCC>
124145
bufferization.dealloc_tensor %in2D_nhwc_CDCC : tensor<?x?x?x?xf32, #CDCC>
125146
return

0 commit comments

Comments
 (0)