Skip to content

Commit bf145cf

Browse files
committed
add and use custom allocation function
1 parent 1db8e79 commit bf145cf

File tree

2 files changed

+30
-7
lines changed

2 files changed

+30
-7
lines changed

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
namespace mlir {
88
class FunctionOpInterface;
9+
class MemRefType;
910
class ModuleOp;
1011
class RewritePatternSet;
1112
class OpBuilder;
@@ -38,7 +39,7 @@ std::unique_ptr<Pass> createOwnershipBasedBufferDeallocationPass(
3839
DeallocationOptions options = DeallocationOptions());
3940

4041
/// Creates a pass that finds all temporary allocations
41-
/// and attempts to move the deallocation after the last user/dependency
42+
/// and attempts to move the deallocation after the last user/dependency
4243
/// of the allocation, thereby optimizing allocation liveness.
4344
std::unique_ptr<Pass> createOptimizeAllocationLivenessPass();
4445

@@ -157,6 +158,12 @@ std::unique_ptr<Pass> createBufferLoopHoistingPass();
157158
// Options struct for BufferResultsToOutParams pass.
158159
// Note: defined only here, not in tablegen.
159160
struct BufferResultsToOutParamsOpts {
161+
/// Allocator function: Generate a memref allocation with the given type.
162+
/// Since `promoteBufferResultsToOutParams` doesn't allow dynamically shaped
163+
/// results, we don't allow passing a range of values for dynamic dims.
164+
using AllocationFn =
165+
std::function<FailureOr<Value>(OpBuilder &, Location, MemRefType)>;
166+
160167
/// Memcpy function: Generate a memcpy between two memrefs.
161168
using MemCpyFn =
162169
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
@@ -167,6 +174,10 @@ struct BufferResultsToOutParamsOpts {
167174
return true;
168175
};
169176

177+
/// Allocation function; used to allocate a memref.
178+
/// If this is empty, memref.alloc is used
179+
std::optional<AllocationFn> allocationFn;
180+
170181
/// Memcpy function; used to create a copy between two memrefs.
171182
/// If this is empty, memref.copy is used.
172183
std::optional<MemCpyFn> memCpyFn;

mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ namespace bufferization {
2222
} // namespace mlir
2323

2424
using namespace mlir;
25+
using AllocationFn = bufferization::BufferResultsToOutParamsOpts::AllocationFn;
2526
using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
2627

2728
/// Return `true` if the given MemRef type has a fully dynamic layout.
@@ -141,9 +142,8 @@ static LogicalResult updateReturnOps(func::FuncOp func,
141142

142143
// Updates all CallOps in the scope of the given ModuleOp by allocating
143144
// temporary buffers for newly introduced out params.
144-
static LogicalResult
145-
updateCalls(ModuleOp module,
146-
const bufferization::BufferResultsToOutParamsOpts &options) {
145+
static LogicalResult updateCalls(ModuleOp module, AllocationFn allocationFn,
146+
std::function<bool(func::FuncOp *)> filterFn) {
147147
bool didFail = false;
148148
SymbolTable symtab(module);
149149
module.walk([&](func::CallOp op) {
@@ -154,7 +154,7 @@ updateCalls(ModuleOp module,
154154
didFail = true;
155155
return;
156156
}
157-
if (!options.filterFn(&callee))
157+
if (!filterFn(&callee))
158158
return;
159159
SmallVector<Value, 6> replaceWithNewCallResults;
160160
SmallVector<Value, 6> replaceWithOutParams;
@@ -177,7 +177,13 @@ updateCalls(ModuleOp module,
177177
auto allocType =
178178
MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
179179
AffineMap(), memrefType.getMemorySpace());
180-
Value outParam = builder.create<memref::AllocOp>(op.getLoc(), allocType);
180+
auto maybeOutParam = allocationFn(builder, op.getLoc(), allocType);
181+
if (failed(maybeOutParam)) {
182+
op.emitError() << "failed to create allocation op";
183+
didFail = true;
184+
return;
185+
}
186+
Value outParam = maybeOutParam.value();
181187
if (!hasStaticIdentityLayout(memrefType)) {
182188
// Layout maps are already checked in `updateFuncOp`.
183189
assert(hasFullyDynamicLayoutMap(memrefType) &&
@@ -226,7 +232,13 @@ LogicalResult mlir::bufferization::promoteBufferResultsToOutParams(
226232
return failure();
227233
}
228234
}
229-
if (failed(updateCalls(module, options)))
235+
auto defaultAllocationFn = [](OpBuilder &builder, Location loc,
236+
MemRefType type) {
237+
return builder.create<memref::AllocOp>(loc, type).getResult();
238+
};
239+
if (failed(updateCalls(module,
240+
options.allocationFn.value_or(defaultAllocationFn),
241+
options.filterFn)))
230242
return failure();
231243
return success();
232244
}

0 commit comments

Comments
 (0)