Skip to content

[mlir][IR] Change block/region walkers to enumerate this block/region #75020

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Dec 11, 2023

This change makes block/region walkers consistent with operation walkers. An operation walk enumerates the current operation. Similarly, block/region walks should enumerate the current block/region.

Example:

// Current behavior:
op1->walk([](Operation *op2) { /* op1 is enumerated */ });
block1->walk([](Block *block2) { /* block1 is NOT enumerated */ });
region1->walk([](Block *block) { /* blocks of region1 are NOT enumerated */ });
region1->walk([](Region *region2) { /* region1 is NOT enumerated });

// New behavior:
op1->walk([](Operation *op2) { /* op1 is enumerated */ });
block1->walk([](Block *block2) { /* block1 IS enumerated */ });
region1->walk([](Block *block) { /* blocks of region1 ARE enumerated */ });
region1->walk([](Region *region2) { /* region1 IS enumerated });

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:bufferization Bufferization infrastructure labels Dec 11, 2023
@llvmbot
Copy link
Member

llvmbot commented Dec 11, 2023

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-bufferization

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

This change makes block/region walkers consistent with operation walkers. An operation walk enumerates the current operation. Similarly, block/region walks should enumerate the current block/region.

Example:

// Current behavior:
op1->walk([](Operation *op2) { /* op1 is enumerated */ });
block1->walk([](Block *block2) { /* block1 is NOT enumerated */ });
region1->walk([](Block *block) { /* blocks of region1 are NOT enumerated */ });
region1->walk([](Region *region2) { /* region1 is NOT enumerated });

// New behavior:
op1->walk([](Operation *op2) { /* op1 is enumerated */ });
block1->walk([](Block *block2) { /* block1 IS enumerated */ });
region1->walk([](Block *block) { /* blocks of region1 ARE enumerated */ });
region1->walk([](Region *region2) { /* region1 IS enumerated });

Depends on #75016. Only review the top commit.


Full diff: https://github.com/llvm/llvm-project/pull/75020.diff

5 Files Affected:

  • (modified) mlir/include/mlir/IR/Block.h (+67-44)
  • (modified) mlir/include/mlir/IR/Region.h (+47-35)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir (+16-1)
  • (modified) mlir/test/IR/visitors.mlir (+21-1)
  • (modified) mlir/test/lib/IR/TestVisitors.cpp (+55)
diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index 3d00c405ead374..e58b87774b8658 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -260,68 +260,91 @@ class Block : public IRObjectWithUseList<BlockOperand>,
   SuccessorRange getSuccessors() { return SuccessorRange(this); }
 
   //===--------------------------------------------------------------------===//
-  // Operation Walkers
+  // Walkers
   //===--------------------------------------------------------------------===//
 
-  /// Walk the operations in this block. The callback method is called for each
-  /// nested region, block or operation, depending on the callback provided.
-  /// The order in which regions, blocks and operations at the same nesting
+  /// Walk all nested operations, blocks (including this block) or regions,
+  /// depending on the type of callback.
+  ///
+  /// The order in which operations, blocks or regions at the same nesting
   /// level are visited (e.g., lexicographical or reverse lexicographical order)
-  /// is determined by 'Iterator'. The walk order for enclosing regions, blocks
-  /// and operations with respect to their nested ones is specified by 'Order'
-  /// (post-order by default). A callback on a block or operation is allowed to
-  /// erase that block or operation if either:
+  /// is determined by `Iterator`. The walk order for enclosing operations,
+  /// blocks or regions with respect to their nested ones is specified by
+  /// `Order` (post-order by default).
+  ///
+  /// A callback on a operation or block is allowed to erase that operation or
+  /// block if either:
   ///   * the walk is in post-order, or
   ///   * the walk is in pre-order and the walk is skipped after the erasure.
+  ///
   /// See Operation::walk for more details.
   template <WalkOrder Order = WalkOrder::PostOrder,
             typename Iterator = ForwardIterator, typename FnT,
+            typename ArgT = detail::first_argument<FnT>,
             typename RetT = detail::walkResultType<FnT>>
   RetT walk(FnT &&callback) {
-    return walk<Order, Iterator>(begin(), end(), std::forward<FnT>(callback));
-  }
-
-  /// Walk the operations in the specified [begin, end) range of this block. The
-  /// callback method is called for each nested region, block or operation,
-  /// depending on the callback provided. The order in which regions, blocks and
-  /// operations at the same nesting level are visited (e.g., lexicographical or
-  /// reverse lexicographical order) is determined by 'Iterator'. The walk order
-  /// for enclosing regions, blocks and operations with respect to their nested
-  /// ones is specified by 'Order' (post-order by default). This method is
-  /// invoked for void-returning callbacks. A callback on a block or operation
-  /// is allowed to erase that block or operation only if the walk is in
-  /// post-order. See non-void method for pre-order erasure.
-  /// See Operation::walk for more details.
-  template <WalkOrder Order = WalkOrder::PostOrder,
-            typename Iterator = ForwardIterator, typename FnT,
-            typename RetT = detail::walkResultType<FnT>>
-  std::enable_if_t<std::is_same<RetT, void>::value, RetT>
-  walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
-    for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
-      detail::walk<Order, Iterator>(&op, callback);
+    if constexpr (std::is_same<ArgT, Block *>::value &&
+                  Order == WalkOrder::PreOrder) {
+      // Pre-order walk on blocks: invoke the callback on this block.
+      if constexpr (std::is_same<RetT, WalkResult>::value) {
+        RetT result = callback(this);
+        if (result.wasSkipped())
+          return WalkResult::advance();
+        if (result.wasInterrupted())
+          return WalkResult::interrupt();
+      } else {
+        callback(this);
+      }
+    }
+
+    // Walk nested operations, blocks or regions.
+    if constexpr (std::is_same<RetT, WalkResult>::value) {
+      if (walk<Order, Iterator>(begin(), end(), std::forward<FnT>(callback))
+              .wasInterrupted())
+        return WalkResult::interrupt();
+    } else {
+      walk<Order, Iterator>(begin(), end(), std::forward<FnT>(callback));
+    }
+
+    if constexpr (std::is_same<ArgT, Block *>::value &&
+                  Order == WalkOrder::PostOrder) {
+      // Post-order walk on blocks: invoke the callback on this block.
+      return callback(this);
+    }
+    if constexpr (std::is_same<RetT, WalkResult>::value)
+      return WalkResult::advance();
   }
 
-  /// Walk the operations in the specified [begin, end) range of this block. The
-  /// callback method is called for each nested region, block or operation,
-  /// depending on the callback provided. The order in which regions, blocks and
-  /// operations at the same nesting level are visited (e.g., lexicographical or
-  /// reverse lexicographical order) is determined by 'Iterator'. The walk order
-  /// for enclosing regions, blocks and operations with respect to their nested
-  /// ones is specified by 'Order' (post-order by default). This method is
-  /// invoked for skippable or interruptible callbacks. A callback on a block or
-  /// operation is allowed to erase that block or operation if either:
+  /// Walk all nested operations, blocks (excluding this block) or regions,
+  /// depending on the type of callback, in the specified [begin, end) range of
+  /// this block.
+  ///
+  /// The order in which operations, blocks or regions at the same nesting
+  /// level are visited (e.g., lexicographical or reverse lexicographical order)
+  /// is determined by `Iterator`. The walk order for enclosing operations,
+  /// blocks or regions with respect to their nested ones is specified by
+  /// `Order` (post-order by default).
+  ///
+  /// A callback on a operation or block is allowed to erase that operation or
+  /// block if either:
   ///   * the walk is in post-order, or
   ///   * the walk is in pre-order and the walk is skipped after the erasure.
+  ///
   /// See Operation::walk for more details.
   template <WalkOrder Order = WalkOrder::PostOrder,
             typename Iterator = ForwardIterator, typename FnT,
             typename RetT = detail::walkResultType<FnT>>
-  std::enable_if_t<std::is_same<RetT, WalkResult>::value, RetT>
-  walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
-    for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
-      if (detail::walk<Order, Iterator>(&op, callback).wasInterrupted())
-        return WalkResult::interrupt();
-    return WalkResult::advance();
+  RetT walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
+    for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end))) {
+      if constexpr (std::is_same<RetT, WalkResult>::value) {
+        if (detail::walk<Order, Iterator>(&op, callback).wasInterrupted())
+          return WalkResult::interrupt();
+      } else {
+        detail::walk<Order, Iterator>(&op, callback);
+      }
+    }
+    if constexpr (std::is_same<RetT, WalkResult>::value)
+      return WalkResult::advance();
   }
 
   //===--------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h
index 4f4812dda79b89..b626350d1b657d 100644
--- a/mlir/include/mlir/IR/Region.h
+++ b/mlir/include/mlir/IR/Region.h
@@ -260,48 +260,60 @@ class Region {
   void dropAllReferences();
 
   //===--------------------------------------------------------------------===//
-  // Operation Walkers
+  // Walkers
   //===--------------------------------------------------------------------===//
 
-  /// Walk the operations in this region. The callback method is called for each
-  /// nested region, block or operation, depending on the callback provided.
-  /// The order in which regions, blocks and operations at the same nesting
-  /// level are visited (e.g., lexicographical or reverse lexicographical order)
-  /// is determined by 'Iterator'. The walk order for enclosing regions, blocks
-  /// and operations with respect to their nested ones is specified by 'Order'
-  /// (post-order by default). This method is invoked for void-returning
-  /// callbacks. A callback on a block or operation is allowed to erase that
-  /// block or operation only if the walk is in post-order. See non-void method
-  /// for pre-order erasure. See Operation::walk for more details.
-  template <WalkOrder Order = WalkOrder::PostOrder,
-            typename Iterator = ForwardIterator, typename FnT,
-            typename RetT = detail::walkResultType<FnT>>
-  std::enable_if_t<std::is_same<RetT, void>::value, RetT> walk(FnT &&callback) {
-    for (auto &block : *this)
-      block.walk<Order, Iterator>(callback);
-  }
-
-  /// Walk the operations in this region. The callback method is called for each
-  /// nested region, block or operation, depending on the callback provided.
-  /// The order in which regions, blocks and operations at the same nesting
+  /// Walk all nested operations, blocks or regions (including this region),
+  /// depending on the type of callback.
+  ///
+  /// The order in which operations, blocks or regions at the same nesting
   /// level are visited (e.g., lexicographical or reverse lexicographical order)
-  /// is determined by 'Iterator'. The walk order for enclosing regions, blocks
-  /// and operations with respect to their nested ones is specified by 'Order'
-  /// (post-order by default). This method is invoked for skippable or
-  /// interruptible callbacks. A callback on a block or operation is allowed to
-  /// erase that block or operation if either:
-  ///   * the walk is in post-order,
-  ///   * or the walk is in pre-order and the walk is skipped after the erasure.
+  /// is determined by `Iterator`. The walk order for enclosing operations,
+  /// blocks or regions with respect to their nested ones is specified by
+  /// `Order` (post-order by default).
+  ///
+  /// A callback on a operation or block is allowed to erase that operation or
+  /// block if either:
+  ///   * the walk is in post-order, or
+  ///   * the walk is in pre-order and the walk is skipped after the erasure.
+  ///
   /// See Operation::walk for more details.
   template <WalkOrder Order = WalkOrder::PostOrder,
             typename Iterator = ForwardIterator, typename FnT,
+            typename ArgT = detail::first_argument<FnT>,
             typename RetT = detail::walkResultType<FnT>>
-  std::enable_if_t<std::is_same<RetT, WalkResult>::value, RetT>
-  walk(FnT &&callback) {
-    for (auto &block : *this)
-      if (block.walk<Order, Iterator>(callback).wasInterrupted())
-        return WalkResult::interrupt();
-    return WalkResult::advance();
+  RetT walk(FnT &&callback) {
+    if constexpr (std::is_same<ArgT, Region *>::value &&
+                  Order == WalkOrder::PreOrder) {
+      // Pre-order walk on regions: invoke the callback on this region.
+      if constexpr (std::is_same<RetT, WalkResult>::value) {
+        RetT result = callback(this);
+        if (result.wasSkipped())
+          return WalkResult::advance();
+        if (result.wasInterrupted())
+          return WalkResult::interrupt();
+      } else {
+        callback(this);
+      }
+    }
+
+    // Walk nested operations, blocks or regions.
+    for (auto &block : *this) {
+      if constexpr (std::is_same<RetT, WalkResult>::value) {
+        if (block.walk<Order, Iterator>(callback).wasInterrupted())
+          return WalkResult::interrupt();
+      } else {
+        block.walk<Order, Iterator>(callback);
+      }
+    }
+
+    if constexpr (std::is_same<ArgT, Region *>::value &&
+                  Order == WalkOrder::PostOrder) {
+      // Post-order walk on regions: invoke the callback on this block.
+      return callback(this);
+    }
+    if constexpr (std::is_same<RetT, WalkResult>::value)
+      return WalkResult::advance();
   }
 
   //===--------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir
index ad7c4c783e907f..1a8a930bc9002b 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir
@@ -531,8 +531,8 @@ func.func @noRegionBranchOpInterface() {
 // This is not allowed in buffer deallocation.
 
 func.func @noRegionBranchOpInterface() {
-  // expected-error@+1 {{All operations with attached regions need to implement the RegionBranchOpInterface.}}
   %0 = "test.bar"() ({
+    // expected-error@+1 {{All operations with attached regions need to implement the RegionBranchOpInterface.}}
     %1 = "test.bar"() ({
       %2 = "test.get_memref"() : () -> memref<2xi32>
       "test.yield"(%2) : (memref<2xi32>) -> ()
@@ -544,6 +544,21 @@ func.func @noRegionBranchOpInterface() {
 
 // -----
 
+// Test Case: The op "test.bar" does not implement the RegionBranchOpInterface.
+// This is not allowed in buffer deallocation.
+
+func.func @noRegionBranchOpInterface() {
+  // expected-error@+1 {{All operations with attached regions need to implement the RegionBranchOpInterface.}}
+  %0 = "test.bar"() ({
+    %2 = "test.get_memref"() : () -> memref<2xi32>
+    %3 = "test.foo"(%2) : (memref<2xi32>) -> (i32)
+    "test.yield"(%3) : (i32) -> ()
+  }) : () -> (i32)
+  "test.terminator"() : () -> ()
+}
+
+// -----
+
 func.func @while_two_arg(%arg0: index) {
   %a = memref.alloc(%arg0) : memref<?xf32>
   scf.while (%arg1 = %a, %arg2 = %a) : (memref<?xf32>, memref<?xf32>) -> (memref<?xf32>, memref<?xf32>) {
diff --git a/mlir/test/IR/visitors.mlir b/mlir/test/IR/visitors.mlir
index 2d83d6922e0cd0..ec7712a45d3882 100644
--- a/mlir/test/IR/visitors.mlir
+++ b/mlir/test/IR/visitors.mlir
@@ -17,7 +17,7 @@ func.func @structured_cfg() {
       "use2"(%i) : (index) -> ()
     }
     "use3"(%i) : (index) -> ()
-  }
+  } {walk_blocks, walk_regions}
   return
 }
 
@@ -88,6 +88,26 @@ func.func @structured_cfg() {
 // CHECK:       Visiting op 'func.func'
 // CHECK:       Visiting op 'builtin.module'
 
+// CHECK-LABEL: Invoke block pre-order visits on blocks
+// CHECK:       Visiting block ^bb0 from region 0 from operation 'scf.for'
+// CHECK:       Visiting block ^bb0 from region 0 from operation 'scf.if'
+// CHECK:       Visiting block ^bb0 from region 1 from operation 'scf.if'
+
+// CHECK-LABEL: Invoke block post-order visits on blocks
+// CHECK:       Visiting block ^bb0 from region 0 from operation 'scf.if'
+// CHECK:       Visiting block ^bb0 from region 1 from operation 'scf.if'
+// CHECK:       Visiting block ^bb0 from region 0 from operation 'scf.for'
+
+// CHECK-LABEL: Invoke region pre-order visits on region
+// CHECK:       Visiting region 0 from operation 'scf.for'
+// CHECK:       Visiting region 0 from operation 'scf.if'
+// CHECK:       Visiting region 1 from operation 'scf.if'
+
+// CHECK-LABEL: Invoke region post-order visits on region
+// CHECK:       Visiting region 0 from operation 'scf.if'
+// CHECK:       Visiting region 1 from operation 'scf.if'
+// CHECK:       Visiting region 0 from operation 'scf.for'
+
 // CHECK-LABEL: Op pre-order erasures
 // CHECK:       Erasing op 'scf.for'
 // CHECK:       Erasing op 'func.return'
diff --git a/mlir/test/lib/IR/TestVisitors.cpp b/mlir/test/lib/IR/TestVisitors.cpp
index a3ef3f35159534..f4cff39cf2e523 100644
--- a/mlir/test/lib/IR/TestVisitors.cpp
+++ b/mlir/test/lib/IR/TestVisitors.cpp
@@ -204,6 +204,60 @@ static void testNoSkipErasureCallbacks(Operation *op) {
   cloned->erase();
 }
 
+/// Invoke region/block walks on regions/blocks.
+static void testBlockAndRegionWalkers(Operation *op) {
+  auto blockPure = [](Block *block) {
+    llvm::outs() << "Visiting ";
+    printBlock(block);
+    llvm::outs() << "\n";
+  };
+  auto regionPure = [](Region *region) {
+    llvm::outs() << "Visiting ";
+    printRegion(region);
+    llvm::outs() << "\n";
+  };
+
+  llvm::outs() << "Invoke block pre-order visits on blocks\n";
+  op->walk([&](Operation *op) {
+    if (!op->hasAttr("walk_blocks"))
+      return;
+    for (Region &region : op->getRegions()) {
+      for (Block &block : region.getBlocks()) {
+        block.walk<WalkOrder::PreOrder>(blockPure);
+      }
+    }
+  });
+
+  llvm::outs() << "Invoke block post-order visits on blocks\n";
+  op->walk([&](Operation *op) {
+    if (!op->hasAttr("walk_blocks"))
+      return;
+    for (Region &region : op->getRegions()) {
+      for (Block &block : region.getBlocks()) {
+        block.walk<WalkOrder::PostOrder>(blockPure);
+      }
+    }
+  });
+
+  llvm::outs() << "Invoke region pre-order visits on region\n";
+  op->walk([&](Operation *op) {
+    if (!op->hasAttr("walk_regions"))
+      return;
+    for (Region &region : op->getRegions()) {
+      region.walk<WalkOrder::PreOrder>(regionPure);
+    }
+  });
+
+  llvm::outs() << "Invoke region post-order visits on region\n";
+  op->walk([&](Operation *op) {
+    if (!op->hasAttr("walk_regions"))
+      return;
+    for (Region &region : op->getRegions()) {
+      region.walk<WalkOrder::PostOrder>(regionPure);
+    }
+  });
+}
+
 namespace {
 /// This pass exercises the different configurations of the IR visitors.
 struct TestIRVisitorsPass
@@ -215,6 +269,7 @@ struct TestIRVisitorsPass
   void runOnOperation() override {
     Operation *op = getOperation();
     testPureCallbacks(op);
+    testBlockAndRegionWalkers(op);
     testSkipErasureCallbacks(op);
     testNoSkipErasureCallbacks(op);
   }

This change makes block/region walkers consistent with operation walkers. An operation walk enumerates the current operation. Similarly, block/region walks should enumerate the current block/region.

Example:
```
// Current behavior:
op1->walk([](Operation *op2) { /* op1 is enumerated */ });
block1->walk([](Block *block2) { /* block1 is NOT enumerated */ });
region1->walk([](Block *block) { /* blocks of region1 are NOT enumerated */ });
region1->walk([](Region *region2) { /* region1 is NOT enumerated });

// New behavior:
op1->walk([](Operation *op2) { /* op1 is enumerated */ });
block1->walk([](Block *block2) { /* block1 IS enumerated */ });
region1->walk([](Block *block) { /* blocks of region1 ARE enumerated */ });
region1->walk([](Region *region2) { /* region1 IS enumerated });
```
Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks

/// ones is specified by 'Order' (post-order by default). This method is
/// invoked for skippable or interruptible callbacks. A callback on a block or
/// operation is allowed to erase that block or operation if either:
/// Walk all nested operations, blocks (excluding this block) or regions,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a little weird that these have the same name but one inclusive and the other exclusive

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A slight inconsistency, that's true. But I don't have any better name. I think it should be clear from the API that this block is excluded because the function takes begin and end block iterators.

@matthias-springer matthias-springer merged commit c4457e1 into llvm:main Dec 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants