Skip to content

Commit 684ac4a

Browse files
[mlir][bufferization] Add support for non-unique func.return
1 parent 11089cc commit 684ac4a

File tree

5 files changed

+235
-110
lines changed

5 files changed

+235
-110
lines changed

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Lines changed: 29 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,13 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
4141
#endif // NDEBUG
4242
}
4343

44-
/// Return the unique ReturnOp that terminates `funcOp`.
45-
/// Return nullptr if there is no such unique ReturnOp.
46-
static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
47-
func::ReturnOp returnOp;
48-
for (Block &b : funcOp.getBody()) {
49-
if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
50-
if (returnOp)
51-
return nullptr;
52-
returnOp = candidateOp;
53-
}
54-
}
55-
return returnOp;
44+
/// Return all top-level func.return ops in the given function.
45+
static SmallVector<func::ReturnOp> getReturnOps(FuncOp funcOp) {
46+
SmallVector<func::ReturnOp> result;
47+
for (Block &b : funcOp.getBody())
48+
if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator()))
49+
result.push_back(returnOp);
50+
return result;
5651
}
5752

5853
/// Return the index-th bufferized function argument type. This assumes that the
@@ -372,15 +367,6 @@ struct FuncOpInterface
372367
getBufferType(op, value, options, invocationStack);
373368
}
374369

375-
LogicalResult verifyAnalysis(Operation *op,
376-
const AnalysisState &state) const {
377-
auto funcOp = cast<func::FuncOp>(op);
378-
// TODO: func.func with multiple returns are not supported.
379-
if (!getAssumedUniqueReturnOp(funcOp) && !funcOp.isExternal())
380-
return op->emitOpError("op without unique func.return is not supported");
381-
return success();
382-
}
383-
384370
/// Rewrite function bbArgs and return values into buffer form. This function
385371
/// bufferizes the function signature and the ReturnOp. When the entire
386372
/// function body has been bufferized, function return types can be switched
@@ -427,41 +413,38 @@ struct FuncOpInterface
427413
return success();
428414
}
429415

430-
// TODO: Support functions with multiple returns.
431-
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
432-
assert(returnOp && "expected func with single return op");
433-
assert(returnOp->getNumOperands() == retTypes.size() &&
434-
"incorrect number of return values");
435-
Location loc = returnOp.getLoc();
436-
437416
// 1. Bufferize every block.
438417
for (Block &block : funcOp.getBody())
439418
if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
440419
options)))
441420
return failure();
442421

443-
// 2. Bufferize all operands of the return op.
444-
SmallVector<Value> returnValues;
445-
for (auto [returnVal, bufferizedType] :
446-
llvm::zip_equal(returnOp->getOperands(), retTypes)) {
447-
auto tensorType = dyn_cast<TensorType>(returnVal.getType());
448-
rewriter.setInsertionPoint(returnOp);
449-
450-
// If not a tensor type just forward it.
451-
if (!tensorType) {
452-
returnValues.push_back(returnVal);
453-
continue;
422+
// 2. Bufferize the operands of the all return op.
423+
for (func::ReturnOp returnOp : getReturnOps(funcOp)) {
424+
assert(returnOp->getNumOperands() == retTypes.size() &&
425+
"incorrect number of return values");
426+
SmallVector<Value> returnValues;
427+
for (auto [returnVal, bufferizedType] :
428+
llvm::zip_equal(returnOp->getOperands(), retTypes)) {
429+
auto tensorType = dyn_cast<TensorType>(returnVal.getType());
430+
rewriter.setInsertionPoint(returnOp);
431+
432+
// If not a tensor type just forward it.
433+
if (!tensorType) {
434+
returnValues.push_back(returnVal);
435+
continue;
436+
}
437+
438+
// Note: If `inferFunctionResultLayout = true`, casts are later folded
439+
// away.
440+
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
441+
returnOp.getLoc(), bufferizedType, returnVal);
442+
returnValues.push_back(toMemrefOp);
454443
}
455444

456-
// Note: If `inferFunctionResultLayout = true`, casts are later folded
457-
// away.
458-
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
459-
loc, bufferizedType, returnVal);
460-
returnValues.push_back(toMemrefOp);
445+
returnOp.getOperandsMutable().assign(returnValues);
461446
}
462447

463-
returnOp.getOperandsMutable().assign(returnValues);
464-
465448
// 3. Set the new function type.
466449
funcOp.setType(newFuncType);
467450
return success();

mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp

Lines changed: 135 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -86,18 +86,13 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
8686
return state.addExtension<FuncAnalysisState>();
8787
}
8888

89-
/// Return the unique ReturnOp that terminates `funcOp`.
90-
/// Return nullptr if there is no such unique ReturnOp.
91-
static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
92-
func::ReturnOp returnOp;
93-
for (Block &b : funcOp.getBody()) {
94-
if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
95-
if (returnOp)
96-
return nullptr;
97-
returnOp = candidateOp;
98-
}
99-
}
100-
return returnOp;
89+
/// Return all top-level func.return ops in the given function.
90+
static SmallVector<func::ReturnOp> getReturnOps(FuncOp funcOp) {
91+
SmallVector<func::ReturnOp> result;
92+
for (Block &b : funcOp.getBody())
93+
if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator()))
94+
result.push_back(returnOp);
95+
return result;
10196
}
10297

10398
namespace {
@@ -146,24 +141,80 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
146141
return success();
147142
}
148143

149-
// Support only single return-terminated block in the function.
150-
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
151-
assert(returnOp && "expected func with single return op");
152-
153-
for (OpOperand &returnVal : returnOp->getOpOperands())
154-
if (isa<RankedTensorType>(returnVal.get().getType()))
155-
for (BlockArgument bbArg : funcOp.getArguments())
156-
if (isa<RankedTensorType>(bbArg.getType())) {
157-
int64_t returnIdx = returnVal.getOperandNumber();
158-
int64_t bbArgIdx = bbArg.getArgNumber();
159-
if (state.areEquivalentBufferizedValues(returnVal.get(), bbArg)) {
160-
funcState.equivalentFuncArgs[funcOp][returnIdx] = bbArgIdx;
161-
if (state.getOptions().testAnalysisOnly)
162-
annotateEquivalentReturnBbArg(returnVal, bbArg);
144+
// Find all func.return ops.
145+
SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
146+
assert(!returnOps.empty() && "expected at least one ReturnOp");
147+
148+
// Build alias sets. Merge all aliases from all func.return ops.
149+
for (BlockArgument bbArg : funcOp.getArguments()) {
150+
if (isa<RankedTensorType>(bbArg.getType())) {
151+
int64_t bbArgIdx = bbArg.getArgNumber();
152+
// Store aliases in a set, so that we don't add the same alias twice.
153+
SetVector<int64_t> aliases;
154+
for (func::ReturnOp returnOp : returnOps) {
155+
for (OpOperand &returnVal : returnOp->getOpOperands()) {
156+
if (isa<RankedTensorType>(returnVal.get().getType())) {
157+
int64_t returnIdx = returnVal.getOperandNumber();
158+
if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
159+
aliases.insert(returnIdx);
163160
}
164-
if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
165-
funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx);
166161
}
162+
}
163+
for (int64_t alias : aliases)
164+
funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(alias);
165+
}
166+
}
167+
168+
// Build equivalence sets.
169+
// Helper function that finds an equivalent block argument index for the
170+
// given OpOperand. Return std::nullopt if no equivalent block argument could
171+
// be found.
172+
auto findEquivalentBlockArgIdx =
173+
[&](OpOperand &opOperand) -> std::optional<int64_t> {
174+
Value v = opOperand.get();
175+
if (!isa<TensorType>(v.getType()))
176+
return std::nullopt;
177+
for (BlockArgument bbArg : funcOp.getArguments()) {
178+
if (isa<RankedTensorType>(bbArg.getType())) {
179+
if (state.areEquivalentBufferizedValues(v, bbArg)) {
180+
if (state.getOptions().testAnalysisOnly)
181+
annotateEquivalentReturnBbArg(opOperand, bbArg);
182+
return bbArg.getArgNumber();
183+
}
184+
}
185+
}
186+
return std::nullopt;
187+
};
188+
189+
int64_t numResults = returnOps.front()->getNumOperands();
190+
for (int64_t i = 0; i < numResults; ++i) {
191+
// Find the equivalent block argument index for the i-th operand of the
192+
// first func.return op.
193+
std::optional<int64_t> maybeEquiv =
194+
findEquivalentBlockArgIdx(returnOps.front()->getOpOperand(i));
195+
if (!maybeEquiv.has_value())
196+
continue;
197+
int64_t bbArgIdx = *maybeEquiv;
198+
bool allEquiv = true;
199+
200+
// Check if all other func.return ops have the same equivalent block
201+
// argument for the i-th operand. In contrast to aliasing information,
202+
// which is just "merged", equivalence information must match across all
203+
// func.return ops.
204+
for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) {
205+
std::optional<int64_t> maybeEquiv =
206+
findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
207+
if (maybeEquiv != bbArgIdx) {
208+
allEquiv = false;
209+
break;
210+
}
211+
}
212+
213+
// All func.return ops have the same equivalent block argument for the i-th
214+
// operand.
215+
if (allEquiv)
216+
funcState.equivalentFuncArgs[funcOp][i] = bbArgIdx;
217+
}
167218

168219
return success();
169220
}
@@ -299,14 +350,6 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
299350
// For each FuncOp, the number of func::CallOp it contains.
300351
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
301352
WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
302-
if (!funcOp.getBody().empty()) {
303-
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
304-
if (!returnOp)
305-
return funcOp->emitError()
306-
<< "cannot bufferize a FuncOp with tensors and "
307-
"without a unique ReturnOp";
308-
}
309-
310353
// Collect function calls and populate the caller map.
311354
numberCallOpsContainedInFuncOp[funcOp] = 0;
312355
return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
@@ -342,6 +385,42 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
342385
return success();
343386
}
344387

388+
/// Helper function that extracts the source from a memref.cast. If the given
389+
/// value is not a memref.cast result, simply returns the given value.
390+
static Value unpackCast(Value v) {
391+
auto castOp = v.getDefiningOp<memref::CastOp>();
392+
if (!castOp)
393+
return v;
394+
return castOp.getSource();
395+
}
396+
397+
/// Helper function that returns the return types (skipping casts) of the given
398+
/// func.return ops. This function returns as many types as the return ops have
399+
/// operands. If the i-th operand is not the same for all func.return ops, then
400+
/// the i-th returned type is an "empty" type.
401+
static SmallVector<Type> getReturnTypes(SmallVector<func::ReturnOp> returnOps) {
402+
assert(!returnOps.empty() && "expected at least one ReturnOp");
403+
int numOperands = returnOps.front()->getNumOperands();
404+
405+
// Helper function that unpacks memref.cast ops and returns the type.
406+
auto getSourceType = [&](Value v) { return unpackCast(v).getType(); };
407+
408+
SmallVector<Type> result;
409+
for (int i = 0; i < numOperands; ++i) {
410+
// Get the type of the i-th operand of the first func.return ops.
411+
Type t = getSourceType(returnOps.front()->getOperand(i));
412+
413+
// Check if all other func.return ops have a matching operand type.
414+
for (int j = 1; j < static_cast<int>(returnOps.size()); ++j)
415+
if (getSourceType(returnOps[j]->getOperand(i)) != t)
416+
t = Type();
417+
418+
result.push_back(t);
419+
}
420+
421+
return result;
422+
}
423+
345424
/// Fold return values that are memref casts and update function return types.
346425
///
347426
/// During FuncOp bufferization, the exact type of the returned memrefs (if any)
@@ -350,21 +429,33 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
350429
/// entire function body, a more concise memref type can potentially be used for
351430
/// the return type of the function.
352431
static void foldMemRefCasts(func::FuncOp funcOp) {
432+
// There is nothing to do for bodiless ops.
353433
if (funcOp.getBody().empty())
354434
return;
355435

356-
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
357-
SmallVector<Type> resultTypes;
436+
// Compute the common result types of all return ops.
437+
SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
438+
SmallVector<Type> resultTypes = getReturnTypes(returnOps);
358439

359-
for (OpOperand &operand : returnOp->getOpOperands()) {
360-
if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
361-
operand.set(castOp.getSource());
362-
resultTypes.push_back(castOp.getSource().getType());
363-
} else {
364-
resultTypes.push_back(operand.get().getType());
440+
// Remove direct casts.
441+
for (func::ReturnOp returnOp : returnOps) {
442+
for (OpOperand &operand : returnOp->getOpOperands()) {
443+
// Bail if no common result type was found.
444+
if (resultTypes[operand.getOperandNumber()]) {
445+
operand.set(unpackCast(operand.get()));
446+
}
365447
}
366448
}
367449

450+
// Fill in the missing result types that were not the same among all
451+
// func.return ops.
452+
for (int i = 0; i < static_cast<int>(resultTypes.size()); ++i) {
453+
if (resultTypes[i])
454+
continue;
455+
resultTypes[i] = funcOp.getFunctionType().getResult(i);
456+
}
457+
458+
// Update the function type.
368459
auto newFuncType = FunctionType::get(
369460
funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
370461
funcOp.setType(newFuncType);

mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,3 +1348,48 @@ func.func @private_func_aliasing(%t: tensor<?xf32>) -> f32 {
13481348
%2 = tensor.extract %1[%c0] : tensor<6xf32>
13491349
return %2 : f32
13501350
}
1351+
1352+
// -----
1353+
1354+
// CHECK-ALIAS-SETS-LABEL: func @multiple_returns(
1355+
func.func @multiple_returns(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) -> tensor<5xf32> {
1356+
cf.cond_br %c, ^bb1, ^bb2
1357+
^bb1:
1358+
return %t0 : tensor<5xf32>
1359+
^bb2:
1360+
return %t1 : tensor<5xf32>
1361+
}
1362+
1363+
// CHECK-ALIAS-SETS: func @caller(
1364+
// CHECK-ALIAS-SETS-SAME: %{{.*}}: i1, %[[t0:.*]]: tensor<5xf32> {bufferization.access = "read"}, %[[t1:.*]]: tensor<5xf32> {bufferization.access = "read"}, %[[t2:.*]]: tensor<5xf32> {bufferization.access = "none"})
1365+
func.func @caller(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) {
1366+
// Check that alias sets are computed correctly.
1367+
// CHECK-ALIAS-SETS: %[[result:.*]] = call @multiple_returns
1368+
// CHECK-ALIAS-SETS-SAME: {__inplace_operands_attr__ = ["none", "true", "true", "true"],
1369+
// CHECK-ALIAS-SETS-SAME: __opresult_alias_set_attr__ = [{{\[}}"%[[result]]", "%[[t0]]", "%[[t1]]"]]}
1370+
call @multiple_returns(%c, %t0, %t1, %t2) : (i1, tensor<5xf32>, tensor<5xf32>, tensor<5xf32>) -> (tensor<5xf32>)
1371+
return
1372+
}
1373+
1374+
// -----
1375+
1376+
// CHECK-ALIAS-SETS-LABEL: func @multiple_equivalent_returns(
1377+
func.func @multiple_equivalent_returns(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) -> tensor<5xf32> {
1378+
cf.cond_br %c, ^bb1, ^bb2
1379+
^bb1:
1380+
return %t0 : tensor<5xf32>
1381+
^bb2:
1382+
return %t0 : tensor<5xf32>
1383+
}
1384+
1385+
// CHECK-ALIAS-SETS: func @caller(
1386+
// CHECK-ALIAS-SETS-SAME: %{{.*}}: i1, %[[t0:.*]]: tensor<5xf32> {bufferization.access = "read"}, %[[t1:.*]]: tensor<5xf32> {bufferization.access = "none"}, %[[t2:.*]]: tensor<5xf32> {bufferization.access = "none"})
1387+
func.func @caller(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>, %t2: tensor<5xf32>) -> tensor<5xf32> {
1388+
// Check that equivalence sets are computed correctly.
1389+
// CHECK-ALIAS-SETS: %[[result:.*]] = call @multiple_equivalent_returns
1390+
// CHECK-ALIAS-SETS-SAME: {__inplace_operands_attr__ = ["none", "true", "true", "true"],
1391+
// CHECK-ALIAS-SETS-SAME: __opresult_alias_set_attr__ = [{{\[}}"%[[result]]", "%[[t0]]"]]}
1392+
%r = call @multiple_equivalent_returns(%c, %t0, %t1, %t2) : (i1, tensor<5xf32>, tensor<5xf32>, tensor<5xf32>) -> (tensor<5xf32>)
1393+
// CHECK-ALIAS-SETS-SAME: {__equivalent_func_args__ = [1], __inplace_operands_attr__ = ["true"]} %[[result]] : tensor<5xf32>
1394+
return %r : tensor<5xf32>
1395+
}

0 commit comments

Comments
 (0)