Skip to content

Commit 11878d6

Browse files
[mlir][sparse] Extract StorageSpecifierToLLVMPass from bufferization pipeline
`StorageSpecifierToLLVMPass` does not have to be part of the bufferization mini pipeline. It can run after the bufferization pipeline. This is desirable because it keeps the bufferization pipeline smaller.
1 parent 9d34c05 commit 11878d6

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ void mlir::sparse_tensor::buildSparseCompiler(
4242
/*enableSIMDIndex32=*/options.force32BitVectorIndices));
4343
if (options.testBufferizationAnalysisOnly)
4444
return;
45+
46+
pm.addPass(createStorageSpecifierToLLVMPass());
4547
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
4648
pm.addNestedPass<func::FuncOp>(
4749
mlir::bufferization::createFinalizingBufferizePass());

mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ namespace sparse_tensor {
3939
/// Return `true` if one of the given types is a sparse tensor type.
4040
static bool containsSparseTensor(TypeRange types) {
4141
for (Type t : types)
42-
if (getSparseTensorEncoding(t))
42+
if (isa<TensorType>(t) && getSparseTensorEncoding(t))
4343
return true;
4444
return false;
4545
}
@@ -97,7 +97,8 @@ class SparsificationAndBufferizationPass
9797
return false;
9898
});
9999

100-
if (failed(bufferization::bufferizeOp(getOperation(), updatedOptions)))
100+
if (failed(bufferization::bufferizeModuleOp(cast<ModuleOp>(getOperation()),
101+
updatedOptions)))
101102
return failure();
102103

103104
bufferization::removeBufferizationAttributesInModule(getOperation());
@@ -154,7 +155,6 @@ class SparsificationAndBufferizationPass
154155
pm.addPass(createSparseTensorCodegenPass(createSparseDeallocs,
155156
enableBufferInitialization));
156157
pm.addPass(createSparseBufferRewritePass(enableBufferInitialization));
157-
pm.addPass(createStorageSpecifierToLLVMPass());
158158
}
159159
if (failed(runPipeline(pm, getOperation())))
160160
return signalPassFailure();

0 commit comments

Comments
 (0)