Skip to content

[MLIR] Harmonize the behavior of the folding API functions #88508

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 23, 2024

Conversation

Dinistro
Copy link
Contributor

This commit changes OpBuilder::tryFold to behave more similarly to Operation::fold. Concretely, this ensures that even an in-place fold returns success.
This is necessary to fix a bug in the dialect conversion that occurred when an in-place folding made an operation legal. The dialect conversion infrastructure did not check if the result of an in-place folding legalized the operation and just went ahead and tried to apply pattern anyways.

The added test contains a simplified version of a breakage we observed downstream.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:llvm mlir labels Apr 12, 2024
@llvmbot
Copy link
Member

llvmbot commented Apr 12, 2024

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir-llvm

Author: Christian Ulmann (Dinistro)

Changes

This commit changes OpBuilder::tryFold to behave more similarly to Operation::fold. Concretely, this ensures that even an in-place fold returns success.
This is necessary to fix a bug in the dialect conversion that occurred when an in-place folding made an operation legal. The dialect conversion infrastructure did not check if the result of an in-place folding legalized the operation and just went ahead and tried to apply pattern anyways.

The added test contains a simplified version of a breakage we observed downstream.


Full diff: https://github.com/llvm/llvm-project/pull/88508.diff

8 Files Affected:

  • (modified) mlir/include/mlir/IR/Builders.h (+12-4)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+2-1)
  • (modified) mlir/lib/IR/Builders.cpp (+11-9)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+4)
  • (modified) mlir/test/Transforms/test-legalizer.mlir (+10)
  • (modified) mlir/test/lib/Dialect/Test/TestDialect.cpp (+9)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (+6)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+6)
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 3beade017d1ab9..e74505e5dbfdf4 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -517,7 +517,7 @@ class OpBuilder : public Builder {
 
   /// Create an operation of specific op type at the current insertion point,
   /// and immediately try to fold it. This functions populates 'results' with
-  /// the results after folding the operation.
+  /// the results of the operation.
   template <typename OpTy, typename... Args>
   void createOrFold(SmallVectorImpl<Value> &results, Location location,
                     Args &&...args) {
@@ -530,10 +530,17 @@ class OpBuilder : public Builder {
     if (block)
       block->getOperations().insert(insertPoint, op);
 
-    // Fold the operation. If successful erase it, otherwise notify.
-    if (succeeded(tryFold(op, results)))
+    // Attempt to fold the operation.
+    if (succeeded(tryFold(op, results)) && !results.empty()) {
+      // Erase the operation, if the fold removed the need for this operation.
+      // Note: The fold already populated the results in this case.
       op->erase();
-    else if (block && listener)
+      return;
+    }
+
+    ResultRange opResults = op->getResults();
+    results.assign(opResults.begin(), opResults.end());
+    if (block && listener)
       listener->notifyOperationInserted(op, /*previous=*/{});
   }
 
@@ -561,6 +568,7 @@ class OpBuilder : public Builder {
 
   /// Attempts to fold the given operation and places new results within
   /// 'results'. Returns success if the operation was folded, failure otherwise.
+  /// If the fold was in-place, `results` will not be filled.
   /// Note: This function does not erase the operation on a successful fold.
   LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results);
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index f90240a67dcc5f..0fff06df39c1e7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2763,7 +2763,8 @@ LogicalResult SExtOp::verify() { return verifyExtOp<SExtOp>(*this); }
 
 /// Folds a cast op that can be chained.
 template <typename T>
-static Value foldChainableCast(T castOp, typename T::FoldAdaptor adaptor) {
+static OpFoldResult foldChainableCast(T castOp,
+                                      typename T::FoldAdaptor adaptor) {
   // cast(x : T0, T0) -> x
   if (castOp.getArg().getType() == castOp.getType())
     return castOp.getArg();
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 18ca3c332e0204..36e17609eab609 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -476,16 +476,14 @@ Operation *OpBuilder::create(Location loc, StringAttr opName,
   return create(state);
 }
 
-/// Attempts to fold the given operation and places new results within
-/// 'results'. Returns success if the operation was folded, failure otherwise.
-/// Note: This function does not erase the operation on a successful fold.
 LogicalResult OpBuilder::tryFold(Operation *op,
                                  SmallVectorImpl<Value> &results) {
+  assert(results.empty());
   ResultRange opResults = op->getResults();
 
   results.reserve(opResults.size());
   auto cleanupFailure = [&] {
-    results.assign(opResults.begin(), opResults.end());
+    results.clear();
     return failure();
   };
 
@@ -495,20 +493,24 @@ LogicalResult OpBuilder::tryFold(Operation *op,
 
   // Try to fold the operation.
   SmallVector<OpFoldResult, 4> foldResults;
-  if (failed(op->fold(foldResults)) || foldResults.empty())
+  if (failed(op->fold(foldResults)))
     return cleanupFailure();
 
+  // An in-place fold does not require generation of any constants.
+  if (foldResults.empty())
+    return success();
+
   // A temporary builder used for creating constants during folding.
   OpBuilder cstBuilder(context);
   SmallVector<Operation *, 1> generatedConstants;
 
   // Populate the results with the folded results.
   Dialect *dialect = op->getDialect();
-  for (auto it : llvm::zip_equal(foldResults, opResults.getTypes())) {
-    Type expectedType = std::get<1>(it);
+  for (auto [foldResult, expectedType] :
+       llvm::zip_equal(foldResults, opResults.getTypes())) {
 
     // Normal values get pushed back directly.
-    if (auto value = llvm::dyn_cast_if_present<Value>(std::get<0>(it))) {
+    if (auto value = llvm::dyn_cast_if_present<Value>(foldResult)) {
       results.push_back(value);
       continue;
     }
@@ -518,7 +520,7 @@ LogicalResult OpBuilder::tryFold(Operation *op,
       return cleanupFailure();
 
     // Ask the dialect to materialize a constant operation for this value.
-    Attribute attr = std::get<0>(it).get<Attribute>();
+    Attribute attr = foldResult.get<Attribute>();
     auto *constOp = dialect->materializeConstant(cstBuilder, attr, expectedType,
                                                  op->getLoc());
     if (!constOp) {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 8671c1008902a0..18d6f7daa4bea7 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2072,6 +2072,10 @@ OperationLegalizer::legalizeWithFold(Operation *op,
     LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
     return failure();
   }
+  // An empty list of replacement values indicates that the fold was in-place.
+  // As the operation changed, a new legalization needs to be attempted.
+  if (replacementValues.empty())
+    return legalize(op, rewriter);
 
   // Insert a replacement for 'op' with the folded replacement values.
   rewriter.replaceOp(op, replacementValues);
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index d552f0346644b3..7530b300d57b8b 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -427,3 +427,13 @@ func.func @use_of_replaced_bbarg(%arg0: i64) {
   }) : (i64) -> (i64)
   "test.invalid"(%0) : (i64) -> ()
 }
+
+// -----
+
+// CHECK-LABEL: @fold_legalization
+func.func @fold_legalization() -> i32 {
+  // CHECK: op_in_place_self_fold
+  // CHECK-SAME: folded = true
+  %1 = "test.op_in_place_self_fold"() : () -> (i32)
+  "test.return"(%1) : (i32) -> ()
+}
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 380c74a47e509a..becd0d68bf1d3e 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -568,6 +568,15 @@ OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) {
   return {};
 }
 
+OpFoldResult TestOpInPlaceSelfFold::fold(FoldAdaptor adaptor) {
+  if (!getProperties().folded) {
+    // The folder adds the "folded" if not present.
+    getProperties().folded = BoolAttr::get(getContext(), true);
+    return getResult();
+  }
+  return {};
+}
+
 OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) {
   int64_t sum = 0;
   if (auto value = dyn_cast_or_null<IntegerAttr>(adaptor.getOp()))
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index e6c3601d08dad0..663064d51f1bbe 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1351,6 +1351,12 @@ def TestOpInPlaceFold : TEST_Op<"op_in_place_fold"> {
   let hasFolder = 1;
 }
 
+def TestOpInPlaceSelfFold : TEST_Op<"op_in_place_self_fold"> {
+  let arguments = (ins OptionalAttr<BoolAttr>:$folded);
+  let results = (outs I32);
+  let hasFolder = 1;
+}
+
 // Test op that simply returns success.
 def TestOpInPlaceFoldSuccess : TEST_Op<"op_in_place_fold_success"> {
   let results = (outs Variadic<I1>);
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 76dc825fe44515..285e39dd9016e1 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1167,6 +1167,12 @@ struct TestLegalizePatternDriver
     target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
         [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; });
 
+    // Create a dynamically legal rule that can only be legalized by folding it.
+    target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
+        [](TestOpInPlaceSelfFold op) {
+          return op.getProperties().folded != nullptr;
+        });
+
     // Handle a partial conversion.
     if (mode == ConversionMode::Partial) {
       DenseSet<Operation *> unlegalizedOps;

@Dinistro
Copy link
Contributor Author

@matthias-springer friendly ping 🙂

LogicalResult OpBuilder::tryFold(Operation *op,
SmallVectorImpl<Value> &results) {
assert(results.empty());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: && "expected empty results"

@@ -561,6 +568,7 @@ class OpBuilder : public Builder {

/// Attempts to fold the given operation and places new results within
/// 'results'. Returns success if the operation was folded, failure otherwise.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use backticks

@@ -1351,6 +1351,12 @@ def TestOpInPlaceFold : TEST_Op<"op_in_place_fold"> {
let hasFolder = 1;
}

def TestOpInPlaceSelfFold : TEST_Op<"op_in_place_self_fold"> {
let arguments = (ins OptionalAttr<BoolAttr>:$folded);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about turning this into a UnitAttr?

// Create a dynamically legal rule that can only be legalized by folding it.
target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
[](TestOpInPlaceSelfFold op) {
return op.getProperties().folded != nullptr;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When this is a unit attr, you can write return op.getFolded();.

OpFoldResult TestOpInPlaceSelfFold::fold(FoldAdaptor adaptor) {
if (!getProperties().folded) {
// The folder adds the "folded" if not present.
getProperties().folded = BoolAttr::get(getContext(), true);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use setFolded here?

This commit changes `OpBuilder::tryFold` to behave more similarly to
`Operation::fold`. Concretely, this ensures that even an in-place fold
returns `success`. This is necessary to fix a bug in the dialect
conversion that occurred when an in-place folding made an operation
legal. The dialect conversion infrastructure did not check if the result
of an in-place folding legalized the operation and just went ahead and
tried to apply pattern anyways.
@Dinistro Dinistro force-pushed the users/dinistro/fix-dialect-conversion-folding branch from 50edf67 to c4e1f79 Compare April 23, 2024 05:47
@Dinistro Dinistro merged commit 4513050 into main Apr 23, 2024
@Dinistro Dinistro deleted the users/dinistro/fix-dialect-conversion-folding branch April 23, 2024 06:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:llvm mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants