Skip to content

[MLIR] Support interrupting AffineExpr walks #74792

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions mlir/include/mlir/IR/AffineExpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#ifndef MLIR_IR_AFFINEEXPR_H
#define MLIR_IR_AFFINEEXPR_H

#include "mlir/IR/Visitors.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/Hashing.h"
Expand Down Expand Up @@ -123,8 +124,13 @@ class AffineExpr {
/// Return true if the affine expression involves AffineSymbolExpr `position`.
bool isFunctionOfSymbol(unsigned position) const;

/// Walk all of the AffineExpr's in this expression in postorder.
void walk(std::function<void(AffineExpr)> callback) const;
/// Walk all of the AffineExpr's in this expression in postorder. This allows
/// a lambda walk function that can either return `void` or a WalkResult. With
/// a WalkResult, interrupting is supported.
template <typename FnT, typename RetT = detail::walkResultType<FnT>>
RetT walk(FnT &&callback) const {
return walk<RetT>(*this, callback);
}

/// This method substitutes any uses of dimensions and symbols (e.g.
/// dim#0 with dimReplacements[0]) and returns the modified expression tree.
Expand Down Expand Up @@ -202,6 +208,15 @@ class AffineExpr {

protected:
ImplType *expr{nullptr};

private:
/// A trampoline for the templated non-static AffineExpr::walk method to
/// dispatch lambda `callback`'s of either a void result type or a
/// WalkResult type. Walk all of the AffineExprs in `e` in postorder. Users
/// should use the regular (non-static) `walk` method.
template <typename WalkRetTy>
static WalkRetTy walk(AffineExpr e,
function_ref<WalkRetTy(AffineExpr)> callback);
};

/// Affine binary operation expression. An affine binary operation could be an
Expand Down
56 changes: 48 additions & 8 deletions mlir/include/mlir/IR/AffineExprVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ namespace mlir {
/// functions in your class. This class is defined in terms of statically
/// resolved overloading, not virtual functions.
///
/// The visitor is templated on its return type (`RetTy`). With a WalkResult
/// return type, the visitor supports interrupting walks.
///
/// For example, here is a visitor that counts the number of for AffineDimExprs
/// in an AffineExpr.
///
Expand Down Expand Up @@ -65,7 +68,6 @@ namespace mlir {
/// virtual function call overhead. Defining and using a AffineExprVisitor is
/// just as efficient as having your own switch instruction over the instruction
/// opcode.

template <typename SubClass, typename RetTy>
class AffineExprVisitorBase {
public:
Expand Down Expand Up @@ -136,6 +138,8 @@ class AffineExprVisitorBase {
RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
};

/// See documentation for AffineExprVisitorBase. This visitor supports
/// interrupting walks when a `WalkResult` is used for `RetTy`.
template <typename SubClass, typename RetTy = void>
class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
//===--------------------------------------------------------------------===//
Expand All @@ -150,27 +154,52 @@ class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
switch (expr.getKind()) {
case AffineExprKind::Add: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
walkOperandsPostOrder(binOpExpr);
if constexpr (std::is_same<RetTy, WalkResult>::value) {
if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
return WalkResult::interrupt();
} else {
walkOperandsPostOrder(binOpExpr);
}
return self->visitAddExpr(binOpExpr);
}
case AffineExprKind::Mul: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
walkOperandsPostOrder(binOpExpr);
if constexpr (std::is_same<RetTy, WalkResult>::value) {
if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
return WalkResult::interrupt();
} else {
walkOperandsPostOrder(binOpExpr);
}
return self->visitMulExpr(binOpExpr);
}
case AffineExprKind::Mod: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
walkOperandsPostOrder(binOpExpr);
if constexpr (std::is_same<RetTy, WalkResult>::value) {
if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
return WalkResult::interrupt();
} else {
walkOperandsPostOrder(binOpExpr);
}
return self->visitModExpr(binOpExpr);
}
case AffineExprKind::FloorDiv: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
walkOperandsPostOrder(binOpExpr);
if constexpr (std::is_same<RetTy, WalkResult>::value) {
if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
return WalkResult::interrupt();
} else {
walkOperandsPostOrder(binOpExpr);
}
return self->visitFloorDivExpr(binOpExpr);
}
case AffineExprKind::CeilDiv: {
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
walkOperandsPostOrder(binOpExpr);
if constexpr (std::is_same<RetTy, WalkResult>::value) {
if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
return WalkResult::interrupt();
} else {
walkOperandsPostOrder(binOpExpr);
}
return self->visitCeilDivExpr(binOpExpr);
}
case AffineExprKind::Constant:
Expand All @@ -186,8 +215,19 @@ class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
private:
// Walk the operands - each operand is itself walked in post order.
RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
walkPostOrder(expr.getLHS());
walkPostOrder(expr.getRHS());
if constexpr (std::is_same<RetTy, WalkResult>::value) {
if (walkPostOrder(expr.getLHS()).wasInterrupted())
return WalkResult::interrupt();
} else {
walkPostOrder(expr.getLHS());
}
if constexpr (std::is_same<RetTy, WalkResult>::value) {
if (walkPostOrder(expr.getLHS()).wasInterrupted())
return WalkResult::interrupt();
return WalkResult::advance();
} else {
return walkPostOrder(expr.getRHS());
}
}
};

Expand Down
60 changes: 29 additions & 31 deletions mlir/lib/Dialect/Affine/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1561,22 +1561,21 @@ static LogicalResult getTileSizePos(
/// memref<4x?xf32, #map0> ==> memref<4x?x?xf32>
static bool
isNormalizedMemRefDynamicDim(unsigned dim, AffineMap layoutMap,
SmallVectorImpl<unsigned> &inMemrefTypeDynDims,
MLIRContext *context) {
bool isDynamicDim = false;
SmallVectorImpl<unsigned> &inMemrefTypeDynDims) {
AffineExpr expr = layoutMap.getResults()[dim];
// Check if affine expr of the dimension includes dynamic dimension of input
// memrefType.
expr.walk([&inMemrefTypeDynDims, &isDynamicDim, &context](AffineExpr e) {
if (isa<AffineDimExpr>(e)) {
for (unsigned dm : inMemrefTypeDynDims) {
if (e == getAffineDimExpr(dm, context)) {
isDynamicDim = true;
}
}
}
});
return isDynamicDim;
MLIRContext *context = layoutMap.getContext();
return expr
.walk([&](AffineExpr e) {
if (isa<AffineDimExpr>(e) &&
llvm::any_of(inMemrefTypeDynDims, [&](unsigned dim) {
return e == getAffineDimExpr(dim, context);
}))
return WalkResult::interrupt();
return WalkResult::advance();
})
.wasInterrupted();
}

/// Create affine expr to calculate dimension size for a tiled-layout map.
Expand Down Expand Up @@ -1792,29 +1791,28 @@ MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
MLIRContext *context = memrefType.getContext();
for (unsigned d = 0; d < newRank; ++d) {
// Check if this dimension is dynamic.
bool isDynDim =
isNormalizedMemRefDynamicDim(d, layoutMap, memrefTypeDynDims, context);
if (isDynDim) {
if (bool isDynDim =
isNormalizedMemRefDynamicDim(d, layoutMap, memrefTypeDynDims)) {
newShape[d] = ShapedType::kDynamic;
} else {
// The lower bound for the shape is always zero.
std::optional<int64_t> ubConst = fac.getConstantBound64(BoundType::UB, d);
// For a static memref and an affine map with no symbols, this is
// always bounded. However, when we have symbols, we may not be able to
// obtain a constant upper bound. Also, mapping to a negative space is
// invalid for normalization.
if (!ubConst.has_value() || *ubConst < 0) {
LLVM_DEBUG(llvm::dbgs()
<< "can't normalize map due to unknown/invalid upper bound");
return memrefType;
}
// If dimension of new memrefType is dynamic, the value is -1.
newShape[d] = *ubConst + 1;
continue;
}
// The lower bound for the shape is always zero.
std::optional<int64_t> ubConst = fac.getConstantBound64(BoundType::UB, d);
// For a static memref and an affine map with no symbols, this is
// always bounded. However, when we have symbols, we may not be able to
// obtain a constant upper bound. Also, mapping to a negative space is
// invalid for normalization.
if (!ubConst.has_value() || *ubConst < 0) {
LLVM_DEBUG(llvm::dbgs()
<< "can't normalize map due to unknown/invalid upper bound");
return memrefType;
}
// If dimension of new memrefType is dynamic, the value is -1.
newShape[d] = *ubConst + 1;
}

// Create the new memref type after trivializing the old layout map.
MemRefType newMemRefType =
auto newMemRefType =
MemRefType::Builder(memrefType)
.setShape(newShape)
.setLayout(AffineMapAttr::get(
Expand Down
41 changes: 28 additions & 13 deletions mlir/lib/IR/AffineExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,37 @@ MLIRContext *AffineExpr::getContext() const { return expr->context; }

AffineExprKind AffineExpr::getKind() const { return expr->kind; }

/// Walk all of the AffineExprs in this subgraph in postorder.
void AffineExpr::walk(std::function<void(AffineExpr)> callback) const {
struct AffineExprWalker : public AffineExprVisitor<AffineExprWalker> {
std::function<void(AffineExpr)> callback;

AffineExprWalker(std::function<void(AffineExpr)> callback)
: callback(std::move(callback)) {}

void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { callback(expr); }
void visitConstantExpr(AffineConstantExpr expr) { callback(expr); }
void visitDimExpr(AffineDimExpr expr) { callback(expr); }
void visitSymbolExpr(AffineSymbolExpr expr) { callback(expr); }
/// Walk all of the AffineExprs in `e` in postorder. This is a private factory
/// method to help handle lambda walk functions. Users should use the regular
/// (non-static) `walk` method.
template <typename WalkRetTy>
WalkRetTy mlir::AffineExpr::walk(AffineExpr e,
function_ref<WalkRetTy(AffineExpr)> callback) {
struct AffineExprWalker
: public AffineExprVisitor<AffineExprWalker, WalkRetTy> {
function_ref<WalkRetTy(AffineExpr)> callback;

AffineExprWalker(function_ref<WalkRetTy(AffineExpr)> callback)
: callback(callback) {}

WalkRetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
return callback(expr);
}
WalkRetTy visitConstantExpr(AffineConstantExpr expr) {
return callback(expr);
}
WalkRetTy visitDimExpr(AffineDimExpr expr) { return callback(expr); }
WalkRetTy visitSymbolExpr(AffineSymbolExpr expr) { return callback(expr); }
};

AffineExprWalker(std::move(callback)).walkPostOrder(*this);
return AffineExprWalker(callback).walkPostOrder(e);
}
// Explicitly instantiate for the two supported return types.
template void mlir::AffineExpr::walk(AffineExpr e,
function_ref<void(AffineExpr)> callback);
template WalkResult
mlir::AffineExpr::walk(AffineExpr e,
function_ref<WalkResult(AffineExpr)> callback);

// Dispatch affine expression construction based on kind.
AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/IR/affine-walk.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: mlir-opt -test-affine-walk -verify-diagnostics %s

// Test affine walk interrupt. A remark should be printed only for the first mod
// expression encountered in post order.

#map = affine_map<(i, j) -> ((i mod 4) mod 2, j)>

"test.check_first_mod"() {"map" = #map} : () -> ()
// expected-remark@-1 {{mod expression}}
1 change: 1 addition & 0 deletions mlir/test/lib/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRTestIR
TestAffineWalk.cpp
TestBytecodeRoundtrip.cpp
TestBuiltinAttributeInterfaces.cpp
TestBuiltinDistinctAttributes.cpp
Expand Down
57 changes: 57 additions & 0 deletions mlir/test/lib/IR/TestAffineWalk.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
//===- TestAffineWalk.cpp - Pass to test affine walks
//----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Pass/Pass.h"

#include "mlir/IR/BuiltinOps.h"

using namespace mlir;

namespace {
/// A test pass for verifying walk interrupts.
struct TestAffineWalk
: public PassWrapper<TestAffineWalk, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAffineWalk)

void runOnOperation() override;
StringRef getArgument() const final { return "test-affine-walk"; }
StringRef getDescription() const final { return "Test affine walk method."; }
};
} // namespace

/// Emits a remark for the first `map`'s result expression that contains a
/// mod expression.
static void checkMod(AffineMap map, Location loc) {
for (AffineExpr e : map.getResults()) {
e.walk([&](AffineExpr s) {
if (s.getKind() == mlir::AffineExprKind::Mod) {
emitRemark(loc, "mod expression: ");
return WalkResult::interrupt();
}
return WalkResult::advance();
});
}
}

void TestAffineWalk::runOnOperation() {
auto m = getOperation();
// Test whether the walk is being correctly interrupted.
m.walk([](Operation *op) {
for (NamedAttribute attr : op->getAttrs()) {
auto mapAttr = attr.getValue().dyn_cast<AffineMapAttr>();
if (!mapAttr)
return;
checkMod(mapAttr.getAffineMap(), op->getLoc());
}
});
}

namespace mlir {
void registerTestAffineWalk() { PassRegistration<TestAffineWalk>(); }
} // namespace mlir
12 changes: 7 additions & 5 deletions mlir/tools/mlir-opt/mlir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,12 @@ void registerSymbolTestPasses();
void registerRegionTestPasses();
void registerTestAffineDataCopyPass();
void registerTestAffineReifyValueBoundsPass();
void registerTestAffineLoopUnswitchingPass();
void registerTestAffineWalk();
void registerTestBytecodeRoundtripPasses();
void registerTestDecomposeAffineOpPass();
void registerTestAffineLoopUnswitchingPass();
void registerTestGpuLoweringPasses();
void registerTestFunc();
void registerTestGpuLoweringPasses();
void registerTestGpuMemoryPromotionPass();
void registerTestLoopPermutationPass();
void registerTestMatchers();
Expand Down Expand Up @@ -167,12 +168,13 @@ void registerTestPasses() {
registerSymbolTestPasses();
registerRegionTestPasses();
registerTestAffineDataCopyPass();
registerTestAffineReifyValueBoundsPass();
registerTestDecomposeAffineOpPass();
registerTestAffineLoopUnswitchingPass();
registerTestGpuLoweringPasses();
registerTestAffineReifyValueBoundsPass();
registerTestAffineWalk();
registerTestBytecodeRoundtripPasses();
registerTestDecomposeAffineOpPass();
registerTestFunc();
registerTestGpuLoweringPasses();
registerTestGpuMemoryPromotionPass();
registerTestLoopPermutationPass();
registerTestMatchers();
Expand Down