@@ -43,12 +43,18 @@ namespace {
43
43
// ===----------------------------------------------------------------------===//
44
44
45
45
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
46
- // / A helper struct that stores finger prints of ops in order to detect broken
47
- // / RewritePatterns. A rewrite pattern is broken if it modifies IR without
48
- // / using the rewriter API or if it returns an inconsistent return value.
49
- struct DebugFingerPrints : public RewriterBase ::ForwardingListener {
50
- DebugFingerPrints (RewriterBase::Listener *driver)
51
- : RewriterBase::ForwardingListener(driver) {}
46
+ // / A helper struct that performs various "expensive checks" to detect broken
47
+ // / rewrite patterns use the rewriter API incorrectly. A rewrite pattern is
48
+ // / broken if:
49
+ // / * IR does not verify after pattern application / folding.
50
+ // / * Pattern returns "failure" but the IR has changed.
51
+ // / * Pattern returns "success" but the IR has not changed.
52
+ // /
53
+ // / This struct stores finger prints of ops to determine whether the IR has
54
+ // / changed or not.
55
+ struct ExpensiveChecks : public RewriterBase ::ForwardingListener {
56
+ ExpensiveChecks (RewriterBase::Listener *driver, Operation *topLevel)
57
+ : RewriterBase::ForwardingListener(driver), topLevel(topLevel) {}
52
58
53
59
// / Compute finger prints of the given op and its nested ops.
54
60
void computeFingerPrints (Operation *topLevel) {
@@ -65,6 +71,13 @@ struct DebugFingerPrints : public RewriterBase::ForwardingListener {
65
71
}
66
72
67
73
void notifyRewriteSuccess () {
74
+ if (!topLevel)
75
+ return ;
76
+
77
+ // Make sure that the IR still verifies.
78
+ if (failed (verify (topLevel)))
79
+ llvm::report_fatal_error (" IR failed to verify after pattern application" );
80
+
68
81
// Pattern application success => IR must have changed.
69
82
OperationFingerPrint afterFingerPrint (topLevel);
70
83
if (*topLevelFingerPrint == afterFingerPrint) {
@@ -90,6 +103,9 @@ struct DebugFingerPrints : public RewriterBase::ForwardingListener {
90
103
}
91
104
92
105
void notifyRewriteFailure () {
106
+ if (!topLevel)
107
+ return ;
108
+
93
109
// Pattern application failure => IR must not have changed.
94
110
OperationFingerPrint afterFingerPrint (topLevel);
95
111
if (*topLevelFingerPrint != afterFingerPrint) {
@@ -98,6 +114,15 @@ struct DebugFingerPrints : public RewriterBase::ForwardingListener {
98
114
}
99
115
}
100
116
117
+ void notifyFoldingSuccess () {
118
+ if (!topLevel)
119
+ return ;
120
+
121
+ // Make sure that the IR still verifies.
122
+ if (failed (verify (topLevel)))
123
+ llvm::report_fatal_error (" IR failed to verify after folding" );
124
+ }
125
+
101
126
protected:
102
127
// / Invalidate the finger print of the given op, i.e., remove it from the map.
103
128
void invalidateFingerPrint (Operation *op) {
@@ -362,7 +387,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
362
387
PatternApplicator matcher;
363
388
364
389
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
365
- DebugFingerPrints debugFingerPrints ;
390
+ ExpensiveChecks expensiveChecks ;
366
391
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
367
392
};
368
393
} // namespace
@@ -373,7 +398,9 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
373
398
: PatternRewriter(ctx), config(config), matcher(patterns)
374
399
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
375
400
// clang-format off
376
- , debugFingerPrints(this )
401
+ , expensiveChecks(
402
+ /* driver=*/ this ,
403
+ /* topLevel=*/ config.scope ? config.scope->getParentOp () : nullptr)
377
404
// clang-format on
378
405
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
379
406
{
@@ -384,7 +411,7 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
384
411
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
385
412
// Send IR notifications to the debug handler. This handler will then forward
386
413
// all notifications to this GreedyPatternRewriteDriver.
387
- setListener (&debugFingerPrints );
414
+ setListener (&expensiveChecks );
388
415
#else
389
416
setListener (this );
390
417
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
@@ -458,8 +485,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
458
485
changed = true ;
459
486
LLVM_DEBUG (logSuccessfulFolding (dumpRootOp));
460
487
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
461
- if (config.scope && failed (verify (config.scope ->getParentOp ())))
462
- llvm::report_fatal_error (" IR failed to verify after folding" );
488
+ expensiveChecks.notifyFoldingSuccess ();
463
489
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
464
490
continue ;
465
491
}
@@ -513,8 +539,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
513
539
changed = true ;
514
540
LLVM_DEBUG (logSuccessfulFolding (dumpRootOp));
515
541
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
516
- if (config.scope && failed (verify (config.scope ->getParentOp ())))
517
- llvm::report_fatal_error (" IR failed to verify after folding" );
542
+ expensiveChecks.notifyFoldingSuccess ();
518
543
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
519
544
continue ;
520
545
}
@@ -551,33 +576,26 @@ bool GreedyPatternRewriteDriver::processWorklist() {
551
576
552
577
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
553
578
if (config.scope ) {
554
- debugFingerPrints .computeFingerPrints (config.scope ->getParentOp ());
579
+ expensiveChecks .computeFingerPrints (config.scope ->getParentOp ());
555
580
}
556
581
auto clearFingerprints =
557
- llvm::make_scope_exit ([&]() { debugFingerPrints .clear (); });
582
+ llvm::make_scope_exit ([&]() { expensiveChecks .clear (); });
558
583
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
559
584
560
585
LogicalResult matchResult =
561
586
matcher.matchAndRewrite (op, *this , canApply, onFailure, onSuccess);
562
587
563
- #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
564
- if (config.scope && failed (verify (config.scope ->getParentOp ())))
565
- llvm::report_fatal_error (" IR failed to verify after pattern application" );
566
- #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
567
-
568
588
if (succeeded (matchResult)) {
569
589
LLVM_DEBUG (logResultWithLine (" success" , " pattern matched" ));
570
590
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
571
- if (config.scope )
572
- debugFingerPrints.notifyRewriteSuccess ();
591
+ expensiveChecks.notifyRewriteSuccess ();
573
592
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
574
593
changed = true ;
575
594
++numRewrites;
576
595
} else {
577
596
LLVM_DEBUG (logResultWithLine (" failure" , " pattern failed to match" ));
578
597
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
579
- if (config.scope )
580
- debugFingerPrints.notifyRewriteFailure ();
598
+ expensiveChecks.notifyRewriteFailure ();
581
599
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
582
600
}
583
601
}
0 commit comments