Skip to content

Commit afac64c

Browse files
committed
[MLIR] BufferResultsToOutParams: Allow to configure memCpyFn
This allows us to configure the pass to emit linalg.copy instead of memref.copy. This is consistent with one-shot-bufferize, which also allows to configure the `memCpyFn`, see https://discord.com/channels/636084430946959380/642426447167881246/1211698722438783087
1 parent 469c5e3 commit afac64c

File tree

2 files changed

+29
-7
lines changed

2 files changed

+29
-7
lines changed

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,19 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
149149
// Options struct for BufferResultsToOutParams pass.
150150
// Note: defined only here, not in tablegen.
151151
struct BufferResultsToOutParamsOptions {
152+
/// Memcpy function: Generate a memcpy between two memrefs.
153+
using MemCpyFn =
154+
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
155+
152156
// Filter function; returns true if the function should be converted.
153157
// Defaults to true, i.e. all functions are converted.
154158
llvm::function_ref<bool(func::FuncOp *)> filterFn = [](func::FuncOp *func) {
155159
return true;
156160
};
161+
162+
/// Memcpy function; used to create a copy between two memrefs.
163+
/// If this is empty, memref.copy is used.
164+
std::optional<MemCpyFn> memCpyFn;
157165
};
158166

159167
/// Creates a pass that converts memref function results to out-params.

mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ namespace bufferization {
2121
} // namespace mlir
2222

2323
using namespace mlir;
24+
using MemCpyFn = bufferization::BufferResultsToOutParamsOptions::MemCpyFn;
2425

2526
/// Return `true` if the given MemRef type has a fully dynamic layout.
2627
static bool hasFullyDynamicLayoutMap(MemRefType type) {
@@ -97,9 +98,10 @@ updateFuncOp(func::FuncOp func,
9798
// Updates all ReturnOps in the scope of the given func::FuncOp by either
9899
// keeping them as return values or copying the associated buffer contents into
99100
// the given out-params.
100-
static void updateReturnOps(func::FuncOp func,
101-
ArrayRef<BlockArgument> appendedEntryArgs) {
102-
func.walk([&](func::ReturnOp op) {
101+
static LogicalResult updateReturnOps(func::FuncOp func,
102+
ArrayRef<BlockArgument> appendedEntryArgs,
103+
MemCpyFn memCpyFn) {
104+
auto res = func.walk([&](func::ReturnOp op) {
103105
SmallVector<Value, 6> copyIntoOutParams;
104106
SmallVector<Value, 6> keepAsReturnOperands;
105107
for (Value operand : op.getOperands()) {
@@ -109,12 +111,16 @@ static void updateReturnOps(func::FuncOp func,
109111
keepAsReturnOperands.push_back(operand);
110112
}
111113
OpBuilder builder(op);
112-
for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs))
113-
builder.create<memref::CopyOp>(op.getLoc(), std::get<0>(t),
114-
std::get<1>(t));
114+
for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
115+
if (failed(
116+
memCpyFn(builder, op.getLoc(), std::get<0>(t), std::get<1>(t))))
117+
return WalkResult::interrupt();
118+
}
115119
builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands);
116120
op.erase();
121+
return WalkResult::advance();
117122
});
123+
return failure(res.wasInterrupted());
118124
}
119125

120126
// Updates all CallOps in the scope of the given ModuleOp by allocating
@@ -192,7 +198,15 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
192198
return failure();
193199
if (func.isExternal())
194200
continue;
195-
updateReturnOps(func, appendedEntryArgs);
201+
auto defaultMemCpyFn = [](OpBuilder &builder, Location loc, Value from,
202+
Value to) {
203+
builder.create<memref::CopyOp>(loc, from, to);
204+
return success();
205+
};
206+
if (failed(updateReturnOps(func, appendedEntryArgs,
207+
options.memCpyFn.value_or(defaultMemCpyFn)))) {
208+
return failure();
209+
}
196210
}
197211
if (failed(updateCalls(module, options)))
198212
return failure();

0 commit comments

Comments
 (0)