8
8
9
9
#include " mlir/Dialect/SparseTensor/Transforms/Passes.h"
10
10
11
+ #include " mlir/Dialect/Affine/IR/AffineOps.h"
11
12
#include " mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12
13
#include " mlir/Dialect/Bufferization/IR/Bufferization.h"
13
14
#include " mlir/Dialect/Bufferization/Transforms/Bufferize.h"
18
19
#include " mlir/Dialect/Func/IR/FuncOps.h"
19
20
#include " mlir/Dialect/GPU/IR/GPUDialect.h"
20
21
#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
22
+ #include " mlir/Dialect/Linalg/IR/Linalg.h"
23
+ #include " mlir/Dialect/MemRef/IR/MemRef.h"
24
+ #include " mlir/Dialect/SCF/IR/SCF.h"
21
25
#include " mlir/Dialect/SparseTensor/IR/SparseTensor.h"
22
26
#include " mlir/Dialect/SparseTensor/Transforms/Passes.h"
23
27
#include " mlir/Pass/PassManager.h"
24
28
#include " mlir/Transforms/Passes.h"
25
29
26
30
using namespace mlir ;
27
- using namespace mlir ::func;
28
31
29
32
namespace mlir {
33
+
34
+ #define GEN_PASS_DEF_SPARSIFICATIONANDBUFFERIZATION
35
+ #include " mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
36
+
30
37
namespace sparse_tensor {
31
38
32
39
// / Return `true` if one of the given types is a sparse tensor type.
@@ -50,8 +57,8 @@ static bool containsSparseTensor(TypeRange types) {
50
57
// / * Dense tensor ops are lowered through BufferizableOpInterface
51
58
// / implementations.
52
59
class SparsificationAndBufferizationPass
53
- : public PassWrapper<SparsificationAndBufferizationPass,
54
- OperationPass<ModuleOp> > {
60
+ : public impl::SparsificationAndBufferizationBase<
61
+ SparsificationAndBufferizationPass > {
55
62
public:
56
63
SparsificationAndBufferizationPass (
57
64
const bufferization::OneShotBufferizationOptions &bufferizationOptions,
@@ -97,12 +104,6 @@ class SparsificationAndBufferizationPass
97
104
return success ();
98
105
}
99
106
100
- void getDependentDialects (::mlir::DialectRegistry ®istry) const override {
101
- registry.insert <bufferization::BufferizationDialect>();
102
- registry.insert <gpu::GPUDialect>();
103
- registry.insert <LLVM::LLVMDialect>();
104
- }
105
-
106
107
void runOnOperation () override {
107
108
{
108
109
// Run enabling transformations.
@@ -179,7 +180,42 @@ class SparsificationAndBufferizationPass
179
180
} // namespace sparse_tensor
180
181
} // namespace mlir
181
182
182
- std::unique_ptr<Pass> mlir::createSparsificationAndBufferizationPass (
183
+ mlir::bufferization::OneShotBufferizationOptions
184
+ mlir::getBufferizationOptionsForSparsification (bool analysisOnly) {
185
+ using namespace mlir ::bufferization;
186
+ OneShotBufferizationOptions options;
187
+ options.bufferizeFunctionBoundaries = true ;
188
+ // TODO(springerm): To spot memory leaks more easily, returning dense allocs
189
+ // should be disallowed.
190
+ options.allowReturnAllocs = true ;
191
+ options.setFunctionBoundaryTypeConversion (LayoutMapOption::IdentityLayoutMap);
192
+ options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
193
+ const BufferizationOptions &options) {
194
+ return getMemRefTypeWithStaticIdentityLayout (
195
+ cast<TensorType>(value.getType ()), memorySpace);
196
+ };
197
+ if (analysisOnly) {
198
+ options.testAnalysisOnly = true ;
199
+ options.printConflicts = true ;
200
+ }
201
+ return options;
202
+ }
203
+
204
+ std::unique_ptr<mlir::Pass> mlir::createSparsificationAndBufferizationPass () {
205
+ SparsificationOptions sparseOptions;
206
+ SparseTensorConversionOptions convOptions;
207
+ return createSparsificationAndBufferizationPass (
208
+ getBufferizationOptionsForSparsification (/* analysisOnly=*/ false ),
209
+ sparseOptions, convOptions,
210
+ /* createSparseDeallocs=*/ false ,
211
+ /* enableRuntimeLibrary=*/ false ,
212
+ /* enableBufferInitialization=*/ false ,
213
+ /* vectorLength=*/ 0 ,
214
+ /* enableVLAVectorization=*/ false ,
215
+ /* enableSIMDIndex32=*/ false );
216
+ }
217
+
218
+ std::unique_ptr<mlir::Pass> mlir::createSparsificationAndBufferizationPass (
183
219
const bufferization::OneShotBufferizationOptions &bufferizationOptions,
184
220
const SparsificationOptions &sparsificationOptions,
185
221
const SparseTensorConversionOptions &sparseTensorConversionOptions,
0 commit comments