-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR] Support interrupting AffineExpr walks #74792
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR] Support interrupting AffineExpr walks #74792
Conversation
@llvm/pr-subscribers-mlir-affine @llvm/pr-subscribers-mlir Author: Uday Bondhugula (bondhugula) ChangesSupport WalkResult for AffineExpr walk and support interrupting walks Full diff: https://github.com/llvm/llvm-project/pull/74792.diff 4 Files Affected:
diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h
index 40e9d28ce5d3a..181a24472473a 100644
--- a/mlir/include/mlir/IR/AffineExpr.h
+++ b/mlir/include/mlir/IR/AffineExpr.h
@@ -14,6 +14,7 @@
#ifndef MLIR_IR_AFFINEEXPR_H
#define MLIR_IR_AFFINEEXPR_H
+#include "mlir/IR/Visitors.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/Hashing.h"
@@ -123,8 +124,19 @@ class AffineExpr {
/// Return true if the affine expression involves AffineSymbolExpr `position`.
bool isFunctionOfSymbol(unsigned position) const;
- /// Walk all of the AffineExpr's in this expression in postorder.
- void walk(std::function<void(AffineExpr)> callback) const;
+ /// Walk all of the AffineExpr's in this expression in postorder. This allows
+ /// a lambda walk function that can either return `void` or a WalkResult. With
+ /// a WalkResult, interrupting is supported.
+ template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+ std::enable_if_t<std::is_same<RetT, void>::value, RetT>
+ walk(FnT &&callback) const {
+ return walk<void>(*this, callback);
+ }
+ template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+ std::enable_if_t<std::is_same<RetT, WalkResult>::value, RetT>
+ walk(FnT &&callback) const {
+ return walk<WalkResult>(*this, callback);
+ }
/// This method substitutes any uses of dimensions and symbols (e.g.
/// dim#0 with dimReplacements[0]) and returns the modified expression tree.
@@ -202,6 +214,15 @@ class AffineExpr {
protected:
ImplType *expr{nullptr};
+
+private:
+ /// A trampoline for the templated non-static AffineExpr::walk method to
+ /// dispatch lambda `callback`'s of either a void result type or a
+ /// WalkResult type. Walk all of the AffineExprs in `e` in postorder. Users
+ /// should use the regular (non-static) `walk` method.
+ template <typename WalkRetTy>
+ static WalkRetTy walk(AffineExpr e,
+ function_ref<WalkRetTy(AffineExpr)> callback);
};
/// Affine binary operation expression. An affine binary operation could be an
diff --git a/mlir/include/mlir/IR/AffineExprVisitor.h b/mlir/include/mlir/IR/AffineExprVisitor.h
index 2860e73c8f428..5b3663d1dea7e 100644
--- a/mlir/include/mlir/IR/AffineExprVisitor.h
+++ b/mlir/include/mlir/IR/AffineExprVisitor.h
@@ -30,6 +30,9 @@ namespace mlir {
/// functions in your class. This class is defined in terms of statically
/// resolved overloading, not virtual functions.
///
+/// The visitor is templated on its return type (`RetTy`). With a WalkResult
+/// return type, the visitor supports interrupting walks.
+///
/// For example, here is a visitor that counts the number of for AffineDimExprs
/// in an AffineExpr.
///
@@ -65,7 +68,6 @@ namespace mlir {
/// virtual function call overhead. Defining and using a AffineExprVisitor is
/// just as efficient as having your own switch instruction over the instruction
/// opcode.
-
template <typename SubClass, typename RetTy>
class AffineExprVisitorBase {
public:
@@ -136,6 +138,8 @@ class AffineExprVisitorBase {
RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
};
+/// See documentation for AffineExprVisitorBase. This visitor supports
+/// interrupting walks when a `WalkResult` is used for `RetTy`.
template <typename SubClass, typename RetTy = void>
class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
//===--------------------------------------------------------------------===//
@@ -150,27 +154,52 @@ class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
switch (expr.getKind()) {
case AffineExprKind::Add: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
- walkOperandsPostOrder(binOpExpr);
+ if constexpr (std::is_same<RetTy, WalkResult>::value) {
+ if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ walkOperandsPostOrder(binOpExpr);
+ }
return self->visitAddExpr(binOpExpr);
}
case AffineExprKind::Mul: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
- walkOperandsPostOrder(binOpExpr);
+ if constexpr (std::is_same<RetTy, WalkResult>::value) {
+ if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ walkOperandsPostOrder(binOpExpr);
+ }
return self->visitMulExpr(binOpExpr);
}
case AffineExprKind::Mod: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
- walkOperandsPostOrder(binOpExpr);
+ if constexpr (std::is_same<RetTy, WalkResult>::value) {
+ if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ walkOperandsPostOrder(binOpExpr);
+ }
return self->visitModExpr(binOpExpr);
}
case AffineExprKind::FloorDiv: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
- walkOperandsPostOrder(binOpExpr);
+ if constexpr (std::is_same<RetTy, WalkResult>::value) {
+ if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ walkOperandsPostOrder(binOpExpr);
+ }
return self->visitFloorDivExpr(binOpExpr);
}
case AffineExprKind::CeilDiv: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
- walkOperandsPostOrder(binOpExpr);
+ if constexpr (std::is_same<RetTy, WalkResult>::value) {
+ if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ walkOperandsPostOrder(binOpExpr);
+ }
return self->visitCeilDivExpr(binOpExpr);
}
case AffineExprKind::Constant:
@@ -186,8 +215,19 @@ class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
private:
// Walk the operands - each operand is itself walked in post order.
RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
- walkPostOrder(expr.getLHS());
- walkPostOrder(expr.getRHS());
+ if constexpr (std::is_same<RetTy, WalkResult>::value) {
+ if (walkPostOrder(expr.getLHS()).wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ walkPostOrder(expr.getLHS());
+ }
+ if constexpr (std::is_same<RetTy, WalkResult>::value) {
+ if (walkPostOrder(expr.getLHS()).wasInterrupted())
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ } else {
+ walkPostOrder(expr.getRHS());
+ }
}
};
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 50a052fb8b74e..578d03c629285 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1561,22 +1561,21 @@ static LogicalResult getTileSizePos(
/// memref<4x?xf32, #map0> ==> memref<4x?x?xf32>
static bool
isNormalizedMemRefDynamicDim(unsigned dim, AffineMap layoutMap,
- SmallVectorImpl<unsigned> &inMemrefTypeDynDims,
- MLIRContext *context) {
- bool isDynamicDim = false;
+ SmallVectorImpl<unsigned> &inMemrefTypeDynDims) {
AffineExpr expr = layoutMap.getResults()[dim];
// Check if affine expr of the dimension includes dynamic dimension of input
// memrefType.
- expr.walk([&inMemrefTypeDynDims, &isDynamicDim, &context](AffineExpr e) {
- if (isa<AffineDimExpr>(e)) {
- for (unsigned dm : inMemrefTypeDynDims) {
- if (e == getAffineDimExpr(dm, context)) {
- isDynamicDim = true;
- }
- }
- }
- });
- return isDynamicDim;
+ MLIRContext *context = layoutMap.getContext();
+ return expr
+ .walk([&](AffineExpr e) {
+ if (isa<AffineDimExpr>(e) &&
+ llvm::any_of(inMemrefTypeDynDims, [&](unsigned dim) {
+ return e == getAffineDimExpr(dim, context);
+ }))
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ })
+ .wasInterrupted();
}
/// Create affine expr to calculate dimension size for a tiled-layout map.
@@ -1792,29 +1791,28 @@ MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
MLIRContext *context = memrefType.getContext();
for (unsigned d = 0; d < newRank; ++d) {
// Check if this dimension is dynamic.
- bool isDynDim =
- isNormalizedMemRefDynamicDim(d, layoutMap, memrefTypeDynDims, context);
- if (isDynDim) {
+ if (bool isDynDim =
+ isNormalizedMemRefDynamicDim(d, layoutMap, memrefTypeDynDims)) {
newShape[d] = ShapedType::kDynamic;
- } else {
- // The lower bound for the shape is always zero.
- std::optional<int64_t> ubConst = fac.getConstantBound64(BoundType::UB, d);
- // For a static memref and an affine map with no symbols, this is
- // always bounded. However, when we have symbols, we may not be able to
- // obtain a constant upper bound. Also, mapping to a negative space is
- // invalid for normalization.
- if (!ubConst.has_value() || *ubConst < 0) {
- LLVM_DEBUG(llvm::dbgs()
- << "can't normalize map due to unknown/invalid upper bound");
- return memrefType;
- }
- // If dimension of new memrefType is dynamic, the value is -1.
- newShape[d] = *ubConst + 1;
+ continue;
+ }
+ // The lower bound for the shape is always zero.
+ std::optional<int64_t> ubConst = fac.getConstantBound64(BoundType::UB, d);
+ // For a static memref and an affine map with no symbols, this is
+ // always bounded. However, when we have symbols, we may not be able to
+ // obtain a constant upper bound. Also, mapping to a negative space is
+ // invalid for normalization.
+ if (!ubConst.has_value() || *ubConst < 0) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "can't normalize map due to unknown/invalid upper bound");
+ return memrefType;
}
+ // If dimension of new memrefType is dynamic, the value is -1.
+ newShape[d] = *ubConst + 1;
}
// Create the new memref type after trivializing the old layout map.
- MemRefType newMemRefType =
+ auto newMemRefType =
MemRefType::Builder(memrefType)
.setShape(newShape)
.setLayout(AffineMapAttr::get(
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 038ceea286a36..a90b264a8edd2 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -26,22 +26,37 @@ MLIRContext *AffineExpr::getContext() const { return expr->context; }
AffineExprKind AffineExpr::getKind() const { return expr->kind; }
-/// Walk all of the AffineExprs in this subgraph in postorder.
-void AffineExpr::walk(std::function<void(AffineExpr)> callback) const {
- struct AffineExprWalker : public AffineExprVisitor<AffineExprWalker> {
- std::function<void(AffineExpr)> callback;
-
- AffineExprWalker(std::function<void(AffineExpr)> callback)
- : callback(std::move(callback)) {}
-
- void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { callback(expr); }
- void visitConstantExpr(AffineConstantExpr expr) { callback(expr); }
- void visitDimExpr(AffineDimExpr expr) { callback(expr); }
- void visitSymbolExpr(AffineSymbolExpr expr) { callback(expr); }
+/// Walk all of the AffineExprs in `e` in postorder. This is a private factory
+/// method to help handle lambda walk functions. Users should use the regular
+/// (non-static) `walk` method.
+template <typename WalkRetTy>
+WalkRetTy mlir::AffineExpr::walk(AffineExpr e,
+ function_ref<WalkRetTy(AffineExpr)> callback) {
+ struct AffineExprWalker
+ : public AffineExprVisitor<AffineExprWalker, WalkRetTy> {
+ function_ref<WalkRetTy(AffineExpr)> callback;
+
+ AffineExprWalker(function_ref<WalkRetTy(AffineExpr)> callback)
+ : callback(callback) {}
+
+ WalkRetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
+ return callback(expr);
+ }
+ WalkRetTy visitConstantExpr(AffineConstantExpr expr) {
+ return callback(expr);
+ }
+ WalkRetTy visitDimExpr(AffineDimExpr expr) { return callback(expr); }
+ WalkRetTy visitSymbolExpr(AffineSymbolExpr expr) { return callback(expr); }
};
- AffineExprWalker(std::move(callback)).walkPostOrder(*this);
+ return AffineExprWalker(callback).walkPostOrder(e);
}
+// Explicitly instantiate for the two supported return types.
+template void mlir::AffineExpr::walk(AffineExpr e,
+ function_ref<void(AffineExpr)> callback);
+template WalkResult
+mlir::AffineExpr::walk(AffineExpr e,
+ function_ref<WalkResult(AffineExpr)> callback);
// Dispatch affine expression construction based on kind.
AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
|
@llvm/pr-subscribers-mlir-core Author: Uday Bondhugula (bondhugula) ChangesSupport WalkResult for AffineExpr walk and support interrupting walks Full diff: https://github.com/llvm/llvm-project/pull/74792.diff 4 Files Affected:
diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h
index 40e9d28ce5d3a0..181a24472473a6 100644
--- a/mlir/include/mlir/IR/AffineExpr.h
+++ b/mlir/include/mlir/IR/AffineExpr.h
@@ -14,6 +14,7 @@
#ifndef MLIR_IR_AFFINEEXPR_H
#define MLIR_IR_AFFINEEXPR_H
+#include "mlir/IR/Visitors.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/Hashing.h"
@@ -123,8 +124,19 @@ class AffineExpr {
/// Return true if the affine expression involves AffineSymbolExpr `position`.
bool isFunctionOfSymbol(unsigned position) const;
- /// Walk all of the AffineExpr's in this expression in postorder.
- void walk(std::function<void(AffineExpr)> callback) const;
+ /// Walk all of the AffineExpr's in this expression in postorder. This allows
+ /// a lambda walk function that can either return `void` or a WalkResult. With
+ /// a WalkResult, interrupting is supported.
+ template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+ std::enable_if_t<std::is_same<RetT, void>::value, RetT>
+ walk(FnT &&callback) const {
+ return walk<void>(*this, callback);
+ }
+ template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+ std::enable_if_t<std::is_same<RetT, WalkResult>::value, RetT>
+ walk(FnT &&callback) const {
+ return walk<WalkResult>(*this, callback);
+ }
/// This method substitutes any uses of dimensions and symbols (e.g.
/// dim#0 with dimReplacements[0]) and returns the modified expression tree.
@@ -202,6 +214,15 @@ class AffineExpr {
protected:
ImplType *expr{nullptr};
+
+private:
+ /// A trampoline for the templated non-static AffineExpr::walk method to
+ /// dispatch lambda `callback`'s of either a void result type or a
+ /// WalkResult type. Walk all of the AffineExprs in `e` in postorder. Users
+ /// should use the regular (non-static) `walk` method.
+ template <typename WalkRetTy>
+ static WalkRetTy walk(AffineExpr e,
+ function_ref<WalkRetTy(AffineExpr)> callback);
};
/// Affine binary operation expression. An affine binary operation could be an
diff --git a/mlir/include/mlir/IR/AffineExprVisitor.h b/mlir/include/mlir/IR/AffineExprVisitor.h
index 2860e73c8f4283..5b3663d1dea7ea 100644
--- a/mlir/include/mlir/IR/AffineExprVisitor.h
+++ b/mlir/include/mlir/IR/AffineExprVisitor.h
@@ -30,6 +30,9 @@ namespace mlir {
/// functions in your class. This class is defined in terms of statically
/// resolved overloading, not virtual functions.
///
+/// The visitor is templated on its return type (`RetTy`). With a WalkResult
+/// return type, the visitor supports interrupting walks.
+///
/// For example, here is a visitor that counts the number of for AffineDimExprs
/// in an AffineExpr.
///
@@ -65,7 +68,6 @@ namespace mlir {
/// virtual function call overhead. Defining and using a AffineExprVisitor is
/// just as efficient as having your own switch instruction over the instruction
/// opcode.
-
template <typename SubClass, typename RetTy>
class AffineExprVisitorBase {
public:
@@ -136,6 +138,8 @@ class AffineExprVisitorBase {
RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
};
+/// See documentation for AffineExprVisitorBase. This visitor supports
+/// interrupting walks when a `WalkResult` is used for `RetTy`.
template <typename SubClass, typename RetTy = void>
class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
//===--------------------------------------------------------------------===//
@@ -150,27 +154,52 @@ class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
switch (expr.getKind()) {
case AffineExprKind::Add: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
- walkOperandsPostOrder(binOpExpr);
+ if constexpr (std::is_same<RetTy, WalkResult>::value) {
+ if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ walkOperandsPostOrder(binOpExpr);
+ }
return self->visitAddExpr(binOpExpr);
}
case AffineExprKind::Mul: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
- walkOperandsPostOrder(binOpExpr);
+ if constexpr (std::is_same<RetTy, WalkResult>::value) {
+ if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ walkOperandsPostOrder(binOpExpr);
+ }
return self->visitMulExpr(binOpExpr);
}
case AffineExprKind::Mod: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
- walkOperandsPostOrder(binOpExpr);
+ if constexpr (std::is_same<RetTy, WalkResult>::value) {
+ if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ walkOperandsPostOrder(binOpExpr);
+ }
return self->visitModExpr(binOpExpr);
}
case AffineExprKind::FloorDiv: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
- walkOperandsPostOrder(binOpExpr);
+ if constexpr (std::is_same<RetTy, WalkResult>::value) {
+ if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ walkOperandsPostOrder(binOpExpr);
+ }
return self->visitFloorDivExpr(binOpExpr);
}
case AffineExprKind::CeilDiv: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
- walkOperandsPostOrder(binOpExpr);
+ if constexpr (std::is_same<RetTy, WalkResult>::value) {
+ if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ walkOperandsPostOrder(binOpExpr);
+ }
return self->visitCeilDivExpr(binOpExpr);
}
case AffineExprKind::Constant:
@@ -186,8 +215,19 @@ class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
private:
// Walk the operands - each operand is itself walked in post order.
RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
- walkPostOrder(expr.getLHS());
- walkPostOrder(expr.getRHS());
+ if constexpr (std::is_same<RetTy, WalkResult>::value) {
+ if (walkPostOrder(expr.getLHS()).wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ walkPostOrder(expr.getLHS());
+ }
+ if constexpr (std::is_same<RetTy, WalkResult>::value) {
+ if (walkPostOrder(expr.getLHS()).wasInterrupted())
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ } else {
+ walkPostOrder(expr.getRHS());
+ }
}
};
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 50a052fb8b74e7..578d03c629285a 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1561,22 +1561,21 @@ static LogicalResult getTileSizePos(
/// memref<4x?xf32, #map0> ==> memref<4x?x?xf32>
static bool
isNormalizedMemRefDynamicDim(unsigned dim, AffineMap layoutMap,
- SmallVectorImpl<unsigned> &inMemrefTypeDynDims,
- MLIRContext *context) {
- bool isDynamicDim = false;
+ SmallVectorImpl<unsigned> &inMemrefTypeDynDims) {
AffineExpr expr = layoutMap.getResults()[dim];
// Check if affine expr of the dimension includes dynamic dimension of input
// memrefType.
- expr.walk([&inMemrefTypeDynDims, &isDynamicDim, &context](AffineExpr e) {
- if (isa<AffineDimExpr>(e)) {
- for (unsigned dm : inMemrefTypeDynDims) {
- if (e == getAffineDimExpr(dm, context)) {
- isDynamicDim = true;
- }
- }
- }
- });
- return isDynamicDim;
+ MLIRContext *context = layoutMap.getContext();
+ return expr
+ .walk([&](AffineExpr e) {
+ if (isa<AffineDimExpr>(e) &&
+ llvm::any_of(inMemrefTypeDynDims, [&](unsigned dim) {
+ return e == getAffineDimExpr(dim, context);
+ }))
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ })
+ .wasInterrupted();
}
/// Create affine expr to calculate dimension size for a tiled-layout map.
@@ -1792,29 +1791,28 @@ MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
MLIRContext *context = memrefType.getContext();
for (unsigned d = 0; d < newRank; ++d) {
// Check if this dimension is dynamic.
- bool isDynDim =
- isNormalizedMemRefDynamicDim(d, layoutMap, memrefTypeDynDims, context);
- if (isDynDim) {
+ if (bool isDynDim =
+ isNormalizedMemRefDynamicDim(d, layoutMap, memrefTypeDynDims)) {
newShape[d] = ShapedType::kDynamic;
- } else {
- // The lower bound for the shape is always zero.
- std::optional<int64_t> ubConst = fac.getConstantBound64(BoundType::UB, d);
- // For a static memref and an affine map with no symbols, this is
- // always bounded. However, when we have symbols, we may not be able to
- // obtain a constant upper bound. Also, mapping to a negative space is
- // invalid for normalization.
- if (!ubConst.has_value() || *ubConst < 0) {
- LLVM_DEBUG(llvm::dbgs()
- << "can't normalize map due to unknown/invalid upper bound");
- return memrefType;
- }
- // If dimension of new memrefType is dynamic, the value is -1.
- newShape[d] = *ubConst + 1;
+ continue;
+ }
+ // The lower bound for the shape is always zero.
+ std::optional<int64_t> ubConst = fac.getConstantBound64(BoundType::UB, d);
+ // For a static memref and an affine map with no symbols, this is
+ // always bounded. However, when we have symbols, we may not be able to
+ // obtain a constant upper bound. Also, mapping to a negative space is
+ // invalid for normalization.
+ if (!ubConst.has_value() || *ubConst < 0) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "can't normalize map due to unknown/invalid upper bound");
+ return memrefType;
}
+ // If dimension of new memrefType is dynamic, the value is -1.
+ newShape[d] = *ubConst + 1;
}
// Create the new memref type after trivializing the old layout map.
- MemRefType newMemRefType =
+ auto newMemRefType =
MemRefType::Builder(memrefType)
.setShape(newShape)
.setLayout(AffineMapAttr::get(
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 038ceea286a363..a90b264a8edd26 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -26,22 +26,37 @@ MLIRContext *AffineExpr::getContext() const { return expr->context; }
AffineExprKind AffineExpr::getKind() const { return expr->kind; }
-/// Walk all of the AffineExprs in this subgraph in postorder.
-void AffineExpr::walk(std::function<void(AffineExpr)> callback) const {
- struct AffineExprWalker : public AffineExprVisitor<AffineExprWalker> {
- std::function<void(AffineExpr)> callback;
-
- AffineExprWalker(std::function<void(AffineExpr)> callback)
- : callback(std::move(callback)) {}
-
- void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { callback(expr); }
- void visitConstantExpr(AffineConstantExpr expr) { callback(expr); }
- void visitDimExpr(AffineDimExpr expr) { callback(expr); }
- void visitSymbolExpr(AffineSymbolExpr expr) { callback(expr); }
+/// Walk all of the AffineExprs in `e` in postorder. This is a private factory
+/// method to help handle lambda walk functions. Users should use the regular
+/// (non-static) `walk` method.
+template <typename WalkRetTy>
+WalkRetTy mlir::AffineExpr::walk(AffineExpr e,
+ function_ref<WalkRetTy(AffineExpr)> callback) {
+ struct AffineExprWalker
+ : public AffineExprVisitor<AffineExprWalker, WalkRetTy> {
+ function_ref<WalkRetTy(AffineExpr)> callback;
+
+ AffineExprWalker(function_ref<WalkRetTy(AffineExpr)> callback)
+ : callback(callback) {}
+
+ WalkRetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
+ return callback(expr);
+ }
+ WalkRetTy visitConstantExpr(AffineConstantExpr expr) {
+ return callback(expr);
+ }
+ WalkRetTy visitDimExpr(AffineDimExpr expr) { return callback(expr); }
+ WalkRetTy visitSymbolExpr(AffineSymbolExpr expr) { return callback(expr); }
};
- AffineExprWalker(std::move(callback)).walkPostOrder(*this);
+ return AffineExprWalker(callback).walkPostOrder(e);
}
+// Explicitly instantiate for the two supported return types.
+template void mlir::AffineExpr::walk(AffineExpr e,
+ function_ref<void(AffineExpr)> callback);
+template WalkResult
+mlir::AffineExpr::walk(AffineExpr e,
+ function_ref<WalkResult(AffineExpr)> callback);
// Dispatch affine expression construction based on kind.
AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
|
bd399fd
to
230f468
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG, but would be better with some tests showing the interruption: do we have unit-tests for AffineExpr?
It's being exercised by an affine utility - the change to it is included here, but it doesn't really prove it's working, but only that compilation (of the compiler) is succeeding. We'll need to think of a new unit test where the interrupt emits a diagnostic which is checked. Is this what you had in mind? |
Yes, something that shows we interrupt and propagate correctly, I saw we have « code coverage » with the utility, but not « path coverage » for the new behavior I think? |
That's right - I think we'll need a test pass. I can think of a simple utility -- for eg. "check whether an affine expression has a modulo in it" and emit a diagnostic. There should be just one diagnostic emitted as a result even in the presence of multiple, etc. |
We can also have a proper unit test here https://github.com/llvm/llvm-project/tree/main/mlir/unittests/IR. This looks like an API-level feature and having to roll a test pass for it may be too much of unwarranted complexity. |
230f468
to
cba3731
Compare
cba3731
to
3bbfe08
Compare
I've now added a test pass to exercise this. |
Makes sense, but I couldn't immediately see an obvious way to test the interrupt via the unit tests while it was easy via diagnostics. I can still create an equivalent unit test if the test pass looks too heavy (for e.g. increases build time etc. when comapred to unit tests). |
3bbfe08
to
38a05d4
Compare
Support WalkResult for AffineExpr walk and support interrupting walks along the lines of Operation::walk. This allows interrupted walks when a condition is met. Also, switch from std::function to llvm::function_ref for the walk function.
38a05d4
to
32627e9
Compare
Support WalkResult for AffineExpr walk and support interrupting walks
along the lines of Operation::walk. This allows interrupted walks when a
condition is met. Also, switch from std::function to llvm::function_ref
for the walk function.