Skip to content

[mlir] Add fast walk-based pattern rewrite driver #113825

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/docs/ActionTracing.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ An action can also carry arbitrary payload, for example we can extend the

```c++
/// A custom Action can be defined minimally by deriving from
/// `tracing::ActionImpl`. It can has any members!
/// `tracing::ActionImpl`. It can have any members!
class MyCustomAction : public tracing::ActionImpl<MyCustomAction> {
public:
using Base = tracing::ActionImpl<MyCustomAction>;
Expand Down
40 changes: 33 additions & 7 deletions mlir/docs/PatternRewriter.md
Original file line number Diff line number Diff line change
Expand Up @@ -320,15 +320,41 @@ conversion target, via a set of pattern-based operation rewriting patterns. This
framework also provides support for type conversions. More information on this
driver can be found [here](DialectConversion.md).

### Walk Pattern Rewrite Driver

This is a fast and simple driver that walks the given op and applies patterns
that locally have the most benefit. The benefit of a pattern is decided solely
by the benefit specified on the pattern, and the relative order of the pattern
within the pattern list (when two patterns have the same local benefit).

The driver performs a post-order traversal. Note that it walks regions of the
given op but does not visit the op.

This driver does not (re)visit modified or newly replaced ops, and does not
allow for progressive rewrites of the same op. Op and block erasure is only
supported for the currently matched op and its descendant. If your pattern
set requires these, consider using the Greedy Pattern Rewrite Driver instead,
at the expense of extra overhead.

This driver is exposed using the `walkAndApplyPatterns` function.

Note: This driver listens for IR changes via the callbacks provided by
`RewriterBase`. It is important that patterns announce all IR changes to the
rewriter and do not bypass the rewriter API by modifying ops directly.

#### Debugging

You can debug the Walk Pattern Rewrite Driver by passing the
`--debug-only=walk-rewriter` CLI flag. This will print the visited and matched
ops.

### Greedy Pattern Rewrite Driver

This driver processes ops in a worklist-driven fashion and greedily applies the
patterns that locally have the most benefit. The benefit of a pattern is decided
solely by the benefit specified on the pattern, and the relative order of the
pattern within the pattern list (when two patterns have the same local benefit).
Patterns are iteratively applied to operations until a fixed point is reached or
until the configurable maximum number of iterations exhausted, at which point
the driver finishes.
patterns that locally have the most benefit (same as the Walk Pattern Rewrite
Driver). Patterns are iteratively applied to operations until a fixed point is
reached or until the configurable maximum number of iterations exhausted, at
which point the driver finishes.

This driver comes in two fashions:

Expand Down Expand Up @@ -368,7 +394,7 @@ rewriter and do not bypass the rewriter API by modifying ops directly.
Note: This driver is the one used by the [canonicalization](Canonicalization.md)
[pass](Passes.md/#-canonicalize) in MLIR.

### Debugging
#### Debugging

To debug the execution of the greedy pattern rewrite driver,
`-debug-only=greedy-rewriter` may be used. This command line flag activates
Expand Down
28 changes: 17 additions & 11 deletions mlir/include/mlir/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -461,54 +461,60 @@ class RewriterBase : public OpBuilder {
/// struct can be used as a base to create listener chains, so that multiple
/// listeners can be notified of IR changes.
struct ForwardingListener : public RewriterBase::Listener {
ForwardingListener(OpBuilder::Listener *listener) : listener(listener) {}
ForwardingListener(OpBuilder::Listener *listener)
: listener(listener),
rewriteListener(
dyn_cast_if_present<RewriterBase::Listener>(listener)) {}

void notifyOperationInserted(Operation *op, InsertPoint previous) override {
listener->notifyOperationInserted(op, previous);
if (listener)
listener->notifyOperationInserted(op, previous);
}
void notifyBlockInserted(Block *block, Region *previous,
Region::iterator previousIt) override {
listener->notifyBlockInserted(block, previous, previousIt);
if (listener)
listener->notifyBlockInserted(block, previous, previousIt);
}
void notifyBlockErased(Block *block) override {
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
if (rewriteListener)
rewriteListener->notifyBlockErased(block);
}
void notifyOperationModified(Operation *op) override {
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
if (rewriteListener)
rewriteListener->notifyOperationModified(op);
}
void notifyOperationReplaced(Operation *op, Operation *newOp) override {
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
if (rewriteListener)
rewriteListener->notifyOperationReplaced(op, newOp);
}
void notifyOperationReplaced(Operation *op,
ValueRange replacement) override {
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
if (rewriteListener)
rewriteListener->notifyOperationReplaced(op, replacement);
}
void notifyOperationErased(Operation *op) override {
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
if (rewriteListener)
rewriteListener->notifyOperationErased(op);
}
void notifyPatternBegin(const Pattern &pattern, Operation *op) override {
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
if (rewriteListener)
rewriteListener->notifyPatternBegin(pattern, op);
}
void notifyPatternEnd(const Pattern &pattern,
LogicalResult status) override {
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
if (rewriteListener)
rewriteListener->notifyPatternEnd(pattern, status);
}
void notifyMatchFailure(
Location loc,
function_ref<void(Diagnostic &)> reasonCallback) override {
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
if (rewriteListener)
rewriteListener->notifyMatchFailure(loc, reasonCallback);
}

private:
OpBuilder::Listener *listener;
RewriterBase::Listener *rewriteListener;
};

/// Move the blocks that belong to "region" before the given position in
Expand Down
37 changes: 37 additions & 0 deletions mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//===- WALKPATTERNREWRITEDRIVER.h - Walk Pattern Rewrite Driver -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Declares a helper function to walk the given op and apply rewrite patterns.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_TRANSFORMS_WALKPATTERNREWRITEDRIVER_H_
#define MLIR_TRANSFORMS_WALKPATTERNREWRITEDRIVER_H_

#include "mlir/IR/Visitors.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"

namespace mlir {

/// A fast walk-based pattern rewrite driver. Rewrites ops nested under the
/// given operation by walking it and applying the highest benefit patterns.
/// This rewriter *does not* wait until a fixpoint is reached and *does not*
/// visit modified or newly replaced ops. Also *does not* perform folding or
/// dead-code elimination.
///
/// This is intended as the simplest and most lightweight pattern rewriter in
/// cases when a simple walk gets the job done.
///
/// Note: Does not apply patterns to the given operation itself.
void walkAndApplyPatterns(Operation *op,
const FrozenRewritePatternSet &patterns,
RewriterBase::Listener *listener = nullptr);

} // namespace mlir

#endif // MLIR_TRANSFORMS_WALKPATTERNREWRITEDRIVER_H_
8 changes: 2 additions & 6 deletions mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"

namespace mlir {
namespace arith {
Expand Down Expand Up @@ -157,11 +157,7 @@ struct ArithUnsignedWhenEquivalentPass
RewritePatternSet patterns(ctx);
populateUnsignedWhenEquivalentPatterns(patterns, solver);

GreedyRewriteConfig config;
config.listener = &listener;

if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
signalPassFailure();
walkAndApplyPatterns(op, std::move(patterns), &listener);
}
};
} // end anonymous namespace
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Transforms/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ add_mlir_library(MLIRTransformUtils
LoopInvariantCodeMotionUtils.cpp
OneToNTypeConversion.cpp
RegionUtils.cpp
WalkPatternRewriteDriver.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
Expand Down
116 changes: 116 additions & 0 deletions mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
//===- WalkPatternRewriteDriver.cpp - A fast walk-based rewriter ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Implements mlir::walkAndApplyPatterns.
//
//===----------------------------------------------------------------------===//

#include "mlir/Transforms/WalkPatternRewriteDriver.h"

#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Rewrite/PatternApplicator.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"

#define DEBUG_TYPE "walk-rewriter"

namespace mlir {

namespace {
struct WalkAndApplyPatternsAction final
: tracing::ActionImpl<WalkAndApplyPatternsAction> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WalkAndApplyPatternsAction)
using ActionImpl::ActionImpl;
static constexpr StringLiteral tag = "walk-and-apply-patterns";
void print(raw_ostream &os) const override { os << tag; }
};

#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Forwarding listener to guard against unsupported erasures of non-descendant
// ops/blocks. Because we use walk-based pattern application, erasing the
// op/block from the *next* iteration (e.g., a user of the visited op) is not
// valid. Note that this is only used with expensive pattern API checks.
struct ErasedOpsListener final : RewriterBase::ForwardingListener {
using RewriterBase::ForwardingListener::ForwardingListener;

void notifyOperationErased(Operation *op) override {
checkErasure(op);
ForwardingListener::notifyOperationErased(op);
}

void notifyBlockErased(Block *block) override {
checkErasure(block->getParentOp());
ForwardingListener::notifyBlockErased(block);
}

void checkErasure(Operation *op) const {
Operation *ancestorOp = op;
while (ancestorOp && ancestorOp != visitedOp)
ancestorOp = ancestorOp->getParentOp();

if (ancestorOp != visitedOp)
llvm::report_fatal_error(
"unsupported erasure in WalkPatternRewriter; "
"erasure is only supported for matched ops and their descendants");
}

Operation *visitedOp = nullptr;
};
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
} // namespace

void walkAndApplyPatterns(Operation *op,
const FrozenRewritePatternSet &patterns,
RewriterBase::Listener *listener) {
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (failed(verify(op)))
llvm::report_fatal_error("walk pattern rewriter input IR failed to verify");
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS

MLIRContext *ctx = op->getContext();
PatternRewriter rewriter(ctx);
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
ErasedOpsListener erasedListener(listener);
rewriter.setListener(&erasedListener);
#else
rewriter.setListener(listener);
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS

PatternApplicator applicator(patterns);
applicator.applyDefaultCostModel();

ctx->executeAction<WalkAndApplyPatternsAction>(
[&] {
for (Region &region : op->getRegions()) {
region.walk([&](Operation *visitedOp) {
LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
llvm::dbgs(), OpPrintingFlags().skipRegions());
llvm::dbgs() << "\n";);
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
erasedListener.visitedOp = visitedOp;
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
}
});
}
},
{op});

#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (failed(verify(op)))
llvm::report_fatal_error(
"walk pattern rewriter result IR failed to verify");
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
}

} // namespace mlir
2 changes: 1 addition & 1 deletion mlir/test/IR/enum-attr-roundtrip.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s | mlir-opt -test-patterns | FileCheck %s
// RUN: mlir-opt %s | mlir-opt -test-greedy-patterns | FileCheck %s

// CHECK-LABEL: @test_enum_attr_roundtrip
func.func @test_enum_attr_roundtrip() -> () {
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/IR/greedy-pattern-rewrite-driver-bottom-up.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -test-patterns="max-iterations=1" \
// RUN: mlir-opt %s -test-greedy-patterns="max-iterations=1" \
// RUN: -allow-unregistered-dialect --split-input-file | FileCheck %s

// CHECK-LABEL: func @add_to_worklist_after_inplace_update()
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/IR/greedy-pattern-rewrite-driver-top-down.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -test-patterns="max-iterations=1 top-down=true" \
// RUN: mlir-opt %s -test-greedy-patterns="max-iterations=1 top-down=true" \
// RUN: --split-input-file | FileCheck %s

// Tests for https://github.com/llvm/llvm-project/issues/86765. Ensure
Expand Down
Loading