Skip to content

Commit 8154494

Browse files
committed
[mlir][sparse] refactor sparsification and bufferization pass into proper TD pass
Registering the SparsificationAndBufferization into a proper TD pass has the advantage that it can be invoked and tested in isolation. This change also moves some bufferization specific set up from the pipeline file into the pass file, keeping the logic more locally. Reviewed By: Peiming Differential Revision: https://reviews.llvm.org/D158219
1 parent 110d141 commit 8154494

File tree

4 files changed

+73
-40
lines changed

4 files changed

+73
-40
lines changed

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
//===----------------------------------------------------------------------===//
2323

2424
namespace mlir {
25+
2526
namespace bufferization {
2627
struct OneShotBufferizationOptions;
2728
} // namespace bufferization
@@ -215,12 +216,13 @@ std::unique_ptr<Pass> createStorageSpecifierToLLVMPass();
215216

216217
//===----------------------------------------------------------------------===//
217218
// The mini-pipeline for sparsification and bufferization.
218-
//
219-
// Note that this mini-pipeline is not defined through the tablegen pass
220-
// mechanism, and, thus, is not individually available through the command-line.
221-
// It is solely used as part of the full sparse compiler pipeline.
222219
//===----------------------------------------------------------------------===//
223220

221+
bufferization::OneShotBufferizationOptions
222+
getBufferizationOptionsForSparsification(bool analysisOnly);
223+
224+
std::unique_ptr<Pass> createSparsificationAndBufferizationPass();
225+
224226
std::unique_ptr<Pass> createSparsificationAndBufferizationPass(
225227
const bufferization::OneShotBufferizationOptions &bufferizationOptions,
226228
const SparsificationOptions &sparsificationOptions,

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,4 +373,23 @@ def StorageSpecifierToLLVM : Pass<"sparse-storage-specifier-to-llvm", "ModuleOp"
373373
];
374374
}
375375

376+
def SparsificationAndBufferization : Pass<"sparsification-and-bufferization", "ModuleOp"> {
377+
let summary = "Mini-pipeline that combines bufferization and sparsifiation";
378+
let description = [{
379+
This pass forms a mini-pipeline that combines bufferization and sparsifiation.
380+
}];
381+
let constructor = "mlir::createSparsificationAndBufferizationPass()";
382+
let dependentDialects = [
383+
"affine::AffineDialect",
384+
"arith::ArithDialect",
385+
"bufferization::BufferizationDialect",
386+
"gpu::GPUDialect",
387+
"LLVM::LLVMDialect",
388+
"linalg::LinalgDialect",
389+
"memref::MemRefDialect",
390+
"scf::SCFDialect",
391+
"sparse_tensor::SparseTensorDialect",
392+
];
393+
}
394+
376395
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES

mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -25,31 +25,6 @@
2525
#include "mlir/Pass/PassManager.h"
2626
#include "mlir/Transforms/Passes.h"
2727

28-
using namespace mlir;
29-
using namespace mlir::sparse_tensor;
30-
31-
/// Return configuration options for One-Shot Bufferize.
32-
static bufferization::OneShotBufferizationOptions
33-
getBufferizationOptions(bool analysisOnly) {
34-
using namespace bufferization;
35-
OneShotBufferizationOptions options;
36-
options.bufferizeFunctionBoundaries = true;
37-
// TODO(springerm): To spot memory leaks more easily, returning dense allocs
38-
// should be disallowed.
39-
options.allowReturnAllocs = true;
40-
options.setFunctionBoundaryTypeConversion(LayoutMapOption::IdentityLayoutMap);
41-
options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
42-
const BufferizationOptions &options) {
43-
return getMemRefTypeWithStaticIdentityLayout(
44-
cast<TensorType>(value.getType()), memorySpace);
45-
};
46-
if (analysisOnly) {
47-
options.testAnalysisOnly = true;
48-
options.printConflicts = true;
49-
}
50-
return options;
51-
}
52-
5328
//===----------------------------------------------------------------------===//
5429
// Pipeline implementation.
5530
//===----------------------------------------------------------------------===//
@@ -58,7 +33,8 @@ void mlir::sparse_tensor::buildSparseCompiler(
5833
OpPassManager &pm, const SparseCompilerOptions &options) {
5934
pm.addNestedPass<func::FuncOp>(createLinalgGeneralizationPass());
6035
pm.addPass(createSparsificationAndBufferizationPass(
61-
getBufferizationOptions(options.testBufferizationAnalysisOnly),
36+
getBufferizationOptionsForSparsification(
37+
options.testBufferizationAnalysisOnly),
6238
options.sparsificationOptions(), options.sparseTensorConversionOptions(),
6339
options.createSparseDeallocs, options.enableRuntimeLibrary,
6440
options.enableBufferInitialization, options.vectorLength,

mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
1010

11+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1112
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1213
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1314
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
@@ -18,15 +19,21 @@
1819
#include "mlir/Dialect/Func/IR/FuncOps.h"
1920
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
2021
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
23+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
24+
#include "mlir/Dialect/SCF/IR/SCF.h"
2125
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
2226
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
2327
#include "mlir/Pass/PassManager.h"
2428
#include "mlir/Transforms/Passes.h"
2529

2630
using namespace mlir;
27-
using namespace mlir::func;
2831

2932
namespace mlir {
33+
34+
#define GEN_PASS_DEF_SPARSIFICATIONANDBUFFERIZATION
35+
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
36+
3037
namespace sparse_tensor {
3138

3239
/// Return `true` if one of the given types is a sparse tensor type.
@@ -50,8 +57,8 @@ static bool containsSparseTensor(TypeRange types) {
5057
/// * Dense tensor ops are lowered through BufferizableOpInterface
5158
/// implementations.
5259
class SparsificationAndBufferizationPass
53-
: public PassWrapper<SparsificationAndBufferizationPass,
54-
OperationPass<ModuleOp>> {
60+
: public impl::SparsificationAndBufferizationBase<
61+
SparsificationAndBufferizationPass> {
5562
public:
5663
SparsificationAndBufferizationPass(
5764
const bufferization::OneShotBufferizationOptions &bufferizationOptions,
@@ -97,12 +104,6 @@ class SparsificationAndBufferizationPass
97104
return success();
98105
}
99106

100-
void getDependentDialects(::mlir::DialectRegistry &registry) const override {
101-
registry.insert<bufferization::BufferizationDialect>();
102-
registry.insert<gpu::GPUDialect>();
103-
registry.insert<LLVM::LLVMDialect>();
104-
}
105-
106107
void runOnOperation() override {
107108
{
108109
// Run enabling transformations.
@@ -179,7 +180,42 @@ class SparsificationAndBufferizationPass
179180
} // namespace sparse_tensor
180181
} // namespace mlir
181182

182-
std::unique_ptr<Pass> mlir::createSparsificationAndBufferizationPass(
183+
mlir::bufferization::OneShotBufferizationOptions
184+
mlir::getBufferizationOptionsForSparsification(bool analysisOnly) {
185+
using namespace mlir::bufferization;
186+
OneShotBufferizationOptions options;
187+
options.bufferizeFunctionBoundaries = true;
188+
// TODO(springerm): To spot memory leaks more easily, returning dense allocs
189+
// should be disallowed.
190+
options.allowReturnAllocs = true;
191+
options.setFunctionBoundaryTypeConversion(LayoutMapOption::IdentityLayoutMap);
192+
options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
193+
const BufferizationOptions &options) {
194+
return getMemRefTypeWithStaticIdentityLayout(
195+
cast<TensorType>(value.getType()), memorySpace);
196+
};
197+
if (analysisOnly) {
198+
options.testAnalysisOnly = true;
199+
options.printConflicts = true;
200+
}
201+
return options;
202+
}
203+
204+
std::unique_ptr<mlir::Pass> mlir::createSparsificationAndBufferizationPass() {
205+
SparsificationOptions sparseOptions;
206+
SparseTensorConversionOptions convOptions;
207+
return createSparsificationAndBufferizationPass(
208+
getBufferizationOptionsForSparsification(/*analysisOnly=*/false),
209+
sparseOptions, convOptions,
210+
/*createSparseDeallocs=*/false,
211+
/*enableRuntimeLibrary=*/false,
212+
/*enableBufferInitialization=*/false,
213+
/*vectorLength=*/0,
214+
/*enableVLAVectorization=*/false,
215+
/*enableSIMDIndex32=*/false);
216+
}
217+
218+
std::unique_ptr<mlir::Pass> mlir::createSparsificationAndBufferizationPass(
183219
const bufferization::OneShotBufferizationOptions &bufferizationOptions,
184220
const SparsificationOptions &sparsificationOptions,
185221
const SparseTensorConversionOptions &sparseTensorConversionOptions,

0 commit comments

Comments
 (0)