Skip to content

Commit 89f8a74

Browse files
[mlir][bufferization] Add support for recursive function calls
1 parent 00ca207 commit 89f8a74

File tree

5 files changed

+137
-33
lines changed

5 files changed

+137
-33
lines changed

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
8282

8383
/// Return the FuncOp called by `callOp`.
8484
static FuncOp getCalledFunction(CallOpInterface callOp) {
85-
SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
85+
SymbolRefAttr sym =
86+
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
8687
if (!sym)
8788
return nullptr;
8889
return dyn_cast_or_null<FuncOp>(
@@ -206,11 +207,18 @@ struct CallOpInterface
206207
FuncOp funcOp = getCalledFunction(callOp);
207208
assert(funcOp && "expected CallOp to a FuncOp");
208209

209-
// The callee was already bufferized, so we can directly take the type from
210+
// If the callee was already bufferized, we can directly take the type from
210211
// its signature.
211212
FunctionType funcType = funcOp.getFunctionType();
212-
return cast<BaseMemRefType>(
213-
funcType.getResult(cast<OpResult>(value).getResultNumber()));
213+
Type resultType =
214+
funcType.getResult(cast<OpResult>(value).getResultNumber());
215+
if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
216+
return bufferizedType;
217+
218+
// Otherwise, call the type converter to compute the bufferized type.
219+
auto tensorType = cast<TensorType>(resultType);
220+
return options.functionArgTypeConverterFn(
221+
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
214222
}
215223

216224
/// All function arguments are writable. It is the responsibility of the
@@ -260,6 +268,18 @@ struct CallOpInterface
260268

261269
// Caller / callee type mismatch is handled with castOrReallocMemRefValue.
262270
auto memRefType = funcType.getInput(opOperand.getOperandNumber());
271+
if (!isa<BaseMemRefType>(memRefType)) {
272+
// The called function was not bufferized yet. This can happen when
273+
// there cycles in the function call graph. Compute the bufferized
274+
// result type.
275+
FailureOr<BaseMemRefType> maybeMemRefType =
276+
bufferization::getBufferType(
277+
funcOp.getArgument(opOperand.getOperandNumber()), options);
278+
if (failed(maybeMemRefType))
279+
return failure();
280+
memRefType = *maybeMemRefType;
281+
}
282+
263283
// Since we don't yet have a clear layout story, to_memref may
264284
// conservatively turn tensors into more dynamic memref than necessary.
265285
// If the memref type of the callee fails, introduce an extra memref.cast

mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -285,14 +285,17 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
285285
}
286286

287287
/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
288-
/// callee-caller order (i.e. callees without callers first).
288+
/// callee-caller order (i.e., callees without callers first). Store all
289+
/// remaining functions (i.e., the ones that call each other recursively) in
290+
/// `remainingFuncOps`.
291+
///
289292
/// Store the map of FuncOp to all its callers in `callerMap`.
290-
/// Return `failure()` if a cycle of calls is detected or if we are unable to
291-
/// retrieve the called FuncOp from any func::CallOp.
292-
static LogicalResult
293-
getFuncOpsOrderedByCalls(ModuleOp moduleOp,
294-
SmallVectorImpl<func::FuncOp> &orderedFuncOps,
295-
FuncCallerMap &callerMap) {
293+
///
294+
/// Return `failure()` if we are unable to retrieve the called FuncOp from
295+
/// any func::CallOp.
296+
static LogicalResult getFuncOpsOrderedByCalls(
297+
ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
298+
SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap) {
296299
// For each FuncOp, the set of functions called by it (i.e. the union of
297300
// symbols of all nested func::CallOp).
298301
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
@@ -326,19 +329,25 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
326329
});
327330
if (res.wasInterrupted())
328331
return failure();
332+
329333
// Iteratively remove function operations that do not call any of the
330-
// functions remaining in the callCounter map and add them to the worklist.
334+
// functions remaining in the callCounter map and add them to ordered list.
331335
while (!numberCallOpsContainedInFuncOp.empty()) {
332336
auto it = llvm::find_if(numberCallOpsContainedInFuncOp,
333337
[](auto entry) { return entry.getSecond() == 0; });
334338
if (it == numberCallOpsContainedInFuncOp.end())
335-
return moduleOp.emitOpError(
336-
"expected callgraph to be free of circular dependencies.");
339+
break;
337340
orderedFuncOps.push_back(it->getFirst());
338341
for (auto callee : calledBy[it->getFirst()])
339342
numberCallOpsContainedInFuncOp[callee]--;
340343
numberCallOpsContainedInFuncOp.erase(it);
341344
}
345+
346+
// Put all other functions in the list of remaining functions. These are
347+
// functions that call each each circularly.
348+
for (auto it : numberCallOpsContainedInFuncOp)
349+
remainingFuncOps.push_back(it.first);
350+
342351
return success();
343352
}
344353

@@ -379,15 +388,17 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
379388
FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
380389

381390
// A list of functions in the order in which they are analyzed + bufferized.
382-
SmallVector<func::FuncOp> orderedFuncOps;
391+
SmallVector<func::FuncOp> orderedFuncOps, remainingFuncOps;
383392

384393
// A mapping of FuncOps to their callers.
385394
FuncCallerMap callerMap;
386395

387-
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
396+
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
397+
remainingFuncOps, callerMap)))
388398
return failure();
389399

390-
// Analyze ops.
400+
// Analyze ops in order. Starting with functions that are not calling any
401+
// other functions.
391402
for (func::FuncOp funcOp : orderedFuncOps) {
392403
if (!state.getOptions().isOpAllowed(funcOp))
393404
continue;
@@ -411,6 +422,25 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
411422
funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed;
412423
}
413424

425+
// Analyze all other ops.
426+
for (func::FuncOp funcOp : remainingFuncOps) {
427+
if (!state.getOptions().isOpAllowed(funcOp))
428+
continue;
429+
430+
// Gather equivalence info for CallOps.
431+
equivalenceAnalysis(funcOp, state, funcState);
432+
433+
// Analyze funcOp.
434+
if (failed(analyzeOp(funcOp, state, statistics)))
435+
return failure();
436+
437+
// TODO: We currently skip all function argument analyses for functions
438+
// that call each other circularly. These analyses do not support recursive
439+
// calls yet. The `BufferizableOpInterface` implementations of `func`
440+
// dialect ops return conservative results in the absence of analysis
441+
// information.
442+
}
443+
414444
return success();
415445
}
416446

@@ -430,13 +460,20 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
430460
IRRewriter rewriter(moduleOp.getContext());
431461

432462
// A list of functions in the order in which they are analyzed + bufferized.
433-
SmallVector<func::FuncOp> orderedFuncOps;
463+
SmallVector<func::FuncOp> orderedFuncOps, remainingFuncOps;
434464

435465
// A mapping of FuncOps to their callers.
436466
FuncCallerMap callerMap;
437467

438-
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
468+
// Try to bufferize functions in calling order. I.e., first bufferize
469+
// functions that do not call other functions. This allows us to infer
470+
// accurate buffer types for function return values. Functions that call
471+
// each other recursively are bufferized in an unspecified order at the end.
472+
// We may use unnecessarily "complex" (in terms of layout map) buffer types.
473+
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
474+
remainingFuncOps, callerMap)))
439475
return failure();
476+
llvm::append_range(orderedFuncOps, remainingFuncOps);
440477

441478
// Bufferize functions.
442479
for (func::FuncOp funcOp : orderedFuncOps) {

mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,3 +1348,15 @@ func.func @private_func_aliasing(%t: tensor<?xf32>) -> f32 {
13481348
%2 = tensor.extract %1[%c0] : tensor<6xf32>
13491349
return %2 : f32
13501350
}
1351+
1352+
// -----
1353+
1354+
// CHECK-LABEL: func @recursive_function
1355+
func.func @recursive_function(%a: tensor<?xf32>, %b: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
1356+
// The analysis does not support recursive function calls and is conservative
1357+
// around them.
1358+
// CHECK: call @recursive_function
1359+
// CHECK-SAME: {__inplace_operands_attr__ = ["false", "false"]}
1360+
%0:2 = call @recursive_function(%a, %b) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>)
1361+
return %0#0, %0#1 : tensor<?xf32>, tensor<?xf32>
1362+
}

mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,6 @@ func.func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>
2525

2626
// -----
2727

28-
// expected-error @-3 {{expected callgraph to be free of circular dependencies}}
29-
30-
func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> {
31-
%0 = call @bar(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
32-
return %0 : tensor<5xf32>
33-
}
34-
35-
func.func @bar(%t: tensor<5xf32>) -> tensor<5xf32>{
36-
%0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
37-
return %0 : tensor<5xf32>
38-
}
39-
40-
// -----
41-
4228
func.func @scf_for(%A : tensor<?xf32>,
4329
%B : tensor<?xf32> {bufferization.writable = true},
4430
%C : tensor<4xf32>,

mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,3 +707,52 @@ func.func @bar(%t: tensor<5xf32>, %m: memref<5xf32>) -> memref<5xf32> {
707707
%0 = func.call @foo(%m) : (memref<5xf32>) -> (memref<5xf32>)
708708
return %0 : memref<5xf32>
709709
}
710+
711+
// -----
712+
713+
// A recursive function.
714+
715+
// CHECK-LABEL: func.func @foo(
716+
// CHECK-SAME: %[[arg0:.*]]: memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>> {
717+
func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> {
718+
// We are conservative around recursive functions. The analysis cannot handle
719+
// them, so we have to assume the op operand of the call op bufferizes to a
720+
// memory read and write. This causes a copy in this test case.
721+
// CHECK: %[[copy:.*]] = memref.alloc() {alignment = 64 : i64} : memref<5xf32>
722+
// CHECK: memref.copy %[[arg0]], %[[copy]]
723+
// CHECK: %[[cast:.*]] = memref.cast %[[copy]] : memref<5xf32> to memref<5xf32, strided<[?], offset: ?>>
724+
// CHECK: %[[call:.*]] = call @foo(%[[cast]])
725+
%0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
726+
727+
// CHECK: memref.load %[[arg0]]
728+
%c0 = arith.constant 0 : index
729+
%extr = tensor.extract %t[%c0] : tensor<5xf32>
730+
vector.print %extr : f32
731+
732+
// CHECK: return %[[call]]
733+
return %0 : tensor<5xf32>
734+
}
735+
736+
// -----
737+
738+
// Two functions calling each other recursively.
739+
740+
// CHECK-LABEL: func.func @foo(
741+
// CHECK-SAME: %[[arg0:.*]]: memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>> {
742+
// CHECK: %[[call:.*]] = call @bar(%[[arg0]]) : (memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>>
743+
// CHECK: return %[[call]]
744+
// CHECK: }
745+
func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> {
746+
%0 = call @bar(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
747+
return %0 : tensor<5xf32>
748+
}
749+
750+
// CHECK-LABEL: func.func @bar(
751+
// CHECK-SAME: %[[arg0:.*]]: memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>> {
752+
// CHECK: %[[call:.*]] = call @foo(%[[arg0]]) : (memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>>
753+
// CHECK: return %[[call]]
754+
// CHECK: }
755+
func.func @bar(%t: tensor<5xf32>) -> tensor<5xf32>{
756+
%0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
757+
return %0 : tensor<5xf32>
758+
}

0 commit comments

Comments
 (0)