Skip to content

Commit dbe3766

Browse files
author
Peiming Liu
authored
[mlir][sparse] handle padding on sparse levels. (#90527)
1 parent 4cd11c9 commit dbe3766

File tree

4 files changed

+339
-29
lines changed

4 files changed

+339
-29
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp

Lines changed: 55 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,40 @@ static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor,
7575
return createOrFoldSliceStrideOp(builder, loc, tensor, toDim(enc, lvl));
7676
}
7777

78+
static bool isIntOrFPZero(Attribute attr) {
79+
if (auto f = llvm::dyn_cast<FloatAttr>(attr); f && f.getValue().isZero())
80+
return true;
81+
if (auto i = llvm::dyn_cast<IntegerAttr>(attr); i && i.getValue().isZero())
82+
return true;
83+
return false;
84+
}
85+
86+
static Value unFoldOpIntResult(OpBuilder &builder, Location loc,
87+
OpFoldResult ofr) {
88+
if (std::optional<int64_t> i = getConstantIntValue(ofr); i.has_value())
89+
return constantIndex(builder, loc, *i);
90+
return ofr.get<Value>();
91+
}
92+
93+
static Value tryFoldTensors(Value t) {
94+
// TODO: this should be done through a folding pass after switching to
95+
// `sparse_tensor.iterate`-based sparsification.
96+
auto stt = tryGetSparseTensorType(t);
97+
auto padOp = t.getDefiningOp<tensor::PadOp>();
98+
if (padOp && stt.has_value() && stt->hasEncoding() &&
99+
padOp.getSourceType().getEncoding() == stt->getEncoding() &&
100+
stt->getEncoding().isIdentity()) {
101+
// Try fusing padOp with zeros.
102+
Attribute padCst;
103+
if (matchPattern(padOp.getBody()->getTerminator(),
104+
m_Op<tensor::YieldOp>(m_Constant(&padCst))) &&
105+
isIntOrFPZero(padCst)) {
106+
return padOp.getSource();
107+
}
108+
}
109+
return t;
110+
}
111+
78112
//===----------------------------------------------------------------------===//
79113
// Sparse tensor loop emitter class implementations
80114
//===----------------------------------------------------------------------===//
@@ -166,15 +200,30 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
166200
std::unique_ptr<SparseIterator>
167201
LoopEmitter::makeLevelIterator(OpBuilder &builder, Location loc, TensorId t,
168202
Level l) {
203+
Value tensor = tensors[t];
204+
auto stt = getSparseTensorType(tensor);
169205
auto it = makeSimpleIterator(*lvls[t][l], emitStrategy);
170-
auto stt = getSparseTensorType(tensors[t]);
206+
207+
Value folded = tryFoldTensors(tensor);
208+
if (folded != tensor) {
209+
auto padOp = tensor.getDefiningOp<tensor::PadOp>();
210+
assert(padOp);
211+
if (padOp.getPaddedDims().test(l)) {
212+
Value low = unFoldOpIntResult(builder, loc, padOp.getMixedLowPad()[l]);
213+
Value high = unFoldOpIntResult(builder, loc, padOp.getMixedHighPad()[l]);
214+
auto padIt = makePaddedIterator(std::move(it), low, high, emitStrategy);
215+
return padIt;
216+
}
217+
}
218+
171219
if (stt.hasEncoding() && stt.getEncoding().isSlice()) {
172-
Value offset = genSliceOffset(builder, loc, tensors[t], l);
173-
Value stride = genSliceStride(builder, loc, tensors[t], l);
220+
Value offset = genSliceOffset(builder, loc, tensor, l);
221+
Value stride = genSliceStride(builder, loc, tensor, l);
174222
auto slicedIt = makeSlicedLevelIterator(
175223
std::move(it), offset, stride, lvls[t][l]->getSize(), emitStrategy);
176224
return slicedIt;
177225
}
226+
178227
return it;
179228
}
180229

@@ -200,7 +249,9 @@ void LoopEmitter::initializeLoopEmit(
200249
// on positions.
201250
for (TensorId t = 0, numTensors = getNumManifestTensors(); t < numTensors;
202251
t++) {
203-
const Value tensor = tensors[t];
252+
// TODO: this should be done through a folding pass after switching to
253+
// `sparse_tensor.iterate`-based sparsification.
254+
const Value tensor = tryFoldTensors(tensors[t]);
204255
const auto rtp = dyn_cast<RankedTensorType>(tensor.getType());
205256
if (!rtp)
206257
// Skips only scalar, zero ranked tensor still need to be bufferized and
@@ -213,14 +264,6 @@ void LoopEmitter::initializeLoopEmit(
213264
const Level lvlRank = stt.getLvlRank();
214265
const auto shape = rtp.getShape();
215266

216-
SmallVector<Value> lvlSzs;
217-
for (Level l = 0; l < stt.getLvlRank(); l++) {
218-
if (stt.hasEncoding())
219-
lvlSzs.push_back(builder.create<LvlOp>(loc, tensor, l));
220-
else
221-
lvlSzs.push_back(builder.create<tensor::DimOp>(loc, tensor, l));
222-
}
223-
224267
// Scan all levels of current tensor.
225268
for (Level l = 0; l < lvlRank; l++) {
226269
// Find upper bound in current dimension.

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp

Lines changed: 97 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -462,11 +462,54 @@ class DedupIterator : public ConcreteIterator {
462462
Value posHi;
463463
};
464464

465+
// A util base-iterator that delegates all methods to the wrapped iterator.
466+
class SimpleWrapIterator : public SparseIterator {
467+
public:
468+
SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind)
469+
: SparseIterator(kind, *wrap), wrap(std::move(wrap)) {}
470+
471+
SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
472+
return wrap->getCursorValTypes(b);
473+
}
474+
bool isBatchIterator() const override { return wrap->isBatchIterator(); }
475+
bool randomAccessible() const override { return wrap->randomAccessible(); };
476+
bool iteratableByFor() const override { return wrap->iteratableByFor(); };
477+
SmallVector<Value> serialize() const override { return wrap->serialize(); };
478+
void deserialize(ValueRange vs) override { wrap->deserialize(vs); };
479+
ValueRange getCurPosition() const override { return wrap->getCurPosition(); }
480+
void genInitImpl(OpBuilder &b, Location l,
481+
const SparseIterator *parent) override {
482+
wrap->genInit(b, l, parent);
483+
}
484+
Value genNotEndImpl(OpBuilder &b, Location l) override {
485+
return wrap->genNotEndImpl(b, l);
486+
}
487+
ValueRange forwardImpl(OpBuilder &b, Location l) override {
488+
return wrap->forward(b, l);
489+
};
490+
Value upperBound(OpBuilder &b, Location l) const override {
491+
return wrap->upperBound(b, l);
492+
};
493+
494+
Value derefImpl(OpBuilder &b, Location l) override {
495+
return wrap->derefImpl(b, l);
496+
}
497+
498+
void locateImpl(OpBuilder &b, Location l, Value crd) override {
499+
return wrap->locate(b, l, crd);
500+
}
501+
502+
SparseIterator &getWrappedIterator() const { return *wrap; }
503+
504+
protected:
505+
std::unique_ptr<SparseIterator> wrap;
506+
};
507+
465508
//
466509
// A filter iterator wrapped from another iterator. The filter iterator update
467510
// the wrapped iterator *in-place*.
468511
//
469-
class FilterIterator : public SparseIterator {
512+
class FilterIterator : public SimpleWrapIterator {
470513
// Coorindate translation between crd loaded from the wrap iterator and the
471514
// filter iterator.
472515
Value fromWrapCrd(OpBuilder &b, Location l, Value wrapCrd) const {
@@ -487,8 +530,8 @@ class FilterIterator : public SparseIterator {
487530
// when crd always < size.
488531
FilterIterator(std::unique_ptr<SparseIterator> &&wrap, Value offset,
489532
Value stride, Value size)
490-
: SparseIterator(IterKind::kFilter, *wrap), offset(offset),
491-
stride(stride), size(size), wrap(std::move(wrap)) {}
533+
: SimpleWrapIterator(std::move(wrap), IterKind::kFilter), offset(offset),
534+
stride(stride), size(size) {}
492535

493536
// For LLVM-style RTTI.
494537
static bool classof(const SparseIterator *from) {
@@ -498,19 +541,10 @@ class FilterIterator : public SparseIterator {
498541
std::string getDebugInterfacePrefix() const override {
499542
return std::string("filter<") + wrap->getDebugInterfacePrefix() + ">";
500543
}
501-
SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
502-
return wrap->getCursorValTypes(b);
503-
}
504544

505-
bool isBatchIterator() const override { return wrap->isBatchIterator(); }
506-
bool randomAccessible() const override { return wrap->randomAccessible(); };
507545
bool iteratableByFor() const override { return randomAccessible(); };
508546
Value upperBound(OpBuilder &b, Location l) const override { return size; };
509547

510-
SmallVector<Value> serialize() const override { return wrap->serialize(); };
511-
void deserialize(ValueRange vs) override { wrap->deserialize(vs); };
512-
ValueRange getCurPosition() const override { return wrap->getCurPosition(); }
513-
514548
void genInitImpl(OpBuilder &b, Location l,
515549
const SparseIterator *parent) override {
516550
wrap->genInit(b, l, parent);
@@ -541,7 +575,47 @@ class FilterIterator : public SparseIterator {
541575
ValueRange forwardImpl(OpBuilder &b, Location l) override;
542576

543577
Value offset, stride, size;
544-
std::unique_ptr<SparseIterator> wrap;
578+
};
579+
580+
//
581+
// A pad iterator wrapped from another iterator. The pad iterator updates
582+
// the wrapped iterator *in-place*.
583+
//
584+
class PadIterator : public SimpleWrapIterator {
585+
586+
public:
587+
PadIterator(std::unique_ptr<SparseIterator> &&wrap, Value padLow,
588+
Value padHigh)
589+
: SimpleWrapIterator(std::move(wrap), IterKind::kPad), padLow(padLow),
590+
padHigh(padHigh) {
591+
assert(!randomAccessible() && "Not implemented.");
592+
}
593+
594+
// For LLVM-style RTTI.
595+
static bool classof(const SparseIterator *from) {
596+
return from->kind == IterKind::kPad;
597+
}
598+
599+
std::string getDebugInterfacePrefix() const override {
600+
return std::string("pad<") + wrap->getDebugInterfacePrefix() + ">";
601+
}
602+
603+
// The upper bound after padding becomes `size + padLow + padHigh`.
604+
Value upperBound(OpBuilder &b, Location l) const override {
605+
return ADDI(ADDI(wrap->upperBound(b, l), padLow), padHigh);
606+
};
607+
608+
// The pad_coord = coord + pad_lo
609+
Value derefImpl(OpBuilder &b, Location l) override {
610+
updateCrd(ADDI(wrap->deref(b, l), padLow));
611+
return getCrd();
612+
}
613+
614+
void locateImpl(OpBuilder &b, Location l, Value crd) override {
615+
assert(randomAccessible());
616+
}
617+
618+
Value padLow, padHigh;
545619
};
546620

547621
class NonEmptySubSectIterator : public SparseIterator {
@@ -1408,10 +1482,19 @@ sparse_tensor::makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit,
14081482
return ret;
14091483
}
14101484

1485+
std::unique_ptr<SparseIterator>
1486+
sparse_tensor::makePaddedIterator(std::unique_ptr<SparseIterator> &&sit,
1487+
Value padLow, Value padHigh,
1488+
SparseEmitStrategy strategy) {
1489+
auto ret = std::make_unique<PadIterator>(std::move(sit), padLow, padHigh);
1490+
ret->setSparseEmitStrategy(strategy);
1491+
return ret;
1492+
}
1493+
14111494
static const SparseIterator *tryUnwrapFilter(const SparseIterator *it) {
14121495
auto *filter = llvm::dyn_cast_or_null<FilterIterator>(it);
14131496
if (filter)
1414-
return filter->wrap.get();
1497+
return &filter->getWrappedIterator();
14151498
return it;
14161499
}
14171500

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ enum class IterKind : uint8_t {
7676
kSubSect,
7777
kNonEmptySubSect,
7878
kFilter,
79+
kPad,
7980
};
8081

8182
/// Helper class that generates loop conditions, etc, to traverse a
@@ -291,26 +292,32 @@ std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &b,
291292
std::unique_ptr<SparseIterator> makeSimpleIterator(const SparseTensorLevel &stl,
292293
SparseEmitStrategy strategy);
293294

294-
/// Helper function to create a synthetic SparseIterator object that iterate
295+
/// Helper function to create a synthetic SparseIterator object that iterates
295296
/// over a dense space specified by [0,`sz`).
296297
std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
297298
makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl,
298299
SparseEmitStrategy strategy);
299300

300-
/// Helper function to create a SparseIterator object that iterate over a
301+
/// Helper function to create a SparseIterator object that iterates over a
301302
/// sliced space, the orignal space (before slicing) is traversed by `sit`.
302303
std::unique_ptr<SparseIterator>
303304
makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit, Value offset,
304305
Value stride, Value size, SparseEmitStrategy strategy);
305306

307+
/// Helper function to create a SparseIterator object that iterates over a
308+
/// padded sparse level (the padded value must be zero).
309+
std::unique_ptr<SparseIterator>
310+
makePaddedIterator(std::unique_ptr<SparseIterator> &&sit, Value padLow,
311+
Value padHigh, SparseEmitStrategy strategy);
312+
306313
/// Helper function to create a SparseIterator object that iterate over the
307314
/// non-empty subsections set.
308315
std::unique_ptr<SparseIterator> makeNonEmptySubSectIterator(
309316
OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound,
310317
std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride,
311318
SparseEmitStrategy strategy);
312319

313-
/// Helper function to create a SparseIterator object that iterate over a
320+
/// Helper function to create a SparseIterator object that iterates over a
314321
/// non-empty subsection created by NonEmptySubSectIterator.
315322
std::unique_ptr<SparseIterator> makeTraverseSubSectIterator(
316323
OpBuilder &b, Location l, const SparseIterator &subsectIter,

0 commit comments

Comments
 (0)