-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][sparse] Support pretty print to debug sparse iteration. #80207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@llvm/pr-subscribers-mlir-sparse @llvm/pr-subscribers-mlir Author: Peiming Liu (PeimingLiu) ChangesPatch is 71.57 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/80207.diff 11 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index e93e2aefb344f..8b2875a751d4a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -47,6 +47,12 @@ enum class ReinterpretMapScope {
kExceptGeneric, // reinterprets operation other than linalg.generic
};
+/// Defines a scope for reinterpret map pass.
+enum class DebugSparseIteration {
+ kNone, // generate fully inlined (and functional) sparse iteration
+ kInterfaceOnly, // generate only place-holder for sparse iteration
+};
+
#define GEN_PASS_DECL
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
@@ -74,11 +80,20 @@ std::unique_ptr<Pass> createPreSparsificationRewritePass();
/// Options for the Sparsification pass.
struct SparsificationOptions {
+ SparsificationOptions(SparseParallelizationStrategy p, DebugSparseIteration d,
+ bool enableRT)
+ : parallelizationStrategy(p), debugSparseIteration(d),
+ enableRuntimeLibrary(enableRT) {}
+
SparsificationOptions(SparseParallelizationStrategy p, bool enableRT)
- : parallelizationStrategy(p), enableRuntimeLibrary(enableRT) {}
+ : SparsificationOptions(p, DebugSparseIteration::kNone, enableRT) {}
+
SparsificationOptions()
- : SparsificationOptions(SparseParallelizationStrategy::kNone, true) {}
+ : SparsificationOptions(SparseParallelizationStrategy::kNone,
+ DebugSparseIteration::kNone, true) {}
+
SparseParallelizationStrategy parallelizationStrategy;
+ DebugSparseIteration debugSparseIteration;
bool enableRuntimeLibrary;
};
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index f38779ed9ed2b..126b91510d391 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -130,6 +130,13 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
clEnumValN(mlir::SparseParallelizationStrategy::kAnyStorageAnyLoop,
"any-storage-any-loop",
"Enable sparse parallelization for any storage and loop."))}]>,
+ Option<"debugSparseIteration", "debug-sparse-iteration", "mlir::DebugSparseIteration",
+ "mlir::DebugSparseIteration::kNone",
+ "Pretty print sparse loops to debug sparse iteration", [{llvm::cl::values(
+ clEnumValN(mlir::DebugSparseIteration::kNone, "none",
+ "Turn off pretty printing and generates functional code."),
+ clEnumValN(mlir::DebugSparseIteration::kInterfaceOnly, "interface-only",
+ "Generate non-functional interfaces for sparse iteration."))}]>,
Option<"enableRuntimeLibrary", "enable-runtime-library", "bool",
"true", "Enable runtime library for manipulating sparse tensors">,
];
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 375e10f9068e4..0ae9f6483588d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -82,13 +82,15 @@ struct SparsificationPass
SparsificationPass(const SparsificationPass &pass) = default;
SparsificationPass(const SparsificationOptions &options) {
parallelization = options.parallelizationStrategy;
+ debugSparseIteration = options.debugSparseIteration;
enableRuntimeLibrary = options.enableRuntimeLibrary;
}
void runOnOperation() override {
auto *ctx = &getContext();
// Translate strategy flags to strategy options.
- SparsificationOptions options(parallelization, enableRuntimeLibrary);
+ SparsificationOptions options(parallelization, debugSparseIteration,
+ enableRuntimeLibrary);
// Apply sparsification and cleanup rewriting.
RewritePatternSet patterns(ctx);
populateSparsificationPatterns(patterns, options);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 5266ca7213bfc..2ceb214052aa3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1369,7 +1369,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
return failure();
// Recursively generates code if admissible.
- env.startEmit();
+ env.startEmit(options.debugSparseIteration);
genBuffers(env, rewriter);
// TODO: Constant affine expression should be handled differently when using
// slice-based codegen, it does not matter now because we already reject the
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp
index d3de55e4d59bd..0af1cc1745f51 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp
@@ -59,7 +59,7 @@ LogicalResult CodegenEnv::initTensorExp() {
return success();
}
-void CodegenEnv::startEmit() {
+void CodegenEnv::startEmit(DebugSparseIteration emitStrategy) {
assert(insChain == nullptr && "must only start emitting once");
if (sparseOut) {
insChain = sparseOut->get();
@@ -96,7 +96,8 @@ void CodegenEnv::startEmit() {
/*dependentLvlGetter=*/
[this](TensorId t, Level lvl) -> std::vector<LoopCoeffPair> {
return merger().getDependentLoops(t, lvl);
- });
+ },
+ emitStrategy);
}
std::optional<Operation *> CodegenEnv::genLoopBoundary(
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
index 728af841cc7b1..7eeddac48f4f1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
@@ -52,7 +52,7 @@ class CodegenEnv {
Merger &merger() { return latticeMerger; }
LoopEmitter &emitter() { return loopEmitter; }
- void startEmit();
+ void startEmit(DebugSparseIteration emitStrategy);
/// Generates loop boundary statements (entering/exiting loops). The function
/// passes and updates the passed-in parameters.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index 3fa4004ae460e..8c1680a393181 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -81,17 +81,20 @@ static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor,
LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput,
bool isSparseOut, unsigned numLoops,
- DependentLvlGetter dimGetter) {
+ DependentLvlGetter dimGetter,
+ DebugSparseIteration emitStrategy) {
initialize(tensors, loopTag, hasOutput, isSparseOut, numLoops, dimGetter);
}
void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
bool isSparseOut, unsigned numLoops,
- DependentLvlGetter dimGetter) {
+ DependentLvlGetter dimGetter,
+ DebugSparseIteration emitStrategy) {
// First initialize the top-level type of the fields.
this->loopTag = loopTag;
this->hasOutput = hasOutput;
this->isSparseOut = isSparseOut;
+ SparseIterator::setDebugSparseIteration(emitStrategy);
const unsigned numManifestTensors = ts.size();
const unsigned synTensorId = numManifestTensors;
@@ -169,7 +172,7 @@ LoopEmitter::makeLevelIterator(OpBuilder &builder, Location loc, TensorId t,
Value offset = genSliceOffset(builder, loc, tensors[t], l);
Value stride = genSliceStride(builder, loc, tensors[t], l);
auto slicedIt = makeSlicedLevelIterator(std::move(it), offset, stride,
- lvls[t][l]->size());
+ lvls[t][l]->getSize());
return slicedIt;
}
return it;
@@ -465,7 +468,7 @@ std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(
// Construct the while-loop with a parameter for each coordinate.
for (SparseIterator *it : spIters) {
- ValueRange itVals = it->getItVals();
+ ValueRange itVals = it->getCursor();
ivs.append(itVals.begin(), itVals.end());
}
@@ -724,7 +727,7 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
// Forward the sparse iterator.
Value cmp = CMPI(eq, it.getCrd(), iv);
it.forwardIf(builder, loc, cmp);
- operands.append(it.getItVals().begin(), it.getItVals().end());
+ operands.append(it.getCursor().begin(), it.getCursor().end());
// const Value newPos = whileOp->getResult(o++);
// Following loops continue iteration from the break point of the
// current while loop.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index d0f447d926f71..e0b4f81487a68 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -15,6 +15,7 @@
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
#include "mlir/IR/PatternMatch.h"
@@ -84,14 +85,17 @@ class LoopEmitter {
/// `isSparseOut` indicates that the sparse output tensor is empty,
/// so the loop emitter will generate loops over it according to the
/// level-sizes.
- void initialize(ValueRange tensors, StringAttr loopTag = nullptr,
- bool hasOutput = false, bool isSparseOut = false,
- unsigned numLoops = 0, DependentLvlGetter getter = nullptr);
-
- explicit LoopEmitter(ValueRange tensors, StringAttr loopTag = nullptr,
- bool hasOutput = false, bool isSparseOut = false,
- unsigned numLoops = 0,
- DependentLvlGetter getter = nullptr);
+ void
+ initialize(ValueRange tensors, StringAttr loopTag = nullptr,
+ bool hasOutput = false, bool isSparseOut = false,
+ unsigned numLoops = 0, DependentLvlGetter getter = nullptr,
+ DebugSparseIteration emitStrategy = DebugSparseIteration::kNone);
+
+ explicit LoopEmitter(
+ ValueRange tensors, StringAttr loopTag = nullptr, bool hasOutput = false,
+ bool isSparseOut = false, unsigned numLoops = 0,
+ DependentLvlGetter getter = nullptr,
+ DebugSparseIteration emitStrategy = DebugSparseIteration::kNone);
/// Starts a loop emitting session by generating all the buffers needed
/// for iterating over the tensors.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index 98323c2195461..bdaf794744bea 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -46,20 +46,6 @@ using ValueTuple = std::tuple<Value, Value, Value>;
namespace {
-class SparseLevel : public SparseTensorLevel {
-public:
- SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
- Value crdBuffer)
- : SparseTensorLevel(tid, lvl, lt, lvlSize), crdBuffer(crdBuffer) {}
-
- Value peekCrdAt(OpBuilder &b, Location l, Value iv) const override {
- return genIndexLoad(b, l, crdBuffer, iv);
- }
-
-protected:
- const Value crdBuffer;
-};
-
class DenseLevel : public SparseTensorLevel {
public:
DenseLevel(unsigned tid, Level lvl, Value lvlSize, bool encoded)
@@ -74,53 +60,27 @@ class DenseLevel : public SparseTensorLevel {
Value max) const override {
assert(max == nullptr && "Dense level can not be non-unique.");
if (encoded) {
- Value posLo = MULI(p, lvlSize);
- return {posLo, lvlSize};
+ Value posLo = MULI(p, getSize());
+ return {posLo, getSize()};
}
// No need to linearize the position for non-annotated tensors.
- return {C_IDX(0), lvlSize};
+ return {C_IDX(0), getSize()};
}
const bool encoded;
};
-class CompressedLevel : public SparseLevel {
+class SparseLevel : public SparseTensorLevel {
public:
- CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
- Value posBuffer, Value crdBuffer)
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
-
- ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
- Value max) const override {
- if (max == nullptr) {
- Value pLo = genIndexLoad(b, l, posBuffer, p);
- Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1)));
- return {pLo, pHi};
- }
- llvm_unreachable("compressed-nu should be the first non-unique level.");
+ SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
+ ValueRange lvlBuf)
+ : SparseTensorLevel(tid, lvl, lt, lvlSize, lvlBuf) {
+ assert(!lvlBuf.empty());
}
-private:
- const Value posBuffer;
-};
-
-class LooseCompressedLevel : public SparseLevel {
-public:
- LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
- Value posBuffer, Value crdBuffer)
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
-
- ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
- Value max) const override {
- assert(max == nullptr && "loss compressed level can not be non-unique.");
- p = MULI(p, C_IDX(2));
- Value pLo = genIndexLoad(b, l, posBuffer, p);
- Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1)));
- return {pLo, pHi};
+ Value peekCrdAt(OpBuilder &b, Location l, Value iv) const override {
+ return genIndexLoad(b, l, getLvlBufs().front(), iv);
}
-
-private:
- const Value posBuffer;
};
class SingletonLevel : public SparseLevel {
@@ -142,8 +102,8 @@ class SingletonLevel : public SparseLevel {
class TwoOutFourLevel : public SparseLevel {
public:
TwoOutFourLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
- Value crdBuffer)
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
+ Value crdBuf)
+ : SparseLevel(tid, lvl, lt, lvlSize, crdBuf) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
Value max) const override {
@@ -154,6 +114,39 @@ class TwoOutFourLevel : public SparseLevel {
}
};
+class CompressedLevel : public SparseLevel {
+public:
+ CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
+ Value posBuffer, Value crdBuffer)
+ : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer, posBuffer}) {}
+
+ ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
+ Value max) const override {
+ if (max == nullptr) {
+ Value pLo = genIndexLoad(b, l, getPosBuf(), p);
+ Value pHi = genIndexLoad(b, l, getPosBuf(), ADDI(p, C_IDX(1)));
+ return {pLo, pHi};
+ }
+ llvm_unreachable("compressed-nu should be the first non-unique level.");
+ }
+};
+
+class LooseCompressedLevel : public SparseLevel {
+public:
+ LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
+ Value posBuffer, Value crdBuffer)
+ : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer, posBuffer}) {}
+
+ ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
+ Value max) const override {
+ assert(max == nullptr && "loss compressed level can not be non-unique.");
+ p = MULI(p, C_IDX(2));
+ Value pLo = genIndexLoad(b, l, getPosBuf(), p);
+ Value pHi = genIndexLoad(b, l, getPosBuf(), ADDI(p, C_IDX(1)));
+ return {pLo, pHi};
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -203,7 +196,8 @@ static Value offsetFromMinCrd(OpBuilder &b, Location l, Value minCrd,
// SparseIterator derived classes.
//===----------------------------------------------------------------------===//
-namespace {
+namespace mlir {
+namespace sparse_tensor {
// The iterator that traverses a concrete sparse tensor levels. High-level
// abstract iterators wrap it to achieve more complex goals (such as collapsing
@@ -212,12 +206,11 @@ namespace {
class ConcreteIterator : public SparseIterator {
protected:
ConcreteIterator(const SparseTensorLevel &stl, IterKind kind,
- unsigned itValCnt)
- : SparseIterator(kind, stl.tid, stl.lvl, itValCnt, itValsStorage),
- stl(stl) {
- // Allocate enough storage for iterator values.
- itValsStorage.resize(itValCnt);
- }
+ unsigned cursorValCnt)
+ : SparseIterator(kind, stl.tid, stl.lvl, cursorValCnt, cursorValsStorage),
+ stl(stl), cursorValsStorage(cursorValCnt, nullptr) {
+ assert(getCursor().size() == cursorValCnt);
+ };
public:
// For LLVM-style RTTI.
@@ -228,22 +221,34 @@ class ConcreteIterator : public SparseIterator {
bool randomAccessible() const override { return isDenseLT(stl.getLT()); };
bool iteratableByFor() const override { return kind != IterKind::kDedup; };
Value upperBound(OpBuilder &b, Location l) const override {
- return stl.size();
+ return stl.getSize();
};
protected:
+ const SparseTensorLevel &stl;
// Owner of the storage, all wrappers build on top of a concrete iterator
// share the same storage such that the iterator values are always
// synchronized.
- SmallVector<Value> itValsStorage;
- const SparseTensorLevel &stl;
+ SmallVector<Value> cursorValsStorage;
};
+} // namespace sparse_tensor
+} // namespace mlir
+
+namespace {
+
class TrivialIterator : public ConcreteIterator {
public:
TrivialIterator(const SparseTensorLevel &stl)
: ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1) {}
+ std::string getDebugInterfacePrefix() const override {
+ return std::string("trivial<") + stl.toString() + ">";
+ }
+ SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
+ return {b.getIndexType()};
+ }
+
SmallVector<Value> serialize() const override {
SmallVector<Value> ret;
ret.push_back(getItPos());
@@ -286,12 +291,12 @@ class TrivialIterator : public ConcreteIterator {
return std::make_pair(getItPos(), posHi);
}
- Value genNotEnd(OpBuilder &b, Location l) override {
+ Value genNotEndImpl(OpBuilder &b, Location l) override {
// We used the first level bound as the bound the collapsed set of levels.
return CMPI(ult, getItPos(), posHi);
}
- Value deref(OpBuilder &b, Location l) override {
+ Value derefImpl(OpBuilder &b, Location l) override {
if (randomAccessible()) {
updateCrd(SUBI(getItPos(), posLo));
} else {
@@ -302,24 +307,24 @@ class TrivialIterator : public ConcreteIterator {
ValueRange forwardImpl(OpBuilder &b, Location l) override {
seek(ADDI(getItPos(), C_IDX(1)));
- return getItVals();
+ return getCursor();
}
ValueRange forwardIf(OpBuilder &b, Location l, Value cond) override {
- Value curPos = getItVals().front();
+ Value curPos = getCursor().front();
Value nxPos = forward(b, l).front();
seek(SELECT(cond, nxPos, curPos));
- return getItVals();
+ return getCursor();
}
- void locate(OpBuilder &b, Location l, Value crd) override {
+ void locateImpl(OpBuilder &b, Location l, Value crd) override {
assert(randomAccessible());
// Seek to the linearized position.
seek(ADDI(crd, posLo));
updateCrd(crd);
}
- Value getItPos() const { return getItVals().front(); }
+ Value getItPos() const { return getCursor().front(); }
Value posLo, posHi;
};
@@ -337,6 +342,13 @@ class DedupIterator : public ...
[truncated]
|
9c6fc88
to
3695e6f
Compare
yinying-lisa-li
approved these changes
Feb 1, 2024
aartbik
approved these changes
Feb 1, 2024
3695e6f
to
f0c9c74
Compare
6d99e7e
to
131471a
Compare
agozillon
pushed a commit
to agozillon/llvm-project
that referenced
this pull request
Feb 5, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
No description provided.