-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][bufferization] Add support for non-unique func.return
#114017
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][bufferization] Add support for non-unique func.return
#114017
Conversation
a9f607f
to
11089cc
Compare
13808b4
to
3522311
Compare
@llvm/pr-subscribers-mlir-bufferization Author: Matthias Springer (matthias-springer) ChangesMultiple
Depends on #113999. Patch is 20.08 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/114017.diff 5 Files Affected:
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 6e91d3b89a7c79..195b17fcf902a2 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -41,18 +41,13 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
#endif // NDEBUG
}
-/// Return the unique ReturnOp that terminates `funcOp`.
-/// Return nullptr if there is no such unique ReturnOp.
-static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
- func::ReturnOp returnOp;
- for (Block &b : funcOp.getBody()) {
- if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
- if (returnOp)
- return nullptr;
- returnOp = candidateOp;
- }
- }
- return returnOp;
+/// Return all top-level func.return ops in the given function.
+static SmallVector<func::ReturnOp> getReturnOps(FuncOp funcOp) {
+ SmallVector<func::ReturnOp> result;
+ for (Block &b : funcOp.getBody())
+ if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator()))
+ result.push_back(returnOp);
+ return result;
}
/// Return the index-th bufferized function argument type. This assumes that the
@@ -372,15 +367,6 @@ struct FuncOpInterface
getBufferType(op, value, options, invocationStack);
}
- LogicalResult verifyAnalysis(Operation *op,
- const AnalysisState &state) const {
- auto funcOp = cast<func::FuncOp>(op);
- // TODO: func.func with multiple returns are not supported.
- if (!getAssumedUniqueReturnOp(funcOp) && !funcOp.isExternal())
- return op->emitOpError("op without unique func.return is not supported");
- return success();
- }
-
/// Rewrite function bbArgs and return values into buffer form. This function
/// bufferizes the function signature and the ReturnOp. When the entire
/// function body has been bufferized, function return types can be switched
@@ -427,41 +413,38 @@ struct FuncOpInterface
return success();
}
- // TODO: Support functions with multiple returns.
- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
- assert(returnOp && "expected func with single return op");
- assert(returnOp->getNumOperands() == retTypes.size() &&
- "incorrect number of return values");
- Location loc = returnOp.getLoc();
-
// 1. Bufferize every block.
for (Block &block : funcOp.getBody())
if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
options)))
return failure();
- // 2. Bufferize all operands of the return op.
- SmallVector<Value> returnValues;
- for (auto [returnVal, bufferizedType] :
- llvm::zip_equal(returnOp->getOperands(), retTypes)) {
- auto tensorType = dyn_cast<TensorType>(returnVal.getType());
- rewriter.setInsertionPoint(returnOp);
-
- // If not a tensor type just forward it.
- if (!tensorType) {
- returnValues.push_back(returnVal);
- continue;
+ // 2. Bufferize the operands of the all return op.
+ for (func::ReturnOp returnOp : getReturnOps(funcOp)) {
+ assert(returnOp->getNumOperands() == retTypes.size() &&
+ "incorrect number of return values");
+ SmallVector<Value> returnValues;
+ for (auto [returnVal, bufferizedType] :
+ llvm::zip_equal(returnOp->getOperands(), retTypes)) {
+ auto tensorType = dyn_cast<TensorType>(returnVal.getType());
+ rewriter.setInsertionPoint(returnOp);
+
+ // If not a tensor type just forward it.
+ if (!tensorType) {
+ returnValues.push_back(returnVal);
+ continue;
+ }
+
+ // Note: If `inferFunctionResultLayout = true`, casts are later folded
+ // away.
+ Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
+ returnOp.getLoc(), bufferizedType, returnVal);
+ returnValues.push_back(toMemrefOp);
}
- // Note: If `inferFunctionResultLayout = true`, casts are later folded
- // away.
- Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
- loc, bufferizedType, returnVal);
- returnValues.push_back(toMemrefOp);
+ returnOp.getOperandsMutable().assign(returnValues);
}
- returnOp.getOperandsMutable().assign(returnValues);
-
// 3. Set the new function type.
funcOp.setType(newFuncType);
return success();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 0a4072605c265f..e4635ebd78d8f8 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -86,18 +86,13 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
return state.addExtension<FuncAnalysisState>();
}
-/// Return the unique ReturnOp that terminates `funcOp`.
-/// Return nullptr if there is no such unique ReturnOp.
-static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
- func::ReturnOp returnOp;
- for (Block &b : funcOp.getBody()) {
- if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
- if (returnOp)
- return nullptr;
- returnOp = candidateOp;
- }
- }
- return returnOp;
+/// Return all top-level func.return ops in the given function.
+static SmallVector<func::ReturnOp> getReturnOps(FuncOp funcOp) {
+ SmallVector<func::ReturnOp> result;
+ for (Block &b : funcOp.getBody())
+ if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator()))
+ result.push_back(returnOp);
+ return result;
}
namespace {
@@ -146,24 +141,80 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
return success();
}
- // Support only single return-terminated block in the function.
- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
- assert(returnOp && "expected func with single return op");
-
- for (OpOperand &returnVal : returnOp->getOpOperands())
- if (isa<RankedTensorType>(returnVal.get().getType()))
- for (BlockArgument bbArg : funcOp.getArguments())
- if (isa<RankedTensorType>(bbArg.getType())) {
- int64_t returnIdx = returnVal.getOperandNumber();
- int64_t bbArgIdx = bbArg.getArgNumber();
- if (state.areEquivalentBufferizedValues(returnVal.get(), bbArg)) {
- funcState.equivalentFuncArgs[funcOp][returnIdx] = bbArgIdx;
- if (state.getOptions().testAnalysisOnly)
- annotateEquivalentReturnBbArg(returnVal, bbArg);
+ // Find all func.return ops.
+ SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
+ assert(!returnOps.empty() && "expected at least one ReturnOp");
+
+ // Build alias sets. Merge all aliases from all func.return ops.
+ for (BlockArgument bbArg : funcOp.getArguments()) {
+ if (isa<RankedTensorType>(bbArg.getType())) {
+ int64_t bbArgIdx = bbArg.getArgNumber();
+ // Store aliases in a set, so that we don't add the same alias twice.
+ SetVector<int64_t> aliases;
+ for (func::ReturnOp returnOp : returnOps) {
+ for (OpOperand &returnVal : returnOp->getOpOperands()) {
+ if (isa<RankedTensorType>(returnVal.get().getType())) {
+ int64_t returnIdx = returnVal.getOperandNumber();
+ if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
+ aliases.insert(returnIdx);
}
- if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
- funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx);
}
+ }
+ for (int64_t alias : aliases)
+ funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(alias);
+ }
+ }
+
+ // Build equivalence sets.
+ // Helper function that finds an equivalent block argument index for the
+ // given OpOperand. Return std::nullopt if no equivalent block argument could
+ // be found.
+ auto findEquivalentBlockArgIdx =
+ [&](OpOperand &opOperand) -> std::optional<int64_t> {
+ Value v = opOperand.get();
+ if (!isa<TensorType>(v.getType()))
+ return std::nullopt;
+ for (BlockArgument bbArg : funcOp.getArguments()) {
+ if (isa<RankedTensorType>(bbArg.getType())) {
+ if (state.areEquivalentBufferizedValues(v, bbArg)) {
+ if (state.getOptions().testAnalysisOnly)
+ annotateEquivalentReturnBbArg(opOperand, bbArg);
+ return bbArg.getArgNumber();
+ }
+ }
+ }
+ return std::nullopt;
+ };
+
+ int64_t numResults = returnOps.front()->getNumOperands();
+ for (int64_t i = 0; i < numResults; ++i) {
+ // Find the equivalent block argument index for the i-th operand of the
+ // first func.return op.
+ std::optional<int64_t> maybeEquiv =
+ findEquivalentBlockArgIdx(returnOps.front()->getOpOperand(i));
+ if (!maybeEquiv.has_value())
+ continue;
+ int64_t bbArgIdx = *maybeEquiv;
+ bool allEquiv = true;
+
+ // Check if all other func.return ops have the same equivalent block
+ // argument for the i-th operand. In contrast to aliasing information,
+ // which is just "merged", equivalence information must match across all
+ // func.return ops.
+ for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) {
+ std::optional<int64_t> maybeEquiv =
+ findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
+ if (maybeEquiv != bbArgIdx) {
+ allEquiv = false;
+ break;
+ }
+ }
+
+ // All func.return ops have the same equivalent block argument for the i-th
+ // operand.
+ if (allEquiv)
+ funcState.equivalentFuncArgs[funcOp][i] = bbArgIdx;
+ }
return success();
}
@@ -299,14 +350,6 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
// For each FuncOp, the number of func::CallOp it contains.
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
- if (!funcOp.getBody().empty()) {
- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
- if (!returnOp)
- return funcOp->emitError()
- << "cannot bufferize a FuncOp with tensors and "
- "without a unique ReturnOp";
- }
-
// Collect function calls and populate the caller map.
numberCallOpsContainedInFuncOp[funcOp] = 0;
return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
@@ -342,6 +385,42 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
return success();
}
+/// Helper function that extracts the source from a memref.cast. If the given
+/// value is not a memref.cast result, simply returns the given value.
+static Value unpackCast(Value v) {
+ auto castOp = v.getDefiningOp<memref::CastOp>();
+ if (!castOp)
+ return v;
+ return castOp.getSource();
+}
+
+/// Helper function that returns the return types (skipping casts) of the given
+/// func.return ops. This function returns as many types as the return ops have
+/// operands. If the i-th operand is not the same for all func.return ops, then
+/// the i-th returned type is an "empty" type.
+static SmallVector<Type> getReturnTypes(SmallVector<func::ReturnOp> returnOps) {
+ assert(!returnOps.empty() && "expected at least one ReturnOp");
+ int numOperands = returnOps.front()->getNumOperands();
+
+ // Helper function that unpacks memref.cast ops and returns the type.
+ auto getSourceType = [&](Value v) { return unpackCast(v).getType(); };
+
+ SmallVector<Type> result;
+ for (int i = 0; i < numOperands; ++i) {
+ // Get the type of the i-th operand of the first func.return ops.
+ Type t = getSourceType(returnOps.front()->getOperand(i));
+
+ // Check if all other func.return ops have a matching operand type.
+ for (int j = 1; j < static_cast<int>(returnOps.size()); ++j)
+ if (getSourceType(returnOps[j]->getOperand(i)) != t)
+ t = Type();
+
+ result.push_back(t);
+ }
+
+ return result;
+}
+
/// Fold return values that are memref casts and update function return types.
///
/// During FuncOp bufferization, the exact type of the returned memrefs (if any)
@@ -350,21 +429,33 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
/// entire function body, a more concise memref type can potentially be used for
/// the return type of the function.
static void foldMemRefCasts(func::FuncOp funcOp) {
+ // There is nothing to do for bodiless ops.
if (funcOp.getBody().empty())
return;
- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
- SmallVector<Type> resultTypes;
+ // Compute the common result types of all return ops.
+ SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
+ SmallVector<Type> resultTypes = getReturnTypes(returnOps);
- for (OpOperand &operand : returnOp->getOpOperands()) {
- if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
- operand.set(castOp.getSource());
- resultTypes.push_back(castOp.getSource().getType());
- } else {
- resultTypes.push_back(operand.get().getType());
+ // Remove direct casts.
+ for (func::ReturnOp returnOp : returnOps) {
+ for (OpOperand &operand : returnOp->getOpOperands()) {
+ // Bail if no common result type was found.
+ if (resultTypes[operand.getOperandNumber()]) {
+ operand.set(unpackCast(operand.get()));
+ }
}
}
+ // Fill in the missing result types that were not the same among all
+ // func.return ops.
+ for (int i = 0; i < static_cast<int>(resultTypes.size()); ++i) {
+ if (resultTypes[i])
+ continue;
+ resultTypes[i] = funcOp.getFunctionType().getResult(i);
+ }
+
+ // Update the function type.
auto newFuncType = FunctionType::get(
funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
funcOp.setType(newFuncType);
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
index 42d9cc00d3ff5a..ab6de70dfb161b 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
@@ -1348,3 +1348,48 @@ func.func @private_func_aliasing(%t: tensor<?xf32>) -> f32 {
%2 = tensor.extract %1[%c0] : tensor<6xf32>
return %2 : f32
}
+
+// -----
+
+// CHECK-ALIAS-SETS-LABEL: func @multiple_returns(
+func.func @multiple_returns(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) -> tensor<5xf32> {
+ cf.cond_br %c, ^bb1, ^bb2
+^bb1:
+ return %t0 : tensor<5xf32>
+^bb2:
+ return %t1 : tensor<5xf32>
+}
+
+// CHECK-ALIAS-SETS: func @caller(
+// CHECK-ALIAS-SETS-SAME: %{{.*}}: i1, %[[t0:.*]]: tensor<5xf32> {bufferization.access = "read"}, %[[t1:.*]]: tensor<5xf32> {bufferization.access = "read"}, %[[t2:.*]]: tensor<5xf32> {bufferization.access = "none"})
+func.func @caller(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) {
+ // Check that alias sets are computed correctly.
+ // CHECK-ALIAS-SETS: %[[result:.*]] = call @multiple_returns
+ // CHECK-ALIAS-SETS-SAME: {__inplace_operands_attr__ = ["none", "true", "true", "true"],
+ // CHECK-ALIAS-SETS-SAME: __opresult_alias_set_attr__ = [{{\[}}"%[[result]]", "%[[t0]]", "%[[t1]]"]]}
+ call @multiple_returns(%c, %t0, %t1, %t2) : (i1, tensor<5xf32>, tensor<5xf32>, tensor<5xf32>) -> (tensor<5xf32>)
+ return
+}
+
+// -----
+
+// CHECK-ALIAS-SETS-LABEL: func @multiple_equivalent_returns(
+func.func @multiple_equivalent_returns(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) -> tensor<5xf32> {
+ cf.cond_br %c, ^bb1, ^bb2
+^bb1:
+ return %t0 : tensor<5xf32>
+^bb2:
+ return %t0 : tensor<5xf32>
+}
+
+// CHECK-ALIAS-SETS: func @caller(
+// CHECK-ALIAS-SETS-SAME: %{{.*}}: i1, %[[t0:.*]]: tensor<5xf32> {bufferization.access = "read"}, %[[t1:.*]]: tensor<5xf32> {bufferization.access = "none"}, %[[t2:.*]]: tensor<5xf32> {bufferization.access = "none"})
+func.func @caller(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) -> tensor<5xf32> {
+ // Check that equivalence sets are computed correctly.
+ // CHECK-ALIAS-SETS: %[[result:.*]] = call @multiple_equivalent_returns
+ // CHECK-ALIAS-SETS-SAME: {__inplace_operands_attr__ = ["none", "true", "true", "true"],
+ // CHECK-ALIAS-SETS-SAME: __opresult_alias_set_attr__ = [{{\[}}"%[[result]]", "%[[t0]]"]]}
+ %r = call @multiple_equivalent_returns(%c, %t0, %t1, %t2) : (i1, tensor<5xf32>, tensor<5xf32>, tensor<5xf32>) -> (tensor<5xf32>)
+ // CHECK: {__equivalent_func_args__ = [1], __inplace_operands_attr__ = ["true"]} %[[result]] : tensor<5xf32>
+ return %r : tensor<5xf32>
+}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
index 2829eafb7c1c59..f3da82cc0064d4 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
@@ -1,24 +1,5 @@
// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="bufferize-function-boundaries=1" -split-input-file -verify-diagnostics
-// expected-error @+1 {{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}}
-func.func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>)
- -> (tensor<f32>, tensor<f32>)
-{
- cf.cond_br %cond1, ^bb1, ^bb2
-
- ^bb1:
- %T:2 = scf.if %cond2 -> (tensor<f32>, tensor<f32>) {
- scf.yield %t1, %t2 : tensor<f32>, tensor<f32>
- } else {
- scf.yield %t2, %t1 : tensor<f32>, tensor<f32>
- }
- return %T#0, %T#1 : tensor<f32>, tensor<f32>
- ^bb2:
- return %t2, %t1 : tensor<f32>, tensor<f32>
-}
-
-// -----
-
// expected-error @-3 {{expected callgraph to be free of circular dependencies}}
func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> {
@@ -160,7 +141,8 @@ func.func @regression_scf_while() {
// -----
-// expected-error @below{{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}}
+// expected-error @below{{could not infer buffer type of block argument}}
+// expected-error @below{{failed to bufferize op}}
func.func @func_multiple_yields(%t: tensor<5xf32>) -> tensor<5xf32> {
func.return %t : tensor<5xf32>
^bb1(%arg1 : tensor<5xf32>):
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index d31b43477beb9f..4f10ffea561aa8 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -722,3 +722,27 @@ func.func @bar(%t: tensor<5xf32>, %m: memref<5xf32>) -> memref<5xf32> {
%0 = func.call @foo(%m) : (memref<5xf32>) -> (memref<5xf32>)
return %0 : memref<5xf32>
}
+
+// -----
+
+// The two func.return operands have different types after bufferization. Make
+// sure that memref.cast ops are inserted.
+
+// CHECK-LABEL: func @result_type_mismatch({{.*}}) -> memref<5xf32, strided<[?], offset: ?>>
+func.func @result_type_mismatch(%c: i1) -> tensor<5xf32> {
+ // CHECK: %[[alloc:.*]] = memref.alloc() {alignment = 64 : i64} : memref<10xf32>
+ %t = tensor.empty() : tensor<10xf32>
+ cf.cond_br %c, ^bb1, ^bb2
+^bb1:
+ // CHECK: %[[m0:.*]] = memref.subview %[[alloc]][0] [5] [2] : memref<10xf32> to memref<5xf32, strided<[2]>>
+ // CHECK: %[[cast0:.*]] = memref.cast %[[m0]] : memref<5xf32, strided<[2]>> to memref<5xf32, strided<[?], offset: ?>>
+ %0 = tensor.extract_slice %t[0][5][2] : tensor<10xf32> to tensor<5xf32>
+ // CHECK: return %[[cast0]] : memref<5xf32, strided<[?], offset: ?>
+ return %0 : tensor<5xf32>
+^bb2:
+ // CHECK: %[[m1:.*]] = memref.subview %[[alloc]][2] [5] [1] : memref<10xf32> to memref<5xf32, strided<[1], offset: 2>>
+ // CHECK: %[[cast1:.*]] = memref.cast %[[m1]] : memref<5xf32, strided<[1], offset: 2>> to memref<5xf32, strided<[?], offset: ?>>
+ %1 = tensor.extract_slice %t[2][5][1] : tensor<10xf32> to tensor<5xf32>
+ // CHECK: return ...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesMultiple
Depends on #113999. Patch is 20.08 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/114017.diff 5 Files Affected:
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 6e91d3b89a7c79..195b17fcf902a2 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -41,18 +41,13 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
#endif // NDEBUG
}
-/// Return the unique ReturnOp that terminates `funcOp`.
-/// Return nullptr if there is no such unique ReturnOp.
-static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
- func::ReturnOp returnOp;
- for (Block &b : funcOp.getBody()) {
- if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
- if (returnOp)
- return nullptr;
- returnOp = candidateOp;
- }
- }
- return returnOp;
+/// Return all top-level func.return ops in the given function.
+static SmallVector<func::ReturnOp> getReturnOps(FuncOp funcOp) {
+ SmallVector<func::ReturnOp> result;
+ for (Block &b : funcOp.getBody())
+ if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator()))
+ result.push_back(returnOp);
+ return result;
}
/// Return the index-th bufferized function argument type. This assumes that the
@@ -372,15 +367,6 @@ struct FuncOpInterface
getBufferType(op, value, options, invocationStack);
}
- LogicalResult verifyAnalysis(Operation *op,
- const AnalysisState &state) const {
- auto funcOp = cast<func::FuncOp>(op);
- // TODO: func.func with multiple returns are not supported.
- if (!getAssumedUniqueReturnOp(funcOp) && !funcOp.isExternal())
- return op->emitOpError("op without unique func.return is not supported");
- return success();
- }
-
/// Rewrite function bbArgs and return values into buffer form. This function
/// bufferizes the function signature and the ReturnOp. When the entire
/// function body has been bufferized, function return types can be switched
@@ -427,41 +413,38 @@ struct FuncOpInterface
return success();
}
- // TODO: Support functions with multiple returns.
- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
- assert(returnOp && "expected func with single return op");
- assert(returnOp->getNumOperands() == retTypes.size() &&
- "incorrect number of return values");
- Location loc = returnOp.getLoc();
-
// 1. Bufferize every block.
for (Block &block : funcOp.getBody())
if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
options)))
return failure();
- // 2. Bufferize all operands of the return op.
- SmallVector<Value> returnValues;
- for (auto [returnVal, bufferizedType] :
- llvm::zip_equal(returnOp->getOperands(), retTypes)) {
- auto tensorType = dyn_cast<TensorType>(returnVal.getType());
- rewriter.setInsertionPoint(returnOp);
-
- // If not a tensor type just forward it.
- if (!tensorType) {
- returnValues.push_back(returnVal);
- continue;
+ // 2. Bufferize the operands of the all return op.
+ for (func::ReturnOp returnOp : getReturnOps(funcOp)) {
+ assert(returnOp->getNumOperands() == retTypes.size() &&
+ "incorrect number of return values");
+ SmallVector<Value> returnValues;
+ for (auto [returnVal, bufferizedType] :
+ llvm::zip_equal(returnOp->getOperands(), retTypes)) {
+ auto tensorType = dyn_cast<TensorType>(returnVal.getType());
+ rewriter.setInsertionPoint(returnOp);
+
+ // If not a tensor type just forward it.
+ if (!tensorType) {
+ returnValues.push_back(returnVal);
+ continue;
+ }
+
+ // Note: If `inferFunctionResultLayout = true`, casts are later folded
+ // away.
+ Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
+ returnOp.getLoc(), bufferizedType, returnVal);
+ returnValues.push_back(toMemrefOp);
}
- // Note: If `inferFunctionResultLayout = true`, casts are later folded
- // away.
- Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
- loc, bufferizedType, returnVal);
- returnValues.push_back(toMemrefOp);
+ returnOp.getOperandsMutable().assign(returnValues);
}
- returnOp.getOperandsMutable().assign(returnValues);
-
// 3. Set the new function type.
funcOp.setType(newFuncType);
return success();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 0a4072605c265f..e4635ebd78d8f8 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -86,18 +86,13 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
return state.addExtension<FuncAnalysisState>();
}
-/// Return the unique ReturnOp that terminates `funcOp`.
-/// Return nullptr if there is no such unique ReturnOp.
-static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
- func::ReturnOp returnOp;
- for (Block &b : funcOp.getBody()) {
- if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
- if (returnOp)
- return nullptr;
- returnOp = candidateOp;
- }
- }
- return returnOp;
+/// Return all top-level func.return ops in the given function.
+static SmallVector<func::ReturnOp> getReturnOps(FuncOp funcOp) {
+ SmallVector<func::ReturnOp> result;
+ for (Block &b : funcOp.getBody())
+ if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator()))
+ result.push_back(returnOp);
+ return result;
}
namespace {
@@ -146,24 +141,80 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
return success();
}
- // Support only single return-terminated block in the function.
- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
- assert(returnOp && "expected func with single return op");
-
- for (OpOperand &returnVal : returnOp->getOpOperands())
- if (isa<RankedTensorType>(returnVal.get().getType()))
- for (BlockArgument bbArg : funcOp.getArguments())
- if (isa<RankedTensorType>(bbArg.getType())) {
- int64_t returnIdx = returnVal.getOperandNumber();
- int64_t bbArgIdx = bbArg.getArgNumber();
- if (state.areEquivalentBufferizedValues(returnVal.get(), bbArg)) {
- funcState.equivalentFuncArgs[funcOp][returnIdx] = bbArgIdx;
- if (state.getOptions().testAnalysisOnly)
- annotateEquivalentReturnBbArg(returnVal, bbArg);
+ // Find all func.return ops.
+ SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
+ assert(!returnOps.empty() && "expected at least one ReturnOp");
+
+ // Build alias sets. Merge all aliases from all func.return ops.
+ for (BlockArgument bbArg : funcOp.getArguments()) {
+ if (isa<RankedTensorType>(bbArg.getType())) {
+ int64_t bbArgIdx = bbArg.getArgNumber();
+ // Store aliases in a set, so that we don't add the same alias twice.
+ SetVector<int64_t> aliases;
+ for (func::ReturnOp returnOp : returnOps) {
+ for (OpOperand &returnVal : returnOp->getOpOperands()) {
+ if (isa<RankedTensorType>(returnVal.get().getType())) {
+ int64_t returnIdx = returnVal.getOperandNumber();
+ if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
+ aliases.insert(returnIdx);
}
- if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
- funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx);
}
+ }
+ for (int64_t alias : aliases)
+ funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(alias);
+ }
+ }
+
+ // Build equivalence sets.
+ // Helper function that finds an equivalent block argument index for the
+ // given OpOperand. Return std::nullopt if no equivalent block argument could
+ // be found.
+ auto findEquivalentBlockArgIdx =
+ [&](OpOperand &opOperand) -> std::optional<int64_t> {
+ Value v = opOperand.get();
+ if (!isa<TensorType>(v.getType()))
+ return std::nullopt;
+ for (BlockArgument bbArg : funcOp.getArguments()) {
+ if (isa<RankedTensorType>(bbArg.getType())) {
+ if (state.areEquivalentBufferizedValues(v, bbArg)) {
+ if (state.getOptions().testAnalysisOnly)
+ annotateEquivalentReturnBbArg(opOperand, bbArg);
+ return bbArg.getArgNumber();
+ }
+ }
+ }
+ return std::nullopt;
+ };
+
+ int64_t numResults = returnOps.front()->getNumOperands();
+ for (int64_t i = 0; i < numResults; ++i) {
+ // Find the equivalent block argument index for the i-th operand of the
+ // first func.return op.
+ std::optional<int64_t> maybeEquiv =
+ findEquivalentBlockArgIdx(returnOps.front()->getOpOperand(i));
+ if (!maybeEquiv.has_value())
+ continue;
+ int64_t bbArgIdx = *maybeEquiv;
+ bool allEquiv = true;
+
+ // Check if all other func.return ops have the same equivalent block
+ // argument for the i-th operand. In contrast to aliasing information,
+ // which is just "merged", equivalence information must match across all
+ // func.return ops.
+ for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) {
+ std::optional<int64_t> maybeEquiv =
+ findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
+ if (maybeEquiv != bbArgIdx) {
+ allEquiv = false;
+ break;
+ }
+ }
+
+ // All func.return ops have the same equivalent block argument for the i-th
+ // operand.
+ if (allEquiv)
+ funcState.equivalentFuncArgs[funcOp][i] = bbArgIdx;
+ }
return success();
}
@@ -299,14 +350,6 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
// For each FuncOp, the number of func::CallOp it contains.
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
- if (!funcOp.getBody().empty()) {
- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
- if (!returnOp)
- return funcOp->emitError()
- << "cannot bufferize a FuncOp with tensors and "
- "without a unique ReturnOp";
- }
-
// Collect function calls and populate the caller map.
numberCallOpsContainedInFuncOp[funcOp] = 0;
return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
@@ -342,6 +385,42 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
return success();
}
+/// Helper function that extracts the source from a memref.cast. If the given
+/// value is not a memref.cast result, simply returns the given value.
+static Value unpackCast(Value v) {
+ auto castOp = v.getDefiningOp<memref::CastOp>();
+ if (!castOp)
+ return v;
+ return castOp.getSource();
+}
+
+/// Helper function that returns the return types (skipping casts) of the given
+/// func.return ops. This function returns as many types as the return ops have
+/// operands. If the i-th operand is not the same for all func.return ops, then
+/// the i-th returned type is an "empty" type.
+static SmallVector<Type> getReturnTypes(SmallVector<func::ReturnOp> returnOps) {
+ assert(!returnOps.empty() && "expected at least one ReturnOp");
+ int numOperands = returnOps.front()->getNumOperands();
+
+ // Helper function that unpacks memref.cast ops and returns the type.
+ auto getSourceType = [&](Value v) { return unpackCast(v).getType(); };
+
+ SmallVector<Type> result;
+ for (int i = 0; i < numOperands; ++i) {
+ // Get the type of the i-th operand of the first func.return ops.
+ Type t = getSourceType(returnOps.front()->getOperand(i));
+
+ // Check if all other func.return ops have a matching operand type.
+ for (int j = 1; j < static_cast<int>(returnOps.size()); ++j)
+ if (getSourceType(returnOps[j]->getOperand(i)) != t)
+ t = Type();
+
+ result.push_back(t);
+ }
+
+ return result;
+}
+
/// Fold return values that are memref casts and update function return types.
///
/// During FuncOp bufferization, the exact type of the returned memrefs (if any)
@@ -350,21 +429,33 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
/// entire function body, a more concise memref type can potentially be used for
/// the return type of the function.
static void foldMemRefCasts(func::FuncOp funcOp) {
+ // There is nothing to do for bodiless ops.
if (funcOp.getBody().empty())
return;
- func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
- SmallVector<Type> resultTypes;
+ // Compute the common result types of all return ops.
+ SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
+ SmallVector<Type> resultTypes = getReturnTypes(returnOps);
- for (OpOperand &operand : returnOp->getOpOperands()) {
- if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
- operand.set(castOp.getSource());
- resultTypes.push_back(castOp.getSource().getType());
- } else {
- resultTypes.push_back(operand.get().getType());
+ // Remove direct casts.
+ for (func::ReturnOp returnOp : returnOps) {
+ for (OpOperand &operand : returnOp->getOpOperands()) {
+ // Bail if no common result type was found.
+ if (resultTypes[operand.getOperandNumber()]) {
+ operand.set(unpackCast(operand.get()));
+ }
}
}
+ // Fill in the missing result types that were not the same among all
+ // func.return ops.
+ for (int i = 0; i < static_cast<int>(resultTypes.size()); ++i) {
+ if (resultTypes[i])
+ continue;
+ resultTypes[i] = funcOp.getFunctionType().getResult(i);
+ }
+
+ // Update the function type.
auto newFuncType = FunctionType::get(
funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
funcOp.setType(newFuncType);
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
index 42d9cc00d3ff5a..ab6de70dfb161b 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir
@@ -1348,3 +1348,48 @@ func.func @private_func_aliasing(%t: tensor<?xf32>) -> f32 {
%2 = tensor.extract %1[%c0] : tensor<6xf32>
return %2 : f32
}
+
+// -----
+
+// CHECK-ALIAS-SETS-LABEL: func @multiple_returns(
+func.func @multiple_returns(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) -> tensor<5xf32> {
+ cf.cond_br %c, ^bb1, ^bb2
+^bb1:
+ return %t0 : tensor<5xf32>
+^bb2:
+ return %t1 : tensor<5xf32>
+}
+
+// CHECK-ALIAS-SETS: func @caller(
+// CHECK-ALIAS-SETS-SAME: %{{.*}}: i1, %[[t0:.*]]: tensor<5xf32> {bufferization.access = "read"}, %[[t1:.*]]: tensor<5xf32> {bufferization.access = "read"}, %[[t2:.*]]: tensor<5xf32> {bufferization.access = "none"})
+func.func @caller(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) {
+ // Check that alias sets are computed correctly.
+ // CHECK-ALIAS-SETS: %[[result:.*]] = call @multiple_returns
+ // CHECK-ALIAS-SETS-SAME: {__inplace_operands_attr__ = ["none", "true", "true", "true"],
+ // CHECK-ALIAS-SETS-SAME: __opresult_alias_set_attr__ = [{{\[}}"%[[result]]", "%[[t0]]", "%[[t1]]"]]}
+ call @multiple_returns(%c, %t0, %t1, %t2) : (i1, tensor<5xf32>, tensor<5xf32>, tensor<5xf32>) -> (tensor<5xf32>)
+ return
+}
+
+// -----
+
+// CHECK-ALIAS-SETS-LABEL: func @multiple_equivalent_returns(
+func.func @multiple_equivalent_returns(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) -> tensor<5xf32> {
+ cf.cond_br %c, ^bb1, ^bb2
+^bb1:
+ return %t0 : tensor<5xf32>
+^bb2:
+ return %t0 : tensor<5xf32>
+}
+
+// CHECK-ALIAS-SETS: func @caller(
+// CHECK-ALIAS-SETS-SAME: %{{.*}}: i1, %[[t0:.*]]: tensor<5xf32> {bufferization.access = "read"}, %[[t1:.*]]: tensor<5xf32> {bufferization.access = "none"}, %[[t2:.*]]: tensor<5xf32> {bufferization.access = "none"})
+func.func @caller(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) -> tensor<5xf32> {
+ // Check that equivalence sets are computed correctly.
+ // CHECK-ALIAS-SETS: %[[result:.*]] = call @multiple_equivalent_returns
+ // CHECK-ALIAS-SETS-SAME: {__inplace_operands_attr__ = ["none", "true", "true", "true"],
+ // CHECK-ALIAS-SETS-SAME: __opresult_alias_set_attr__ = [{{\[}}"%[[result]]", "%[[t0]]"]]}
+ %r = call @multiple_equivalent_returns(%c, %t0, %t1, %t2) : (i1, tensor<5xf32>, tensor<5xf32>, tensor<5xf32>) -> (tensor<5xf32>)
+ // CHECK: {__equivalent_func_args__ = [1], __inplace_operands_attr__ = ["true"]} %[[result]] : tensor<5xf32>
+ return %r : tensor<5xf32>
+}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
index 2829eafb7c1c59..f3da82cc0064d4 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
@@ -1,24 +1,5 @@
// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="bufferize-function-boundaries=1" -split-input-file -verify-diagnostics
-// expected-error @+1 {{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}}
-func.func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>)
- -> (tensor<f32>, tensor<f32>)
-{
- cf.cond_br %cond1, ^bb1, ^bb2
-
- ^bb1:
- %T:2 = scf.if %cond2 -> (tensor<f32>, tensor<f32>) {
- scf.yield %t1, %t2 : tensor<f32>, tensor<f32>
- } else {
- scf.yield %t2, %t1 : tensor<f32>, tensor<f32>
- }
- return %T#0, %T#1 : tensor<f32>, tensor<f32>
- ^bb2:
- return %t2, %t1 : tensor<f32>, tensor<f32>
-}
-
-// -----
-
// expected-error @-3 {{expected callgraph to be free of circular dependencies}}
func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> {
@@ -160,7 +141,8 @@ func.func @regression_scf_while() {
// -----
-// expected-error @below{{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}}
+// expected-error @below{{could not infer buffer type of block argument}}
+// expected-error @below{{failed to bufferize op}}
func.func @func_multiple_yields(%t: tensor<5xf32>) -> tensor<5xf32> {
func.return %t : tensor<5xf32>
^bb1(%arg1 : tensor<5xf32>):
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index d31b43477beb9f..4f10ffea561aa8 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -722,3 +722,27 @@ func.func @bar(%t: tensor<5xf32>, %m: memref<5xf32>) -> memref<5xf32> {
%0 = func.call @foo(%m) : (memref<5xf32>) -> (memref<5xf32>)
return %0 : memref<5xf32>
}
+
+// -----
+
+// The two func.return operands have different types after bufferization. Make
+// sure that memref.cast ops are inserted.
+
+// CHECK-LABEL: func @result_type_mismatch({{.*}}) -> memref<5xf32, strided<[?], offset: ?>>
+func.func @result_type_mismatch(%c: i1) -> tensor<5xf32> {
+ // CHECK: %[[alloc:.*]] = memref.alloc() {alignment = 64 : i64} : memref<10xf32>
+ %t = tensor.empty() : tensor<10xf32>
+ cf.cond_br %c, ^bb1, ^bb2
+^bb1:
+ // CHECK: %[[m0:.*]] = memref.subview %[[alloc]][0] [5] [2] : memref<10xf32> to memref<5xf32, strided<[2]>>
+ // CHECK: %[[cast0:.*]] = memref.cast %[[m0]] : memref<5xf32, strided<[2]>> to memref<5xf32, strided<[?], offset: ?>>
+ %0 = tensor.extract_slice %t[0][5][2] : tensor<10xf32> to tensor<5xf32>
+ // CHECK: return %[[cast0]] : memref<5xf32, strided<[?], offset: ?>
+ return %0 : tensor<5xf32>
+^bb2:
+ // CHECK: %[[m1:.*]] = memref.subview %[[alloc]][2] [5] [1] : memref<10xf32> to memref<5xf32, strided<[1], offset: 2>>
+ // CHECK: %[[cast1:.*]] = memref.cast %[[m1]] : memref<5xf32, strided<[1], offset: 2>> to memref<5xf32, strided<[?], offset: ?>>
+ %1 = tensor.extract_slice %t[2][5][1] : tensor<10xf32> to tensor<5xf32>
+ // CHECK: return ...
[truncated]
|
3522311
to
684ac4a
Compare
40c5042
to
3f7f8a7
Compare
%1 = tensor.extract_slice %t[2][5][1] : tensor<10xf32> to tensor<5xf32> | ||
// CHECK: return %[[cast1]] : memref<5xf32, strided<[?], offset: ?>> | ||
return %1 : tensor<5xf32> | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very interesting test. thanks.
mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
Outdated
Show resolved
Hide resolved
3f7f8a7
to
52d1a20
Compare
mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
Outdated
Show resolved
Hide resolved
52d1a20
to
1122ffe
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM . Thanks.
Multiple
func.return
ops inside of afunc.func
op are now supported during bufferization. This PR extends the code base in 3 places:memref.cast
ops are folded away only if allfunc.return
ops have matching buffer types. (E.g., we don't fold if tworeturn
ops have operands with different layout maps.)func.return
ops are merged. That's because aliasing is a "may be" property.func.return
ops are taken only if they match. If differentfunc.return
ops have different equivalence sets for their operands, the equivalence information is dropped. That's because equivalence is a "must be" property.This commit is in preparation of removing the deprecated
func-bufferize
pass. That pass can bufferize functions with multiplereturn
ops.