Skip to content

[mlir][Transforms] GreedyPatternRewriteDriver: Better expensive checks encapsulation #78175

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 42 additions & 24 deletions mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,18 @@ namespace {
//===----------------------------------------------------------------------===//

#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
/// A helper struct that stores finger prints of ops in order to detect broken
/// RewritePatterns. A rewrite pattern is broken if it modifies IR without
/// using the rewriter API or if it returns an inconsistent return value.
struct DebugFingerPrints : public RewriterBase::ForwardingListener {
DebugFingerPrints(RewriterBase::Listener *driver)
: RewriterBase::ForwardingListener(driver) {}
/// A helper struct that performs various "expensive checks" to detect broken
/// rewrite patterns use the rewriter API incorrectly. A rewrite pattern is
/// broken if:
/// * IR does not verify after pattern application / folding.
/// * Pattern returns "failure" but the IR has changed.
/// * Pattern returns "success" but the IR has not changed.
///
/// This struct stores finger prints of ops to determine whether the IR has
/// changed or not.
struct ExpensiveChecks : public RewriterBase::ForwardingListener {
ExpensiveChecks(RewriterBase::Listener *driver, Operation *topLevel)
: RewriterBase::ForwardingListener(driver), topLevel(topLevel) {}

/// Compute finger prints of the given op and its nested ops.
void computeFingerPrints(Operation *topLevel) {
Expand All @@ -65,6 +71,13 @@ struct DebugFingerPrints : public RewriterBase::ForwardingListener {
}

void notifyRewriteSuccess() {
if (!topLevel)
return;

// Make sure that the IR still verifies.
if (failed(verify(topLevel)))
llvm::report_fatal_error("IR failed to verify after pattern application");

// Pattern application success => IR must have changed.
OperationFingerPrint afterFingerPrint(topLevel);
if (*topLevelFingerPrint == afterFingerPrint) {
Expand All @@ -90,6 +103,9 @@ struct DebugFingerPrints : public RewriterBase::ForwardingListener {
}

void notifyRewriteFailure() {
if (!topLevel)
return;

// Pattern application failure => IR must not have changed.
OperationFingerPrint afterFingerPrint(topLevel);
if (*topLevelFingerPrint != afterFingerPrint) {
Expand All @@ -98,6 +114,15 @@ struct DebugFingerPrints : public RewriterBase::ForwardingListener {
}
}

void notifyFoldingSuccess() {
if (!topLevel)
return;

// Make sure that the IR still verifies.
if (failed(verify(topLevel)))
llvm::report_fatal_error("IR failed to verify after folding");
}

protected:
/// Invalidate the finger print of the given op, i.e., remove it from the map.
void invalidateFingerPrint(Operation *op) {
Expand Down Expand Up @@ -362,7 +387,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
PatternApplicator matcher;

#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
DebugFingerPrints debugFingerPrints;
ExpensiveChecks expensiveChecks;
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
};
} // namespace
Expand All @@ -373,7 +398,9 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
: PatternRewriter(ctx), config(config), matcher(patterns)
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// clang-format off
, debugFingerPrints(this)
, expensiveChecks(
/*driver=*/this,
/*topLevel=*/config.scope ? config.scope->getParentOp() : nullptr)
// clang-format on
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
{
Expand All @@ -384,7 +411,7 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Send IR notifications to the debug handler. This handler will then forward
// all notifications to this GreedyPatternRewriteDriver.
setListener(&debugFingerPrints);
setListener(&expensiveChecks);
#else
setListener(this);
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
Expand Down Expand Up @@ -458,8 +485,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
changed = true;
LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (config.scope && failed(verify(config.scope->getParentOp())))
llvm::report_fatal_error("IR failed to verify after folding");
expensiveChecks.notifyFoldingSuccess();
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
continue;
}
Expand Down Expand Up @@ -513,8 +539,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
changed = true;
LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (config.scope && failed(verify(config.scope->getParentOp())))
llvm::report_fatal_error("IR failed to verify after folding");
expensiveChecks.notifyFoldingSuccess();
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
continue;
}
Expand Down Expand Up @@ -551,33 +576,26 @@ bool GreedyPatternRewriteDriver::processWorklist() {

#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (config.scope) {
debugFingerPrints.computeFingerPrints(config.scope->getParentOp());
expensiveChecks.computeFingerPrints(config.scope->getParentOp());
}
auto clearFingerprints =
llvm::make_scope_exit([&]() { debugFingerPrints.clear(); });
llvm::make_scope_exit([&]() { expensiveChecks.clear(); });
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS

LogicalResult matchResult =
matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess);

#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (config.scope && failed(verify(config.scope->getParentOp())))
llvm::report_fatal_error("IR failed to verify after pattern application");
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS

if (succeeded(matchResult)) {
LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (config.scope)
debugFingerPrints.notifyRewriteSuccess();
expensiveChecks.notifyRewriteSuccess();
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
changed = true;
++numRewrites;
} else {
LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match"));
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (config.scope)
debugFingerPrints.notifyRewriteFailure();
expensiveChecks.notifyRewriteFailure();
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
}
}
Expand Down