Skip to content

Commit cc788e0

Browse files
committed
Move default alloc/copy fns to BufferResultsToOutParamsOpts struct
1 parent bf145cf commit cc788e0

File tree

2 files changed

+25
-29
lines changed

2 files changed

+25
-29
lines changed

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES_H
33

44
#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
5+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
56
#include "mlir/Pass/Pass.h"
67

78
namespace mlir {
@@ -175,12 +176,19 @@ struct BufferResultsToOutParamsOpts {
175176
};
176177

177178
/// Allocation function; used to allocate a memref.
178-
/// If this is empty, memref.alloc is used
179-
std::optional<AllocationFn> allocationFn;
179+
/// Default memref.alloc is used
180+
AllocationFn allocationFn = [](OpBuilder &builder, Location loc,
181+
MemRefType type) {
182+
return builder.create<memref::AllocOp>(loc, type).getResult();
183+
};
180184

181185
/// Memcpy function; used to create a copy between two memrefs.
182-
/// If this is empty, memref.copy is used.
183-
std::optional<MemCpyFn> memCpyFn;
186+
/// Default memref.copy is used.
187+
MemCpyFn memCpyFn = [](OpBuilder &builder, Location loc, Value from,
188+
Value to) {
189+
builder.create<memref::CopyOp>(loc, from, to);
190+
return success();
191+
};
184192

185193
/// If true, the pass adds a "bufferize.result" attribute to each output
186194
/// parameter.

mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,9 @@ updateFuncOp(func::FuncOp func,
107107
// Updates all ReturnOps in the scope of the given func::FuncOp by either
108108
// keeping them as return values or copying the associated buffer contents into
109109
// the given out-params.
110-
static LogicalResult updateReturnOps(func::FuncOp func,
111-
ArrayRef<BlockArgument> appendedEntryArgs,
112-
MemCpyFn memCpyFn,
113-
bool hoistStaticAllocs) {
110+
static LogicalResult
111+
updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
112+
const bufferization::BufferResultsToOutParamsOpts &options) {
114113
auto res = func.walk([&](func::ReturnOp op) {
115114
SmallVector<Value, 6> copyIntoOutParams;
116115
SmallVector<Value, 6> keepAsReturnOperands;
@@ -122,14 +121,14 @@ static LogicalResult updateReturnOps(func::FuncOp func,
122121
}
123122
OpBuilder builder(op);
124123
for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
125-
if (hoistStaticAllocs &&
124+
if (options.hoistStaticAllocs &&
126125
isa_and_nonnull<bufferization::AllocationOpInterface>(
127126
orig.getDefiningOp()) &&
128127
mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) {
129128
orig.replaceAllUsesWith(arg);
130129
orig.getDefiningOp()->erase();
131130
} else {
132-
if (failed(memCpyFn(builder, op.getLoc(), orig, arg)))
131+
if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg)))
133132
return WalkResult::interrupt();
134133
}
135134
}
@@ -142,8 +141,9 @@ static LogicalResult updateReturnOps(func::FuncOp func,
142141

143142
// Updates all CallOps in the scope of the given ModuleOp by allocating
144143
// temporary buffers for newly introduced out params.
145-
static LogicalResult updateCalls(ModuleOp module, AllocationFn allocationFn,
146-
std::function<bool(func::FuncOp *)> filterFn) {
144+
static LogicalResult
145+
updateCalls(ModuleOp module,
146+
const bufferization::BufferResultsToOutParamsOpts &options) {
147147
bool didFail = false;
148148
SymbolTable symtab(module);
149149
module.walk([&](func::CallOp op) {
@@ -154,7 +154,7 @@ static LogicalResult updateCalls(ModuleOp module, AllocationFn allocationFn,
154154
didFail = true;
155155
return;
156156
}
157-
if (!filterFn(&callee))
157+
if (!options.filterFn(&callee))
158158
return;
159159
SmallVector<Value, 6> replaceWithNewCallResults;
160160
SmallVector<Value, 6> replaceWithOutParams;
@@ -177,7 +177,8 @@ static LogicalResult updateCalls(ModuleOp module, AllocationFn allocationFn,
177177
auto allocType =
178178
MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
179179
AffineMap(), memrefType.getMemorySpace());
180-
auto maybeOutParam = allocationFn(builder, op.getLoc(), allocType);
180+
auto maybeOutParam =
181+
options.allocationFn(builder, op.getLoc(), allocType);
181182
if (failed(maybeOutParam)) {
182183
op.emitError() << "failed to create allocation op";
183184
didFail = true;
@@ -221,24 +222,11 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
221222
return failure();
222223
if (func.isExternal())
223224
continue;
224-
auto defaultMemCpyFn = [](OpBuilder &builder, Location loc, Value from,
225-
Value to) {
226-
builder.create<memref::CopyOp>(loc, from, to);
227-
return success();
228-
};
229-
if (failed(updateReturnOps(func, appendedEntryArgs,
230-
options.memCpyFn.value_or(defaultMemCpyFn),
231-
options.hoistStaticAllocs))) {
225+
if (failed(updateReturnOps(func, appendedEntryArgs, options))) {
232226
return failure();
233227
}
234228
}
235-
auto defaultAllocationFn = [](OpBuilder &builder, Location loc,
236-
MemRefType type) {
237-
return builder.create<memref::AllocOp>(loc, type).getResult();
238-
};
239-
if (failed(updateCalls(module,
240-
options.allocationFn.value_or(defaultAllocationFn),
241-
options.filterFn)))
229+
if (failed(updateCalls(module, options)))
242230
return failure();
243231
return success();
244232
}

0 commit comments

Comments
 (0)