Skip to content

[mlir][IR] Add listener notifications for pattern begin/end #84131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 10, 2024

Conversation

matthias-springer
Copy link
Member

This commit adds two new notifications to RewriterBase::Listener:

  • notifyPatternBegin: Called when a pattern application begins during a greedy pattern rewrite or dialect conversion.
  • notifyPatternEnd: Called when a pattern application finishes during a greedy pattern rewrite or dialect conversion.

The listener infrastructure already provides a notifyMatchFailure callback that notifies about the reason for a pattern match failure. The two new notifications provide additional information about pattern applications.

This change is in preparation of improving the handle update mechanism in the apply_conversion_patterns transform op.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Mar 6, 2024
@llvmbot
Copy link
Member

llvmbot commented Mar 6, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

This commit adds two new notifications to RewriterBase::Listener:

  • notifyPatternBegin: Called when a pattern application begins during a greedy pattern rewrite or dialect conversion.
  • notifyPatternEnd: Called when a pattern application finishes during a greedy pattern rewrite or dialect conversion.

The listener infrastructure already provides a notifyMatchFailure callback that notifies about the reason for a pattern match failure. The two new notifications provide additional information about pattern applications.

This change is in preparation of improving the handle update mechanism in the apply_conversion_patterns transform op.


Full diff: https://github.com/llvm/llvm-project/pull/84131.diff

3 Files Affected:

  • (modified) mlir/include/mlir/IR/PatternMatch.h (+25-5)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+21-8)
  • (modified) mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp (+31-22)
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index f8d22cfb22afd0..838b4947648f5e 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -432,11 +432,22 @@ class RewriterBase : public OpBuilder {
     /// Note: This notification is not triggered when unlinking an operation.
     virtual void notifyOperationErased(Operation *op) {}
 
-    /// Notify the listener that the pattern failed to match the given
-    /// operation, and provide a callback to populate a diagnostic with the
-    /// reason why the failure occurred. This method allows for derived
-    /// listeners to optionally hook into the reason why a rewrite failed, and
-    /// display it to users.
+    /// Notify the listener that the specified pattern is about to be applied
+    /// at the specified root operation.
+    virtual void notifyPatternBegin(const Pattern &pattern, Operation *op) {}
+
+    /// Notify the listener that a pattern application finished with the
+    /// specified status. "success" indicates that the pattern was applied
+    /// successfully. "failure" indicates that the pattern could not be
+    /// applied. The pattern may have communicated the reason for the failure
+    /// with `notifyMatchFailure`.
+    virtual void notifyPatternEnd(const Pattern &pattern,
+                                  LogicalResult status) {}
+
+    /// Notify the listener that the pattern failed to match, and provide a
+    /// callback to populate a diagnostic with the reason why the failure
+    /// occurred. This method allows for derived listeners to optionally hook
+    /// into the reason why a rewrite failed, and display it to users.
     virtual void
     notifyMatchFailure(Location loc,
                        function_ref<void(Diagnostic &)> reasonCallback) {}
@@ -478,6 +489,15 @@ class RewriterBase : public OpBuilder {
       if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
         rewriteListener->notifyOperationErased(op);
     }
+    void notifyPatternBegin(const Pattern &pattern, Operation *op) override {
+      if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+        rewriteListener->notifyPatternBegin(pattern, op);
+    }
+    void notifyPatternEnd(const Pattern &pattern,
+                          LogicalResult status) override {
+      if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+        rewriteListener->notifyPatternEnd(pattern, status);
+    }
     void notifyMatchFailure(
         Location loc,
         function_ref<void(Diagnostic &)> reasonCallback) override {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index a5145246bc30c4..587fbe209b58af 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1863,7 +1863,8 @@ class OperationLegalizer {
   using LegalizationAction = ConversionTarget::LegalizationAction;
 
   OperationLegalizer(const ConversionTarget &targetInfo,
-                     const FrozenRewritePatternSet &patterns);
+                     const FrozenRewritePatternSet &patterns,
+                     const ConversionConfig &config);
 
   /// Returns true if the given operation is known to be illegal on the target.
   bool isIllegal(Operation *op) const;
@@ -1955,12 +1956,16 @@ class OperationLegalizer {
 
   /// The pattern applicator to use for conversions.
   PatternApplicator applicator;
+
+  /// Dialect conversion configuration.
+  const ConversionConfig &config;
 };
 } // namespace
 
 OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
-                                       const FrozenRewritePatternSet &patterns)
-    : target(targetInfo), applicator(patterns) {
+                                       const FrozenRewritePatternSet &patterns,
+                                       const ConversionConfig &config)
+    : target(targetInfo), applicator(patterns), config(config) {
   // The set of patterns that can be applied to illegal operations to transform
   // them into legal ones.
   DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
@@ -2105,7 +2110,10 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
 
   // Functor that returns if the given pattern may be applied.
   auto canApply = [&](const Pattern &pattern) {
-    return canApplyPattern(op, pattern, rewriter);
+    bool canApply = canApplyPattern(op, pattern, rewriter);
+    if (canApply && config.listener)
+      config.listener->notifyPatternBegin(pattern, op);
+    return canApply;
   };
 
   // Functor that cleans up the rewriter state after a pattern failed to match.
@@ -2122,6 +2130,8 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
         rewriterImpl.config.notifyCallback(diag);
       }
     });
+    if (config.listener)
+      config.listener->notifyPatternEnd(pattern, failure());
     rewriterImpl.resetState(curState);
     appliedPatterns.erase(&pattern);
   };
@@ -2134,6 +2144,8 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
     appliedPatterns.erase(&pattern);
     if (failed(result))
       rewriterImpl.resetState(curState);
+    if (config.listener)
+      config.listener->notifyPatternEnd(pattern, result);
     return result;
   };
 
@@ -2509,7 +2521,8 @@ struct OperationConverter {
                               const FrozenRewritePatternSet &patterns,
                               const ConversionConfig &config,
                               OpConversionMode mode)
-      : opLegalizer(target, patterns), config(config), mode(mode) {}
+      : config(config), opLegalizer(target, patterns, this->config),
+        mode(mode) {}
 
   /// Converts the given operations to the conversion target.
   LogicalResult convertOperations(ArrayRef<Operation *> ops);
@@ -2546,12 +2559,12 @@ struct OperationConverter {
       ConversionPatternRewriterImpl &rewriterImpl,
       const DenseMap<Value, SmallVector<Value>> &inverseMapping);
 
-  /// The legalizer to use when converting operations.
-  OperationLegalizer opLegalizer;
-
   /// Dialect conversion configuration.
   ConversionConfig config;
 
+  /// The legalizer to use when converting operations.
+  OperationLegalizer opLegalizer;
+
   /// The conversion mode to use when legalizing operations.
   OpConversionMode mode;
 };
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 51d2f5e01b7235..5fda6f87196f94 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -562,30 +562,39 @@ bool GreedyPatternRewriteDriver::processWorklist() {
     // Try to match one of the patterns. The rewriter is automatically
     // notified of any necessary changes, so there is nothing else to do
     // here.
-#ifndef NDEBUG
-    auto canApply = [&](const Pattern &pattern) {
-      LLVM_DEBUG({
-        logger.getOStream() << "\n";
-        logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
-                           << op->getName() << " -> (";
-        llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
-        logger.getOStream() << ")' {\n";
-        logger.indent();
-      });
-      return true;
-    };
-    auto onFailure = [&](const Pattern &pattern) {
-      LLVM_DEBUG(logResult("failure", "pattern failed to match"));
-    };
-    auto onSuccess = [&](const Pattern &pattern) {
-      LLVM_DEBUG(logResult("success", "pattern applied successfully"));
-      return success();
-    };
-#else
     function_ref<bool(const Pattern &)> canApply = {};
     function_ref<void(const Pattern &)> onFailure = {};
     function_ref<LogicalResult(const Pattern &)> onSuccess = {};
-#endif
+    bool debugBuild = false;
+#ifdef NDEBUG
+    debugBuild = true;
+#endif // NDEBUG
+    if (debugBuild || config.listener) {
+      canApply = [&](const Pattern &pattern) {
+        LLVM_DEBUG({
+          logger.getOStream() << "\n";
+          logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
+                             << op->getName() << " -> (";
+          llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
+          logger.getOStream() << ")' {\n";
+          logger.indent();
+        });
+        if (config.listener)
+          config.listener->notifyPatternBegin(pattern, op);
+        return true;
+      };
+      onFailure = [&](const Pattern &pattern) {
+        LLVM_DEBUG(logResult("failure", "pattern failed to match"));
+        if (config.listener)
+          config.listener->notifyPatternEnd(pattern, failure());
+      };
+      onSuccess = [&](const Pattern &pattern) {
+        LLVM_DEBUG(logResult("success", "pattern applied successfully"));
+        if (config.listener)
+          config.listener->notifyPatternEnd(pattern, success());
+        return success();
+      };
+    }
 
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
     if (config.scope) {
@@ -731,7 +740,7 @@ void GreedyPatternRewriteDriver::notifyMatchFailure(
   LLVM_DEBUG({
     Diagnostic diag(loc, DiagnosticSeverity::Remark);
     reasonCallback(diag);
-    logger.startLine() << "** Failure : " << diag.str() << "\n";
+    logger.startLine() << "** Match Failure : " << diag.str() << "\n";
   });
   if (config.listener)
     config.listener->notifyMatchFailure(loc, reasonCallback);

Until now, `transform.apply_conversion_patterns` consumed the target handle and potentially invalidated handles. This commit adds tracking functionality similar to `transform.apply_patterns`, such that handles are no longer invalidated, but updated based on op replacements performed by the dialect conversion.

This new functionality is hidden behind a `preserve_handles` attribute for now.
@matthias-springer matthias-springer force-pushed the users/matthias-springer/apply_conversion_patterns_listener branch from 270ade8 to c698363 Compare March 8, 2024 01:59
@matthias-springer matthias-springer force-pushed the users/matthias-springer/pattern_listener branch 2 times, most recently from 407c7f7 to 0aef4b9 Compare March 8, 2024 03:43
function_ref<LogicalResult(const Pattern &)> onSuccess = {});
std::function<bool(const Pattern &)> canApply = {},
std::function<void(const Pattern &)> onFailure = {},
std::function<LogicalResult(const Pattern &)> onSuccess = {});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change?

Copy link
Member Author

@matthias-springer matthias-springer Mar 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a stack-use-after-scope (reported by ASAN, also crashes in opt mode) with function_ref. The callback in the greedy pattern rewriter is defined inside of an if check:

    function_ref<bool(const Pattern &)> canApply = {};
    function_ref<void(const Pattern &)> onFailure = {};
    function_ref<LogicalResult(const Pattern &)> onSuccess = {};
    bool debugBuild = false;
#ifdef NDEBUG
    debugBuild = true;
#endif // NDEBUG
    if (debugBuild || config.listener) {
      // Lambda captures "this".
      canApply = [&](const Pattern &pattern) {
        if (this->config.listener) { /* ... */ }
        // ...
      }
      // ...
    }

    // `canApply` points to a lambda that went out of scope.
    LogicalResult matchResult =
        matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess);

function_ref is a "non-owning function wrapper", but the lambda captures this.

Changing to std::function is one way to fix it. I could also just always pass a lambda. That would actually be my preferred solution, but there is a slight overhead when running in opt mode and without listener because the callback would always be called (even if it does not do anything):

    LogicalResult matchResult = matcher.matchAndRewrite(
        op, *this,
        /*canApply=*/[&](const Pattern &pattern) {
          if (this->listener) { /* ... */ }
          // ...
        },
        /*onFailure=*/[&](const Pattern &pattern) { /* ... */},
        /*onSuccess=*/[&](const Pattern &pattern) { /* ... */});

What do you think?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(The existing code seemed to care about performance here; There were different canApply etc. depending on NDEBUG.)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That can explain why you changed it at the call-site, but I'm puzzled about this function: it does not capture the callback as far as I can tell.

Copy link
Member Author

@matthias-springer matthias-springer Mar 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are you referring to with this function?

The problem here is really just caused by the fact that the canApply = assignment is inside of a nested scope. And the lambda object is dead by the time matcher.matchAndRewrite is called. I.e., the canApply function_ref points to an already free'd lambda. At least that's my understanding.

What's the C++ guidelines wrt. to function vs. function_ref. This is the first time I ran into such an issue, and assigning lambdas to function_ref feels "dangerous" to me now (because they can capture stuff). When using function, I don't have to think about the lifetime of an object.

Copy link
Collaborator

@joker-eph joker-eph Mar 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are you referring to with this function?

Where this comment thread is anchored: matchAndRewrite

The problem here is really just caused by the fact that the canApply = assignment is inside of a nested scope. And the lambda object is dead by the time matcher.matchAndRewrite is called. I.e., the canApply function_ref points to an already free'd lambda. At least that's my understanding.

Yes, but that's a problem for the call-site (processWorklist()), I don't quite see where you make the connection to the signature of matchAndRewrite?

What's the C++ guidelines wrt. to function vs. function_ref. This is the first time I ran into such an issue, and assigning lambdas to function_ref feels "dangerous" to me now (because they can capture stuff). When using function, I don't have to think about the lifetime of an object.

It is the same a StringRef: in general you use them in the function API / signature. More rarely for local variable (usually lambda gets assigned an auto variable)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good way to think about it.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/pattern_listener branch from 0aef4b9 to 24a56ca Compare March 8, 2024 07:20
onFailure = nullptr;
onSuccess = nullptr;
}
#endif // NDEBUG
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I didn't suggest changing this, what you had here was reasonable!

@matthias-springer matthias-springer force-pushed the users/matthias-springer/pattern_listener branch from 24a56ca to a65d640 Compare March 8, 2024 07:26
Base automatically changed from users/matthias-springer/apply_conversion_patterns_listener to main March 10, 2024 03:10
@matthias-springer matthias-springer merged commit 9b6bd70 into main Mar 10, 2024
@matthias-springer matthias-springer deleted the users/matthias-springer/pattern_listener branch March 10, 2024 03:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants