Skip to content

[mlir][IR] Add notifyBlockRemoved callback to listener #78306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/include/mlir/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,10 @@ class RewriterBase : public OpBuilder {
Listener()
: OpBuilder::Listener(ListenerBase::Kind::RewriterBaseListener) {}

/// Notify the listener that the specified block is about to be erased.
/// At this point, the block has zero uses.
virtual void notifyBlockRemoved(Block *block) {}

/// Notify the listener that the specified operation was modified in-place.
virtual void notifyOperationModified(Operation *op) {}

Expand Down Expand Up @@ -452,6 +456,10 @@ class RewriterBase : public OpBuilder {
void notifyBlockCreated(Block *block) override {
listener->notifyBlockCreated(block);
}
void notifyBlockRemoved(Block *block) override {
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
rewriteListener->notifyBlockRemoved(block);
}
void notifyOperationModified(Operation *op) override {
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
rewriteListener->notifyOperationModified(op);
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1114,7 +1114,7 @@ static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
scf::ForOp newLoop = rewriter.create<scf::ForOp>(
loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
operands);
newLoop.getBody()->erase();
rewriter.eraseBlock(newLoop.getBody());

newLoop.getRegion().getBlocks().splice(
newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
Expand Down
11 changes: 9 additions & 2 deletions mlir/lib/IR/PatternMatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ void RewriterBase::eraseOp(Operation *op) {
for (BlockArgument bbArg : b->getArguments())
bbArg.dropAllUses();
b->dropAllUses();
b->erase();
eraseBlock(b);
}
}
}
Expand All @@ -256,10 +256,17 @@ void RewriterBase::eraseOp(Operation *op) {
}

void RewriterBase::eraseBlock(Block *block) {
assert(block->use_empty() && "expected 'block' to have no uses");

for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
assert(op.use_empty() && "expected 'op' to have no uses");
eraseOp(&op);
}

// Notify the listener that the block is about to be removed.
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
rewriteListener->notifyBlockRemoved(block);

block->erase();
}

Expand Down Expand Up @@ -311,7 +318,7 @@ void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
// Move operations from the source block to the dest block and erase the
// source block.
dest->getOperations().splice(before, source->getOperations());
source->erase();
eraseBlock(source);
}

void RewriterBase::inlineBlockBefore(Block *source, Operation *op,
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,9 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
/// Notify the driver that the given block was created.
void notifyBlockCreated(Block *block) override;

/// Notify the driver that the given block is about to be removed.
void notifyBlockRemoved(Block *block) override;

/// For debugging only: Notify the driver of a pattern match failure.
LogicalResult
notifyMatchFailure(Location loc,
Expand Down Expand Up @@ -633,6 +636,11 @@ void GreedyPatternRewriteDriver::notifyBlockCreated(Block *block) {
config.listener->notifyBlockCreated(block);
}

void GreedyPatternRewriteDriver::notifyBlockRemoved(Block *block) {
if (config.listener)
config.listener->notifyBlockRemoved(block);
}

void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
LLVM_DEBUG({
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
Expand Down