Skip to content

Commit 12e3812

Browse files
author
Peiming Liu
committed
[mlir][sparse] support tensor.pad on CSR tensors
1 parent 987c036 commit 12e3812

File tree

3 files changed

+111
-57
lines changed

3 files changed

+111
-57
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp

Lines changed: 97 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,9 @@ class DenseLevel : public SparseTensorLevel {
9595
ValueRange getLvlBuffers() const override { return {}; }
9696

9797
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
98-
ValueRange parentPos) const override {
98+
ValueRange parentPos, Value inPadZone) const override {
9999
assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
100+
assert(!inPadZone && "Not implememnted");
100101
Value p = parentPos.front();
101102
Value posLo = MULI(p, lvlSize);
102103
return {posLo, lvlSize};
@@ -115,7 +116,8 @@ class BatchLevel : public SparseTensorLevel {
115116
ValueRange getLvlBuffers() const override { return {}; }
116117

117118
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange,
118-
ValueRange parentPos) const override {
119+
ValueRange parentPos, Value inPadZone) const override {
120+
assert(!inPadZone && "Not implememnted");
119121
assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
120122
// No need to linearize the position for non-annotated tensors.
121123
return {C_IDX(0), lvlSize};
@@ -129,18 +131,41 @@ class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
129131
: SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
130132

131133
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
132-
ValueRange parentPos) const override {
134+
ValueRange parentPos, Value inPadZone) const override {
133135

134136
assert(parentPos.size() == 1 &&
135137
"compressed level must be the first non-unique level.");
136-
Value p = parentPos.front();
137138

138-
SmallVector<Value> memCrd(batchPrefix);
139-
memCrd.push_back(p);
140-
Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
141-
memCrd.back() = ADDI(p, C_IDX(1));
142-
Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
143-
return {pLo, pHi};
139+
auto loadRange = [&b, l, parentPos, batchPrefix, this]() -> ValuePair {
140+
Value p = parentPos.front();
141+
SmallVector<Value> memCrd(batchPrefix);
142+
memCrd.push_back(p);
143+
Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
144+
memCrd.back() = ADDI(p, C_IDX(1));
145+
Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
146+
return {pLo, pHi};
147+
};
148+
149+
if (inPadZone == nullptr)
150+
return loadRange();
151+
152+
SmallVector<Type, 2> types{b.getIndexType(), b.getIndexType()};
153+
scf::IfOp posRangeIf = b.create<scf::IfOp>(l, types, inPadZone, true);
154+
// True branch.
155+
b.setInsertionPointToStart(posRangeIf.thenBlock());
156+
// Returns a "fake" empty range [0, 0) if parent iterator is in pad zone.
157+
SmallVector<Value, 2> emptyRange{C_IDX(0), C_IDX(0)};
158+
b.create<scf::YieldOp>(l, emptyRange);
159+
160+
// False branch.
161+
b.setInsertionPointToStart(posRangeIf.elseBlock());
162+
auto [pLo, pHi] = loadRange();
163+
SmallVector<Value, 2> loadedRange{pLo, pHi};
164+
b.create<scf::YieldOp>(l, loadedRange);
165+
166+
b.setInsertionPointAfter(posRangeIf);
167+
ValueRange posRange = posRangeIf.getResults();
168+
return {posRange.front(), posRange.back()};
144169
}
145170
};
146171

@@ -151,9 +176,10 @@ class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
151176
: SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
152177

153178
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
154-
ValueRange parentPos) const override {
179+
ValueRange parentPos, Value inPadZone) const override {
155180
assert(parentPos.size() == 1 &&
156181
"loose-compressed level must be the first non-unique level.");
182+
assert(!inPadZone && "Not implememnted");
157183
SmallVector<Value> memCrd(batchPrefix);
158184
Value p = parentPos.front();
159185
p = MULI(p, C_IDX(2));
@@ -172,8 +198,9 @@ class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
172198
: SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
173199

174200
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
175-
ValueRange parentPos) const override {
201+
ValueRange parentPos, Value inPadZone) const override {
176202
assert(parentPos.size() == 1 || parentPos.size() == 2);
203+
assert(!inPadZone && "Not implememnted");
177204
Value p = parentPos.front();
178205
Value segHi = parentPos.size() == 2 ? parentPos.back() : nullptr;
179206

@@ -191,9 +218,10 @@ class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
191218
: SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
192219

193220
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
194-
ValueRange parentPos) const override {
221+
ValueRange parentPos, Value inPadZone) const override {
195222
assert(parentPos.size() == 1 && isUnique() &&
196223
"n:m level can not be non-unique.");
224+
assert(!inPadZone && "Not implememnted");
197225
// Each n:m blk has exactly n specified elements.
198226
auto n = getN(lt);
199227
Value posLo = MULI(parentPos.front(), C_IDX(n));
@@ -325,23 +353,7 @@ class TrivialIterator : public ConcreteIterator {
325353
};
326354

327355
void genInitImpl(OpBuilder &b, Location l,
328-
const SparseIterator *parent) override {
329-
330-
if (isBatchIterator() && batchCrds.size() <= stl.lvl)
331-
batchCrds.resize(stl.lvl + 1, nullptr);
332-
333-
Value c0 = C_IDX(0);
334-
ValueRange pPos = c0;
335-
// If the parent iterator is a batch iterator, we also start from 0 (but
336-
// on a different batch).
337-
if (parent && !parent->isBatchIterator())
338-
pPos = parent->getCurPosition();
339-
340-
ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
341-
std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos);
342-
// Seek to the lowest position.
343-
seek(posLo);
344-
}
356+
const SparseIterator *parent) override;
345357

346358
ValuePair genForCond(OpBuilder &b, Location l) override {
347359
if (randomAccessible())
@@ -465,8 +477,9 @@ class DedupIterator : public ConcreteIterator {
465477
// A util base-iterator that delegates all methods to the wrapped iterator.
466478
class SimpleWrapIterator : public SparseIterator {
467479
public:
468-
SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind)
469-
: SparseIterator(kind, *wrap), wrap(std::move(wrap)) {}
480+
SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind,
481+
unsigned extraCursorVal = 0)
482+
: SparseIterator(kind, *wrap, extraCursorVal), wrap(std::move(wrap)) {}
470483

471484
SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
472485
return wrap->getCursorValTypes(b);
@@ -586,9 +599,10 @@ class PadIterator : public SimpleWrapIterator {
586599
public:
587600
PadIterator(std::unique_ptr<SparseIterator> &&wrap, Value padLow,
588601
Value padHigh)
589-
: SimpleWrapIterator(std::move(wrap), IterKind::kPad), padLow(padLow),
590-
padHigh(padHigh) {
591-
assert(!randomAccessible() && "Not implemented.");
602+
: SimpleWrapIterator(std::move(wrap), IterKind::kPad,
603+
wrap->randomAccessible() ? 1 : 0),
604+
padLow(padLow), padHigh(padHigh) {
605+
// assert(!randomAccessible());
592606
}
593607

594608
// For LLVM-style RTTI.
@@ -600,6 +614,19 @@ class PadIterator : public SimpleWrapIterator {
600614
return std::string("pad<") + wrap->getDebugInterfacePrefix() + ">";
601615
}
602616

617+
// For padded dense iterator, we append a `inPadZone: bool` in addition to
618+
// values used by the wrapped iterator.
619+
ValueRange getCurPosition() const override { return getCursor(); }
620+
621+
SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
622+
SmallVector<Type> ret = wrap->getCursorValTypes(b);
623+
// Need a extra boolean value `inPadZone` for padded dense iterator.
624+
if (randomAccessible())
625+
ret.push_back(b.getI1Type());
626+
627+
return ret;
628+
}
629+
603630
// The upper bound after padding becomes `size + padLow + padHigh`.
604631
Value upperBound(OpBuilder &b, Location l) const override {
605632
return ADDI(ADDI(wrap->upperBound(b, l), padLow), padHigh);
@@ -613,6 +640,14 @@ class PadIterator : public SimpleWrapIterator {
613640

614641
void locateImpl(OpBuilder &b, Location l, Value crd) override {
615642
assert(randomAccessible());
643+
wrap->locate(b, l, SUBI(crd, padLow));
644+
645+
// inPadZone = crd < padLow || crd >= size + padLow.
646+
Value inPadLow = CMPI(ult, crd, padLow);
647+
Value inPadHigh = CMPI(uge, crd, ADDI(wrap->upperBound(b, l), padLow));
648+
getMutCursorVals().back() = ORI(inPadLow, inPadHigh);
649+
650+
updateCrd(crd);
616651
}
617652

618653
Value padLow, padHigh;
@@ -1227,6 +1262,33 @@ ValueRange NonEmptySubSectIterator::inflateSubSectTree(
12271262
return p->inflateSubSectTree(b, l, reduc, visitDenseSubSect);
12281263
}
12291264

1265+
void TrivialIterator::genInitImpl(OpBuilder &b, Location l,
1266+
const SparseIterator *parent) {
1267+
1268+
if (isBatchIterator() && batchCrds.size() <= stl.lvl)
1269+
batchCrds.resize(stl.lvl + 1, nullptr);
1270+
1271+
Value c0 = C_IDX(0);
1272+
ValueRange pPos = c0;
1273+
Value inPadZone = nullptr;
1274+
// If the parent iterator is a batch iterator, we also start from 0 (but
1275+
// on a different batch).
1276+
if (parent && !parent->isBatchIterator()) {
1277+
pPos = parent->getCurPosition();
1278+
if (llvm::isa<PadIterator>(parent) && parent->randomAccessible()) {
1279+
// A padded dense iterator create "sparse" padded zone, which need to be
1280+
// handled specially.
1281+
inPadZone = pPos.back();
1282+
pPos = pPos.drop_back();
1283+
}
1284+
}
1285+
1286+
ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
1287+
std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos, inPadZone);
1288+
// Seek to the lowest position.
1289+
seek(posLo);
1290+
}
1291+
12301292
void NonEmptySubSectIterator::genInitImpl(OpBuilder &b, Location l,
12311293
const SparseIterator *) {
12321294
Value c0 = C_IDX(0);

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ class SparseTensorLevel {
4646
///
4747
/// For a sparse level, [posLo, loopHi) specifies the range of index pointer
4848
/// to load coordinate from the coordinate buffer.
49-
virtual std::pair<Value, Value> peekRangeAt(OpBuilder &b, Location l,
50-
ValueRange batchPrefix,
51-
ValueRange parentPos) const = 0;
49+
virtual std::pair<Value, Value>
50+
peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
51+
ValueRange parentPos, Value inPadZone = nullptr) const = 0;
5252

5353
Level getLevel() const { return lvl; }
5454
LevelType getLT() const { return lt; }

mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,8 @@
3030
// Do the same run, but now with direct IR generation and VLA vectorization.
3131
// RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
3232

33-
#CCCC = #sparse_tensor.encoding<{
34-
map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : compressed)
35-
}>
36-
37-
#CDCD = #sparse_tensor.encoding<{
38-
map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : dense, d2 : compressed, d3 : dense)
39-
}>
40-
41-
#DCCD = #sparse_tensor.encoding<{
42-
map = (d0, d1, d2, d3) -> (d0 : dense, d1 : compressed, d2 : compressed, d3 : dense)
33+
#CDCC_NHWC = #sparse_tensor.encoding<{
34+
map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : dense, d2 : compressed, d3 : compressed)
4335
}>
4436

4537
// Creates and returns 4-D buffer of size (%s1, %s2, %s3, %s4) filled with the value %f
@@ -66,19 +58,19 @@ func.func @conv_2d_nhwc_hwcf(%arg0: tensor<3x8x8x3xf32>, %arg1: tensor<5x5x3x1xf
6658
return %ret : tensor<3x8x8x1xf32>
6759
}
6860

69-
func.func @conv_2d_nhwc_hwcf_CCCC(%arg0: tensor<3x8x8x3xf32, #CCCC>, %arg1: tensor<5x5x3x1xf32>) -> tensor<3x8x8x1xf32> {
61+
func.func @conv_2d_nhwc_hwcf_CDCC_NHWC(%arg0: tensor<3x8x8x3xf32, #CDCC_NHWC>, %arg1: tensor<5x5x3x1xf32>) -> tensor<3x8x8x1xf32> {
7062
%cst_0 = arith.constant 0.00000e+00 : f32
7163
%buf = tensor.empty() : tensor<3x8x8x1xf32>
7264
%s = linalg.fill ins(%cst_0 : f32) outs(%buf : tensor<3x8x8x1xf32>) -> tensor<3x8x8x1xf32>
7365

7466
%padded = tensor.pad %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0] {
7567
^bb0(%arg75: index, %arg76: index, %arg77: index, %arg78: index):
7668
tensor.yield %cst_0 : f32
77-
} : tensor<3x8x8x3xf32, #CCCC> to tensor<3x12x12x3xf32, #CCCC>
69+
} : tensor<3x8x8x3xf32, #CDCC_NHWC> to tensor<3x12x12x3xf32, #CDCC_NHWC>
7870

7971
%ret = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
8072
strides = dense<1> : tensor<2xi64>}
81-
ins (%padded, %arg1: tensor<3x12x12x3xf32, #CCCC>, tensor<5x5x3x1xf32>)
73+
ins (%padded, %arg1: tensor<3x12x12x3xf32, #CDCC_NHWC>, tensor<5x5x3x1xf32>)
8274
outs (%s: tensor<3x8x8x1xf32>) -> tensor<3x8x8x1xf32>
8375
return %ret : tensor<3x8x8x1xf32>
8476
}
@@ -105,8 +97,8 @@ func.func @main() {
10597

10698
%dense_ret = call @conv_2d_nhwc_hwcf(%static_input, %static_filter, %static_output) : (tensor<3x8x8x3xf32>, tensor<5x5x3x1xf32>, tensor<3x8x8x1xf32>) -> (tensor<3x8x8x1xf32>)
10799

108-
%in2D_nhwc_CCCC = sparse_tensor.convert %static_input : tensor<3x8x8x3xf32> to tensor<3x8x8x3xf32, #CCCC>
109-
%CCCC_ret = call @conv_2d_nhwc_hwcf_CCCC(%in2D_nhwc_CCCC, %static_filter) : (tensor<3x8x8x3xf32, #CCCC>, tensor<5x5x3x1xf32>) -> (tensor<3x8x8x1xf32>)
100+
%in2D_nhwc_CDCC_NHWC = sparse_tensor.convert %static_input : tensor<3x8x8x3xf32> to tensor<3x8x8x3xf32, #CDCC_NHWC>
101+
%CDCC_NHWC_ret = call @conv_2d_nhwc_hwcf_CDCC_NHWC(%in2D_nhwc_CDCC_NHWC, %static_filter) : (tensor<3x8x8x3xf32, #CDCC_NHWC>, tensor<5x5x3x1xf32>) -> (tensor<3x8x8x1xf32>)
110102

111103

112104
// CHECK: ( ( ( ( 108 ), ( 160 ), ( 196 ), ( 196 ), ( 196 ), ( 196 ), ( 144 ), ( 108 ) ),
@@ -161,17 +153,17 @@ func.func @main() {
161153
// CHECK-SAME: ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
162154
// CHECK-SAME: ( ( 144 ), ( 192 ), ( 240 ), ( 240 ), ( 240 ), ( 240 ), ( 192 ), ( 144 ) ),
163155
// CHECK-SAME: ( ( 108 ), ( 144 ), ( 180 ), ( 180 ), ( 180 ), ( 180 ), ( 144 ), ( 108 ) ) ) )
164-
%CCCC_v = vector.transfer_read %CCCC_ret[%c0, %c0, %c0, %c0], %zero
156+
%CDCC_NHWC_v = vector.transfer_read %CDCC_NHWC_ret[%c0, %c0, %c0, %c0], %zero
165157
: tensor<3x8x8x1xf32>, vector<3x8x8x1xf32>
166-
vector.print %CCCC_v : vector<3x8x8x1xf32>
158+
vector.print %CDCC_NHWC_v : vector<3x8x8x1xf32>
167159

168160
bufferization.dealloc_tensor %static_filter : tensor<5x5x3x1xf32>
169161
bufferization.dealloc_tensor %static_input : tensor<3x8x8x3xf32>
170162
bufferization.dealloc_tensor %static_output : tensor<3x8x8x1xf32>
171163

172-
bufferization.dealloc_tensor %CCCC_ret : tensor<3x8x8x1xf32>
164+
bufferization.dealloc_tensor %CDCC_NHWC_ret : tensor<3x8x8x1xf32>
173165

174-
bufferization.dealloc_tensor %in2D_nhwc_CCCC : tensor<3x8x8x3xf32, #CCCC>
166+
bufferization.dealloc_tensor %in2D_nhwc_CDCC_NHWC : tensor<3x8x8x3xf32, #CDCC_NHWC>
175167

176168
return
177169
}

0 commit comments

Comments
 (0)