Skip to content

Commit 060c8be

Browse files
committed
[mlir][OneShotModuleBufferize] Add a new flag: no-analysis-func-filter
OneShotModuleBufferize fails if the input IR cannot be analyzed. One can set CopyBeforeWrite=true in order to skip analysis. In that case, a buffer copy is inserted on every write. This leads to many copies, also in FuncOps that could be analyzed. This change aims to copy buffers only when it is a must. When running OneShotModuleBufferize with CopyBeforeWrite=false, FuncOps whose names are specified in noAnalysisFuncFilter will not be analyzed. Ops in these FuncOps will not be analyzed as well. They will be bufferized with CopyBeforeWrite=true, while the other ops will be bufferized with CopyBeforeWrite=false. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D142631
1 parent 54b40a1 commit 060c8be

File tree

5 files changed

+93
-13
lines changed

5 files changed

+93
-13
lines changed

mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H
1010
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H
1111

12+
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1213
namespace mlir {
1314

1415
struct LogicalResult;
@@ -27,11 +28,16 @@ LogicalResult analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
2728
/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
2829
///
2930
/// Note: This function does not run One-Shot Analysis. No buffer copies are
30-
/// inserted unless `options.copyBeforeWrite` is set, in which case buffers are
31-
/// copied before every write.
32-
LogicalResult bufferizeModuleOp(ModuleOp moduleOp,
33-
const OneShotBufferizationOptions &options,
34-
BufferizationStatistics *statistics = nullptr);
31+
/// inserted except two cases:
32+
/// - `options.copyBeforeWrite` is set, in which case buffers are copied before
33+
/// every write.
34+
/// - `options.copyBeforeWrite` is not set and `analysisFilterFn` returns true
35+
/// for some FuncOps. These FuncOps were not analyzed. Buffer copies will be
36+
/// inserted only to these FuncOps.
37+
LogicalResult
38+
bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options,
39+
BufferizationStatistics *statistics = nullptr,
40+
OpFilter::Entry::FilterFn analysisFilterFn = nullptr);
3541

3642
/// Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
3743
void removeBufferizationAttributesInModule(ModuleOp moduleOp);
@@ -43,7 +49,8 @@ void removeBufferizationAttributesInModule(ModuleOp moduleOp);
4349
LogicalResult runOneShotModuleBufferize(
4450
ModuleOp moduleOp,
4551
const bufferization::OneShotBufferizationOptions &options,
46-
BufferizationStatistics *statistics = nullptr);
52+
BufferizationStatistics *statistics = nullptr,
53+
OpFilter::Entry::FilterFn analysisFilterFn = nullptr);
4754

4855
} // namespace bufferization
4956
} // namespace mlir

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,9 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
297297
"core bufferization passes.">,
298298
ListOption<"dialectFilter", "dialect-filter", "std::string",
299299
"Restrict bufferization to ops from these dialects.">,
300+
ListOption<"noAnalysisFuncFilter", "no-analysis-func-filter", "std::string",
301+
"Skip analysis of functions with these symbol names."
302+
"Set copyBeforeWrite to true when bufferizing them.">,
300303
Option<"functionBoundaryTypeConversion",
301304
"function-boundary-type-conversion", "std::string",
302305
/*default=*/"\"infer-layout-map\"",

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,11 +249,26 @@ struct OneShotBufferizePass
249249
BufferizationStatistics statistics;
250250
ModuleOp moduleOp = getOperation();
251251
if (opt.bufferizeFunctionBoundaries) {
252-
if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics))) {
252+
OpFilter::Entry::FilterFn analysisFilterFn = nullptr;
253+
// FuncOps whose names are specified in noAnalysisFuncFilter will not be
254+
// analyzed. Ops in these FuncOps will not be analyzed as well.
255+
if (this->noAnalysisFuncFilter.hasValue())
256+
analysisFilterFn = [=](Operation *op) {
257+
auto func = dyn_cast<func::FuncOp>(op);
258+
if (!func)
259+
func = op->getParentOfType<func::FuncOp>();
260+
if (func)
261+
return llvm::is_contained(noAnalysisFuncFilter, func.getSymName());
262+
return false;
263+
};
264+
if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics,
265+
analysisFilterFn))) {
253266
signalPassFailure();
254267
return;
255268
}
256269
} else {
270+
assert(!this->noAnalysisFuncFilter.hasValue() &&
271+
"invalid combination of bufferization flags");
257272
if (failed(runOneShotBufferize(moduleOp, opt, &statistics))) {
258273
signalPassFailure();
259274
return;

mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,9 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
378378

379379
// Analyze ops.
380380
for (func::FuncOp funcOp : orderedFuncOps) {
381+
if (!state.getOptions().isOpAllowed(funcOp))
382+
continue;
383+
381384
// Now analyzing function.
382385
funcState.startFunctionAnalysis(funcOp);
383386

@@ -410,7 +413,8 @@ void mlir::bufferization::removeBufferizationAttributesInModule(
410413

411414
LogicalResult mlir::bufferization::bufferizeModuleOp(
412415
ModuleOp moduleOp, const OneShotBufferizationOptions &options,
413-
BufferizationStatistics *statistics) {
416+
BufferizationStatistics *statistics,
417+
OpFilter::Entry::FilterFn analysisFilterFn) {
414418
assert(options.bufferizeFunctionBoundaries &&
415419
"expected that function boundary bufferization is activated");
416420
IRRewriter rewriter(moduleOp.getContext());
@@ -428,7 +432,9 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
428432
for (func::FuncOp funcOp : orderedFuncOps) {
429433
// Note: It would be good to apply cleanups here but we cannot as aliasInfo
430434
// would be invalidated.
431-
if (failed(bufferizeOp(funcOp, options, options.copyBeforeWrite,
435+
bool copyBeforeWrite = options.copyBeforeWrite ||
436+
(analysisFilterFn && analysisFilterFn(funcOp));
437+
if (failed(bufferizeOp(funcOp, options, copyBeforeWrite,
432438
/*opFilter=*/nullptr, statistics)))
433439
return failure();
434440
// Change buffer return types to more precise layout maps.
@@ -445,18 +451,28 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
445451

446452
LogicalResult mlir::bufferization::runOneShotModuleBufferize(
447453
ModuleOp moduleOp, const OneShotBufferizationOptions &options,
448-
BufferizationStatistics *statistics) {
454+
BufferizationStatistics *statistics,
455+
OpFilter::Entry::FilterFn analysisFilterFn) {
449456
assert(options.bufferizeFunctionBoundaries &&
450457
"expected that function boundary bufferization is activated");
451458
assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
452459
"invalid combination of bufferization flags");
453460
if (!options.copyBeforeWrite) {
454-
if (failed(insertTensorCopies(moduleOp, options, statistics)))
455-
return failure();
461+
if (!analysisFilterFn) {
462+
if (failed(insertTensorCopies(moduleOp, options, statistics)))
463+
return failure();
464+
} else {
465+
OneShotBufferizationOptions updatedOptions(options);
466+
updatedOptions.opFilter.denyOperation(analysisFilterFn);
467+
if (failed(insertTensorCopies(moduleOp, updatedOptions, statistics)))
468+
return failure();
469+
}
456470
}
457471
if (options.testAnalysisOnly)
458472
return success();
459-
if (failed(bufferizeModuleOp(moduleOp, options, statistics)))
473+
474+
if (failed(
475+
bufferizeModuleOp(moduleOp, options, statistics, analysisFilterFn)))
460476
return failure();
461477
return success();
462478
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 no-analysis-func-filter=contains_to_memref_op" -drop-equivalent-buffer-results --split-input-file | FileCheck %s
2+
3+
// ToMemref ops do not pass analysis step. CopyBeforeWrite will be true only for the
4+
// FuncOp "contains_to_memref_op" since it is specified in no-analysis-func-filter.
5+
6+
module {
7+
// CHECK-LABEL: func.func @foo(
8+
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32, strided<[?], offset: ?>>) {
9+
func.func @foo(%arg0: tensor<?xf32>) -> tensor<?xf32> {
10+
// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index
11+
%cst = arith.constant 1.000000e+00 : f32
12+
13+
// CHECK-NEXT: %[[c1:.*]] = arith.constant 1.000000e+00 : f32
14+
%c0 = arith.constant 0 : index
15+
16+
// CHECK-NEXT: memref.store %[[c1]], %[[arg0]]{{\[}}%[[c0]]] : memref<?xf32, strided<[?], offset: ?>>
17+
%inserted = tensor.insert %cst into %arg0[%c0] : tensor<?xf32>
18+
19+
return %inserted : tensor<?xf32>
20+
}
21+
22+
// CHECK-LABEL: func.func @contains_to_memref_op(
23+
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32, strided<[?], offset: ?>>,
24+
// CHECK-SAME: %[[arg1:.*]]: index) -> vector<5xf32> {
25+
func.func @contains_to_memref_op(%arg0: tensor<?xf32> {bufferization.writable = true}, %arg1: index) -> vector<5xf32> {
26+
27+
%0 = bufferization.to_memref %arg0 : memref<?xf32>
28+
29+
// CHECK: %[[c0:.*]] = arith.constant 0 : index
30+
%cst = arith.constant 0.000000e+00 : f32
31+
32+
// CHECK: %[[dim:.*]] = memref.dim %[[arg0]], %[[c0]] : memref<?xf32, strided<[?], offset: ?>>
33+
// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) : memref<?xf32>
34+
// CHECK: memref.copy %[[arg0]], %[[alloc]] : memref<?xf32, strided<[?], offset: ?>> to memref<?xf32>
35+
// CHECK: vector.transfer_read
36+
%1 = vector.transfer_read %0[%arg1], %cst : memref<?xf32>, vector<5xf32>
37+
return %1 : vector<5xf32>
38+
}
39+
}

0 commit comments

Comments
 (0)