Skip to content

[mlir][bufferization] Add support for non-unique func.return #114017

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Oct 29, 2024

Multiple func.return ops inside of a func.func op are now supported during bufferization. This PR extends the code base in 3 places:

  • When inferring function return types, memref.cast ops are folded away only if all func.return ops have matching buffer types. (E.g., we don't fold if two return ops have operands with different layout maps.)
  • The alias sets of all func.return ops are merged. That's because aliasing is a "may be" property.
  • The equivalence sets of all func.return ops are taken only if they match. If different func.return ops have different equivalence sets for their operands, the equivalence information is dropped. That's because equivalence is a "must be" property.

This commit is in preparation of removing the deprecated func-bufferize pass. That pass can bufferize functions with multiple return ops.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/bufferization_bodiless branch from a9f607f to 11089cc Compare October 29, 2024 22:16
@matthias-springer matthias-springer force-pushed the users/matthias-springer/bufferization_unique_returns branch 2 times, most recently from 13808b4 to 3522311 Compare October 29, 2024 22:51
@matthias-springer matthias-springer marked this pull request as ready for review October 29, 2024 22:56
@llvmbot llvmbot added mlir mlir:bufferization Bufferization infrastructure labels Oct 29, 2024
@llvmbot
Copy link
Member

llvmbot commented Oct 29, 2024

@llvm/pr-subscribers-mlir-bufferization

Author: Matthias Springer (matthias-springer)

Changes

Multiple func.return ops inside of a func.func op are now supported during bufferization. This PR extends the code base in 3 places:

  • When inferring function return types, memref.cast ops are folded away only if all func.return ops have matching buffer types. (E.g., we don't fold if two return ops have operands with different layout maps.)
  • The alias sets of all func.return ops are merged. That's because aliasing is a "may be" property.
  • The equivalence sets of all func.return ops are taken only if they match. If different func.return ops have different equivalence sets for their operands, the equivalence information is dropped. That's because equivalence is a "must be" property.

Depends on #113999.


Patch is 20.08 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/114017.diff

5 Files Affected:

  • (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+29-46)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp (+135-44)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir (+45)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir (+2-20)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir (+24)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 6e91d3b89a7c79..195b17fcf902a2 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -41,18 +41,13 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
 #endif // NDEBUG
 }
 
-/// Return the unique ReturnOp that terminates `funcOp`.
-/// Return nullptr if there is no such unique ReturnOp.
-static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
-  func::ReturnOp returnOp;
-  for (Block &b : funcOp.getBody()) {
-    if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
-      if (returnOp)
-        return nullptr;
-      returnOp = candidateOp;
-    }
-  }
-  return returnOp;
+/// Return all top-level func.return ops in the given function.
+static SmallVector<func::ReturnOp> getReturnOps(FuncOp funcOp) {
+  SmallVector<func::ReturnOp> result;
+  for (Block &b : funcOp.getBody())
+    if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator()))
+      result.push_back(returnOp);
+  return result;
 }
 
 /// Return the index-th bufferized function argument type. This assumes that the
@@ -372,15 +367,6 @@ struct FuncOpInterface
         getBufferType(op, value, options, invocationStack);
   }
 
-  LogicalResult verifyAnalysis(Operation *op,
-                               const AnalysisState &state) const {
-    auto funcOp = cast<func::FuncOp>(op);
-    // TODO: func.func with multiple returns are not supported.
-    if (!getAssumedUniqueReturnOp(funcOp) && !funcOp.isExternal())
-      return op->emitOpError("op without unique func.return is not supported");
-    return success();
-  }
-
   /// Rewrite function bbArgs and return values into buffer form. This function
   /// bufferizes the function signature and the ReturnOp. When the entire
   /// function body has been bufferized, function return types can be switched
@@ -427,41 +413,38 @@ struct FuncOpInterface
       return success();
     }
 
-    // TODO: Support functions with multiple returns.
-    func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
-    assert(returnOp && "expected func with single return op");
-    assert(returnOp->getNumOperands() == retTypes.size() &&
-           "incorrect number of return values");
-    Location loc = returnOp.getLoc();
-
     // 1. Bufferize every block.
     for (Block &block : funcOp.getBody())
       if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
                                                         options)))
         return failure();
 
-    // 2. Bufferize all operands of the return op.
-    SmallVector<Value> returnValues;
-    for (auto [returnVal, bufferizedType] :
-         llvm::zip_equal(returnOp->getOperands(), retTypes)) {
-      auto tensorType = dyn_cast<TensorType>(returnVal.getType());
-      rewriter.setInsertionPoint(returnOp);
-
-      // If not a tensor type just forward it.
-      if (!tensorType) {
-        returnValues.push_back(returnVal);
-        continue;
+    // 2. Bufferize the operands of the all return op.
+    for (func::ReturnOp returnOp : getReturnOps(funcOp)) {
+      assert(returnOp->getNumOperands() == retTypes.size() &&
+             "incorrect number of return values");
+      SmallVector<Value> returnValues;
+      for (auto [returnVal, bufferizedType] :
+           llvm::zip_equal(returnOp->getOperands(), retTypes)) {
+        auto tensorType = dyn_cast<TensorType>(returnVal.getType());
+        rewriter.setInsertionPoint(returnOp);
+
+        // If not a tensor type just forward it.
+        if (!tensorType) {
+          returnValues.push_back(returnVal);
+          continue;
+        }
+
+        // Note: If `inferFunctionResultLayout = true`, casts are later folded
+        // away.
+        Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
+            returnOp.getLoc(), bufferizedType, returnVal);
+        returnValues.push_back(toMemrefOp);
       }
 
-      // Note: If `inferFunctionResultLayout = true`, casts are later folded
-      // away.
-      Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
-          loc, bufferizedType, returnVal);
-      returnValues.push_back(toMemrefOp);
+      returnOp.getOperandsMutable().assign(returnValues);
     }
 
-    returnOp.getOperandsMutable().assign(returnValues);
-
     // 3. Set the new function type.
     funcOp.setType(newFuncType);
     return success();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 0a4072605c265f..e4635ebd78d8f8 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -86,18 +86,13 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
   return state.addExtension<FuncAnalysisState>();
 }
 
-/// Return the unique ReturnOp that terminates `funcOp`.
-/// Return nullptr if there is no such unique ReturnOp.
-static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
-  func::ReturnOp returnOp;
-  for (Block &b : funcOp.getBody()) {
-    if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
-      if (returnOp)
-        return nullptr;
-      returnOp = candidateOp;
-    }
-  }
-  return returnOp;
+/// Return all top-level func.return ops in the given function.
+static SmallVector<func::ReturnOp> getReturnOps(FuncOp funcOp) {
+  SmallVector<func::ReturnOp> result;
+  for (Block &b : funcOp.getBody())
+    if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator()))
+      result.push_back(returnOp);
+  return result;
 }
 
 namespace {
@@ -146,24 +141,80 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
     return success();
   }
 
-  // Support only single return-terminated block in the function.
-  func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
-  assert(returnOp && "expected func with single return op");
-
-  for (OpOperand &returnVal : returnOp->getOpOperands())
-    if (isa<RankedTensorType>(returnVal.get().getType()))
-      for (BlockArgument bbArg : funcOp.getArguments())
-        if (isa<RankedTensorType>(bbArg.getType())) {
-          int64_t returnIdx = returnVal.getOperandNumber();
-          int64_t bbArgIdx = bbArg.getArgNumber();
-          if (state.areEquivalentBufferizedValues(returnVal.get(), bbArg)) {
-            funcState.equivalentFuncArgs[funcOp][returnIdx] = bbArgIdx;
-            if (state.getOptions().testAnalysisOnly)
-              annotateEquivalentReturnBbArg(returnVal, bbArg);
+  // Find all func.return ops.
+  SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
+  assert(!returnOps.empty() && "expected at least one ReturnOp");
+
+  // Build alias sets. Merge all aliases from all func.return ops.
+  for (BlockArgument bbArg : funcOp.getArguments()) {
+    if (isa<RankedTensorType>(bbArg.getType())) {
+      int64_t bbArgIdx = bbArg.getArgNumber();
+      // Store aliases in a set, so that we don't add the same alias twice.
+      SetVector<int64_t> aliases;
+      for (func::ReturnOp returnOp : returnOps) {
+        for (OpOperand &returnVal : returnOp->getOpOperands()) {
+          if (isa<RankedTensorType>(returnVal.get().getType())) {
+            int64_t returnIdx = returnVal.getOperandNumber();
+            if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
+              aliases.insert(returnIdx);
           }
-          if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
-            funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx);
         }
+      }
+      for (int64_t alias : aliases)
+        funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(alias);
+    }
+  }
+
+  // Build equivalence sets.
+  // Helper function that finds an equivalent block argument index for the
+  // given OpOperand. Return std::nullopt if no equivalent block argument could
+  // be found.
+  auto findEquivalentBlockArgIdx =
+      [&](OpOperand &opOperand) -> std::optional<int64_t> {
+    Value v = opOperand.get();
+    if (!isa<TensorType>(v.getType()))
+      return std::nullopt;
+    for (BlockArgument bbArg : funcOp.getArguments()) {
+      if (isa<RankedTensorType>(bbArg.getType())) {
+        if (state.areEquivalentBufferizedValues(v, bbArg)) {
+          if (state.getOptions().testAnalysisOnly)
+            annotateEquivalentReturnBbArg(opOperand, bbArg);
+          return bbArg.getArgNumber();
+        }
+      }
+    }
+    return std::nullopt;
+  };
+
+  int64_t numResults = returnOps.front()->getNumOperands();
+  for (int64_t i = 0; i < numResults; ++i) {
+    // Find the equivalent block argument index for the i-th operand of the
+    // first func.return op.
+    std::optional<int64_t> maybeEquiv =
+        findEquivalentBlockArgIdx(returnOps.front()->getOpOperand(i));
+    if (!maybeEquiv.has_value())
+      continue;
+    int64_t bbArgIdx = *maybeEquiv;
+    bool allEquiv = true;
+
+    // Check if all other func.return ops have the same equivalent block
+    // argument for the i-th operand. In contrast to aliasing information,
+    // which is just "merged", equivalence information must match across all
+    // func.return ops.
+    for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) {
+      std::optional<int64_t> maybeEquiv =
+          findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
+      if (maybeEquiv != bbArgIdx) {
+        allEquiv = false;
+        break;
+      }
+    }
+
+    // All func.return ops have the same equivalent block argument for the i-th
+    // operand.
+    if (allEquiv)
+      funcState.equivalentFuncArgs[funcOp][i] = bbArgIdx;
+  }
 
   return success();
 }
@@ -299,14 +350,6 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
   // For each FuncOp, the number of func::CallOp it contains.
   DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
   WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
-    if (!funcOp.getBody().empty()) {
-      func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
-      if (!returnOp)
-        return funcOp->emitError()
-               << "cannot bufferize a FuncOp with tensors and "
-                  "without a unique ReturnOp";
-    }
-
     // Collect function calls and populate the caller map.
     numberCallOpsContainedInFuncOp[funcOp] = 0;
     return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
@@ -342,6 +385,42 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
   return success();
 }
 
+/// Helper function that extracts the source from a memref.cast. If the given
+/// value is not a memref.cast result, simply returns the given value.
+static Value unpackCast(Value v) {
+  auto castOp = v.getDefiningOp<memref::CastOp>();
+  if (!castOp)
+    return v;
+  return castOp.getSource();
+}
+
+/// Helper function that returns the return types (skipping casts) of the given
+/// func.return ops. This function returns as many types as the return ops have
+/// operands. If the i-th operand is not the same for all func.return ops, then
+/// the i-th returned type is an "empty" type.
+static SmallVector<Type> getReturnTypes(SmallVector<func::ReturnOp> returnOps) {
+  assert(!returnOps.empty() && "expected at least one ReturnOp");
+  int numOperands = returnOps.front()->getNumOperands();
+
+  // Helper function that unpacks memref.cast ops and returns the type.
+  auto getSourceType = [&](Value v) { return unpackCast(v).getType(); };
+
+  SmallVector<Type> result;
+  for (int i = 0; i < numOperands; ++i) {
+    // Get the type of the i-th operand of the first func.return ops.
+    Type t = getSourceType(returnOps.front()->getOperand(i));
+
+    // Check if all other func.return ops have a matching operand type.
+    for (int j = 1; j < static_cast<int>(returnOps.size()); ++j)
+      if (getSourceType(returnOps[j]->getOperand(i)) != t)
+        t = Type();
+
+    result.push_back(t);
+  }
+
+  return result;
+}
+
 /// Fold return values that are memref casts and update function return types.
 ///
 /// During FuncOp bufferization, the exact type of the returned memrefs (if any)
@@ -350,21 +429,33 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
 /// entire function body, a more concise memref type can potentially be used for
 /// the return type of the function.
 static void foldMemRefCasts(func::FuncOp funcOp) {
+  // There is nothing to do for bodiless ops.
   if (funcOp.getBody().empty())
     return;
 
-  func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
-  SmallVector<Type> resultTypes;
+  // Compute the common result types of all return ops.
+  SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
+  SmallVector<Type> resultTypes = getReturnTypes(returnOps);
 
-  for (OpOperand &operand : returnOp->getOpOperands()) {
-    if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
-      operand.set(castOp.getSource());
-      resultTypes.push_back(castOp.getSource().getType());
-    } else {
-      resultTypes.push_back(operand.get().getType());
+  // Remove direct casts.
+  for (func::ReturnOp returnOp : returnOps) {
+    for (OpOperand &operand : returnOp->getOpOperands()) {
+      // Bail if no common result type was found.
+      if (resultTypes[operand.getOperandNumber()]) {
+        operand.set(unpackCast(operand.get()));
+      }
     }
   }
 
+  // Fill in the missing result types that were not the same among all
+  // func.return ops.
+  for (int i = 0; i < static_cast<int>(resultTypes.size()); ++i) {
+    if (resultTypes[i])
+      continue;
+    resultTypes[i] = funcOp.getFunctionType().getResult(i);
+  }
+
+  // Update the function type.
   auto newFuncType = FunctionType::get(
       funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
   funcOp.setType(newFuncType);
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
index 42d9cc00d3ff5a..ab6de70dfb161b 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
@@ -1348,3 +1348,48 @@ func.func @private_func_aliasing(%t: tensor<?xf32>) -> f32 {
   %2 = tensor.extract %1[%c0] : tensor<6xf32>
   return %2 : f32
 }
+
+// -----
+
+// CHECK-ALIAS-SETS-LABEL: func @multiple_returns(
+func.func @multiple_returns(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) -> tensor<5xf32> {
+  cf.cond_br %c, ^bb1, ^bb2
+^bb1:
+  return %t0 : tensor<5xf32>
+^bb2:
+  return %t1 : tensor<5xf32>
+}
+
+//       CHECK-ALIAS-SETS: func @caller(
+//  CHECK-ALIAS-SETS-SAME:     %{{.*}}: i1, %[[t0:.*]]: tensor<5xf32> {bufferization.access = "read"}, %[[t1:.*]]: tensor<5xf32> {bufferization.access = "read"}, %[[t2:.*]]: tensor<5xf32> {bufferization.access = "none"})
+func.func @caller(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) {
+  // Check that alias sets are computed correctly.
+  //      CHECK-ALIAS-SETS: %[[result:.*]] = call @multiple_returns
+  // CHECK-ALIAS-SETS-SAME: {__inplace_operands_attr__ = ["none", "true", "true", "true"],
+  // CHECK-ALIAS-SETS-SAME:  __opresult_alias_set_attr__ = [{{\[}}"%[[result]]", "%[[t0]]", "%[[t1]]"]]}
+  call @multiple_returns(%c, %t0, %t1, %t2) : (i1, tensor<5xf32>, tensor<5xf32>, tensor<5xf32>) -> (tensor<5xf32>)
+  return
+}
+
+// -----
+
+// CHECK-ALIAS-SETS-LABEL: func @multiple_equivalent_returns(
+func.func @multiple_equivalent_returns(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) -> tensor<5xf32> {
+  cf.cond_br %c, ^bb1, ^bb2
+^bb1:
+  return %t0 : tensor<5xf32>
+^bb2:
+  return %t0 : tensor<5xf32>
+}
+
+//       CHECK-ALIAS-SETS: func @caller(
+//  CHECK-ALIAS-SETS-SAME:     %{{.*}}: i1, %[[t0:.*]]: tensor<5xf32> {bufferization.access = "read"}, %[[t1:.*]]: tensor<5xf32> {bufferization.access = "none"}, %[[t2:.*]]: tensor<5xf32> {bufferization.access = "none"})
+func.func @caller(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) -> tensor<5xf32> {
+  // Check that equivalence sets are computed correctly.
+  //      CHECK-ALIAS-SETS: %[[result:.*]] = call @multiple_equivalent_returns
+  // CHECK-ALIAS-SETS-SAME: {__inplace_operands_attr__ = ["none", "true", "true", "true"],
+  // CHECK-ALIAS-SETS-SAME:  __opresult_alias_set_attr__ = [{{\[}}"%[[result]]", "%[[t0]]"]]}
+  %r = call @multiple_equivalent_returns(%c, %t0, %t1, %t2) : (i1, tensor<5xf32>, tensor<5xf32>, tensor<5xf32>) -> (tensor<5xf32>)
+  // CHECK: {__equivalent_func_args__ = [1], __inplace_operands_attr__ = ["true"]} %[[result]] : tensor<5xf32>
+  return %r : tensor<5xf32>
+}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
index 2829eafb7c1c59..f3da82cc0064d4 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
@@ -1,24 +1,5 @@
 // RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="bufferize-function-boundaries=1" -split-input-file -verify-diagnostics
 
-// expected-error @+1 {{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}}
-func.func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>)
-    -> (tensor<f32>, tensor<f32>)
-{
-  cf.cond_br %cond1, ^bb1, ^bb2
-
-  ^bb1:
-    %T:2 = scf.if %cond2 -> (tensor<f32>, tensor<f32>) {
-      scf.yield %t1, %t2 : tensor<f32>, tensor<f32>
-    } else {
-      scf.yield %t2, %t1 : tensor<f32>, tensor<f32>
-    }
-    return %T#0, %T#1 : tensor<f32>, tensor<f32>
-  ^bb2:
-    return %t2, %t1 : tensor<f32>, tensor<f32>
-}
-
-// -----
-
 // expected-error @-3 {{expected callgraph to be free of circular dependencies}}
 
 func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> {
@@ -160,7 +141,8 @@ func.func @regression_scf_while() {
 
 // -----
 
-// expected-error @below{{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}}
+// expected-error @below{{could not infer buffer type of block argument}}
+// expected-error @below{{failed to bufferize op}}
 func.func @func_multiple_yields(%t: tensor<5xf32>) -> tensor<5xf32> {
   func.return %t : tensor<5xf32>
 ^bb1(%arg1 : tensor<5xf32>):
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index d31b43477beb9f..4f10ffea561aa8 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -722,3 +722,27 @@ func.func @bar(%t: tensor<5xf32>, %m: memref<5xf32>) -> memref<5xf32> {
   %0 = func.call @foo(%m) : (memref<5xf32>) -> (memref<5xf32>)
   return %0 : memref<5xf32>
 }
+
+// -----
+
+// The two func.return operands have different types after bufferization. Make
+// sure that memref.cast ops are inserted.
+
+// CHECK-LABEL: func @result_type_mismatch({{.*}}) -> memref<5xf32, strided<[?], offset: ?>>
+func.func @result_type_mismatch(%c: i1) -> tensor<5xf32> {
+  // CHECK: %[[alloc:.*]] = memref.alloc() {alignment = 64 : i64} : memref<10xf32>
+  %t = tensor.empty() : tensor<10xf32>
+  cf.cond_br %c, ^bb1, ^bb2
+^bb1:
+  // CHECK: %[[m0:.*]] = memref.subview %[[alloc]][0] [5] [2] : memref<10xf32> to memref<5xf32, strided<[2]>>
+  // CHECK: %[[cast0:.*]] = memref.cast %[[m0]] : memref<5xf32, strided<[2]>> to memref<5xf32, strided<[?], offset: ?>>
+  %0 = tensor.extract_slice %t[0][5][2] : tensor<10xf32> to tensor<5xf32>
+  // CHECK: return %[[cast0]] : memref<5xf32, strided<[?], offset: ?>
+  return %0 : tensor<5xf32>
+^bb2:
+  // CHECK: %[[m1:.*]] = memref.subview %[[alloc]][2] [5] [1] : memref<10xf32> to memref<5xf32, strided<[1], offset: 2>>
+  // CHECK: %[[cast1:.*]] = memref.cast %[[m1]] : memref<5xf32, strided<[1], offset: 2>> to memref<5xf32, strided<[?], offset: ?>>
+  %1 = tensor.extract_slice %t[2][5][1] : tensor<10xf32> to tensor<5xf32>
+  // CHECK: return ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Oct 29, 2024

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Multiple func.return ops inside of a func.func op are now supported during bufferization. This PR extends the code base in 3 places:

  • When inferring function return types, memref.cast ops are folded away only if all func.return ops have matching buffer types. (E.g., we don't fold if two return ops have operands with different layout maps.)
  • The alias sets of all func.return ops are merged. That's because aliasing is a "may be" property.
  • The equivalence sets of all func.return ops are taken only if they match. If different func.return ops have different equivalence sets for their operands, the equivalence information is dropped. That's because equivalence is a "must be" property.

Depends on #113999.


Patch is 20.08 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/114017.diff

5 Files Affected:

  • (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+29-46)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp (+135-44)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir (+45)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir (+2-20)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir (+24)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 6e91d3b89a7c79..195b17fcf902a2 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -41,18 +41,13 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
 #endif // NDEBUG
 }
 
-/// Return the unique ReturnOp that terminates `funcOp`.
-/// Return nullptr if there is no such unique ReturnOp.
-static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
-  func::ReturnOp returnOp;
-  for (Block &b : funcOp.getBody()) {
-    if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
-      if (returnOp)
-        return nullptr;
-      returnOp = candidateOp;
-    }
-  }
-  return returnOp;
+/// Return all top-level func.return ops in the given function.
+static SmallVector<func::ReturnOp> getReturnOps(FuncOp funcOp) {
+  SmallVector<func::ReturnOp> result;
+  for (Block &b : funcOp.getBody())
+    if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator()))
+      result.push_back(returnOp);
+  return result;
 }
 
 /// Return the index-th bufferized function argument type. This assumes that the
@@ -372,15 +367,6 @@ struct FuncOpInterface
         getBufferType(op, value, options, invocationStack);
   }
 
-  LogicalResult verifyAnalysis(Operation *op,
-                               const AnalysisState &state) const {
-    auto funcOp = cast<func::FuncOp>(op);
-    // TODO: func.func with multiple returns are not supported.
-    if (!getAssumedUniqueReturnOp(funcOp) && !funcOp.isExternal())
-      return op->emitOpError("op without unique func.return is not supported");
-    return success();
-  }
-
   /// Rewrite function bbArgs and return values into buffer form. This function
   /// bufferizes the function signature and the ReturnOp. When the entire
   /// function body has been bufferized, function return types can be switched
@@ -427,41 +413,38 @@ struct FuncOpInterface
       return success();
     }
 
-    // TODO: Support functions with multiple returns.
-    func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
-    assert(returnOp && "expected func with single return op");
-    assert(returnOp->getNumOperands() == retTypes.size() &&
-           "incorrect number of return values");
-    Location loc = returnOp.getLoc();
-
     // 1. Bufferize every block.
     for (Block &block : funcOp.getBody())
       if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
                                                         options)))
         return failure();
 
-    // 2. Bufferize all operands of the return op.
-    SmallVector<Value> returnValues;
-    for (auto [returnVal, bufferizedType] :
-         llvm::zip_equal(returnOp->getOperands(), retTypes)) {
-      auto tensorType = dyn_cast<TensorType>(returnVal.getType());
-      rewriter.setInsertionPoint(returnOp);
-
-      // If not a tensor type just forward it.
-      if (!tensorType) {
-        returnValues.push_back(returnVal);
-        continue;
+    // 2. Bufferize the operands of the all return op.
+    for (func::ReturnOp returnOp : getReturnOps(funcOp)) {
+      assert(returnOp->getNumOperands() == retTypes.size() &&
+             "incorrect number of return values");
+      SmallVector<Value> returnValues;
+      for (auto [returnVal, bufferizedType] :
+           llvm::zip_equal(returnOp->getOperands(), retTypes)) {
+        auto tensorType = dyn_cast<TensorType>(returnVal.getType());
+        rewriter.setInsertionPoint(returnOp);
+
+        // If not a tensor type just forward it.
+        if (!tensorType) {
+          returnValues.push_back(returnVal);
+          continue;
+        }
+
+        // Note: If `inferFunctionResultLayout = true`, casts are later folded
+        // away.
+        Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
+            returnOp.getLoc(), bufferizedType, returnVal);
+        returnValues.push_back(toMemrefOp);
       }
 
-      // Note: If `inferFunctionResultLayout = true`, casts are later folded
-      // away.
-      Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
-          loc, bufferizedType, returnVal);
-      returnValues.push_back(toMemrefOp);
+      returnOp.getOperandsMutable().assign(returnValues);
     }
 
-    returnOp.getOperandsMutable().assign(returnValues);
-
     // 3. Set the new function type.
     funcOp.setType(newFuncType);
     return success();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 0a4072605c265f..e4635ebd78d8f8 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -86,18 +86,13 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
   return state.addExtension<FuncAnalysisState>();
 }
 
-/// Return the unique ReturnOp that terminates `funcOp`.
-/// Return nullptr if there is no such unique ReturnOp.
-static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
-  func::ReturnOp returnOp;
-  for (Block &b : funcOp.getBody()) {
-    if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
-      if (returnOp)
-        return nullptr;
-      returnOp = candidateOp;
-    }
-  }
-  return returnOp;
+/// Return all top-level func.return ops in the given function.
+static SmallVector<func::ReturnOp> getReturnOps(FuncOp funcOp) {
+  SmallVector<func::ReturnOp> result;
+  for (Block &b : funcOp.getBody())
+    if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator()))
+      result.push_back(returnOp);
+  return result;
 }
 
 namespace {
@@ -146,24 +141,80 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
     return success();
   }
 
-  // Support only single return-terminated block in the function.
-  func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
-  assert(returnOp && "expected func with single return op");
-
-  for (OpOperand &returnVal : returnOp->getOpOperands())
-    if (isa<RankedTensorType>(returnVal.get().getType()))
-      for (BlockArgument bbArg : funcOp.getArguments())
-        if (isa<RankedTensorType>(bbArg.getType())) {
-          int64_t returnIdx = returnVal.getOperandNumber();
-          int64_t bbArgIdx = bbArg.getArgNumber();
-          if (state.areEquivalentBufferizedValues(returnVal.get(), bbArg)) {
-            funcState.equivalentFuncArgs[funcOp][returnIdx] = bbArgIdx;
-            if (state.getOptions().testAnalysisOnly)
-              annotateEquivalentReturnBbArg(returnVal, bbArg);
+  // Find all func.return ops.
+  SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
+  assert(!returnOps.empty() && "expected at least one ReturnOp");
+
+  // Build alias sets. Merge all aliases from all func.return ops.
+  for (BlockArgument bbArg : funcOp.getArguments()) {
+    if (isa<RankedTensorType>(bbArg.getType())) {
+      int64_t bbArgIdx = bbArg.getArgNumber();
+      // Store aliases in a set, so that we don't add the same alias twice.
+      SetVector<int64_t> aliases;
+      for (func::ReturnOp returnOp : returnOps) {
+        for (OpOperand &returnVal : returnOp->getOpOperands()) {
+          if (isa<RankedTensorType>(returnVal.get().getType())) {
+            int64_t returnIdx = returnVal.getOperandNumber();
+            if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
+              aliases.insert(returnIdx);
           }
-          if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
-            funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx);
         }
+      }
+      for (int64_t alias : aliases)
+        funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(alias);
+    }
+  }
+
+  // Build equivalence sets.
+  // Helper function that finds an equivalent block argument index for the
+  // given OpOperand. Return std::nullopt if no equivalent block argument could
+  // be found.
+  auto findEquivalentBlockArgIdx =
+      [&](OpOperand &opOperand) -> std::optional<int64_t> {
+    Value v = opOperand.get();
+    if (!isa<TensorType>(v.getType()))
+      return std::nullopt;
+    for (BlockArgument bbArg : funcOp.getArguments()) {
+      if (isa<RankedTensorType>(bbArg.getType())) {
+        if (state.areEquivalentBufferizedValues(v, bbArg)) {
+          if (state.getOptions().testAnalysisOnly)
+            annotateEquivalentReturnBbArg(opOperand, bbArg);
+          return bbArg.getArgNumber();
+        }
+      }
+    }
+    return std::nullopt;
+  };
+
+  int64_t numResults = returnOps.front()->getNumOperands();
+  for (int64_t i = 0; i < numResults; ++i) {
+    // Find the equivalent block argument index for the i-th operand of the
+    // first func.return op.
+    std::optional<int64_t> maybeEquiv =
+        findEquivalentBlockArgIdx(returnOps.front()->getOpOperand(i));
+    if (!maybeEquiv.has_value())
+      continue;
+    int64_t bbArgIdx = *maybeEquiv;
+    bool allEquiv = true;
+
+    // Check if all other func.return ops have the same equivalent block
+    // argument for the i-th operand. In contrast to aliasing information,
+    // which is just "merged", equivalence information must match across all
+    // func.return ops.
+    for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) {
+      std::optional<int64_t> maybeEquiv =
+          findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
+      if (maybeEquiv != bbArgIdx) {
+        allEquiv = false;
+        break;
+      }
+    }
+
+    // All func.return ops have the same equivalent block argument for the i-th
+    // operand.
+    if (allEquiv)
+      funcState.equivalentFuncArgs[funcOp][i] = bbArgIdx;
+  }
 
   return success();
 }
@@ -299,14 +350,6 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
   // For each FuncOp, the number of func::CallOp it contains.
   DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
   WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
-    if (!funcOp.getBody().empty()) {
-      func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
-      if (!returnOp)
-        return funcOp->emitError()
-               << "cannot bufferize a FuncOp with tensors and "
-                  "without a unique ReturnOp";
-    }
-
     // Collect function calls and populate the caller map.
     numberCallOpsContainedInFuncOp[funcOp] = 0;
     return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
@@ -342,6 +385,42 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
   return success();
 }
 
+/// Helper function that extracts the source from a memref.cast. If the given
+/// value is not a memref.cast result, simply returns the given value.
+static Value unpackCast(Value v) {
+  auto castOp = v.getDefiningOp<memref::CastOp>();
+  if (!castOp)
+    return v;
+  return castOp.getSource();
+}
+
+/// Helper function that returns the return types (skipping casts) of the given
+/// func.return ops. This function returns as many types as the return ops have
+/// operands. If the i-th operand is not the same for all func.return ops, then
+/// the i-th returned type is an "empty" type.
+static SmallVector<Type> getReturnTypes(SmallVector<func::ReturnOp> returnOps) {
+  assert(!returnOps.empty() && "expected at least one ReturnOp");
+  int numOperands = returnOps.front()->getNumOperands();
+
+  // Helper function that unpacks memref.cast ops and returns the type.
+  auto getSourceType = [&](Value v) { return unpackCast(v).getType(); };
+
+  SmallVector<Type> result;
+  for (int i = 0; i < numOperands; ++i) {
+    // Get the type of the i-th operand of the first func.return ops.
+    Type t = getSourceType(returnOps.front()->getOperand(i));
+
+    // Check if all other func.return ops have a matching operand type.
+    for (int j = 1; j < static_cast<int>(returnOps.size()); ++j)
+      if (getSourceType(returnOps[j]->getOperand(i)) != t)
+        t = Type();
+
+    result.push_back(t);
+  }
+
+  return result;
+}
+
 /// Fold return values that are memref casts and update function return types.
 ///
 /// During FuncOp bufferization, the exact type of the returned memrefs (if any)
@@ -350,21 +429,33 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
 /// entire function body, a more concise memref type can potentially be used for
 /// the return type of the function.
 static void foldMemRefCasts(func::FuncOp funcOp) {
+  // There is nothing to do for bodiless ops.
   if (funcOp.getBody().empty())
     return;
 
-  func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
-  SmallVector<Type> resultTypes;
+  // Compute the common result types of all return ops.
+  SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
+  SmallVector<Type> resultTypes = getReturnTypes(returnOps);
 
-  for (OpOperand &operand : returnOp->getOpOperands()) {
-    if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
-      operand.set(castOp.getSource());
-      resultTypes.push_back(castOp.getSource().getType());
-    } else {
-      resultTypes.push_back(operand.get().getType());
+  // Remove direct casts.
+  for (func::ReturnOp returnOp : returnOps) {
+    for (OpOperand &operand : returnOp->getOpOperands()) {
+      // Bail if no common result type was found.
+      if (resultTypes[operand.getOperandNumber()]) {
+        operand.set(unpackCast(operand.get()));
+      }
     }
   }
 
+  // Fill in the missing result types that were not the same among all
+  // func.return ops.
+  for (int i = 0; i < static_cast<int>(resultTypes.size()); ++i) {
+    if (resultTypes[i])
+      continue;
+    resultTypes[i] = funcOp.getFunctionType().getResult(i);
+  }
+
+  // Update the function type.
   auto newFuncType = FunctionType::get(
       funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
   funcOp.setType(newFuncType);
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
index 42d9cc00d3ff5a..ab6de70dfb161b 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
@@ -1348,3 +1348,48 @@ func.func @private_func_aliasing(%t: tensor<?xf32>) -> f32 {
   %2 = tensor.extract %1[%c0] : tensor<6xf32>
   return %2 : f32
 }
+
+// -----
+
+// CHECK-ALIAS-SETS-LABEL: func @multiple_returns(
+func.func @multiple_returns(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) -> tensor<5xf32> {
+  cf.cond_br %c, ^bb1, ^bb2
+^bb1:
+  return %t0 : tensor<5xf32>
+^bb2:
+  return %t1 : tensor<5xf32>
+}
+
+//       CHECK-ALIAS-SETS: func @caller(
+//  CHECK-ALIAS-SETS-SAME:     %{{.*}}: i1, %[[t0:.*]]: tensor<5xf32> {bufferization.access = "read"}, %[[t1:.*]]: tensor<5xf32> {bufferization.access = "read"}, %[[t2:.*]]: tensor<5xf32> {bufferization.access = "none"})
+func.func @caller(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) {
+  // Check that alias sets are computed correctly.
+  //      CHECK-ALIAS-SETS: %[[result:.*]] = call @multiple_returns
+  // CHECK-ALIAS-SETS-SAME: {__inplace_operands_attr__ = ["none", "true", "true", "true"],
+  // CHECK-ALIAS-SETS-SAME:  __opresult_alias_set_attr__ = [{{\[}}"%[[result]]", "%[[t0]]", "%[[t1]]"]]}
+  call @multiple_returns(%c, %t0, %t1, %t2) : (i1, tensor<5xf32>, tensor<5xf32>, tensor<5xf32>) -> (tensor<5xf32>)
+  return
+}
+
+// -----
+
+// CHECK-ALIAS-SETS-LABEL: func @multiple_equivalent_returns(
+func.func @multiple_equivalent_returns(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) -> tensor<5xf32> {
+  cf.cond_br %c, ^bb1, ^bb2
+^bb1:
+  return %t0 : tensor<5xf32>
+^bb2:
+  return %t0 : tensor<5xf32>
+}
+
+//       CHECK-ALIAS-SETS: func @caller(
+//  CHECK-ALIAS-SETS-SAME:     %{{.*}}: i1, %[[t0:.*]]: tensor<5xf32> {bufferization.access = "read"}, %[[t1:.*]]: tensor<5xf32> {bufferization.access = "none"}, %[[t2:.*]]: tensor<5xf32> {bufferization.access = "none"})
+func.func @caller(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) -> tensor<5xf32> {
+  // Check that equivalence sets are computed correctly.
+  //      CHECK-ALIAS-SETS: %[[result:.*]] = call @multiple_equivalent_returns
+  // CHECK-ALIAS-SETS-SAME: {__inplace_operands_attr__ = ["none", "true", "true", "true"],
+  // CHECK-ALIAS-SETS-SAME:  __opresult_alias_set_attr__ = [{{\[}}"%[[result]]", "%[[t0]]"]]}
+  %r = call @multiple_equivalent_returns(%c, %t0, %t1, %t2) : (i1, tensor<5xf32>, tensor<5xf32>, tensor<5xf32>) -> (tensor<5xf32>)
+  // CHECK: {__equivalent_func_args__ = [1], __inplace_operands_attr__ = ["true"]} %[[result]] : tensor<5xf32>
+  return %r : tensor<5xf32>
+}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
index 2829eafb7c1c59..f3da82cc0064d4 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
@@ -1,24 +1,5 @@
 // RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="bufferize-function-boundaries=1" -split-input-file -verify-diagnostics
 
-// expected-error @+1 {{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}}
-func.func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>)
-    -> (tensor<f32>, tensor<f32>)
-{
-  cf.cond_br %cond1, ^bb1, ^bb2
-
-  ^bb1:
-    %T:2 = scf.if %cond2 -> (tensor<f32>, tensor<f32>) {
-      scf.yield %t1, %t2 : tensor<f32>, tensor<f32>
-    } else {
-      scf.yield %t2, %t1 : tensor<f32>, tensor<f32>
-    }
-    return %T#0, %T#1 : tensor<f32>, tensor<f32>
-  ^bb2:
-    return %t2, %t1 : tensor<f32>, tensor<f32>
-}
-
-// -----
-
 // expected-error @-3 {{expected callgraph to be free of circular dependencies}}
 
 func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> {
@@ -160,7 +141,8 @@ func.func @regression_scf_while() {
 
 // -----
 
-// expected-error @below{{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}}
+// expected-error @below{{could not infer buffer type of block argument}}
+// expected-error @below{{failed to bufferize op}}
 func.func @func_multiple_yields(%t: tensor<5xf32>) -> tensor<5xf32> {
   func.return %t : tensor<5xf32>
 ^bb1(%arg1 : tensor<5xf32>):
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index d31b43477beb9f..4f10ffea561aa8 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -722,3 +722,27 @@ func.func @bar(%t: tensor<5xf32>, %m: memref<5xf32>) -> memref<5xf32> {
   %0 = func.call @foo(%m) : (memref<5xf32>) -> (memref<5xf32>)
   return %0 : memref<5xf32>
 }
+
+// -----
+
+// The two func.return operands have different types after bufferization. Make
+// sure that memref.cast ops are inserted.
+
+// CHECK-LABEL: func @result_type_mismatch({{.*}}) -> memref<5xf32, strided<[?], offset: ?>>
+func.func @result_type_mismatch(%c: i1) -> tensor<5xf32> {
+  // CHECK: %[[alloc:.*]] = memref.alloc() {alignment = 64 : i64} : memref<10xf32>
+  %t = tensor.empty() : tensor<10xf32>
+  cf.cond_br %c, ^bb1, ^bb2
+^bb1:
+  // CHECK: %[[m0:.*]] = memref.subview %[[alloc]][0] [5] [2] : memref<10xf32> to memref<5xf32, strided<[2]>>
+  // CHECK: %[[cast0:.*]] = memref.cast %[[m0]] : memref<5xf32, strided<[2]>> to memref<5xf32, strided<[?], offset: ?>>
+  %0 = tensor.extract_slice %t[0][5][2] : tensor<10xf32> to tensor<5xf32>
+  // CHECK: return %[[cast0]] : memref<5xf32, strided<[?], offset: ?>
+  return %0 : tensor<5xf32>
+^bb2:
+  // CHECK: %[[m1:.*]] = memref.subview %[[alloc]][2] [5] [1] : memref<10xf32> to memref<5xf32, strided<[1], offset: 2>>
+  // CHECK: %[[cast1:.*]] = memref.cast %[[m1]] : memref<5xf32, strided<[1], offset: 2>> to memref<5xf32, strided<[?], offset: ?>>
+  %1 = tensor.extract_slice %t[2][5][1] : tensor<10xf32> to tensor<5xf32>
+  // CHECK: return ...
[truncated]

@matthias-springer matthias-springer force-pushed the users/matthias-springer/bufferization_unique_returns branch from 3522311 to 684ac4a Compare October 29, 2024 23:19
Base automatically changed from users/matthias-springer/bufferization_bodiless to main October 30, 2024 12:49
@matthias-springer matthias-springer force-pushed the users/matthias-springer/bufferization_unique_returns branch 2 times, most recently from 40c5042 to 3f7f8a7 Compare November 5, 2024 01:22
%1 = tensor.extract_slice %t[2][5][1] : tensor<10xf32> to tensor<5xf32>
// CHECK: return %[[cast1]] : memref<5xf32, strided<[?], offset: ?>>
return %1 : tensor<5xf32>
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very interesting test. thanks.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/bufferization_unique_returns branch from 3f7f8a7 to 52d1a20 Compare November 10, 2024 04:40
@matthias-springer matthias-springer force-pushed the users/matthias-springer/bufferization_unique_returns branch from 52d1a20 to 1122ffe Compare November 12, 2024 12:59
Copy link
Contributor

@javedabsar1 javedabsar1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM . Thanks.

@matthias-springer matthias-springer merged commit b0a4e95 into main Nov 12, 2024
8 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/bufferization_unique_returns branch November 12, 2024 23:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants