-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Transforms][NFC] Improve listener layering in dialect conversion #81236
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Transforms][NFC] Improve listener layering in dialect conversion #81236
Conversation
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesContext: Conversion patterns provide a In the current design, With this commit, Note: This is a re-upload of #80825 onto the llvm repository (instead of private repository), to enable dependent PRs. Full diff: https://github.com/llvm/llvm-project/pull/81236.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index b1ec1fe4ecd51a..f061d761ecefbb 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -632,8 +632,7 @@ struct ConversionPatternRewriterImpl;
/// This class implements a pattern rewriter for use with ConversionPatterns. It
/// extends the base PatternRewriter and provides special conversion specific
/// hooks.
-class ConversionPatternRewriter final : public PatternRewriter,
- public RewriterBase::Listener {
+class ConversionPatternRewriter final : public PatternRewriter {
public:
explicit ConversionPatternRewriter(MLIRContext *ctx);
~ConversionPatternRewriter() override;
@@ -712,10 +711,6 @@ class ConversionPatternRewriter final : public PatternRewriter,
/// implemented for dialect conversion.
void eraseBlock(Block *block) override;
- /// PatternRewriter hook creating a new block.
- void notifyBlockInserted(Block *block, Region *previous,
- Region::iterator previousIt) override;
-
/// PatternRewriter hook for splitting a block into two parts.
Block *splitBlock(Block *block, Block::iterator before) override;
@@ -724,9 +719,6 @@ class ConversionPatternRewriter final : public PatternRewriter,
ValueRange argValues = std::nullopt) override;
using PatternRewriter::inlineBlockBefore;
- /// PatternRewriter hook for inserting a new operation.
- void notifyOperationInserted(Operation *op, InsertPoint previous) override;
-
/// PatternRewriter hook for updating the given operation in-place.
/// Note: These methods only track updates to the given operation itself,
/// and not nested regions. Updates to regions will still require notification
@@ -739,18 +731,11 @@ class ConversionPatternRewriter final : public PatternRewriter,
/// PatternRewriter hook for updating the given operation in-place.
void cancelOpModification(Operation *op) override;
- /// PatternRewriter hook for notifying match failure reasons.
- void
- notifyMatchFailure(Location loc,
- function_ref<void(Diagnostic &)> reasonCallback) override;
- using PatternRewriter::notifyMatchFailure;
-
/// Return a reference to the internal implementation.
detail::ConversionPatternRewriterImpl &getImpl();
private:
// Hide unsupported pattern rewriter API.
- using OpBuilder::getListener;
using OpBuilder::setListener;
void moveOpBefore(Operation *op, Block *block,
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index 828f53c16d8f86..31e81107f655c0 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -582,7 +582,7 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
// Inside regular functions we use the blocking wait operation to wait for
// the async object (token, value or group) to become available.
if (!isInCoroutine) {
- ImplicitLocOpBuilder builder(loc, op, &rewriter);
+ ImplicitLocOpBuilder builder(loc, rewriter);
builder.create<RuntimeAwaitOp>(loc, operand);
// Assert that the awaited operands is not in the error state.
@@ -601,7 +601,7 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
CoroMachinery &coro = funcCoro->getSecond();
Block *suspended = op->getBlock();
- ImplicitLocOpBuilder builder(loc, op, &rewriter);
+ ImplicitLocOpBuilder builder(loc, rewriter);
MLIRContext *ctx = op->getContext();
// Save the coroutine state and resume on a runtime managed thread when
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index e90447084d68bd..e41231d7cbd390 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -825,7 +825,7 @@ void ArgConverter::insertConversion(Block *newBlock,
//===----------------------------------------------------------------------===//
namespace mlir {
namespace detail {
-struct ConversionPatternRewriterImpl {
+struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter)
: argConverter(rewriter, unresolvedMaterializations),
notifyCallback(nullptr) {}
@@ -903,15 +903,19 @@ struct ConversionPatternRewriterImpl {
// Rewriter Notification Hooks
//===--------------------------------------------------------------------===//
- /// PatternRewriter hook for replacing the results of an operation.
+ //// Notifies that an op was inserted.
+ void notifyOperationInserted(Operation *op,
+ OpBuilder::InsertPoint previous) override;
+
+ /// Notifies that an op is about to be replaced with the given values.
void notifyOpReplaced(Operation *op, ValueRange newValues);
/// Notifies that a block is about to be erased.
void notifyBlockIsBeingErased(Block *block);
- /// Notifies that a block was created.
- void notifyInsertedBlock(Block *block, Region *previous,
- Region::iterator previousIt);
+ /// Notifies that a block was inserted.
+ void notifyBlockInserted(Block *block, Region *previous,
+ Region::iterator previousIt) override;
/// Notifies that a block was split.
void notifySplitBlock(Block *block, Block *continuation);
@@ -921,8 +925,9 @@ struct ConversionPatternRewriterImpl {
Block::iterator before);
/// Notifies that a pattern match failed for the given reason.
- void notifyMatchFailure(Location loc,
- function_ref<void(Diagnostic &)> reasonCallback);
+ void
+ notifyMatchFailure(Location loc,
+ function_ref<void(Diagnostic &)> reasonCallback) override;
//===--------------------------------------------------------------------===//
// State
@@ -1363,6 +1368,16 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
//===----------------------------------------------------------------------===//
// Rewriter Notification Hooks
+void ConversionPatternRewriterImpl::notifyOperationInserted(
+ Operation *op, OpBuilder::InsertPoint previous) {
+ assert(!previous.isSet() && "expected newly created op");
+ LLVM_DEBUG({
+ logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
+ << ")\n";
+ });
+ createdOps.push_back(op);
+}
+
void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
ValueRange newValues) {
assert(newValues.size() == op->getNumResults());
@@ -1398,7 +1413,7 @@ void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
blockActions.push_back(BlockAction::getErase(block, {region, origNextBlock}));
}
-void ConversionPatternRewriterImpl::notifyInsertedBlock(
+void ConversionPatternRewriterImpl::notifyBlockInserted(
Block *block, Region *previous, Region::iterator previousIt) {
if (!previous) {
// This is a newly created block.
@@ -1437,7 +1452,7 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
: PatternRewriter(ctx),
impl(new detail::ConversionPatternRewriterImpl(*this)) {
- setListener(this);
+ setListener(impl.get());
}
ConversionPatternRewriter::~ConversionPatternRewriter() = default;
@@ -1540,11 +1555,6 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
results);
}
-void ConversionPatternRewriter::notifyBlockInserted(
- Block *block, Region *previous, Region::iterator previousIt) {
- impl->notifyInsertedBlock(block, previous, previousIt);
-}
-
Block *ConversionPatternRewriter::splitBlock(Block *block,
Block::iterator before) {
auto *continuation = block->splitBlock(before);
@@ -1572,16 +1582,6 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
eraseBlock(source);
}
-void ConversionPatternRewriter::notifyOperationInserted(Operation *op,
- InsertPoint previous) {
- assert(!previous.isSet() && "expected newly created op");
- LLVM_DEBUG({
- impl->logger.startLine()
- << "** Insert : '" << op->getName() << "'(" << op << ")\n";
- });
- impl->createdOps.push_back(op);
-}
-
void ConversionPatternRewriter::startOpModification(Operation *op) {
#ifndef NDEBUG
impl->pendingRootUpdates.insert(op);
@@ -1614,11 +1614,6 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) {
rootUpdates.erase(rootUpdates.begin() + updateIdx);
}
-void ConversionPatternRewriter::notifyMatchFailure(
- Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
- impl->notifyMatchFailure(loc, reasonCallback);
-}
-
void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block,
Block::iterator iterator) {
llvm_unreachable(
|
@llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesContext: Conversion patterns provide a In the current design, With this commit, Note: This is a re-upload of #80825 onto the llvm repository (instead of private repository), to enable dependent PRs. Full diff: https://github.com/llvm/llvm-project/pull/81236.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index b1ec1fe4ecd51a..f061d761ecefbb 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -632,8 +632,7 @@ struct ConversionPatternRewriterImpl;
/// This class implements a pattern rewriter for use with ConversionPatterns. It
/// extends the base PatternRewriter and provides special conversion specific
/// hooks.
-class ConversionPatternRewriter final : public PatternRewriter,
- public RewriterBase::Listener {
+class ConversionPatternRewriter final : public PatternRewriter {
public:
explicit ConversionPatternRewriter(MLIRContext *ctx);
~ConversionPatternRewriter() override;
@@ -712,10 +711,6 @@ class ConversionPatternRewriter final : public PatternRewriter,
/// implemented for dialect conversion.
void eraseBlock(Block *block) override;
- /// PatternRewriter hook creating a new block.
- void notifyBlockInserted(Block *block, Region *previous,
- Region::iterator previousIt) override;
-
/// PatternRewriter hook for splitting a block into two parts.
Block *splitBlock(Block *block, Block::iterator before) override;
@@ -724,9 +719,6 @@ class ConversionPatternRewriter final : public PatternRewriter,
ValueRange argValues = std::nullopt) override;
using PatternRewriter::inlineBlockBefore;
- /// PatternRewriter hook for inserting a new operation.
- void notifyOperationInserted(Operation *op, InsertPoint previous) override;
-
/// PatternRewriter hook for updating the given operation in-place.
/// Note: These methods only track updates to the given operation itself,
/// and not nested regions. Updates to regions will still require notification
@@ -739,18 +731,11 @@ class ConversionPatternRewriter final : public PatternRewriter,
/// PatternRewriter hook for updating the given operation in-place.
void cancelOpModification(Operation *op) override;
- /// PatternRewriter hook for notifying match failure reasons.
- void
- notifyMatchFailure(Location loc,
- function_ref<void(Diagnostic &)> reasonCallback) override;
- using PatternRewriter::notifyMatchFailure;
-
/// Return a reference to the internal implementation.
detail::ConversionPatternRewriterImpl &getImpl();
private:
// Hide unsupported pattern rewriter API.
- using OpBuilder::getListener;
using OpBuilder::setListener;
void moveOpBefore(Operation *op, Block *block,
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index 828f53c16d8f86..31e81107f655c0 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -582,7 +582,7 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
// Inside regular functions we use the blocking wait operation to wait for
// the async object (token, value or group) to become available.
if (!isInCoroutine) {
- ImplicitLocOpBuilder builder(loc, op, &rewriter);
+ ImplicitLocOpBuilder builder(loc, rewriter);
builder.create<RuntimeAwaitOp>(loc, operand);
// Assert that the awaited operands is not in the error state.
@@ -601,7 +601,7 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
CoroMachinery &coro = funcCoro->getSecond();
Block *suspended = op->getBlock();
- ImplicitLocOpBuilder builder(loc, op, &rewriter);
+ ImplicitLocOpBuilder builder(loc, rewriter);
MLIRContext *ctx = op->getContext();
// Save the coroutine state and resume on a runtime managed thread when
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index e90447084d68bd..e41231d7cbd390 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -825,7 +825,7 @@ void ArgConverter::insertConversion(Block *newBlock,
//===----------------------------------------------------------------------===//
namespace mlir {
namespace detail {
-struct ConversionPatternRewriterImpl {
+struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter)
: argConverter(rewriter, unresolvedMaterializations),
notifyCallback(nullptr) {}
@@ -903,15 +903,19 @@ struct ConversionPatternRewriterImpl {
// Rewriter Notification Hooks
//===--------------------------------------------------------------------===//
- /// PatternRewriter hook for replacing the results of an operation.
+ //// Notifies that an op was inserted.
+ void notifyOperationInserted(Operation *op,
+ OpBuilder::InsertPoint previous) override;
+
+ /// Notifies that an op is about to be replaced with the given values.
void notifyOpReplaced(Operation *op, ValueRange newValues);
/// Notifies that a block is about to be erased.
void notifyBlockIsBeingErased(Block *block);
- /// Notifies that a block was created.
- void notifyInsertedBlock(Block *block, Region *previous,
- Region::iterator previousIt);
+ /// Notifies that a block was inserted.
+ void notifyBlockInserted(Block *block, Region *previous,
+ Region::iterator previousIt) override;
/// Notifies that a block was split.
void notifySplitBlock(Block *block, Block *continuation);
@@ -921,8 +925,9 @@ struct ConversionPatternRewriterImpl {
Block::iterator before);
/// Notifies that a pattern match failed for the given reason.
- void notifyMatchFailure(Location loc,
- function_ref<void(Diagnostic &)> reasonCallback);
+ void
+ notifyMatchFailure(Location loc,
+ function_ref<void(Diagnostic &)> reasonCallback) override;
//===--------------------------------------------------------------------===//
// State
@@ -1363,6 +1368,16 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
//===----------------------------------------------------------------------===//
// Rewriter Notification Hooks
+void ConversionPatternRewriterImpl::notifyOperationInserted(
+ Operation *op, OpBuilder::InsertPoint previous) {
+ assert(!previous.isSet() && "expected newly created op");
+ LLVM_DEBUG({
+ logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
+ << ")\n";
+ });
+ createdOps.push_back(op);
+}
+
void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
ValueRange newValues) {
assert(newValues.size() == op->getNumResults());
@@ -1398,7 +1413,7 @@ void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
blockActions.push_back(BlockAction::getErase(block, {region, origNextBlock}));
}
-void ConversionPatternRewriterImpl::notifyInsertedBlock(
+void ConversionPatternRewriterImpl::notifyBlockInserted(
Block *block, Region *previous, Region::iterator previousIt) {
if (!previous) {
// This is a newly created block.
@@ -1437,7 +1452,7 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
: PatternRewriter(ctx),
impl(new detail::ConversionPatternRewriterImpl(*this)) {
- setListener(this);
+ setListener(impl.get());
}
ConversionPatternRewriter::~ConversionPatternRewriter() = default;
@@ -1540,11 +1555,6 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
results);
}
-void ConversionPatternRewriter::notifyBlockInserted(
- Block *block, Region *previous, Region::iterator previousIt) {
- impl->notifyInsertedBlock(block, previous, previousIt);
-}
-
Block *ConversionPatternRewriter::splitBlock(Block *block,
Block::iterator before) {
auto *continuation = block->splitBlock(before);
@@ -1572,16 +1582,6 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
eraseBlock(source);
}
-void ConversionPatternRewriter::notifyOperationInserted(Operation *op,
- InsertPoint previous) {
- assert(!previous.isSet() && "expected newly created op");
- LLVM_DEBUG({
- impl->logger.startLine()
- << "** Insert : '" << op->getName() << "'(" << op << ")\n";
- });
- impl->createdOps.push_back(op);
-}
-
void ConversionPatternRewriter::startOpModification(Operation *op) {
#ifndef NDEBUG
impl->pendingRootUpdates.insert(op);
@@ -1614,11 +1614,6 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) {
rootUpdates.erase(rootUpdates.begin() + updateIdx);
}
-void ConversionPatternRewriter::notifyMatchFailure(
- Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
- impl->notifyMatchFailure(loc, reasonCallback);
-}
-
void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block,
Block::iterator iterator) {
llvm_unreachable(
|
Context: Conversion patterns provide a `ConversionPatternRewriter` to modify the IR. `ConversionPatternRewriter` provides the public API. Most function calls are forwarded/handled by `ConversionPatternRewriterImpl`. The dialect conversion uses the listener infrastructure to get notified about op/block insertions. In the current design, `ConversionPatternRewriter` inherits from both `PatternRewriter` and `Listener`. The conversion rewriter registers itself as a listener. This is problematic because listener functions such as `notifyOperationInserted` are now part of the public API and can be called from conversion patterns; that would bring the dialect conversion into an inconsistent state. With this commit, `ConversionPatternRewriter` no longer inherits from `Listener`. Instead `ConversionPatternRewriterImpl` inherits from `Listener`. This removes the problematic public API and also simplifies the code a bit: block/op insertion notifications were previously forwarded to the `ConversionPatternRewriterImpl`. This is no longer needed.
37fe558
to
baa249a
Compare
This is ready for review. If this looks good, I can start landing the first pieces. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has always bothered be, thanks for improving.
Context: Conversion patterns provide a
ConversionPatternRewriter
to modify the IR.ConversionPatternRewriter
provides the public API. Most function calls are forwarded/handled byConversionPatternRewriterImpl
. The dialect conversion uses the listener infrastructure to get notified about op/block insertions.In the current design,
ConversionPatternRewriter
inherits from bothPatternRewriter
andListener
. The conversion rewriter registers itself as a listener. This is problematic because listener functions such asnotifyOperationInserted
are now part of the public API and can be called from conversion patterns; that would bring the dialect conversion into an inconsistent state.With this commit,
ConversionPatternRewriter
no longer inherits fromListener
. InsteadConversionPatternRewriterImpl
inherits fromListener
. This removes the problematic public API and also simplifies the code a bit: block/op insertion notifications were previously forwarded to theConversionPatternRewriterImpl
. This is no longer needed.Note: This is a re-upload of #80825 onto the llvm repository (instead of private repository), to enable dependent PRs.