Skip to content

Commit 0f8a6b7

Browse files
authored
[mlir] Add fast walk-based pattern rewrite driver (#113825)
This is intended as a fast pattern rewrite driver for the cases when a simple walk gets the job done but we would still want to implement it in terms of rewrite patterns (that can be used with the greedy pattern rewrite driver downstream). The new driver is inspired by the discussion in #112454 and the LLVM Dev presentation from @matthias-springer earlier this week. This limitation comes with some limitations: * It does not repeat until a fixpoint or revisit ops modified in place or newly created ops. In general, it only walks forward (in the post-order). * `matchAndRewrite` can only erase the matched op or its descendants. This is verified under expensive checks. * It does not perform folding / DCE. We could probably relax some of these in the future without sacrificing too much performance.
1 parent 1d03708 commit 0f8a6b7

15 files changed

+455
-67
lines changed

mlir/docs/ActionTracing.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ An action can also carry arbitrary payload, for example we can extend the
8686

8787
```c++
8888
/// A custom Action can be defined minimally by deriving from
89-
/// `tracing::ActionImpl`. It can has any members!
89+
/// `tracing::ActionImpl`. It can have any members!
9090
class MyCustomAction : public tracing::ActionImpl<MyCustomAction> {
9191
public:
9292
using Base = tracing::ActionImpl<MyCustomAction>;

mlir/docs/PatternRewriter.md

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -320,15 +320,41 @@ conversion target, via a set of pattern-based operation rewriting patterns. This
320320
framework also provides support for type conversions. More information on this
321321
driver can be found [here](DialectConversion.md).
322322

323+
### Walk Pattern Rewrite Driver
324+
325+
This is a fast and simple driver that walks the given op and applies patterns
326+
that locally have the most benefit. The benefit of a pattern is decided solely
327+
by the benefit specified on the pattern, and the relative order of the pattern
328+
within the pattern list (when two patterns have the same local benefit).
329+
330+
The driver performs a post-order traversal. Note that it walks regions of the
331+
given op but does not visit the op.
332+
333+
This driver does not (re)visit modified or newly replaced ops, and does not
334+
allow for progressive rewrites of the same op. Op and block erasure is only
335+
supported for the currently matched op and its descendant. If your pattern
336+
set requires these, consider using the Greedy Pattern Rewrite Driver instead,
337+
at the expense of extra overhead.
338+
339+
This driver is exposed using the `walkAndApplyPatterns` function.
340+
341+
Note: This driver listens for IR changes via the callbacks provided by
342+
`RewriterBase`. It is important that patterns announce all IR changes to the
343+
rewriter and do not bypass the rewriter API by modifying ops directly.
344+
345+
#### Debugging
346+
347+
You can debug the Walk Pattern Rewrite Driver by passing the
348+
`--debug-only=walk-rewriter` CLI flag. This will print the visited and matched
349+
ops.
350+
323351
### Greedy Pattern Rewrite Driver
324352

325353
This driver processes ops in a worklist-driven fashion and greedily applies the
326-
patterns that locally have the most benefit. The benefit of a pattern is decided
327-
solely by the benefit specified on the pattern, and the relative order of the
328-
pattern within the pattern list (when two patterns have the same local benefit).
329-
Patterns are iteratively applied to operations until a fixed point is reached or
330-
until the configurable maximum number of iterations exhausted, at which point
331-
the driver finishes.
354+
patterns that locally have the most benefit (same as the Walk Pattern Rewrite
355+
Driver). Patterns are iteratively applied to operations until a fixed point is
356+
reached or until the configurable maximum number of iterations exhausted, at
357+
which point the driver finishes.
332358

333359
This driver comes in two fashions:
334360

@@ -368,7 +394,7 @@ rewriter and do not bypass the rewriter API by modifying ops directly.
368394
Note: This driver is the one used by the [canonicalization](Canonicalization.md)
369395
[pass](Passes.md/#-canonicalize) in MLIR.
370396

371-
### Debugging
397+
#### Debugging
372398

373399
To debug the execution of the greedy pattern rewrite driver,
374400
`-debug-only=greedy-rewriter` may be used. This command line flag activates

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -461,54 +461,60 @@ class RewriterBase : public OpBuilder {
461461
/// struct can be used as a base to create listener chains, so that multiple
462462
/// listeners can be notified of IR changes.
463463
struct ForwardingListener : public RewriterBase::Listener {
464-
ForwardingListener(OpBuilder::Listener *listener) : listener(listener) {}
464+
ForwardingListener(OpBuilder::Listener *listener)
465+
: listener(listener),
466+
rewriteListener(
467+
dyn_cast_if_present<RewriterBase::Listener>(listener)) {}
465468

466469
void notifyOperationInserted(Operation *op, InsertPoint previous) override {
467-
listener->notifyOperationInserted(op, previous);
470+
if (listener)
471+
listener->notifyOperationInserted(op, previous);
468472
}
469473
void notifyBlockInserted(Block *block, Region *previous,
470474
Region::iterator previousIt) override {
471-
listener->notifyBlockInserted(block, previous, previousIt);
475+
if (listener)
476+
listener->notifyBlockInserted(block, previous, previousIt);
472477
}
473478
void notifyBlockErased(Block *block) override {
474-
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
479+
if (rewriteListener)
475480
rewriteListener->notifyBlockErased(block);
476481
}
477482
void notifyOperationModified(Operation *op) override {
478-
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
483+
if (rewriteListener)
479484
rewriteListener->notifyOperationModified(op);
480485
}
481486
void notifyOperationReplaced(Operation *op, Operation *newOp) override {
482-
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
487+
if (rewriteListener)
483488
rewriteListener->notifyOperationReplaced(op, newOp);
484489
}
485490
void notifyOperationReplaced(Operation *op,
486491
ValueRange replacement) override {
487-
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
492+
if (rewriteListener)
488493
rewriteListener->notifyOperationReplaced(op, replacement);
489494
}
490495
void notifyOperationErased(Operation *op) override {
491-
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
496+
if (rewriteListener)
492497
rewriteListener->notifyOperationErased(op);
493498
}
494499
void notifyPatternBegin(const Pattern &pattern, Operation *op) override {
495-
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
500+
if (rewriteListener)
496501
rewriteListener->notifyPatternBegin(pattern, op);
497502
}
498503
void notifyPatternEnd(const Pattern &pattern,
499504
LogicalResult status) override {
500-
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
505+
if (rewriteListener)
501506
rewriteListener->notifyPatternEnd(pattern, status);
502507
}
503508
void notifyMatchFailure(
504509
Location loc,
505510
function_ref<void(Diagnostic &)> reasonCallback) override {
506-
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
511+
if (rewriteListener)
507512
rewriteListener->notifyMatchFailure(loc, reasonCallback);
508513
}
509514

510515
private:
511516
OpBuilder::Listener *listener;
517+
RewriterBase::Listener *rewriteListener;
512518
};
513519

514520
/// Move the blocks that belong to "region" before the given position in
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
//===- WALKPATTERNREWRITEDRIVER.h - Walk Pattern Rewrite Driver -*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Declares a helper function to walk the given op and apply rewrite patterns.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_TRANSFORMS_WALKPATTERNREWRITEDRIVER_H_
14+
#define MLIR_TRANSFORMS_WALKPATTERNREWRITEDRIVER_H_
15+
16+
#include "mlir/IR/Visitors.h"
17+
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
18+
19+
namespace mlir {
20+
21+
/// A fast walk-based pattern rewrite driver. Rewrites ops nested under the
22+
/// given operation by walking it and applying the highest benefit patterns.
23+
/// This rewriter *does not* wait until a fixpoint is reached and *does not*
24+
/// visit modified or newly replaced ops. Also *does not* perform folding or
25+
/// dead-code elimination.
26+
///
27+
/// This is intended as the simplest and most lightweight pattern rewriter in
28+
/// cases when a simple walk gets the job done.
29+
///
30+
/// Note: Does not apply patterns to the given operation itself.
31+
void walkAndApplyPatterns(Operation *op,
32+
const FrozenRewritePatternSet &patterns,
33+
RewriterBase::Listener *listener = nullptr);
34+
35+
} // namespace mlir
36+
37+
#endif // MLIR_TRANSFORMS_WALKPATTERNREWRITEDRIVER_H_

mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
1515
#include "mlir/Dialect/Arith/IR/Arith.h"
1616
#include "mlir/IR/PatternMatch.h"
17-
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17+
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
1818

1919
namespace mlir {
2020
namespace arith {
@@ -157,11 +157,7 @@ struct ArithUnsignedWhenEquivalentPass
157157
RewritePatternSet patterns(ctx);
158158
populateUnsignedWhenEquivalentPatterns(patterns, solver);
159159

160-
GreedyRewriteConfig config;
161-
config.listener = &listener;
162-
163-
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
164-
signalPassFailure();
160+
walkAndApplyPatterns(op, std::move(patterns), &listener);
165161
}
166162
};
167163
} // end anonymous namespace

mlir/lib/Transforms/Utils/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ add_mlir_library(MLIRTransformUtils
1010
LoopInvariantCodeMotionUtils.cpp
1111
OneToNTypeConversion.cpp
1212
RegionUtils.cpp
13+
WalkPatternRewriteDriver.cpp
1314

1415
ADDITIONAL_HEADER_DIRS
1516
${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
//===- WalkPatternRewriteDriver.cpp - A fast walk-based rewriter ---------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Implements mlir::walkAndApplyPatterns.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
14+
15+
#include "mlir/IR/MLIRContext.h"
16+
#include "mlir/IR/OperationSupport.h"
17+
#include "mlir/IR/PatternMatch.h"
18+
#include "mlir/IR/Verifier.h"
19+
#include "mlir/IR/Visitors.h"
20+
#include "mlir/Rewrite/PatternApplicator.h"
21+
#include "llvm/Support/Debug.h"
22+
#include "llvm/Support/ErrorHandling.h"
23+
24+
#define DEBUG_TYPE "walk-rewriter"
25+
26+
namespace mlir {
27+
28+
namespace {
29+
struct WalkAndApplyPatternsAction final
30+
: tracing::ActionImpl<WalkAndApplyPatternsAction> {
31+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WalkAndApplyPatternsAction)
32+
using ActionImpl::ActionImpl;
33+
static constexpr StringLiteral tag = "walk-and-apply-patterns";
34+
void print(raw_ostream &os) const override { os << tag; }
35+
};
36+
37+
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
38+
// Forwarding listener to guard against unsupported erasures of non-descendant
39+
// ops/blocks. Because we use walk-based pattern application, erasing the
40+
// op/block from the *next* iteration (e.g., a user of the visited op) is not
41+
// valid. Note that this is only used with expensive pattern API checks.
42+
struct ErasedOpsListener final : RewriterBase::ForwardingListener {
43+
using RewriterBase::ForwardingListener::ForwardingListener;
44+
45+
void notifyOperationErased(Operation *op) override {
46+
checkErasure(op);
47+
ForwardingListener::notifyOperationErased(op);
48+
}
49+
50+
void notifyBlockErased(Block *block) override {
51+
checkErasure(block->getParentOp());
52+
ForwardingListener::notifyBlockErased(block);
53+
}
54+
55+
void checkErasure(Operation *op) const {
56+
Operation *ancestorOp = op;
57+
while (ancestorOp && ancestorOp != visitedOp)
58+
ancestorOp = ancestorOp->getParentOp();
59+
60+
if (ancestorOp != visitedOp)
61+
llvm::report_fatal_error(
62+
"unsupported erasure in WalkPatternRewriter; "
63+
"erasure is only supported for matched ops and their descendants");
64+
}
65+
66+
Operation *visitedOp = nullptr;
67+
};
68+
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
69+
} // namespace
70+
71+
void walkAndApplyPatterns(Operation *op,
72+
const FrozenRewritePatternSet &patterns,
73+
RewriterBase::Listener *listener) {
74+
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
75+
if (failed(verify(op)))
76+
llvm::report_fatal_error("walk pattern rewriter input IR failed to verify");
77+
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
78+
79+
MLIRContext *ctx = op->getContext();
80+
PatternRewriter rewriter(ctx);
81+
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
82+
ErasedOpsListener erasedListener(listener);
83+
rewriter.setListener(&erasedListener);
84+
#else
85+
rewriter.setListener(listener);
86+
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
87+
88+
PatternApplicator applicator(patterns);
89+
applicator.applyDefaultCostModel();
90+
91+
ctx->executeAction<WalkAndApplyPatternsAction>(
92+
[&] {
93+
for (Region &region : op->getRegions()) {
94+
region.walk([&](Operation *visitedOp) {
95+
LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
96+
llvm::dbgs(), OpPrintingFlags().skipRegions());
97+
llvm::dbgs() << "\n";);
98+
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
99+
erasedListener.visitedOp = visitedOp;
100+
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
101+
if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
102+
LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
103+
}
104+
});
105+
}
106+
},
107+
{op});
108+
109+
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
110+
if (failed(verify(op)))
111+
llvm::report_fatal_error(
112+
"walk pattern rewriter result IR failed to verify");
113+
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
114+
}
115+
116+
} // namespace mlir

mlir/test/IR/enum-attr-roundtrip.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s | mlir-opt -test-patterns | FileCheck %s
1+
// RUN: mlir-opt %s | mlir-opt -test-greedy-patterns | FileCheck %s
22

33
// CHECK-LABEL: @test_enum_attr_roundtrip
44
func.func @test_enum_attr_roundtrip() -> () {

mlir/test/IR/greedy-pattern-rewrite-driver-bottom-up.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -test-patterns="max-iterations=1" \
1+
// RUN: mlir-opt %s -test-greedy-patterns="max-iterations=1" \
22
// RUN: -allow-unregistered-dialect --split-input-file | FileCheck %s
33

44
// CHECK-LABEL: func @add_to_worklist_after_inplace_update()

mlir/test/IR/greedy-pattern-rewrite-driver-top-down.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -test-patterns="max-iterations=1 top-down=true" \
1+
// RUN: mlir-opt %s -test-greedy-patterns="max-iterations=1 top-down=true" \
22
// RUN: --split-input-file | FileCheck %s
33

44
// Tests for https://github.com/llvm/llvm-project/issues/86765. Ensure

0 commit comments

Comments
 (0)