Skip to content

[mlir][bufferization] Never pass ownership to functions #80655

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions mlir/docs/Bufferization.md
Original file line number Diff line number Diff line change
Expand Up @@ -504,10 +504,8 @@ accordingly:

The following example contains a few interesting cases:
* Basic block arguments are modified to also pass along the ownership
indicator, but not for entry bocks of non-private functions (assuming the
`private-function-dynamic-ownership` pass option is disabled) where the
function boundary ABI is applied instead. "Private" in this context refers
to functions that cannot be called externally.
indicator, but not for entry blocks, where the function boundary ABI
is applied instead.
* The result of `arith.select` initially has 'Unknown' assigned as ownership,
but once the `bufferization.dealloc` operation is inserted it is put in the
'retained' list (since it has uses in a later basic block) and thus the
Expand Down Expand Up @@ -543,10 +541,7 @@ func.func @example(%memref: memref<?xi8>, %select_cond: i1, %br_cond: i1) {
After running `--ownership-based-buffer-deallocation`, it looks as follows:

```mlir
// Since this is not a private function, the signature will not be modified even
// when private-function-dynamic-ownership is enabled. Instead the function
// boundary ABI has to be applied which means that ownership of `%memref` will
// never be acquired.
// Function boundary ABI: ownership of `%memref` will never be acquired.
func.func @example(%memref: memref<?xi8>, %select_cond: i1, %br_cond: i1) {
%false = arith.constant false
%true = arith.constant true
Expand Down Expand Up @@ -602,7 +597,7 @@ func.func @example(%memref: memref<?xi8>, %select_cond: i1, %br_cond: i1) {
: memref<i8>, memref<i8>, memref<i8>)
if (%false, %not_br_cond, %false)
retain (%memref, %select : memref<?xi8>, memref<?xi8>)

// Because %select is used in ^bb1 without passing it via block argument, we
// need to update it's ownership value here by merging the ownership values
// returned by the dealloc operations
Expand Down
6 changes: 3 additions & 3 deletions mlir/include/mlir/Dialect/Bufferization/Pipelines/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ struct BufferDeallocationPipelineOptions
PassOptions::Option<bool> privateFunctionDynamicOwnership{
*this, "private-function-dynamic-ownership",
llvm::cl::desc(
"Allows to add additional arguments to private functions to "
"dynamically pass ownership of memrefs to callees. This can enable "
"earlier deallocations."),
"Allows to add additional results to private functions to return "
"ownership of returned memrefs to callers. This can avoid spurious "
"buffer clones in the callee."),
llvm::cl::init(false)};
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,11 +292,10 @@ class BufferDeallocation {
FailureOr<Operation *> handleInterface(RegionBranchOpInterface op);

/// If the private-function-dynamic-ownership pass option is enabled and the
/// called function is private, additional arguments and results are added for
/// each MemRef argument/result to pass the dynamic ownership indicator along.
/// Otherwise, updates the ownership map and list of memrefs to be deallocated
/// according to the function boundary ABI, i.e., assume ownership of all
/// returned MemRefs.
/// called function is private, additional results are added for each MemRef
/// result to pass the dynamic ownership indicator along. Otherwise, updates
/// the ownership map and list of memrefs to be deallocated according to the
/// function boundary ABI, i.e., assume ownership of all returned MemRefs.
///
/// Example (assume `private-function-dynamic-ownership` is enabled):
/// ```
Expand All @@ -309,17 +308,15 @@ class BufferDeallocation {
/// becomes
/// ```
/// func.func @f(%arg0: memref<2xi32>) -> memref<2xi32> {...}
/// func.func private @g(%arg0: memref<2xi32>) -> memref<2xi32> {...}
/// func.func private @g(%arg0: memref<2xi32>) -> (memref<2xi32>, i1) {...}
///
/// %ret_f = func.call @f(%memref) : (memref<2xi32>) -> memref<2xi32>
/// // set ownership(%ret_f) := true
/// // remember to deallocate %ret_f
///
/// // (new_memref, own) = getmemrefWithUniqueOwnership(%memref)
/// %ret_g:2 = func.call @g(new_memref, own) :
/// (memref<2xi32>, i1) -> (memref<2xi32>, i1)
/// %ret_g:2 = func.call @g(%memref) : (memref<2xi32>) -> (memref<2xi32>, i1)
/// // set ownership(%ret_g#0) := %ret_g#1
/// // remember to deallocate %ret_g
/// // remember to deallocate %ret_g if it comes with ownership
/// ```
FailureOr<Operation *> handleInterface(CallOpInterface op);

Expand Down Expand Up @@ -444,8 +441,8 @@ class BufferDeallocation {
static LogicalResult verifyOperationPreconditions(Operation *op);

/// When the 'private-function-dynamic-ownership' pass option is enabled,
/// additional `i1` arguments and return values are added for each MemRef
/// value in the function signature. This function takes care of updating the
/// additional `i1` return values are added for each MemRef result in the
/// function signature. This function takes care of updating the
/// `function_type` attribute of the function according to the actually
/// returned values from the terminators.
static LogicalResult updateFunctionSignature(FunctionOpInterface op);
Expand Down Expand Up @@ -650,7 +647,7 @@ LogicalResult BufferDeallocation::deallocate(Block *block) {

// Adhere to function boundary ABI: no ownership of function argument
// MemRefs is taken.
if (isFunctionWithoutDynamicOwnership(block->getParentOp()) &&
if (isa<FunctionOpInterface>(block->getParentOp()) &&
block->isEntryBlock()) {
Value newArg = buildBoolValue(builder, arg.getLoc(), false);
state.updateOwnership(arg, newArg);
Expand Down Expand Up @@ -838,26 +835,10 @@ FailureOr<Operation *> BufferDeallocation::handleInterface(CallOpInterface op) {
isPrivate = symbol.isPrivate() && !symbol.isDeclaration();

// If the private-function-dynamic-ownership option is enabled and we are
// calling a private function, we need to add an additional `i1`
// argument/result for each MemRef argument/result to dynamically pass the
// current ownership indicator rather than adhering to the function boundary
// ABI.
// calling a private function, we need to add an additional `i1` result for
// each MemRef result to dynamically pass the current ownership indicator
// rather than adhering to the function boundary ABI.
if (options.privateFuncDynamicOwnership && isPrivate) {
SmallVector<Value> newOperands, ownershipIndicatorsToAdd;
for (Value operand : op.getArgOperands()) {
if (!isMemref(operand)) {
newOperands.push_back(operand);
continue;
}
auto [memref, condition] =
materializeUniqueOwnership(builder, operand, op->getBlock());
newOperands.push_back(memref);
ownershipIndicatorsToAdd.push_back(condition);
}
newOperands.append(ownershipIndicatorsToAdd.begin(),
ownershipIndicatorsToAdd.end());
op.getArgOperandsMutable().assign(newOperands);

unsigned numMemrefs = llvm::count_if(op->getResults(), isMemref);
SmallVector<Type> ownershipTypesToAppend(numMemrefs, builder.getI1Type());
unsigned ownershipCounter = op->getNumResults();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func.func @function_call() {
// CHECK-DYNAMIC-LABEL: func @function_call()
// CHECK-DYNAMIC: [[ALLOC0:%.+]] = memref.alloc(
// CHECK-DYNAMIC-NEXT: [[ALLOC1:%.+]] = memref.alloc(
// CHECK-DYNAMIC-NEXT: [[RET:%.+]]:2 = call @f([[ALLOC0]], %true{{[0-9_]*}}) : (memref<f64>, i1) -> (memref<f64>, i1)
// CHECK-DYNAMIC-NEXT: [[RET:%.+]]:2 = call @f([[ALLOC0]]) : (memref<f64>) -> (memref<f64>, i1)
// CHECK-DYNAMIC-NEXT: test.copy
// CHECK-DYNAMIC-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[RET]]#0
// CHECK-DYNAMIC-NEXT: bufferization.dealloc ([[ALLOC0]], [[ALLOC1]], [[BASE]] :{{.*}}) if (%true{{[0-9_]*}}, %true{{[0-9_]*}}, [[RET]]#1)
Expand Down Expand Up @@ -102,7 +102,7 @@ func.func @function_call_requries_merged_ownership_mid_block(%arg0: i1) {
// CHECK-DYNAMIC: [[ALLOC0:%.+]] = memref.alloc(
// CHECK-DYNAMIC-NEXT: [[ALLOC1:%.+]] = memref.alloca(
// CHECK-DYNAMIC-NEXT: [[SELECT:%.+]] = arith.select [[ARG0]], [[ALLOC0]], [[ALLOC1]]
// CHECK-DYNAMIC-NEXT: [[RET:%.+]]:2 = call @f([[SELECT]], [[ARG0]])
// CHECK-DYNAMIC-NEXT: [[RET:%.+]]:2 = call @f([[SELECT]])
// CHECK-DYNAMIC-NEXT: test.copy
// CHECK-DYNAMIC-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[RET]]#0
// CHECK-DYNAMIC-NEXT: bufferization.dealloc ([[ALLOC0]], [[BASE]] :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,9 @@ func.func private @emptyUsesValue(%arg0: memref<4xf32>) {
// CHECK-NEXT: return

// CHECK-DYNAMIC-LABEL: func private @emptyUsesValue(
// CHECK-DYNAMIC-SAME: [[ARG0:%.+]]: memref<4xf32>, [[ARG1:%.+]]: i1)
// CHECK-DYNAMIC-SAME: [[ARG0:%.+]]: memref<4xf32>)
// CHECK-DYNAMIC: [[ALLOC:%.*]] = memref.alloc()
// CHECK-DYNAMIC: [[BASE:%[a-zA-Z0-9_]+]], {{.*}} = memref.extract_strided_metadata [[ARG0]]
// CHECK-DYNAMIC-NEXT: bufferization.dealloc ([[BASE]] :{{.*}}) if ([[ARG1]])
// CHECK-DYNAMIC-NOT: retain
// CHECK-DYNAMIC-NEXT: "test.read_buffer"
// CHECK-DYNAMIC-NEXT: bufferization.dealloc ([[ALLOC]] :{{.*}}) if (%true{{[0-9_]*}})
// CHECK-DYNAMIC-NOT: retain
// CHECK-DYNAMIC-NEXT: return
Expand Down Expand Up @@ -74,13 +72,11 @@ func.func private @redundantOperations(%arg0: memref<2xf32>) {
// CHECK-NEXT: return

// CHECK-DYNAMIC-LABEL: func private @redundantOperations
// CHECK-DYNAMIC: (%[[ARG0:.*]]: memref{{.*}}, %[[ARG1:.*]]: i1)
// CHECK-DYNAMIC: (%[[ARG0:.*]]: memref{{.*}})
// CHECK-DYNAMIC: %[[FIRST_ALLOC:.*]] = memref.alloc()
// CHECK-DYNAMIC-NEXT: test.buffer_based
// CHECK-DYNAMIC: %[[SECOND_ALLOC:.*]] = memref.alloc()
// CHECK-DYNAMIC-NEXT: test.buffer_based
// CHECK-DYNAMIC-NEXT: %[[BASE:[a-zA-Z0-9_]+]], {{.*}} = memref.extract_strided_metadata %[[ARG0]]
// CHECK-DYNAMIC-NEXT: bufferization.dealloc (%[[BASE]] : {{.*}}) if (%[[ARG1]])
// CHECK-DYNAMIC-NEXT: bufferization.dealloc (%[[FIRST_ALLOC]] : {{.*}}) if (%true{{[0-9_]*}})
// CHECK-DYNAMIC-NEXT: bufferization.dealloc (%[[SECOND_ALLOC]] : {{.*}}) if (%true{{[0-9_]*}})
// CHECK-DYNAMIC-NEXT: return
Expand Down Expand Up @@ -121,14 +117,39 @@ func.func private @memref_in_function_results(

// CHECK-DYNAMIC-LABEL: func private @memref_in_function_results
// CHECK-DYNAMIC: (%[[ARG0:.*]]: memref<5xf32>, %[[ARG1:.*]]: memref<10xf32>,
// CHECK-DYNAMIC-SAME: %[[RESULT:.*]]: memref<5xf32>, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: i1, %[[ARG5:.*]]: i1)
// CHECK-DYNAMIC-SAME: %[[RESULT:.*]]: memref<5xf32>)
// CHECK-DYNAMIC: %[[X:.*]] = memref.alloc()
// CHECK-DYNAMIC: %[[Y:.*]] = memref.alloc()
// CHECK-DYNAMIC: test.copy
// CHECK-DYNAMIC: %[[BASE0:[a-zA-Z0-9_]+]], {{.+}} = memref.extract_strided_metadata %[[ARG0]]
// CHECK-DYNAMIC: %[[BASE1:[a-zA-Z0-9_]+]], {{.+}} = memref.extract_strided_metadata %[[RESULT]]
// CHECK-DYNAMIC: bufferization.dealloc (%[[Y]] : {{.*}}) if (%true{{[0-9_]*}})
// CHECK-DYNAMIC-NOT: retain
// CHECK-DYNAMIC: [[OWN:%.+]] = bufferization.dealloc (%[[BASE0]], %[[BASE1]] : {{.*}}) if (%[[ARG3]], %[[ARG5]]) retain (%[[ARG1]] :
// CHECK-DYNAMIC: [[OR:%.+]] = arith.ori [[OWN]], %[[ARG4]]
// CHECK-DYNAMIC: return %[[ARG1]], %[[X]], [[OR]], %true
// CHECK-DYNAMIC: return %[[ARG1]], %[[X]], %false, %true

// -----

// CHECK-DYNAMIC-LABEL: func private @private_callee(
// CHECK-DYNAMIC-SAME: %[[arg0:.*]]: memref<f32>) -> (memref<f32>, i1)
// CHECK-DYNAMIC: %[[true:.*]] = arith.constant true
// CHECK-DYNAMIC: %[[alloc:.*]] = memref.alloc() : memref<f32>
// CHECK-DYNAMIC-NOT: bufferization.dealloc
// CHECK-DYNAMIC: return %[[alloc]], %[[true]]
func.func private @private_callee(%arg0: memref<f32>) -> memref<f32> {
%alloc = memref.alloc() : memref<f32>
return %alloc : memref<f32>
}

// CHECK-DYNAMIC: func @caller() -> f32
// CHECK-DYNAMIC: %[[true:.*]] = arith.constant true
// CHECK-DYNAMIC: %[[alloc:.*]] = memref.alloc() : memref<f32>
// CHECK-DYNAMIC: %[[call:.*]]:2 = call @private_callee(%[[alloc]])
// CHECK-DYNAMIC: memref.load
// CHECK-DYNAMIC: %[[base:.*]], %[[offset:.*]] = memref.extract_strided_metadata %[[call]]#0
// CHECK-DYNAMIC: bufferization.dealloc (%[[alloc]], %[[base]] : {{.*}}) if (%[[true]], %[[call]]#1)
// CHECK-DYNAMIC-NOT: retain
func.func @caller() -> (f32) {
%alloc = memref.alloc() : memref<f32>
%ret = call @private_callee(%alloc) : (memref<f32>) -> memref<f32>

%val = memref.load %ret[] : memref<f32>
return %val : f32
}