-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Transforms] Dialect conversion: Erase materialized constants instead of rollback #136489
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Transforms] Dialect conversion: Erase materialized constants instead of rollback #136489
Conversation
@llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesWhen illegal (and not legalizable) constant operations are materialized during a dialect conversion as part of op folding, these operations must be deleted again. This used to be implemented via the rollback mechanism. This commit switches the implementation to regular rewriter API usage: simply delete the materialized constants with This commit is in preparation of the One-Shot Dialect Conversion refactoring, which will disallow IR rollbacks. This commit also adds a new optional parameter to Full diff: https://github.com/llvm/llvm-project/pull/136489.diff 3 Files Affected:
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index cd8d3ee0af72b..8f13705fac96d 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -564,9 +564,13 @@ class OpBuilder : public Builder {
/// Attempts to fold the given operation and places new results within
/// `results`. Returns success if the operation was folded, failure otherwise.
- /// If the fold was in-place, `results` will not be filled.
+ /// If the fold was in-place, `results` will not be filled. Optionally, newly
+ /// materialized constant operations can be returned to the caller.
+ ///
/// Note: This function does not erase the operation on a successful fold.
- LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results);
+ LogicalResult
+ tryFold(Operation *op, SmallVectorImpl<Value> &results,
+ SmallVector<Operation *> *materializedConstants = nullptr);
/// Creates a deep copy of the specified operation, remapping any operands
/// that use values outside of the operation using the map that is provided
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 16bd8201ad50a..9450ef7738fa0 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -465,8 +465,9 @@ Operation *OpBuilder::create(Location loc, StringAttr opName,
return create(state);
}
-LogicalResult OpBuilder::tryFold(Operation *op,
- SmallVectorImpl<Value> &results) {
+LogicalResult
+OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value> &results,
+ SmallVector<Operation *> *materializedConstants) {
assert(results.empty() && "expected empty results");
ResultRange opResults = op->getResults();
@@ -528,6 +529,10 @@ LogicalResult OpBuilder::tryFold(Operation *op,
for (Operation *cst : generatedConstants)
insert(cst);
+ // Return materialized constant operations.
+ if (materializedConstants)
+ *materializedConstants = std::move(generatedConstants);
+
return success();
}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index a56fca25e1697..63225c6bbee7c 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2090,8 +2090,6 @@ LogicalResult
OperationLegalizer::legalizeWithFold(Operation *op,
ConversionPatternRewriter &rewriter) {
auto &rewriterImpl = rewriter.getImpl();
- RewriterState curState = rewriterImpl.getCurrentState();
-
LLVM_DEBUG({
rewriterImpl.logger.startLine() << "* Fold {\n";
rewriterImpl.logger.indent();
@@ -2099,28 +2097,27 @@ OperationLegalizer::legalizeWithFold(Operation *op,
// Try to fold the operation.
SmallVector<Value, 2> replacementValues;
+ SmallVector<Operation *> newOps;
rewriter.setInsertionPoint(op);
- if (failed(rewriter.tryFold(op, replacementValues))) {
+ if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
return failure();
}
+
// An empty list of replacement values indicates that the fold was in-place.
// As the operation changed, a new legalization needs to be attempted.
if (replacementValues.empty())
return legalize(op, rewriter);
// Recursively legalize any new constant operations.
- for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
- i != e; ++i) {
- auto *createOp =
- dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
- if (!createOp)
- continue;
- if (failed(legalize(createOp->getOperation(), rewriter))) {
+ for (Operation *newOp : newOps) {
+ if (failed(legalize(newOp, rewriter))) {
LLVM_DEBUG(logFailure(rewriterImpl.logger,
"failed to legalize generated constant '{0}'",
- createOp->getOperation()->getName()));
- rewriterImpl.resetState(curState);
+ newOp->getName()));
+ // Legalization failed: erase all materialized constants.
+ for (Operation *op : newOps)
+ rewriter.eraseOp(op);
return failure();
}
}
|
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesWhen illegal (and not legalizable) constant operations are materialized during a dialect conversion as part of op folding, these operations must be deleted again. This used to be implemented via the rollback mechanism. This commit switches the implementation to regular rewriter API usage: simply delete the materialized constants with This commit is in preparation of the One-Shot Dialect Conversion refactoring, which will disallow IR rollbacks. This commit also adds a new optional parameter to Full diff: https://github.com/llvm/llvm-project/pull/136489.diff 3 Files Affected:
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index cd8d3ee0af72b..8f13705fac96d 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -564,9 +564,13 @@ class OpBuilder : public Builder {
/// Attempts to fold the given operation and places new results within
/// `results`. Returns success if the operation was folded, failure otherwise.
- /// If the fold was in-place, `results` will not be filled.
+ /// If the fold was in-place, `results` will not be filled. Optionally, newly
+ /// materialized constant operations can be returned to the caller.
+ ///
/// Note: This function does not erase the operation on a successful fold.
- LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results);
+ LogicalResult
+ tryFold(Operation *op, SmallVectorImpl<Value> &results,
+ SmallVector<Operation *> *materializedConstants = nullptr);
/// Creates a deep copy of the specified operation, remapping any operands
/// that use values outside of the operation using the map that is provided
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 16bd8201ad50a..9450ef7738fa0 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -465,8 +465,9 @@ Operation *OpBuilder::create(Location loc, StringAttr opName,
return create(state);
}
-LogicalResult OpBuilder::tryFold(Operation *op,
- SmallVectorImpl<Value> &results) {
+LogicalResult
+OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value> &results,
+ SmallVector<Operation *> *materializedConstants) {
assert(results.empty() && "expected empty results");
ResultRange opResults = op->getResults();
@@ -528,6 +529,10 @@ LogicalResult OpBuilder::tryFold(Operation *op,
for (Operation *cst : generatedConstants)
insert(cst);
+ // Return materialized constant operations.
+ if (materializedConstants)
+ *materializedConstants = std::move(generatedConstants);
+
return success();
}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index a56fca25e1697..63225c6bbee7c 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2090,8 +2090,6 @@ LogicalResult
OperationLegalizer::legalizeWithFold(Operation *op,
ConversionPatternRewriter &rewriter) {
auto &rewriterImpl = rewriter.getImpl();
- RewriterState curState = rewriterImpl.getCurrentState();
-
LLVM_DEBUG({
rewriterImpl.logger.startLine() << "* Fold {\n";
rewriterImpl.logger.indent();
@@ -2099,28 +2097,27 @@ OperationLegalizer::legalizeWithFold(Operation *op,
// Try to fold the operation.
SmallVector<Value, 2> replacementValues;
+ SmallVector<Operation *> newOps;
rewriter.setInsertionPoint(op);
- if (failed(rewriter.tryFold(op, replacementValues))) {
+ if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
return failure();
}
+
// An empty list of replacement values indicates that the fold was in-place.
// As the operation changed, a new legalization needs to be attempted.
if (replacementValues.empty())
return legalize(op, rewriter);
// Recursively legalize any new constant operations.
- for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
- i != e; ++i) {
- auto *createOp =
- dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
- if (!createOp)
- continue;
- if (failed(legalize(createOp->getOperation(), rewriter))) {
+ for (Operation *newOp : newOps) {
+ if (failed(legalize(newOp, rewriter))) {
LLVM_DEBUG(logFailure(rewriterImpl.logger,
"failed to legalize generated constant '{0}'",
- createOp->getOperation()->getName()));
- rewriterImpl.resetState(curState);
+ newOp->getName()));
+ // Legalization failed: erase all materialized constants.
+ for (Operation *op : newOps)
+ rewriter.eraseOp(op);
return failure();
}
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
63e2888
to
8bf5ca0
Compare
let's start working to integrate llvm/llvm-project#136489 and llvm/llvm-project#136490 (maybe there will be no work 🤞
let's start working to integrate llvm/llvm-project#136489 and llvm/llvm-project#136490 (maybe there will be no work 🤞
…nstead of rollback (llvm#136489) When illegal (and not legalizable) constant operations are materialized during a dialect conversion as part of op folding, these operations must be deleted again. This used to be implemented via the rollback mechanism. This commit switches the implementation to regular rewriter API usage: simply delete the materialized constants with `eraseOp`. This commit is in preparation of the One-Shot Dialect Conversion refactoring, which will disallow IR rollbacks. This commit also adds a new optional parameter to `OpBuilder::tryFold` to get hold of the materialized constant ops.
…nstead of rollback (llvm#136489) When illegal (and not legalizable) constant operations are materialized during a dialect conversion as part of op folding, these operations must be deleted again. This used to be implemented via the rollback mechanism. This commit switches the implementation to regular rewriter API usage: simply delete the materialized constants with `eraseOp`. This commit is in preparation of the One-Shot Dialect Conversion refactoring, which will disallow IR rollbacks. This commit also adds a new optional parameter to `OpBuilder::tryFold` to get hold of the materialized constant ops.
…nstead of rollback (llvm#136489) When illegal (and not legalizable) constant operations are materialized during a dialect conversion as part of op folding, these operations must be deleted again. This used to be implemented via the rollback mechanism. This commit switches the implementation to regular rewriter API usage: simply delete the materialized constants with `eraseOp`. This commit is in preparation of the One-Shot Dialect Conversion refactoring, which will disallow IR rollbacks. This commit also adds a new optional parameter to `OpBuilder::tryFold` to get hold of the materialized constant ops.
When illegal (and not legalizable) constant operations are materialized during a dialect conversion as part of op folding, these operations must be deleted again. This used to be implemented via the rollback mechanism. This commit switches the implementation to regular rewriter API usage: simply delete the materialized constants with
eraseOp
.This commit is in preparation of the One-Shot Dialect Conversion refactoring, which will disallow IR rollbacks.
This commit also adds a new optional parameter to
OpBuilder::tryFold
to get hold of the materialized constant ops.