Skip to content

Commit 590efce

Browse files
committed
[mlir][bufferization] Add deallocation option to remove existing dealloc operations, add option to specify the kind of alloc operations to consider
1 parent 7026a8c commit 590efce

File tree

10 files changed

+260
-7
lines changed

10 files changed

+260
-7
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERDEALLOCATIONOPINTERFACE_H_
1111

1212
#include "mlir/Analysis/Liveness.h"
13+
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
14+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1315
#include "mlir/IR/Operation.h"
1416
#include "mlir/IR/SymbolTable.h"
1517
#include "mlir/Support/LLVM.h"
@@ -92,6 +94,9 @@ class Ownership {
9294

9395
/// Options for BufferDeallocationOpInterface-based buffer deallocation.
9496
struct DeallocationOptions {
97+
using DetectionFn = std::function<bool(Operation *)>;
98+
using ReplaceDeallocFn = std::function<FailureOr<ValueRange>(Operation *)>;
99+
95100
// A pass option indicating whether private functions should be modified to
96101
// pass the ownership of MemRef values instead of adhering to the function
97102
// boundary ABI.
@@ -106,6 +111,48 @@ struct DeallocationOptions {
106111
/// to, an error will already be emitted at compile time. This cannot be
107112
/// changed with this option.
108113
bool verifyFunctionBoundaryABI = true;
114+
115+
/// Given an allocation side-effect on the passed operation, determine whether
116+
/// this allocation operation is of relevance (i.e., should assign ownership
117+
/// to the allocated value). If it is determined to not be relevant,
118+
/// ownership will be set to 'false', i.e., it will be leaked. This is useful
119+
/// to support deallocation of multiple different kinds of allocation ops.
120+
DetectionFn isRelevantAllocOp = [](Operation *op) {
121+
return isa<memref::MemRefDialect, bufferization::BufferizationDialect>(
122+
op->getDialect());
123+
};
124+
125+
/// Given a free side-effect on the passed operation, determine whether this
126+
/// deallocation operation is of relevance (i.e., should be removed if the
127+
/// `removeExistingDeallocations` option is enabled or otherwise an error
128+
/// should be emitted because existing deallocation operations are not
129+
/// supported without that flag). If it is determined to not be relevant,
130+
/// the operation will be ignored. This is useful to support deallocation of
131+
/// multiple different kinds of allocation ops where deallocations for some of
132+
/// them are already present in the IR.
133+
DetectionFn isRelevantDeallocOp = [](Operation *op) {
134+
return isa<memref::MemRefDialect, bufferization::BufferizationDialect>(
135+
op->getDialect());
136+
};
137+
138+
/// When enabled, remove deallocation operations determined to be relevant
139+
/// according to `isRelevantDeallocOp`. If the operation has result values,
140+
/// `getDeallocReplacement` will be called to determine the SSA values that
141+
/// should be used as replacements.
142+
bool removeExistingDeallocations = false;
143+
144+
/// Provides SSA values for deallocation operations when
145+
/// `removeExistingDeallocations` is enabled. May return a failure when the
146+
/// given deallocation operation is not supported (e.g., because no
147+
/// replacement for a result value can be determined). A failure will directly
148+
/// lead to a failure emitted by the deallocation pass.
149+
ReplaceDeallocFn getDeallocReplacement =
150+
[](Operation *op) -> FailureOr<ValueRange> {
151+
if (isa<memref::DeallocOp>(op))
152+
return ValueRange{};
153+
// ReallocOp has to be expanded before running the dealloc pass.
154+
return failure();
155+
};
109156
};
110157

111158
/// This class collects all the state that we need to perform the buffer

mlir/include/mlir/Dialect/Bufferization/Pipelines/Passes.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ struct BufferDeallocationPipelineOptions
4040
"statically that the ABI is not adhered to, an error will already be "
4141
"emitted at compile time. This cannot be changed with this option."),
4242
llvm::cl::init(true)};
43+
PassOptions::Option<bool> removeExistingDeallocations{
44+
*this, "remove-existing-deallocations",
45+
llvm::cl::desc("Removes all pre-existing memref.dealloc operations and "
46+
"insert all deallocations according to the buffer "
47+
"deallocation rules."),
48+
llvm::cl::init(false)};
4349

4450
/// Convert this BufferDeallocationPipelineOptions struct to a
4551
/// DeallocationOptions struct to be passed to the

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,11 @@ def OwnershipBasedBufferDeallocation : Pass<
233233
"If it can be determined statically that the ABI is not adhered "
234234
"to, an error will already be emitted at compile time. This cannot "
235235
"be changed with this option.">,
236+
Option<"removeExistingDeallocations", "remove-existing-deallocations",
237+
"bool", /*default=*/"false",
238+
"Remove already existing MemRef deallocation operations and let the "
239+
"deallocation pass insert the deallocation operations according to "
240+
"its rules.">,
236241
];
237242
let constructor = "mlir::bufferization::createOwnershipBasedBufferDeallocationPass()";
238243

mlir/lib/Dialect/Bufferization/Pipelines/BufferizationPipelines.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ BufferDeallocationPipelineOptions::asDeallocationOptions() const {
2727
DeallocationOptions opts;
2828
opts.privateFuncDynamicOwnership = privateFunctionDynamicOwnership.getValue();
2929
opts.verifyFunctionBoundaryABI = verifyFunctionBoundaryABI.getValue();
30+
opts.removeExistingDeallocations = removeExistingDeallocations.getValue();
3031
return opts;
3132
}
3233

mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -816,15 +816,28 @@ FailureOr<Operation *>
816816
BufferDeallocation::handleInterface(MemoryEffectOpInterface op) {
817817
auto *block = op->getBlock();
818818

819-
for (auto operand : llvm::make_filter_range(op->getOperands(), isMemref))
820-
if (op.getEffectOnValue<MemoryEffects::Free>(operand).has_value())
819+
for (auto operand : llvm::make_filter_range(op->getOperands(), isMemref)) {
820+
if (op.getEffectOnValue<MemoryEffects::Free>(operand).has_value() &&
821+
options.isRelevantDeallocOp(op)) {
822+
if (auto repl = options.getDeallocReplacement(op);
823+
succeeded(repl) && options.removeExistingDeallocations) {
824+
op->replaceAllUsesWith(repl.value());
825+
op.erase();
826+
return FailureOr<Operation *>(nullptr);
827+
}
828+
821829
return op->emitError(
822830
"memory free side-effect on MemRef value not supported!");
831+
}
832+
}
823833

824834
OpBuilder builder = OpBuilder::atBlockBegin(block);
825835
for (auto res : llvm::make_filter_range(op->getResults(), isMemref)) {
826836
auto allocEffect = op.getEffectOnValue<MemoryEffects::Allocate>(res);
827837
if (allocEffect.has_value()) {
838+
// Assuming that an alloc effect is interpreted as MUST and not MAY.
839+
state.resetOwnerships(res, block);
840+
828841
if (isa<SideEffects::AutomaticAllocationScopeResource>(
829842
allocEffect->getResource())) {
830843
// Make sure that the ownership of auto-managed allocations is set to
@@ -839,8 +852,15 @@ BufferDeallocation::handleInterface(MemoryEffectOpInterface op) {
839852
continue;
840853
}
841854

842-
state.updateOwnership(res, buildBoolValue(builder, op.getLoc(), true));
843-
state.addMemrefToDeallocate(res, block);
855+
if (options.isRelevantAllocOp(op)) {
856+
state.updateOwnership(res, buildBoolValue(builder, op.getLoc(), true));
857+
state.addMemrefToDeallocate(res, block);
858+
continue;
859+
}
860+
861+
// Alloc operations from other dialects are expected to have matching
862+
// deallocation operations inserted by another pass.
863+
state.updateOwnership(res, buildBoolValue(builder, op.getLoc(), false));
844864
}
845865
}
846866

@@ -943,16 +963,18 @@ struct OwnershipBasedBufferDeallocationPass
943963
: public bufferization::impl::OwnershipBasedBufferDeallocationBase<
944964
OwnershipBasedBufferDeallocationPass> {
945965
OwnershipBasedBufferDeallocationPass() = default;
946-
OwnershipBasedBufferDeallocationPass(const DeallocationOptions &options)
947-
: OwnershipBasedBufferDeallocationPass() {
966+
OwnershipBasedBufferDeallocationPass(const DeallocationOptions &options) {
948967
privateFuncDynamicOwnership.setValue(options.privateFuncDynamicOwnership);
949968
verifyFunctionBoundaryABI.setValue(options.verifyFunctionBoundaryABI);
969+
removeExistingDeallocations.setValue(options.removeExistingDeallocations);
950970
}
951971
void runOnOperation() override {
952972
DeallocationOptions options;
953973
options.privateFuncDynamicOwnership =
954974
privateFuncDynamicOwnership.getValue();
955975
options.verifyFunctionBoundaryABI = verifyFunctionBoundaryABI.getValue();
976+
options.removeExistingDeallocations =
977+
removeExistingDeallocations.getValue();
956978

957979
auto status = getOperation()->walk([&](func::FuncOp func) {
958980
if (func.isExternal())
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: mlir-opt %s --test-ownership-based-buffer-deallocation -split-input-file | FileCheck %s
2+
3+
func.func @mixed_allocations(%cond: i1) -> (memref<f32>, !gpu.async.token) {
4+
%a1 = memref.alloc() : memref<f32>
5+
%a2 = gpu.alloc() : memref<f32>
6+
%0 = arith.select %cond, %a1, %a2 : memref<f32>
7+
%token = gpu.dealloc async [] %a2 : memref<f32>
8+
memref.dealloc %a1 : memref<f32>
9+
return %0, %token : memref<f32>, !gpu.async.token
10+
}
11+
12+
// CHECK: [[A1:%.+]] = memref.alloc(
13+
// CHECK: [[A2:%.+]] = gpu.alloc
14+
// CHECK: [[SELECT:%.+]] = arith.select {{.*}}, [[A1]], [[A2]]
15+
// CHECK: [[TOKEN:%.+]] = gpu.wait async
16+
// CHECK: [[A1_BASE:%.+]],{{.*}} = memref.extract_strided_metadata [[A1]]
17+
// CHECK: [[A1_PTR:%.+]] = memref.extract_aligned_pointer_as_index [[A1_BASE]]
18+
// CHECK: [[SELECT_PTR:%.+]] = memref.extract_aligned_pointer_as_index [[SELECT]]
19+
// CHECK: [[ALIAS0:%.+]] = arith.cmpi ne, [[A1_PTR]], [[SELECT_PTR]]
20+
// CHECK: [[COND0:%.+]] = arith.andi [[ALIAS0]], %true
21+
// CHECK: scf.if [[COND0]] {
22+
// CHECK: memref.dealloc [[A1_BASE]]
23+
// CHECK: }
24+
// CHECK: [[A2_BASE:%.+]],{{.*}} = memref.extract_strided_metadata [[A2]]
25+
// CHECK: [[A2_PTR:%.+]] = memref.extract_aligned_pointer_as_index [[A2_BASE]]
26+
// CHECK: [[SELECT_PTR:%.+]] = memref.extract_aligned_pointer_as_index [[SELECT]]
27+
// CHECK: [[ALIAS1:%.+]] = arith.cmpi ne, [[A2_PTR]], [[SELECT_PTR]]
28+
// CHECK: [[COND1:%.+]] = arith.andi [[ALIAS1]], %true
29+
// CHECK: scf.if [[COND1]] {
30+
// TODO: add pass option to lower-deallocation to insert gpu.dealloc here
31+
// CHECK: memref.dealloc [[A2_BASE]]
32+
// CHECK: }
33+
// CHECK: return [[SELECT]], [[TOKEN]]

mlir/test/lib/Dialect/Bufferization/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,12 @@ add_mlir_library(MLIRBufferizationTestPasses
55
EXCLUDE_FROM_LIBMLIR
66

77
LINK_LIBS PUBLIC
8+
MLIRArithDialect
89
MLIRBufferizationDialect
910
MLIRBufferizationTransforms
11+
MLIRFuncDialect
12+
MLIRGPUDialect
13+
MLIRSCFDialect
1014
MLIRIR
1115
MLIRPass
1216
)
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
//===- TestOwnershipBasedBufferDeallocation.cpp -----------------*- c++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
10+
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
11+
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
12+
#include "mlir/Dialect/Func/IR/FuncOps.h"
13+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
14+
#include "mlir/Dialect/SCF/IR/SCF.h"
15+
#include "mlir/Pass/Pass.h"
16+
#include "mlir/Transforms/DialectConversion.h"
17+
18+
using namespace mlir;
19+
20+
namespace {
21+
/// This pass runs the ownership based deallocation pass once for `memref.alloc`
22+
/// operations, then lowers the `bufferization.dealloc` operations, and
23+
/// afterwards runs the deallocation pass again for `gpu.alloc` operations and
24+
/// lowers the inserted `bufferization.dealloc` operations again to the
25+
/// corresponding deallocation operations.
26+
struct TestOwnershipBasedBufferDeallocationPass
27+
: public PassWrapper<TestOwnershipBasedBufferDeallocationPass,
28+
OperationPass<ModuleOp>> {
29+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
30+
TestOwnershipBasedBufferDeallocationPass)
31+
32+
TestOwnershipBasedBufferDeallocationPass() = default;
33+
TestOwnershipBasedBufferDeallocationPass(
34+
const TestOwnershipBasedBufferDeallocationPass &pass)
35+
: TestOwnershipBasedBufferDeallocationPass() {}
36+
37+
void getDependentDialects(DialectRegistry &registry) const override {
38+
registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
39+
scf::SCFDialect, func::FuncDialect, arith::ArithDialect>();
40+
}
41+
StringRef getArgument() const final {
42+
return "test-ownership-based-buffer-deallocation";
43+
}
44+
StringRef getDescription() const final {
45+
return "Module pass to test the Ownership-based Buffer Deallocation pass";
46+
}
47+
48+
void runOnOperation() override {
49+
ModuleOp module = getOperation();
50+
51+
// Build the library function for the lowering of `bufferization.dealloc`.
52+
OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
53+
SymbolTable symbolTable(module);
54+
func::FuncOp helper = bufferization::buildDeallocationLibraryFunction(
55+
builder, module.getLoc(), symbolTable);
56+
57+
RewritePatternSet patterns(module->getContext());
58+
bufferization::populateBufferizationDeallocLoweringPattern(patterns,
59+
helper);
60+
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
61+
62+
WalkResult result = getOperation()->walk([&](FunctionOpInterface funcOp) {
63+
// Deallocate the `memref.alloc` operations.
64+
bufferization::DeallocationOptions options;
65+
options.removeExistingDeallocations = true;
66+
if (failed(
67+
bufferization::deallocateBuffersOwnershipBased(funcOp, options)))
68+
return WalkResult::interrupt();
69+
70+
// Lower the inserted `bufferization.dealloc` operations.
71+
ConversionTarget target(getContext());
72+
target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
73+
scf::SCFDialect, func::FuncDialect>();
74+
target.addIllegalOp<bufferization::DeallocOp>();
75+
76+
if (failed(applyPartialConversion(funcOp, target, frozenPatterns)))
77+
return WalkResult::interrupt();
78+
79+
// Deallocate the `gpu.alloc` operations.
80+
options.isRelevantAllocOp = [](Operation *op) {
81+
return isa<gpu::GPUDialect>(op->getDialect());
82+
};
83+
options.isRelevantDeallocOp = [](Operation *op) {
84+
return isa<gpu::GPUDialect>(op->getDialect());
85+
};
86+
options.getDeallocReplacement =
87+
[](Operation *op) -> FailureOr<ValueRange> {
88+
if (auto gpuDealloc = dyn_cast<gpu::DeallocOp>(op)) {
89+
if (gpuDealloc.getAsyncToken()) {
90+
OpBuilder builder(op);
91+
ValueRange token =
92+
builder
93+
.create<gpu::WaitOp>(
94+
op->getLoc(),
95+
gpu::AsyncTokenType::get(builder.getContext()),
96+
ValueRange{})
97+
.getResults();
98+
return token;
99+
}
100+
return ValueRange{};
101+
}
102+
return failure();
103+
};
104+
if (failed(
105+
bufferization::deallocateBuffersOwnershipBased(funcOp, options)))
106+
return WalkResult::interrupt();
107+
108+
// Lower the `bufferization.dealloc` operations inserted in the second
109+
// deallocation run.
110+
// TODO: they are currently also lowered to memref.dealloc, we need to
111+
// add pass options to the lowering pass that allow us to select the
112+
// dealloc operation to be inserted.
113+
if (failed(applyPartialConversion(funcOp, target, frozenPatterns)))
114+
return WalkResult::interrupt();
115+
116+
return WalkResult::advance();
117+
});
118+
if (result.wasInterrupted())
119+
signalPassFailure();
120+
}
121+
};
122+
} // namespace
123+
124+
namespace mlir::test {
125+
void registerTestOwnershipBasedBufferDeallocationPass() {
126+
PassRegistration<TestOwnershipBasedBufferDeallocationPass>();
127+
}
128+
} // namespace mlir::test

mlir/tools/mlir-opt/mlir-opt.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ void registerTestMemRefStrideCalculation();
121121
void registerTestNextAccessPass();
122122
void registerTestOneToNTypeConversionPass();
123123
void registerTestOpaqueLoc();
124+
void registerTestOwnershipBasedBufferDeallocationPass();
124125
void registerTestPadFusion();
125126
void registerTestPDLByteCodePass();
126127
void registerTestPDLLPasses();
@@ -241,6 +242,7 @@ void registerTestPasses() {
241242
mlir::test::registerTestNextAccessPass();
242243
mlir::test::registerTestOneToNTypeConversionPass();
243244
mlir::test::registerTestOpaqueLoc();
245+
mlir::test::registerTestOwnershipBasedBufferDeallocationPass();
244246
mlir::test::registerTestPadFusion();
245247
mlir::test::registerTestPDLByteCodePass();
246248
mlir::test::registerTestPDLLPasses();

utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

55
load("@bazel_skylib//rules:expand_template.bzl", "expand_template")
6+
load("//llvm:lit_test.bzl", "package_path")
67
load("//mlir:build_defs.bzl", "if_cuda_available")
78
load("//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
8-
load("//llvm:lit_test.bzl", "package_path")
99

1010
package(
1111
default_visibility = ["//visibility:public"],
@@ -843,10 +843,15 @@ cc_library(
843843
defines = ["MLIR_CUDA_CONVERSIONS_ENABLED"],
844844
includes = ["lib/Dialect/Test"],
845845
deps = [
846+
"//mlir:ArithDialect",
846847
"//mlir:BufferizationDialect",
847848
"//mlir:BufferizationTransforms",
849+
"//mlir:FuncDialect",
850+
"//mlir:GPUDialect",
848851
"//mlir:IR",
849852
"//mlir:Pass",
853+
"//mlir:SCFDialect",
854+
"//mlir:Transforms",
850855
],
851856
)
852857

0 commit comments

Comments
 (0)