Skip to content

Commit c076fa1

Browse files
[mlir][bufferize] Deallocate returned buffers with BufferDeallocation
New buffer allocations can now be returned/yielded from blocks with `allow-return-allocs`. One-Shot Bufferize deallocates all buffers at the end of the block. If this is not possible (because the buffer escapes the block), this is now done by the existing BufferDeallocation pass. Differential Revision: https://reviews.llvm.org/D121527
1 parent 30adb9f commit c076fa1

File tree

11 files changed

+145
-62
lines changed

11 files changed

+145
-62
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -516,11 +516,16 @@ LogicalResult createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
516516
LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from, Value to,
517517
const BufferizationOptions &options);
518518

519-
/// Finalize all buffer allocations.
520-
/// * Hoist buffer allocations as much as possible.
521-
/// * Create alloc/dealloc ops as specified by the bufferization options.
522-
LogicalResult finalizeBuffers(Operation *op,
523-
const BufferizationOptions &options);
519+
/// Try to hoist all new buffer allocations until the next hoisting barrier.
520+
LogicalResult hoistBufferAllocations(Operation *op,
521+
const BufferizationOptions &options);
522+
523+
/// Create alloc/dealloc ops as specified in the bufferization options. If
524+
/// `onlyLeakingAlloc`, only those buffer allocations are processed for which no
525+
/// buffer deallocation can be created.
526+
LogicalResult createAllocDeallocOps(Operation *op,
527+
const BufferizationOptions &options,
528+
bool onlyLeakingAllocs = false);
524529
} // namespace bufferization
525530
} // namespace mlir
526531

mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,12 @@ BufferizationOptions getPartialBufferizationOptions();
9090
/// Note: This function overload is useful for extending the bufferization.
9191
LogicalResult bufferizeOp(Operation *op,
9292
BufferizationState &bufferizationState);
93+
94+
/// Finalize all buffer allocations.
95+
/// * Hoist buffer allocations as much as possible.
96+
/// * Create alloc/dealloc ops as specified by the bufferization options.
97+
LogicalResult finalizeBuffers(Operation *op,
98+
const BufferizationOptions &options);
9399
} // namespace bufferization
94100
} // namespace mlir
95101

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ struct OneShotBufferizationOptions;
1515
/// buffers.
1616
std::unique_ptr<Pass> createBufferDeallocationPass();
1717

18+
/// Run buffer deallocation.
19+
LogicalResult deallocateBuffers(Operation *op);
20+
1821
/// Creates a pass that moves allocations upwards to reduce the number of
1922
/// required copies that are inserted during the BufferDeallocation pass.
2023
std::unique_ptr<Pass> createBufferHoistingPass();
@@ -55,6 +58,9 @@ createPromoteBuffersToStackPass(std::function<bool(Value)> isSmallAlloc);
5558
// Registration
5659
//===----------------------------------------------------------------------===//
5760

61+
/// Register external models for AllocationOpInterface.
62+
void registerAllocationOpInterfaceExternalModels(DialectRegistry &registry);
63+
5864
/// Generate the code for registering passes.
5965
#define GEN_PASS_REGISTRATION
6066
#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,11 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
170170
example, `tensor.generate` is not in destination-passing style and always
171171
results in a new buffer allocation.
172172

173-
One-Shot Bufferize deallocates all buffers that it allocates. Yielding newly
174-
allocated buffers from a block is not supported yet and such IR will be
175-
rejected. For testing purposes and compatibility with partial bufferization,
176-
One-Shot Bufferize can be run with `allow-return-allocs=1 create-dealloc=0`
177-
to allow such IR.
173+
One-Shot Bufferize deallocates all buffers that it allocates. Returning or
174+
yielding newly allocated buffers from a block can lead to bad performance
175+
because additional buffer copies would be inserted. By default, such IR is
176+
rejected by One-Shot Bufferize. If performance is not important, such IR can
177+
be allowed with `allow-return-allocs=1`.
178178

179179
One-Shot Bufferize will by default reject IR that contains non-bufferizable
180180
op, i.e., ops that do not implemement BufferizableOpInterface. Such IR can
@@ -204,7 +204,7 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
204204
let options = [
205205
Option<"allowReturnAllocs", "allow-return-allocs", "bool",
206206
/*default=*/"false",
207-
"Allows the return of new allocations (for testing purposes only)">,
207+
"Allows returning/yielding new allocations from a block.">,
208208
Option<"allowUnknownOps", "allow-unknown-ops", "bool",
209209
/*default=*/"false",
210210
"Allows unknown (not bufferizable) ops in the input IR.">,

mlir/include/mlir/Dialect/Linalg/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def LinalgComprehensiveModuleBufferize :
4545
"Annotates IR with RaW conflicts. Requires test-analysis-only.">,
4646
Option<"allowReturnAllocs", "allow-return-allocs", "bool",
4747
/*default=*/"false",
48-
"Allows the return of new allocations (for testing purposes only)">,
48+
"Allows returning/yielding new allocations from a block.">,
4949
Option<"allowUnknownOps", "allow-unknown-ops", "bool",
5050
/*default=*/"false",
5151
"Allows unknown (not bufferizable) ops in the input IR.">,

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -508,8 +508,10 @@ LogicalResult bufferization::createMemCpy(OpBuilder &b, Location loc,
508508
return success();
509509
}
510510

511-
static LogicalResult
512-
createAllocDeallocOps(Operation *op, const BufferizationOptions &options) {
511+
LogicalResult
512+
bufferization::createAllocDeallocOps(Operation *op,
513+
const BufferizationOptions &options,
514+
bool onlyLeakingAllocs) {
513515
IRRewriter rewriter(op->getContext());
514516

515517
// Bufferization creates memref.alloca ops. After bufferization, these must be
@@ -518,7 +520,11 @@ createAllocDeallocOps(Operation *op, const BufferizationOptions &options) {
518520
// Ignore memref.alloca ops that were not created by the bufferization.
519521
if (!allocaOp->hasAttr(kBufferAllocationAttr))
520522
return WalkResult::skip();
523+
// If `onlyLeakingAllocs`, process only ops that are marked as
524+
// "skip dealloc".
521525
bool skipDealloc = allocaOp->hasAttr(kSkipDeallocAttr);
526+
if (onlyLeakingAllocs && !skipDealloc)
527+
return WalkResult::skip();
522528

523529
// Create alloc.
524530
Block *block = allocaOp->getBlock();
@@ -547,8 +553,9 @@ createAllocDeallocOps(Operation *op, const BufferizationOptions &options) {
547553

548554
/// Try to hoist all new buffer allocations until the next hoisting barrier.
549555
// TODO: Consolidate this function with the existing buffer hoisting pass.
550-
static LogicalResult
551-
hoistBufferAllocations(Operation *op, const BufferizationOptions &options) {
556+
LogicalResult
557+
bufferization::hoistBufferAllocations(Operation *op,
558+
const BufferizationOptions &options) {
552559
// Nothing to do if allocation hoisting is deactivated.
553560
if (!options.hoistAllocations)
554561
return success();
@@ -601,17 +608,6 @@ hoistBufferAllocations(Operation *op, const BufferizationOptions &options) {
601608
return success();
602609
}
603610

604-
LogicalResult
605-
bufferization::finalizeBuffers(Operation *op,
606-
const BufferizationOptions &options) {
607-
if (failed(hoistBufferAllocations(op, options)))
608-
return failure();
609-
if (failed(createAllocDeallocOps(op, options)))
610-
return failure();
611-
612-
return success();
613-
}
614-
615611
//===----------------------------------------------------------------------===//
616612
// Bufferization-specific BlockAndValueMapping support with debugging.
617613
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp

Lines changed: 55 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,15 @@ walkReturnOperations(Region *region,
7777
return success();
7878
}
7979

80-
/// Checks if all operations in a given region that have at least one attached
81-
/// region implement the RegionBranchOpInterface. This is not required in edge
82-
/// cases, where we have a single attached region and the parent operation has
83-
/// no results.
84-
static bool validateSupportedControlFlow(Region &region) {
85-
bool success = true;
86-
region.walk([&success](Operation *operation) {
80+
/// Checks if all operations that have at least one attached region implement
81+
/// the RegionBranchOpInterface. This is not required in edge cases, where we
82+
/// have a single attached region and the parent operation has no results.
83+
static bool validateSupportedControlFlow(Operation *op) {
84+
WalkResult result = op->walk([&](Operation *operation) {
85+
// Only check ops that are inside a function.
86+
if (!operation->getParentOfType<FuncOp>())
87+
return WalkResult::advance();
88+
8789
auto regions = operation->getRegions();
8890
// Walk over all operations in a region and check if the operation has at
8991
// least one region and implements the RegionBranchOpInterface. If there
@@ -96,10 +98,11 @@ static bool validateSupportedControlFlow(Region &region) {
9698
!dyn_cast<RegionBranchOpInterface>(operation)) {
9799
operation->emitError("All operations with attached regions need to "
98100
"implement the RegionBranchOpInterface.");
99-
success = false;
100101
}
102+
103+
return WalkResult::advance();
101104
});
102-
return success;
105+
return !result.wasSkipped();
103106
}
104107

105108
namespace {
@@ -639,40 +642,61 @@ struct BufferDeallocationPass : BufferDeallocationBase<BufferDeallocationPass> {
639642
void getDependentDialects(DialectRegistry &registry) const override {
640643
registry.insert<bufferization::BufferizationDialect>();
641644
registry.insert<memref::MemRefDialect>();
642-
registry.addOpInterface<memref::AllocOp, DefaultAllocationInterface>();
645+
registerAllocationOpInterfaceExternalModels(registry);
643646
}
644647

645648
void runOnOperation() override {
646649
FuncOp func = getOperation();
647650
if (func.isExternal())
648651
return;
649652

650-
// Ensure that there are supported loops only.
651-
Backedges backedges(func);
652-
if (backedges.size()) {
653-
func.emitError("Only structured control-flow loops are supported.");
654-
return signalPassFailure();
655-
}
656-
657-
// Check that the control flow structures are supported.
658-
if (!validateSupportedControlFlow(func.getRegion()))
659-
return signalPassFailure();
653+
if (failed(deallocateBuffers(func)))
654+
signalPassFailure();
655+
}
656+
};
660657

661-
// Gather all required allocation nodes and prepare the deallocation phase.
662-
BufferDeallocation deallocation(func);
658+
} // namespace
663659

664-
// Check for supported AllocationOpInterface implementations and prepare the
665-
// internal deallocation pass.
666-
if (failed(deallocation.prepare()))
667-
return signalPassFailure();
660+
LogicalResult bufferization::deallocateBuffers(Operation *op) {
661+
if (isa<ModuleOp>(op)) {
662+
WalkResult result = op->walk([&](FuncOp funcOp) {
663+
if (failed(deallocateBuffers(funcOp)))
664+
return WalkResult::interrupt();
665+
return WalkResult::advance();
666+
});
667+
return success(!result.wasInterrupted());
668+
}
668669

669-
// Place all required temporary clone and dealloc nodes.
670-
if (failed(deallocation.deallocate()))
671-
return signalPassFailure();
670+
// Ensure that there are supported loops only.
671+
Backedges backedges(op);
672+
if (backedges.size()) {
673+
op->emitError("Only structured control-flow loops are supported.");
674+
return failure();
672675
}
673-
};
674676

675-
} // namespace
677+
// Check that the control flow structures are supported.
678+
if (!validateSupportedControlFlow(op))
679+
return failure();
680+
681+
// Gather all required allocation nodes and prepare the deallocation phase.
682+
BufferDeallocation deallocation(op);
683+
684+
// Check for supported AllocationOpInterface implementations and prepare the
685+
// internal deallocation pass.
686+
if (failed(deallocation.prepare()))
687+
return failure();
688+
689+
// Place all required temporary clone and dealloc nodes.
690+
if (failed(deallocation.deallocate()))
691+
return failure();
692+
693+
return success();
694+
}
695+
696+
void bufferization::registerAllocationOpInterfaceExternalModels(
697+
DialectRegistry &registry) {
698+
registry.addOpInterface<memref::AllocOp, DefaultAllocationInterface>();
699+
}
676700

677701
//===----------------------------------------------------------------------===//
678702
// BufferDeallocationPass construction

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,9 @@ struct OneShotBufferizePass
157157
: options(options) {}
158158

159159
void getDependentDialects(DialectRegistry &registry) const override {
160-
registry.insert<bufferization::BufferizationDialect>();
160+
registry
161+
.insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
162+
registerAllocationOpInterfaceExternalModels(registry);
161163
}
162164

163165
void runOnOperation() override {
@@ -299,6 +301,21 @@ checkBufferizationResult(Operation *op, const BufferizationOptions &options) {
299301
return success();
300302
}
301303

304+
LogicalResult
305+
bufferization::finalizeBuffers(Operation *op,
306+
const BufferizationOptions &options) {
307+
if (failed(hoistBufferAllocations(op, options)))
308+
return failure();
309+
if (failed(createAllocDeallocOps(op, options, /*onlyLeakingAllocs=*/true)))
310+
return failure();
311+
if (options.createDeallocs && failed(deallocateBuffers(op)))
312+
return failure();
313+
if (failed(createAllocDeallocOps(op, options)))
314+
return failure();
315+
316+
return success();
317+
}
318+
302319
LogicalResult bufferization::bufferizeOp(Operation *op,
303320
const AnalysisState &analysisState) {
304321
BufferizationState bufferizationState(analysisState);

mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1313
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1414
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
15+
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
1516
#include "mlir/Dialect/Func/IR/FuncOps.h"
1617
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
1718
#include "mlir/Dialect/Linalg/Passes.h"
@@ -51,6 +52,7 @@ struct LinalgComprehensiveModuleBufferize
5152
vector::VectorDialect, scf::SCFDialect,
5253
arith::ArithmeticDialect, func::FuncDialect, AffineDialect>();
5354
arith::registerBufferizableOpInterfaceExternalModels(registry);
55+
bufferization::registerAllocationOpInterfaceExternalModels(registry);
5456
linalg::registerBufferizableOpInterfaceExternalModels(registry);
5557
scf::registerBufferizableOpInterfaceExternalModels(registry);
5658
std_ext::registerModuleBufferizationExternalModels(registry);

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@ func @buffer_not_deallocated(%t : tensor<?xf32>, %c : i1) -> tensor<?xf32> {
2121
} else {
2222
// CHECK: } else {
2323
// CHECK: %[[m:.*]] = bufferization.to_memref %[[t]]
24-
// CHECK: scf.yield %[[m]]
24+
// CHECK: %[[cloned:.*]] = bufferization.clone %[[m]]
25+
// CHECK: scf.yield %[[cloned]]
2526
scf.yield %t : tensor<?xf32>
2627
}
2728
// CHECK: }
28-
// CHECK-NOT: dealloc
2929
// CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]]
30+
// CHECK: memref.dealloc %[[r]]
3031
// CHECK: return %[[r_tensor]]
3132
return %r : tensor<?xf32>
3233
}

0 commit comments

Comments
 (0)