Skip to content

Commit 1ac6846

Browse files
authored
[mlir][sparse] support sparse dilated convolution. (#80470)
1 parent e12be9c commit 1ac6846

File tree

4 files changed

+179
-34
lines changed

4 files changed

+179
-34
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -313,16 +313,16 @@ void LoopEmitter::initSubSectIterator(OpBuilder &builder, Location loc) {
313313
// Compute the subsection size.
314314
Value size = c0;
315315
for (auto [loop, stride] : remDepStack[t][lvl]) {
316-
Value loopHi = loopHighs[loop];
317-
size = ADDI(size, MULI(loopHi, C_IDX(stride)));
316+
Value idxMax = SUBI(loopHighs[loop], C_IDX(1));
317+
size = ADDI(size, ADDI(MULI(idxMax, C_IDX(stride)), C_IDX(1)));
318318
}
319319
it = makeNonEmptySubSectIterator(builder, loc, parent, loopHighs[loop],
320320
std::move(lvlIt), size, curDep.second);
321321
} else {
322-
Value size = loopHighs[loop];
323322
const SparseIterator &subSectIter = *iters[t][lvl].back();
324-
it = makeTraverseSubSectIterator(subSectIter, *parent, std::move(lvlIt),
325-
size, curDep.second);
323+
it = makeTraverseSubSectIterator(builder, loc, subSectIter, *parent,
324+
std::move(lvlIt), loopHighs[loop],
325+
curDep.second);
326326
}
327327
lastIter[t] = it.get();
328328
iters[t][lvl].emplace_back(std::move(it));

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -665,13 +665,10 @@ class SubSectIterator : public SparseIterator {
665665
public:
666666
SubSectIterator(const NonEmptySubSectIterator &subSect,
667667
const SparseIterator &parent,
668-
std::unique_ptr<SparseIterator> &&wrap, Value size,
669-
unsigned stride)
668+
std::unique_ptr<SparseIterator> &&wrap)
670669
: SparseIterator(IterKind::kSubSect, *wrap,
671670
/*extraCursorCnt=*/wrap->randomAccessible() ? 0 : 1),
672-
subSect(subSect), wrap(std::move(wrap)), parent(parent), size(size),
673-
stride(stride), helper(*this) {
674-
assert(stride == 1 && "Not implemented.");
671+
subSect(subSect), wrap(std::move(wrap)), parent(parent), helper(*this) {
675672
assert(subSect.tid == tid && subSect.lvl == lvl);
676673
assert(parent.kind != IterKind::kSubSect || parent.lvl + 1 == lvl);
677674
};
@@ -693,7 +690,9 @@ class SubSectIterator : public SparseIterator {
693690

694691
bool randomAccessible() const override { return wrap->randomAccessible(); };
695692
bool iteratableByFor() const override { return randomAccessible(); };
696-
Value upperBound(OpBuilder &b, Location l) const override { return size; }
693+
Value upperBound(OpBuilder &b, Location l) const override {
694+
return subSect.subSectSz;
695+
}
697696
std::pair<Value, Value> getCurPosition() const override {
698697
return wrap->getCurPosition();
699698
};
@@ -711,7 +710,7 @@ class SubSectIterator : public SparseIterator {
711710
assert(p->lvl + 1 == lvl);
712711
wrap->genInit(b, l, p);
713712
// Linearize the dense subsection index.
714-
nxLvlTupleStart = MULI(size, p->getNxLvlTupleId(b, l));
713+
nxLvlTupleStart = MULI(subSect.subSectSz, p->getNxLvlTupleId(b, l));
715714
} else {
716715
assert(subSect.lvl == lvl && subSect.isSubSectRoot());
717716
wrap->deserialize(subSect.delegate->serialize());
@@ -765,9 +764,6 @@ class SubSectIterator : public SparseIterator {
765764
std::unique_ptr<SparseIterator> wrap;
766765
const SparseIterator &parent;
767766

768-
Value size;
769-
unsigned stride;
770-
771767
SubSectIterHelper helper;
772768
};
773769

@@ -1330,29 +1326,19 @@ sparse_tensor::makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit,
13301326
return std::make_unique<FilterIterator>(std::move(sit), offset, stride, size);
13311327
}
13321328

1333-
template <typename IterType>
13341329
static const SparseIterator *tryUnwrapFilter(const SparseIterator *it) {
13351330
auto *filter = llvm::dyn_cast_or_null<FilterIterator>(it);
1336-
if (filter && llvm::isa<IterType>(filter->wrap.get())) {
1331+
if (filter)
13371332
return filter->wrap.get();
1338-
}
13391333
return it;
13401334
}
1341-
template <typename IterType>
1342-
static const IterType *unwrapFilter(const SparseIterator *it) {
1343-
auto *filter = llvm::dyn_cast_or_null<FilterIterator>(it);
1344-
if (filter) {
1345-
return llvm::cast<IterType>(filter->wrap.get());
1346-
}
1347-
return llvm::cast<IterType>(it);
1348-
}
13491335

13501336
std::unique_ptr<SparseIterator> sparse_tensor::makeNonEmptySubSectIterator(
13511337
OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound,
13521338
std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride) {
13531339

13541340
// Try unwrap the NonEmptySubSectIterator from a filter parent.
1355-
parent = tryUnwrapFilter<NonEmptySubSectIterator>(parent);
1341+
parent = tryUnwrapFilter(parent);
13561342
auto it = std::make_unique<NonEmptySubSectIterator>(
13571343
b, l, parent, std::move(delegate), size);
13581344

@@ -1366,12 +1352,22 @@ std::unique_ptr<SparseIterator> sparse_tensor::makeNonEmptySubSectIterator(
13661352
}
13671353

13681354
std::unique_ptr<SparseIterator> sparse_tensor::makeTraverseSubSectIterator(
1369-
const SparseIterator &subSectIter, const SparseIterator &parent,
1370-
std::unique_ptr<SparseIterator> &&wrap, Value size, unsigned stride) {
1355+
OpBuilder &b, Location l, const SparseIterator &subSectIter,
1356+
const SparseIterator &parent, std::unique_ptr<SparseIterator> &&wrap,
1357+
Value loopBound, unsigned stride) {
1358+
13711359
// This must be a subsection iterator or a filtered subsection iterator.
1372-
auto &subSect = *unwrapFilter<NonEmptySubSectIterator>(&subSectIter);
1373-
return std::make_unique<SubSectIterator>(subSect, parent, std::move(wrap),
1374-
size, stride);
1360+
auto &subSect =
1361+
llvm::cast<NonEmptySubSectIterator>(*tryUnwrapFilter(&subSectIter));
1362+
1363+
auto it = std::make_unique<SubSectIterator>(
1364+
subSect, *tryUnwrapFilter(&parent), std::move(wrap));
1365+
1366+
if (stride != 1) {
1367+
return std::make_unique<FilterIterator>(std::move(it), /*offset=*/C_IDX(0),
1368+
C_IDX(stride), /*size=*/loopBound);
1369+
}
1370+
return it;
13751371
}
13761372

13771373
#undef CMPI

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,9 @@ std::unique_ptr<SparseIterator> makeNonEmptySubSectIterator(
300300
/// Helper function to create a SparseIterator object that iterate over a
301301
/// non-empty subsection created by NonEmptySubSectIterator.
302302
std::unique_ptr<SparseIterator> makeTraverseSubSectIterator(
303-
const SparseIterator &subsectIter, const SparseIterator &parent,
304-
std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride);
303+
OpBuilder &b, Location l, const SparseIterator &subsectIter,
304+
const SparseIterator &parent, std::unique_ptr<SparseIterator> &&wrap,
305+
Value loopBound, unsigned stride);
305306

306307
} // namespace sparse_tensor
307308
} // namespace mlir
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
//--------------------------------------------------------------------------------------------------
2+
// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS.
3+
//
4+
// Set-up that's shared across all tests in this directory. In principle, this
5+
// config could be moved to lit.local.cfg. However, there are downstream users that
6+
// do not use these LIT config files. Hence why this is kept inline.
7+
//
8+
// DEFINE: %{sparsifier_opts} = enable-runtime-library=true
9+
// DEFINE: %{sparsifier_opts_sve} = enable-arm-sve=true %{sparsifier_opts}
10+
// DEFINE: %{compile} = mlir-opt %s --sparsifier="%{sparsifier_opts}"
11+
// DEFINE: %{compile_sve} = mlir-opt %s --sparsifier="%{sparsifier_opts_sve}"
12+
// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
13+
// DEFINE: %{run_opts} = -e entry -entry-point-result=void
14+
// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
15+
// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
16+
//
17+
// DEFINE: %{env} =
18+
//--------------------------------------------------------------------------------------------------
19+
20+
// RUN: %{compile} | %{run} | FileCheck %s
21+
//
22+
// Do the same run, but now with direct IR generation.
23+
// REDEFINE: %{sparsifier_opts} = enable-runtime-library=false enable-buffer-initialization=true
24+
// RUN: %{compile} | %{run} | FileCheck %s
25+
//
26+
// Do the same run, but now with direct IR generation and vectorization.
27+
// REDEFINE: %{sparsifier_opts} = enable-runtime-library=false enable-buffer-initialization=true vl=2 reassociate-fp-reductions=true enable-index-optimizations=true
28+
// RUN: %{compile} | %{run} | FileCheck %s
29+
//
30+
// Do the same run, but now with direct IR generation and VLA vectorization.
31+
// RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
32+
33+
#CCCC = #sparse_tensor.encoding<{
34+
map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : compressed)
35+
}>
36+
37+
#CDCC = #sparse_tensor.encoding<{
38+
map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : dense, d2 : compressed, d3 : compressed)
39+
}>
40+
41+
// Creates and returns 4-D buffer of size (%s1, %s2, %s3, %s4) filled with the value %f
42+
func.func @alloc_4d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %s4 : index, %f : f32) -> tensor<?x?x?x?xf32> {
43+
%buf = tensor.empty(%s1, %s2, %s3, %s4) : tensor<?x?x?x?xf32>
44+
%ret = linalg.fill ins(%f : f32) outs(%buf : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
45+
return %ret : tensor<?x?x?x?xf32>
46+
}
47+
48+
func.func @conv_2d_nhwc_hwcf(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
49+
%ret = linalg.conv_2d_nhwc_hwcf {dilations = dense<2> : tensor<2xi64>,
50+
strides = dense<1> : tensor<2xi64>}
51+
ins (%arg0, %arg1: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
52+
outs (%arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
53+
return %ret : tensor<?x?x?x?xf32>
54+
}
55+
56+
func.func @conv_2d_nhwc_hwcf_CCCC(%arg0: tensor<?x?x?x?xf32, #CCCC>, %arg1: tensor<?x?x?x?xf32>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
57+
%ret = linalg.conv_2d_nhwc_hwcf {dilations = dense<2> : tensor<2xi64>,
58+
strides = dense<1> : tensor<2xi64>}
59+
ins (%arg0, %arg1: tensor<?x?x?x?xf32, #CCCC>, tensor<?x?x?x?xf32>)
60+
outs (%arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
61+
return %ret : tensor<?x?x?x?xf32>
62+
}
63+
64+
func.func @conv_2d_nhwc_hwcf_CDCC(%arg0: tensor<?x?x?x?xf32, #CDCC>, %arg1: tensor<?x?x?x?xf32>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
65+
%ret = linalg.conv_2d_nhwc_hwcf {dilations = dense<2> : tensor<2xi64>,
66+
strides = dense<1> : tensor<2xi64>}
67+
ins (%arg0, %arg1: tensor<?x?x?x?xf32, #CDCC>, tensor<?x?x?x?xf32>)
68+
outs (%arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
69+
return %ret : tensor<?x?x?x?xf32>
70+
}
71+
72+
func.func @conv_2d_nhwc_hwcf_dual_CDCC(%arg0: tensor<?x?x?x?xf32, #CDCC>, %arg1: tensor<?x?x?x?xf32, #CDCC>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
73+
%ret = linalg.conv_2d_nhwc_hwcf {dilations = dense<2> : tensor<2xi64>,
74+
strides = dense<1> : tensor<2xi64>}
75+
ins (%arg0, %arg1: tensor<?x?x?x?xf32, #CDCC>, tensor<?x?x?x?xf32, #CDCC>)
76+
outs (%arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
77+
return %ret : tensor<?x?x?x?xf32>
78+
}
79+
80+
81+
func.func @entry() {
82+
%c0 = arith.constant 0 : index
83+
%c1 = arith.constant 1 : index
84+
%c3 = arith.constant 3 : index
85+
%c5 = arith.constant 5 : index
86+
%c6 = arith.constant 6 : index
87+
%c7 = arith.constant 7 : index
88+
%f10 = arith.constant 10.00000e+00 : f32
89+
%val = arith.constant 2.00000e+00 : f32
90+
%zero = arith.constant 0.00000e+00 : f32
91+
92+
%filter2D_nhwc = call @alloc_4d_filled_f32(%c3, %c3, %c3, %c1, %val) :(index, index, index, index, f32) -> (tensor<?x?x?x?xf32>)
93+
%in2D_tmp = call @alloc_4d_filled_f32(%c3, %c7, %c7, %c3, %f10) : (index, index, index, index, f32) -> (tensor<?x?x?x?xf32>)
94+
%in2D_nhwc = tensor.insert %zero into %in2D_tmp[%c0, %c1, %c1, %c0] : tensor<?x?x?x?xf32>
95+
%out2D_nhwc = call @alloc_4d_filled_f32(%c3, %c3, %c3, %c1, %zero) : (index, index, index, index, f32) -> (tensor<?x?x?x?xf32>)
96+
97+
%in2D_nhwc_CCCC = sparse_tensor.convert %in2D_nhwc
98+
: tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32, #CCCC>
99+
%filter2D_nhwc_CDCC = sparse_tensor.convert %filter2D_nhwc
100+
: tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32, #CDCC>
101+
%in2D_nhwc_CDCC = sparse_tensor.convert %in2D_nhwc
102+
: tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32, #CDCC>
103+
104+
%dense_ret = call @conv_2d_nhwc_hwcf(%in2D_nhwc, %filter2D_nhwc, %out2D_nhwc) : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> (tensor<?x?x?x?xf32>)
105+
%CCCC_ret = call @conv_2d_nhwc_hwcf_CCCC(%in2D_nhwc_CCCC, %filter2D_nhwc, %out2D_nhwc) : (tensor<?x?x?x?xf32, #CCCC>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> (tensor<?x?x?x?xf32>)
106+
%CDCC_ret = call @conv_2d_nhwc_hwcf_CDCC(%in2D_nhwc_CDCC, %filter2D_nhwc, %out2D_nhwc) : (tensor<?x?x?x?xf32, #CDCC>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> (tensor<?x?x?x?xf32>)
107+
108+
%dual_CDCC_ret = call @conv_2d_nhwc_hwcf_dual_CDCC(%in2D_nhwc_CDCC, %filter2D_nhwc_CDCC, %out2D_nhwc)
109+
: (tensor<?x?x?x?xf32, #CDCC>, tensor<?x?x?x?xf32, #CDCC>, tensor<?x?x?x?xf32>) -> (tensor<?x?x?x?xf32>)
110+
111+
// CHECK: ( ( ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 520 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ) ),
112+
// CHECK-SAME: ( ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ) ),
113+
// CHECK-SAME: ( ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ) ) )
114+
%dense_v = vector.transfer_read %dense_ret[%c0, %c0, %c0, %c0], %zero
115+
: tensor<?x?x?x?xf32>, vector<3x3x3x1xf32>
116+
vector.print %dense_v : vector<3x3x3x1xf32>
117+
118+
// CHECK-NEXT: ( ( ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 520 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ) ),
119+
// CHECK-SAME: ( ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ) ),
120+
// CHECK-SAME: ( ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ) ) )
121+
%v_dual = vector.transfer_read %dual_CDCC_ret[%c0, %c0, %c0, %c0], %zero
122+
: tensor<?x?x?x?xf32>, vector<3x3x3x1xf32>
123+
vector.print %v_dual : vector<3x3x3x1xf32>
124+
125+
// CHECK-NEXT: ( ( ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 520 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ) ),
126+
// CHECK-SAME: ( ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ) ),
127+
// CHECK-SAME: ( ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ) ) )
128+
%v1 = vector.transfer_read %CCCC_ret[%c0, %c0, %c0, %c0], %zero
129+
: tensor<?x?x?x?xf32>, vector<3x3x3x1xf32>
130+
vector.print %v1 : vector<3x3x3x1xf32>
131+
132+
// CHECK-NEXT: ( ( ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 520 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ) ),
133+
// CHECK-SAME: ( ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ) ),
134+
// CHECK-SAME: ( ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ) ) )
135+
%v2 = vector.transfer_read %CDCC_ret[%c0, %c0, %c0, %c0], %zero
136+
: tensor<?x?x?x?xf32>, vector<3x3x3x1xf32>
137+
vector.print %v1 : vector<3x3x3x1xf32>
138+
139+
// Free the resources.
140+
bufferization.dealloc_tensor %in2D_nhwc : tensor<?x?x?x?xf32>
141+
bufferization.dealloc_tensor %filter2D_nhwc : tensor<?x?x?x?xf32>
142+
bufferization.dealloc_tensor %out2D_nhwc : tensor<?x?x?x?xf32>
143+
144+
bufferization.dealloc_tensor %filter2D_nhwc_CDCC : tensor<?x?x?x?xf32, #CDCC>
145+
bufferization.dealloc_tensor %in2D_nhwc_CCCC : tensor<?x?x?x?xf32, #CCCC>
146+
bufferization.dealloc_tensor %in2D_nhwc_CDCC : tensor<?x?x?x?xf32, #CDCC>
147+
return
148+
}

0 commit comments

Comments
 (0)