Skip to content

[mlir][Transforms] Dialect conversion: Erase materialized constants instead of rollback #136489

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

matthias-springer
Copy link
Member

When illegal (and not legalizable) constant operations are materialized during a dialect conversion as part of op folding, these operations must be deleted again. This used to be implemented via the rollback mechanism. This commit switches the implementation to regular rewriter API usage: simply delete the materialized constants with eraseOp.

This commit is in preparation of the One-Shot Dialect Conversion refactoring, which will disallow IR rollbacks.

This commit also adds a new optional parameter to OpBuilder::tryFold to get hold of the materialized constant ops.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Apr 20, 2025
@llvmbot
Copy link
Member

llvmbot commented Apr 20, 2025

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

When illegal (and not legalizable) constant operations are materialized during a dialect conversion as part of op folding, these operations must be deleted again. This used to be implemented via the rollback mechanism. This commit switches the implementation to regular rewriter API usage: simply delete the materialized constants with eraseOp.

This commit is in preparation of the One-Shot Dialect Conversion refactoring, which will disallow IR rollbacks.

This commit also adds a new optional parameter to OpBuilder::tryFold to get hold of the materialized constant ops.


Full diff: https://github.com/llvm/llvm-project/pull/136489.diff

3 Files Affected:

  • (modified) mlir/include/mlir/IR/Builders.h (+6-2)
  • (modified) mlir/lib/IR/Builders.cpp (+7-2)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+9-12)
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index cd8d3ee0af72b..8f13705fac96d 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -564,9 +564,13 @@ class OpBuilder : public Builder {
 
   /// Attempts to fold the given operation and places new results within
   /// `results`. Returns success if the operation was folded, failure otherwise.
-  /// If the fold was in-place, `results` will not be filled.
+  /// If the fold was in-place, `results` will not be filled. Optionally, newly
+  /// materialized constant operations can be returned to the caller.
+  ///
   /// Note: This function does not erase the operation on a successful fold.
-  LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results);
+  LogicalResult
+  tryFold(Operation *op, SmallVectorImpl<Value> &results,
+          SmallVector<Operation *> *materializedConstants = nullptr);
 
   /// Creates a deep copy of the specified operation, remapping any operands
   /// that use values outside of the operation using the map that is provided
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 16bd8201ad50a..9450ef7738fa0 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -465,8 +465,9 @@ Operation *OpBuilder::create(Location loc, StringAttr opName,
   return create(state);
 }
 
-LogicalResult OpBuilder::tryFold(Operation *op,
-                                 SmallVectorImpl<Value> &results) {
+LogicalResult
+OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value> &results,
+                   SmallVector<Operation *> *materializedConstants) {
   assert(results.empty() && "expected empty results");
   ResultRange opResults = op->getResults();
 
@@ -528,6 +529,10 @@ LogicalResult OpBuilder::tryFold(Operation *op,
   for (Operation *cst : generatedConstants)
     insert(cst);
 
+  // Return materialized constant operations.
+  if (materializedConstants)
+    *materializedConstants = std::move(generatedConstants);
+
   return success();
 }
 
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index a56fca25e1697..63225c6bbee7c 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2090,8 +2090,6 @@ LogicalResult
 OperationLegalizer::legalizeWithFold(Operation *op,
                                      ConversionPatternRewriter &rewriter) {
   auto &rewriterImpl = rewriter.getImpl();
-  RewriterState curState = rewriterImpl.getCurrentState();
-
   LLVM_DEBUG({
     rewriterImpl.logger.startLine() << "* Fold {\n";
     rewriterImpl.logger.indent();
@@ -2099,28 +2097,27 @@ OperationLegalizer::legalizeWithFold(Operation *op,
 
   // Try to fold the operation.
   SmallVector<Value, 2> replacementValues;
+  SmallVector<Operation *> newOps;
   rewriter.setInsertionPoint(op);
-  if (failed(rewriter.tryFold(op, replacementValues))) {
+  if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
     LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
     return failure();
   }
+
   // An empty list of replacement values indicates that the fold was in-place.
   // As the operation changed, a new legalization needs to be attempted.
   if (replacementValues.empty())
     return legalize(op, rewriter);
 
   // Recursively legalize any new constant operations.
-  for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
-       i != e; ++i) {
-    auto *createOp =
-        dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
-    if (!createOp)
-      continue;
-    if (failed(legalize(createOp->getOperation(), rewriter))) {
+  for (Operation *newOp : newOps) {
+    if (failed(legalize(newOp, rewriter))) {
       LLVM_DEBUG(logFailure(rewriterImpl.logger,
                             "failed to legalize generated constant '{0}'",
-                            createOp->getOperation()->getName()));
-      rewriterImpl.resetState(curState);
+                            newOp->getName()));
+      // Legalization failed: erase all materialized constants.
+      for (Operation *op : newOps)
+        rewriter.eraseOp(op);
       return failure();
     }
   }

@llvmbot
Copy link
Member

llvmbot commented Apr 20, 2025

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

When illegal (and not legalizable) constant operations are materialized during a dialect conversion as part of op folding, these operations must be deleted again. This used to be implemented via the rollback mechanism. This commit switches the implementation to regular rewriter API usage: simply delete the materialized constants with eraseOp.

This commit is in preparation of the One-Shot Dialect Conversion refactoring, which will disallow IR rollbacks.

This commit also adds a new optional parameter to OpBuilder::tryFold to get hold of the materialized constant ops.


Full diff: https://github.com/llvm/llvm-project/pull/136489.diff

3 Files Affected:

  • (modified) mlir/include/mlir/IR/Builders.h (+6-2)
  • (modified) mlir/lib/IR/Builders.cpp (+7-2)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+9-12)
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index cd8d3ee0af72b..8f13705fac96d 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -564,9 +564,13 @@ class OpBuilder : public Builder {
 
   /// Attempts to fold the given operation and places new results within
   /// `results`. Returns success if the operation was folded, failure otherwise.
-  /// If the fold was in-place, `results` will not be filled.
+  /// If the fold was in-place, `results` will not be filled. Optionally, newly
+  /// materialized constant operations can be returned to the caller.
+  ///
   /// Note: This function does not erase the operation on a successful fold.
-  LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results);
+  LogicalResult
+  tryFold(Operation *op, SmallVectorImpl<Value> &results,
+          SmallVector<Operation *> *materializedConstants = nullptr);
 
   /// Creates a deep copy of the specified operation, remapping any operands
   /// that use values outside of the operation using the map that is provided
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 16bd8201ad50a..9450ef7738fa0 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -465,8 +465,9 @@ Operation *OpBuilder::create(Location loc, StringAttr opName,
   return create(state);
 }
 
-LogicalResult OpBuilder::tryFold(Operation *op,
-                                 SmallVectorImpl<Value> &results) {
+LogicalResult
+OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value> &results,
+                   SmallVector<Operation *> *materializedConstants) {
   assert(results.empty() && "expected empty results");
   ResultRange opResults = op->getResults();
 
@@ -528,6 +529,10 @@ LogicalResult OpBuilder::tryFold(Operation *op,
   for (Operation *cst : generatedConstants)
     insert(cst);
 
+  // Return materialized constant operations.
+  if (materializedConstants)
+    *materializedConstants = std::move(generatedConstants);
+
   return success();
 }
 
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index a56fca25e1697..63225c6bbee7c 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2090,8 +2090,6 @@ LogicalResult
 OperationLegalizer::legalizeWithFold(Operation *op,
                                      ConversionPatternRewriter &rewriter) {
   auto &rewriterImpl = rewriter.getImpl();
-  RewriterState curState = rewriterImpl.getCurrentState();
-
   LLVM_DEBUG({
     rewriterImpl.logger.startLine() << "* Fold {\n";
     rewriterImpl.logger.indent();
@@ -2099,28 +2097,27 @@ OperationLegalizer::legalizeWithFold(Operation *op,
 
   // Try to fold the operation.
   SmallVector<Value, 2> replacementValues;
+  SmallVector<Operation *> newOps;
   rewriter.setInsertionPoint(op);
-  if (failed(rewriter.tryFold(op, replacementValues))) {
+  if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
     LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
     return failure();
   }
+
   // An empty list of replacement values indicates that the fold was in-place.
   // As the operation changed, a new legalization needs to be attempted.
   if (replacementValues.empty())
     return legalize(op, rewriter);
 
   // Recursively legalize any new constant operations.
-  for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
-       i != e; ++i) {
-    auto *createOp =
-        dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
-    if (!createOp)
-      continue;
-    if (failed(legalize(createOp->getOperation(), rewriter))) {
+  for (Operation *newOp : newOps) {
+    if (failed(legalize(newOp, rewriter))) {
       LLVM_DEBUG(logFailure(rewriterImpl.logger,
                             "failed to legalize generated constant '{0}'",
-                            createOp->getOperation()->getName()));
-      rewriterImpl.resetState(curState);
+                            newOp->getName()));
+      // Legalization failed: erase all materialized constants.
+      for (Operation *op : newOps)
+        rewriter.eraseOp(op);
       return failure();
     }
   }

Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@matthias-springer matthias-springer force-pushed the users/matthias-springer/dialect_conv_rollback_fold branch from 63e2888 to 8bf5ca0 Compare April 22, 2025 06:55
@matthias-springer matthias-springer merged commit 2d3bbb6 into main Apr 22, 2025
7 of 10 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/dialect_conv_rollback_fold branch April 22, 2025 07:12
makslevental added a commit to makslevental/triton that referenced this pull request Apr 23, 2025
let's start working to integrate 

llvm/llvm-project#136489

and 

llvm/llvm-project#136490

(maybe there will be no work 🤞
antiagainst pushed a commit to triton-lang/triton that referenced this pull request Apr 23, 2025
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
…nstead of rollback (llvm#136489)

When illegal (and not legalizable) constant operations are materialized
during a dialect conversion as part of op folding, these operations must
be deleted again. This used to be implemented via the rollback
mechanism. This commit switches the implementation to regular rewriter
API usage: simply delete the materialized constants with `eraseOp`.

This commit is in preparation of the One-Shot Dialect Conversion
refactoring, which will disallow IR rollbacks.

This commit also adds a new optional parameter to `OpBuilder::tryFold`
to get hold of the materialized constant ops.
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
…nstead of rollback (llvm#136489)

When illegal (and not legalizable) constant operations are materialized
during a dialect conversion as part of op folding, these operations must
be deleted again. This used to be implemented via the rollback
mechanism. This commit switches the implementation to regular rewriter
API usage: simply delete the materialized constants with `eraseOp`.

This commit is in preparation of the One-Shot Dialect Conversion
refactoring, which will disallow IR rollbacks.

This commit also adds a new optional parameter to `OpBuilder::tryFold`
to get hold of the materialized constant ops.
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
…nstead of rollback (llvm#136489)

When illegal (and not legalizable) constant operations are materialized
during a dialect conversion as part of op folding, these operations must
be deleted again. This used to be implemented via the rollback
mechanism. This commit switches the implementation to regular rewriter
API usage: simply delete the materialized constants with `eraseOp`.

This commit is in preparation of the One-Shot Dialect Conversion
refactoring, which will disallow IR rollbacks.

This commit also adds a new optional parameter to `OpBuilder::tryFold`
to get hold of the materialized constant ops.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants