Skip to content

Commit a02a0e8

Browse files
[mlir][Transforms] GreedyPatternRewriteDriver: Better expensive checks encapsulation (#78175)
This change moves most IR verification logic (which is part of the expensive checks) into `DebugFingerPrints` and renames the struct to `ExpensiveChecks`. This isolates the debugging logic better from the remaining code. This commit also removes a redundant check: the IR is no longer verified after a failed pattern application. We already assert that the IR did not change. (We know that the IR was valid before the attempted pattern application.)
1 parent af1463d commit a02a0e8

File tree

1 file changed

+42
-24
lines changed

1 file changed

+42
-24
lines changed

mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,18 @@ namespace {
4343
//===----------------------------------------------------------------------===//
4444

4545
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
46-
/// A helper struct that stores finger prints of ops in order to detect broken
47-
/// RewritePatterns. A rewrite pattern is broken if it modifies IR without
48-
/// using the rewriter API or if it returns an inconsistent return value.
49-
struct DebugFingerPrints : public RewriterBase::ForwardingListener {
50-
DebugFingerPrints(RewriterBase::Listener *driver)
51-
: RewriterBase::ForwardingListener(driver) {}
46+
/// A helper struct that performs various "expensive checks" to detect broken
47+
/// rewrite patterns use the rewriter API incorrectly. A rewrite pattern is
48+
/// broken if:
49+
/// * IR does not verify after pattern application / folding.
50+
/// * Pattern returns "failure" but the IR has changed.
51+
/// * Pattern returns "success" but the IR has not changed.
52+
///
53+
/// This struct stores finger prints of ops to determine whether the IR has
54+
/// changed or not.
55+
struct ExpensiveChecks : public RewriterBase::ForwardingListener {
56+
ExpensiveChecks(RewriterBase::Listener *driver, Operation *topLevel)
57+
: RewriterBase::ForwardingListener(driver), topLevel(topLevel) {}
5258

5359
/// Compute finger prints of the given op and its nested ops.
5460
void computeFingerPrints(Operation *topLevel) {
@@ -65,6 +71,13 @@ struct DebugFingerPrints : public RewriterBase::ForwardingListener {
6571
}
6672

6773
void notifyRewriteSuccess() {
74+
if (!topLevel)
75+
return;
76+
77+
// Make sure that the IR still verifies.
78+
if (failed(verify(topLevel)))
79+
llvm::report_fatal_error("IR failed to verify after pattern application");
80+
6881
// Pattern application success => IR must have changed.
6982
OperationFingerPrint afterFingerPrint(topLevel);
7083
if (*topLevelFingerPrint == afterFingerPrint) {
@@ -90,6 +103,9 @@ struct DebugFingerPrints : public RewriterBase::ForwardingListener {
90103
}
91104

92105
void notifyRewriteFailure() {
106+
if (!topLevel)
107+
return;
108+
93109
// Pattern application failure => IR must not have changed.
94110
OperationFingerPrint afterFingerPrint(topLevel);
95111
if (*topLevelFingerPrint != afterFingerPrint) {
@@ -98,6 +114,15 @@ struct DebugFingerPrints : public RewriterBase::ForwardingListener {
98114
}
99115
}
100116

117+
void notifyFoldingSuccess() {
118+
if (!topLevel)
119+
return;
120+
121+
// Make sure that the IR still verifies.
122+
if (failed(verify(topLevel)))
123+
llvm::report_fatal_error("IR failed to verify after folding");
124+
}
125+
101126
protected:
102127
/// Invalidate the finger print of the given op, i.e., remove it from the map.
103128
void invalidateFingerPrint(Operation *op) {
@@ -362,7 +387,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
362387
PatternApplicator matcher;
363388

364389
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
365-
DebugFingerPrints debugFingerPrints;
390+
ExpensiveChecks expensiveChecks;
366391
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
367392
};
368393
} // namespace
@@ -373,7 +398,9 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
373398
: PatternRewriter(ctx), config(config), matcher(patterns)
374399
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
375400
// clang-format off
376-
, debugFingerPrints(this)
401+
, expensiveChecks(
402+
/*driver=*/this,
403+
/*topLevel=*/config.scope ? config.scope->getParentOp() : nullptr)
377404
// clang-format on
378405
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
379406
{
@@ -384,7 +411,7 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
384411
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
385412
// Send IR notifications to the debug handler. This handler will then forward
386413
// all notifications to this GreedyPatternRewriteDriver.
387-
setListener(&debugFingerPrints);
414+
setListener(&expensiveChecks);
388415
#else
389416
setListener(this);
390417
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
@@ -458,8 +485,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
458485
changed = true;
459486
LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
460487
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
461-
if (config.scope && failed(verify(config.scope->getParentOp())))
462-
llvm::report_fatal_error("IR failed to verify after folding");
488+
expensiveChecks.notifyFoldingSuccess();
463489
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
464490
continue;
465491
}
@@ -513,8 +539,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
513539
changed = true;
514540
LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
515541
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
516-
if (config.scope && failed(verify(config.scope->getParentOp())))
517-
llvm::report_fatal_error("IR failed to verify after folding");
542+
expensiveChecks.notifyFoldingSuccess();
518543
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
519544
continue;
520545
}
@@ -551,33 +576,26 @@ bool GreedyPatternRewriteDriver::processWorklist() {
551576

552577
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
553578
if (config.scope) {
554-
debugFingerPrints.computeFingerPrints(config.scope->getParentOp());
579+
expensiveChecks.computeFingerPrints(config.scope->getParentOp());
555580
}
556581
auto clearFingerprints =
557-
llvm::make_scope_exit([&]() { debugFingerPrints.clear(); });
582+
llvm::make_scope_exit([&]() { expensiveChecks.clear(); });
558583
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
559584

560585
LogicalResult matchResult =
561586
matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess);
562587

563-
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
564-
if (config.scope && failed(verify(config.scope->getParentOp())))
565-
llvm::report_fatal_error("IR failed to verify after pattern application");
566-
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
567-
568588
if (succeeded(matchResult)) {
569589
LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
570590
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
571-
if (config.scope)
572-
debugFingerPrints.notifyRewriteSuccess();
591+
expensiveChecks.notifyRewriteSuccess();
573592
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
574593
changed = true;
575594
++numRewrites;
576595
} else {
577596
LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match"));
578597
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
579-
if (config.scope)
580-
debugFingerPrints.notifyRewriteFailure();
598+
expensiveChecks.notifyRewriteFailure();
581599
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
582600
}
583601
}

0 commit comments

Comments
 (0)