@@ -154,14 +154,12 @@ namespace {
154
154
struct RewriterState {
155
155
RewriterState (unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
156
156
unsigned numReplacements, unsigned numArgReplacements,
157
- unsigned numRewrites, unsigned numIgnoredOperations,
158
- unsigned numRootUpdates)
157
+ unsigned numRewrites, unsigned numIgnoredOperations)
159
158
: numCreatedOps(numCreatedOps),
160
159
numUnresolvedMaterializations (numUnresolvedMaterializations),
161
160
numReplacements(numReplacements),
162
161
numArgReplacements(numArgReplacements), numRewrites(numRewrites),
163
- numIgnoredOperations(numIgnoredOperations),
164
- numRootUpdates(numRootUpdates) {}
162
+ numIgnoredOperations(numIgnoredOperations) {}
165
163
166
164
// / The current number of created operations.
167
165
unsigned numCreatedOps;
@@ -180,44 +178,6 @@ struct RewriterState {
180
178
181
179
// / The current number of ignored operations.
182
180
unsigned numIgnoredOperations;
183
-
184
- // / The current number of operations that were updated in place.
185
- unsigned numRootUpdates;
186
- };
187
-
188
- // ===----------------------------------------------------------------------===//
189
- // OperationTransactionState
190
-
191
- // / The state of an operation that was updated by a pattern in-place. This
192
- // / contains all of the necessary information to reconstruct an operation that
193
- // / was updated in place.
194
- class OperationTransactionState {
195
- public:
196
- OperationTransactionState () = default ;
197
- OperationTransactionState (Operation *op)
198
- : op(op), loc(op->getLoc ()), attrs(op->getAttrDictionary ()),
199
- operands(op->operand_begin (), op->operand_end()),
200
- successors(op->successor_begin (), op->successor_end()) {}
201
-
202
- // / Discard the transaction state and reset the state of the original
203
- // / operation.
204
- void resetOperation () const {
205
- op->setLoc (loc);
206
- op->setAttrs (attrs);
207
- op->setOperands (operands);
208
- for (const auto &it : llvm::enumerate (successors))
209
- op->setSuccessor (it.value (), it.index ());
210
- }
211
-
212
- // / Return the original operation of this state.
213
- Operation *getOperation () const { return op; }
214
-
215
- private:
216
- Operation *op;
217
- LocationAttr loc;
218
- DictionaryAttr attrs;
219
- SmallVector<Value, 8 > operands;
220
- SmallVector<Block *, 2 > successors;
221
181
};
222
182
223
183
// ===----------------------------------------------------------------------===//
@@ -754,14 +714,19 @@ namespace {
754
714
class IRRewrite {
755
715
public:
756
716
// / The kind of the rewrite. Rewrites can be undone if the conversion fails.
717
+ // / Enum values are ordered, so that they can be used in `classof`: first all
718
+ // / block rewrites, then all operation rewrites.
757
719
enum class Kind {
720
+ // Block rewrites
758
721
CreateBlock,
759
722
EraseBlock,
760
723
InlineBlock,
761
724
MoveBlock,
762
725
SplitBlock,
763
726
BlockTypeConversion,
764
- MoveOperation
727
+ // Operation rewrites
728
+ MoveOperation,
729
+ ModifyOperation
765
730
};
766
731
767
732
virtual ~IRRewrite () = default ;
@@ -992,7 +957,7 @@ class OperationRewrite : public IRRewrite {
992
957
993
958
static bool classof (const IRRewrite *rewrite) {
994
959
return rewrite->getKind () >= Kind::MoveOperation &&
995
- rewrite->getKind () <= Kind::MoveOperation ;
960
+ rewrite->getKind () <= Kind::ModifyOperation ;
996
961
}
997
962
998
963
protected:
@@ -1031,8 +996,48 @@ class MoveOperationRewrite : public OperationRewrite {
1031
996
// this operation was the only operation in the region.
1032
997
Operation *insertBeforeOp;
1033
998
};
999
+
1000
+ // / In-place modification of an op. This rewrite is immediately reflected in
1001
+ // / the IR. The previous state of the operation is stored in this object.
1002
+ class ModifyOperationRewrite : public OperationRewrite {
1003
+ public:
1004
+ ModifyOperationRewrite (ConversionPatternRewriterImpl &rewriterImpl,
1005
+ Operation *op)
1006
+ : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
1007
+ loc (op->getLoc ()), attrs(op->getAttrDictionary ()),
1008
+ operands(op->operand_begin (), op->operand_end()),
1009
+ successors(op->successor_begin (), op->successor_end()) {}
1010
+
1011
+ static bool classof (const IRRewrite *rewrite) {
1012
+ return rewrite->getKind () == Kind::ModifyOperation;
1013
+ }
1014
+
1015
+ void rollback () override {
1016
+ op->setLoc (loc);
1017
+ op->setAttrs (attrs);
1018
+ op->setOperands (operands);
1019
+ for (const auto &it : llvm::enumerate (successors))
1020
+ op->setSuccessor (it.value (), it.index ());
1021
+ }
1022
+
1023
+ private:
1024
+ LocationAttr loc;
1025
+ DictionaryAttr attrs;
1026
+ SmallVector<Value, 8 > operands;
1027
+ SmallVector<Block *, 2 > successors;
1028
+ };
1034
1029
} // namespace
1035
1030
1031
+ // / Return "true" if there is an operation rewrite that matches the specified
1032
+ // / rewrite type and operation among the given rewrites.
1033
+ template <typename RewriteTy, typename R>
1034
+ static bool hasRewrite (R &&rewrites, Operation *op) {
1035
+ return any_of (std::move (rewrites), [&](auto &rewrite) {
1036
+ auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get ());
1037
+ return rewriteTy && rewriteTy->getOperation () == op;
1038
+ });
1039
+ }
1040
+
1036
1041
// ===----------------------------------------------------------------------===//
1037
1042
// ConversionPatternRewriterImpl
1038
1043
// ===----------------------------------------------------------------------===//
@@ -1184,9 +1189,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
1184
1189
// / operation was ignored.
1185
1190
SetVector<Operation *> ignoredOps;
1186
1191
1187
- // / A transaction state for each of operations that were updated in-place.
1188
- SmallVector<OperationTransactionState, 4 > rootUpdates;
1189
-
1190
1192
// / A vector of indices into `replacements` of operations that were replaced
1191
1193
// / with values with different result types than the original operation, e.g.
1192
1194
// / 1->N conversion of some kind.
@@ -1238,10 +1240,6 @@ static void detachNestedAndErase(Operation *op) {
1238
1240
}
1239
1241
1240
1242
void ConversionPatternRewriterImpl::discardRewrites () {
1241
- // Reset any operations that were updated in place.
1242
- for (auto &state : rootUpdates)
1243
- state.resetOperation ();
1244
-
1245
1243
undoRewrites ();
1246
1244
1247
1245
// Remove any newly created ops.
@@ -1316,15 +1314,10 @@ void ConversionPatternRewriterImpl::applyRewrites() {
1316
1314
RewriterState ConversionPatternRewriterImpl::getCurrentState () {
1317
1315
return RewriterState (createdOps.size (), unresolvedMaterializations.size (),
1318
1316
replacements.size (), argReplacements.size (),
1319
- rewrites.size (), ignoredOps.size (), rootUpdates. size () );
1317
+ rewrites.size (), ignoredOps.size ());
1320
1318
}
1321
1319
1322
1320
void ConversionPatternRewriterImpl::resetState (RewriterState state) {
1323
- // Reset any operations that were updated in place.
1324
- for (unsigned i = state.numRootUpdates , e = rootUpdates.size (); i != e; ++i)
1325
- rootUpdates[i].resetOperation ();
1326
- rootUpdates.resize (state.numRootUpdates );
1327
-
1328
1321
// Reset any replaced arguments.
1329
1322
for (BlockArgument replacedArg :
1330
1323
llvm::drop_begin (argReplacements, state.numArgReplacements ))
@@ -1750,7 +1743,7 @@ void ConversionPatternRewriter::startOpModification(Operation *op) {
1750
1743
#ifndef NDEBUG
1751
1744
impl->pendingRootUpdates .insert (op);
1752
1745
#endif
1753
- impl->rootUpdates . emplace_back (op);
1746
+ impl->appendRewrite <ModifyOperationRewrite> (op);
1754
1747
}
1755
1748
1756
1749
void ConversionPatternRewriter::finalizeOpModification (Operation *op) {
@@ -1769,13 +1762,15 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) {
1769
1762
" operation did not have a pending in-place update" );
1770
1763
#endif
1771
1764
// Erase the last update for this operation.
1772
- auto stateHasOp = [op](const auto &it) { return it.getOperation () == op; };
1773
- auto &rootUpdates = impl->rootUpdates ;
1774
- auto it = llvm::find_if (llvm::reverse (rootUpdates), stateHasOp);
1775
- assert (it != rootUpdates.rend () && " no root update started on op" );
1776
- (*it).resetOperation ();
1777
- int updateIdx = std::prev (rootUpdates.rend ()) - it;
1778
- rootUpdates.erase (rootUpdates.begin () + updateIdx);
1765
+ auto it = llvm::find_if (
1766
+ llvm::reverse (impl->rewrites ), [&](std::unique_ptr<IRRewrite> &rewrite) {
1767
+ auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get ());
1768
+ return modifyRewrite && modifyRewrite->getOperation () == op;
1769
+ });
1770
+ assert (it != impl->rewrites .rend () && " no root update started on op" );
1771
+ (*it)->rollback ();
1772
+ int updateIdx = std::prev (impl->rewrites .rend ()) - it;
1773
+ impl->rewrites .erase (impl->rewrites .begin () + updateIdx);
1779
1774
}
1780
1775
1781
1776
detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl () {
@@ -2059,6 +2054,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
2059
2054
// Functor that cleans up the rewriter state after a pattern failed to match.
2060
2055
RewriterState curState = rewriterImpl.getCurrentState ();
2061
2056
auto onFailure = [&](const Pattern &pattern) {
2057
+ assert (rewriterImpl.pendingRootUpdates .empty () && " dangling root updates" );
2062
2058
LLVM_DEBUG ({
2063
2059
logFailure (rewriterImpl.logger , " pattern failed to match" );
2064
2060
if (rewriterImpl.notifyCallback ) {
@@ -2076,6 +2072,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
2076
2072
// Functor that performs additional legalization when a pattern is
2077
2073
// successfully applied.
2078
2074
auto onSuccess = [&](const Pattern &pattern) {
2075
+ assert (rewriterImpl.pendingRootUpdates .empty () && " dangling root updates" );
2079
2076
auto result = legalizePatternResult (op, pattern, rewriter, curState);
2080
2077
appliedPatterns.erase (&pattern);
2081
2078
if (failed (result))
@@ -2118,7 +2115,6 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
2118
2115
2119
2116
#ifndef NDEBUG
2120
2117
assert (impl.pendingRootUpdates .empty () && " dangling root updates" );
2121
- #endif
2122
2118
2123
2119
// Check that the root was either replaced or updated in place.
2124
2120
auto replacedRoot = [&] {
@@ -2127,14 +2123,12 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
2127
2123
[op](auto &it) { return it.first == op; });
2128
2124
};
2129
2125
auto updatedRootInPlace = [&] {
2130
- return llvm::any_of (
2131
- llvm::drop_begin (impl.rootUpdates , curState.numRootUpdates ),
2132
- [op](auto &state) { return state.getOperation () == op; });
2126
+ return hasRewrite<ModifyOperationRewrite>(
2127
+ llvm::drop_begin (impl.rewrites , curState.numRewrites ), op);
2133
2128
};
2134
- (void )replacedRoot;
2135
- (void )updatedRootInPlace;
2136
2129
assert ((replacedRoot () || updatedRootInPlace ()) &&
2137
2130
" expected pattern to replace the root operation" );
2131
+ #endif // NDEBUG
2138
2132
2139
2133
// Legalize each of the actions registered during application.
2140
2134
RewriterState newState = impl.getCurrentState ();
@@ -2221,8 +2215,11 @@ LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2221
2215
LogicalResult OperationLegalizer::legalizePatternRootUpdates (
2222
2216
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
2223
2217
RewriterState &state, RewriterState &newState) {
2224
- for (int i = state.numRootUpdates , e = newState.numRootUpdates ; i != e; ++i) {
2225
- Operation *op = impl.rootUpdates [i].getOperation ();
2218
+ for (int i = state.numRewrites , e = newState.numRewrites ; i != e; ++i) {
2219
+ auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites [i].get ());
2220
+ if (!rewrite)
2221
+ continue ;
2222
+ Operation *op = rewrite->getOperation ();
2226
2223
if (failed (legalize (op, rewriter))) {
2227
2224
LLVM_DEBUG (logFailure (
2228
2225
impl.logger , " failed to legalize operation updated in-place '{0}'" ,
@@ -3562,7 +3559,8 @@ mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
3562
3559
// Full Conversion
3563
3560
3564
3561
LogicalResult
3565
- mlir::applyFullConversion (ArrayRef<Operation *> ops, const ConversionTarget &target,
3562
+ mlir::applyFullConversion (ArrayRef<Operation *> ops,
3563
+ const ConversionTarget &target,
3566
3564
const FrozenRewritePatternSet &patterns) {
3567
3565
OperationConverter opConverter (target, patterns, OpConversionMode::Full);
3568
3566
return opConverter.convertOperations (ops);
0 commit comments