Skip to content

Commit 0921bfd

Browse files
[mlir][Transforms] Dialect conversion: Add missing erasure notifications (#145030)
Add missing listener notifications when erasing nested blocks/operations. This commit also moves some of the functionality from `ConversionPatternRewriter` to `ConversionPatternRewriterImpl`. This is in preparation of the One-Shot Dialect Conversion refactoring: The implementations in `ConversionPatternRewriter` should be as simple as possible, so that a switch between "rollback allowed" and "rollback not allowed" can be inserted at that level. (In the latter case, `ConversionPatternRewriterImpl` can be bypassed to some degree, and `PatternRewriter::eraseBlock` etc. can be used.) Depends on #145018.
1 parent 4a4582d commit 0921bfd

File tree

2 files changed

+51
-19
lines changed

2 files changed

+51
-19
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,26 @@ struct RewriterState {
274274
// IR rewrites
275275
//===----------------------------------------------------------------------===//
276276

277+
static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
278+
279+
/// Notify the listener that the given block and its contents are being erased.
280+
static void notifyIRErased(RewriterBase::Listener *listener, Block &b) {
281+
for (Operation &op : b)
282+
notifyIRErased(listener, op);
283+
listener->notifyBlockErased(&b);
284+
}
285+
286+
/// Notify the listener that the given operation and its contents are being
287+
/// erased.
288+
static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
289+
for (Region &r : op.getRegions()) {
290+
for (Block &b : r) {
291+
notifyIRErased(listener, b);
292+
}
293+
}
294+
listener->notifyOperationErased(&op);
295+
}
296+
277297
/// An IR rewrite that can be committed (upon success) or rolled back (upon
278298
/// failure).
279299
///
@@ -422,17 +442,20 @@ class EraseBlockRewrite : public BlockRewrite {
422442
}
423443

424444
void commit(RewriterBase &rewriter) override {
425-
// Erase the block.
426445
assert(block && "expected block");
427-
assert(block->empty() && "expected empty block");
428446

429-
// Notify the listener that the block is about to be erased.
447+
// Notify the listener that the block and its contents are being erased.
430448
if (auto *listener =
431449
dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
432-
listener->notifyBlockErased(block);
450+
notifyIRErased(listener, *block);
433451
}
434452

435453
void cleanup(RewriterBase &rewriter) override {
454+
// Erase the contents of the block.
455+
for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
456+
rewriter.eraseOp(&op);
457+
assert(block->empty() && "expected empty block");
458+
436459
// Erase the block.
437460
block->dropAllDefinedValueUses();
438461
delete block;
@@ -1147,12 +1170,9 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
11471170
if (getConfig().unlegalizedOps)
11481171
getConfig().unlegalizedOps->erase(op);
11491172

1150-
// Notify the listener that the operation (and its nested operations) was
1151-
// erased.
1152-
if (listener) {
1153-
op->walk<WalkOrder::PostOrder>(
1154-
[&](Operation *op) { listener->notifyOperationErased(op); });
1155-
}
1173+
// Notify the listener that the operation and its contents are being erased.
1174+
if (listener)
1175+
notifyIRErased(listener, *op);
11561176

11571177
// Do not erase the operation yet. It may still be referenced in `mapping`.
11581178
// Just unlink it for now and erase it during cleanup.
@@ -1605,13 +1625,18 @@ void ConversionPatternRewriterImpl::replaceOp(
16051625
}
16061626

16071627
void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
1628+
assert(!wasOpReplaced(block->getParentOp()) &&
1629+
"attempting to erase a block within a replaced/erased op");
16081630
appendRewrite<EraseBlockRewrite>(block);
16091631

16101632
// Unlink the block from its parent region. The block is kept in the rewrite
16111633
// object and will be actually destroyed when rewrites are applied. This
16121634
// allows us to keep the operations in the block live and undo the removal by
16131635
// re-inserting the block.
16141636
block->getParent()->getBlocks().remove(block);
1637+
1638+
// Mark all nested ops as erased.
1639+
block->walk([&](Operation *op) { replacedOps.insert(op); });
16151640
}
16161641

16171642
void ConversionPatternRewriterImpl::notifyBlockInserted(
@@ -1709,13 +1734,6 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
17091734
}
17101735

17111736
void ConversionPatternRewriter::eraseBlock(Block *block) {
1712-
assert(!impl->wasOpReplaced(block->getParentOp()) &&
1713-
"attempting to erase a block within a replaced/erased op");
1714-
1715-
// Mark all ops for erasure.
1716-
for (Operation &op : *block)
1717-
eraseOp(&op);
1718-
17191737
impl->eraseBlock(block);
17201738
}
17211739

mlir/test/Transforms/test-legalizer.mlir

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -461,12 +461,26 @@ func.func @convert_detached_signature() {
461461

462462
// -----
463463

464+
// CHECK: notifyOperationReplaced: test.erase_op
465+
// CHECK: notifyOperationErased: test.dummy_op_lvl_2
466+
// CHECK: notifyBlockErased
467+
// CHECK: notifyOperationErased: test.dummy_op_lvl_1
468+
// CHECK: notifyBlockErased
469+
// CHECK: notifyOperationErased: test.erase_op
470+
// CHECK: notifyOperationInserted: test.valid, was unlinked
471+
// CHECK: notifyOperationReplaced: test.drop_operands_and_replace_with_valid
472+
// CHECK: notifyOperationErased: test.drop_operands_and_replace_with_valid
473+
464474
// CHECK-LABEL: func @circular_mapping()
465475
// CHECK-NEXT: "test.valid"() : () -> ()
466476
func.func @circular_mapping() {
467477
// Regression test that used to crash due to circular
468-
// unrealized_conversion_cast ops.
469-
%0 = "test.erase_op"() : () -> (i64)
478+
// unrealized_conversion_cast ops.
479+
%0 = "test.erase_op"() ({
480+
"test.dummy_op_lvl_1"() ({
481+
"test.dummy_op_lvl_2"() : () -> ()
482+
}) : () -> ()
483+
}): () -> (i64)
470484
"test.drop_operands_and_replace_with_valid"(%0) : (i64) -> ()
471485
}
472486

0 commit comments

Comments
 (0)