Skip to content

Commit 6b6c4b1

Browse files
[mlir][Transforms][NFC] Do not use SingleEraseRewriter during rollback
1 parent ab78e8c commit 6b6c4b1

File tree

1 file changed

+8
-14
lines changed

1 file changed

+8
-14
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -153,19 +153,16 @@ namespace {
153153
/// This is useful when saving and undoing a set of rewrites.
154154
struct RewriterState {
155155
RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
156-
unsigned numErased, unsigned numReplacedOps)
156+
unsigned numReplacedOps)
157157
: numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
158-
numErased(numErased), numReplacedOps(numReplacedOps) {}
158+
numReplacedOps(numReplacedOps) {}
159159

160160
/// The current number of rewrites performed.
161161
unsigned numRewrites;
162162

163163
/// The current number of ignored operations.
164164
unsigned numIgnoredOperations;
165165

166-
/// The current number of erased operations/blocks.
167-
unsigned numErased;
168-
169166
/// The current number of replaced ops that are scheduled for erasure.
170167
unsigned numReplacedOps;
171168
};
@@ -273,8 +270,9 @@ class CreateBlockRewrite : public BlockRewrite {
273270
auto &blockOps = block->getOperations();
274271
while (!blockOps.empty())
275272
blockOps.remove(blockOps.begin());
273+
block->dropAllUses();
276274
if (block->getParent())
277-
eraseBlock(block);
275+
block->erase();
278276
else
279277
delete block;
280278
}
@@ -858,7 +856,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
858856
void notifyBlockErased(Block *block) override { erased.insert(block); }
859857

860858
/// Pointers to all erased operations and blocks.
861-
SetVector<void *> erased;
859+
DenseSet<void *> erased;
862860
};
863861

864862
//===--------------------------------------------------------------------===//
@@ -1044,15 +1042,15 @@ void CreateOperationRewrite::rollback() {
10441042
region.getBlocks().remove(region.getBlocks().begin());
10451043
}
10461044
op->dropAllUses();
1047-
eraseOp(op);
1045+
op->erase();
10481046
}
10491047

10501048
void UnresolvedMaterializationRewrite::rollback() {
10511049
if (getMaterializationKind() == MaterializationKind::Target) {
10521050
for (Value input : op->getOperands())
10531051
rewriterImpl.mapping.erase(input);
10541052
}
1055-
eraseOp(op);
1053+
op->erase();
10561054
}
10571055

10581056
void UnresolvedMaterializationRewrite::cleanup() { eraseOp(op); }
@@ -1069,8 +1067,7 @@ void ConversionPatternRewriterImpl::applyRewrites() {
10691067
// State Management
10701068

10711069
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
1072-
return RewriterState(rewrites.size(), ignoredOps.size(),
1073-
eraseRewriter.erased.size(), replacedOps.size());
1070+
return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
10741071
}
10751072

10761073
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -1081,9 +1078,6 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
10811078
while (ignoredOps.size() != state.numIgnoredOperations)
10821079
ignoredOps.pop_back();
10831080

1084-
while (eraseRewriter.erased.size() != state.numErased)
1085-
eraseRewriter.erased.pop_back();
1086-
10871081
while (replacedOps.size() != state.numReplacedOps)
10881082
replacedOps.pop_back();
10891083
}

0 commit comments

Comments
 (0)