Skip to content

Commit 89fd111

Browse files
[mlir][bufferization] Buffer deallocation: Disallow unregistered ops
Memory side effects of unregistered ops are unknown. In particular, we do not know whether an unregistered op allocates memory or not. Therefore, unregistered ops cannot be handled safely in the buffer deallocation pass.
1 parent 9803de0 commit 89fd111

File tree

9 files changed

+150
-77
lines changed

9 files changed

+150
-77
lines changed

mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp

Lines changed: 51 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,32 @@ static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
4848

4949
static bool isMemref(Value v) { return v.getType().isa<BaseMemRefType>(); }
5050

51+
/// Return "true" if the given op is guaranteed to have no "Allocate" or "Free"
52+
/// side effect.
53+
static bool hasNoAllocateOrFreeSideEffect(Operation *op) {
54+
if (isa<MemoryEffectOpInterface>(op))
55+
return hasEffect<MemoryEffects::Allocate>(op) ||
56+
hasEffect<MemoryEffects::Free>(op);
57+
// If the op does not implement the MemoryEffectOpInterface but has has
58+
// recursive memory effects, then this op in isolation (without its body) does
59+
// not have any side effects. The ops inside the regions of this op will be
60+
// processed separately.
61+
return op->hasTrait<OpTrait::HasRecursiveMemoryEffects>();
62+
}
63+
64+
/// Return "true" if the given op has buffer semantics. I.e., it has buffer
65+
/// operands, buffer results and/or buffer region entry block arguments.
66+
static bool hasBufferSemantics(Operation *op) {
67+
if (llvm::any_of(op->getOperands(), isMemref) ||
68+
llvm::any_of(op->getResults(), isMemref))
69+
return true;
70+
for (Region &region : op->getRegions())
71+
if (!region.empty())
72+
if (llvm::any_of(region.front().getArguments(), isMemref))
73+
return true;
74+
return false;
75+
}
76+
5177
//===----------------------------------------------------------------------===//
5278
// Backedges analysis
5379
//===----------------------------------------------------------------------===//
@@ -462,21 +488,6 @@ BufferDeallocation::materializeUniqueOwnership(OpBuilder &builder, Value memref,
462488
return state.getMemrefWithUniqueOwnership(builder, memref, block);
463489
}
464490

465-
static bool regionOperatesOnMemrefValues(Region &region) {
466-
WalkResult result = region.walk([](Block *block) {
467-
if (llvm::any_of(block->getArguments(), isMemref))
468-
return WalkResult::interrupt();
469-
for (Operation &op : *block) {
470-
if (llvm::any_of(op.getOperands(), isMemref))
471-
return WalkResult::interrupt();
472-
if (llvm::any_of(op.getResults(), isMemref))
473-
return WalkResult::interrupt();
474-
}
475-
return WalkResult::advance();
476-
});
477-
return result.wasInterrupted();
478-
}
479-
480491
LogicalResult
481492
BufferDeallocation::verifyFunctionPreconditions(FunctionOpInterface op) {
482493
// (1) Ensure that there are supported loops only (no explicit control flow
@@ -491,7 +502,29 @@ BufferDeallocation::verifyFunctionPreconditions(FunctionOpInterface op) {
491502
}
492503

493504
LogicalResult BufferDeallocation::verifyOperationPreconditions(Operation *op) {
494-
// (1) Check that the control flow structures are supported.
505+
// (1) The pass does not work properly when deallocations are already present.
506+
// Alternatively, we could also remove all deallocations as a pre-pass.
507+
if (isa<DeallocOp>(op))
508+
return op->emitError(
509+
"No deallocation operations must be present when running this pass!");
510+
511+
// (2) Memory side effects of unregistered ops are unknown. In particular, we
512+
// do not know whether an unregistered op allocates memory or not. Call ops
513+
// typically do not implement the MemoryEffectOpInterface but usually do not
514+
// have side effects (apart from the callee, which will be analyzed
515+
// separately), so they are allowed.
516+
if (!isa<MemoryEffectOpInterface>(op) &&
517+
!op->hasTrait<OpTrait::HasRecursiveMemoryEffects>() &&
518+
!isa<CallOpInterface>(op))
519+
return op->emitError(
520+
"ops with unknown memory side effects are not supported");
521+
522+
// We do not care about ops that do not operate on buffers and have no
523+
// Allocate/Free side effect.
524+
if (!hasBufferSemantics(op) && hasNoAllocateOrFreeSideEffect(op))
525+
return success();
526+
527+
// (3) Check that the control flow structures are supported.
495528
auto regions = op->getRegions();
496529
// Check that if the operation has at
497530
// least one region it implements the RegionBranchOpInterface. If there
@@ -502,17 +535,10 @@ LogicalResult BufferDeallocation::verifyOperationPreconditions(Operation *op) {
502535
size_t size = regions.size();
503536
if (((size == 1 && !op->getResults().empty()) || size > 1) &&
504537
!dyn_cast<RegionBranchOpInterface>(op)) {
505-
if (llvm::any_of(regions, regionOperatesOnMemrefValues))
506-
return op->emitError("All operations with attached regions need to "
507-
"implement the RegionBranchOpInterface.");
538+
return op->emitError("All operations with attached regions need to "
539+
"implement the RegionBranchOpInterface.");
508540
}
509541

510-
// (2) The pass does not work properly when deallocations are already present.
511-
// Alternatively, we could also remove all deallocations as a pre-pass.
512-
if (isa<DeallocOp>(op))
513-
return op->emitError(
514-
"No deallocation operations must be present when running this pass!");
515-
516542
// (3) Check that terminators with more than one successor except `cf.cond_br`
517543
// are not present and that either BranchOpInterface or
518544
// RegionBranchTerminatorOpInterface is implemented.

mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,7 @@ func.func @select_captured_in_next_block(%arg0: index, %arg1: memref<?xi8>, %arg
570570
func.func @blocks_not_preordered_by_dominance() {
571571
cf.br ^bb1
572572
^bb2:
573-
"test.memref_user"(%alloc) : (memref<2xi32>) -> ()
573+
"test.read_buffer"(%alloc) : (memref<2xi32>) -> ()
574574
return
575575
^bb1:
576576
%alloc = memref.alloc() : memref<2xi32>
@@ -581,7 +581,7 @@ func.func @blocks_not_preordered_by_dominance() {
581581
// CHECK-NEXT: [[TRUE:%.+]] = arith.constant true
582582
// CHECK-NEXT: cf.br [[BB1:\^.+]]
583583
// CHECK-NEXT: [[BB2:\^[a-zA-Z0-9_]+]]:
584-
// CHECK-NEXT: "test.memref_user"([[ALLOC:%[a-zA-Z0-9_]+]])
584+
// CHECK-NEXT: "test.read_buffer"([[ALLOC:%[a-zA-Z0-9_]+]])
585585
// CHECK-NEXT: bufferization.dealloc ([[ALLOC]] : {{.*}}) if ([[TRUE]])
586586
// CHECK-NOT: retain
587587
// CHECK-NEXT: return

mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-existing-deallocs.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ func.func @auto_dealloc() {
88
%c100 = arith.constant 100 : index
99
%alloc = memref.alloc(%c10) : memref<?xi32>
1010
%realloc = memref.realloc %alloc(%c100) : memref<?xi32> to memref<?xi32>
11-
"test.memref_user"(%realloc) : (memref<?xi32>) -> ()
11+
"test.read_buffer"(%realloc) : (memref<?xi32>) -> ()
1212
return
1313
}
1414

@@ -17,7 +17,7 @@ func.func @auto_dealloc() {
1717
// CHECK-NOT: bufferization.dealloc
1818
// CHECK: [[V0:%.+]]:2 = scf.if
1919
// CHECK-NOT: bufferization.dealloc
20-
// CHECK: test.memref_user
20+
// CHECK: test.read_buffer
2121
// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[V0]]#0
2222
// CHECK-NEXT: bufferization.dealloc ([[ALLOC]], [[BASE]] :{{.*}}) if (%true{{[0-9_]*}}, [[V0]]#1)
2323
// CHECK-NEXT: return
@@ -32,14 +32,14 @@ func.func @auto_dealloc_inside_nested_region(%arg0: memref<?xi32>, %arg1: i1) {
3232
} else {
3333
scf.yield %arg0 : memref<?xi32>
3434
}
35-
"test.memref_user"(%0) : (memref<?xi32>) -> ()
35+
"test.read_buffer"(%0) : (memref<?xi32>) -> ()
3636
return
3737
}
3838

3939
// CHECK-LABEL: func @auto_dealloc_inside_nested_region
4040
// CHECK-SAME: (%arg0: memref<?xi32>, %arg1: i1)
4141
// CHECK-NOT: dealloc
42-
// CHECK: "test.memref_user"([[V0:%.+]]#0)
42+
// CHECK: "test.read_buffer"([[V0:%.+]]#0)
4343
// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0
4444
// CHECK-NEXT: bufferization.dealloc ([[BASE]] : memref<i32>) if ([[V0]]#1)
4545
// CHECK-NEXT: return

mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-function-boundaries.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
func.func private @emptyUsesValue(%arg0: memref<4xf32>) {
1414
%0 = memref.alloc() : memref<4xf32>
15-
"test.memref_user"(%0) : (memref<4xf32>) -> ()
15+
"test.read_buffer"(%0) : (memref<4xf32>) -> ()
1616
return
1717
}
1818

@@ -37,7 +37,7 @@ func.func private @emptyUsesValue(%arg0: memref<4xf32>) {
3737

3838
func.func @emptyUsesValue(%arg0: memref<4xf32>) {
3939
%0 = memref.alloc() : memref<4xf32>
40-
"test.memref_user"(%0) : (memref<4xf32>) -> ()
40+
"test.read_buffer"(%0) : (memref<4xf32>) -> ()
4141
return
4242
}
4343

mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,20 @@ func.func private @no_interface_no_operands(%t : tensor<?x?x?xf16>) -> memref<?x
2121
// CHECK-LABEL: func private @no_interface(
2222
// CHECK: %[[true:.*]] = arith.constant true
2323
// CHECK: %[[alloc:.*]] = memref.alloc
24-
// CHECK: %[[foo:.*]] = "test.foo"(%[[alloc]])
24+
// CHECK: %[[foo:.*]] = "test.forward_buffer"(%[[alloc]])
2525
// CHECK: %[[r:.*]] = bufferization.dealloc (%[[alloc]] : {{.*}}) if (%[[true]]) retain (%[[foo]] : {{.*}})
2626
// CHECK: return %[[foo]]
2727
func.func private @no_interface() -> memref<5xf32> {
2828
%0 = memref.alloc() : memref<5xf32>
29-
%1 = "test.foo"(%0) : (memref<5xf32>) -> (memref<5xf32>)
29+
%1 = "test.forward_buffer"(%0) : (memref<5xf32>) -> (memref<5xf32>)
3030
return %1 : memref<5xf32>
3131
}
32+
33+
// -----
34+
35+
func.func @no_side_effects() {
36+
%0 = memref.alloc() : memref<5xf32>
37+
// expected-error @below{{ops with unknown memory side effects are not supported}}
38+
"test.unregistered_op_foo"(%0) : (memref<5xf32>) -> ()
39+
return
40+
}

mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir

Lines changed: 30 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ func.func @nested_region_control_flow(
7171
scf.yield %1 : memref<?x?xf32>
7272
} else {
7373
%3 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
74-
"test.memref_user"(%3) : (memref<?x?xf32>) -> ()
74+
"test.read_buffer"(%3) : (memref<?x?xf32>) -> ()
7575
scf.yield %1 : memref<?x?xf32>
7676
}
7777
return %2 : memref<?x?xf32>
@@ -253,7 +253,7 @@ func.func @loop_alloc(
253253
%buf: memref<2xf32>,
254254
%res: memref<2xf32>) {
255255
%0 = memref.alloc() : memref<2xf32>
256-
"test.memref_user"(%0) : (memref<2xf32>) -> ()
256+
"test.read_buffer"(%0) : (memref<2xf32>) -> ()
257257
%1 = scf.for %i = %lb to %ub step %step
258258
iter_args(%iterBuf = %buf) -> memref<2xf32> {
259259
%2 = arith.cmpi eq, %i, %ub : index
@@ -385,15 +385,15 @@ func.func @loop_nested_alloc(
385385
%buf: memref<2xf32>,
386386
%res: memref<2xf32>) {
387387
%0 = memref.alloc() : memref<2xf32>
388-
"test.memref_user"(%0) : (memref<2xf32>) -> ()
388+
"test.read_buffer"(%0) : (memref<2xf32>) -> ()
389389
%1 = scf.for %i = %lb to %ub step %step
390390
iter_args(%iterBuf = %buf) -> memref<2xf32> {
391391
%2 = scf.for %i2 = %lb to %ub step %step
392392
iter_args(%iterBuf2 = %iterBuf) -> memref<2xf32> {
393393
%3 = scf.for %i3 = %lb to %ub step %step
394394
iter_args(%iterBuf3 = %iterBuf2) -> memref<2xf32> {
395395
%4 = memref.alloc() : memref<2xf32>
396-
"test.memref_user"(%4) : (memref<2xf32>) -> ()
396+
"test.read_buffer"(%4) : (memref<2xf32>) -> ()
397397
%5 = arith.cmpi eq, %i, %ub : index
398398
%6 = scf.if %5 -> (memref<2xf32>) {
399399
%7 = memref.alloc() : memref<2xf32>
@@ -476,7 +476,7 @@ func.func @assumingOp(
476476
// Confirm the alloc will be dealloc'ed in the block.
477477
%1 = shape.assuming %arg0 -> memref<2xf32> {
478478
%0 = memref.alloc() : memref<2xf32>
479-
"test.memref_user"(%0) : (memref<2xf32>) -> ()
479+
"test.read_buffer"(%0) : (memref<2xf32>) -> ()
480480
shape.assuming_yield %arg2 : memref<2xf32>
481481
}
482482
// Confirm the alloc will be returned and dealloc'ed after its use.
@@ -511,58 +511,49 @@ func.func @assumingOp(
511511

512512
// -----
513513

514-
// Test Case: The op "test.bar" does not implement the RegionBranchOpInterface.
515-
// This is only allowed in buffer deallocation because the operation's region
516-
// does not deal with any MemRef values.
514+
// Test Case: The op "test.one_region_with_recursive_memory_effects" does not
515+
// implement the RegionBranchOpInterface. This is allowed during buffer
516+
// deallocation because the operation's region does not deal with any MemRef
517+
// values.
517518

518519
func.func @noRegionBranchOpInterface() {
519-
%0 = "test.bar"() ({
520-
%1 = "test.bar"() ({
521-
"test.yield"() : () -> ()
520+
%0 = "test.one_region_with_recursive_memory_effects"() ({
521+
%1 = "test.one_region_with_recursive_memory_effects"() ({
522+
%2 = memref.alloc() : memref<2xi32>
523+
"test.read_buffer"(%2) : (memref<2xi32>) -> ()
524+
"test.return"() : () -> ()
522525
}) : () -> (i32)
523-
"test.yield"() : () -> ()
526+
"test.return"() : () -> ()
524527
}) : () -> (i32)
525-
"test.terminator"() : () -> ()
528+
"test.return"() : () -> ()
526529
}
527530

528531
// -----
529532

530-
// Test Case: The op "test.bar" does not implement the RegionBranchOpInterface.
531-
// This is not allowed in buffer deallocation.
533+
// Test Case: The second op "test.one_region_with_recursive_memory_effects" does
534+
// not implement the RegionBranchOpInterface but has buffer semantics. This is
535+
// not allowed during buffer deallocation.
532536

533537
func.func @noRegionBranchOpInterface() {
534-
%0 = "test.bar"() ({
538+
%0 = "test.one_region_with_recursive_memory_effects"() ({
535539
// expected-error@+1 {{All operations with attached regions need to implement the RegionBranchOpInterface.}}
536-
%1 = "test.bar"() ({
537-
%2 = "test.get_memref"() : () -> memref<2xi32>
538-
"test.yield"(%2) : (memref<2xi32>) -> ()
540+
%1 = "test.one_region_with_recursive_memory_effects"() ({
541+
%2 = memref.alloc() : memref<2xi32>
542+
"test.read_buffer"(%2) : (memref<2xi32>) -> ()
543+
"test.return"(%2) : (memref<2xi32>) -> ()
539544
}) : () -> (memref<2xi32>)
540-
"test.yield"() : () -> ()
545+
"test.return"() : () -> ()
541546
}) : () -> (i32)
542-
"test.terminator"() : () -> ()
543-
}
544-
545-
// -----
546-
547-
// Test Case: The op "test.bar" does not implement the RegionBranchOpInterface.
548-
// This is not allowed in buffer deallocation.
549-
550-
func.func @noRegionBranchOpInterface() {
551-
// expected-error@+1 {{All operations with attached regions need to implement the RegionBranchOpInterface.}}
552-
%0 = "test.bar"() ({
553-
%2 = "test.get_memref"() : () -> memref<2xi32>
554-
%3 = "test.foo"(%2) : (memref<2xi32>) -> (i32)
555-
"test.yield"(%3) : (i32) -> ()
556-
}) : () -> (i32)
557-
"test.terminator"() : () -> ()
547+
"test.return"() : () -> ()
558548
}
559549

560550
// -----
561551

562552
func.func @while_two_arg(%arg0: index) {
563553
%a = memref.alloc(%arg0) : memref<?xf32>
564554
scf.while (%arg1 = %a, %arg2 = %a) : (memref<?xf32>, memref<?xf32>) -> (memref<?xf32>, memref<?xf32>) {
565-
%0 = "test.make_condition"() : () -> i1
555+
// This op has a side effect, but it's not an allocate/free side effect.
556+
%0 = "test.side_effect_op"() {effects = [{effect="read"}]} : () -> i1
566557
scf.condition(%0) %arg1, %arg2 : memref<?xf32>, memref<?xf32>
567558
} do {
568559
^bb0(%arg1: memref<?xf32>, %arg2: memref<?xf32>):
@@ -591,7 +582,8 @@ func.func @while_two_arg(%arg0: index) {
591582
func.func @while_three_arg(%arg0: index) {
592583
%a = memref.alloc(%arg0) : memref<?xf32>
593584
scf.while (%arg1 = %a, %arg2 = %a, %arg3 = %a) : (memref<?xf32>, memref<?xf32>, memref<?xf32>) -> (memref<?xf32>, memref<?xf32>, memref<?xf32>) {
594-
%0 = "test.make_condition"() : () -> i1
585+
// This op has a side effect but it's not an Allocate or Free side effect.
586+
%0 = "test.side_effect_op"() {effects = [{effect="read"}]} : () -> i1
595587
scf.condition(%0) %arg1, %arg2, %arg3 : memref<?xf32>, memref<?xf32>, memref<?xf32>
596588
} do {
597589
^bb0(%arg1: memref<?xf32>, %arg2: memref<?xf32>, %arg3: memref<?xf32>):

mlir/test/Dialect/GPU/bufferization-buffer-deallocation.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ func.func @gpu_launch() {
55
gpu.launch blocks(%arg0, %arg1, %arg2) in (%arg6 = %c1, %arg7 = %c1, %arg8 = %c1)
66
threads(%arg3, %arg4, %arg5) in (%arg9 = %c1, %arg10 = %c1, %arg11 = %c1) {
77
%alloc = memref.alloc() : memref<2xf32>
8-
"test.memref_user"(%alloc) : (memref<2xf32>) -> ()
8+
"test.read_buffer"(%alloc) : (memref<2xf32>) -> ()
99
gpu.terminator
1010
}
1111
return

mlir/test/lib/Dialect/Test/TestDialect.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,6 +1272,21 @@ static void printSumProperty(OpAsmPrinter &printer, Operation *op,
12721272
printer << second << " = " << (second + first);
12731273
}
12741274

1275+
//===----------------------------------------------------------------------===//
1276+
// Tensor/Buffer Ops
1277+
//===----------------------------------------------------------------------===//
1278+
1279+
void ReadBufferOp::getEffects(
1280+
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1281+
&effects) {
1282+
// The buffer operand is read.
1283+
effects.emplace_back(MemoryEffects::Read::get(), getBuffer(),
1284+
SideEffects::DefaultResource::get());
1285+
// The buffer contents are dumped.
1286+
effects.emplace_back(MemoryEffects::Write::get(),
1287+
SideEffects::DefaultResource::get());
1288+
}
1289+
12751290
//===----------------------------------------------------------------------===//
12761291
// Test Dataflow
12771292
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)