Skip to content

[mlir][IR][NFC] Listener::notifyMatchFailure returns void #80704

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

matthias-springer
Copy link
Member

There are two notifyMatchFailure methods: one in the rewriter and one in the listener. The one in the rewriter notifies the listener and returns "failure" for convenience. The one in the listener should not return anything; it is just a notification. It can currently be abused to return "success" from the rewriter function. That would be a violation of the rewriter API rules.

Also make sure that the listener is always notified about match failures, not just with NDEBUG. The current implementation is consistent: one notifyMatchFailure overload notifies only in debug mode and another one notifies all the time.

There are two `notifyMatchFailure` methods: one in the rewriter and one in the listener. The one in the rewriter notifies the listener and returns "failure" for convenience. The one in the listener should not return anything; it is just a notification. It can currently be abused to return "success". That would be a violation of the rewriter API rules.
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Feb 5, 2024
@llvmbot
Copy link
Member

llvmbot commented Feb 5, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

There are two notifyMatchFailure methods: one in the rewriter and one in the listener. The one in the rewriter notifies the listener and returns "failure" for convenience. The one in the listener should not return anything; it is just a notification. It can currently be abused to return "success" from the rewriter function. That would be a violation of the rewriter API rules.

Also make sure that the listener is always notified about match failures, not just with NDEBUG. The current implementation is consistent: one notifyMatchFailure overload notifies only in debug mode and another one notifies all the time.


Full diff: https://github.com/llvm/llvm-project/pull/80704.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h (+1-1)
  • (modified) mlir/include/mlir/IR/PatternMatch.h (+6-13)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+1-1)
  • (modified) mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp (+1-2)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+7-8)
  • (modified) mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp (+3-4)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index c2e3cde8ebc69f..2e096e1f552924 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -992,7 +992,7 @@ class TrackingListener : public RewriterBase::Listener,
   /// Notify the listener that the pattern failed to match the given operation,
   /// and provide a callback to populate a diagnostic with the reason why the
   /// failure occurred.
-  LogicalResult
+  void
   notifyMatchFailure(Location loc,
                      function_ref<void(Diagnostic &)> reasonCallback) override;
 
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 61da27825e870c..78dcfe7f6fc3d2 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -437,11 +437,9 @@ class RewriterBase : public OpBuilder {
     /// reason why the failure occurred. This method allows for derived
     /// listeners to optionally hook into the reason why a rewrite failed, and
     /// display it to users.
-    virtual LogicalResult
+    virtual void
     notifyMatchFailure(Location loc,
-                       function_ref<void(Diagnostic &)> reasonCallback) {
-      return failure();
-    }
+                       function_ref<void(Diagnostic &)> reasonCallback) {}
 
     static bool classof(const OpBuilder::Listener *base);
   };
@@ -480,12 +478,11 @@ class RewriterBase : public OpBuilder {
       if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
         rewriteListener->notifyOperationRemoved(op);
     }
-    LogicalResult notifyMatchFailure(
+    void notifyMatchFailure(
         Location loc,
         function_ref<void(Diagnostic &)> reasonCallback) override {
       if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
-        return rewriteListener->notifyMatchFailure(loc, reasonCallback);
-      return failure();
+        rewriteListener->notifyMatchFailure(loc, reasonCallback);
     }
 
   private:
@@ -688,20 +685,16 @@ class RewriterBase : public OpBuilder {
   template <typename CallbackT>
   std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
   notifyMatchFailure(Location loc, CallbackT &&reasonCallback) {
-#ifndef NDEBUG
     if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
-      return rewriteListener->notifyMatchFailure(
+      rewriteListener->notifyMatchFailure(
           loc, function_ref<void(Diagnostic &)>(reasonCallback));
     return failure();
-#else
-    return failure();
-#endif
   }
   template <typename CallbackT>
   std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
   notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) {
     if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
-      return rewriteListener->notifyMatchFailure(
+      rewriteListener->notifyMatchFailure(
           op->getLoc(), function_ref<void(Diagnostic &)>(reasonCallback));
     return failure();
   }
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 51e3e413b516f4..b1ec1fe4ecd51a 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -740,7 +740,7 @@ class ConversionPatternRewriter final : public PatternRewriter,
   void cancelOpModification(Operation *op) override;
 
   /// PatternRewriter hook for notifying match failure reasons.
-  LogicalResult
+  void
   notifyMatchFailure(Location loc,
                      function_ref<void(Diagnostic &)> reasonCallback) override;
   using PatternRewriter::notifyMatchFailure;
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 371ad904dcae5a..a964c205b62e84 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -1265,14 +1265,13 @@ DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
   return diag;
 }
 
-LogicalResult transform::TrackingListener::notifyMatchFailure(
+void transform::TrackingListener::notifyMatchFailure(
     Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
   LLVM_DEBUG({
     Diagnostic diag(loc, DiagnosticSeverity::Remark);
     reasonCallback(diag);
     DBGS() << "Match Failure : " << diag.str() << "\n";
   });
-  return failure();
 }
 
 void transform::TrackingListener::notifyOperationRemoved(Operation *op) {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 346135fb447227..e90447084d68bd 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -921,9 +921,8 @@ struct ConversionPatternRewriterImpl {
                                Block::iterator before);
 
   /// Notifies that a pattern match failed for the given reason.
-  LogicalResult
-  notifyMatchFailure(Location loc,
-                     function_ref<void(Diagnostic &)> reasonCallback);
+  void notifyMatchFailure(Location loc,
+                          function_ref<void(Diagnostic &)> reasonCallback);
 
   //===--------------------------------------------------------------------===//
   // State
@@ -1236,10 +1235,11 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
       legalTypes.clear();
       if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
         Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
-        return notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
+        notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
           diag << "unable to convert type for " << valueDiagTag << " #"
                << it.index() << ", type was " << origType;
         });
+        return failure();
       }
       // TODO: There currently isn't any mechanism to do 1->N type conversion
       // via the PatternRewriter replacement API, so for now we just ignore it.
@@ -1419,7 +1419,7 @@ void ConversionPatternRewriterImpl::notifyBlockBeingInlined(
   blockActions.push_back(BlockAction::getInline(block, srcBlock, before));
 }
 
-LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
+void ConversionPatternRewriterImpl::notifyMatchFailure(
     Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
   LLVM_DEBUG({
     Diagnostic diag(loc, DiagnosticSeverity::Remark);
@@ -1428,7 +1428,6 @@ LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
     if (notifyCallback)
       notifyCallback(diag);
   });
-  return failure();
 }
 
 //===----------------------------------------------------------------------===//
@@ -1615,9 +1614,9 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) {
   rootUpdates.erase(rootUpdates.begin() + updateIdx);
 }
 
-LogicalResult ConversionPatternRewriter::notifyMatchFailure(
+void ConversionPatternRewriter::notifyMatchFailure(
     Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
-  return impl->notifyMatchFailure(loc, reasonCallback);
+  impl->notifyMatchFailure(loc, reasonCallback);
 }
 
 void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block,
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index d5395045af434d..bde8c290e774bc 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -387,7 +387,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
   void notifyBlockRemoved(Block *block) override;
 
   /// For debugging only: Notify the driver of a pattern match failure.
-  LogicalResult
+  void
   notifyMatchFailure(Location loc,
                      function_ref<void(Diagnostic &)> reasonCallback) override;
 
@@ -726,7 +726,7 @@ void GreedyPatternRewriteDriver::notifyOperationReplaced(
     config.listener->notifyOperationReplaced(op, replacement);
 }
 
-LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure(
+void GreedyPatternRewriteDriver::notifyMatchFailure(
     Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
   LLVM_DEBUG({
     Diagnostic diag(loc, DiagnosticSeverity::Remark);
@@ -734,8 +734,7 @@ LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure(
     logger.startLine() << "** Failure : " << diag.str() << "\n";
   });
   if (config.listener)
-    return config.listener->notifyMatchFailure(loc, reasonCallback);
-  return failure();
+    config.listener->notifyMatchFailure(loc, reasonCallback);
 }
 
 //===----------------------------------------------------------------------===//

@joker-eph
Copy link
Collaborator

NFC?

@matthias-springer matthias-springer changed the title [mlir][IR] Listener::notifyMatchFailure returns void [mlir][IR][NFC] Listener::notifyMatchFailure returns void Feb 7, 2024
@matthias-springer matthias-springer merged commit 9a028af into llvm:main Feb 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants