Skip to content

[MLIR] Support interrupting AffineExpr walks #74792

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 5, 2024

Conversation

bondhugula
Copy link
Contributor

@bondhugula bondhugula commented Dec 8, 2023

Support WalkResult for AffineExpr walk and support interrupting walks
along the lines of Operation::walk. This allows interrupted walks when a
condition is met. Also, switch from std::function to llvm::function_ref
for the walk function.

@bondhugula bondhugula requested a review from River707 December 8, 2023 01:02
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:affine mlir labels Dec 8, 2023
@llvmbot
Copy link
Member

llvmbot commented Dec 8, 2023

@llvm/pr-subscribers-mlir-affine

@llvm/pr-subscribers-mlir

Author: Uday Bondhugula (bondhugula)

Changes

Support WalkResult for AffineExpr walk and support interrupting walks
along the lines of Operation::walk. This allows interrupted walks when a
condition is met. Also, switch from std::function to llvm::function_ref
for the walk function.


Full diff: https://github.com/llvm/llvm-project/pull/74792.diff

4 Files Affected:

  • (modified) mlir/include/mlir/IR/AffineExpr.h (+23-2)
  • (modified) mlir/include/mlir/IR/AffineExprVisitor.h (+48-8)
  • (modified) mlir/lib/Dialect/Affine/Utils/Utils.cpp (+29-31)
  • (modified) mlir/lib/IR/AffineExpr.cpp (+28-13)
diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h
index 40e9d28ce5d3a..181a24472473a 100644
--- a/mlir/include/mlir/IR/AffineExpr.h
+++ b/mlir/include/mlir/IR/AffineExpr.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_IR_AFFINEEXPR_H
 #define MLIR_IR_AFFINEEXPR_H
 
+#include "mlir/IR/Visitors.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/DenseMapInfo.h"
 #include "llvm/ADT/Hashing.h"
@@ -123,8 +124,19 @@ class AffineExpr {
   /// Return true if the affine expression involves AffineSymbolExpr `position`.
   bool isFunctionOfSymbol(unsigned position) const;
 
-  /// Walk all of the AffineExpr's in this expression in postorder.
-  void walk(std::function<void(AffineExpr)> callback) const;
+  /// Walk all of the AffineExpr's in this expression in postorder. This allows
+  /// a lambda walk function that can either return `void` or a WalkResult. With
+  /// a WalkResult, interrupting is supported.
+  template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+  std::enable_if_t<std::is_same<RetT, void>::value, RetT>
+  walk(FnT &&callback) const {
+    return walk<void>(*this, callback);
+  }
+  template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+  std::enable_if_t<std::is_same<RetT, WalkResult>::value, RetT>
+  walk(FnT &&callback) const {
+    return walk<WalkResult>(*this, callback);
+  }
 
   /// This method substitutes any uses of dimensions and symbols (e.g.
   /// dim#0 with dimReplacements[0]) and returns the modified expression tree.
@@ -202,6 +214,15 @@ class AffineExpr {
 
 protected:
   ImplType *expr{nullptr};
+
+private:
+  /// A trampoline for the templated non-static AffineExpr::walk method to
+  /// dispatch lambda `callback`'s of either a void result type or a
+  /// WalkResult type. Walk all of the AffineExprs in `e` in postorder. Users
+  /// should use the regular (non-static) `walk` method.
+  template <typename WalkRetTy>
+  static WalkRetTy walk(AffineExpr e,
+                        function_ref<WalkRetTy(AffineExpr)> callback);
 };
 
 /// Affine binary operation expression. An affine binary operation could be an
diff --git a/mlir/include/mlir/IR/AffineExprVisitor.h b/mlir/include/mlir/IR/AffineExprVisitor.h
index 2860e73c8f428..5b3663d1dea7e 100644
--- a/mlir/include/mlir/IR/AffineExprVisitor.h
+++ b/mlir/include/mlir/IR/AffineExprVisitor.h
@@ -30,6 +30,9 @@ namespace mlir {
 /// functions in your class. This class is defined in terms of statically
 /// resolved overloading, not virtual functions.
 ///
+/// The visitor is templated on its return type (`RetTy`). With a WalkResult
+/// return type, the visitor supports interrupting walks.
+///
 /// For example, here is a visitor that counts the number of for AffineDimExprs
 /// in an AffineExpr.
 ///
@@ -65,7 +68,6 @@ namespace mlir {
 /// virtual function call overhead. Defining and using a AffineExprVisitor is
 /// just as efficient as having your own switch instruction over the instruction
 /// opcode.
-
 template <typename SubClass, typename RetTy>
 class AffineExprVisitorBase {
 public:
@@ -136,6 +138,8 @@ class AffineExprVisitorBase {
   RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
 };
 
+/// See documentation for AffineExprVisitorBase. This visitor supports
+/// interrupting walks when a `WalkResult` is used for `RetTy`.
 template <typename SubClass, typename RetTy = void>
 class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
   //===--------------------------------------------------------------------===//
@@ -150,27 +154,52 @@ class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
     switch (expr.getKind()) {
     case AffineExprKind::Add: {
       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
-      walkOperandsPostOrder(binOpExpr);
+      if constexpr (std::is_same<RetTy, WalkResult>::value) {
+        if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
+          return WalkResult::interrupt();
+      } else {
+        walkOperandsPostOrder(binOpExpr);
+      }
       return self->visitAddExpr(binOpExpr);
     }
     case AffineExprKind::Mul: {
       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
-      walkOperandsPostOrder(binOpExpr);
+      if constexpr (std::is_same<RetTy, WalkResult>::value) {
+        if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
+          return WalkResult::interrupt();
+      } else {
+        walkOperandsPostOrder(binOpExpr);
+      }
       return self->visitMulExpr(binOpExpr);
     }
     case AffineExprKind::Mod: {
       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
-      walkOperandsPostOrder(binOpExpr);
+      if constexpr (std::is_same<RetTy, WalkResult>::value) {
+        if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
+          return WalkResult::interrupt();
+      } else {
+        walkOperandsPostOrder(binOpExpr);
+      }
       return self->visitModExpr(binOpExpr);
     }
     case AffineExprKind::FloorDiv: {
       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
-      walkOperandsPostOrder(binOpExpr);
+      if constexpr (std::is_same<RetTy, WalkResult>::value) {
+        if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
+          return WalkResult::interrupt();
+      } else {
+        walkOperandsPostOrder(binOpExpr);
+      }
       return self->visitFloorDivExpr(binOpExpr);
     }
     case AffineExprKind::CeilDiv: {
       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
-      walkOperandsPostOrder(binOpExpr);
+      if constexpr (std::is_same<RetTy, WalkResult>::value) {
+        if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
+          return WalkResult::interrupt();
+      } else {
+        walkOperandsPostOrder(binOpExpr);
+      }
       return self->visitCeilDivExpr(binOpExpr);
     }
     case AffineExprKind::Constant:
@@ -186,8 +215,19 @@ class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
 private:
   // Walk the operands - each operand is itself walked in post order.
   RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
-    walkPostOrder(expr.getLHS());
-    walkPostOrder(expr.getRHS());
+    if constexpr (std::is_same<RetTy, WalkResult>::value) {
+      if (walkPostOrder(expr.getLHS()).wasInterrupted())
+        return WalkResult::interrupt();
+    } else {
+      walkPostOrder(expr.getLHS());
+    }
+    if constexpr (std::is_same<RetTy, WalkResult>::value) {
+      if (walkPostOrder(expr.getLHS()).wasInterrupted())
+        return WalkResult::interrupt();
+      return WalkResult::advance();
+    } else {
+      walkPostOrder(expr.getRHS());
+    }
   }
 };
 
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 50a052fb8b74e..578d03c629285 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1561,22 +1561,21 @@ static LogicalResult getTileSizePos(
 /// memref<4x?xf32, #map0>  ==>  memref<4x?x?xf32>
 static bool
 isNormalizedMemRefDynamicDim(unsigned dim, AffineMap layoutMap,
-                             SmallVectorImpl<unsigned> &inMemrefTypeDynDims,
-                             MLIRContext *context) {
-  bool isDynamicDim = false;
+                             SmallVectorImpl<unsigned> &inMemrefTypeDynDims) {
   AffineExpr expr = layoutMap.getResults()[dim];
   // Check if affine expr of the dimension includes dynamic dimension of input
   // memrefType.
-  expr.walk([&inMemrefTypeDynDims, &isDynamicDim, &context](AffineExpr e) {
-    if (isa<AffineDimExpr>(e)) {
-      for (unsigned dm : inMemrefTypeDynDims) {
-        if (e == getAffineDimExpr(dm, context)) {
-          isDynamicDim = true;
-        }
-      }
-    }
-  });
-  return isDynamicDim;
+  MLIRContext *context = layoutMap.getContext();
+  return expr
+      .walk([&](AffineExpr e) {
+        if (isa<AffineDimExpr>(e) &&
+            llvm::any_of(inMemrefTypeDynDims, [&](unsigned dim) {
+              return e == getAffineDimExpr(dim, context);
+            }))
+          return WalkResult::interrupt();
+        return WalkResult::advance();
+      })
+      .wasInterrupted();
 }
 
 /// Create affine expr to calculate dimension size for a tiled-layout map.
@@ -1792,29 +1791,28 @@ MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
   MLIRContext *context = memrefType.getContext();
   for (unsigned d = 0; d < newRank; ++d) {
     // Check if this dimension is dynamic.
-    bool isDynDim =
-        isNormalizedMemRefDynamicDim(d, layoutMap, memrefTypeDynDims, context);
-    if (isDynDim) {
+    if (bool isDynDim =
+            isNormalizedMemRefDynamicDim(d, layoutMap, memrefTypeDynDims)) {
       newShape[d] = ShapedType::kDynamic;
-    } else {
-      // The lower bound for the shape is always zero.
-      std::optional<int64_t> ubConst = fac.getConstantBound64(BoundType::UB, d);
-      // For a static memref and an affine map with no symbols, this is
-      // always bounded. However, when we have symbols, we may not be able to
-      // obtain a constant upper bound. Also, mapping to a negative space is
-      // invalid for normalization.
-      if (!ubConst.has_value() || *ubConst < 0) {
-        LLVM_DEBUG(llvm::dbgs()
-                   << "can't normalize map due to unknown/invalid upper bound");
-        return memrefType;
-      }
-      // If dimension of new memrefType is dynamic, the value is -1.
-      newShape[d] = *ubConst + 1;
+      continue;
+    }
+    // The lower bound for the shape is always zero.
+    std::optional<int64_t> ubConst = fac.getConstantBound64(BoundType::UB, d);
+    // For a static memref and an affine map with no symbols, this is
+    // always bounded. However, when we have symbols, we may not be able to
+    // obtain a constant upper bound. Also, mapping to a negative space is
+    // invalid for normalization.
+    if (!ubConst.has_value() || *ubConst < 0) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "can't normalize map due to unknown/invalid upper bound");
+      return memrefType;
     }
+    // If dimension of new memrefType is dynamic, the value is -1.
+    newShape[d] = *ubConst + 1;
   }
 
   // Create the new memref type after trivializing the old layout map.
-  MemRefType newMemRefType =
+  auto newMemRefType =
       MemRefType::Builder(memrefType)
           .setShape(newShape)
           .setLayout(AffineMapAttr::get(
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 038ceea286a36..a90b264a8edd2 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -26,22 +26,37 @@ MLIRContext *AffineExpr::getContext() const { return expr->context; }
 
 AffineExprKind AffineExpr::getKind() const { return expr->kind; }
 
-/// Walk all of the AffineExprs in this subgraph in postorder.
-void AffineExpr::walk(std::function<void(AffineExpr)> callback) const {
-  struct AffineExprWalker : public AffineExprVisitor<AffineExprWalker> {
-    std::function<void(AffineExpr)> callback;
-
-    AffineExprWalker(std::function<void(AffineExpr)> callback)
-        : callback(std::move(callback)) {}
-
-    void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { callback(expr); }
-    void visitConstantExpr(AffineConstantExpr expr) { callback(expr); }
-    void visitDimExpr(AffineDimExpr expr) { callback(expr); }
-    void visitSymbolExpr(AffineSymbolExpr expr) { callback(expr); }
+/// Walk all of the AffineExprs in `e` in postorder. This is a private factory
+/// method to help handle lambda walk functions. Users should use the regular
+/// (non-static) `walk` method.
+template <typename WalkRetTy>
+WalkRetTy mlir::AffineExpr::walk(AffineExpr e,
+                                 function_ref<WalkRetTy(AffineExpr)> callback) {
+  struct AffineExprWalker
+      : public AffineExprVisitor<AffineExprWalker, WalkRetTy> {
+    function_ref<WalkRetTy(AffineExpr)> callback;
+
+    AffineExprWalker(function_ref<WalkRetTy(AffineExpr)> callback)
+        : callback(callback) {}
+
+    WalkRetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
+      return callback(expr);
+    }
+    WalkRetTy visitConstantExpr(AffineConstantExpr expr) {
+      return callback(expr);
+    }
+    WalkRetTy visitDimExpr(AffineDimExpr expr) { return callback(expr); }
+    WalkRetTy visitSymbolExpr(AffineSymbolExpr expr) { return callback(expr); }
   };
 
-  AffineExprWalker(std::move(callback)).walkPostOrder(*this);
+  return AffineExprWalker(callback).walkPostOrder(e);
 }
+// Explicitly instantiate for the two supported return types.
+template void mlir::AffineExpr::walk(AffineExpr e,
+                                     function_ref<void(AffineExpr)> callback);
+template WalkResult
+mlir::AffineExpr::walk(AffineExpr e,
+                       function_ref<WalkResult(AffineExpr)> callback);
 
 // Dispatch affine expression construction based on kind.
 AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,

@llvmbot
Copy link
Member

llvmbot commented Dec 8, 2023

@llvm/pr-subscribers-mlir-core

Author: Uday Bondhugula (bondhugula)

Changes

Support WalkResult for AffineExpr walk and support interrupting walks
along the lines of Operation::walk. This allows interrupted walks when a
condition is met. Also, switch from std::function to llvm::function_ref
for the walk function.


Full diff: https://github.com/llvm/llvm-project/pull/74792.diff

4 Files Affected:

  • (modified) mlir/include/mlir/IR/AffineExpr.h (+23-2)
  • (modified) mlir/include/mlir/IR/AffineExprVisitor.h (+48-8)
  • (modified) mlir/lib/Dialect/Affine/Utils/Utils.cpp (+29-31)
  • (modified) mlir/lib/IR/AffineExpr.cpp (+28-13)
diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h
index 40e9d28ce5d3a0..181a24472473a6 100644
--- a/mlir/include/mlir/IR/AffineExpr.h
+++ b/mlir/include/mlir/IR/AffineExpr.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_IR_AFFINEEXPR_H
 #define MLIR_IR_AFFINEEXPR_H
 
+#include "mlir/IR/Visitors.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/DenseMapInfo.h"
 #include "llvm/ADT/Hashing.h"
@@ -123,8 +124,19 @@ class AffineExpr {
   /// Return true if the affine expression involves AffineSymbolExpr `position`.
   bool isFunctionOfSymbol(unsigned position) const;
 
-  /// Walk all of the AffineExpr's in this expression in postorder.
-  void walk(std::function<void(AffineExpr)> callback) const;
+  /// Walk all of the AffineExpr's in this expression in postorder. This allows
+  /// a lambda walk function that can either return `void` or a WalkResult. With
+  /// a WalkResult, interrupting is supported.
+  template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+  std::enable_if_t<std::is_same<RetT, void>::value, RetT>
+  walk(FnT &&callback) const {
+    return walk<void>(*this, callback);
+  }
+  template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+  std::enable_if_t<std::is_same<RetT, WalkResult>::value, RetT>
+  walk(FnT &&callback) const {
+    return walk<WalkResult>(*this, callback);
+  }
 
   /// This method substitutes any uses of dimensions and symbols (e.g.
   /// dim#0 with dimReplacements[0]) and returns the modified expression tree.
@@ -202,6 +214,15 @@ class AffineExpr {
 
 protected:
   ImplType *expr{nullptr};
+
+private:
+  /// A trampoline for the templated non-static AffineExpr::walk method to
+  /// dispatch lambda `callback`'s of either a void result type or a
+  /// WalkResult type. Walk all of the AffineExprs in `e` in postorder. Users
+  /// should use the regular (non-static) `walk` method.
+  template <typename WalkRetTy>
+  static WalkRetTy walk(AffineExpr e,
+                        function_ref<WalkRetTy(AffineExpr)> callback);
 };
 
 /// Affine binary operation expression. An affine binary operation could be an
diff --git a/mlir/include/mlir/IR/AffineExprVisitor.h b/mlir/include/mlir/IR/AffineExprVisitor.h
index 2860e73c8f4283..5b3663d1dea7ea 100644
--- a/mlir/include/mlir/IR/AffineExprVisitor.h
+++ b/mlir/include/mlir/IR/AffineExprVisitor.h
@@ -30,6 +30,9 @@ namespace mlir {
 /// functions in your class. This class is defined in terms of statically
 /// resolved overloading, not virtual functions.
 ///
+/// The visitor is templated on its return type (`RetTy`). With a WalkResult
+/// return type, the visitor supports interrupting walks.
+///
 /// For example, here is a visitor that counts the number of for AffineDimExprs
 /// in an AffineExpr.
 ///
@@ -65,7 +68,6 @@ namespace mlir {
 /// virtual function call overhead. Defining and using a AffineExprVisitor is
 /// just as efficient as having your own switch instruction over the instruction
 /// opcode.
-
 template <typename SubClass, typename RetTy>
 class AffineExprVisitorBase {
 public:
@@ -136,6 +138,8 @@ class AffineExprVisitorBase {
   RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
 };
 
+/// See documentation for AffineExprVisitorBase. This visitor supports
+/// interrupting walks when a `WalkResult` is used for `RetTy`.
 template <typename SubClass, typename RetTy = void>
 class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
   //===--------------------------------------------------------------------===//
@@ -150,27 +154,52 @@ class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
     switch (expr.getKind()) {
     case AffineExprKind::Add: {
       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
-      walkOperandsPostOrder(binOpExpr);
+      if constexpr (std::is_same<RetTy, WalkResult>::value) {
+        if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
+          return WalkResult::interrupt();
+      } else {
+        walkOperandsPostOrder(binOpExpr);
+      }
       return self->visitAddExpr(binOpExpr);
     }
     case AffineExprKind::Mul: {
       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
-      walkOperandsPostOrder(binOpExpr);
+      if constexpr (std::is_same<RetTy, WalkResult>::value) {
+        if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
+          return WalkResult::interrupt();
+      } else {
+        walkOperandsPostOrder(binOpExpr);
+      }
       return self->visitMulExpr(binOpExpr);
     }
     case AffineExprKind::Mod: {
       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
-      walkOperandsPostOrder(binOpExpr);
+      if constexpr (std::is_same<RetTy, WalkResult>::value) {
+        if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
+          return WalkResult::interrupt();
+      } else {
+        walkOperandsPostOrder(binOpExpr);
+      }
       return self->visitModExpr(binOpExpr);
     }
     case AffineExprKind::FloorDiv: {
       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
-      walkOperandsPostOrder(binOpExpr);
+      if constexpr (std::is_same<RetTy, WalkResult>::value) {
+        if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
+          return WalkResult::interrupt();
+      } else {
+        walkOperandsPostOrder(binOpExpr);
+      }
       return self->visitFloorDivExpr(binOpExpr);
     }
     case AffineExprKind::CeilDiv: {
       auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
-      walkOperandsPostOrder(binOpExpr);
+      if constexpr (std::is_same<RetTy, WalkResult>::value) {
+        if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
+          return WalkResult::interrupt();
+      } else {
+        walkOperandsPostOrder(binOpExpr);
+      }
       return self->visitCeilDivExpr(binOpExpr);
     }
     case AffineExprKind::Constant:
@@ -186,8 +215,19 @@ class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
 private:
   // Walk the operands - each operand is itself walked in post order.
   RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
-    walkPostOrder(expr.getLHS());
-    walkPostOrder(expr.getRHS());
+    if constexpr (std::is_same<RetTy, WalkResult>::value) {
+      if (walkPostOrder(expr.getLHS()).wasInterrupted())
+        return WalkResult::interrupt();
+    } else {
+      walkPostOrder(expr.getLHS());
+    }
+    if constexpr (std::is_same<RetTy, WalkResult>::value) {
+      if (walkPostOrder(expr.getLHS()).wasInterrupted())
+        return WalkResult::interrupt();
+      return WalkResult::advance();
+    } else {
+      walkPostOrder(expr.getRHS());
+    }
   }
 };
 
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 50a052fb8b74e7..578d03c629285a 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1561,22 +1561,21 @@ static LogicalResult getTileSizePos(
 /// memref<4x?xf32, #map0>  ==>  memref<4x?x?xf32>
 static bool
 isNormalizedMemRefDynamicDim(unsigned dim, AffineMap layoutMap,
-                             SmallVectorImpl<unsigned> &inMemrefTypeDynDims,
-                             MLIRContext *context) {
-  bool isDynamicDim = false;
+                             SmallVectorImpl<unsigned> &inMemrefTypeDynDims) {
   AffineExpr expr = layoutMap.getResults()[dim];
   // Check if affine expr of the dimension includes dynamic dimension of input
   // memrefType.
-  expr.walk([&inMemrefTypeDynDims, &isDynamicDim, &context](AffineExpr e) {
-    if (isa<AffineDimExpr>(e)) {
-      for (unsigned dm : inMemrefTypeDynDims) {
-        if (e == getAffineDimExpr(dm, context)) {
-          isDynamicDim = true;
-        }
-      }
-    }
-  });
-  return isDynamicDim;
+  MLIRContext *context = layoutMap.getContext();
+  return expr
+      .walk([&](AffineExpr e) {
+        if (isa<AffineDimExpr>(e) &&
+            llvm::any_of(inMemrefTypeDynDims, [&](unsigned dim) {
+              return e == getAffineDimExpr(dim, context);
+            }))
+          return WalkResult::interrupt();
+        return WalkResult::advance();
+      })
+      .wasInterrupted();
 }
 
 /// Create affine expr to calculate dimension size for a tiled-layout map.
@@ -1792,29 +1791,28 @@ MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
   MLIRContext *context = memrefType.getContext();
   for (unsigned d = 0; d < newRank; ++d) {
     // Check if this dimension is dynamic.
-    bool isDynDim =
-        isNormalizedMemRefDynamicDim(d, layoutMap, memrefTypeDynDims, context);
-    if (isDynDim) {
+    if (bool isDynDim =
+            isNormalizedMemRefDynamicDim(d, layoutMap, memrefTypeDynDims)) {
       newShape[d] = ShapedType::kDynamic;
-    } else {
-      // The lower bound for the shape is always zero.
-      std::optional<int64_t> ubConst = fac.getConstantBound64(BoundType::UB, d);
-      // For a static memref and an affine map with no symbols, this is
-      // always bounded. However, when we have symbols, we may not be able to
-      // obtain a constant upper bound. Also, mapping to a negative space is
-      // invalid for normalization.
-      if (!ubConst.has_value() || *ubConst < 0) {
-        LLVM_DEBUG(llvm::dbgs()
-                   << "can't normalize map due to unknown/invalid upper bound");
-        return memrefType;
-      }
-      // If dimension of new memrefType is dynamic, the value is -1.
-      newShape[d] = *ubConst + 1;
+      continue;
+    }
+    // The lower bound for the shape is always zero.
+    std::optional<int64_t> ubConst = fac.getConstantBound64(BoundType::UB, d);
+    // For a static memref and an affine map with no symbols, this is
+    // always bounded. However, when we have symbols, we may not be able to
+    // obtain a constant upper bound. Also, mapping to a negative space is
+    // invalid for normalization.
+    if (!ubConst.has_value() || *ubConst < 0) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "can't normalize map due to unknown/invalid upper bound");
+      return memrefType;
     }
+    // If dimension of new memrefType is dynamic, the value is -1.
+    newShape[d] = *ubConst + 1;
   }
 
   // Create the new memref type after trivializing the old layout map.
-  MemRefType newMemRefType =
+  auto newMemRefType =
       MemRefType::Builder(memrefType)
           .setShape(newShape)
           .setLayout(AffineMapAttr::get(
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 038ceea286a363..a90b264a8edd26 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -26,22 +26,37 @@ MLIRContext *AffineExpr::getContext() const { return expr->context; }
 
 AffineExprKind AffineExpr::getKind() const { return expr->kind; }
 
-/// Walk all of the AffineExprs in this subgraph in postorder.
-void AffineExpr::walk(std::function<void(AffineExpr)> callback) const {
-  struct AffineExprWalker : public AffineExprVisitor<AffineExprWalker> {
-    std::function<void(AffineExpr)> callback;
-
-    AffineExprWalker(std::function<void(AffineExpr)> callback)
-        : callback(std::move(callback)) {}
-
-    void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { callback(expr); }
-    void visitConstantExpr(AffineConstantExpr expr) { callback(expr); }
-    void visitDimExpr(AffineDimExpr expr) { callback(expr); }
-    void visitSymbolExpr(AffineSymbolExpr expr) { callback(expr); }
+/// Walk all of the AffineExprs in `e` in postorder. This is a private factory
+/// method to help handle lambda walk functions. Users should use the regular
+/// (non-static) `walk` method.
+template <typename WalkRetTy>
+WalkRetTy mlir::AffineExpr::walk(AffineExpr e,
+                                 function_ref<WalkRetTy(AffineExpr)> callback) {
+  struct AffineExprWalker
+      : public AffineExprVisitor<AffineExprWalker, WalkRetTy> {
+    function_ref<WalkRetTy(AffineExpr)> callback;
+
+    AffineExprWalker(function_ref<WalkRetTy(AffineExpr)> callback)
+        : callback(callback) {}
+
+    WalkRetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
+      return callback(expr);
+    }
+    WalkRetTy visitConstantExpr(AffineConstantExpr expr) {
+      return callback(expr);
+    }
+    WalkRetTy visitDimExpr(AffineDimExpr expr) { return callback(expr); }
+    WalkRetTy visitSymbolExpr(AffineSymbolExpr expr) { return callback(expr); }
   };
 
-  AffineExprWalker(std::move(callback)).walkPostOrder(*this);
+  return AffineExprWalker(callback).walkPostOrder(e);
 }
+// Explicitly instantiate for the two supported return types.
+template void mlir::AffineExpr::walk(AffineExpr e,
+                                     function_ref<void(AffineExpr)> callback);
+template WalkResult
+mlir::AffineExpr::walk(AffineExpr e,
+                       function_ref<WalkResult(AffineExpr)> callback);
 
 // Dispatch affine expression construction based on kind.
 AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,

@bondhugula bondhugula changed the title [MLIR] Support interupting AffineExpr walks [MLIR] Support interrupting AffineExpr walks Dec 8, 2023
@bondhugula bondhugula force-pushed the uday/support_affine_walk_result branch from bd399fd to 230f468 Compare December 8, 2023 01:03
Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG, but would be better with some tests showing the interruption: do we have unit-tests for AffineExpr?

@bondhugula
Copy link
Contributor Author

LG, but would be better with some tests showing the interruption: do we have unit-tests for AffineExpr?

It's being exercised by an affine utility - the change to it is included here, but it doesn't really prove it's working, but only that compilation (of the compiler) is succeeding. We'll need to think of a new unit test where the interrupt emits a diagnostic which is checked. Is this what you had in mind?

@joker-eph
Copy link
Collaborator

Yes, something that shows we interrupt and propagate correctly, I saw we have « code coverage » with the utility, but not « path coverage » for the new behavior I think?

@bondhugula
Copy link
Contributor Author

Yes, something that shows we interrupt and propagate correctly, I saw we have « code coverage » with the utility, but not « path coverage » for the new behavior I think?

That's right - I think we'll need a test pass. I can think of a simple utility -- for eg. "check whether an affine expression has a modulo in it" and emit a diagnostic. There should be just one diagnostic emitted as a result even in the presence of multiple, etc.

@ftynse
Copy link
Member

ftynse commented Dec 8, 2023

We can also have a proper unit test here https://github.com/llvm/llvm-project/tree/main/mlir/unittests/IR. This looks like an API-level feature and having to roll a test pass for it may be too much of unwarranted complexity.

@bondhugula bondhugula force-pushed the uday/support_affine_walk_result branch from 230f468 to cba3731 Compare December 9, 2023 01:46
@bondhugula bondhugula force-pushed the uday/support_affine_walk_result branch from cba3731 to 3bbfe08 Compare December 28, 2023 10:53
@bondhugula
Copy link
Contributor Author

Yes, something that shows we interrupt and propagate correctly, I saw we have « code coverage » with the utility, but not « path coverage » for the new behavior I think?

I've now added a test pass to exercise this.

@bondhugula
Copy link
Contributor Author

We can also have a proper unit test here https://github.com/llvm/llvm-project/tree/main/mlir/unittests/IR. This looks like an API-level feature and having to roll a test pass for it may be too much of unwarranted complexity.

Makes sense, but I couldn't immediately see an obvious way to test the interrupt via the unit tests while it was easy via diagnostics. I can still create an equivalent unit test if the test pass looks too heavy (for e.g. increases build time etc. when comapred to unit tests).

@bondhugula bondhugula force-pushed the uday/support_affine_walk_result branch from 3bbfe08 to 38a05d4 Compare December 28, 2023 11:12
Support WalkResult for AffineExpr walk and support interrupting walks
along the lines of Operation::walk. This allows interrupted walks when a
condition is met. Also, switch from std::function to llvm::function_ref
for the walk function.
@bondhugula bondhugula force-pushed the uday/support_affine_walk_result branch from 38a05d4 to 32627e9 Compare January 4, 2024 09:55
@bondhugula bondhugula merged commit c1eef48 into llvm:main Jan 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:affine mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants