Skip to content

Commit 7888539

Browse files
author
Peiming Liu
authored
[mlir][sparse] support tensor.pad on CSR tensors (#90687)
1 parent 0af415d commit 7888539

File tree

4 files changed

+198
-58
lines changed

4 files changed

+198
-58
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp

Lines changed: 105 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,9 @@ class DenseLevel : public SparseTensorLevel {
9595
ValueRange getLvlBuffers() const override { return {}; }
9696

9797
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
98-
ValueRange parentPos) const override {
98+
ValueRange parentPos, Value inPadZone) const override {
9999
assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
100+
assert(!inPadZone && "Not implemented");
100101
Value p = parentPos.front();
101102
Value posLo = MULI(p, lvlSize);
102103
return {posLo, lvlSize};
@@ -115,7 +116,8 @@ class BatchLevel : public SparseTensorLevel {
115116
ValueRange getLvlBuffers() const override { return {}; }
116117

117118
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange,
118-
ValueRange parentPos) const override {
119+
ValueRange parentPos, Value inPadZone) const override {
120+
assert(!inPadZone && "Not implemented");
119121
assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
120122
// No need to linearize the position for non-annotated tensors.
121123
return {C_IDX(0), lvlSize};
@@ -129,18 +131,42 @@ class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
129131
: SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
130132

131133
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
132-
ValueRange parentPos) const override {
134+
ValueRange parentPos, Value inPadZone) const override {
133135

134136
assert(parentPos.size() == 1 &&
135137
"compressed level must be the first non-unique level.");
136-
Value p = parentPos.front();
137138

138-
SmallVector<Value> memCrd(batchPrefix);
139-
memCrd.push_back(p);
140-
Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
141-
memCrd.back() = ADDI(p, C_IDX(1));
142-
Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
143-
return {pLo, pHi};
139+
auto loadRange = [&b, l, parentPos, batchPrefix, this]() -> ValuePair {
140+
Value p = parentPos.front();
141+
SmallVector<Value> memCrd(batchPrefix);
142+
memCrd.push_back(p);
143+
Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
144+
memCrd.back() = ADDI(p, C_IDX(1));
145+
Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
146+
return {pLo, pHi};
147+
};
148+
149+
if (inPadZone == nullptr)
150+
return loadRange();
151+
152+
SmallVector<Type, 2> types{b.getIndexType(), b.getIndexType()};
153+
scf::IfOp posRangeIf = b.create<scf::IfOp>(l, types, inPadZone, true);
154+
// True branch, returns a "fake" empty range [0, 0) if parent
155+
// iterator is in pad zone.
156+
b.setInsertionPointToStart(posRangeIf.thenBlock());
157+
158+
SmallVector<Value, 2> emptyRange{C_IDX(0), C_IDX(0)};
159+
b.create<scf::YieldOp>(l, emptyRange);
160+
161+
// False branch, returns the actual range.
162+
b.setInsertionPointToStart(posRangeIf.elseBlock());
163+
auto [pLo, pHi] = loadRange();
164+
SmallVector<Value, 2> loadedRange{pLo, pHi};
165+
b.create<scf::YieldOp>(l, loadedRange);
166+
167+
b.setInsertionPointAfter(posRangeIf);
168+
ValueRange posRange = posRangeIf.getResults();
169+
return {posRange.front(), posRange.back()};
144170
}
145171
};
146172

@@ -151,9 +177,10 @@ class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
151177
: SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
152178

153179
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
154-
ValueRange parentPos) const override {
180+
ValueRange parentPos, Value inPadZone) const override {
155181
assert(parentPos.size() == 1 &&
156182
"loose-compressed level must be the first non-unique level.");
183+
assert(!inPadZone && "Not implemented");
157184
SmallVector<Value> memCrd(batchPrefix);
158185
Value p = parentPos.front();
159186
p = MULI(p, C_IDX(2));
@@ -172,8 +199,9 @@ class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
172199
: SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
173200

174201
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
175-
ValueRange parentPos) const override {
202+
ValueRange parentPos, Value inPadZone) const override {
176203
assert(parentPos.size() == 1 || parentPos.size() == 2);
204+
assert(!inPadZone && "Not implemented");
177205
Value p = parentPos.front();
178206
Value segHi = parentPos.size() == 2 ? parentPos.back() : nullptr;
179207

@@ -191,9 +219,10 @@ class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
191219
: SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
192220

193221
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
194-
ValueRange parentPos) const override {
222+
ValueRange parentPos, Value inPadZone) const override {
195223
assert(parentPos.size() == 1 && isUnique() &&
196224
"n:m level can not be non-unique.");
225+
assert(!inPadZone && "Not implemented");
197226
// Each n:m blk has exactly n specified elements.
198227
auto n = getN(lt);
199228
Value posLo = MULI(parentPos.front(), C_IDX(n));
@@ -325,23 +354,7 @@ class TrivialIterator : public ConcreteIterator {
325354
};
326355

327356
void genInitImpl(OpBuilder &b, Location l,
328-
const SparseIterator *parent) override {
329-
330-
if (isBatchIterator() && batchCrds.size() <= stl.lvl)
331-
batchCrds.resize(stl.lvl + 1, nullptr);
332-
333-
Value c0 = C_IDX(0);
334-
ValueRange pPos = c0;
335-
// If the parent iterator is a batch iterator, we also start from 0 (but
336-
// on a different batch).
337-
if (parent && !parent->isBatchIterator())
338-
pPos = parent->getCurPosition();
339-
340-
ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
341-
std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos);
342-
// Seek to the lowest position.
343-
seek(posLo);
344-
}
357+
const SparseIterator *parent) override;
345358

346359
ValuePair genForCond(OpBuilder &b, Location l) override {
347360
if (randomAccessible())
@@ -465,15 +478,17 @@ class DedupIterator : public ConcreteIterator {
465478
// A util base-iterator that delegates all methods to the wrapped iterator.
466479
class SimpleWrapIterator : public SparseIterator {
467480
public:
468-
SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind)
469-
: SparseIterator(kind, *wrap), wrap(std::move(wrap)) {}
481+
SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind,
482+
unsigned extraCursorVal = 0)
483+
: SparseIterator(kind, *wrap, extraCursorVal), wrap(std::move(wrap)) {}
470484

471485
SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
472486
return wrap->getCursorValTypes(b);
473487
}
474488
bool isBatchIterator() const override { return wrap->isBatchIterator(); }
475489
bool randomAccessible() const override { return wrap->randomAccessible(); };
476490
bool iteratableByFor() const override { return wrap->iteratableByFor(); };
491+
477492
SmallVector<Value> serialize() const override { return wrap->serialize(); };
478493
void deserialize(ValueRange vs) override { wrap->deserialize(vs); };
479494
ValueRange getCurPosition() const override { return wrap->getCurPosition(); }
@@ -586,10 +601,9 @@ class PadIterator : public SimpleWrapIterator {
586601
public:
587602
PadIterator(std::unique_ptr<SparseIterator> &&wrap, Value padLow,
588603
Value padHigh)
589-
: SimpleWrapIterator(std::move(wrap), IterKind::kPad), padLow(padLow),
590-
padHigh(padHigh) {
591-
assert(!randomAccessible() && "Not implemented.");
592-
}
604+
: SimpleWrapIterator(std::move(wrap), IterKind::kPad,
605+
wrap->randomAccessible() ? 1 : 0),
606+
padLow(padLow), padHigh(padHigh) {}
593607

594608
// For LLVM-style RTTI.
595609
static bool classof(const SparseIterator *from) {
@@ -600,6 +614,26 @@ class PadIterator : public SimpleWrapIterator {
600614
return std::string("pad<") + wrap->getDebugInterfacePrefix() + ">";
601615
}
602616

617+
// Returns a pair of values for *upper*, *lower* bound respectively.
618+
ValuePair genForCond(OpBuilder &b, Location l) override {
619+
if (randomAccessible())
620+
return {getCrd(), upperBound(b, l)};
621+
return wrap->genForCond(b, l);
622+
}
623+
624+
// For padded dense iterator, we append a `inPadZone: bool` in addition to
625+
// values used by the wrapped iterator.
626+
ValueRange getCurPosition() const override { return getCursor(); }
627+
628+
SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
629+
SmallVector<Type> ret = wrap->getCursorValTypes(b);
630+
// Need an extra boolean value `inPadZone` for padded dense iterator.
631+
if (randomAccessible())
632+
ret.push_back(b.getI1Type());
633+
634+
return ret;
635+
}
636+
603637
// The upper bound after padding becomes `size + padLow + padHigh`.
604638
Value upperBound(OpBuilder &b, Location l) const override {
605639
return ADDI(ADDI(wrap->upperBound(b, l), padLow), padHigh);
@@ -613,6 +647,14 @@ class PadIterator : public SimpleWrapIterator {
613647

614648
void locateImpl(OpBuilder &b, Location l, Value crd) override {
615649
assert(randomAccessible());
650+
wrap->locate(b, l, SUBI(crd, padLow));
651+
652+
// inPadZone = crd < padLow || crd >= size + padLow.
653+
Value inPadLow = CMPI(ult, crd, padLow);
654+
Value inPadHigh = CMPI(uge, crd, ADDI(wrap->upperBound(b, l), padLow));
655+
getMutCursorVals().back() = ORI(inPadLow, inPadHigh);
656+
657+
updateCrd(crd);
616658
}
617659

618660
Value padLow, padHigh;
@@ -1227,6 +1269,33 @@ ValueRange NonEmptySubSectIterator::inflateSubSectTree(
12271269
return p->inflateSubSectTree(b, l, reduc, visitDenseSubSect);
12281270
}
12291271

1272+
void TrivialIterator::genInitImpl(OpBuilder &b, Location l,
1273+
const SparseIterator *parent) {
1274+
1275+
if (isBatchIterator() && batchCrds.size() <= stl.lvl)
1276+
batchCrds.resize(stl.lvl + 1, nullptr);
1277+
1278+
Value c0 = C_IDX(0);
1279+
ValueRange pPos = c0;
1280+
Value inPadZone = nullptr;
1281+
// If the parent iterator is a batch iterator, we also start from 0 (but
1282+
// on a different batch).
1283+
if (parent && !parent->isBatchIterator()) {
1284+
pPos = parent->getCurPosition();
1285+
if (llvm::isa<PadIterator>(parent) && parent->randomAccessible()) {
1286+
// A padded dense iterator create "sparse" padded zone, which need to be
1287+
// handled specially.
1288+
inPadZone = pPos.back();
1289+
pPos = pPos.drop_back();
1290+
}
1291+
}
1292+
1293+
ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
1294+
std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos, inPadZone);
1295+
// Seek to the lowest position.
1296+
seek(posLo);
1297+
}
1298+
12301299
void NonEmptySubSectIterator::genInitImpl(OpBuilder &b, Location l,
12311300
const SparseIterator *) {
12321301
Value c0 = C_IDX(0);

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ class SparseTensorLevel {
4646
///
4747
/// For a sparse level, [posLo, loopHi) specifies the range of index pointer
4848
/// to load coordinate from the coordinate buffer.
49-
virtual std::pair<Value, Value> peekRangeAt(OpBuilder &b, Location l,
50-
ValueRange batchPrefix,
51-
ValueRange parentPos) const = 0;
49+
virtual std::pair<Value, Value>
50+
peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
51+
ValueRange parentPos, Value inPadZone = nullptr) const = 0;
5252

5353
Level getLevel() const { return lvl; }
5454
LevelType getLT() const { return lt; }
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification -canonicalize | FileCheck %s
2+
3+
#CSR = #sparse_tensor.encoding<{
4+
map = (d0, d1) -> (d0 : dense, d1 : compressed)
5+
}>
6+
7+
#elemwise = {
8+
indexing_maps = [
9+
affine_map<(i,j) -> (i,j)>, // A
10+
affine_map<(i,j) -> (i,j)>, // B
11+
affine_map<(i,j) -> (i,j)> // X (out)
12+
],
13+
iterator_types = ["parallel", "parallel"],
14+
doc = "X(i,j) = A(i,j) OP B(i,j)"
15+
}
16+
17+
18+
// CHECK-LABEL: func.func @padded_mul(
19+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x4xf32, #sparse>,
20+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<8x8xf32>) -> tensor<8x8xf32> {
21+
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant -1 : index
22+
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 6 : index
23+
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 8 : index
24+
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
25+
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
26+
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 2 : index
27+
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32
28+
// CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<8x8xf32>
29+
// CHECK: %[[VAL_10:.*]] = linalg.fill ins(%[[VAL_8]] : f32) outs(%[[VAL_9]] : tensor<8x8xf32>) -> tensor<8x8xf32>
30+
// CHECK: %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<4x4xf32, #sparse> to memref<?xindex>
31+
// CHECK: %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<4x4xf32, #sparse> to memref<?xindex>
32+
// CHECK: %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4xf32, #sparse> to memref<?xf32>
33+
// CHECK: %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_10]] : memref<8x8xf32>
34+
// CHECK: linalg.fill ins(%[[VAL_8]] : f32) outs(%[[VAL_14]] : memref<8x8xf32>)
35+
// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_5]] {
36+
// CHECK: %[[VAL_16:.*]] = arith.subi %[[VAL_15]], %[[VAL_7]] : index
37+
// CHECK: %[[VAL_17:.*]] = arith.cmpi ult, %[[VAL_15]], %[[VAL_7]] : index
38+
// CHECK: %[[VAL_18:.*]] = arith.cmpi uge, %[[VAL_15]], %[[VAL_3]] : index
39+
// CHECK: %[[VAL_19:.*]] = arith.ori %[[VAL_17]], %[[VAL_18]] : i1
40+
// CHECK: %[[VAL_20:.*]]:2 = scf.if %[[VAL_19]] -> (index, index) {
41+
// CHECK: scf.yield %[[VAL_6]], %[[VAL_6]] : index, index
42+
// CHECK: } else {
43+
// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_16]]] : memref<?xindex>
44+
// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_15]], %[[VAL_2]] : index
45+
// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_22]]] : memref<?xindex>
46+
// CHECK: scf.yield %[[VAL_21]], %[[VAL_23]] : index, index
47+
// CHECK: }
48+
// CHECK: scf.for %[[VAL_24:.*]] = %[[VAL_20]]#0 to %[[VAL_20]]#1 step %[[VAL_5]] {
49+
// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_24]]] : memref<?xindex>
50+
// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_7]] : index
51+
// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_24]]] : memref<?xf32>
52+
// CHECK: %[[VAL_29:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : tensor<8x8xf32>
53+
// CHECK: %[[VAL_30:.*]] = arith.mulf %[[VAL_28]], %[[VAL_29]] : f32
54+
// CHECK: memref.store %[[VAL_30]], %[[VAL_14]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : memref<8x8xf32>
55+
// CHECK: } {"Emitted from" = "linalg.generic"}
56+
// CHECK: } {"Emitted from" = "linalg.generic"}
57+
// CHECK: %[[VAL_31:.*]] = bufferization.to_tensor %[[VAL_14]] : memref<8x8xf32>
58+
// CHECK: return %[[VAL_31]] : tensor<8x8xf32>
59+
// CHECK: }
60+
func.func @padded_mul(%arg0: tensor<4x4xf32, #CSR>, %arg1: tensor<8x8xf32>) -> tensor<8x8xf32> {
61+
%cst_0 = arith.constant 0.00000e+00 : f32
62+
%buf = tensor.empty() : tensor<8x8xf32>
63+
%s = linalg.fill ins(%cst_0 : f32) outs(%buf : tensor<8x8xf32>) -> tensor<8x8xf32>
64+
65+
%padded = tensor.pad %arg0 low[2, 2] high[2, 2] {
66+
^bb0(%arg75: index, %arg76: index):
67+
tensor.yield %cst_0 : f32
68+
} : tensor<4x4xf32, #CSR> to tensor<8x8xf32, #CSR>
69+
70+
%0 = linalg.generic #elemwise
71+
ins(%padded, %arg1: tensor<8x8xf32, #CSR>, tensor<8x8xf32>)
72+
outs(%s: tensor<8x8xf32>) {
73+
^bb(%a: f32, %b: f32, %x: f32):
74+
%0 = arith.mulf %a, %b : f32
75+
linalg.yield %0 : f32
76+
} -> tensor<8x8xf32>
77+
78+
return %0 : tensor<8x8xf32>
79+
}

0 commit comments

Comments
 (0)