Skip to content

Commit 687cf24

Browse files
matthias-springeragozillon
authored andcommitted
[mlir][bufferization] Never pass ownership to functions (llvm#80655)
Even when `private-function-dynamic-ownership` is set, ownership should never be passed to the callee. This can lead to double deallocs (llvm#77096) or use-after-free in the caller because ownership is currently passed regardless of whether there are any further uses of the buffer in the caller or not. Note: This is consistent with the fact that ownership is never passed to nested regions. This commit fixes llvm#77096.
1 parent cebe93f commit 687cf24

File tree

5 files changed

+56
-59
lines changed

5 files changed

+56
-59
lines changed

mlir/docs/Bufferization.md

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -504,10 +504,8 @@ accordingly:
504504

505505
The following example contains a few interesting cases:
506506
* Basic block arguments are modified to also pass along the ownership
507-
indicator, but not for entry bocks of non-private functions (assuming the
508-
`private-function-dynamic-ownership` pass option is disabled) where the
509-
function boundary ABI is applied instead. "Private" in this context refers
510-
to functions that cannot be called externally.
507+
indicator, but not for entry blocks, where the function boundary ABI
508+
is applied instead.
511509
* The result of `arith.select` initially has 'Unknown' assigned as ownership,
512510
but once the `bufferization.dealloc` operation is inserted it is put in the
513511
'retained' list (since it has uses in a later basic block) and thus the
@@ -543,10 +541,7 @@ func.func @example(%memref: memref<?xi8>, %select_cond: i1, %br_cond: i1) {
543541
After running `--ownership-based-buffer-deallocation`, it looks as follows:
544542

545543
```mlir
546-
// Since this is not a private function, the signature will not be modified even
547-
// when private-function-dynamic-ownership is enabled. Instead the function
548-
// boundary ABI has to be applied which means that ownership of `%memref` will
549-
// never be acquired.
544+
// Function boundary ABI: ownership of `%memref` will never be acquired.
550545
func.func @example(%memref: memref<?xi8>, %select_cond: i1, %br_cond: i1) {
551546
%false = arith.constant false
552547
%true = arith.constant true
@@ -602,7 +597,7 @@ func.func @example(%memref: memref<?xi8>, %select_cond: i1, %br_cond: i1) {
602597
: memref<i8>, memref<i8>, memref<i8>)
603598
if (%false, %not_br_cond, %false)
604599
retain (%memref, %select : memref<?xi8>, memref<?xi8>)
605-
600+
606601
// Because %select is used in ^bb1 without passing it via block argument, we
607602
// need to update it's ownership value here by merging the ownership values
608603
// returned by the dealloc operations

mlir/include/mlir/Dialect/Bufferization/Pipelines/Passes.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ struct BufferDeallocationPipelineOptions
2424
PassOptions::Option<bool> privateFunctionDynamicOwnership{
2525
*this, "private-function-dynamic-ownership",
2626
llvm::cl::desc(
27-
"Allows to add additional arguments to private functions to "
28-
"dynamically pass ownership of memrefs to callees. This can enable "
29-
"earlier deallocations."),
27+
"Allows to add additional results to private functions to return "
28+
"ownership of returned memrefs to callers. This can avoid spurious "
29+
"buffer clones in the callee."),
3030
llvm::cl::init(false)};
3131
};
3232

mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp

Lines changed: 13 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -292,11 +292,10 @@ class BufferDeallocation {
292292
FailureOr<Operation *> handleInterface(RegionBranchOpInterface op);
293293

294294
/// If the private-function-dynamic-ownership pass option is enabled and the
295-
/// called function is private, additional arguments and results are added for
296-
/// each MemRef argument/result to pass the dynamic ownership indicator along.
297-
/// Otherwise, updates the ownership map and list of memrefs to be deallocated
298-
/// according to the function boundary ABI, i.e., assume ownership of all
299-
/// returned MemRefs.
295+
/// called function is private, additional results are added for each MemRef
296+
/// result to pass the dynamic ownership indicator along. Otherwise, updates
297+
/// the ownership map and list of memrefs to be deallocated according to the
298+
/// function boundary ABI, i.e., assume ownership of all returned MemRefs.
300299
///
301300
/// Example (assume `private-function-dynamic-ownership` is enabled):
302301
/// ```
@@ -309,17 +308,15 @@ class BufferDeallocation {
309308
/// becomes
310309
/// ```
311310
/// func.func @f(%arg0: memref<2xi32>) -> memref<2xi32> {...}
312-
/// func.func private @g(%arg0: memref<2xi32>) -> memref<2xi32> {...}
311+
/// func.func private @g(%arg0: memref<2xi32>) -> (memref<2xi32>, i1) {...}
313312
///
314313
/// %ret_f = func.call @f(%memref) : (memref<2xi32>) -> memref<2xi32>
315314
/// // set ownership(%ret_f) := true
316315
/// // remember to deallocate %ret_f
317316
///
318-
/// // (new_memref, own) = getmemrefWithUniqueOwnership(%memref)
319-
/// %ret_g:2 = func.call @g(new_memref, own) :
320-
/// (memref<2xi32>, i1) -> (memref<2xi32>, i1)
317+
/// %ret_g:2 = func.call @g(%memref) : (memref<2xi32>) -> (memref<2xi32>, i1)
321318
/// // set ownership(%ret_g#0) := %ret_g#1
322-
/// // remember to deallocate %ret_g
319+
/// // remember to deallocate %ret_g if it comes with ownership
323320
/// ```
324321
FailureOr<Operation *> handleInterface(CallOpInterface op);
325322

@@ -444,8 +441,8 @@ class BufferDeallocation {
444441
static LogicalResult verifyOperationPreconditions(Operation *op);
445442

446443
/// When the 'private-function-dynamic-ownership' pass option is enabled,
447-
/// additional `i1` arguments and return values are added for each MemRef
448-
/// value in the function signature. This function takes care of updating the
444+
/// additional `i1` return values are added for each MemRef result in the
445+
/// function signature. This function takes care of updating the
449446
/// `function_type` attribute of the function according to the actually
450447
/// returned values from the terminators.
451448
static LogicalResult updateFunctionSignature(FunctionOpInterface op);
@@ -650,7 +647,7 @@ LogicalResult BufferDeallocation::deallocate(Block *block) {
650647

651648
// Adhere to function boundary ABI: no ownership of function argument
652649
// MemRefs is taken.
653-
if (isFunctionWithoutDynamicOwnership(block->getParentOp()) &&
650+
if (isa<FunctionOpInterface>(block->getParentOp()) &&
654651
block->isEntryBlock()) {
655652
Value newArg = buildBoolValue(builder, arg.getLoc(), false);
656653
state.updateOwnership(arg, newArg);
@@ -838,26 +835,10 @@ FailureOr<Operation *> BufferDeallocation::handleInterface(CallOpInterface op) {
838835
isPrivate = symbol.isPrivate() && !symbol.isDeclaration();
839836

840837
// If the private-function-dynamic-ownership option is enabled and we are
841-
// calling a private function, we need to add an additional `i1`
842-
// argument/result for each MemRef argument/result to dynamically pass the
843-
// current ownership indicator rather than adhering to the function boundary
844-
// ABI.
838+
// calling a private function, we need to add an additional `i1` result for
839+
// each MemRef result to dynamically pass the current ownership indicator
840+
// rather than adhering to the function boundary ABI.
845841
if (options.privateFuncDynamicOwnership && isPrivate) {
846-
SmallVector<Value> newOperands, ownershipIndicatorsToAdd;
847-
for (Value operand : op.getArgOperands()) {
848-
if (!isMemref(operand)) {
849-
newOperands.push_back(operand);
850-
continue;
851-
}
852-
auto [memref, condition] =
853-
materializeUniqueOwnership(builder, operand, op->getBlock());
854-
newOperands.push_back(memref);
855-
ownershipIndicatorsToAdd.push_back(condition);
856-
}
857-
newOperands.append(ownershipIndicatorsToAdd.begin(),
858-
ownershipIndicatorsToAdd.end());
859-
op.getArgOperandsMutable().assign(newOperands);
860-
861842
unsigned numMemrefs = llvm::count_if(op->getResults(), isMemref);
862843
SmallVector<Type> ownershipTypesToAppend(numMemrefs, builder.getI1Type());
863844
unsigned ownershipCounter = op->getNumResults();

mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-callop-interface.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ func.func @function_call() {
3333
// CHECK-DYNAMIC-LABEL: func @function_call()
3434
// CHECK-DYNAMIC: [[ALLOC0:%.+]] = memref.alloc(
3535
// CHECK-DYNAMIC-NEXT: [[ALLOC1:%.+]] = memref.alloc(
36-
// CHECK-DYNAMIC-NEXT: [[RET:%.+]]:2 = call @f([[ALLOC0]], %true{{[0-9_]*}}) : (memref<f64>, i1) -> (memref<f64>, i1)
36+
// CHECK-DYNAMIC-NEXT: [[RET:%.+]]:2 = call @f([[ALLOC0]]) : (memref<f64>) -> (memref<f64>, i1)
3737
// CHECK-DYNAMIC-NEXT: test.copy
3838
// CHECK-DYNAMIC-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[RET]]#0
3939
// CHECK-DYNAMIC-NEXT: bufferization.dealloc ([[ALLOC0]], [[ALLOC1]], [[BASE]] :{{.*}}) if (%true{{[0-9_]*}}, %true{{[0-9_]*}}, [[RET]]#1)
@@ -102,7 +102,7 @@ func.func @function_call_requries_merged_ownership_mid_block(%arg0: i1) {
102102
// CHECK-DYNAMIC: [[ALLOC0:%.+]] = memref.alloc(
103103
// CHECK-DYNAMIC-NEXT: [[ALLOC1:%.+]] = memref.alloca(
104104
// CHECK-DYNAMIC-NEXT: [[SELECT:%.+]] = arith.select [[ARG0]], [[ALLOC0]], [[ALLOC1]]
105-
// CHECK-DYNAMIC-NEXT: [[RET:%.+]]:2 = call @f([[SELECT]], [[ARG0]])
105+
// CHECK-DYNAMIC-NEXT: [[RET:%.+]]:2 = call @f([[SELECT]])
106106
// CHECK-DYNAMIC-NEXT: test.copy
107107
// CHECK-DYNAMIC-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[RET]]#0
108108
// CHECK-DYNAMIC-NEXT: bufferization.dealloc ([[ALLOC0]], [[BASE]] :

mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-function-boundaries.mlir

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,9 @@ func.func private @emptyUsesValue(%arg0: memref<4xf32>) {
2424
// CHECK-NEXT: return
2525

2626
// CHECK-DYNAMIC-LABEL: func private @emptyUsesValue(
27-
// CHECK-DYNAMIC-SAME: [[ARG0:%.+]]: memref<4xf32>, [[ARG1:%.+]]: i1)
27+
// CHECK-DYNAMIC-SAME: [[ARG0:%.+]]: memref<4xf32>)
2828
// CHECK-DYNAMIC: [[ALLOC:%.*]] = memref.alloc()
29-
// CHECK-DYNAMIC: [[BASE:%[a-zA-Z0-9_]+]], {{.*}} = memref.extract_strided_metadata [[ARG0]]
30-
// CHECK-DYNAMIC-NEXT: bufferization.dealloc ([[BASE]] :{{.*}}) if ([[ARG1]])
31-
// CHECK-DYNAMIC-NOT: retain
29+
// CHECK-DYNAMIC-NEXT: "test.read_buffer"
3230
// CHECK-DYNAMIC-NEXT: bufferization.dealloc ([[ALLOC]] :{{.*}}) if (%true{{[0-9_]*}})
3331
// CHECK-DYNAMIC-NOT: retain
3432
// CHECK-DYNAMIC-NEXT: return
@@ -74,13 +72,11 @@ func.func private @redundantOperations(%arg0: memref<2xf32>) {
7472
// CHECK-NEXT: return
7573

7674
// CHECK-DYNAMIC-LABEL: func private @redundantOperations
77-
// CHECK-DYNAMIC: (%[[ARG0:.*]]: memref{{.*}}, %[[ARG1:.*]]: i1)
75+
// CHECK-DYNAMIC: (%[[ARG0:.*]]: memref{{.*}})
7876
// CHECK-DYNAMIC: %[[FIRST_ALLOC:.*]] = memref.alloc()
7977
// CHECK-DYNAMIC-NEXT: test.buffer_based
8078
// CHECK-DYNAMIC: %[[SECOND_ALLOC:.*]] = memref.alloc()
8179
// CHECK-DYNAMIC-NEXT: test.buffer_based
82-
// CHECK-DYNAMIC-NEXT: %[[BASE:[a-zA-Z0-9_]+]], {{.*}} = memref.extract_strided_metadata %[[ARG0]]
83-
// CHECK-DYNAMIC-NEXT: bufferization.dealloc (%[[BASE]] : {{.*}}) if (%[[ARG1]])
8480
// CHECK-DYNAMIC-NEXT: bufferization.dealloc (%[[FIRST_ALLOC]] : {{.*}}) if (%true{{[0-9_]*}})
8581
// CHECK-DYNAMIC-NEXT: bufferization.dealloc (%[[SECOND_ALLOC]] : {{.*}}) if (%true{{[0-9_]*}})
8682
// CHECK-DYNAMIC-NEXT: return
@@ -121,14 +117,39 @@ func.func private @memref_in_function_results(
121117

122118
// CHECK-DYNAMIC-LABEL: func private @memref_in_function_results
123119
// CHECK-DYNAMIC: (%[[ARG0:.*]]: memref<5xf32>, %[[ARG1:.*]]: memref<10xf32>,
124-
// CHECK-DYNAMIC-SAME: %[[RESULT:.*]]: memref<5xf32>, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: i1, %[[ARG5:.*]]: i1)
120+
// CHECK-DYNAMIC-SAME: %[[RESULT:.*]]: memref<5xf32>)
125121
// CHECK-DYNAMIC: %[[X:.*]] = memref.alloc()
126122
// CHECK-DYNAMIC: %[[Y:.*]] = memref.alloc()
127123
// CHECK-DYNAMIC: test.copy
128-
// CHECK-DYNAMIC: %[[BASE0:[a-zA-Z0-9_]+]], {{.+}} = memref.extract_strided_metadata %[[ARG0]]
129-
// CHECK-DYNAMIC: %[[BASE1:[a-zA-Z0-9_]+]], {{.+}} = memref.extract_strided_metadata %[[RESULT]]
130124
// CHECK-DYNAMIC: bufferization.dealloc (%[[Y]] : {{.*}}) if (%true{{[0-9_]*}})
131125
// CHECK-DYNAMIC-NOT: retain
132-
// CHECK-DYNAMIC: [[OWN:%.+]] = bufferization.dealloc (%[[BASE0]], %[[BASE1]] : {{.*}}) if (%[[ARG3]], %[[ARG5]]) retain (%[[ARG1]] :
133-
// CHECK-DYNAMIC: [[OR:%.+]] = arith.ori [[OWN]], %[[ARG4]]
134-
// CHECK-DYNAMIC: return %[[ARG1]], %[[X]], [[OR]], %true
126+
// CHECK-DYNAMIC: return %[[ARG1]], %[[X]], %false, %true
127+
128+
// -----
129+
130+
// CHECK-DYNAMIC-LABEL: func private @private_callee(
131+
// CHECK-DYNAMIC-SAME: %[[arg0:.*]]: memref<f32>) -> (memref<f32>, i1)
132+
// CHECK-DYNAMIC: %[[true:.*]] = arith.constant true
133+
// CHECK-DYNAMIC: %[[alloc:.*]] = memref.alloc() : memref<f32>
134+
// CHECK-DYNAMIC-NOT: bufferization.dealloc
135+
// CHECK-DYNAMIC: return %[[alloc]], %[[true]]
136+
func.func private @private_callee(%arg0: memref<f32>) -> memref<f32> {
137+
%alloc = memref.alloc() : memref<f32>
138+
return %alloc : memref<f32>
139+
}
140+
141+
// CHECK-DYNAMIC: func @caller() -> f32
142+
// CHECK-DYNAMIC: %[[true:.*]] = arith.constant true
143+
// CHECK-DYNAMIC: %[[alloc:.*]] = memref.alloc() : memref<f32>
144+
// CHECK-DYNAMIC: %[[call:.*]]:2 = call @private_callee(%[[alloc]])
145+
// CHECK-DYNAMIC: memref.load
146+
// CHECK-DYNAMIC: %[[base:.*]], %[[offset:.*]] = memref.extract_strided_metadata %[[call]]#0
147+
// CHECK-DYNAMIC: bufferization.dealloc (%[[alloc]], %[[base]] : {{.*}}) if (%[[true]], %[[call]]#1)
148+
// CHECK-DYNAMIC-NOT: retain
149+
func.func @caller() -> (f32) {
150+
%alloc = memref.alloc() : memref<f32>
151+
%ret = call @private_callee(%alloc) : (memref<f32>) -> memref<f32>
152+
153+
%val = memref.load %ret[] : memref<f32>
154+
return %val : f32
155+
}

0 commit comments

Comments
 (0)