Skip to content

Commit 5cabd6c

Browse files
[mlir][Transforms][NFC] Turn in-place op modifications into RewriteActions
This commit simplifies the internal state of the dialect conversion. A separate field for the previous state of in-place op modifications is no longer needed. BEGIN_PUBLIC No public commit message needed for presubmit. END_PUBLIC
1 parent 7318585 commit 5cabd6c

File tree

2 files changed

+74
-76
lines changed

2 files changed

+74
-76
lines changed

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -744,8 +744,8 @@ class ConversionPatternRewriter final : public PatternRewriter {
744744

745745
/// PatternRewriter hook for updating the given operation in-place.
746746
/// Note: These methods only track updates to the given operation itself,
747-
/// and not nested regions. Updates to regions will still require
748-
/// notification through other more specific hooks above.
747+
/// and not nested regions. Updates to regions will still require notification
748+
/// through other more specific hooks above.
749749
void startOpModification(Operation *op) override;
750750

751751
/// PatternRewriter hook for updating the given operation in-place.

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 72 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -154,14 +154,12 @@ namespace {
154154
struct RewriterState {
155155
RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
156156
unsigned numReplacements, unsigned numArgReplacements,
157-
unsigned numRewrites, unsigned numIgnoredOperations,
158-
unsigned numRootUpdates)
157+
unsigned numRewrites, unsigned numIgnoredOperations)
159158
: numCreatedOps(numCreatedOps),
160159
numUnresolvedMaterializations(numUnresolvedMaterializations),
161160
numReplacements(numReplacements),
162161
numArgReplacements(numArgReplacements), numRewrites(numRewrites),
163-
numIgnoredOperations(numIgnoredOperations),
164-
numRootUpdates(numRootUpdates) {}
162+
numIgnoredOperations(numIgnoredOperations) {}
165163

166164
/// The current number of created operations.
167165
unsigned numCreatedOps;
@@ -180,44 +178,6 @@ struct RewriterState {
180178

181179
/// The current number of ignored operations.
182180
unsigned numIgnoredOperations;
183-
184-
/// The current number of operations that were updated in place.
185-
unsigned numRootUpdates;
186-
};
187-
188-
//===----------------------------------------------------------------------===//
189-
// OperationTransactionState
190-
191-
/// The state of an operation that was updated by a pattern in-place. This
192-
/// contains all of the necessary information to reconstruct an operation that
193-
/// was updated in place.
194-
class OperationTransactionState {
195-
public:
196-
OperationTransactionState() = default;
197-
OperationTransactionState(Operation *op)
198-
: op(op), loc(op->getLoc()), attrs(op->getAttrDictionary()),
199-
operands(op->operand_begin(), op->operand_end()),
200-
successors(op->successor_begin(), op->successor_end()) {}
201-
202-
/// Discard the transaction state and reset the state of the original
203-
/// operation.
204-
void resetOperation() const {
205-
op->setLoc(loc);
206-
op->setAttrs(attrs);
207-
op->setOperands(operands);
208-
for (const auto &it : llvm::enumerate(successors))
209-
op->setSuccessor(it.value(), it.index());
210-
}
211-
212-
/// Return the original operation of this state.
213-
Operation *getOperation() const { return op; }
214-
215-
private:
216-
Operation *op;
217-
LocationAttr loc;
218-
DictionaryAttr attrs;
219-
SmallVector<Value, 8> operands;
220-
SmallVector<Block *, 2> successors;
221181
};
222182

223183
//===----------------------------------------------------------------------===//
@@ -754,14 +714,19 @@ namespace {
754714
class IRRewrite {
755715
public:
756716
/// The kind of the rewrite. Rewrites can be undone if the conversion fails.
717+
/// Enum values are ordered, so that they can be used in `classof`: first all
718+
/// block rewrites, then all operation rewrites.
757719
enum class Kind {
720+
// Block rewrites
758721
CreateBlock,
759722
EraseBlock,
760723
InlineBlock,
761724
MoveBlock,
762725
SplitBlock,
763726
BlockTypeConversion,
764-
MoveOperation
727+
// Operation rewrites
728+
MoveOperation,
729+
ModifyOperation
765730
};
766731

767732
virtual ~IRRewrite() = default;
@@ -992,7 +957,7 @@ class OperationRewrite : public IRRewrite {
992957

993958
static bool classof(const IRRewrite *rewrite) {
994959
return rewrite->getKind() >= Kind::MoveOperation &&
995-
rewrite->getKind() <= Kind::MoveOperation;
960+
rewrite->getKind() <= Kind::ModifyOperation;
996961
}
997962

998963
protected:
@@ -1031,8 +996,48 @@ class MoveOperationRewrite : public OperationRewrite {
1031996
// this operation was the only operation in the region.
1032997
Operation *insertBeforeOp;
1033998
};
999+
1000+
/// In-place modification of an op. This rewrite is immediately reflected in
1001+
/// the IR. The previous state of the operation is stored in this object.
1002+
class ModifyOperationRewrite : public OperationRewrite {
1003+
public:
1004+
ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
1005+
Operation *op)
1006+
: OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
1007+
loc(op->getLoc()), attrs(op->getAttrDictionary()),
1008+
operands(op->operand_begin(), op->operand_end()),
1009+
successors(op->successor_begin(), op->successor_end()) {}
1010+
1011+
static bool classof(const IRRewrite *rewrite) {
1012+
return rewrite->getKind() == Kind::ModifyOperation;
1013+
}
1014+
1015+
void rollback() override {
1016+
op->setLoc(loc);
1017+
op->setAttrs(attrs);
1018+
op->setOperands(operands);
1019+
for (const auto &it : llvm::enumerate(successors))
1020+
op->setSuccessor(it.value(), it.index());
1021+
}
1022+
1023+
private:
1024+
LocationAttr loc;
1025+
DictionaryAttr attrs;
1026+
SmallVector<Value, 8> operands;
1027+
SmallVector<Block *, 2> successors;
1028+
};
10341029
} // namespace
10351030

1031+
/// Return "true" if there is an operation rewrite that matches the specified
1032+
/// rewrite type and operation among the given rewrites.
1033+
template <typename RewriteTy, typename R>
1034+
static bool hasRewrite(R &&rewrites, Operation *op) {
1035+
return any_of(std::move(rewrites), [&](auto &rewrite) {
1036+
auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
1037+
return rewriteTy && rewriteTy->getOperation() == op;
1038+
});
1039+
}
1040+
10361041
//===----------------------------------------------------------------------===//
10371042
// ConversionPatternRewriterImpl
10381043
//===----------------------------------------------------------------------===//
@@ -1184,9 +1189,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
11841189
/// operation was ignored.
11851190
SetVector<Operation *> ignoredOps;
11861191

1187-
/// A transaction state for each of operations that were updated in-place.
1188-
SmallVector<OperationTransactionState, 4> rootUpdates;
1189-
11901192
/// A vector of indices into `replacements` of operations that were replaced
11911193
/// with values with different result types than the original operation, e.g.
11921194
/// 1->N conversion of some kind.
@@ -1238,10 +1240,6 @@ static void detachNestedAndErase(Operation *op) {
12381240
}
12391241

12401242
void ConversionPatternRewriterImpl::discardRewrites() {
1241-
// Reset any operations that were updated in place.
1242-
for (auto &state : rootUpdates)
1243-
state.resetOperation();
1244-
12451243
undoRewrites();
12461244

12471245
// Remove any newly created ops.
@@ -1316,15 +1314,10 @@ void ConversionPatternRewriterImpl::applyRewrites() {
13161314
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
13171315
return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
13181316
replacements.size(), argReplacements.size(),
1319-
rewrites.size(), ignoredOps.size(), rootUpdates.size());
1317+
rewrites.size(), ignoredOps.size());
13201318
}
13211319

13221320
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
1323-
// Reset any operations that were updated in place.
1324-
for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i)
1325-
rootUpdates[i].resetOperation();
1326-
rootUpdates.resize(state.numRootUpdates);
1327-
13281321
// Reset any replaced arguments.
13291322
for (BlockArgument replacedArg :
13301323
llvm::drop_begin(argReplacements, state.numArgReplacements))
@@ -1750,7 +1743,7 @@ void ConversionPatternRewriter::startOpModification(Operation *op) {
17501743
#ifndef NDEBUG
17511744
impl->pendingRootUpdates.insert(op);
17521745
#endif
1753-
impl->rootUpdates.emplace_back(op);
1746+
impl->appendRewrite<ModifyOperationRewrite>(op);
17541747
}
17551748

17561749
void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
@@ -1769,13 +1762,15 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) {
17691762
"operation did not have a pending in-place update");
17701763
#endif
17711764
// Erase the last update for this operation.
1772-
auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; };
1773-
auto &rootUpdates = impl->rootUpdates;
1774-
auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp);
1775-
assert(it != rootUpdates.rend() && "no root update started on op");
1776-
(*it).resetOperation();
1777-
int updateIdx = std::prev(rootUpdates.rend()) - it;
1778-
rootUpdates.erase(rootUpdates.begin() + updateIdx);
1765+
auto it = llvm::find_if(
1766+
llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
1767+
auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
1768+
return modifyRewrite && modifyRewrite->getOperation() == op;
1769+
});
1770+
assert(it != impl->rewrites.rend() && "no root update started on op");
1771+
(*it)->rollback();
1772+
int updateIdx = std::prev(impl->rewrites.rend()) - it;
1773+
impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
17791774
}
17801775

17811776
detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
@@ -2059,6 +2054,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
20592054
// Functor that cleans up the rewriter state after a pattern failed to match.
20602055
RewriterState curState = rewriterImpl.getCurrentState();
20612056
auto onFailure = [&](const Pattern &pattern) {
2057+
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
20622058
LLVM_DEBUG({
20632059
logFailure(rewriterImpl.logger, "pattern failed to match");
20642060
if (rewriterImpl.notifyCallback) {
@@ -2076,6 +2072,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
20762072
// Functor that performs additional legalization when a pattern is
20772073
// successfully applied.
20782074
auto onSuccess = [&](const Pattern &pattern) {
2075+
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
20792076
auto result = legalizePatternResult(op, pattern, rewriter, curState);
20802077
appliedPatterns.erase(&pattern);
20812078
if (failed(result))
@@ -2118,7 +2115,6 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
21182115

21192116
#ifndef NDEBUG
21202117
assert(impl.pendingRootUpdates.empty() && "dangling root updates");
2121-
#endif
21222118

21232119
// Check that the root was either replaced or updated in place.
21242120
auto replacedRoot = [&] {
@@ -2127,14 +2123,12 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
21272123
[op](auto &it) { return it.first == op; });
21282124
};
21292125
auto updatedRootInPlace = [&] {
2130-
return llvm::any_of(
2131-
llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates),
2132-
[op](auto &state) { return state.getOperation() == op; });
2126+
return hasRewrite<ModifyOperationRewrite>(
2127+
llvm::drop_begin(impl.rewrites, curState.numRewrites), op);
21332128
};
2134-
(void)replacedRoot;
2135-
(void)updatedRootInPlace;
21362129
assert((replacedRoot() || updatedRootInPlace()) &&
21372130
"expected pattern to replace the root operation");
2131+
#endif // NDEBUG
21382132

21392133
// Legalize each of the actions registered during application.
21402134
RewriterState newState = impl.getCurrentState();
@@ -2221,8 +2215,11 @@ LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
22212215
LogicalResult OperationLegalizer::legalizePatternRootUpdates(
22222216
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
22232217
RewriterState &state, RewriterState &newState) {
2224-
for (int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) {
2225-
Operation *op = impl.rootUpdates[i].getOperation();
2218+
for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2219+
auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites[i].get());
2220+
if (!rewrite)
2221+
continue;
2222+
Operation *op = rewrite->getOperation();
22262223
if (failed(legalize(op, rewriter))) {
22272224
LLVM_DEBUG(logFailure(
22282225
impl.logger, "failed to legalize operation updated in-place '{0}'",
@@ -3562,7 +3559,8 @@ mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
35623559
// Full Conversion
35633560

35643561
LogicalResult
3565-
mlir::applyFullConversion(ArrayRef<Operation *> ops, const ConversionTarget &target,
3562+
mlir::applyFullConversion(ArrayRef<Operation *> ops,
3563+
const ConversionTarget &target,
35663564
const FrozenRewritePatternSet &patterns) {
35673565
OperationConverter opConverter(target, patterns, OpConversionMode::Full);
35683566
return opConverter.convertOperations(ops);

0 commit comments

Comments
 (0)