Skip to content

[mlir][Transforms][NFC] Turn op creation into IRRewrite #81759

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 64 additions & 38 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,17 +152,12 @@ namespace {
/// This class contains a snapshot of the current conversion rewriter state.
/// This is useful when saving and undoing a set of rewrites.
struct RewriterState {
RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
unsigned numRewrites, unsigned numIgnoredOperations,
unsigned numErased)
: numCreatedOps(numCreatedOps),
numUnresolvedMaterializations(numUnresolvedMaterializations),
RewriterState(unsigned numUnresolvedMaterializations, unsigned numRewrites,
unsigned numIgnoredOperations, unsigned numErased)
: numUnresolvedMaterializations(numUnresolvedMaterializations),
numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
numErased(numErased) {}

/// The current number of created operations.
unsigned numCreatedOps;

/// The current number of unresolved materializations.
unsigned numUnresolvedMaterializations;

Expand Down Expand Up @@ -303,7 +298,8 @@ class IRRewrite {
// Operation rewrites
MoveOperation,
ModifyOperation,
ReplaceOperation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmm, come to think of it: would it make sense to have "marker" types here so that you wouldn't need to change below if you add types & also where one adds entries here is self-documented due to markers?

ReplaceOperation,
CreateOperation
};

virtual ~IRRewrite() = default;
Expand Down Expand Up @@ -376,7 +372,10 @@ class CreateBlockRewrite : public BlockRewrite {
auto &blockOps = block->getOperations();
while (!blockOps.empty())
blockOps.remove(blockOps.begin());
eraseBlock(block);
if (block->getParent())
eraseBlock(block);
else
delete block;
}
};

Expand Down Expand Up @@ -606,7 +605,7 @@ class OperationRewrite : public IRRewrite {

static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() >= Kind::MoveOperation &&
rewrite->getKind() <= Kind::ReplaceOperation;
rewrite->getKind() <= Kind::CreateOperation;
}

protected:
Expand Down Expand Up @@ -740,6 +739,19 @@ class ReplaceOperationRewrite : public OperationRewrite {
/// A boolean flag that indicates whether result types have changed or not.
bool changedResults;
};

class CreateOperationRewrite : public OperationRewrite {
public:
CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Operation *op)
: OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}

static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::CreateOperation;
}

void rollback() override;
};
} // namespace

/// Return "true" if there is an operation rewrite that matches the specified
Expand Down Expand Up @@ -957,9 +969,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
// replacing a value with one of a different type.
ConversionValueMapping mapping;

/// Ordered vector of all of the newly created operations during conversion.
SmallVector<Operation *> createdOps;

/// Ordered vector of all unresolved type conversion materializations during
/// conversion.
SmallVector<UnresolvedMaterialization> unresolvedMaterializations;
Expand Down Expand Up @@ -1144,6 +1153,15 @@ void ReplaceOperationRewrite::rollback() {

void ReplaceOperationRewrite::cleanup() { eraseOp(op); }

void CreateOperationRewrite::rollback() {
for (Region &region : op->getRegions()) {
while (!region.getBlocks().empty())
region.getBlocks().remove(region.getBlocks().begin());
}
op->dropAllUses();
eraseOp(op);
}

void ConversionPatternRewriterImpl::detachNestedAndErase(Operation *op) {
for (Region &region : op->getRegions()) {
for (Block &block : region.getBlocks()) {
Expand All @@ -1161,8 +1179,6 @@ void ConversionPatternRewriterImpl::discardRewrites() {
// Remove any newly created ops.
for (UnresolvedMaterialization &materialization : unresolvedMaterializations)
detachNestedAndErase(materialization.getOp());
for (auto *op : llvm::reverse(createdOps))
detachNestedAndErase(op);
}

void ConversionPatternRewriterImpl::applyRewrites() {
Expand All @@ -1182,9 +1198,8 @@ void ConversionPatternRewriterImpl::applyRewrites() {
// State Management

RewriterState ConversionPatternRewriterImpl::getCurrentState() {
return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
rewrites.size(), ignoredOps.size(),
eraseRewriter.erased.size());
return RewriterState(unresolvedMaterializations.size(), rewrites.size(),
ignoredOps.size(), eraseRewriter.erased.size());
}

void ConversionPatternRewriterImpl::resetState(RewriterState state) {
Expand All @@ -1205,12 +1220,6 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
detachNestedAndErase(op);
}

// Pop all of the newly created operations.
while (createdOps.size() != state.numCreatedOps) {
detachNestedAndErase(createdOps.back());
createdOps.pop_back();
}

// Pop all of the recorded ignored operations that are no longer valid.
while (ignoredOps.size() != state.numIgnoredOperations)
ignoredOps.pop_back();
Expand Down Expand Up @@ -1478,7 +1487,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
});
if (!previous.isSet()) {
// This is a newly created op.
createdOps.push_back(op);
appendRewrite<CreateOperationRewrite>(op);
return;
}
Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
Expand Down Expand Up @@ -1979,13 +1988,16 @@ OperationLegalizer::legalizeWithFold(Operation *op,
rewriter.replaceOp(op, replacementValues);

// Recursively legalize any new constant operations.
for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
i != e; ++i) {
Operation *cstOp = rewriterImpl.createdOps[i];
if (failed(legalize(cstOp, rewriter))) {
auto *createOp =
dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
if (!createOp)
continue;
if (failed(legalize(createOp->getOperation(), rewriter))) {
LLVM_DEBUG(logFailure(rewriterImpl.logger,
"failed to legalize generated constant '{0}'",
cstOp->getName()));
createOp->getOperation()->getName()));
rewriterImpl.resetState(curState);
return failure();
}
Expand Down Expand Up @@ -2132,9 +2144,14 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
// blocks in regions created by this pattern will already be legalized later
// on. If we haven't built the set yet, build it now.
if (operationsToIgnore.empty()) {
auto createdOps = ArrayRef<Operation *>(impl.createdOps)
.drop_front(state.numCreatedOps);
operationsToIgnore.insert(createdOps.begin(), createdOps.end());
for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
++i) {
auto *createOp =
dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
if (!createOp)
continue;
operationsToIgnore.insert(createOp->getOperation());
}
}

// If this operation should be considered for re-legalization, try it.
Expand All @@ -2152,8 +2169,11 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
RewriterState &state, RewriterState &newState) {
for (int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) {
Operation *op = impl.createdOps[i];
for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
if (!createOp)
continue;
Operation *op = createOp->getOperation();
if (failed(legalize(op, rewriter))) {
LLVM_DEBUG(logFailure(impl.logger,
"failed to legalize generated operation '{0}'({1})",
Expand Down Expand Up @@ -2583,10 +2603,16 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
});
return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
};
for (auto &r : rewriterImpl.rewrites)
if (auto *rewrite = dyn_cast<BlockTypeConversionRewrite>(r.get()))
if (failed(rewrite->materializeLiveConversions(findLiveUser)))
// Note: `rewrites` may be reallocated as the loop is running.
for (int64_t i = 0; i < static_cast<int64_t>(rewriterImpl.rewrites.size());
++i) {
auto &rewrite = rewriterImpl.rewrites[i];
if (auto *blockTypeConversionRewrite =
dyn_cast<BlockTypeConversionRewrite>(rewrite.get()))
if (failed(blockTypeConversionRewrite->materializeLiveConversions(
findLiveUser)))
return failure();
}
return success();
}

Expand Down