Skip to content

Commit 43198b0

Browse files
[mlir][bufferization] Better analysis around allocs and block arguments (#67923)
Values that are the result of buffer allocation ops are guaranteed to *not* be the same allocation as block arguments of containing blocks. This fact can be used to allow for more aggressive simplification of `bufferization.dealloc` ops.
1 parent 5317912 commit 43198b0

File tree

3 files changed

+72
-14
lines changed

3 files changed

+72
-14
lines changed

mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,19 +49,49 @@ static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
4949
return success();
5050
}
5151

52-
/// Checks if 'memref' may or must alias a MemRef in 'memrefList'. It is often a
52+
/// Given a memref value, return the "base" value by skipping over all
53+
/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
54+
static Value getViewBase(Value value) {
55+
while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
56+
value = viewLikeOp.getViewSource();
57+
return value;
58+
}
59+
60+
/// Return "true" if the given values are guaranteed to be different (and
61+
/// non-aliasing) allocations based on the fact that one value is the result
62+
/// of an allocation and the other value is a block argument of a parent block.
63+
/// Note: This is a best-effort analysis that will eventually be replaced by a
64+
/// proper "is same allocation" analysis. This function may return "false" even
65+
/// though the two values are distinct allocations.
66+
static bool distinctAllocAndBlockArgument(Value v1, Value v2) {
67+
Value v1Base = getViewBase(v1);
68+
Value v2Base = getViewBase(v2);
69+
auto areDistinct = [](Value v1, Value v2) {
70+
if (Operation *op = v1.getDefiningOp())
71+
if (hasEffect<MemoryEffects::Allocate>(op, v1))
72+
if (auto bbArg = dyn_cast<BlockArgument>(v2))
73+
if (bbArg.getOwner()->findAncestorOpInBlock(*op))
74+
return true;
75+
return false;
76+
};
77+
return areDistinct(v1Base, v2Base) || areDistinct(v2Base, v1Base);
78+
}
79+
80+
/// Checks if `memref` may or must alias a MemRef in `otherList`. It is often a
5381
/// requirement of optimization patterns that there cannot be any aliasing
54-
/// memref in order to perform the desired simplification. The 'allowSelfAlias'
55-
/// argument indicates whether 'memref' may be present in 'memrefList' which
82+
/// memref in order to perform the desired simplification. The `allowSelfAlias`
83+
/// argument indicates whether `memref` may be present in `otherList` which
5684
/// makes this helper function applicable to situations where we already know
57-
/// that 'memref' is in the list but also when we don't want it in the list.
85+
/// that `memref` is in the list but also when we don't want it in the list.
5886
static bool potentiallyAliasesMemref(AliasAnalysis &analysis,
59-
ValueRange memrefList, Value memref,
87+
ValueRange otherList, Value memref,
6088
bool allowSelfAlias) {
61-
for (auto mr : memrefList) {
62-
if (allowSelfAlias && mr == memref)
89+
for (auto other : otherList) {
90+
if (allowSelfAlias && other == memref)
91+
continue;
92+
if (distinctAllocAndBlockArgument(other, memref))
6393
continue;
64-
if (!analysis.alias(mr, memref).isNo())
94+
if (!analysis.alias(other, memref).isNo())
6595
return true;
6696
}
6797
return false;

mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,8 @@ func.func @loop_alloc(
270270
// CHECK: [[V0:%.+]]:2 = scf.for {{.*}} iter_args([[ARG6:%.+]] = [[ARG3]], [[ARG7:%.+]] = %false
271271
// CHECK: [[ALLOC1:%.+]] = memref.alloc()
272272
// CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG6]]
273-
// CHECK: bufferization.dealloc ([[BASE]] :{{.*}}) if ([[ARG7]]) retain ([[ALLOC1]] :
273+
// CHECK: bufferization.dealloc ([[BASE]] :{{.*}}) if ([[ARG7]])
274+
// CHECK-NOT: retain
274275
// CHECK: scf.yield [[ALLOC1]], %true
275276
// CHECK: test.copy
276277
// CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0
@@ -563,8 +564,8 @@ func.func @while_two_arg(%arg0: index) {
563564
// CHECK: ^bb0([[ARG1:%.+]]: memref<?xf32>, [[ARG2:%.+]]: memref<?xf32>, [[ARG3:%.+]]: i1, [[ARG4:%.+]]: i1):
564565
// CHECK: [[ALLOC1:%.+]] = memref.alloc(
565566
// CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG2]]
566-
// CHECK: [[OWN:%.+]]:2 = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[ARG4]]) retain ([[ARG1]], [[ALLOC1]] :
567-
// CHECK: [[OWN_AGG:%.+]] = arith.ori [[OWN]]#0, [[ARG3]]
567+
// CHECK: [[OWN:%.+]] = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[ARG4]]) retain ([[ARG1]] :
568+
// CHECK: [[OWN_AGG:%.+]] = arith.ori [[OWN]], [[ARG3]]
568569
// CHECK: scf.yield [[ARG1]], [[ALLOC1]], [[OWN_AGG]], %true
569570
// CHECK: [[BASE0:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0
570571
// CHECK: [[BASE1:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#1
@@ -594,10 +595,10 @@ func.func @while_three_arg(%arg0: index) {
594595
// CHECK: [[ALLOC1:%.+]] = memref.alloc(
595596
// CHECK: [[ALLOC2:%.+]] = memref.alloc(
596597
// CHECK: [[BASE0:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG1]]
597-
// CHECK: [[BASE1:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG2]]
598598
// CHECK: [[BASE2:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG3]]
599-
// CHECK: [[OWN:%.+]]:3 = bufferization.dealloc ([[BASE0]], [[BASE1]], [[BASE2]], [[ALLOC1]] :{{.*}}) if ([[ARG4]], [[ARG5]], [[ARG6]], %true{{[0-9_]*}}) retain ([[ALLOC2]], [[ALLOC1]], [[ARG2]] :
600-
// CHECK: scf.yield [[ALLOC2]], [[ALLOC1]], [[ARG2]], %true{{[0-9_]*}}, %true{{[0-9_]*}}, [[OWN]]#2 :
599+
// CHECK: [[OWN:%.+]] = bufferization.dealloc ([[BASE0]], [[BASE2]] :{{.*}}) if ([[ARG4]], [[ARG6]]) retain ([[ARG2]] :
600+
// CHECK: [[OWN_AGG:%.+]] = arith.ori [[OWN]], [[ARG5]]
601+
// CHECK: scf.yield [[ALLOC2]], [[ALLOC1]], [[ARG2]], %true{{[0-9_]*}}, %true{{[0-9_]*}}, [[OWN_AGG]] :
601602
// CHECK: }
602603
// CHECK: [[BASE0:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0
603604
// CHECK: [[BASE1:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#1

mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,30 @@ func.func @dealloc_remove_dealloc_memref_contained_in_retained_with_const_true_c
133133
// CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG2]]
134134
// CHECK: bufferization.dealloc ([[BASE]] :{{.*}}) if (%true{{[0-9_]*}})
135135
// CHECK-NEXT: return [[ARG0]], [[ARG1]], %true{{[0-9_]*}}, %true{{[0-9_]*}} :
136+
137+
// -----
138+
139+
func.func @alloc_and_bbarg(%arg0: memref<5xf32>, %arg1: index, %arg2: index, %arg3: index) -> f32 {
140+
%true = arith.constant true
141+
%false = arith.constant false
142+
%0:2 = scf.for %arg4 = %arg1 to %arg2 step %arg3 iter_args(%arg5 = %arg0, %arg6 = %false) -> (memref<5xf32>, i1) {
143+
%alloc = memref.alloc() : memref<5xf32>
144+
memref.copy %arg5, %alloc : memref<5xf32> to memref<5xf32>
145+
%base_buffer_0, %offset_1, %sizes_2, %strides_3 = memref.extract_strided_metadata %arg5 : memref<5xf32> -> memref<f32>, index, index, index
146+
%2 = bufferization.dealloc (%base_buffer_0, %alloc : memref<f32>, memref<5xf32>) if (%arg6, %true) retain (%alloc : memref<5xf32>)
147+
scf.yield %alloc, %2 : memref<5xf32>, i1
148+
}
149+
%1 = memref.load %0#0[%arg1] : memref<5xf32>
150+
%base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %0#0 : memref<5xf32> -> memref<f32>, index, index, index
151+
bufferization.dealloc (%base_buffer : memref<f32>) if (%0#1)
152+
return %1 : f32
153+
}
154+
155+
// CHECK-LABEL: func @alloc_and_bbarg
156+
// CHECK: %[[true:.*]] = arith.constant true
157+
// CHECK: scf.for {{.*}} iter_args(%[[iter:.*]] = %{{.*}}, %{{.*}} = %{{.*}})
158+
// CHECK: %[[alloc:.*]] = memref.alloc
159+
// CHECK: %[[view:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[iter]]
160+
// CHECK: bufferization.dealloc (%[[view]] : memref<f32>)
161+
// CHECK-NOT: retain
162+
// CHECK: scf.yield %[[alloc]], %[[true]]

0 commit comments

Comments
 (0)