Skip to content

Commit 87c770b

Browse files
[mlir][bufferization][NFC] Put inplacability conflict resolution in op interface
The TensorCopyInsertion pass resolves out-of-place bufferization decisions by inserting explicit `bufferization.alloc_tensor` ops. This change moves that functionality into a new BufferizableOpInterface method, so that it can be overridden by op implementations. Some op bufferizations must insert additional `alloc_tensor` ops to make sure that certain aliasing invariants are not violated (e.g., scf::ForOp). This will be addressed in a subsequent change. Differential Revision: https://reviews.llvm.org/D126817
1 parent 9f12215 commit 87c770b

File tree

3 files changed

+67
-23
lines changed

3 files changed

+67
-23
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,32 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
192192
llvm_unreachable("bufferRelation not implemented");
193193
}]
194194
>,
195+
InterfaceMethod<
196+
/*desc=*/[{
197+
Resolve all inplacability conflicts by inserting explicit
198+
`bufferization.alloc_tensor` ops. Examples of inplacability conflicts
199+
are read-after-write conflicts or writes into non-writable buffers.
200+
201+
This method should rewrite the IR in such a way that for each tensor
202+
OpOperand t, buffer(t) can be directly used when during bufferization.
203+
The bufferization does no longer have to care about inplacability
204+
conflicts.
205+
206+
This method can query analysis information from the given analysis
207+
state.
208+
}],
209+
/*retType=*/"LogicalResult",
210+
/*methodName=*/"resolveConflicts",
211+
/*args=*/(ins "RewriterBase &":$rewriter,
212+
"const AnalysisState &":$state),
213+
/*methodBody=*/"",
214+
/*defaultImplementation=*/[{
215+
auto bufferizableOp =
216+
cast<BufferizableOpInterface>($_op.getOperation());
217+
return bufferizableOp.resolveTensorOpOperandConflicts(
218+
rewriter, state);
219+
}]
220+
>,
195221
InterfaceMethod<
196222
/*desc=*/[{
197223
Bufferize this op, i.e., rewrite it into a memref-based equivalent.
@@ -301,6 +327,11 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
301327
];
302328

303329
let extraClassDeclaration = [{
330+
/// Resolve out-of-place tensor OpOperands with explicit allocations in the
331+
/// form of `bufferization.alloc_tensor` ops.
332+
LogicalResult resolveTensorOpOperandConflicts(
333+
RewriterBase &rewriter, const AnalysisState &state);
334+
304335
/// Return `true` if the given OpOperand creates an alias but does neither
305336
/// read nor write. This implies that `bufferizesToMemoryRead` and
306337
/// `bufferizesToMemoryWrite` must return `false`. This method will never

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
#include "mlir/IR/Value.h"
1919
#include "llvm/Support/Debug.h"
2020

21+
//===----------------------------------------------------------------------===//
22+
// BufferizableOpInterface
23+
//===----------------------------------------------------------------------===//
24+
2125
namespace mlir {
2226
namespace bufferization {
2327

@@ -38,6 +42,31 @@ using namespace bufferization;
3842
constexpr const ::llvm::StringLiteral
3943
bufferization::BufferizableOpInterface::kInplaceableAttrName;
4044

45+
LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
46+
RewriterBase &rewriter, const AnalysisState &state) {
47+
Operation *op = getOperation();
48+
for (OpOperand &opOperand : op->getOpOperands()) {
49+
Type operandType = opOperand.get().getType();
50+
if (!operandType.isa<TensorType>())
51+
continue;
52+
if (state.isInPlace(opOperand))
53+
continue;
54+
if (operandType.isa<UnrankedTensorType>())
55+
return op->emitError("copies of unranked tensors are not supported");
56+
auto tensorType = operandType.dyn_cast<RankedTensorType>();
57+
if (!tensorType)
58+
continue;
59+
SmallVector<OpResult> aliasingOpResults =
60+
state.getAliasingOpResult(opOperand);
61+
bool escape = llvm::any_of(
62+
aliasingOpResults, [&](Value v) { return state.isTensorYielded(v); });
63+
Value copy = rewriter.create<AllocTensorOp>(
64+
op->getLoc(), tensorType, ValueRange(), opOperand.get(), escape);
65+
rewriter.updateRootInPlace(op, [&]() { opOperand.set(copy); });
66+
}
67+
return success();
68+
}
69+
4170
//===----------------------------------------------------------------------===//
4271
// OpFilter
4372
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ LogicalResult mlir::bufferization::insertTensorCopies(
4343
LogicalResult
4444
mlir::bufferization::insertTensorCopies(Operation *op,
4545
const AnalysisState &state) {
46-
OpBuilder builder(op->getContext());
46+
IRRewriter rewriter(op->getContext());
4747
WalkResult result = op->walk([&](Operation *op) {
4848
auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op);
4949
if (!bufferizableOp)
@@ -55,31 +55,15 @@ mlir::bufferization::insertTensorCopies(Operation *op,
5555
if (allocTensorOp.escape())
5656
return WalkResult::advance();
5757
bool escape = state.isTensorYielded(allocTensorOp.result());
58-
allocTensorOp.escapeAttr(builder.getBoolAttr(escape));
58+
allocTensorOp.escapeAttr(rewriter.getBoolAttr(escape));
5959
return WalkResult::advance();
6060
}
6161

62-
// Find out-of-place tensor OpOperands and resolve them with an explicit
63-
// tensor copy in the form of an AllocTensorOp.
64-
builder.setInsertionPoint(op);
65-
for (OpOperand &opOperand : op->getOpOperands()) {
66-
if (opOperand.get().getType().isa<UnrankedTensorType>()) {
67-
op->emitError("copies of unranked tensors are not supported");
68-
return WalkResult::interrupt();
69-
}
70-
auto tensorType = opOperand.get().getType().dyn_cast<RankedTensorType>();
71-
if (!tensorType)
72-
continue;
73-
if (state.isInPlace(opOperand))
74-
continue;
75-
SmallVector<OpResult> aliasingOpResults =
76-
state.getAliasingOpResult(opOperand);
77-
bool escape = llvm::any_of(
78-
aliasingOpResults, [&](Value v) { return state.isTensorYielded(v); });
79-
Value copy = builder.create<AllocTensorOp>(
80-
op->getLoc(), tensorType, ValueRange(), opOperand.get(), escape);
81-
opOperand.set(copy);
82-
}
62+
// Find inplacability conflicts and resolve them. (Typically with explicit
63+
// tensor copies in the form of AllocTensorOps.)
64+
rewriter.setInsertionPoint(op);
65+
if (failed(bufferizableOp.resolveConflicts(rewriter, state)))
66+
return WalkResult::interrupt();
8367

8468
return WalkResult::advance();
8569
});

0 commit comments

Comments
 (0)