Skip to content

Commit 310a278

Browse files
[mlir][Transforms][NFC] Simplify handling of erased IR (#83423)
The dialect conversion uses a `SingleEraseRewriter` to ensure that an op/block is not erased twice. This can happen during the "commit" phase when an unresolved materialization is inserted into a block and the enclosing op is erased by the user. In that case, the unresolved materialization should not be erased a second time later in the "commit" phase. This problem cannot happen during "rollback", so ops/block can be erased directly without using the rewriter. With this change, the `SingleEraseRewriter` is used only during "commit"/"cleanup". At that point, the dialect conversion is guaranteed to succeed and no rollback can happen. Therefore, it is not necessary to store the number of erased IR objects (because we will never "reset" the rewriter to previous a previous state).
1 parent 71c2a13 commit 310a278

File tree

1 file changed

+8
-14
lines changed

1 file changed

+8
-14
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -153,19 +153,16 @@ namespace {
153153
/// This is useful when saving and undoing a set of rewrites.
154154
struct RewriterState {
155155
RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
156-
unsigned numErased, unsigned numReplacedOps)
156+
unsigned numReplacedOps)
157157
: numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
158-
numErased(numErased), numReplacedOps(numReplacedOps) {}
158+
numReplacedOps(numReplacedOps) {}
159159

160160
/// The current number of rewrites performed.
161161
unsigned numRewrites;
162162

163163
/// The current number of ignored operations.
164164
unsigned numIgnoredOperations;
165165

166-
/// The current number of erased operations/blocks.
167-
unsigned numErased;
168-
169166
/// The current number of replaced ops that are scheduled for erasure.
170167
unsigned numReplacedOps;
171168
};
@@ -274,8 +271,9 @@ class CreateBlockRewrite : public BlockRewrite {
274271
auto &blockOps = block->getOperations();
275272
while (!blockOps.empty())
276273
blockOps.remove(blockOps.begin());
274+
block->dropAllUses();
277275
if (block->getParent())
278-
eraseBlock(block);
276+
block->erase();
279277
else
280278
delete block;
281279
}
@@ -905,7 +903,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
905903
void notifyBlockErased(Block *block) override { erased.insert(block); }
906904

907905
/// Pointers to all erased operations and blocks.
908-
SetVector<void *> erased;
906+
DenseSet<void *> erased;
909907
};
910908

911909
//===--------------------------------------------------------------------===//
@@ -1091,15 +1089,15 @@ void CreateOperationRewrite::rollback() {
10911089
region.getBlocks().remove(region.getBlocks().begin());
10921090
}
10931091
op->dropAllUses();
1094-
eraseOp(op);
1092+
op->erase();
10951093
}
10961094

10971095
void UnresolvedMaterializationRewrite::rollback() {
10981096
if (getMaterializationKind() == MaterializationKind::Target) {
10991097
for (Value input : op->getOperands())
11001098
rewriterImpl.mapping.erase(input);
11011099
}
1102-
eraseOp(op);
1100+
op->erase();
11031101
}
11041102

11051103
void UnresolvedMaterializationRewrite::cleanup() { eraseOp(op); }
@@ -1116,8 +1114,7 @@ void ConversionPatternRewriterImpl::applyRewrites() {
11161114
// State Management
11171115

11181116
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
1119-
return RewriterState(rewrites.size(), ignoredOps.size(),
1120-
eraseRewriter.erased.size(), replacedOps.size());
1117+
return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
11211118
}
11221119

11231120
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -1128,9 +1125,6 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
11281125
while (ignoredOps.size() != state.numIgnoredOperations)
11291126
ignoredOps.pop_back();
11301127

1131-
while (eraseRewriter.erased.size() != state.numErased)
1132-
eraseRewriter.erased.pop_back();
1133-
11341128
while (replacedOps.size() != state.numReplacedOps)
11351129
replacedOps.pop_back();
11361130
}

0 commit comments

Comments
 (0)