-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Transforms][NFC] Turn in-place op modification into IRRewrite
#81245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
matthias-springer
merged 1 commit into
main
from
users/matthias-springer/dialect_conversion_modify_op_inplace
Feb 21, 2024
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -154,14 +154,12 @@ namespace { | |
struct RewriterState { | ||
RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations, | ||
unsigned numReplacements, unsigned numArgReplacements, | ||
unsigned numRewrites, unsigned numIgnoredOperations, | ||
unsigned numRootUpdates) | ||
unsigned numRewrites, unsigned numIgnoredOperations) | ||
: numCreatedOps(numCreatedOps), | ||
numUnresolvedMaterializations(numUnresolvedMaterializations), | ||
numReplacements(numReplacements), | ||
numArgReplacements(numArgReplacements), numRewrites(numRewrites), | ||
numIgnoredOperations(numIgnoredOperations), | ||
numRootUpdates(numRootUpdates) {} | ||
numIgnoredOperations(numIgnoredOperations) {} | ||
|
||
/// The current number of created operations. | ||
unsigned numCreatedOps; | ||
|
@@ -180,44 +178,6 @@ struct RewriterState { | |
|
||
/// The current number of ignored operations. | ||
unsigned numIgnoredOperations; | ||
|
||
/// The current number of operations that were updated in place. | ||
unsigned numRootUpdates; | ||
}; | ||
|
||
//===----------------------------------------------------------------------===// | ||
// OperationTransactionState | ||
|
||
/// The state of an operation that was updated by a pattern in-place. This | ||
/// contains all of the necessary information to reconstruct an operation that | ||
/// was updated in place. | ||
class OperationTransactionState { | ||
public: | ||
OperationTransactionState() = default; | ||
OperationTransactionState(Operation *op) | ||
: op(op), loc(op->getLoc()), attrs(op->getAttrDictionary()), | ||
operands(op->operand_begin(), op->operand_end()), | ||
successors(op->successor_begin(), op->successor_end()) {} | ||
|
||
/// Discard the transaction state and reset the state of the original | ||
/// operation. | ||
void resetOperation() const { | ||
op->setLoc(loc); | ||
op->setAttrs(attrs); | ||
op->setOperands(operands); | ||
for (const auto &it : llvm::enumerate(successors)) | ||
op->setSuccessor(it.value(), it.index()); | ||
} | ||
|
||
/// Return the original operation of this state. | ||
Operation *getOperation() const { return op; } | ||
|
||
private: | ||
Operation *op; | ||
LocationAttr loc; | ||
DictionaryAttr attrs; | ||
SmallVector<Value, 8> operands; | ||
SmallVector<Block *, 2> successors; | ||
}; | ||
|
||
//===----------------------------------------------------------------------===// | ||
|
@@ -754,14 +714,19 @@ namespace { | |
class IRRewrite { | ||
public: | ||
/// The kind of the rewrite. Rewrites can be undone if the conversion fails. | ||
/// Enum values are ordered, so that they can be used in `classof`: first all | ||
/// block rewrites, then all operation rewrites. | ||
enum class Kind { | ||
// Block rewrites | ||
CreateBlock, | ||
EraseBlock, | ||
InlineBlock, | ||
MoveBlock, | ||
SplitBlock, | ||
BlockTypeConversion, | ||
MoveOperation | ||
// Operation rewrites | ||
MoveOperation, | ||
ModifyOperation | ||
}; | ||
|
||
virtual ~IRRewrite() = default; | ||
|
@@ -992,7 +957,7 @@ class OperationRewrite : public IRRewrite { | |
|
||
static bool classof(const IRRewrite *rewrite) { | ||
return rewrite->getKind() >= Kind::MoveOperation && | ||
rewrite->getKind() <= Kind::MoveOperation; | ||
rewrite->getKind() <= Kind::ModifyOperation; | ||
} | ||
|
||
protected: | ||
|
@@ -1031,8 +996,48 @@ class MoveOperationRewrite : public OperationRewrite { | |
// this operation was the only operation in the region. | ||
Operation *insertBeforeOp; | ||
}; | ||
|
||
/// In-place modification of an op. This rewrite is immediately reflected in | ||
/// the IR. The previous state of the operation is stored in this object. | ||
class ModifyOperationRewrite : public OperationRewrite { | ||
public: | ||
ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl, | ||
Operation *op) | ||
: OperationRewrite(Kind::ModifyOperation, rewriterImpl, op), | ||
loc(op->getLoc()), attrs(op->getAttrDictionary()), | ||
operands(op->operand_begin(), op->operand_end()), | ||
successors(op->successor_begin(), op->successor_end()) {} | ||
|
||
static bool classof(const IRRewrite *rewrite) { | ||
return rewrite->getKind() == Kind::ModifyOperation; | ||
} | ||
|
||
void rollback() override { | ||
op->setLoc(loc); | ||
op->setAttrs(attrs); | ||
op->setOperands(operands); | ||
for (const auto &it : llvm::enumerate(successors)) | ||
op->setSuccessor(it.value(), it.index()); | ||
} | ||
|
||
private: | ||
LocationAttr loc; | ||
DictionaryAttr attrs; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is properties needed too? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I made it a separate PR (#82474), so that this PR can stay NFC. |
||
SmallVector<Value, 8> operands; | ||
SmallVector<Block *, 2> successors; | ||
}; | ||
} // namespace | ||
|
||
/// Return "true" if there is an operation rewrite that matches the specified | ||
/// rewrite type and operation among the given rewrites. | ||
template <typename RewriteTy, typename R> | ||
static bool hasRewrite(R &&rewrites, Operation *op) { | ||
return any_of(std::move(rewrites), [&](auto &rewrite) { | ||
auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get()); | ||
return rewriteTy && rewriteTy->getOperation() == op; | ||
}); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// ConversionPatternRewriterImpl | ||
//===----------------------------------------------------------------------===// | ||
|
@@ -1184,9 +1189,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { | |
/// operation was ignored. | ||
SetVector<Operation *> ignoredOps; | ||
|
||
/// A transaction state for each of operations that were updated in-place. | ||
SmallVector<OperationTransactionState, 4> rootUpdates; | ||
|
||
/// A vector of indices into `replacements` of operations that were replaced | ||
/// with values with different result types than the original operation, e.g. | ||
/// 1->N conversion of some kind. | ||
|
@@ -1238,10 +1240,6 @@ static void detachNestedAndErase(Operation *op) { | |
} | ||
|
||
void ConversionPatternRewriterImpl::discardRewrites() { | ||
// Reset any operations that were updated in place. | ||
for (auto &state : rootUpdates) | ||
state.resetOperation(); | ||
|
||
undoRewrites(); | ||
|
||
// Remove any newly created ops. | ||
|
@@ -1316,15 +1314,10 @@ void ConversionPatternRewriterImpl::applyRewrites() { | |
RewriterState ConversionPatternRewriterImpl::getCurrentState() { | ||
return RewriterState(createdOps.size(), unresolvedMaterializations.size(), | ||
replacements.size(), argReplacements.size(), | ||
rewrites.size(), ignoredOps.size(), rootUpdates.size()); | ||
rewrites.size(), ignoredOps.size()); | ||
} | ||
|
||
void ConversionPatternRewriterImpl::resetState(RewriterState state) { | ||
// Reset any operations that were updated in place. | ||
for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i) | ||
rootUpdates[i].resetOperation(); | ||
rootUpdates.resize(state.numRootUpdates); | ||
|
||
// Reset any replaced arguments. | ||
for (BlockArgument replacedArg : | ||
llvm::drop_begin(argReplacements, state.numArgReplacements)) | ||
|
@@ -1750,7 +1743,7 @@ void ConversionPatternRewriter::startOpModification(Operation *op) { | |
#ifndef NDEBUG | ||
impl->pendingRootUpdates.insert(op); | ||
#endif | ||
impl->rootUpdates.emplace_back(op); | ||
impl->appendRewrite<ModifyOperationRewrite>(op); | ||
} | ||
|
||
void ConversionPatternRewriter::finalizeOpModification(Operation *op) { | ||
|
@@ -1769,13 +1762,15 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) { | |
"operation did not have a pending in-place update"); | ||
#endif | ||
// Erase the last update for this operation. | ||
auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; }; | ||
auto &rootUpdates = impl->rootUpdates; | ||
auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp); | ||
assert(it != rootUpdates.rend() && "no root update started on op"); | ||
(*it).resetOperation(); | ||
int updateIdx = std::prev(rootUpdates.rend()) - it; | ||
rootUpdates.erase(rootUpdates.begin() + updateIdx); | ||
auto it = llvm::find_if( | ||
llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) { | ||
auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get()); | ||
return modifyRewrite && modifyRewrite->getOperation() == op; | ||
}); | ||
assert(it != impl->rewrites.rend() && "no root update started on op"); | ||
(*it)->rollback(); | ||
int updateIdx = std::prev(impl->rewrites.rend()) - it; | ||
impl->rewrites.erase(impl->rewrites.begin() + updateIdx); | ||
} | ||
|
||
detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { | ||
|
@@ -2059,6 +2054,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op, | |
// Functor that cleans up the rewriter state after a pattern failed to match. | ||
RewriterState curState = rewriterImpl.getCurrentState(); | ||
auto onFailure = [&](const Pattern &pattern) { | ||
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); | ||
LLVM_DEBUG({ | ||
logFailure(rewriterImpl.logger, "pattern failed to match"); | ||
if (rewriterImpl.notifyCallback) { | ||
|
@@ -2076,6 +2072,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op, | |
// Functor that performs additional legalization when a pattern is | ||
// successfully applied. | ||
auto onSuccess = [&](const Pattern &pattern) { | ||
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); | ||
auto result = legalizePatternResult(op, pattern, rewriter, curState); | ||
appliedPatterns.erase(&pattern); | ||
if (failed(result)) | ||
|
@@ -2118,7 +2115,6 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern, | |
|
||
#ifndef NDEBUG | ||
assert(impl.pendingRootUpdates.empty() && "dangling root updates"); | ||
#endif | ||
|
||
// Check that the root was either replaced or updated in place. | ||
auto replacedRoot = [&] { | ||
|
@@ -2127,14 +2123,12 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern, | |
[op](auto &it) { return it.first == op; }); | ||
}; | ||
auto updatedRootInPlace = [&] { | ||
return llvm::any_of( | ||
llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates), | ||
[op](auto &state) { return state.getOperation() == op; }); | ||
return hasRewrite<ModifyOperationRewrite>( | ||
llvm::drop_begin(impl.rewrites, curState.numRewrites), op); | ||
}; | ||
(void)replacedRoot; | ||
(void)updatedRootInPlace; | ||
assert((replacedRoot() || updatedRootInPlace()) && | ||
"expected pattern to replace the root operation"); | ||
#endif // NDEBUG | ||
|
||
// Legalize each of the actions registered during application. | ||
RewriterState newState = impl.getCurrentState(); | ||
|
@@ -2221,8 +2215,11 @@ LogicalResult OperationLegalizer::legalizePatternCreatedOperations( | |
LogicalResult OperationLegalizer::legalizePatternRootUpdates( | ||
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, | ||
RewriterState &state, RewriterState &newState) { | ||
for (int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) { | ||
Operation *op = impl.rootUpdates[i].getOperation(); | ||
for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) { | ||
auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites[i].get()); | ||
if (!rewrite) | ||
continue; | ||
Operation *op = rewrite->getOperation(); | ||
if (failed(legalize(op, rewriter))) { | ||
LLVM_DEBUG(logFailure( | ||
impl.logger, "failed to legalize operation updated in-place '{0}'", | ||
|
@@ -3562,7 +3559,8 @@ mlir::applyPartialConversion(Operation *op, const ConversionTarget &target, | |
// Full Conversion | ||
|
||
LogicalResult | ||
mlir::applyFullConversion(ArrayRef<Operation *> ops, const ConversionTarget &target, | ||
mlir::applyFullConversion(ArrayRef<Operation *> ops, | ||
const ConversionTarget &target, | ||
const FrozenRewritePatternSet &patterns) { | ||
OperationConverter opConverter(target, patterns, OpConversionMode::Full); | ||
return opConverter.convertOperations(ops); | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.