Skip to content

Commit 3f7f8a7

Browse files
[mlir][bufferization] Add support for non-unique func.return
1 parent c271ba7 commit 3f7f8a7

File tree

5 files changed

+237
-110
lines changed

5 files changed

+237
-110
lines changed

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Lines changed: 29 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,13 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
4141
#endif // NDEBUG
4242
}
4343

44-
/// Return the unique ReturnOp that terminates `funcOp`.
45-
/// Return nullptr if there is no such unique ReturnOp.
46-
static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
47-
func::ReturnOp returnOp;
48-
for (Block &b : funcOp.getBody()) {
49-
if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
50-
if (returnOp)
51-
return nullptr;
52-
returnOp = candidateOp;
53-
}
54-
}
55-
return returnOp;
44+
/// Return all top-level func.return ops in the given function.
45+
static SmallVector<func::ReturnOp> getReturnOps(FuncOp funcOp) {
46+
SmallVector<func::ReturnOp> result;
47+
for (Block &b : funcOp.getBody())
48+
if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator()))
49+
result.push_back(returnOp);
50+
return result;
5651
}
5752

5853
/// Return the index-th bufferized function argument type. This assumes that the
@@ -391,15 +386,6 @@ struct FuncOpInterface
391386
getBufferType(op, value, options, invocationStack);
392387
}
393388

394-
LogicalResult verifyAnalysis(Operation *op,
395-
const AnalysisState &state) const {
396-
auto funcOp = cast<func::FuncOp>(op);
397-
// TODO: func.func with multiple returns are not supported.
398-
if (!getAssumedUniqueReturnOp(funcOp) && !funcOp.isExternal())
399-
return op->emitOpError("op without unique func.return is not supported");
400-
return success();
401-
}
402-
403389
/// Rewrite function bbArgs and return values into buffer form. This function
404390
/// bufferizes the function signature and the ReturnOp. When the entire
405391
/// function body has been bufferized, function return types can be switched
@@ -446,41 +432,38 @@ struct FuncOpInterface
446432
return success();
447433
}
448434

449-
// TODO: Support functions with multiple returns.
450-
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
451-
assert(returnOp && "expected func with single return op");
452-
assert(returnOp->getNumOperands() == retTypes.size() &&
453-
"incorrect number of return values");
454-
Location loc = returnOp.getLoc();
455-
456435
// 1. Bufferize every block.
457436
for (Block &block : funcOp.getBody())
458437
if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
459438
options)))
460439
return failure();
461440

462-
// 2. Bufferize all operands of the return op.
463-
SmallVector<Value> returnValues;
464-
for (auto [returnVal, bufferizedType] :
465-
llvm::zip_equal(returnOp->getOperands(), retTypes)) {
466-
auto tensorType = dyn_cast<TensorType>(returnVal.getType());
467-
rewriter.setInsertionPoint(returnOp);
468-
469-
// If not a tensor type just forward it.
470-
if (!tensorType) {
471-
returnValues.push_back(returnVal);
472-
continue;
441+
// 2. Bufferize the operands of the all return op.
442+
for (func::ReturnOp returnOp : getReturnOps(funcOp)) {
443+
assert(returnOp->getNumOperands() == retTypes.size() &&
444+
"incorrect number of return values");
445+
SmallVector<Value> returnValues;
446+
for (auto [returnVal, bufferizedType] :
447+
llvm::zip_equal(returnOp->getOperands(), retTypes)) {
448+
auto tensorType = dyn_cast<TensorType>(returnVal.getType());
449+
rewriter.setInsertionPoint(returnOp);
450+
451+
// If not a tensor type just forward it.
452+
if (!tensorType) {
453+
returnValues.push_back(returnVal);
454+
continue;
455+
}
456+
457+
// Note: If `inferFunctionResultLayout = true`, casts are later folded
458+
// away.
459+
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
460+
returnOp.getLoc(), bufferizedType, returnVal);
461+
returnValues.push_back(toMemrefOp);
473462
}
474463

475-
// Note: If `inferFunctionResultLayout = true`, casts are later folded
476-
// away.
477-
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
478-
loc, bufferizedType, returnVal);
479-
returnValues.push_back(toMemrefOp);
464+
returnOp.getOperandsMutable().assign(returnValues);
480465
}
481466

482-
returnOp.getOperandsMutable().assign(returnValues);
483-
484467
// 3. Set the new function type.
485468
funcOp.setType(newFuncType);
486469
return success();

mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp

Lines changed: 135 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -86,18 +86,13 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
8686
return state.addExtension<FuncAnalysisState>();
8787
}
8888

89-
/// Return the unique ReturnOp that terminates `funcOp`.
90-
/// Return nullptr if there is no such unique ReturnOp.
91-
static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
92-
func::ReturnOp returnOp;
93-
for (Block &b : funcOp.getBody()) {
94-
if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
95-
if (returnOp)
96-
return nullptr;
97-
returnOp = candidateOp;
98-
}
99-
}
100-
return returnOp;
89+
/// Return all top-level func.return ops in the given function.
90+
static SmallVector<func::ReturnOp> getReturnOps(FuncOp funcOp) {
91+
SmallVector<func::ReturnOp> result;
92+
for (Block &b : funcOp.getBody())
93+
if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator()))
94+
result.push_back(returnOp);
95+
return result;
10196
}
10297

10398
namespace {
@@ -146,24 +141,80 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
146141
return success();
147142
}
148143

149-
// Support only single return-terminated block in the function.
150-
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
151-
assert(returnOp && "expected func with single return op");
152-
153-
for (OpOperand &returnVal : returnOp->getOpOperands())
154-
if (isa<RankedTensorType>(returnVal.get().getType()))
155-
for (BlockArgument bbArg : funcOp.getArguments())
156-
if (isa<RankedTensorType>(bbArg.getType())) {
157-
int64_t returnIdx = returnVal.getOperandNumber();
158-
int64_t bbArgIdx = bbArg.getArgNumber();
159-
if (state.areEquivalentBufferizedValues(returnVal.get(), bbArg)) {
160-
funcState.equivalentFuncArgs[funcOp][returnIdx] = bbArgIdx;
161-
if (state.getOptions().testAnalysisOnly)
162-
annotateEquivalentReturnBbArg(returnVal, bbArg);
144+
// Find all func.return ops.
145+
SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
146+
assert(!returnOps.empty() && "expected at least one ReturnOp");
147+
148+
// Build alias sets. Merge all aliases from all func.return ops.
149+
for (BlockArgument bbArg : funcOp.getArguments()) {
150+
if (isa<RankedTensorType>(bbArg.getType())) {
151+
int64_t bbArgIdx = bbArg.getArgNumber();
152+
// Store aliases in a set, so that we don't add the same alias twice.
153+
SetVector<int64_t> aliases;
154+
for (func::ReturnOp returnOp : returnOps) {
155+
for (OpOperand &returnVal : returnOp->getOpOperands()) {
156+
if (isa<RankedTensorType>(returnVal.get().getType())) {
157+
int64_t returnIdx = returnVal.getOperandNumber();
158+
if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
159+
aliases.insert(returnIdx);
163160
}
164-
if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
165-
funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx);
166161
}
162+
}
163+
for (int64_t alias : aliases)
164+
funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(alias);
165+
}
166+
}
167+
168+
// Build equivalence sets.
169+
// Helper function that finds an equivalent block argument index for the
170+
// given OpOperand. Return std::nullopt if no equivalent block argument could
171+
// be found.
172+
auto findEquivalentBlockArgIdx =
173+
[&](OpOperand &opOperand) -> std::optional<int64_t> {
174+
Value v = opOperand.get();
175+
if (!isa<TensorType>(v.getType()))
176+
return std::nullopt;
177+
for (BlockArgument bbArg : funcOp.getArguments()) {
178+
if (isa<RankedTensorType>(bbArg.getType())) {
179+
if (state.areEquivalentBufferizedValues(v, bbArg)) {
180+
if (state.getOptions().testAnalysisOnly)
181+
annotateEquivalentReturnBbArg(opOperand, bbArg);
182+
return bbArg.getArgNumber();
183+
}
184+
}
185+
}
186+
return std::nullopt;
187+
};
188+
189+
int64_t numResults = returnOps.front()->getNumOperands();
190+
for (int64_t i = 0; i < numResults; ++i) {
191+
// Find the equivalent block argument index for the i-th operand of the
192+
// first func.return op.
193+
std::optional<int64_t> maybeEquiv =
194+
findEquivalentBlockArgIdx(returnOps.front()->getOpOperand(i));
195+
if (!maybeEquiv.has_value())
196+
continue;
197+
int64_t bbArgIdx = *maybeEquiv;
198+
bool allEquiv = true;
199+
200+
// Check if all other func.return ops have the same equivalent block
201+
// argument for the i-th operand. In contrast to aliasing information,
202+
// which is just "merged", equivalence information must match across all
203+
// func.return ops.
204+
for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) {
205+
std::optional<int64_t> maybeEquiv =
206+
findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
207+
if (maybeEquiv != bbArgIdx) {
208+
allEquiv = false;
209+
break;
210+
}
211+
}
212+
213+
// All func.return ops have the same equivalent block argument for the i-th
214+
// operand.
215+
if (allEquiv)
216+
funcState.equivalentFuncArgs[funcOp][i] = bbArgIdx;
217+
}
167218

168219
return success();
169220
}
@@ -302,14 +353,6 @@ static LogicalResult getFuncOpsOrderedByCalls(
302353
// For each FuncOp, the number of func::CallOp it contains.
303354
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
304355
WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
305-
if (!funcOp.getBody().empty()) {
306-
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
307-
if (!returnOp)
308-
return funcOp->emitError()
309-
<< "cannot bufferize a FuncOp with tensors and "
310-
"without a unique ReturnOp";
311-
}
312-
313356
// Collect function calls and populate the caller map.
314357
numberCallOpsContainedInFuncOp[funcOp] = 0;
315358
return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
@@ -351,6 +394,42 @@ static LogicalResult getFuncOpsOrderedByCalls(
351394
return success();
352395
}
353396

397+
/// Helper function that extracts the source from a memref.cast. If the given
398+
/// value is not a memref.cast result, simply returns the given value.
399+
static Value unpackCast(Value v) {
400+
auto castOp = v.getDefiningOp<memref::CastOp>();
401+
if (!castOp)
402+
return v;
403+
return castOp.getSource();
404+
}
405+
406+
/// Helper function that returns the return types (skipping casts) of the given
407+
/// func.return ops. This function returns as many types as the return ops have
408+
/// operands. If the i-th operand is not the same for all func.return ops, then
409+
/// the i-th returned type is an "empty" type.
410+
static SmallVector<Type> getReturnTypes(SmallVector<func::ReturnOp> returnOps) {
411+
assert(!returnOps.empty() && "expected at least one ReturnOp");
412+
int numOperands = returnOps.front()->getNumOperands();
413+
414+
// Helper function that unpacks memref.cast ops and returns the type.
415+
auto getSourceType = [&](Value v) { return unpackCast(v).getType(); };
416+
417+
SmallVector<Type> result;
418+
for (int i = 0; i < numOperands; ++i) {
419+
// Get the type of the i-th operand of the first func.return ops.
420+
Type t = getSourceType(returnOps.front()->getOperand(i));
421+
422+
// Check if all other func.return ops have a matching operand type.
423+
for (int j = 1; j < static_cast<int>(returnOps.size()); ++j)
424+
if (getSourceType(returnOps[j]->getOperand(i)) != t)
425+
t = Type();
426+
427+
result.push_back(t);
428+
}
429+
430+
return result;
431+
}
432+
354433
/// Fold return values that are memref casts and update function return types.
355434
///
356435
/// During FuncOp bufferization, the exact type of the returned memrefs (if any)
@@ -359,21 +438,33 @@ static LogicalResult getFuncOpsOrderedByCalls(
359438
/// entire function body, a more concise memref type can potentially be used for
360439
/// the return type of the function.
361440
static void foldMemRefCasts(func::FuncOp funcOp) {
441+
// There is nothing to do for bodiless ops.
362442
if (funcOp.getBody().empty())
363443
return;
364444

365-
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
366-
SmallVector<Type> resultTypes;
445+
// Compute the common result types of all return ops.
446+
SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
447+
SmallVector<Type> resultTypes = getReturnTypes(returnOps);
367448

368-
for (OpOperand &operand : returnOp->getOpOperands()) {
369-
if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
370-
operand.set(castOp.getSource());
371-
resultTypes.push_back(castOp.getSource().getType());
372-
} else {
373-
resultTypes.push_back(operand.get().getType());
449+
// Remove direct casts.
450+
for (func::ReturnOp returnOp : returnOps) {
451+
for (OpOperand &operand : returnOp->getOpOperands()) {
452+
// Bail if no common result type was found.
453+
if (resultTypes[operand.getOperandNumber()]) {
454+
operand.set(unpackCast(operand.get()));
455+
}
374456
}
375457
}
376458

459+
// Fill in the missing result types that were not the same among all
460+
// func.return ops.
461+
for (int i = 0; i < static_cast<int>(resultTypes.size()); ++i) {
462+
if (resultTypes[i])
463+
continue;
464+
resultTypes[i] = funcOp.getFunctionType().getResult(i);
465+
}
466+
467+
// Update the function type.
377468
auto newFuncType = FunctionType::get(
378469
funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
379470
funcOp.setType(newFuncType);

mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,3 +1360,49 @@ func.func @recursive_function(%a: tensor<?xf32>, %b: tensor<?xf32>) -> (tensor<?
13601360
%0:2 = call @recursive_function(%a, %b) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>)
13611361
return %0#0, %0#1 : tensor<?xf32>, tensor<?xf32>
13621362
}
1363+
1364+
// -----
1365+
1366+
// CHECK-ALIAS-SETS-LABEL: func @multiple_returns(
1367+
func.func @multiple_returns(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) -> tensor<5xf32> {
1368+
cf.cond_br %c, ^bb1, ^bb2
1369+
^bb1:
1370+
return %t0 : tensor<5xf32>
1371+
^bb2:
1372+
return %t1 : tensor<5xf32>
1373+
}
1374+
1375+
// CHECK-ALIAS-SETS: func @caller(
1376+
// CHECK-ALIAS-SETS-SAME: %{{.*}}: i1, %[[t0:.*]]: tensor<5xf32> {bufferization.access = "read"}, %[[t1:.*]]: tensor<5xf32> {bufferization.access = "read"}, %[[t2:.*]]: tensor<5xf32> {bufferization.access = "none"})
1377+
func.func @caller(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) {
1378+
// Check that alias sets are computed correctly.
1379+
// CHECK-ALIAS-SETS: %[[result:.*]] = call @multiple_returns
1380+
// CHECK-ALIAS-SETS-SAME: {__inplace_operands_attr__ = ["none", "true", "true", "true"],
1381+
// CHECK-ALIAS-SETS-SAME: __opresult_alias_set_attr__ = [{{\[}}"%[[result]]", "%[[t0]]", "%[[t1]]"]]}
1382+
call @multiple_returns(%c, %t0, %t1, %t2) : (i1, tensor<5xf32>, tensor<5xf32>, tensor<5xf32>) -> (tensor<5xf32>)
1383+
return
1384+
}
1385+
1386+
// -----
1387+
1388+
// CHECK-ALIAS-SETS-LABEL: func @multiple_equivalent_returns(
1389+
func.func @multiple_equivalent_returns(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) -> tensor<5xf32> {
1390+
cf.cond_br %c, ^bb1, ^bb2
1391+
^bb1:
1392+
return %t0 : tensor<5xf32>
1393+
^bb2:
1394+
return %t0 : tensor<5xf32>
1395+
}
1396+
1397+
// CHECK-ALIAS-SETS: func @caller(
1398+
// CHECK-ALIAS-SETS-SAME: %{{.*}}: i1, %[[t0:.*]]: tensor<5xf32> {bufferization.access = "read"}, %[[t1:.*]]: tensor<5xf32> {bufferization.access = "none"}, %[[t2:.*]]: tensor<5xf32> {bufferization.access = "none"})
1399+
func.func @caller(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) -> tensor<5xf32> {
1400+
// Check that equivalence sets are computed correctly.
1401+
// CHECK-ALIAS-SETS: %[[result:.*]] = call @multiple_equivalent_returns
1402+
// CHECK-ALIAS-SETS-SAME: {__inplace_operands_attr__ = ["none", "true", "true", "true"],
1403+
// CHECK-ALIAS-SETS-SAME: __opresult_alias_set_attr__ = [{{\[}}"%[[result]]", "%[[t0]]"]]}
1404+
%r = call @multiple_equivalent_returns(%c, %t0, %t1, %t2) : (i1, tensor<5xf32>, tensor<5xf32>, tensor<5xf32>) -> (tensor<5xf32>)
1405+
// CHECK-ALIAS-SETS-SAME: {__equivalent_func_args__ = [1], __inplace_operands_attr__ = ["true"]} %[[result]] : tensor<5xf32>
1406+
return %r : tensor<5xf32>
1407+
}
1408+

0 commit comments

Comments
 (0)