Skip to content

Commit 4a653b4

Browse files
authored
[mlir][sparse] Support pretty print to debug sparse iteration. (#80207)
1 parent 4eac146 commit 4a653b4

File tree

9 files changed

+293
-349
lines changed

9 files changed

+293
-349
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,15 @@ struct SparsificationPass
9595
SparsificationPass(const SparsificationPass &pass) = default;
9696
SparsificationPass(const SparsificationOptions &options) {
9797
parallelization = options.parallelizationStrategy;
98+
sparseEmitStrategy = options.sparseEmitStrategy;
9899
enableRuntimeLibrary = options.enableRuntimeLibrary;
99100
}
100101

101102
void runOnOperation() override {
102103
auto *ctx = &getContext();
103104
// Translate strategy flags to strategy options.
104-
SparsificationOptions options(parallelization, enableRuntimeLibrary);
105+
SparsificationOptions options(parallelization, sparseEmitStrategy,
106+
enableRuntimeLibrary);
105107
// Apply sparsification and cleanup rewriting.
106108
RewritePatternSet patterns(ctx);
107109
populateSparsificationPatterns(patterns, options);

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1369,7 +1369,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
13691369
return failure();
13701370

13711371
// Recursively generates code if admissible.
1372-
env.startEmit();
1372+
env.startEmit(options.sparseEmitStrategy);
13731373
genBuffers(env, rewriter);
13741374
// TODO: Constant affine expression should be handled differently when using
13751375
// slice-based codegen, it does not matter now because we already reject the

mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ LogicalResult CodegenEnv::initTensorExp() {
5959
return success();
6060
}
6161

62-
void CodegenEnv::startEmit() {
62+
void CodegenEnv::startEmit(SparseEmitStrategy emitStrategy) {
6363
assert(insChain == nullptr && "must only start emitting once");
6464
if (sparseOut) {
6565
insChain = sparseOut->get();
@@ -96,7 +96,8 @@ void CodegenEnv::startEmit() {
9696
/*dependentLvlGetter=*/
9797
[this](TensorId t, Level lvl) -> std::vector<LoopCoeffPair> {
9898
return merger().getDependentLoops(t, lvl);
99-
});
99+
},
100+
emitStrategy);
100101
}
101102

102103
std::optional<Operation *> CodegenEnv::genLoopBoundary(

mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class CodegenEnv {
5252
Merger &merger() { return latticeMerger; }
5353
LoopEmitter &emitter() { return loopEmitter; }
5454

55-
void startEmit();
55+
void startEmit(SparseEmitStrategy emitStrategy);
5656

5757
/// Generates loop boundary statements (entering/exiting loops). The function
5858
/// passes and updates the passed-in parameters.

mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,17 +81,20 @@ static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor,
8181

8282
LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput,
8383
bool isSparseOut, unsigned numLoops,
84-
DependentLvlGetter dimGetter) {
84+
DependentLvlGetter dimGetter,
85+
SparseEmitStrategy emitStrategy) {
8586
initialize(tensors, loopTag, hasOutput, isSparseOut, numLoops, dimGetter);
8687
}
8788

8889
void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
8990
bool isSparseOut, unsigned numLoops,
90-
DependentLvlGetter dimGetter) {
91+
DependentLvlGetter dimGetter,
92+
SparseEmitStrategy emitStrategy) {
9193
// First initialize the top-level type of the fields.
9294
this->loopTag = loopTag;
9395
this->hasOutput = hasOutput;
9496
this->isSparseOut = isSparseOut;
97+
SparseIterator::setSparseEmitStrategy(emitStrategy);
9598

9699
const unsigned numManifestTensors = ts.size();
97100
const unsigned synTensorId = numManifestTensors;
@@ -169,7 +172,7 @@ LoopEmitter::makeLevelIterator(OpBuilder &builder, Location loc, TensorId t,
169172
Value offset = genSliceOffset(builder, loc, tensors[t], l);
170173
Value stride = genSliceStride(builder, loc, tensors[t], l);
171174
auto slicedIt = makeSlicedLevelIterator(std::move(it), offset, stride,
172-
lvls[t][l]->size());
175+
lvls[t][l]->getSize());
173176
return slicedIt;
174177
}
175178
return it;
@@ -465,7 +468,7 @@ std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(
465468

466469
// Construct the while-loop with a parameter for each coordinate.
467470
for (SparseIterator *it : spIters) {
468-
ValueRange itVals = it->getItVals();
471+
ValueRange itVals = it->getCursor();
469472
ivs.append(itVals.begin(), itVals.end());
470473
}
471474

@@ -724,7 +727,7 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
724727
// Forward the sparse iterator.
725728
Value cmp = CMPI(eq, it.getCrd(), iv);
726729
it.forwardIf(builder, loc, cmp);
727-
operands.append(it.getItVals().begin(), it.getItVals().end());
730+
operands.append(it.getCursor().begin(), it.getCursor().end());
728731
// const Value newPos = whileOp->getResult(o++);
729732
// Following loops continue iteration from the break point of the
730733
// current while loop.

mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
1717
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
18+
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
1819
#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
1920
#include "mlir/IR/PatternMatch.h"
2021

@@ -84,14 +85,17 @@ class LoopEmitter {
8485
/// `isSparseOut` indicates that the sparse output tensor is empty,
8586
/// so the loop emitter will generate loops over it according to the
8687
/// level-sizes.
87-
void initialize(ValueRange tensors, StringAttr loopTag = nullptr,
88-
bool hasOutput = false, bool isSparseOut = false,
89-
unsigned numLoops = 0, DependentLvlGetter getter = nullptr);
90-
91-
explicit LoopEmitter(ValueRange tensors, StringAttr loopTag = nullptr,
92-
bool hasOutput = false, bool isSparseOut = false,
93-
unsigned numLoops = 0,
94-
DependentLvlGetter getter = nullptr);
88+
void
89+
initialize(ValueRange tensors, StringAttr loopTag = nullptr,
90+
bool hasOutput = false, bool isSparseOut = false,
91+
unsigned numLoops = 0, DependentLvlGetter getter = nullptr,
92+
SparseEmitStrategy emitStrategy = SparseEmitStrategy::kFunctional);
93+
94+
explicit LoopEmitter(
95+
ValueRange tensors, StringAttr loopTag = nullptr, bool hasOutput = false,
96+
bool isSparseOut = false, unsigned numLoops = 0,
97+
DependentLvlGetter getter = nullptr,
98+
SparseEmitStrategy emitStrategy = SparseEmitStrategy::kFunctional);
9599

96100
/// Starts a loop emitting session by generating all the buffers needed
97101
/// for iterating over the tensors.

0 commit comments

Comments
 (0)