Skip to content

[mlir][Transforms][NFC] Improve listener layering in dialect conversion #80825

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 1 addition & 16 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -632,8 +632,7 @@ struct ConversionPatternRewriterImpl;
/// This class implements a pattern rewriter for use with ConversionPatterns. It
/// extends the base PatternRewriter and provides special conversion specific
/// hooks.
class ConversionPatternRewriter final : public PatternRewriter,
public RewriterBase::Listener {
class ConversionPatternRewriter final : public PatternRewriter {
public:
explicit ConversionPatternRewriter(MLIRContext *ctx);
~ConversionPatternRewriter() override;
Expand Down Expand Up @@ -712,10 +711,6 @@ class ConversionPatternRewriter final : public PatternRewriter,
/// implemented for dialect conversion.
void eraseBlock(Block *block) override;

/// PatternRewriter hook creating a new block.
void notifyBlockInserted(Block *block, Region *previous,
Region::iterator previousIt) override;

/// PatternRewriter hook for splitting a block into two parts.
Block *splitBlock(Block *block, Block::iterator before) override;

Expand All @@ -724,9 +719,6 @@ class ConversionPatternRewriter final : public PatternRewriter,
ValueRange argValues = std::nullopt) override;
using PatternRewriter::inlineBlockBefore;

/// PatternRewriter hook for inserting a new operation.
void notifyOperationInserted(Operation *op, InsertPoint previous) override;

/// PatternRewriter hook for updating the given operation in-place.
/// Note: These methods only track updates to the given operation itself,
/// and not nested regions. Updates to regions will still require notification
Expand All @@ -739,18 +731,11 @@ class ConversionPatternRewriter final : public PatternRewriter,
/// PatternRewriter hook for updating the given operation in-place.
void cancelOpModification(Operation *op) override;

/// PatternRewriter hook for notifying match failure reasons.
void
notifyMatchFailure(Location loc,
function_ref<void(Diagnostic &)> reasonCallback) override;
using PatternRewriter::notifyMatchFailure;

/// Return a reference to the internal implementation.
detail::ConversionPatternRewriterImpl &getImpl();

private:
// Hide unsupported pattern rewriter API.
using OpBuilder::getListener;
using OpBuilder::setListener;

void moveOpBefore(Operation *op, Block *block,
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
// Inside regular functions we use the blocking wait operation to wait for
// the async object (token, value or group) to become available.
if (!isInCoroutine) {
ImplicitLocOpBuilder builder(loc, op, &rewriter);
ImplicitLocOpBuilder builder(loc, rewriter);
builder.create<RuntimeAwaitOp>(loc, operand);

// Assert that the awaited operands is not in the error state.
Expand All @@ -601,7 +601,7 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
CoroMachinery &coro = funcCoro->getSecond();
Block *suspended = op->getBlock();

ImplicitLocOpBuilder builder(loc, op, &rewriter);
ImplicitLocOpBuilder builder(loc, rewriter);
MLIRContext *ctx = op->getContext();

// Save the coroutine state and resume on a runtime managed thread when
Expand Down
53 changes: 24 additions & 29 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,7 @@ void ArgConverter::insertConversion(Block *newBlock,
//===----------------------------------------------------------------------===//
namespace mlir {
namespace detail {
struct ConversionPatternRewriterImpl {
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter)
: argConverter(rewriter, unresolvedMaterializations),
notifyCallback(nullptr) {}
Expand Down Expand Up @@ -903,15 +903,19 @@ struct ConversionPatternRewriterImpl {
// Rewriter Notification Hooks
//===--------------------------------------------------------------------===//

/// PatternRewriter hook for replacing the results of an operation.
//// Notifies that an op was inserted.
void notifyOperationInserted(Operation *op,
OpBuilder::InsertPoint previous) override;

/// Notifies that an op is about to be replaced with the given values.
void notifyOpReplaced(Operation *op, ValueRange newValues);

/// Notifies that a block is about to be erased.
void notifyBlockIsBeingErased(Block *block);

/// Notifies that a block was created.
void notifyInsertedBlock(Block *block, Region *previous,
Region::iterator previousIt);
/// Notifies that a block was inserted.
void notifyBlockInserted(Block *block, Region *previous,
Region::iterator previousIt) override;

/// Notifies that a block was split.
void notifySplitBlock(Block *block, Block *continuation);
Expand All @@ -921,8 +925,9 @@ struct ConversionPatternRewriterImpl {
Block::iterator before);

/// Notifies that a pattern match failed for the given reason.
void notifyMatchFailure(Location loc,
function_ref<void(Diagnostic &)> reasonCallback);
void
notifyMatchFailure(Location loc,
function_ref<void(Diagnostic &)> reasonCallback) override;

//===--------------------------------------------------------------------===//
// State
Expand Down Expand Up @@ -1363,6 +1368,16 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
//===----------------------------------------------------------------------===//
// Rewriter Notification Hooks

void ConversionPatternRewriterImpl::notifyOperationInserted(
Operation *op, OpBuilder::InsertPoint previous) {
assert(!previous.isSet() && "expected newly created op");
LLVM_DEBUG({
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
<< ")\n";
});
createdOps.push_back(op);
}

void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
ValueRange newValues) {
assert(newValues.size() == op->getNumResults());
Expand Down Expand Up @@ -1398,7 +1413,7 @@ void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
blockActions.push_back(BlockAction::getErase(block, {region, origNextBlock}));
}

void ConversionPatternRewriterImpl::notifyInsertedBlock(
void ConversionPatternRewriterImpl::notifyBlockInserted(
Block *block, Region *previous, Region::iterator previousIt) {
if (!previous) {
// This is a newly created block.
Expand Down Expand Up @@ -1437,7 +1452,7 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
: PatternRewriter(ctx),
impl(new detail::ConversionPatternRewriterImpl(*this)) {
setListener(this);
setListener(impl.get());
}

ConversionPatternRewriter::~ConversionPatternRewriter() = default;
Expand Down Expand Up @@ -1540,11 +1555,6 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
results);
}

void ConversionPatternRewriter::notifyBlockInserted(
Block *block, Region *previous, Region::iterator previousIt) {
impl->notifyInsertedBlock(block, previous, previousIt);
}

Block *ConversionPatternRewriter::splitBlock(Block *block,
Block::iterator before) {
auto *continuation = block->splitBlock(before);
Expand Down Expand Up @@ -1572,16 +1582,6 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
eraseBlock(source);
}

void ConversionPatternRewriter::notifyOperationInserted(Operation *op,
InsertPoint previous) {
assert(!previous.isSet() && "expected newly created op");
LLVM_DEBUG({
impl->logger.startLine()
<< "** Insert : '" << op->getName() << "'(" << op << ")\n";
});
impl->createdOps.push_back(op);
}

void ConversionPatternRewriter::startOpModification(Operation *op) {
#ifndef NDEBUG
impl->pendingRootUpdates.insert(op);
Expand Down Expand Up @@ -1614,11 +1614,6 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) {
rootUpdates.erase(rootUpdates.begin() + updateIdx);
}

void ConversionPatternRewriter::notifyMatchFailure(
Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
impl->notifyMatchFailure(loc, reasonCallback);
}

void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block,
Block::iterator iterator) {
llvm_unreachable(
Expand Down