Skip to content

[mlir][bufferization] Buffer deallocation: skip ops that do not operate on buffers #75126

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,32 @@ static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {

static bool isMemref(Value v) { return v.getType().isa<BaseMemRefType>(); }

/// Return "true" if the given op is guaranteed to have no "Allocate" or "Free"
/// side effect.
static bool hasNoAllocateOrFreeSideEffect(Operation *op) {
if (isa<MemoryEffectOpInterface>(op))
return hasEffect<MemoryEffects::Allocate>(op) ||
hasEffect<MemoryEffects::Free>(op);
// If the op does not implement the MemoryEffectOpInterface but has has
// recursive memory effects, then this op in isolation (without its body) does
// not have any side effects. The ops inside the regions of this op will be
// processed separately.
return op->hasTrait<OpTrait::HasRecursiveMemoryEffects>();
}

/// Return "true" if the given op has buffer semantics. I.e., it has buffer
/// operands, buffer results and/or buffer region entry block arguments.
static bool hasBufferSemantics(Operation *op) {
if (llvm::any_of(op->getOperands(), isMemref) ||
llvm::any_of(op->getResults(), isMemref))
return true;
for (Region &region : op->getRegions())
if (!region.empty())
if (llvm::any_of(region.front().getArguments(), isMemref))
return true;
return false;
}

//===----------------------------------------------------------------------===//
// Backedges analysis
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -462,31 +488,6 @@ BufferDeallocation::materializeUniqueOwnership(OpBuilder &builder, Value memref,
return state.getMemrefWithUniqueOwnership(builder, memref, block);
}

static bool regionOperatesOnMemrefValues(Region &region) {
auto checkBlock = [](Block *block) {
if (llvm::any_of(block->getArguments(), isMemref))
return WalkResult::interrupt();
for (Operation &op : *block) {
if (llvm::any_of(op.getOperands(), isMemref))
return WalkResult::interrupt();
if (llvm::any_of(op.getResults(), isMemref))
return WalkResult::interrupt();
}
return WalkResult::advance();
};
WalkResult result = region.walk(checkBlock);
if (result.wasInterrupted())
return true;

// Note: Block::walk/Region::walk visits only blocks that are nested under
// nested operations, but not direct children.
for (Block &block : region)
if (checkBlock(&block).wasInterrupted())
return true;

return false;
}

LogicalResult
BufferDeallocation::verifyFunctionPreconditions(FunctionOpInterface op) {
// (1) Ensure that there are supported loops only (no explicit control flow
Expand All @@ -512,9 +513,8 @@ LogicalResult BufferDeallocation::verifyOperationPreconditions(Operation *op) {
size_t size = regions.size();
if (((size == 1 && !op->getResults().empty()) || size > 1) &&
!dyn_cast<RegionBranchOpInterface>(op)) {
if (llvm::any_of(regions, regionOperatesOnMemrefValues))
return op->emitError("All operations with attached regions need to "
"implement the RegionBranchOpInterface.");
return op->emitError("All operations with attached regions need to "
"implement the RegionBranchOpInterface.");
}

// (2) The pass does not work properly when deallocations are already present.
Expand Down Expand Up @@ -648,6 +648,12 @@ LogicalResult BufferDeallocation::deallocate(Block *block) {
// For each operation in the block, handle the interfaces that affect aliasing
// and ownership of memrefs.
for (Operation &op : llvm::make_early_inc_range(*block)) {
// Skip ops that do not operate on buffers, have no Allocate/Free side
// effect and are not terminators. (bufferization.dealloc ops are inserted
// in front of terminators, so terminators cannot be skipped.)
if (!op.hasTrait<OpTrait::IsTerminator>() && !hasBufferSemantics(&op) &&
hasNoAllocateOrFreeSideEffect(&op))
continue;
FailureOr<Operation *> result = handleAllInterfaces(&op);
if (failed(result))
return failure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -516,8 +516,8 @@ func.func @assumingOp(
// does not deal with any MemRef values.

func.func @noRegionBranchOpInterface() {
%0 = "test.bar"() ({
%1 = "test.bar"() ({
%0 = "test.one_region_with_recursive_memory_effects"() ({
%1 = "test.one_region_with_recursive_memory_effects"() ({
"test.yield"() : () -> ()
}) : () -> (i32)
"test.yield"() : () -> ()
Expand All @@ -531,7 +531,7 @@ func.func @noRegionBranchOpInterface() {
// This is not allowed in buffer deallocation.

func.func @noRegionBranchOpInterface() {
%0 = "test.bar"() ({
%0 = "test.one_region_with_recursive_memory_effects"() ({
// expected-error@+1 {{All operations with attached regions need to implement the RegionBranchOpInterface.}}
%1 = "test.bar"() ({
%2 = "test.get_memref"() : () -> memref<2xi32>
Expand All @@ -545,11 +545,10 @@ func.func @noRegionBranchOpInterface() {
// -----

// Test Case: The op "test.bar" does not implement the RegionBranchOpInterface.
// This is not allowed in buffer deallocation.
// But it also does not operate on buffers, so we don't care.

func.func @noRegionBranchOpInterface() {
// expected-error@+1 {{All operations with attached regions need to implement the RegionBranchOpInterface.}}
%0 = "test.bar"() ({
%0 = "test.one_region_with_recursive_memory_effects"() ({
%2 = "test.get_memref"() : () -> memref<2xi32>
%3 = "test.foo"(%2) : (memref<2xi32>) -> (i32)
"test.yield"(%3) : (i32) -> ()
Expand Down
11 changes: 11 additions & 0 deletions mlir/test/lib/Dialect/Test/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,17 @@ def IsolatedRegionsOp : TEST_Op<"isolated_regions", [IsolatedFromAbove]> {
let assemblyFormat = "attr-dict-with-keyword $regions";
}

def OneRegionWithRecursiveMemoryEffectsOp
: TEST_Op<"one_region_with_recursive_memory_effects", [
RecursiveMemoryEffects]> {
let description = [{
Op that has one region and recursive side effects. The
RegionBranchOpInterface is not implemented on this op.
}];
let results = (outs AnyType:$result);
let regions = (region SizedRegion<1>:$body);
}

//===----------------------------------------------------------------------===//
// NoTerminator Operation
//===----------------------------------------------------------------------===//
Expand Down