Skip to content

Commit a43641c

Browse files
[mlir][bufferization] Fix regionOperatesOnMemrefValues (#75016)
`Region::walk([](Block *b) {...})` does not enumerate blocks that are direct children of the region. These blocks must be checked manually.
1 parent 6ed1daa commit a43641c

File tree

2 files changed

+29
-4
lines changed

2 files changed

+29
-4
lines changed

mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ BufferDeallocation::materializeUniqueOwnership(OpBuilder &builder, Value memref,
463463
}
464464

465465
static bool regionOperatesOnMemrefValues(Region &region) {
466-
WalkResult result = region.walk([](Block *block) {
466+
auto checkBlock = [](Block *block) {
467467
if (llvm::any_of(block->getArguments(), isMemref))
468468
return WalkResult::interrupt();
469469
for (Operation &op : *block) {
@@ -473,8 +473,18 @@ static bool regionOperatesOnMemrefValues(Region &region) {
473473
return WalkResult::interrupt();
474474
}
475475
return WalkResult::advance();
476-
});
477-
return result.wasInterrupted();
476+
};
477+
WalkResult result = region.walk(checkBlock);
478+
if (result.wasInterrupted())
479+
return true;
480+
481+
// Note: Block::walk/Region::walk visits only blocks that are nested under
482+
// nested operations, but not direct children.
483+
for (Block &block : region)
484+
if (checkBlock(&block).wasInterrupted())
485+
return true;
486+
487+
return false;
478488
}
479489

480490
LogicalResult

mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,8 +531,8 @@ func.func @noRegionBranchOpInterface() {
531531
// This is not allowed in buffer deallocation.
532532

533533
func.func @noRegionBranchOpInterface() {
534-
// expected-error@+1 {{All operations with attached regions need to implement the RegionBranchOpInterface.}}
535534
%0 = "test.bar"() ({
535+
// expected-error@+1 {{All operations with attached regions need to implement the RegionBranchOpInterface.}}
536536
%1 = "test.bar"() ({
537537
%2 = "test.get_memref"() : () -> memref<2xi32>
538538
"test.yield"(%2) : (memref<2xi32>) -> ()
@@ -544,6 +544,21 @@ func.func @noRegionBranchOpInterface() {
544544

545545
// -----
546546

547+
// Test Case: The op "test.bar" does not implement the RegionBranchOpInterface.
548+
// This is not allowed in buffer deallocation.
549+
550+
func.func @noRegionBranchOpInterface() {
551+
// expected-error@+1 {{All operations with attached regions need to implement the RegionBranchOpInterface.}}
552+
%0 = "test.bar"() ({
553+
%2 = "test.get_memref"() : () -> memref<2xi32>
554+
%3 = "test.foo"(%2) : (memref<2xi32>) -> (i32)
555+
"test.yield"(%3) : (i32) -> ()
556+
}) : () -> (i32)
557+
"test.terminator"() : () -> ()
558+
}
559+
560+
// -----
561+
547562
func.func @while_two_arg(%arg0: index) {
548563
%a = memref.alloc(%arg0) : memref<?xf32>
549564
scf.while (%arg1 = %a, %arg2 = %a) : (memref<?xf32>, memref<?xf32>) -> (memref<?xf32>, memref<?xf32>) {

0 commit comments

Comments
 (0)