Skip to content

Commit 5618d2b

Browse files
committed
[mlir][sparse] Add option enable-buffer-initialization to initialize the memory buffers for sparse tensors to support debugging.
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D137592
1 parent 20d6f63 commit 5618d2b

File tree

6 files changed

+40
-8
lines changed

6 files changed

+40
-8
lines changed

mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ struct SparseCompilerOptions
6363
*this, "test-bufferization-analysis-only",
6464
desc("Run only the inplacability analysis"), init(false)};
6565

66+
PassOptions::Option<bool> enableBufferInitialization{
67+
*this, "enable-buffer-initialization",
68+
desc("Enable zero-initialization of memory buffers"), init(false)};
69+
6670
/// Projects out the options for `createSparsificationPass`.
6771
SparsificationOptions sparsificationOptions() const {
6872
return SparsificationOptions(parallelization);

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,10 @@ std::unique_ptr<Pass> createSparseTensorRewritePass(bool enableRT,
153153
std::unique_ptr<Pass> createDenseBufferizationPass(
154154
const bufferization::OneShotBufferizationOptions &options);
155155

156-
void populateSparseBufferRewriting(RewritePatternSet &patterns);
157-
std::unique_ptr<Pass> createSparseBufferRewritePass();
156+
void populateSparseBufferRewriting(RewritePatternSet &patterns,
157+
bool enableBufferInitialization);
158+
std::unique_ptr<Pass>
159+
createSparseBufferRewritePass(bool enableBufferInitialization = false);
158160

159161
//===----------------------------------------------------------------------===//
160162
// Registration.

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,10 @@ def SparseBufferRewrite : Pass<"sparse-buffer-rewrite", "ModuleOp"> {
198198
"scf::SCFDialect",
199199
"sparse_tensor::SparseTensorDialect",
200200
];
201+
let options = [
202+
Option<"enableBufferInitialization", "enable-buffer-initialization", "bool",
203+
"false", "Enable zero-initialization of the memory buffers">,
204+
];
201205
}
202206

203207
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES

mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ void mlir::sparse_tensor::buildSparseCompiler(
6565
options.sparseTensorConversionOptions()));
6666
else
6767
pm.addPass(createSparseTensorCodegenPass());
68-
pm.addPass(createSparseBufferRewritePass());
68+
pm.addPass(createSparseBufferRewritePass(options.enableBufferInitialization));
6969
pm.addPass(createDenseBufferizationPass(
7070
getBufferizationOptions(/*analysisOnly=*/false)));
7171
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());

mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,8 @@ namespace {
635635
struct PushBackRewriter : OpRewritePattern<PushBackOp> {
636636
public:
637637
using OpRewritePattern<PushBackOp>::OpRewritePattern;
638+
PushBackRewriter(MLIRContext *context, bool enableInit)
639+
: OpRewritePattern(context), enableBufferInitialization(enableInit) {}
638640
LogicalResult matchAndRewrite(PushBackOp op,
639641
PatternRewriter &rewriter) const override {
640642
// Rewrite push_back(buffer, value, n) to:
@@ -705,6 +707,16 @@ struct PushBackRewriter : OpRewritePattern<PushBackOp> {
705707

706708
Value newBuffer =
707709
rewriter.create<memref::ReallocOp>(loc, bufferType, buffer, capacity);
710+
if (enableBufferInitialization) {
711+
Value fillSize = rewriter.create<arith::SubIOp>(loc, capacity, newSize);
712+
Value fillValue = rewriter.create<arith::ConstantOp>(
713+
loc, value.getType(), rewriter.getZeroAttr(value.getType()));
714+
Value subBuffer = rewriter.create<memref::SubViewOp>(
715+
loc, newBuffer, /*offset=*/ValueRange{newSize},
716+
/*size=*/ValueRange{fillSize},
717+
/*step=*/ValueRange{constantIndex(rewriter, loc, 1)});
718+
rewriter.create<linalg::FillOp>(loc, fillValue, subBuffer);
719+
}
708720
rewriter.create<scf::YieldOp>(loc, newBuffer);
709721

710722
// False branch.
@@ -731,6 +743,9 @@ struct PushBackRewriter : OpRewritePattern<PushBackOp> {
731743
rewriter.replaceOp(op, buffer);
732744
return success();
733745
}
746+
747+
private:
748+
bool enableBufferInitialization;
734749
};
735750

736751
/// Sparse rewriting rule for the sort operator.
@@ -777,6 +792,9 @@ struct SortRewriter : public OpRewritePattern<SortOp> {
777792
// Methods that add patterns described in this file to a pattern list.
778793
//===---------------------------------------------------------------------===//
779794

780-
void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns) {
781-
patterns.add<PushBackRewriter, SortRewriter>(patterns.getContext());
795+
void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns,
796+
bool enableBufferInitialization) {
797+
patterns.add<PushBackRewriter>(patterns.getContext(),
798+
enableBufferInitialization);
799+
patterns.add<SortRewriter>(patterns.getContext());
782800
}

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,11 +215,14 @@ struct SparseBufferRewritePass
215215

216216
SparseBufferRewritePass() = default;
217217
SparseBufferRewritePass(const SparseBufferRewritePass &pass) = default;
218+
SparseBufferRewritePass(bool enableInit) {
219+
enableBufferInitialization = enableInit;
220+
}
218221

219222
void runOnOperation() override {
220223
auto *ctx = &getContext();
221224
RewritePatternSet patterns(ctx);
222-
populateSparseBufferRewriting(patterns);
225+
populateSparseBufferRewriting(patterns, enableBufferInitialization);
223226
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
224227
}
225228
};
@@ -279,6 +282,7 @@ std::unique_ptr<Pass> mlir::createSparseTensorCodegenPass() {
279282
return std::make_unique<SparseTensorCodegenPass>();
280283
}
281284

282-
std::unique_ptr<Pass> mlir::createSparseBufferRewritePass() {
283-
return std::make_unique<SparseBufferRewritePass>();
285+
std::unique_ptr<Pass>
286+
mlir::createSparseBufferRewritePass(bool enableBufferInitialization) {
287+
return std::make_unique<SparseBufferRewritePass>(enableBufferInitialization);
284288
}

0 commit comments

Comments
 (0)