Skip to content

Commit 88b91b4

Browse files
[mlir][Transforms] Dialect conversion: Erase materialized constants instead of rollback
1 parent 784dc16 commit 88b91b4

File tree

3 files changed

+22
-16
lines changed

3 files changed

+22
-16
lines changed

mlir/include/mlir/IR/Builders.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -564,9 +564,13 @@ class OpBuilder : public Builder {
564564

565565
/// Attempts to fold the given operation and places new results within
566566
/// `results`. Returns success if the operation was folded, failure otherwise.
567-
/// If the fold was in-place, `results` will not be filled.
567+
/// If the fold was in-place, `results` will not be filled. Optionally, newly
568+
/// materialized constant operations can be returned to the caller.
569+
///
568570
/// Note: This function does not erase the operation on a successful fold.
569-
LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results);
571+
LogicalResult
572+
tryFold(Operation *op, SmallVectorImpl<Value> &results,
573+
SmallVector<Operation *> *materializedConstants = nullptr);
570574

571575
/// Creates a deep copy of the specified operation, remapping any operands
572576
/// that use values outside of the operation using the map that is provided

mlir/lib/IR/Builders.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -465,8 +465,9 @@ Operation *OpBuilder::create(Location loc, StringAttr opName,
465465
return create(state);
466466
}
467467

468-
LogicalResult OpBuilder::tryFold(Operation *op,
469-
SmallVectorImpl<Value> &results) {
468+
LogicalResult
469+
OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value> &results,
470+
SmallVector<Operation *> *materializedConstants) {
470471
assert(results.empty() && "expected empty results");
471472
ResultRange opResults = op->getResults();
472473

@@ -528,6 +529,10 @@ LogicalResult OpBuilder::tryFold(Operation *op,
528529
for (Operation *cst : generatedConstants)
529530
insert(cst);
530531

532+
// Return materialized constant operations.
533+
if (materializedConstants)
534+
*materializedConstants = std::move(generatedConstants);
535+
531536
return success();
532537
}
533538

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2090,37 +2090,34 @@ LogicalResult
20902090
OperationLegalizer::legalizeWithFold(Operation *op,
20912091
ConversionPatternRewriter &rewriter) {
20922092
auto &rewriterImpl = rewriter.getImpl();
2093-
RewriterState curState = rewriterImpl.getCurrentState();
2094-
20952093
LLVM_DEBUG({
20962094
rewriterImpl.logger.startLine() << "* Fold {\n";
20972095
rewriterImpl.logger.indent();
20982096
});
20992097

21002098
// Try to fold the operation.
21012099
SmallVector<Value, 2> replacementValues;
2100+
SmallVector<Operation *> newOps;
21022101
rewriter.setInsertionPoint(op);
2103-
if (failed(rewriter.tryFold(op, replacementValues))) {
2102+
if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
21042103
LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
21052104
return failure();
21062105
}
2106+
21072107
// An empty list of replacement values indicates that the fold was in-place.
21082108
// As the operation changed, a new legalization needs to be attempted.
21092109
if (replacementValues.empty())
21102110
return legalize(op, rewriter);
21112111

21122112
// Recursively legalize any new constant operations.
2113-
for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
2114-
i != e; ++i) {
2115-
auto *createOp =
2116-
dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
2117-
if (!createOp)
2118-
continue;
2119-
if (failed(legalize(createOp->getOperation(), rewriter))) {
2113+
for (Operation *newOp : newOps) {
2114+
if (failed(legalize(newOp, rewriter))) {
21202115
LLVM_DEBUG(logFailure(rewriterImpl.logger,
21212116
"failed to legalize generated constant '{0}'",
2122-
createOp->getOperation()->getName()));
2123-
rewriterImpl.resetState(curState);
2117+
newOp->getName()));
2118+
// Legalization failed: erase all materialized constants.
2119+
for (Operation *op : newOps)
2120+
rewriter.eraseOp(op);
21242121
return failure();
21252122
}
21262123
}

0 commit comments

Comments
 (0)