Skip to content

Commit c0db8d5

Browse files
committed
[mlir] Expose a function to populate tensor constant bufferization patterns
This makes it easier to use it from other bufferization passes. Differential Revision: https://reviews.llvm.org/D103838
1 parent 639b397 commit c0db8d5

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
namespace mlir {
2121

22+
class GlobalCreator;
2223
class RewritePatternSet;
2324
using OwningRewritePatternList = RewritePatternSet;
2425

@@ -31,6 +32,12 @@ std::unique_ptr<Pass> createStdBufferizePass();
3132
/// Creates an instance of func bufferization pass.
3233
std::unique_ptr<Pass> createFuncBufferizePass();
3334

35+
/// Add patterns to bufferize tensor constants into global memrefs to the given
36+
/// pattern list.
37+
void populateTensorConstantBufferizePatterns(
38+
GlobalCreator &globalCreator, BufferizeTypeConverter &typeConverter,
39+
RewritePatternSet &patterns);
40+
3441
/// Creates an instance of tensor constant bufferization pass.
3542
std::unique_ptr<Pass> createTensorConstantBufferizePass();
3643

mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,13 @@ class BufferizeTensorConstantOp : public OpConversionPattern<ConstantOp> {
8181
};
8282
} // namespace
8383

84+
void mlir::populateTensorConstantBufferizePatterns(
85+
GlobalCreator &globalCreator, BufferizeTypeConverter &typeConverter,
86+
RewritePatternSet &patterns) {
87+
patterns.add<BufferizeTensorConstantOp>(globalCreator, typeConverter,
88+
patterns.getContext());
89+
}
90+
8491
namespace {
8592
struct TensorConstantBufferizePass
8693
: public TensorConstantBufferizeBase<TensorConstantBufferizePass> {
@@ -94,7 +101,7 @@ struct TensorConstantBufferizePass
94101
ConversionTarget target(*context);
95102

96103
target.addLegalDialect<memref::MemRefDialect>();
97-
patterns.add<BufferizeTensorConstantOp>(globals, typeConverter, context);
104+
populateTensorConstantBufferizePatterns(globals, typeConverter, patterns);
98105
target.addDynamicallyLegalOp<ConstantOp>(
99106
[&](ConstantOp op) { return typeConverter.isLegal(op.getType()); });
100107
if (failed(applyPartialConversion(module, target, std::move(patterns))))

0 commit comments

Comments
 (0)