Skip to content

Commit 3b2004e

Browse files
[mlir][bufferization] Add TensorCopyInsertion pass
This pass runs the One-Shot Analysis to find out which tensor OpOperands must bufferize out-of-place. It then rewrites those tensor OpOperands to explicit allocations with a copy in the form of `bufferization.alloc_tensor`. The resulting IR can then be bufferized without having to care about read-after-write conflicts. This change makes it possible to connect One-Shot Analysis to other bufferizations such as the sparse compiler. Differential Revision: https://reviews.llvm.org/D126573
1 parent 6d890a0 commit 3b2004e

File tree

6 files changed

+206
-0
lines changed

6 files changed

+206
-0
lines changed

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,11 @@ std::unique_ptr<Pass> createAllocTensorEliminationPass();
7878
/// Create a pass that bufferizes ops from the bufferization dialect.
7979
std::unique_ptr<Pass> createBufferizationBufferizePass();
8080

81+
/// Create a pass that resolves out-of-place tensor OpOperands with copies.
82+
std::unique_ptr<Pass> createTensorCopyInsertionPass();
83+
std::unique_ptr<Pass>
84+
createTensorCopyInsertionPass(const OneShotBufferizationOptions &options);
85+
8186
//===----------------------------------------------------------------------===//
8287
// Registration
8388
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,25 @@ def PromoteBuffersToStack : Pass<"promote-buffers-to-stack", "func::FuncOp"> {
324324
];
325325
}
326326

327+
def TensorCopyInsertion : Pass<"tensor-copy-insertion"> {
328+
let summary = "Make all tensor IR inplaceable by inserting copies";
329+
let description = [{
330+
This pass runs One-Shot Analysis and inserts copies for all OpOperands that
331+
were decided to bufferize out-of-place. After running this pass, a
332+
bufferization can write to buffers directly (without making copies) and no
333+
longer has to care about potential read-after-write conflicts.
334+
}];
335+
let options = [
336+
Option<"allowReturnAllocs", "allow-return-allocs", "bool",
337+
/*default=*/"false",
338+
"Allows returning/yielding new allocations from a block.">,
339+
Option<"bufferizeFunctionBoundaries", "bufferize-function-boundaries",
340+
"bool", /*default=*/"0",
341+
"Bufferize function boundaries (experimental).">,
342+
];
343+
let constructor = "mlir::bufferization::createTensorCopyInsertionPass()";
344+
}
345+
327346
def AllocTensorElimination : Pass<"eliminate-alloc-tensors"> {
328347
let summary = "Try to eliminate all alloc_tensor ops.";
329348
let description = [{
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//===- TensorCopyInsertion.h - Resolve Bufferization Conflicts w/ Copies --===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_TENSORCOPYINSERTION_H
10+
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_TENSORCOPYINSERTION_H
11+
12+
#include "mlir/IR/Operation.h"
13+
14+
namespace mlir {
15+
namespace bufferization {
16+
class AnalysisState;
17+
struct OneShotBufferizationOptions;
18+
19+
LogicalResult insertTensorCopies(Operation *op,
20+
const OneShotBufferizationOptions &options);
21+
22+
LogicalResult insertTensorCopies(Operation *op, const AnalysisState &state);
23+
} // namespace bufferization
24+
} // namespace mlir
25+
26+
#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_TENSORCOPYINSERTION_H

mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
99
FuncBufferizableOpInterfaceImpl.cpp
1010
OneShotAnalysis.cpp
1111
OneShotModuleBufferize.cpp
12+
TensorCopyInsertion.cpp
1213

1314
ADDITIONAL_HEADER_DIRS
1415
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
//===- TensorCopyInsertion.cpp - Resolve Bufferization Conflicts w/ Copies ===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Bufferization/Transforms/TensorCopyInsertion.h"
10+
11+
#include "PassDetail.h"
12+
13+
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
14+
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
15+
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
16+
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
17+
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
18+
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
19+
20+
using namespace mlir;
21+
using namespace mlir::bufferization;
22+
23+
LogicalResult mlir::bufferization::insertTensorCopies(
24+
Operation *op, const OneShotBufferizationOptions &options) {
25+
OneShotAnalysisState state(op, options);
26+
// Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize
27+
// analysis depending on whether function boundary bufferization is enabled or
28+
// not.
29+
if (options.bufferizeFunctionBoundaries) {
30+
if (failed(analyzeModuleOp(cast<ModuleOp>(op), state)))
31+
return failure();
32+
} else {
33+
if (failed(analyzeOp(op, state)))
34+
return failure();
35+
}
36+
37+
if (options.testAnalysisOnly)
38+
return success();
39+
40+
return insertTensorCopies(op, state);
41+
}
42+
43+
LogicalResult
44+
mlir::bufferization::insertTensorCopies(Operation *op,
45+
const AnalysisState &state) {
46+
OpBuilder builder(op->getContext());
47+
WalkResult result = op->walk([&](Operation *op) {
48+
auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op);
49+
if (!bufferizableOp)
50+
return WalkResult::skip();
51+
52+
// Find AllocTensorOps without an `escape` attribute and add the attribute
53+
// based on analysis results.
54+
if (auto allocTensorOp = dyn_cast<AllocTensorOp>(op)) {
55+
if (allocTensorOp.escape())
56+
return WalkResult::advance();
57+
bool escape = state.isTensorYielded(allocTensorOp.result());
58+
allocTensorOp.escapeAttr(builder.getBoolAttr(escape));
59+
return WalkResult::advance();
60+
}
61+
62+
// Find out-of-place tensor OpOperands and resolve them with an explicit
63+
// tensor copy in the form of an AllocTensorOp.
64+
builder.setInsertionPoint(op);
65+
for (OpOperand &opOperand : op->getOpOperands()) {
66+
if (opOperand.get().getType().isa<UnrankedTensorType>()) {
67+
op->emitError("copies of unranked tensors are not supported");
68+
return WalkResult::interrupt();
69+
}
70+
auto tensorType = opOperand.get().getType().dyn_cast<RankedTensorType>();
71+
if (!tensorType)
72+
continue;
73+
if (state.isInPlace(opOperand))
74+
continue;
75+
SmallVector<OpResult> aliasingOpResults =
76+
state.getAliasingOpResult(opOperand);
77+
bool escape = llvm::any_of(
78+
aliasingOpResults, [&](Value v) { return state.isTensorYielded(v); });
79+
Value copy = builder.create<AllocTensorOp>(
80+
op->getLoc(), tensorType, ValueRange(), opOperand.get(), escape);
81+
opOperand.set(copy);
82+
}
83+
84+
return WalkResult::advance();
85+
});
86+
87+
return failure(result.wasInterrupted());
88+
}
89+
90+
namespace {
91+
struct TensorCopyInsertionPass
92+
: TensorCopyInsertionBase<TensorCopyInsertionPass> {
93+
TensorCopyInsertionPass()
94+
: TensorCopyInsertionBase<TensorCopyInsertionPass>(),
95+
options(llvm::None) {}
96+
TensorCopyInsertionPass(const OneShotBufferizationOptions &options)
97+
: TensorCopyInsertionBase<TensorCopyInsertionPass>(), options(options) {}
98+
99+
void getDependentDialects(DialectRegistry &registry) const override {
100+
registry.insert<bufferization::BufferizationDialect>();
101+
}
102+
103+
void runOnOperation() override {
104+
if (options.hasValue()) {
105+
if (failed(insertTensorCopies(getOperation(), *options)))
106+
signalPassFailure();
107+
} else {
108+
OneShotBufferizationOptions options;
109+
options.allowReturnAllocs = allowReturnAllocs;
110+
options.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
111+
if (failed(insertTensorCopies(getOperation(), options)))
112+
signalPassFailure();
113+
}
114+
}
115+
116+
private:
117+
Optional<OneShotBufferizationOptions> options;
118+
};
119+
} // namespace
120+
121+
std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass() {
122+
return std::make_unique<TensorCopyInsertionPass>();
123+
}
124+
125+
std::unique_ptr<Pass> mlir::bufferization::createTensorCopyInsertionPass(
126+
const OneShotBufferizationOptions &options) {
127+
return std::make_unique<TensorCopyInsertionPass>(options);
128+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: mlir-opt %s -tensor-copy-insertion -split-input-file | FileCheck %s
2+
// RUN: mlir-opt %s -tensor-copy-insertion="bufferize-function-boundaries allow-return-allocs" -split-input-file | FileCheck %s --check-prefix=CHECK-FUNC
3+
4+
// CHECK-LABEL: func @read_after_write_conflict(
5+
// CHECK-SAME: %[[t:.*]]: tensor<?xf32>
6+
// CHECK-FUNC-LABEL: func @read_after_write_conflict(
7+
func.func @read_after_write_conflict(%t: tensor<?xf32>, %idx: index, %f: f32)
8+
-> (tensor<?xf32>, tensor<?xf32>)
9+
{
10+
// CHECK: %[[copy:.*]] = bufferization.alloc_tensor() copy(%[[t]]) {escape = false} : tensor<?xf32>
11+
// CHECK-FUNC: bufferization.alloc_tensor() copy(%{{.*}}) {escape = true} : tensor<?xf32>
12+
// CHECK: %[[insert:.*]] = tensor.insert %{{.*}} into %[[copy]]
13+
%0 = tensor.insert %f into %t[%idx] : tensor<?xf32>
14+
// CHECK: return %[[insert]], %[[t]]
15+
return %0, %t : tensor<?xf32>, tensor<?xf32>
16+
}
17+
18+
// -----
19+
20+
// CHECK-LABEL: func @return_alloc_tensor
21+
// CHECK-FUNC-LABEL: func @return_alloc_tensor
22+
func.func @return_alloc_tensor() -> (tensor<5xf32>) {
23+
// CHECK: bufferization.alloc_tensor() {escape = false} : tensor<5xf32>
24+
// CHECK-FUNC: bufferization.alloc_tensor() {escape = true} : tensor<5xf32>
25+
%0 = bufferization.alloc_tensor() : tensor<5xf32>
26+
return %0 : tensor<5xf32>
27+
}

0 commit comments

Comments
 (0)