Skip to content

[mlir][bufferization] BufferDeallocationOpInterface: support custom ownership update logic #66350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//===- BufferDeallocationOpInterfaceImpl.h ----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H
#define MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H

namespace mlir {

class DialectRegistry;

namespace arith {
void registerBufferDeallocationOpInterfaceExternalModels(
DialectRegistry &registry);
} // namespace arith
} // namespace mlir

#endif // MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ class DeallocationState {
/// a new SSA value, returned as the first element of the pair, which has
/// 'Unique' ownership and can be used instead of the passed Value with the
/// the ownership indicator returned as the second element of the pair.
std::pair<Value, Value> getMemrefWithUniqueOwnership(OpBuilder &builder,
Value memref);
std::pair<Value, Value>
getMemrefWithUniqueOwnership(OpBuilder &builder, Value memref, Block *block);

/// Given two basic blocks and the values passed via block arguments to the
/// destination block, compute the list of MemRefs that have to be retained in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,34 @@ def BufferDeallocationOpInterface :
/*retType=*/"FailureOr<Operation *>",
/*methodName=*/"process",
/*args=*/(ins "DeallocationState &":$state,
"const DeallocationOptions &":$options)>
"const DeallocationOptions &":$options)>,
InterfaceMethod<
/*desc=*/[{
This method allows the implementing operation to specify custom logic
to materialize an ownership indicator value for the given MemRef typed
value it defines (including block arguments of nested regions). Since
the operation itself has more information about its semantics the
materialized IR can be more efficient compared to the default
implementation and avoid cloning MemRefs and/or doing alias checking
at runtime.
Note that the same logic could also be implemented in the 'process'
method above, however, the IR is always materialized then. If
it's desirable to only materialize the IR to compute an updated
ownership indicator when needed, it should be implemented using this
method (which is especially important if operations are created that
cannot be easily canonicalized away anymore).
}],
/*retType=*/"std::pair<Value, Value>",
/*methodName=*/"materializeUniqueOwnershipForMemref",
/*args=*/(ins "DeallocationState &":$state,
"const DeallocationOptions &":$options,
"OpBuilder &":$builder,
"Value":$memref),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
return state.getMemrefWithUniqueOwnership(
builder, memref, memref.getParentBlock());
}]>,
];
}

Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
Expand Down Expand Up @@ -133,6 +134,7 @@ inline void registerAllDialects(DialectRegistry &registry) {

// Register all external models.
affine::registerValueBoundsOpInterfaceExternalModels(registry);
arith::registerBufferDeallocationOpInterfaceExternalModels(registry);
arith::registerBufferizableOpInterfaceExternalModels(registry);
arith::registerValueBoundsOpInterfaceExternalModels(registry);
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
//===- BufferDeallocationOpInterfaceImpl.cpp ------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"

using namespace mlir;
using namespace mlir::bufferization;

namespace {
/// Provides custom logic to materialize ownership indicator values for the
/// result value of 'arith.select'. Instead of cloning or runtime alias
/// checking, this implementation inserts another `arith.select` to choose the
/// ownership indicator of the operand in the same way the original
/// `arith.select` chooses the MemRef operand. If at least one of the operand's
/// ownerships is 'Unknown', fall back to the default implementation.
///
/// Example:
/// ```mlir
/// // let ownership(%m0) := %o0
/// // let ownership(%m1) := %o1
/// %res = arith.select %cond, %m0, %m1
/// ```
/// The default implementation would insert a clone and replace all uses of the
/// result of `arith.select` with that clone:
/// ```mlir
/// %res = arith.select %cond, %m0, %m1
/// %clone = bufferization.clone %res
/// // let ownership(%res) := 'Unknown'
/// // let ownership(%clone) := %true
/// // replace all uses of %res with %clone
/// ```
/// This implementation, on the other hand, materializes the following:
/// ```mlir
/// %res = arith.select %cond, %m0, %m1
/// %res_ownership = arith.select %cond, %o0, %o1
/// // let ownership(%res) := %res_ownership
/// ```
struct SelectOpInterface
: public BufferDeallocationOpInterface::ExternalModel<SelectOpInterface,
arith::SelectOp> {
FailureOr<Operation *> process(Operation *op, DeallocationState &state,
const DeallocationOptions &options) const {
return op; // nothing to do
}

std::pair<Value, Value>
materializeUniqueOwnershipForMemref(Operation *op, DeallocationState &state,
const DeallocationOptions &options,
OpBuilder &builder, Value value) const {
auto selectOp = cast<arith::SelectOp>(op);
assert(value == selectOp.getResult() &&
"Value not defined by this operation");

Block *block = value.getParentBlock();
if (!state.getOwnership(selectOp.getTrueValue(), block).isUnique() ||
!state.getOwnership(selectOp.getFalseValue(), block).isUnique())
return state.getMemrefWithUniqueOwnership(builder, value,
value.getParentBlock());

Value ownership = builder.create<arith::SelectOp>(
op->getLoc(), selectOp.getCondition(),
state.getOwnership(selectOp.getTrueValue(), block).getIndicator(),
state.getOwnership(selectOp.getFalseValue(), block).getIndicator());
return {selectOp.getResult(), ownership};
}
};

} // namespace

void mlir::arith::registerBufferDeallocationOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, ArithDialect *dialect) {
SelectOp::attachInterface<SelectOpInterface>(*ctx);
});
}
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRArithTransforms
BufferDeallocationOpInterfaceImpl.cpp
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
EmulateUnsupportedFloats.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ void DeallocationState::getLiveMemrefsIn(Block *block,

std::pair<Value, Value>
DeallocationState::getMemrefWithUniqueOwnership(OpBuilder &builder,
Value memref) {
auto iter = ownershipMap.find({memref, memref.getParentBlock()});
Value memref, Block *block) {
auto iter = ownershipMap.find({memref, block});
assert(iter != ownershipMap.end() &&
"Value must already have been registered in the ownership map");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,13 +376,24 @@ class BufferDeallocation {
/// Given an SSA value of MemRef type, returns the same of a new SSA value
/// which has 'Unique' ownership where the ownership indicator is guaranteed
/// to be always 'true'.
Value getMemrefWithGuaranteedOwnership(OpBuilder &builder, Value memref);
Value materializeMemrefWithGuaranteedOwnership(OpBuilder &builder,
Value memref, Block *block);

/// Returns whether the given operation implements FunctionOpInterface, has
/// private visibility, and the private-function-dynamic-ownership pass option
/// is enabled.
bool isFunctionWithoutDynamicOwnership(Operation *op);

/// Given an SSA value of MemRef type, this function queries the
/// BufferDeallocationOpInterface of the defining operation of 'memref' for a
/// materialized ownership indicator for 'memref'. If the op does not
/// implement the interface or if the block for which the materialized value
/// is requested does not match the block in which 'memref' is defined, the
/// default implementation in
/// `DeallocationState::getMemrefWithUniqueOwnership` is queried instead.
std::pair<Value, Value>
materializeUniqueOwnership(OpBuilder &builder, Value memref, Block *block);

/// Checks all the preconditions for operations implementing the
/// FunctionOpInterface that have to hold for the deallocation to be
/// applicable:
Expand Down Expand Up @@ -428,6 +439,28 @@ class BufferDeallocation {
// BufferDeallocation Implementation
//===----------------------------------------------------------------------===//

std::pair<Value, Value>
BufferDeallocation::materializeUniqueOwnership(OpBuilder &builder, Value memref,
Block *block) {
// The interface can only materialize ownership indicators in the same block
// as the defining op.
if (memref.getParentBlock() != block)
return state.getMemrefWithUniqueOwnership(builder, memref, block);

Operation *owner = memref.getDefiningOp();
if (!owner)
owner = memref.getParentBlock()->getParentOp();

// If the op implements the interface, query it for a materialized ownership
// value.
if (auto deallocOpInterface = dyn_cast<BufferDeallocationOpInterface>(owner))
return deallocOpInterface.materializeUniqueOwnershipForMemref(
state, options, builder, memref);

// Otherwise use the default implementation.
return state.getMemrefWithUniqueOwnership(builder, memref, block);
}

static bool regionOperatesOnMemrefValues(Region &region) {
WalkResult result = region.walk([](Block *block) {
if (llvm::any_of(block->getArguments(), isMemref))
Expand Down Expand Up @@ -677,11 +710,11 @@ BufferDeallocation::handleInterface(RegionBranchOpInterface op) {
return newOp.getOperation();
}

Value BufferDeallocation::getMemrefWithGuaranteedOwnership(OpBuilder &builder,
Value memref) {
Value BufferDeallocation::materializeMemrefWithGuaranteedOwnership(
OpBuilder &builder, Value memref, Block *block) {
// First, make sure we at least have 'Unique' ownership already.
std::pair<Value, Value> newMemrefAndOnwership =
state.getMemrefWithUniqueOwnership(builder, memref);
materializeUniqueOwnership(builder, memref, block);
Value newMemref = newMemrefAndOnwership.first;
Value condition = newMemrefAndOnwership.second;

Expand Down Expand Up @@ -785,7 +818,7 @@ FailureOr<Operation *> BufferDeallocation::handleInterface(CallOpInterface op) {
continue;
}
auto [memref, condition] =
state.getMemrefWithUniqueOwnership(builder, operand);
materializeUniqueOwnership(builder, operand, op->getBlock());
newOperands.push_back(memref);
ownershipIndicatorsToAdd.push_back(condition);
}
Expand Down Expand Up @@ -868,7 +901,8 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) {
if (!isMemref(val.get()))
continue;

val.set(getMemrefWithGuaranteedOwnership(builder, val.get()));
val.set(materializeMemrefWithGuaranteedOwnership(builder, val.get(),
op->getBlock()));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,15 @@ func.func @function_call_requries_merged_ownership_mid_block(%arg0: i1) {
// CHECK-NEXT: return

// CHECK-DYNAMIC-LABEL: func @function_call_requries_merged_ownership_mid_block
// CHECK-DYNAMIC-SAME: ([[ARG0:%.+]]: i1)
// CHECK-DYNAMIC: [[ALLOC0:%.+]] = memref.alloc(
// CHECK-DYNAMIC-NEXT: [[ALLOC1:%.+]] = memref.alloca(
// CHECK-DYNAMIC-NEXT: [[SELECT:%.+]] = arith.select{{.*}}[[ALLOC0]], [[ALLOC1]]
// CHECK-DYNAMIC-NEXT: [[CLONE:%.+]] = bufferization.clone [[SELECT]]
// CHECK-DYNAMIC-NEXT: [[RET:%.+]]:2 = call @f([[CLONE]], %true{{[0-9_]*}})
// CHECK-DYNAMIC-NEXT: [[SELECT:%.+]] = arith.select [[ARG0]], [[ALLOC0]], [[ALLOC1]]
// CHECK-DYNAMIC-NEXT: [[RET:%.+]]:2 = call @f([[SELECT]], [[ARG0]])
// CHECK-DYNAMIC-NEXT: test.copy
// CHECK-DYNAMIC-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[RET]]#0
// CHECK-DYNAMIC-NEXT: bufferization.dealloc ([[ALLOC0]], [[CLONE]], [[BASE]] :
// CHECK-DYNAMIC-SAME: if (%true{{[0-9_]*}}, %true{{[0-9_]*}}, [[RET]]#1)
// CHECK-DYNAMIC-NEXT: bufferization.dealloc ([[ALLOC0]], [[BASE]] :
// CHECK-DYNAMIC-SAME: if (%true{{[0-9_]*}}, [[RET]]#1)
// CHECK-DYNAMIC-NOT: retain
// CHECK-DYNAMIC-NEXT: return

Expand Down