Skip to content

[mlir][Transforms][NFC] Improve listener layering in dialect conversion #80825

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Feb 6, 2024

Context: Conversion patterns provide a ConversionPatternRewriter to modify the IR. ConversionPatternRewriter provides the public API. Most function calls are forwarded/handled by ConversionPatternRewriterImpl. The dialect conversion uses the listener infrastructure to get notified about op/block insertions.

In the current design, ConversionPatternRewriter inherits from both PatternRewriter and Listener. The conversion rewriter registers itself as a listener. This is problematic because listener functions such as notifyOperationInserted are now part of the public API and can be called from conversion patterns; that would bring the dialect conversion into an inconsistent state.

With this commit, ConversionPatternRewriter no longer inherits from Listener. Instead ConversionPatternRewriterImpl inherits from Listener. This removes the problematic public API and also simplifies the code a bit: block/op insertion notifications were previously forwarded to the ConversionPatternRewriterImpl. This is no longer needed.

@llvmbot
Copy link
Member

llvmbot commented Feb 6, 2024

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir-async

Author: Matthias Springer (matthias-springer)

Changes

Context: Conversion patterns provide a ConversionPatternRewriter to modify the IR. ConversionPatternRewriter provides the public API. Most function calls are forwarded/handled by ConversionPatternRewriterImpl. The dialect conversion uses the listener infrastructure to get notified about op/block insertions.

In the current design, ConversionPatternRewriter inherits from both PatternRewriter and Listener. The conversion rewriter registers itself as a listener. This is problematic because listener functions such as notifyOperationInserted are now part of the public API and can be called from conversion patterns; that would bring the dialect conversion into an inconsistent state.

With this commit, ConversionPatternRewriter no longer inherits from Listener. Instead ConversionPatternRewriterImpl inherits from Listener. This removes the problematic public API and also simplifies the code a bit: block/op insertion notifications were previously forwarded to the ConversionPatternRewriterImpl. This is no longer needed.

Depends on #80704. Review only the top commit.


Full diff: https://github.com/llvm/llvm-project/pull/80825.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h (+1-1)
  • (modified) mlir/include/mlir/IR/PatternMatch.h (+6-13)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+1-15)
  • (modified) mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp (+4-2)
  • (modified) mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp (+1-2)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+26-32)
  • (modified) mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp (+3-4)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index c2e3cde8ebc69f..2e096e1f552924 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -992,7 +992,7 @@ class TrackingListener : public RewriterBase::Listener,
   /// Notify the listener that the pattern failed to match the given operation,
   /// and provide a callback to populate a diagnostic with the reason why the
   /// failure occurred.
-  LogicalResult
+  void
   notifyMatchFailure(Location loc,
                      function_ref<void(Diagnostic &)> reasonCallback) override;
 
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 61da27825e870c..78dcfe7f6fc3d2 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -437,11 +437,9 @@ class RewriterBase : public OpBuilder {
     /// reason why the failure occurred. This method allows for derived
     /// listeners to optionally hook into the reason why a rewrite failed, and
     /// display it to users.
-    virtual LogicalResult
+    virtual void
     notifyMatchFailure(Location loc,
-                       function_ref<void(Diagnostic &)> reasonCallback) {
-      return failure();
-    }
+                       function_ref<void(Diagnostic &)> reasonCallback) {}
 
     static bool classof(const OpBuilder::Listener *base);
   };
@@ -480,12 +478,11 @@ class RewriterBase : public OpBuilder {
       if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
         rewriteListener->notifyOperationRemoved(op);
     }
-    LogicalResult notifyMatchFailure(
+    void notifyMatchFailure(
         Location loc,
         function_ref<void(Diagnostic &)> reasonCallback) override {
       if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
-        return rewriteListener->notifyMatchFailure(loc, reasonCallback);
-      return failure();
+        rewriteListener->notifyMatchFailure(loc, reasonCallback);
     }
 
   private:
@@ -688,20 +685,16 @@ class RewriterBase : public OpBuilder {
   template <typename CallbackT>
   std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
   notifyMatchFailure(Location loc, CallbackT &&reasonCallback) {
-#ifndef NDEBUG
     if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
-      return rewriteListener->notifyMatchFailure(
+      rewriteListener->notifyMatchFailure(
           loc, function_ref<void(Diagnostic &)>(reasonCallback));
     return failure();
-#else
-    return failure();
-#endif
   }
   template <typename CallbackT>
   std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
   notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) {
     if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
-      return rewriteListener->notifyMatchFailure(
+      rewriteListener->notifyMatchFailure(
           op->getLoc(), function_ref<void(Diagnostic &)>(reasonCallback));
     return failure();
   }
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 51e3e413b516f4..dd1d1a3f707edb 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -632,8 +632,7 @@ struct ConversionPatternRewriterImpl;
 /// This class implements a pattern rewriter for use with ConversionPatterns. It
 /// extends the base PatternRewriter and provides special conversion specific
 /// hooks.
-class ConversionPatternRewriter final : public PatternRewriter,
-                                        public RewriterBase::Listener {
+class ConversionPatternRewriter final : public PatternRewriter {
 public:
   explicit ConversionPatternRewriter(MLIRContext *ctx);
   ~ConversionPatternRewriter() override;
@@ -712,10 +711,6 @@ class ConversionPatternRewriter final : public PatternRewriter,
   /// implemented for dialect conversion.
   void eraseBlock(Block *block) override;
 
-  /// PatternRewriter hook creating a new block.
-  void notifyBlockInserted(Block *block, Region *previous,
-                           Region::iterator previousIt) override;
-
   /// PatternRewriter hook for splitting a block into two parts.
   Block *splitBlock(Block *block, Block::iterator before) override;
 
@@ -724,9 +719,6 @@ class ConversionPatternRewriter final : public PatternRewriter,
                          ValueRange argValues = std::nullopt) override;
   using PatternRewriter::inlineBlockBefore;
 
-  /// PatternRewriter hook for inserting a new operation.
-  void notifyOperationInserted(Operation *op, InsertPoint previous) override;
-
   /// PatternRewriter hook for updating the given operation in-place.
   /// Note: These methods only track updates to the given operation itself,
   /// and not nested regions. Updates to regions will still require notification
@@ -739,12 +731,6 @@ class ConversionPatternRewriter final : public PatternRewriter,
   /// PatternRewriter hook for updating the given operation in-place.
   void cancelOpModification(Operation *op) override;
 
-  /// PatternRewriter hook for notifying match failure reasons.
-  LogicalResult
-  notifyMatchFailure(Location loc,
-                     function_ref<void(Diagnostic &)> reasonCallback) override;
-  using PatternRewriter::notifyMatchFailure;
-
   /// Return a reference to the internal implementation.
   detail::ConversionPatternRewriterImpl &getImpl();
 
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index 828f53c16d8f86..6dc6a8bc8ccc76 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -582,7 +582,8 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
     // Inside regular functions we use the blocking wait operation to wait for
     // the async object (token, value or group) to become available.
     if (!isInCoroutine) {
-      ImplicitLocOpBuilder builder(loc, op, &rewriter);
+      ImplicitLocOpBuilder builder(loc, rewriter);
+      builder.setInsertionPoint(op);
       builder.create<RuntimeAwaitOp>(loc, operand);
 
       // Assert that the awaited operands is not in the error state.
@@ -601,7 +602,8 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
       CoroMachinery &coro = funcCoro->getSecond();
       Block *suspended = op->getBlock();
 
-      ImplicitLocOpBuilder builder(loc, op, &rewriter);
+      ImplicitLocOpBuilder builder(loc, rewriter);
+      builder.setInsertionPoint(op);
       MLIRContext *ctx = op->getContext();
 
       // Save the coroutine state and resume on a runtime managed thread when
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 371ad904dcae5a..a964c205b62e84 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -1265,14 +1265,13 @@ DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
   return diag;
 }
 
-LogicalResult transform::TrackingListener::notifyMatchFailure(
+void transform::TrackingListener::notifyMatchFailure(
     Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
   LLVM_DEBUG({
     Diagnostic diag(loc, DiagnosticSeverity::Remark);
     reasonCallback(diag);
     DBGS() << "Match Failure : " << diag.str() << "\n";
   });
-  return failure();
 }
 
 void transform::TrackingListener::notifyOperationRemoved(Operation *op) {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 346135fb447227..e41231d7cbd390 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -825,7 +825,7 @@ void ArgConverter::insertConversion(Block *newBlock,
 //===----------------------------------------------------------------------===//
 namespace mlir {
 namespace detail {
-struct ConversionPatternRewriterImpl {
+struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter)
       : argConverter(rewriter, unresolvedMaterializations),
         notifyCallback(nullptr) {}
@@ -903,15 +903,19 @@ struct ConversionPatternRewriterImpl {
   // Rewriter Notification Hooks
   //===--------------------------------------------------------------------===//
 
-  /// PatternRewriter hook for replacing the results of an operation.
+  //// Notifies that an op was inserted.
+  void notifyOperationInserted(Operation *op,
+                               OpBuilder::InsertPoint previous) override;
+
+  /// Notifies that an op is about to be replaced with the given values.
   void notifyOpReplaced(Operation *op, ValueRange newValues);
 
   /// Notifies that a block is about to be erased.
   void notifyBlockIsBeingErased(Block *block);
 
-  /// Notifies that a block was created.
-  void notifyInsertedBlock(Block *block, Region *previous,
-                           Region::iterator previousIt);
+  /// Notifies that a block was inserted.
+  void notifyBlockInserted(Block *block, Region *previous,
+                           Region::iterator previousIt) override;
 
   /// Notifies that a block was split.
   void notifySplitBlock(Block *block, Block *continuation);
@@ -921,9 +925,9 @@ struct ConversionPatternRewriterImpl {
                                Block::iterator before);
 
   /// Notifies that a pattern match failed for the given reason.
-  LogicalResult
+  void
   notifyMatchFailure(Location loc,
-                     function_ref<void(Diagnostic &)> reasonCallback);
+                     function_ref<void(Diagnostic &)> reasonCallback) override;
 
   //===--------------------------------------------------------------------===//
   // State
@@ -1236,10 +1240,11 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
       legalTypes.clear();
       if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
         Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
-        return notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
+        notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
           diag << "unable to convert type for " << valueDiagTag << " #"
                << it.index() << ", type was " << origType;
         });
+        return failure();
       }
       // TODO: There currently isn't any mechanism to do 1->N type conversion
       // via the PatternRewriter replacement API, so for now we just ignore it.
@@ -1363,6 +1368,16 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
 //===----------------------------------------------------------------------===//
 // Rewriter Notification Hooks
 
+void ConversionPatternRewriterImpl::notifyOperationInserted(
+    Operation *op, OpBuilder::InsertPoint previous) {
+  assert(!previous.isSet() && "expected newly created op");
+  LLVM_DEBUG({
+    logger.startLine() << "** Insert  : '" << op->getName() << "'(" << op
+                       << ")\n";
+  });
+  createdOps.push_back(op);
+}
+
 void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
                                                      ValueRange newValues) {
   assert(newValues.size() == op->getNumResults());
@@ -1398,7 +1413,7 @@ void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
   blockActions.push_back(BlockAction::getErase(block, {region, origNextBlock}));
 }
 
-void ConversionPatternRewriterImpl::notifyInsertedBlock(
+void ConversionPatternRewriterImpl::notifyBlockInserted(
     Block *block, Region *previous, Region::iterator previousIt) {
   if (!previous) {
     // This is a newly created block.
@@ -1419,7 +1434,7 @@ void ConversionPatternRewriterImpl::notifyBlockBeingInlined(
   blockActions.push_back(BlockAction::getInline(block, srcBlock, before));
 }
 
-LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
+void ConversionPatternRewriterImpl::notifyMatchFailure(
     Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
   LLVM_DEBUG({
     Diagnostic diag(loc, DiagnosticSeverity::Remark);
@@ -1428,7 +1443,6 @@ LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
     if (notifyCallback)
       notifyCallback(diag);
   });
-  return failure();
 }
 
 //===----------------------------------------------------------------------===//
@@ -1438,7 +1452,7 @@ LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
 ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
     : PatternRewriter(ctx),
       impl(new detail::ConversionPatternRewriterImpl(*this)) {
-  setListener(this);
+  setListener(impl.get());
 }
 
 ConversionPatternRewriter::~ConversionPatternRewriter() = default;
@@ -1541,11 +1555,6 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
                            results);
 }
 
-void ConversionPatternRewriter::notifyBlockInserted(
-    Block *block, Region *previous, Region::iterator previousIt) {
-  impl->notifyInsertedBlock(block, previous, previousIt);
-}
-
 Block *ConversionPatternRewriter::splitBlock(Block *block,
                                              Block::iterator before) {
   auto *continuation = block->splitBlock(before);
@@ -1573,16 +1582,6 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
   eraseBlock(source);
 }
 
-void ConversionPatternRewriter::notifyOperationInserted(Operation *op,
-                                                        InsertPoint previous) {
-  assert(!previous.isSet() && "expected newly created op");
-  LLVM_DEBUG({
-    impl->logger.startLine()
-        << "** Insert  : '" << op->getName() << "'(" << op << ")\n";
-  });
-  impl->createdOps.push_back(op);
-}
-
 void ConversionPatternRewriter::startOpModification(Operation *op) {
 #ifndef NDEBUG
   impl->pendingRootUpdates.insert(op);
@@ -1615,11 +1614,6 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) {
   rootUpdates.erase(rootUpdates.begin() + updateIdx);
 }
 
-LogicalResult ConversionPatternRewriter::notifyMatchFailure(
-    Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
-  return impl->notifyMatchFailure(loc, reasonCallback);
-}
-
 void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block,
                                              Block::iterator iterator) {
   llvm_unreachable(
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index d5395045af434d..bde8c290e774bc 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -387,7 +387,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
   void notifyBlockRemoved(Block *block) override;
 
   /// For debugging only: Notify the driver of a pattern match failure.
-  LogicalResult
+  void
   notifyMatchFailure(Location loc,
                      function_ref<void(Diagnostic &)> reasonCallback) override;
 
@@ -726,7 +726,7 @@ void GreedyPatternRewriteDriver::notifyOperationReplaced(
     config.listener->notifyOperationReplaced(op, replacement);
 }
 
-LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure(
+void GreedyPatternRewriteDriver::notifyMatchFailure(
     Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
   LLVM_DEBUG({
     Diagnostic diag(loc, DiagnosticSeverity::Remark);
@@ -734,8 +734,7 @@ LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure(
     logger.startLine() << "** Failure : " << diag.str() << "\n";
   });
   if (config.listener)
-    return config.listener->notifyMatchFailure(loc, reasonCallback);
-  return failure();
+    config.listener->notifyMatchFailure(loc, reasonCallback);
 }
 
 //===----------------------------------------------------------------------===//

@llvmbot
Copy link
Member

llvmbot commented Feb 6, 2024

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Context: Conversion patterns provide a ConversionPatternRewriter to modify the IR. ConversionPatternRewriter provides the public API. Most function calls are forwarded/handled by ConversionPatternRewriterImpl. The dialect conversion uses the listener infrastructure to get notified about op/block insertions.

In the current design, ConversionPatternRewriter inherits from both PatternRewriter and Listener. The conversion rewriter registers itself as a listener. This is problematic because listener functions such as notifyOperationInserted are now part of the public API and can be called from conversion patterns; that would bring the dialect conversion into an inconsistent state.

With this commit, ConversionPatternRewriter no longer inherits from Listener. Instead ConversionPatternRewriterImpl inherits from Listener. This removes the problematic public API and also simplifies the code a bit: block/op insertion notifications were previously forwarded to the ConversionPatternRewriterImpl. This is no longer needed.

Depends on #80704. Review only the top commit.


Full diff: https://github.com/llvm/llvm-project/pull/80825.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h (+1-1)
  • (modified) mlir/include/mlir/IR/PatternMatch.h (+6-13)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+1-15)
  • (modified) mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp (+4-2)
  • (modified) mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp (+1-2)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+26-32)
  • (modified) mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp (+3-4)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index c2e3cde8ebc69..2e096e1f55292 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -992,7 +992,7 @@ class TrackingListener : public RewriterBase::Listener,
   /// Notify the listener that the pattern failed to match the given operation,
   /// and provide a callback to populate a diagnostic with the reason why the
   /// failure occurred.
-  LogicalResult
+  void
   notifyMatchFailure(Location loc,
                      function_ref<void(Diagnostic &)> reasonCallback) override;
 
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 61da27825e870..78dcfe7f6fc3d 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -437,11 +437,9 @@ class RewriterBase : public OpBuilder {
     /// reason why the failure occurred. This method allows for derived
     /// listeners to optionally hook into the reason why a rewrite failed, and
     /// display it to users.
-    virtual LogicalResult
+    virtual void
     notifyMatchFailure(Location loc,
-                       function_ref<void(Diagnostic &)> reasonCallback) {
-      return failure();
-    }
+                       function_ref<void(Diagnostic &)> reasonCallback) {}
 
     static bool classof(const OpBuilder::Listener *base);
   };
@@ -480,12 +478,11 @@ class RewriterBase : public OpBuilder {
       if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
         rewriteListener->notifyOperationRemoved(op);
     }
-    LogicalResult notifyMatchFailure(
+    void notifyMatchFailure(
         Location loc,
         function_ref<void(Diagnostic &)> reasonCallback) override {
       if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
-        return rewriteListener->notifyMatchFailure(loc, reasonCallback);
-      return failure();
+        rewriteListener->notifyMatchFailure(loc, reasonCallback);
     }
 
   private:
@@ -688,20 +685,16 @@ class RewriterBase : public OpBuilder {
   template <typename CallbackT>
   std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
   notifyMatchFailure(Location loc, CallbackT &&reasonCallback) {
-#ifndef NDEBUG
     if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
-      return rewriteListener->notifyMatchFailure(
+      rewriteListener->notifyMatchFailure(
           loc, function_ref<void(Diagnostic &)>(reasonCallback));
     return failure();
-#else
-    return failure();
-#endif
   }
   template <typename CallbackT>
   std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
   notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) {
     if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
-      return rewriteListener->notifyMatchFailure(
+      rewriteListener->notifyMatchFailure(
           op->getLoc(), function_ref<void(Diagnostic &)>(reasonCallback));
     return failure();
   }
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 51e3e413b516f..dd1d1a3f707ed 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -632,8 +632,7 @@ struct ConversionPatternRewriterImpl;
 /// This class implements a pattern rewriter for use with ConversionPatterns. It
 /// extends the base PatternRewriter and provides special conversion specific
 /// hooks.
-class ConversionPatternRewriter final : public PatternRewriter,
-                                        public RewriterBase::Listener {
+class ConversionPatternRewriter final : public PatternRewriter {
 public:
   explicit ConversionPatternRewriter(MLIRContext *ctx);
   ~ConversionPatternRewriter() override;
@@ -712,10 +711,6 @@ class ConversionPatternRewriter final : public PatternRewriter,
   /// implemented for dialect conversion.
   void eraseBlock(Block *block) override;
 
-  /// PatternRewriter hook creating a new block.
-  void notifyBlockInserted(Block *block, Region *previous,
-                           Region::iterator previousIt) override;
-
   /// PatternRewriter hook for splitting a block into two parts.
   Block *splitBlock(Block *block, Block::iterator before) override;
 
@@ -724,9 +719,6 @@ class ConversionPatternRewriter final : public PatternRewriter,
                          ValueRange argValues = std::nullopt) override;
   using PatternRewriter::inlineBlockBefore;
 
-  /// PatternRewriter hook for inserting a new operation.
-  void notifyOperationInserted(Operation *op, InsertPoint previous) override;
-
   /// PatternRewriter hook for updating the given operation in-place.
   /// Note: These methods only track updates to the given operation itself,
   /// and not nested regions. Updates to regions will still require notification
@@ -739,12 +731,6 @@ class ConversionPatternRewriter final : public PatternRewriter,
   /// PatternRewriter hook for updating the given operation in-place.
   void cancelOpModification(Operation *op) override;
 
-  /// PatternRewriter hook for notifying match failure reasons.
-  LogicalResult
-  notifyMatchFailure(Location loc,
-                     function_ref<void(Diagnostic &)> reasonCallback) override;
-  using PatternRewriter::notifyMatchFailure;
-
   /// Return a reference to the internal implementation.
   detail::ConversionPatternRewriterImpl &getImpl();
 
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index 828f53c16d8f8..6dc6a8bc8ccc7 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -582,7 +582,8 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
     // Inside regular functions we use the blocking wait operation to wait for
     // the async object (token, value or group) to become available.
     if (!isInCoroutine) {
-      ImplicitLocOpBuilder builder(loc, op, &rewriter);
+      ImplicitLocOpBuilder builder(loc, rewriter);
+      builder.setInsertionPoint(op);
       builder.create<RuntimeAwaitOp>(loc, operand);
 
       // Assert that the awaited operands is not in the error state.
@@ -601,7 +602,8 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
       CoroMachinery &coro = funcCoro->getSecond();
       Block *suspended = op->getBlock();
 
-      ImplicitLocOpBuilder builder(loc, op, &rewriter);
+      ImplicitLocOpBuilder builder(loc, rewriter);
+      builder.setInsertionPoint(op);
       MLIRContext *ctx = op->getContext();
 
       // Save the coroutine state and resume on a runtime managed thread when
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 371ad904dcae5..a964c205b62e8 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -1265,14 +1265,13 @@ DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
   return diag;
 }
 
-LogicalResult transform::TrackingListener::notifyMatchFailure(
+void transform::TrackingListener::notifyMatchFailure(
     Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
   LLVM_DEBUG({
     Diagnostic diag(loc, DiagnosticSeverity::Remark);
     reasonCallback(diag);
     DBGS() << "Match Failure : " << diag.str() << "\n";
   });
-  return failure();
 }
 
 void transform::TrackingListener::notifyOperationRemoved(Operation *op) {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 346135fb44722..e41231d7cbd39 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -825,7 +825,7 @@ void ArgConverter::insertConversion(Block *newBlock,
 //===----------------------------------------------------------------------===//
 namespace mlir {
 namespace detail {
-struct ConversionPatternRewriterImpl {
+struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter)
       : argConverter(rewriter, unresolvedMaterializations),
         notifyCallback(nullptr) {}
@@ -903,15 +903,19 @@ struct ConversionPatternRewriterImpl {
   // Rewriter Notification Hooks
   //===--------------------------------------------------------------------===//
 
-  /// PatternRewriter hook for replacing the results of an operation.
+  //// Notifies that an op was inserted.
+  void notifyOperationInserted(Operation *op,
+                               OpBuilder::InsertPoint previous) override;
+
+  /// Notifies that an op is about to be replaced with the given values.
   void notifyOpReplaced(Operation *op, ValueRange newValues);
 
   /// Notifies that a block is about to be erased.
   void notifyBlockIsBeingErased(Block *block);
 
-  /// Notifies that a block was created.
-  void notifyInsertedBlock(Block *block, Region *previous,
-                           Region::iterator previousIt);
+  /// Notifies that a block was inserted.
+  void notifyBlockInserted(Block *block, Region *previous,
+                           Region::iterator previousIt) override;
 
   /// Notifies that a block was split.
   void notifySplitBlock(Block *block, Block *continuation);
@@ -921,9 +925,9 @@ struct ConversionPatternRewriterImpl {
                                Block::iterator before);
 
   /// Notifies that a pattern match failed for the given reason.
-  LogicalResult
+  void
   notifyMatchFailure(Location loc,
-                     function_ref<void(Diagnostic &)> reasonCallback);
+                     function_ref<void(Diagnostic &)> reasonCallback) override;
 
   //===--------------------------------------------------------------------===//
   // State
@@ -1236,10 +1240,11 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
       legalTypes.clear();
       if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
         Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
-        return notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
+        notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
           diag << "unable to convert type for " << valueDiagTag << " #"
                << it.index() << ", type was " << origType;
         });
+        return failure();
       }
       // TODO: There currently isn't any mechanism to do 1->N type conversion
       // via the PatternRewriter replacement API, so for now we just ignore it.
@@ -1363,6 +1368,16 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
 //===----------------------------------------------------------------------===//
 // Rewriter Notification Hooks
 
+void ConversionPatternRewriterImpl::notifyOperationInserted(
+    Operation *op, OpBuilder::InsertPoint previous) {
+  assert(!previous.isSet() && "expected newly created op");
+  LLVM_DEBUG({
+    logger.startLine() << "** Insert  : '" << op->getName() << "'(" << op
+                       << ")\n";
+  });
+  createdOps.push_back(op);
+}
+
 void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
                                                      ValueRange newValues) {
   assert(newValues.size() == op->getNumResults());
@@ -1398,7 +1413,7 @@ void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
   blockActions.push_back(BlockAction::getErase(block, {region, origNextBlock}));
 }
 
-void ConversionPatternRewriterImpl::notifyInsertedBlock(
+void ConversionPatternRewriterImpl::notifyBlockInserted(
     Block *block, Region *previous, Region::iterator previousIt) {
   if (!previous) {
     // This is a newly created block.
@@ -1419,7 +1434,7 @@ void ConversionPatternRewriterImpl::notifyBlockBeingInlined(
   blockActions.push_back(BlockAction::getInline(block, srcBlock, before));
 }
 
-LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
+void ConversionPatternRewriterImpl::notifyMatchFailure(
     Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
   LLVM_DEBUG({
     Diagnostic diag(loc, DiagnosticSeverity::Remark);
@@ -1428,7 +1443,6 @@ LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
     if (notifyCallback)
       notifyCallback(diag);
   });
-  return failure();
 }
 
 //===----------------------------------------------------------------------===//
@@ -1438,7 +1452,7 @@ LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
 ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
     : PatternRewriter(ctx),
       impl(new detail::ConversionPatternRewriterImpl(*this)) {
-  setListener(this);
+  setListener(impl.get());
 }
 
 ConversionPatternRewriter::~ConversionPatternRewriter() = default;
@@ -1541,11 +1555,6 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
                            results);
 }
 
-void ConversionPatternRewriter::notifyBlockInserted(
-    Block *block, Region *previous, Region::iterator previousIt) {
-  impl->notifyInsertedBlock(block, previous, previousIt);
-}
-
 Block *ConversionPatternRewriter::splitBlock(Block *block,
                                              Block::iterator before) {
   auto *continuation = block->splitBlock(before);
@@ -1573,16 +1582,6 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
   eraseBlock(source);
 }
 
-void ConversionPatternRewriter::notifyOperationInserted(Operation *op,
-                                                        InsertPoint previous) {
-  assert(!previous.isSet() && "expected newly created op");
-  LLVM_DEBUG({
-    impl->logger.startLine()
-        << "** Insert  : '" << op->getName() << "'(" << op << ")\n";
-  });
-  impl->createdOps.push_back(op);
-}
-
 void ConversionPatternRewriter::startOpModification(Operation *op) {
 #ifndef NDEBUG
   impl->pendingRootUpdates.insert(op);
@@ -1615,11 +1614,6 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) {
   rootUpdates.erase(rootUpdates.begin() + updateIdx);
 }
 
-LogicalResult ConversionPatternRewriter::notifyMatchFailure(
-    Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
-  return impl->notifyMatchFailure(loc, reasonCallback);
-}
-
 void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block,
                                              Block::iterator iterator) {
   llvm_unreachable(
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index d5395045af434..bde8c290e774b 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -387,7 +387,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
   void notifyBlockRemoved(Block *block) override;
 
   /// For debugging only: Notify the driver of a pattern match failure.
-  LogicalResult
+  void
   notifyMatchFailure(Location loc,
                      function_ref<void(Diagnostic &)> reasonCallback) override;
 
@@ -726,7 +726,7 @@ void GreedyPatternRewriteDriver::notifyOperationReplaced(
     config.listener->notifyOperationReplaced(op, replacement);
 }
 
-LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure(
+void GreedyPatternRewriteDriver::notifyMatchFailure(
     Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
   LLVM_DEBUG({
     Diagnostic diag(loc, DiagnosticSeverity::Remark);
@@ -734,8 +734,7 @@ LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure(
     logger.startLine() << "** Failure : " << diag.str() << "\n";
   });
   if (config.listener)
-    return config.listener->notifyMatchFailure(loc, reasonCallback);
-  return failure();
+    config.listener->notifyMatchFailure(loc, reasonCallback);
 }
 
 //===----------------------------------------------------------------------===//

@matthias-springer matthias-springer changed the title [mlir][IR][NFC] Improve listener layering in dialect conversion [mlir][Transforms][NFC] Improve listener layering in dialect conversion Feb 6, 2024
@matthias-springer matthias-springer force-pushed the conversion_pattern_rewriter_listener branch from 941a492 to 5371363 Compare February 6, 2024 11:18
@joker-eph
Copy link
Collaborator

Review only the top commit.

What about using proper stacked PR by pushing your branch into the repo?

@@ -582,7 +582,8 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
// Inside regular functions we use the blocking wait operation to wait for
// the async object (token, value or group) to become available.
if (!isInCoroutine) {
ImplicitLocOpBuilder builder(loc, op, &rewriter);
ImplicitLocOpBuilder builder(loc, rewriter);
builder.setInsertionPoint(op);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ImplicitLocOpBuilder builder(loc, op, &rewriter);

The above constructor sets the insertion point before op and attaches &rewriter as a listener.

After this commit, ConversionPatternRewriter no longer inherits from Listener, but we must still attach the same listener that is attached to rewriter. (rewriter used to have itself attached as a listener.) That's what ImplicitLocOpBuilder builder(loc, rewriter); does. There is no constructor that takes an existing OpBuilder (or subclass thereof) and op (to set the insertion point), so I have to manually set the insertion point with setInsertionPoint.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the rewriter already having the insertion point set? Why isn't ImplicitLocOpBuilder builder(loc, rewriter); taking the insertion point from there?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, this line is not needed. I didn't check what the insertion point of the rewriter was. It is also op.

@matthias-springer matthias-springer force-pushed the conversion_pattern_rewriter_listener branch from 5371363 to f088103 Compare February 7, 2024 08:18
@matthias-springer
Copy link
Member Author

Review only the top commit.

What about using proper stacked PR by pushing your branch into the repo?

Oh right, I'll give it a try on the next stack of commits.

@matthias-springer matthias-springer force-pushed the conversion_pattern_rewriter_listener branch from f088103 to c06724c Compare February 7, 2024 08:53
Context: Conversion patterns provide a `ConversionPatternRewriter` to modify the IR. `ConversionPatternRewriter` provides the public API. Most function calls are forwarded/handled by `ConversionPatternRewriterImpl`. The dialect conversion uses the listener infrastructure to get notified about op/block insertions.

In the current design, `ConversionPatternRewriter` inherits from both `PatternRewriter` and `Listener`. The conversion rewriter registers itself as a listener. This is problematic because listener functions such as `notifyOperationInserted` are now part of the public API and can be called from conversion patterns; that would bring the dialect conversion into an inconsistent state.

With this commit, `ConversionPatternRewriter` no longer inherits from `Listener`. Instead `ConversionPatternRewriterImpl` inherits from `Listener`. This removes the problematic public API and also simplifies the code a bit: block/op insertion notifications were previously forwarded to the `ConversionPatternRewriterImpl`. This is no longer needed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:async mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants