Skip to content

Commit 32627e9

Browse files
committed
[MLIR] Support interrupting AffineExpr walks
Support WalkResult for AffineExpr walk and support interrupting walks along the lines of Operation::walk. This allows interrupted walks when a condition is met. Also, switch from std::function to llvm::function_ref for the walk function.
1 parent 5ed11e7 commit 32627e9

File tree

8 files changed

+196
-59
lines changed

8 files changed

+196
-59
lines changed

mlir/include/mlir/IR/AffineExpr.h

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#ifndef MLIR_IR_AFFINEEXPR_H
1515
#define MLIR_IR_AFFINEEXPR_H
1616

17+
#include "mlir/IR/Visitors.h"
1718
#include "mlir/Support/LLVM.h"
1819
#include "llvm/ADT/DenseMapInfo.h"
1920
#include "llvm/ADT/Hashing.h"
@@ -123,8 +124,13 @@ class AffineExpr {
123124
/// Return true if the affine expression involves AffineSymbolExpr `position`.
124125
bool isFunctionOfSymbol(unsigned position) const;
125126

126-
/// Walk all of the AffineExpr's in this expression in postorder.
127-
void walk(std::function<void(AffineExpr)> callback) const;
127+
/// Walk all of the AffineExpr's in this expression in postorder. This allows
128+
/// a lambda walk function that can either return `void` or a WalkResult. With
129+
/// a WalkResult, interrupting is supported.
130+
template <typename FnT, typename RetT = detail::walkResultType<FnT>>
131+
RetT walk(FnT &&callback) const {
132+
return walk<RetT>(*this, callback);
133+
}
128134

129135
/// This method substitutes any uses of dimensions and symbols (e.g.
130136
/// dim#0 with dimReplacements[0]) and returns the modified expression tree.
@@ -202,6 +208,15 @@ class AffineExpr {
202208

203209
protected:
204210
ImplType *expr{nullptr};
211+
212+
private:
213+
/// A trampoline for the templated non-static AffineExpr::walk method to
214+
/// dispatch lambda `callback`'s of either a void result type or a
215+
/// WalkResult type. Walk all of the AffineExprs in `e` in postorder. Users
216+
/// should use the regular (non-static) `walk` method.
217+
template <typename WalkRetTy>
218+
static WalkRetTy walk(AffineExpr e,
219+
function_ref<WalkRetTy(AffineExpr)> callback);
205220
};
206221

207222
/// Affine binary operation expression. An affine binary operation could be an

mlir/include/mlir/IR/AffineExprVisitor.h

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ namespace mlir {
3030
/// functions in your class. This class is defined in terms of statically
3131
/// resolved overloading, not virtual functions.
3232
///
33+
/// The visitor is templated on its return type (`RetTy`). With a WalkResult
34+
/// return type, the visitor supports interrupting walks.
35+
///
3336
/// For example, here is a visitor that counts the number of for AffineDimExprs
3437
/// in an AffineExpr.
3538
///
@@ -65,7 +68,6 @@ namespace mlir {
6568
/// virtual function call overhead. Defining and using a AffineExprVisitor is
6669
/// just as efficient as having your own switch instruction over the instruction
6770
/// opcode.
68-
6971
template <typename SubClass, typename RetTy>
7072
class AffineExprVisitorBase {
7173
public:
@@ -136,6 +138,8 @@ class AffineExprVisitorBase {
136138
RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
137139
};
138140

141+
/// See documentation for AffineExprVisitorBase. This visitor supports
142+
/// interrupting walks when a `WalkResult` is used for `RetTy`.
139143
template <typename SubClass, typename RetTy = void>
140144
class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
141145
//===--------------------------------------------------------------------===//
@@ -150,27 +154,52 @@ class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
150154
switch (expr.getKind()) {
151155
case AffineExprKind::Add: {
152156
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
153-
walkOperandsPostOrder(binOpExpr);
157+
if constexpr (std::is_same<RetTy, WalkResult>::value) {
158+
if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
159+
return WalkResult::interrupt();
160+
} else {
161+
walkOperandsPostOrder(binOpExpr);
162+
}
154163
return self->visitAddExpr(binOpExpr);
155164
}
156165
case AffineExprKind::Mul: {
157166
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
158-
walkOperandsPostOrder(binOpExpr);
167+
if constexpr (std::is_same<RetTy, WalkResult>::value) {
168+
if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
169+
return WalkResult::interrupt();
170+
} else {
171+
walkOperandsPostOrder(binOpExpr);
172+
}
159173
return self->visitMulExpr(binOpExpr);
160174
}
161175
case AffineExprKind::Mod: {
162176
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
163-
walkOperandsPostOrder(binOpExpr);
177+
if constexpr (std::is_same<RetTy, WalkResult>::value) {
178+
if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
179+
return WalkResult::interrupt();
180+
} else {
181+
walkOperandsPostOrder(binOpExpr);
182+
}
164183
return self->visitModExpr(binOpExpr);
165184
}
166185
case AffineExprKind::FloorDiv: {
167186
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
168-
walkOperandsPostOrder(binOpExpr);
187+
if constexpr (std::is_same<RetTy, WalkResult>::value) {
188+
if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
189+
return WalkResult::interrupt();
190+
} else {
191+
walkOperandsPostOrder(binOpExpr);
192+
}
169193
return self->visitFloorDivExpr(binOpExpr);
170194
}
171195
case AffineExprKind::CeilDiv: {
172196
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
173-
walkOperandsPostOrder(binOpExpr);
197+
if constexpr (std::is_same<RetTy, WalkResult>::value) {
198+
if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
199+
return WalkResult::interrupt();
200+
} else {
201+
walkOperandsPostOrder(binOpExpr);
202+
}
174203
return self->visitCeilDivExpr(binOpExpr);
175204
}
176205
case AffineExprKind::Constant:
@@ -186,8 +215,19 @@ class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
186215
private:
187216
// Walk the operands - each operand is itself walked in post order.
188217
RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
189-
walkPostOrder(expr.getLHS());
190-
walkPostOrder(expr.getRHS());
218+
if constexpr (std::is_same<RetTy, WalkResult>::value) {
219+
if (walkPostOrder(expr.getLHS()).wasInterrupted())
220+
return WalkResult::interrupt();
221+
} else {
222+
walkPostOrder(expr.getLHS());
223+
}
224+
if constexpr (std::is_same<RetTy, WalkResult>::value) {
225+
if (walkPostOrder(expr.getLHS()).wasInterrupted())
226+
return WalkResult::interrupt();
227+
return WalkResult::advance();
228+
} else {
229+
return walkPostOrder(expr.getRHS());
230+
}
191231
}
192232
};
193233

mlir/lib/Dialect/Affine/Utils/Utils.cpp

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1561,22 +1561,21 @@ static LogicalResult getTileSizePos(
15611561
/// memref<4x?xf32, #map0> ==> memref<4x?x?xf32>
15621562
static bool
15631563
isNormalizedMemRefDynamicDim(unsigned dim, AffineMap layoutMap,
1564-
SmallVectorImpl<unsigned> &inMemrefTypeDynDims,
1565-
MLIRContext *context) {
1566-
bool isDynamicDim = false;
1564+
SmallVectorImpl<unsigned> &inMemrefTypeDynDims) {
15671565
AffineExpr expr = layoutMap.getResults()[dim];
15681566
// Check if affine expr of the dimension includes dynamic dimension of input
15691567
// memrefType.
1570-
expr.walk([&inMemrefTypeDynDims, &isDynamicDim, &context](AffineExpr e) {
1571-
if (isa<AffineDimExpr>(e)) {
1572-
for (unsigned dm : inMemrefTypeDynDims) {
1573-
if (e == getAffineDimExpr(dm, context)) {
1574-
isDynamicDim = true;
1575-
}
1576-
}
1577-
}
1578-
});
1579-
return isDynamicDim;
1568+
MLIRContext *context = layoutMap.getContext();
1569+
return expr
1570+
.walk([&](AffineExpr e) {
1571+
if (isa<AffineDimExpr>(e) &&
1572+
llvm::any_of(inMemrefTypeDynDims, [&](unsigned dim) {
1573+
return e == getAffineDimExpr(dim, context);
1574+
}))
1575+
return WalkResult::interrupt();
1576+
return WalkResult::advance();
1577+
})
1578+
.wasInterrupted();
15801579
}
15811580

15821581
/// Create affine expr to calculate dimension size for a tiled-layout map.
@@ -1792,29 +1791,28 @@ MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
17921791
MLIRContext *context = memrefType.getContext();
17931792
for (unsigned d = 0; d < newRank; ++d) {
17941793
// Check if this dimension is dynamic.
1795-
bool isDynDim =
1796-
isNormalizedMemRefDynamicDim(d, layoutMap, memrefTypeDynDims, context);
1797-
if (isDynDim) {
1794+
if (bool isDynDim =
1795+
isNormalizedMemRefDynamicDim(d, layoutMap, memrefTypeDynDims)) {
17981796
newShape[d] = ShapedType::kDynamic;
1799-
} else {
1800-
// The lower bound for the shape is always zero.
1801-
std::optional<int64_t> ubConst = fac.getConstantBound64(BoundType::UB, d);
1802-
// For a static memref and an affine map with no symbols, this is
1803-
// always bounded. However, when we have symbols, we may not be able to
1804-
// obtain a constant upper bound. Also, mapping to a negative space is
1805-
// invalid for normalization.
1806-
if (!ubConst.has_value() || *ubConst < 0) {
1807-
LLVM_DEBUG(llvm::dbgs()
1808-
<< "can't normalize map due to unknown/invalid upper bound");
1809-
return memrefType;
1810-
}
1811-
// If dimension of new memrefType is dynamic, the value is -1.
1812-
newShape[d] = *ubConst + 1;
1797+
continue;
1798+
}
1799+
// The lower bound for the shape is always zero.
1800+
std::optional<int64_t> ubConst = fac.getConstantBound64(BoundType::UB, d);
1801+
// For a static memref and an affine map with no symbols, this is
1802+
// always bounded. However, when we have symbols, we may not be able to
1803+
// obtain a constant upper bound. Also, mapping to a negative space is
1804+
// invalid for normalization.
1805+
if (!ubConst.has_value() || *ubConst < 0) {
1806+
LLVM_DEBUG(llvm::dbgs()
1807+
<< "can't normalize map due to unknown/invalid upper bound");
1808+
return memrefType;
18131809
}
1810+
// If dimension of new memrefType is dynamic, the value is -1.
1811+
newShape[d] = *ubConst + 1;
18141812
}
18151813

18161814
// Create the new memref type after trivializing the old layout map.
1817-
MemRefType newMemRefType =
1815+
auto newMemRefType =
18181816
MemRefType::Builder(memrefType)
18191817
.setShape(newShape)
18201818
.setLayout(AffineMapAttr::get(

mlir/lib/IR/AffineExpr.cpp

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,37 @@ MLIRContext *AffineExpr::getContext() const { return expr->context; }
2626

2727
AffineExprKind AffineExpr::getKind() const { return expr->kind; }
2828

29-
/// Walk all of the AffineExprs in this subgraph in postorder.
30-
void AffineExpr::walk(std::function<void(AffineExpr)> callback) const {
31-
struct AffineExprWalker : public AffineExprVisitor<AffineExprWalker> {
32-
std::function<void(AffineExpr)> callback;
33-
34-
AffineExprWalker(std::function<void(AffineExpr)> callback)
35-
: callback(std::move(callback)) {}
36-
37-
void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { callback(expr); }
38-
void visitConstantExpr(AffineConstantExpr expr) { callback(expr); }
39-
void visitDimExpr(AffineDimExpr expr) { callback(expr); }
40-
void visitSymbolExpr(AffineSymbolExpr expr) { callback(expr); }
29+
/// Walk all of the AffineExprs in `e` in postorder. This is a private factory
30+
/// method to help handle lambda walk functions. Users should use the regular
31+
/// (non-static) `walk` method.
32+
template <typename WalkRetTy>
33+
WalkRetTy mlir::AffineExpr::walk(AffineExpr e,
34+
function_ref<WalkRetTy(AffineExpr)> callback) {
35+
struct AffineExprWalker
36+
: public AffineExprVisitor<AffineExprWalker, WalkRetTy> {
37+
function_ref<WalkRetTy(AffineExpr)> callback;
38+
39+
AffineExprWalker(function_ref<WalkRetTy(AffineExpr)> callback)
40+
: callback(callback) {}
41+
42+
WalkRetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
43+
return callback(expr);
44+
}
45+
WalkRetTy visitConstantExpr(AffineConstantExpr expr) {
46+
return callback(expr);
47+
}
48+
WalkRetTy visitDimExpr(AffineDimExpr expr) { return callback(expr); }
49+
WalkRetTy visitSymbolExpr(AffineSymbolExpr expr) { return callback(expr); }
4150
};
4251

43-
AffineExprWalker(std::move(callback)).walkPostOrder(*this);
52+
return AffineExprWalker(callback).walkPostOrder(e);
4453
}
54+
// Explicitly instantiate for the two supported return types.
55+
template void mlir::AffineExpr::walk(AffineExpr e,
56+
function_ref<void(AffineExpr)> callback);
57+
template WalkResult
58+
mlir::AffineExpr::walk(AffineExpr e,
59+
function_ref<WalkResult(AffineExpr)> callback);
4560

4661
// Dispatch affine expression construction based on kind.
4762
AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,

mlir/test/IR/affine-walk.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// RUN: mlir-opt -test-affine-walk -verify-diagnostics %s
2+
3+
// Test affine walk interrupt. A remark should be printed only for the first mod
4+
// expression encountered in post order.
5+
6+
#map = affine_map<(i, j) -> ((i mod 4) mod 2, j)>
7+
8+
"test.check_first_mod"() {"map" = #map} : () -> ()
9+
// expected-remark@-1 {{mod expression}}

mlir/test/lib/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Exclude tests from libMLIR.so
22
add_mlir_library(MLIRTestIR
3+
TestAffineWalk.cpp
34
TestBytecodeRoundtrip.cpp
45
TestBuiltinAttributeInterfaces.cpp
56
TestBuiltinDistinctAttributes.cpp

mlir/test/lib/IR/TestAffineWalk.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
//===- TestAffineWalk.cpp - Pass to test affine walks
2+
//----------------------===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#include "mlir/Pass/Pass.h"
11+
12+
#include "mlir/IR/BuiltinOps.h"
13+
14+
using namespace mlir;
15+
16+
namespace {
17+
/// A test pass for verifying walk interrupts.
18+
struct TestAffineWalk
19+
: public PassWrapper<TestAffineWalk, OperationPass<ModuleOp>> {
20+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAffineWalk)
21+
22+
void runOnOperation() override;
23+
StringRef getArgument() const final { return "test-affine-walk"; }
24+
StringRef getDescription() const final { return "Test affine walk method."; }
25+
};
26+
} // namespace
27+
28+
/// Emits a remark for the first `map`'s result expression that contains a
29+
/// mod expression.
30+
static void checkMod(AffineMap map, Location loc) {
31+
for (AffineExpr e : map.getResults()) {
32+
e.walk([&](AffineExpr s) {
33+
if (s.getKind() == mlir::AffineExprKind::Mod) {
34+
emitRemark(loc, "mod expression: ");
35+
return WalkResult::interrupt();
36+
}
37+
return WalkResult::advance();
38+
});
39+
}
40+
}
41+
42+
void TestAffineWalk::runOnOperation() {
43+
auto m = getOperation();
44+
// Test whether the walk is being correctly interrupted.
45+
m.walk([](Operation *op) {
46+
for (NamedAttribute attr : op->getAttrs()) {
47+
auto mapAttr = attr.getValue().dyn_cast<AffineMapAttr>();
48+
if (!mapAttr)
49+
return;
50+
checkMod(mapAttr.getAffineMap(), op->getLoc());
51+
}
52+
});
53+
}
54+
55+
namespace mlir {
56+
void registerTestAffineWalk() { PassRegistration<TestAffineWalk>(); }
57+
} // namespace mlir

mlir/tools/mlir-opt/mlir-opt.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,12 @@ void registerSymbolTestPasses();
4444
void registerRegionTestPasses();
4545
void registerTestAffineDataCopyPass();
4646
void registerTestAffineReifyValueBoundsPass();
47+
void registerTestAffineLoopUnswitchingPass();
48+
void registerTestAffineWalk();
4749
void registerTestBytecodeRoundtripPasses();
4850
void registerTestDecomposeAffineOpPass();
49-
void registerTestAffineLoopUnswitchingPass();
50-
void registerTestGpuLoweringPasses();
5151
void registerTestFunc();
52+
void registerTestGpuLoweringPasses();
5253
void registerTestGpuMemoryPromotionPass();
5354
void registerTestLoopPermutationPass();
5455
void registerTestMatchers();
@@ -167,12 +168,13 @@ void registerTestPasses() {
167168
registerSymbolTestPasses();
168169
registerRegionTestPasses();
169170
registerTestAffineDataCopyPass();
170-
registerTestAffineReifyValueBoundsPass();
171-
registerTestDecomposeAffineOpPass();
172171
registerTestAffineLoopUnswitchingPass();
173-
registerTestGpuLoweringPasses();
172+
registerTestAffineReifyValueBoundsPass();
173+
registerTestAffineWalk();
174174
registerTestBytecodeRoundtripPasses();
175+
registerTestDecomposeAffineOpPass();
175176
registerTestFunc();
177+
registerTestGpuLoweringPasses();
176178
registerTestGpuMemoryPromotionPass();
177179
registerTestLoopPermutationPass();
178180
registerTestMatchers();

0 commit comments

Comments
 (0)