Skip to content

Commit 62bf771

Browse files
[mlir][IR] Add notifyBlockRemoved callback to listener (#78306)
There is already a "block inserted" notification (in `OpBuilder::Listener`), so there should also be a "block removed" notification. The purpose of this change is to make the listener API more mature. There is currently a gap between what kind of IR changes can be made and what IR changes can be listened to. At the moment, the only way to inform listeners about "block removal" is to send a manual `notifyOperationModified` for the parent op (e.g., by wrapping the `eraseBlock(b)` method call in `updateRootInPlace(b->getParentOp())`). This tells the listener that *something* has changed, but it is somewhat of an API abuse.
1 parent d023044 commit 62bf771

File tree

4 files changed

+26
-3
lines changed

4 files changed

+26
-3
lines changed

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,10 @@ class RewriterBase : public OpBuilder {
402402
Listener()
403403
: OpBuilder::Listener(ListenerBase::Kind::RewriterBaseListener) {}
404404

405+
/// Notify the listener that the specified block is about to be erased.
406+
/// At this point, the block has zero uses.
407+
virtual void notifyBlockRemoved(Block *block) {}
408+
405409
/// Notify the listener that the specified operation was modified in-place.
406410
virtual void notifyOperationModified(Operation *op) {}
407411

@@ -452,6 +456,10 @@ class RewriterBase : public OpBuilder {
452456
void notifyBlockCreated(Block *block) override {
453457
listener->notifyBlockCreated(block);
454458
}
459+
void notifyBlockRemoved(Block *block) override {
460+
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
461+
rewriteListener->notifyBlockRemoved(block);
462+
}
455463
void notifyOperationModified(Operation *op) override {
456464
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
457465
rewriteListener->notifyOperationModified(op);

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1114,7 +1114,7 @@ static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
11141114
scf::ForOp newLoop = rewriter.create<scf::ForOp>(
11151115
loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
11161116
operands);
1117-
newLoop.getBody()->erase();
1117+
rewriter.eraseBlock(newLoop.getBody());
11181118

11191119
newLoop.getRegion().getBlocks().splice(
11201120
newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());

mlir/lib/IR/PatternMatch.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ void RewriterBase::eraseOp(Operation *op) {
244244
for (BlockArgument bbArg : b->getArguments())
245245
bbArg.dropAllUses();
246246
b->dropAllUses();
247-
b->erase();
247+
eraseBlock(b);
248248
}
249249
}
250250
}
@@ -256,10 +256,17 @@ void RewriterBase::eraseOp(Operation *op) {
256256
}
257257

258258
void RewriterBase::eraseBlock(Block *block) {
259+
assert(block->use_empty() && "expected 'block' to have no uses");
260+
259261
for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
260262
assert(op.use_empty() && "expected 'op' to have no uses");
261263
eraseOp(&op);
262264
}
265+
266+
// Notify the listener that the block is about to be removed.
267+
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
268+
rewriteListener->notifyBlockRemoved(block);
269+
263270
block->erase();
264271
}
265272

@@ -311,7 +318,7 @@ void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
311318
// Move operations from the source block to the dest block and erase the
312319
// source block.
313320
dest->getOperations().splice(before, source->getOperations());
314-
source->erase();
321+
eraseBlock(source);
315322
}
316323

317324
void RewriterBase::inlineBlockBefore(Block *source, Operation *op,

mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,9 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
373373
/// Notify the driver that the given block was created.
374374
void notifyBlockCreated(Block *block) override;
375375

376+
/// Notify the driver that the given block is about to be removed.
377+
void notifyBlockRemoved(Block *block) override;
378+
376379
/// For debugging only: Notify the driver of a pattern match failure.
377380
LogicalResult
378381
notifyMatchFailure(Location loc,
@@ -633,6 +636,11 @@ void GreedyPatternRewriteDriver::notifyBlockCreated(Block *block) {
633636
config.listener->notifyBlockCreated(block);
634637
}
635638

639+
void GreedyPatternRewriteDriver::notifyBlockRemoved(Block *block) {
640+
if (config.listener)
641+
config.listener->notifyBlockRemoved(block);
642+
}
643+
636644
void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
637645
LLVM_DEBUG({
638646
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op

0 commit comments

Comments
 (0)