Skip to content

Commit 8906b7b

Browse files
authored
Enable custom alloc-like ops in promoteBufferResultsToOutParams (#120288)
In `buffer-results-to-out-params`, when `hoist-static-allocs` option is enabled the pass was looking for `memref.alloc`s in order to attempt to avoid copies when it can. Which makes it not extensible to external ops that have allocation like properties. This patch simply changes `memref::AllocOp` to `AllocationOpInterface` in the check to enable for any allocation op. Moreover, for function call updates, we enable setting an allocation function callback in `BufferResultsToOutParamsOpts` to allow users to emit their own alloc-like op.
1 parent 831e1ac commit 8906b7b

File tree

2 files changed

+40
-19
lines changed

2 files changed

+40
-19
lines changed

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES_H
33

44
#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
5+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
56
#include "mlir/Pass/Pass.h"
67

78
namespace mlir {
89
class FunctionOpInterface;
10+
class MemRefType;
911
class ModuleOp;
1012
class RewritePatternSet;
1113
class OpBuilder;
@@ -38,7 +40,7 @@ std::unique_ptr<Pass> createOwnershipBasedBufferDeallocationPass(
3840
DeallocationOptions options = DeallocationOptions());
3941

4042
/// Creates a pass that finds all temporary allocations
41-
/// and attempts to move the deallocation after the last user/dependency
43+
/// and attempts to move the deallocation after the last user/dependency
4244
/// of the allocation, thereby optimizing allocation liveness.
4345
std::unique_ptr<Pass> createOptimizeAllocationLivenessPass();
4446

@@ -157,6 +159,12 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
157159
// Options struct for BufferResultsToOutParams pass.
158160
// Note: defined only here, not in tablegen.
159161
struct BufferResultsToOutParamsOpts {
162+
/// Allocator function: Generate a memref allocation with the given type.
163+
/// Since `promoteBufferResultsToOutParams` doesn't allow dynamically shaped
164+
/// results, we don't allow passing a range of values for dynamic dims.
165+
using AllocationFn =
166+
std::function<FailureOr<Value>(OpBuilder &, Location, MemRefType)>;
167+
160168
/// Memcpy function: Generate a memcpy between two memrefs.
161169
using MemCpyFn =
162170
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
@@ -167,9 +175,20 @@ struct BufferResultsToOutParamsOpts {
167175
return true;
168176
};
169177

178+
/// Allocation function; used to allocate a memref.
179+
/// Default memref.alloc is used
180+
AllocationFn allocationFn = [](OpBuilder &builder, Location loc,
181+
MemRefType type) {
182+
return builder.create<memref::AllocOp>(loc, type).getResult();
183+
};
184+
170185
/// Memcpy function; used to create a copy between two memrefs.
171-
/// If this is empty, memref.copy is used.
172-
std::optional<MemCpyFn> memCpyFn;
186+
/// Default memref.copy is used.
187+
MemCpyFn memCpyFn = [](OpBuilder &builder, Location loc, Value from,
188+
Value to) {
189+
builder.create<memref::CopyOp>(loc, from, to);
190+
return success();
191+
};
173192

174193
/// If true, the pass adds a "bufferize.result" attribute to each output
175194
/// parameter.

mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
910
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
1011

1112
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -21,6 +22,7 @@ namespace bufferization {
2122
} // namespace mlir
2223

2324
using namespace mlir;
25+
using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn;
2426
using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
2527

2628
/// Return `true` if the given MemRef type has a fully dynamic layout.
@@ -105,10 +107,9 @@ updateFuncOp(func::FuncOp func,
105107
// Updates all ReturnOps in the scope of the given func::FuncOp by either
106108
// keeping them as return values or copying the associated buffer contents into
107109
// the given out-params.
108-
static LogicalResult updateReturnOps(func::FuncOp func,
109-
ArrayRef<BlockArgument> appendedEntryArgs,
110-
MemCpyFn memCpyFn,
111-
bool hoistStaticAllocs) {
110+
static LogicalResult
111+
updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
112+
const bufferization::BufferResultsToOutParamsOpts &options) {
112113
auto res = func.walk([&](func::ReturnOp op) {
113114
SmallVector<Value, 6> copyIntoOutParams;
114115
SmallVector<Value, 6> keepAsReturnOperands;
@@ -120,13 +121,14 @@ static LogicalResult updateReturnOps(func::FuncOp func,
120121
}
121122
OpBuilder builder(op);
122123
for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
123-
if (hoistStaticAllocs &&
124-
isa_and_nonnull<memref::AllocOp>(orig.getDefiningOp()) &&
124+
if (options.hoistStaticAllocs &&
125+
isa_and_nonnull<bufferization::AllocationOpInterface>(
126+
orig.getDefiningOp()) &&
125127
mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) {
126128
orig.replaceAllUsesWith(arg);
127129
orig.getDefiningOp()->erase();
128130
} else {
129-
if (failed(memCpyFn(builder, op.getLoc(), orig, arg)))
131+
if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg)))
130132
return WalkResult::interrupt();
131133
}
132134
}
@@ -175,7 +177,14 @@ updateCalls(ModuleOp module,
175177
auto allocType =
176178
MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
177179
AffineMap(), memrefType.getMemorySpace());
178-
Value outParam = builder.create<memref::AllocOp>(op.getLoc(), allocType);
180+
auto maybeOutParam =
181+
options.allocationFn(builder, op.getLoc(), allocType);
182+
if (failed(maybeOutParam)) {
183+
op.emitError() << "failed to create allocation op";
184+
didFail = true;
185+
return;
186+
}
187+
Value outParam = maybeOutParam.value();
179188
if (!hasStaticIdentityLayout(memrefType)) {
180189
// Layout maps are already checked in `updateFuncOp`.
181190
assert(hasFullyDynamicLayoutMap(memrefType) &&
@@ -213,14 +222,7 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
213222
return failure();
214223
if (func.isExternal())
215224
continue;
216-
auto defaultMemCpyFn = [](OpBuilder &builder, Location loc, Value from,
217-
Value to) {
218-
builder.create<memref::CopyOp>(loc, from, to);
219-
return success();
220-
};
221-
if (failed(updateReturnOps(func, appendedEntryArgs,
222-
options.memCpyFn.value_or(defaultMemCpyFn),
223-
options.hoistStaticAllocs))) {
225+
if (failed(updateReturnOps(func, appendedEntryArgs, options))) {
224226
return failure();
225227
}
226228
}

0 commit comments

Comments
 (0)