Skip to content

Commit 13808b4

Browse files
[mlir][bufferization] Add support for non-unique func.return
1 parent 11089cc commit 13808b4

File tree

4 files changed

+190
-110
lines changed

4 files changed

+190
-110
lines changed

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Lines changed: 29 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,13 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
4141
#endif // NDEBUG
4242
}
4343

44-
/// Return the unique ReturnOp that terminates `funcOp`.
45-
/// Return nullptr if there is no such unique ReturnOp.
46-
static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
47-
func::ReturnOp returnOp;
48-
for (Block &b : funcOp.getBody()) {
49-
if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
50-
if (returnOp)
51-
return nullptr;
52-
returnOp = candidateOp;
53-
}
54-
}
55-
return returnOp;
44+
/// Return all top-level func.return ops in the given function.
45+
static SmallVector<func::ReturnOp> getReturnOps(FuncOp funcOp) {
46+
SmallVector<func::ReturnOp> result;
47+
for (Block &b : funcOp.getBody())
48+
if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator()))
49+
result.push_back(returnOp);
50+
return result;
5651
}
5752

5853
/// Return the index-th bufferized function argument type. This assumes that the
@@ -372,15 +367,6 @@ struct FuncOpInterface
372367
getBufferType(op, value, options, invocationStack);
373368
}
374369

375-
LogicalResult verifyAnalysis(Operation *op,
376-
const AnalysisState &state) const {
377-
auto funcOp = cast<func::FuncOp>(op);
378-
// TODO: func.func with multiple returns are not supported.
379-
if (!getAssumedUniqueReturnOp(funcOp) && !funcOp.isExternal())
380-
return op->emitOpError("op without unique func.return is not supported");
381-
return success();
382-
}
383-
384370
/// Rewrite function bbArgs and return values into buffer form. This function
385371
/// bufferizes the function signature and the ReturnOp. When the entire
386372
/// function body has been bufferized, function return types can be switched
@@ -427,41 +413,38 @@ struct FuncOpInterface
427413
return success();
428414
}
429415

430-
// TODO: Support functions with multiple returns.
431-
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
432-
assert(returnOp && "expected func with single return op");
433-
assert(returnOp->getNumOperands() == retTypes.size() &&
434-
"incorrect number of return values");
435-
Location loc = returnOp.getLoc();
436-
437416
// 1. Bufferize every block.
438417
for (Block &block : funcOp.getBody())
439418
if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
440419
options)))
441420
return failure();
442421

443-
// 2. Bufferize all operands of the return op.
444-
SmallVector<Value> returnValues;
445-
for (auto [returnVal, bufferizedType] :
446-
llvm::zip_equal(returnOp->getOperands(), retTypes)) {
447-
auto tensorType = dyn_cast<TensorType>(returnVal.getType());
448-
rewriter.setInsertionPoint(returnOp);
449-
450-
// If not a tensor type just forward it.
451-
if (!tensorType) {
452-
returnValues.push_back(returnVal);
453-
continue;
422+
// 2. Bufferize the operands of the all return op.
423+
for (func::ReturnOp returnOp : getReturnOps(funcOp)) {
424+
assert(returnOp->getNumOperands() == retTypes.size() &&
425+
"incorrect number of return values");
426+
SmallVector<Value> returnValues;
427+
for (auto [returnVal, bufferizedType] :
428+
llvm::zip_equal(returnOp->getOperands(), retTypes)) {
429+
auto tensorType = dyn_cast<TensorType>(returnVal.getType());
430+
rewriter.setInsertionPoint(returnOp);
431+
432+
// If not a tensor type just forward it.
433+
if (!tensorType) {
434+
returnValues.push_back(returnVal);
435+
continue;
436+
}
437+
438+
// Note: If `inferFunctionResultLayout = true`, casts are later folded
439+
// away.
440+
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
441+
returnOp.getLoc(), bufferizedType, returnVal);
442+
returnValues.push_back(toMemrefOp);
454443
}
455444

456-
// Note: If `inferFunctionResultLayout = true`, casts are later folded
457-
// away.
458-
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
459-
loc, bufferizedType, returnVal);
460-
returnValues.push_back(toMemrefOp);
445+
returnOp.getOperandsMutable().assign(returnValues);
461446
}
462447

463-
returnOp.getOperandsMutable().assign(returnValues);
464-
465448
// 3. Set the new function type.
466449
funcOp.setType(newFuncType);
467450
return success();

mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp

Lines changed: 135 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -86,18 +86,13 @@ getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
8686
return state.addExtension<FuncAnalysisState>();
8787
}
8888

89-
/// Return the unique ReturnOp that terminates `funcOp`.
90-
/// Return nullptr if there is no such unique ReturnOp.
91-
static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
92-
func::ReturnOp returnOp;
93-
for (Block &b : funcOp.getBody()) {
94-
if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
95-
if (returnOp)
96-
return nullptr;
97-
returnOp = candidateOp;
98-
}
99-
}
100-
return returnOp;
89+
/// Return all top-level func.return ops in the given function.
90+
static SmallVector<func::ReturnOp> getReturnOps(FuncOp funcOp) {
91+
SmallVector<func::ReturnOp> result;
92+
for (Block &b : funcOp.getBody())
93+
if (auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator()))
94+
result.push_back(returnOp);
95+
return result;
10196
}
10297

10398
namespace {
@@ -146,24 +141,80 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
146141
return success();
147142
}
148143

149-
// Support only single return-terminated block in the function.
150-
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
151-
assert(returnOp && "expected func with single return op");
152-
153-
for (OpOperand &returnVal : returnOp->getOpOperands())
154-
if (isa<RankedTensorType>(returnVal.get().getType()))
155-
for (BlockArgument bbArg : funcOp.getArguments())
156-
if (isa<RankedTensorType>(bbArg.getType())) {
157-
int64_t returnIdx = returnVal.getOperandNumber();
158-
int64_t bbArgIdx = bbArg.getArgNumber();
159-
if (state.areEquivalentBufferizedValues(returnVal.get(), bbArg)) {
160-
funcState.equivalentFuncArgs[funcOp][returnIdx] = bbArgIdx;
161-
if (state.getOptions().testAnalysisOnly)
162-
annotateEquivalentReturnBbArg(returnVal, bbArg);
144+
// Find all func.return ops.
145+
SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
146+
assert(!returnOps.empty() && "expected at least one ReturnOp");
147+
148+
// Build alias sets. Merge all aliases from all func.return ops.
149+
for (BlockArgument bbArg : funcOp.getArguments()) {
150+
if (isa<RankedTensorType>(bbArg.getType())) {
151+
int64_t bbArgIdx = bbArg.getArgNumber();
152+
// Store aliases in a set, so that we don't add the same alias twice.
153+
SetVector<int64_t> aliases;
154+
for (func::ReturnOp returnOp : returnOps) {
155+
for (OpOperand &returnVal : returnOp->getOpOperands()) {
156+
if (isa<RankedTensorType>(returnVal.get().getType())) {
157+
int64_t returnIdx = returnVal.getOperandNumber();
158+
if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
159+
aliases.insert(returnIdx);
163160
}
164-
if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
165-
funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx);
166161
}
162+
}
163+
for (int64_t alias : aliases)
164+
funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(alias);
165+
}
166+
}
167+
168+
// Build equivalence sets.
169+
// Helper function that finds an equivalent block argument index for the
170+
// given OpOperand. Return std::nullopt if no equivalent block argument could
171+
// be found.
172+
auto findEquivalentBlockArgIdx =
173+
[&](OpOperand &opOperand) -> std::optional<int64_t> {
174+
Value v = opOperand.get();
175+
if (!isa<TensorType>(v.getType()))
176+
return std::nullopt;
177+
for (BlockArgument bbArg : funcOp.getArguments()) {
178+
if (isa<RankedTensorType>(bbArg.getType())) {
179+
if (state.areEquivalentBufferizedValues(v, bbArg)) {
180+
if (state.getOptions().testAnalysisOnly)
181+
annotateEquivalentReturnBbArg(opOperand, bbArg);
182+
return bbArg.getArgNumber();
183+
}
184+
}
185+
}
186+
return std::nullopt;
187+
};
188+
189+
int64_t numResults = returnOps.front()->getNumOperands();
190+
for (int64_t i = 0; i < numResults; ++i) {
191+
// Find the equivalent block argument index for the i-th operand of the
192+
// first func.return op.
193+
std::optional<int64_t> maybeEquiv =
194+
findEquivalentBlockArgIdx(returnOps.front()->getOpOperand(i));
195+
if (!maybeEquiv.has_value())
196+
continue;
197+
int64_t bbArgIdx = *maybeEquiv;
198+
bool allEquiv = true;
199+
200+
// Check if all other func.return ops have the same equivalent block
201+
// argument for the i-th operand. In contrast to aliasing information,
202+
// which is just "merged", equivalence information must match across all
203+
// func.return ops.
204+
for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) {
205+
std::optional<int64_t> maybeEquiv =
206+
findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
207+
if (maybeEquiv != bbArgIdx) {
208+
allEquiv = false;
209+
break;
210+
}
211+
}
212+
213+
// All func.return ops have the same equivalent block argument for the i-th
214+
// operand.
215+
if (allEquiv)
216+
funcState.equivalentFuncArgs[funcOp][i] = bbArgIdx;
217+
}
167218

168219
return success();
169220
}
@@ -299,14 +350,6 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
299350
// For each FuncOp, the number of func::CallOp it contains.
300351
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
301352
WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
302-
if (!funcOp.getBody().empty()) {
303-
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
304-
if (!returnOp)
305-
return funcOp->emitError()
306-
<< "cannot bufferize a FuncOp with tensors and "
307-
"without a unique ReturnOp";
308-
}
309-
310353
// Collect function calls and populate the caller map.
311354
numberCallOpsContainedInFuncOp[funcOp] = 0;
312355
return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
@@ -342,6 +385,42 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
342385
return success();
343386
}
344387

388+
/// Helper function that extracts the source from a memref.cast. If the given
389+
/// value is not a memref.cast result, simply returns the given value.
390+
static Value unpackCast(Value v) {
391+
auto castOp = v.getDefiningOp<memref::CastOp>();
392+
if (!castOp)
393+
return v;
394+
return castOp.getSource();
395+
}
396+
397+
/// Helper function that returns the return types (skipping casts) of the given
398+
/// func.return ops. This function returns as many types as the return ops have
399+
/// operands. If the i-th operand is not the same for all func.return ops, then
400+
/// the i-th returned type is an "empty" type.
401+
static SmallVector<Type> getReturnTypes(SmallVector<func::ReturnOp> returnOps) {
402+
assert(!returnOps.empty() && "expected at least one ReturnOp");
403+
int numOperands = returnOps.front()->getNumOperands();
404+
405+
// Helper function that unpacks memref.cast ops and returns the type.
406+
auto getSourceType = [&](Value v) { return unpackCast(v).getType(); };
407+
408+
SmallVector<Type> result;
409+
for (int i = 0; i < numOperands; ++i) {
410+
// Get the type of the i-th operand of the first func.return ops.
411+
Type t = getSourceType(returnOps.front()->getOperand(i));
412+
413+
// Check if all other func.return ops have a matching operand type.
414+
for (int j = 1; j < static_cast<int>(returnOps.size()); ++j)
415+
if (getSourceType(returnOps[j]->getOperand(i)) != t)
416+
t = Type();
417+
418+
result.push_back(t);
419+
}
420+
421+
return result;
422+
}
423+
345424
/// Fold return values that are memref casts and update function return types.
346425
///
347426
/// During FuncOp bufferization, the exact type of the returned memrefs (if any)
@@ -350,21 +429,33 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
350429
/// entire function body, a more concise memref type can potentially be used for
351430
/// the return type of the function.
352431
static void foldMemRefCasts(func::FuncOp funcOp) {
432+
// There is nothing to do for bodiless ops.
353433
if (funcOp.getBody().empty())
354434
return;
355435

356-
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
357-
SmallVector<Type> resultTypes;
436+
// Compute the common result types of all return ops.
437+
SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
438+
SmallVector<Type> resultTypes = getReturnTypes(returnOps);
358439

359-
for (OpOperand &operand : returnOp->getOpOperands()) {
360-
if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
361-
operand.set(castOp.getSource());
362-
resultTypes.push_back(castOp.getSource().getType());
363-
} else {
364-
resultTypes.push_back(operand.get().getType());
440+
// Remove direct casts.
441+
for (func::ReturnOp returnOp : returnOps) {
442+
for (OpOperand &operand : returnOp->getOpOperands()) {
443+
// Bail if no common result type was found.
444+
if (resultTypes[operand.getOperandNumber()]) {
445+
operand.set(unpackCast(operand.get()));
446+
}
365447
}
366448
}
367449

450+
// Fill in the missing result types that were not the same among all
451+
// func.return ops.
452+
for (int i = 0; i < static_cast<int>(resultTypes.size()); ++i) {
453+
if (resultTypes[i])
454+
continue;
455+
resultTypes[i] = funcOp.getFunctionType().getResult(i);
456+
}
457+
458+
// Update the function type.
368459
auto newFuncType = FunctionType::get(
369460
funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
370461
funcOp.setType(newFuncType);

mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,5 @@
11
// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="bufferize-function-boundaries=1" -split-input-file -verify-diagnostics
22

3-
// expected-error @+1 {{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}}
4-
func.func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>)
5-
-> (tensor<f32>, tensor<f32>)
6-
{
7-
cf.cond_br %cond1, ^bb1, ^bb2
8-
9-
^bb1:
10-
%T:2 = scf.if %cond2 -> (tensor<f32>, tensor<f32>) {
11-
scf.yield %t1, %t2 : tensor<f32>, tensor<f32>
12-
} else {
13-
scf.yield %t2, %t1 : tensor<f32>, tensor<f32>
14-
}
15-
return %T#0, %T#1 : tensor<f32>, tensor<f32>
16-
^bb2:
17-
return %t2, %t1 : tensor<f32>, tensor<f32>
18-
}
19-
20-
// -----
21-
223
// expected-error @-3 {{expected callgraph to be free of circular dependencies}}
234

245
func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> {
@@ -160,7 +141,8 @@ func.func @regression_scf_while() {
160141

161142
// -----
162143

163-
// expected-error @below{{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}}
144+
// expected-error @below{{could not infer buffer type of block argument}}
145+
// expected-error @below{{failed to bufferize op}}
164146
func.func @func_multiple_yields(%t: tensor<5xf32>) -> tensor<5xf32> {
165147
func.return %t : tensor<5xf32>
166148
^bb1(%arg1 : tensor<5xf32>):

mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,3 +722,27 @@ func.func @bar(%t: tensor<5xf32>, %m: memref<5xf32>) -> memref<5xf32> {
722722
%0 = func.call @foo(%m) : (memref<5xf32>) -> (memref<5xf32>)
723723
return %0 : memref<5xf32>
724724
}
725+
726+
// -----
727+
728+
// The two func.return operands have different types after bufferization. Make
729+
// sure that memref.cast ops are inserted.
730+
731+
// CHECK-LABEL: func @result_type_mismatch({{.*}}) -> memref<5xf32, strided<[?], offset: ?>>
732+
func.func @result_type_mismatch(%c: i1) -> tensor<5xf32> {
733+
// CHECK: %[[alloc:.*]] = memref.alloc() {alignment = 64 : i64} : memref<10xf32>
734+
%t = tensor.empty() : tensor<10xf32>
735+
cf.cond_br %c, ^bb1, ^bb2
736+
^bb1:
737+
// CHECK: %[[m0:.*]] = memref.subview %[[alloc]][0] [5] [2] : memref<10xf32> to memref<5xf32, strided<[2]>>
738+
// CHECK: %[[cast0:.*]] = memref.cast %[[m0]] : memref<5xf32, strided<[2]>> to memref<5xf32, strided<[?], offset: ?>>
739+
%0 = tensor.extract_slice %t[0][5][2] : tensor<10xf32> to tensor<5xf32>
740+
// CHECK: return %[[cast0]] : memref<5xf32, strided<[?], offset: ?>
741+
return %0 : tensor<5xf32>
742+
^bb2:
743+
// CHECK: %[[m1:.*]] = memref.subview %[[alloc]][2] [5] [1] : memref<10xf32> to memref<5xf32, strided<[1], offset: 2>>
744+
// CHECK: %[[cast1:.*]] = memref.cast %[[m1]] : memref<5xf32, strided<[1], offset: 2>> to memref<5xf32, strided<[?], offset: ?>>
745+
%1 = tensor.extract_slice %t[2][5][1] : tensor<10xf32> to tensor<5xf32>
746+
// CHECK: return %[[cast1]] : memref<5xf32, strided<[?], offset: ?>>
747+
return %1 : tensor<5xf32>
748+
}

0 commit comments

Comments
 (0)