Skip to content

Commit c271ba7

Browse files
[mlir][bufferization] Add support for recursive function calls (#114003)
This commit adds support for recursive function calls to One-Shot Bufferize. The analysis does not support recursive function calls. The function body itself can be analyzed, but we cannot make any assumptions about the aliasing relation between function result and function arguments. Similarly, when looking at a `call` op, we do not know whether the operands will bufferize to a memory read/write. In the absence of such information, we have to conservatively assume that they do. This commit is in preparation of removing the deprecated `func-bufferize` pass. That pass can bufferize recursive functions.
1 parent 910a73f commit c271ba7

File tree

5 files changed

+145
-32
lines changed

5 files changed

+145
-32
lines changed

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,11 +207,18 @@ struct CallOpInterface
207207
FuncOp funcOp = getCalledFunction(callOp);
208208
assert(funcOp && "expected CallOp to a FuncOp");
209209

210-
// The callee was already bufferized, so we can directly take the type from
210+
// If the callee was already bufferized, we can directly take the type from
211211
// its signature.
212212
FunctionType funcType = funcOp.getFunctionType();
213-
return cast<BaseMemRefType>(
214-
funcType.getResult(cast<OpResult>(value).getResultNumber()));
213+
Type resultType =
214+
funcType.getResult(cast<OpResult>(value).getResultNumber());
215+
if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
216+
return bufferizedType;
217+
218+
// Otherwise, call the type converter to compute the bufferized type.
219+
auto tensorType = cast<TensorType>(resultType);
220+
return options.functionArgTypeConverterFn(
221+
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
215222
}
216223

217224
/// All function arguments are writable. It is the responsibility of the
@@ -261,6 +268,18 @@ struct CallOpInterface
261268

262269
// Caller / callee type mismatch is handled with castOrReallocMemRefValue.
263270
auto memRefType = funcType.getInput(opOperand.getOperandNumber());
271+
if (!isa<BaseMemRefType>(memRefType)) {
272+
// The called function was not bufferized yet. This can happen when
273+
// there cycles in the function call graph. Compute the bufferized
274+
// result type.
275+
FailureOr<BaseMemRefType> maybeMemRefType =
276+
bufferization::getBufferType(
277+
funcOp.getArgument(opOperand.getOperandNumber()), options);
278+
if (failed(maybeMemRefType))
279+
return failure();
280+
memRefType = *maybeMemRefType;
281+
}
282+
264283
// Since we don't yet have a clear layout story, to_memref may
265284
// conservatively turn tensors into more dynamic memref than necessary.
266285
// If the memref type of the callee fails, introduce an extra memref.cast

mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp

Lines changed: 62 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -285,14 +285,17 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
285285
}
286286

287287
/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
288-
/// callee-caller order (i.e. callees without callers first).
288+
/// callee-caller order (i.e., callees without callers first). Store all
289+
/// remaining functions (i.e., the ones that call each other recursively) in
290+
/// `remainingFuncOps`.
291+
///
289292
/// Store the map of FuncOp to all its callers in `callerMap`.
290-
/// Return `failure()` if a cycle of calls is detected or if we are unable to
291-
/// retrieve the called FuncOp from any func::CallOp.
292-
static LogicalResult
293-
getFuncOpsOrderedByCalls(ModuleOp moduleOp,
294-
SmallVectorImpl<func::FuncOp> &orderedFuncOps,
295-
FuncCallerMap &callerMap) {
293+
///
294+
/// Return `failure()` if we are unable to retrieve the called FuncOp from
295+
/// any func::CallOp.
296+
static LogicalResult getFuncOpsOrderedByCalls(
297+
ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
298+
SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap) {
296299
// For each FuncOp, the set of functions called by it (i.e. the union of
297300
// symbols of all nested func::CallOp).
298301
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
@@ -326,19 +329,25 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
326329
});
327330
if (res.wasInterrupted())
328331
return failure();
332+
329333
// Iteratively remove function operations that do not call any of the
330-
// functions remaining in the callCounter map and add them to the worklist.
334+
// functions remaining in the callCounter map and add them to ordered list.
331335
while (!numberCallOpsContainedInFuncOp.empty()) {
332336
auto it = llvm::find_if(numberCallOpsContainedInFuncOp,
333337
[](auto entry) { return entry.getSecond() == 0; });
334338
if (it == numberCallOpsContainedInFuncOp.end())
335-
return moduleOp.emitOpError(
336-
"expected callgraph to be free of circular dependencies.");
339+
break;
337340
orderedFuncOps.push_back(it->getFirst());
338341
for (auto callee : calledBy[it->getFirst()])
339342
numberCallOpsContainedInFuncOp[callee]--;
340343
numberCallOpsContainedInFuncOp.erase(it);
341344
}
345+
346+
// Put all other functions in the list of remaining functions. These are
347+
// functions that call each other circularly.
348+
for (auto it : numberCallOpsContainedInFuncOp)
349+
remainingFuncOps.push_back(it.first);
350+
342351
return success();
343352
}
344353

@@ -378,16 +387,23 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
378387
"expected that function boundary bufferization is activated");
379388
FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
380389

381-
// A list of functions in the order in which they are analyzed + bufferized.
390+
// A list of non-circular functions in the order in which they are analyzed
391+
// and bufferized.
382392
SmallVector<func::FuncOp> orderedFuncOps;
393+
// A list of all other functions. I.e., functions that call each other
394+
// recursively. For these, we analyze the function body but not the function
395+
// boundary.
396+
SmallVector<func::FuncOp> remainingFuncOps;
383397

384398
// A mapping of FuncOps to their callers.
385399
FuncCallerMap callerMap;
386400

387-
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
401+
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
402+
remainingFuncOps, callerMap)))
388403
return failure();
389404

390-
// Analyze ops.
405+
// Analyze functions in order. Starting with functions that are not calling
406+
// any other functions.
391407
for (func::FuncOp funcOp : orderedFuncOps) {
392408
if (!state.getOptions().isOpAllowed(funcOp))
393409
continue;
@@ -411,6 +427,25 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
411427
funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed;
412428
}
413429

430+
// Analyze all other functions. All function boundary analyses are skipped.
431+
for (func::FuncOp funcOp : remainingFuncOps) {
432+
if (!state.getOptions().isOpAllowed(funcOp))
433+
continue;
434+
435+
// Gather equivalence info for CallOps.
436+
equivalenceAnalysis(funcOp, state, funcState);
437+
438+
// Analyze funcOp.
439+
if (failed(analyzeOp(funcOp, state, statistics)))
440+
return failure();
441+
442+
// TODO: We currently skip all function argument analyses for functions
443+
// that call each other circularly. These analyses do not support recursive
444+
// calls yet. The `BufferizableOpInterface` implementations of `func`
445+
// dialect ops return conservative results in the absence of analysis
446+
// information.
447+
}
448+
414449
return success();
415450
}
416451

@@ -429,14 +464,26 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
429464
"expected that function boundary bufferization is activated");
430465
IRRewriter rewriter(moduleOp.getContext());
431466

432-
// A list of functions in the order in which they are analyzed + bufferized.
467+
// A list of non-circular functions in the order in which they are analyzed
468+
// and bufferized.
433469
SmallVector<func::FuncOp> orderedFuncOps;
470+
// A list of all other functions. I.e., functions that call each other
471+
// recursively. For these, we analyze the function body but not the function
472+
// boundary.
473+
SmallVector<func::FuncOp> remainingFuncOps;
434474

435475
// A mapping of FuncOps to their callers.
436476
FuncCallerMap callerMap;
437477

438-
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
478+
// Try to bufferize functions in calling order. I.e., first bufferize
479+
// functions that do not call other functions. This allows us to infer
480+
// accurate buffer types for function return values. Functions that call
481+
// each other recursively are bufferized in an unspecified order at the end.
482+
// We may use unnecessarily "complex" (in terms of layout map) buffer types.
483+
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
484+
remainingFuncOps, callerMap)))
439485
return failure();
486+
llvm::append_range(orderedFuncOps, remainingFuncOps);
440487

441488
// Bufferize functions.
442489
for (func::FuncOp funcOp : orderedFuncOps) {

mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,3 +1348,15 @@ func.func @private_func_aliasing(%t: tensor<?xf32>) -> f32 {
13481348
%2 = tensor.extract %1[%c0] : tensor<6xf32>
13491349
return %2 : f32
13501350
}
1351+
1352+
// -----
1353+
1354+
// CHECK-LABEL: func @recursive_function
1355+
func.func @recursive_function(%a: tensor<?xf32>, %b: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
1356+
// The analysis does not support recursive function calls and is conservative
1357+
// around them.
1358+
// CHECK: call @recursive_function
1359+
// CHECK-SAME: {__inplace_operands_attr__ = ["false", "false"]}
1360+
%0:2 = call @recursive_function(%a, %b) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>)
1361+
return %0#0, %0#1 : tensor<?xf32>, tensor<?xf32>
1362+
}

mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,6 @@ func.func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>
1919

2020
// -----
2121

22-
// expected-error @-3 {{expected callgraph to be free of circular dependencies}}
23-
24-
func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> {
25-
%0 = call @bar(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
26-
return %0 : tensor<5xf32>
27-
}
28-
29-
func.func @bar(%t: tensor<5xf32>) -> tensor<5xf32>{
30-
%0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
31-
return %0 : tensor<5xf32>
32-
}
33-
34-
// -----
35-
3622
func.func @scf_for(%A : tensor<?xf32>,
3723
%B : tensor<?xf32> {bufferization.writable = true},
3824
%C : tensor<4xf32>,

mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,3 +722,52 @@ func.func @bar(%t: tensor<5xf32>, %m: memref<5xf32>) -> memref<5xf32> {
722722
%0 = func.call @foo(%m) : (memref<5xf32>) -> (memref<5xf32>)
723723
return %0 : memref<5xf32>
724724
}
725+
726+
// -----
727+
728+
// A recursive function.
729+
730+
// CHECK-LABEL: func.func @foo(
731+
// CHECK-SAME: %[[arg0:.*]]: memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>> {
732+
func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> {
733+
// We are conservative around recursive functions. The analysis cannot handle
734+
// them, so we have to assume the op operand of the call op bufferizes to a
735+
// memory read and write. This causes a copy in this test case.
736+
// CHECK: %[[copy:.*]] = memref.alloc() {alignment = 64 : i64} : memref<5xf32>
737+
// CHECK: memref.copy %[[arg0]], %[[copy]]
738+
// CHECK: %[[cast:.*]] = memref.cast %[[copy]] : memref<5xf32> to memref<5xf32, strided<[?], offset: ?>>
739+
// CHECK: %[[call:.*]] = call @foo(%[[cast]])
740+
%0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
741+
742+
// CHECK: memref.load %[[arg0]]
743+
%c0 = arith.constant 0 : index
744+
%extr = tensor.extract %t[%c0] : tensor<5xf32>
745+
vector.print %extr : f32
746+
747+
// CHECK: return %[[call]]
748+
return %0 : tensor<5xf32>
749+
}
750+
751+
// -----
752+
753+
// Two functions calling each other recursively.
754+
755+
// CHECK-LABEL: func.func @foo(
756+
// CHECK-SAME: %[[arg0:.*]]: memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>> {
757+
// CHECK: %[[call:.*]] = call @bar(%[[arg0]]) : (memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>>
758+
// CHECK: return %[[call]]
759+
// CHECK: }
760+
func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> {
761+
%0 = call @bar(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
762+
return %0 : tensor<5xf32>
763+
}
764+
765+
// CHECK-LABEL: func.func @bar(
766+
// CHECK-SAME: %[[arg0:.*]]: memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>> {
767+
// CHECK: %[[call:.*]] = call @foo(%[[arg0]]) : (memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>>
768+
// CHECK: return %[[call]]
769+
// CHECK: }
770+
func.func @bar(%t: tensor<5xf32>) -> tensor<5xf32>{
771+
%0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
772+
return %0 : tensor<5xf32>
773+
}

0 commit comments

Comments
 (0)