Skip to content

Commit 1122ffe

Browse files
[mlir][bufferization] Add support for non-unique func.return
1 parent 38b0e1c commit 1122ffe

File tree

6 files changed

+236
-114
lines changed

6 files changed

+236
-114
lines changed

mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1313
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
1414
#include "mlir/Dialect/Func/IR/FuncOps.h"
15+
#include "llvm/ADT/SmallVector.h"
1516

1617
namespace mlir {
1718
class DialectRegistry;
@@ -21,6 +22,9 @@ class FuncOp;
2122
} // namespace func
2223

2324
namespace bufferization {
25+
/// Helper function that returns all func.return ops in the given function.
26+
SmallVector<func::ReturnOp> getReturnOps(func::FuncOp funcOp);
27+
2428
namespace func_ext {
2529
/// The state of analysis of a FuncOp.
2630
enum class FuncOpAnalysisState { NotAnalyzed, InProgress, Analyzed };

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Lines changed: 31 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@
1919
#include <optional>
2020

2121
namespace mlir {
22+
/// Return all func.return ops in the given function.
23+
SmallVector<func::ReturnOp> bufferization::getReturnOps(func::FuncOp funcOp) {
24+
SmallVector<func::ReturnOp> result;
25+
for (Block &b : funcOp.getBody())
26+
if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator()))
27+
result.push_back(returnOp);
28+
return result;
29+
}
30+
2231
namespace bufferization {
2332
namespace func_ext {
2433

@@ -41,20 +50,6 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
4150
#endif // NDEBUG
4251
}
4352

44-
/// Return the unique ReturnOp that terminates `funcOp`.
45-
/// Return nullptr if there is no such unique ReturnOp.
46-
static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
47-
func::ReturnOp returnOp;
48-
for (Block &b : funcOp.getBody()) {
49-
if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
50-
if (returnOp)
51-
return nullptr;
52-
returnOp = candidateOp;
53-
}
54-
}
55-
return returnOp;
56-
}
57-
5853
/// Return the index-th bufferized function argument type. This assumes that the
5954
/// specified argument is a tensor. If the tensor is ranked, a layout map may be
6055
/// specified by the user (as per `options.functionArgTypeConverterFn`).
@@ -391,15 +386,6 @@ struct FuncOpInterface
391386
getBufferType(op, value, options, invocationStack);
392387
}
393388

394-
LogicalResult verifyAnalysis(Operation *op,
395-
const AnalysisState &state) const {
396-
auto funcOp = cast<func::FuncOp>(op);
397-
// TODO: func.func with multiple returns are not supported.
398-
if (!getAssumedUniqueReturnOp(funcOp) && !funcOp.isExternal())
399-
return op->emitOpError("op without unique func.return is not supported");
400-
return success();
401-
}
402-
403389
/// Rewrite function bbArgs and return values into buffer form. This function
404390
/// bufferizes the function signature and the ReturnOp. When the entire
405391
/// function body has been bufferized, function return types can be switched
@@ -446,41 +432,38 @@ struct FuncOpInterface
446432
return success();
447433
}
448434

449-
// TODO: Support functions with multiple returns.
450-
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
451-
assert(returnOp && "expected func with single return op");
452-
assert(returnOp->getNumOperands() == retTypes.size() &&
453-
"incorrect number of return values");
454-
Location loc = returnOp.getLoc();
455-
456435
// 1. Bufferize every block.
457436
for (Block &block : funcOp.getBody())
458437
if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
459438
options)))
460439
return failure();
461440

462-
// 2. Bufferize all operands of the return op.
463-
SmallVector<Value> returnValues;
464-
for (auto [returnVal, bufferizedType] :
465-
llvm::zip_equal(returnOp->getOperands(), retTypes)) {
466-
auto tensorType = dyn_cast<TensorType>(returnVal.getType());
467-
rewriter.setInsertionPoint(returnOp);
468-
469-
// If not a tensor type just forward it.
470-
if (!tensorType) {
471-
returnValues.push_back(returnVal);
472-
continue;
441+
// 2. Bufferize the operands of the all return op.
442+
for (func::ReturnOp returnOp : getReturnOps(funcOp)) {
443+
assert(returnOp->getNumOperands() == retTypes.size() &&
444+
"incorrect number of return values");
445+
SmallVector<Value> returnValues;
446+
for (auto [returnVal, bufferizedType] :
447+
llvm::zip_equal(returnOp->getOperands(), retTypes)) {
448+
auto tensorType = dyn_cast<TensorType>(returnVal.getType());
449+
rewriter.setInsertionPoint(returnOp);
450+
451+
// If not a tensor type just forward it.
452+
if (!tensorType) {
453+
returnValues.push_back(returnVal);
454+
continue;
455+
}
456+
457+
// Note: If `inferFunctionResultLayout = true`, casts are later folded
458+
// away.
459+
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
460+
returnOp.getLoc(), bufferizedType, returnVal);
461+
returnValues.push_back(toMemrefOp);
473462
}
474463

475-
// Note: If `inferFunctionResultLayout = true`, casts are later folded
476-
// away.
477-
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
478-
loc, bufferizedType, returnVal);
479-
returnValues.push_back(toMemrefOp);
464+
returnOp.getOperandsMutable().assign(returnValues);
480465
}
481466

482-
returnOp.getOperandsMutable().assign(returnValues);
483-
484467
// 3. Set the new function type.
485468
funcOp.setType(newFuncType);
486469
return success();

mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp

Lines changed: 128 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -86,20 +86,6 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
8686
return state.addExtension<FuncAnalysisState>();
8787
}
8888

89-
/// Return the unique ReturnOp that terminates `funcOp`.
90-
/// Return nullptr if there is no such unique ReturnOp.
91-
static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
92-
func::ReturnOp returnOp;
93-
for (Block &b : funcOp.getBody()) {
94-
if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
95-
if (returnOp)
96-
return nullptr;
97-
returnOp = candidateOp;
98-
}
99-
}
100-
return returnOp;
101-
}
102-
10389
namespace {
10490

10591
/// Annotate IR with the results of the analysis. For testing purposes only.
@@ -146,24 +132,80 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
146132
return success();
147133
}
148134

149-
// Support only single return-terminated block in the function.
150-
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
151-
assert(returnOp && "expected func with single return op");
152-
153-
for (OpOperand &returnVal : returnOp->getOpOperands())
154-
if (isa<RankedTensorType>(returnVal.get().getType()))
155-
for (BlockArgument bbArg : funcOp.getArguments())
156-
if (isa<RankedTensorType>(bbArg.getType())) {
157-
int64_t returnIdx = returnVal.getOperandNumber();
158-
int64_t bbArgIdx = bbArg.getArgNumber();
159-
if (state.areEquivalentBufferizedValues(returnVal.get(), bbArg)) {
160-
funcState.equivalentFuncArgs[funcOp][returnIdx] = bbArgIdx;
161-
if (state.getOptions().testAnalysisOnly)
162-
annotateEquivalentReturnBbArg(returnVal, bbArg);
135+
// Find all func.return ops.
136+
SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
137+
assert(!returnOps.empty() && "expected at least one ReturnOp");
138+
139+
// Build alias sets. Merge all aliases from all func.return ops.
140+
for (BlockArgument bbArg : funcOp.getArguments()) {
141+
if (isa<RankedTensorType>(bbArg.getType())) {
142+
int64_t bbArgIdx = bbArg.getArgNumber();
143+
// Store aliases in a set, so that we don't add the same alias twice.
144+
SetVector<int64_t> aliases;
145+
for (func::ReturnOp returnOp : returnOps) {
146+
for (OpOperand &returnVal : returnOp->getOpOperands()) {
147+
if (isa<RankedTensorType>(returnVal.get().getType())) {
148+
int64_t returnIdx = returnVal.getOperandNumber();
149+
if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
150+
aliases.insert(returnIdx);
163151
}
164-
if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
165-
funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx);
166152
}
153+
}
154+
for (int64_t alias : aliases)
155+
funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(alias);
156+
}
157+
}
158+
159+
// Build equivalence sets.
160+
// Helper function that finds an equivalent block argument index for the
161+
// given OpOperand. Return std::nullopt if no equivalent block argument could
162+
// be found.
163+
auto findEquivalentBlockArgIdx =
164+
[&](OpOperand &opOperand) -> std::optional<int64_t> {
165+
Value v = opOperand.get();
166+
if (!isa<TensorType>(v.getType()))
167+
return std::nullopt;
168+
for (BlockArgument bbArg : funcOp.getArguments()) {
169+
if (isa<RankedTensorType>(bbArg.getType())) {
170+
if (state.areEquivalentBufferizedValues(v, bbArg)) {
171+
if (state.getOptions().testAnalysisOnly)
172+
annotateEquivalentReturnBbArg(opOperand, bbArg);
173+
return bbArg.getArgNumber();
174+
}
175+
}
176+
}
177+
return std::nullopt;
178+
};
179+
180+
int64_t numResults = returnOps.front()->getNumOperands();
181+
for (int64_t i = 0; i < numResults; ++i) {
182+
// Find the equivalent block argument index for the i-th operand of the
183+
// first func.return op.
184+
std::optional<int64_t> maybeEquiv =
185+
findEquivalentBlockArgIdx(returnOps.front()->getOpOperand(i));
186+
if (!maybeEquiv.has_value())
187+
continue;
188+
int64_t bbArgIdx = *maybeEquiv;
189+
bool allEquiv = true;
190+
191+
// Check if all other func.return ops have the same equivalent block
192+
// argument for the i-th operand. In contrast to aliasing information,
193+
// which is just "merged", equivalence information must match across all
194+
// func.return ops.
195+
for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) {
196+
std::optional<int64_t> maybeEquiv =
197+
findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
198+
if (maybeEquiv != bbArgIdx) {
199+
allEquiv = false;
200+
break;
201+
}
202+
}
203+
204+
// All func.return ops have the same equivalent block argument for the i-th
205+
// operand.
206+
if (allEquiv)
207+
funcState.equivalentFuncArgs[funcOp][i] = bbArgIdx;
208+
}
167209

168210
return success();
169211
}
@@ -302,14 +344,6 @@ static LogicalResult getFuncOpsOrderedByCalls(
302344
// For each FuncOp, the number of func::CallOp it contains.
303345
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
304346
WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
305-
if (!funcOp.getBody().empty()) {
306-
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
307-
if (!returnOp)
308-
return funcOp->emitError()
309-
<< "cannot bufferize a FuncOp with tensors and "
310-
"without a unique ReturnOp";
311-
}
312-
313347
// Collect function calls and populate the caller map.
314348
numberCallOpsContainedInFuncOp[funcOp] = 0;
315349
return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
@@ -351,6 +385,42 @@ static LogicalResult getFuncOpsOrderedByCalls(
351385
return success();
352386
}
353387

388+
/// Helper function that extracts the source from a memref.cast. If the given
389+
/// value is not a memref.cast result, simply returns the given value.
390+
static Value unpackCast(Value v) {
391+
auto castOp = v.getDefiningOp<memref::CastOp>();
392+
if (!castOp)
393+
return v;
394+
return castOp.getSource();
395+
}
396+
397+
/// Helper function that returns the return types (skipping casts) of the given
398+
/// func.return ops. This function returns as many types as the return ops have
399+
/// operands. If the i-th operand is not the same for all func.return ops, then
400+
/// the i-th returned type is an "empty" type.
401+
static SmallVector<Type> getReturnTypes(SmallVector<func::ReturnOp> returnOps) {
402+
assert(!returnOps.empty() && "expected at least one ReturnOp");
403+
int numOperands = returnOps.front()->getNumOperands();
404+
405+
// Helper function that unpacks memref.cast ops and returns the type.
406+
auto getSourceType = [&](Value v) { return unpackCast(v).getType(); };
407+
408+
SmallVector<Type> result;
409+
for (int i = 0; i < numOperands; ++i) {
410+
// Get the type of the i-th operand of the first func.return ops.
411+
Type t = getSourceType(returnOps.front()->getOperand(i));
412+
413+
// Check if all other func.return ops have a matching operand type.
414+
for (int j = 1; j < static_cast<int>(returnOps.size()); ++j)
415+
if (getSourceType(returnOps[j]->getOperand(i)) != t)
416+
t = Type();
417+
418+
result.push_back(t);
419+
}
420+
421+
return result;
422+
}
423+
354424
/// Fold return values that are memref casts and update function return types.
355425
///
356426
/// During FuncOp bufferization, the exact type of the returned memrefs (if any)
@@ -359,21 +429,33 @@ static LogicalResult getFuncOpsOrderedByCalls(
359429
/// entire function body, a more concise memref type can potentially be used for
360430
/// the return type of the function.
361431
static void foldMemRefCasts(func::FuncOp funcOp) {
432+
// There is nothing to do for bodiless ops.
362433
if (funcOp.getBody().empty())
363434
return;
364435

365-
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
366-
SmallVector<Type> resultTypes;
436+
// Compute the common result types of all return ops.
437+
SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
438+
SmallVector<Type> resultTypes = getReturnTypes(returnOps);
367439

368-
for (OpOperand &operand : returnOp->getOpOperands()) {
369-
if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
370-
operand.set(castOp.getSource());
371-
resultTypes.push_back(castOp.getSource().getType());
372-
} else {
373-
resultTypes.push_back(operand.get().getType());
440+
// Remove direct casts.
441+
for (func::ReturnOp returnOp : returnOps) {
442+
for (OpOperand &operand : returnOp->getOpOperands()) {
443+
// Bail if no common result type was found.
444+
if (resultTypes[operand.getOperandNumber()]) {
445+
operand.set(unpackCast(operand.get()));
446+
}
374447
}
375448
}
376449

450+
// Fill in the missing result types that were not the same among all
451+
// func.return ops.
452+
for (int i = 0; i < static_cast<int>(resultTypes.size()); ++i) {
453+
if (resultTypes[i])
454+
continue;
455+
resultTypes[i] = funcOp.getFunctionType().getResult(i);
456+
}
457+
458+
// Update the function type.
377459
auto newFuncType = FunctionType::get(
378460
funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
379461
funcOp.setType(newFuncType);

0 commit comments

Comments
 (0)