Skip to content

[mlir][bufferization] Add support for recursive function calls #114003

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,18 @@ struct CallOpInterface
FuncOp funcOp = getCalledFunction(callOp);
assert(funcOp && "expected CallOp to a FuncOp");

// The callee was already bufferized, so we can directly take the type from
// If the callee was already bufferized, we can directly take the type from
// its signature.
FunctionType funcType = funcOp.getFunctionType();
return cast<BaseMemRefType>(
funcType.getResult(cast<OpResult>(value).getResultNumber()));
Type resultType =
funcType.getResult(cast<OpResult>(value).getResultNumber());
if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
return bufferizedType;

// Otherwise, call the type converter to compute the bufferized type.
auto tensorType = cast<TensorType>(resultType);
return options.functionArgTypeConverterFn(
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
}

/// All function arguments are writable. It is the responsibility of the
Expand Down Expand Up @@ -261,6 +268,18 @@ struct CallOpInterface

// Caller / callee type mismatch is handled with castOrReallocMemRefValue.
auto memRefType = funcType.getInput(opOperand.getOperandNumber());
if (!isa<BaseMemRefType>(memRefType)) {
// The called function was not bufferized yet. This can happen when
// there cycles in the function call graph. Compute the bufferized
// result type.
FailureOr<BaseMemRefType> maybeMemRefType =
bufferization::getBufferType(
funcOp.getArgument(opOperand.getOperandNumber()), options);
if (failed(maybeMemRefType))
return failure();
memRefType = *maybeMemRefType;
}

// Since we don't yet have a clear layout story, to_memref may
// conservatively turn tensors into more dynamic memref than necessary.
// If the memref type of the callee fails, introduce an extra memref.cast
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,14 +285,17 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
}

/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
/// callee-caller order (i.e. callees without callers first).
/// callee-caller order (i.e., callees without callers first). Store all
/// remaining functions (i.e., the ones that call each other recursively) in
/// `remainingFuncOps`.
///
/// Store the map of FuncOp to all its callers in `callerMap`.
/// Return `failure()` if a cycle of calls is detected or if we are unable to
/// retrieve the called FuncOp from any func::CallOp.
static LogicalResult
getFuncOpsOrderedByCalls(ModuleOp moduleOp,
SmallVectorImpl<func::FuncOp> &orderedFuncOps,
FuncCallerMap &callerMap) {
///
/// Return `failure()` if we are unable to retrieve the called FuncOp from
/// any func::CallOp.
static LogicalResult getFuncOpsOrderedByCalls(
ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap) {
// For each FuncOp, the set of functions called by it (i.e. the union of
// symbols of all nested func::CallOp).
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
Expand Down Expand Up @@ -326,19 +329,25 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
});
if (res.wasInterrupted())
return failure();

// Iteratively remove function operations that do not call any of the
// functions remaining in the callCounter map and add them to the worklist.
// functions remaining in the callCounter map and add them to ordered list.
while (!numberCallOpsContainedInFuncOp.empty()) {
auto it = llvm::find_if(numberCallOpsContainedInFuncOp,
[](auto entry) { return entry.getSecond() == 0; });
if (it == numberCallOpsContainedInFuncOp.end())
return moduleOp.emitOpError(
"expected callgraph to be free of circular dependencies.");
break;
orderedFuncOps.push_back(it->getFirst());
for (auto callee : calledBy[it->getFirst()])
numberCallOpsContainedInFuncOp[callee]--;
numberCallOpsContainedInFuncOp.erase(it);
}

// Put all other functions in the list of remaining functions. These are
// functions that call each other circularly.
for (auto it : numberCallOpsContainedInFuncOp)
remainingFuncOps.push_back(it.first);

return success();
}

Expand Down Expand Up @@ -378,16 +387,23 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
"expected that function boundary bufferization is activated");
FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);

// A list of functions in the order in which they are analyzed + bufferized.
// A list of non-circular functions in the order in which they are analyzed
// and bufferized.
SmallVector<func::FuncOp> orderedFuncOps;
// A list of all other functions. I.e., functions that call each other
// recursively. For these, we analyze the function body but not the function
// boundary.
SmallVector<func::FuncOp> remainingFuncOps;

// A mapping of FuncOps to their callers.
FuncCallerMap callerMap;

if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
remainingFuncOps, callerMap)))
return failure();

// Analyze ops.
// Analyze functions in order. Starting with functions that are not calling
// any other functions.
for (func::FuncOp funcOp : orderedFuncOps) {
if (!state.getOptions().isOpAllowed(funcOp))
continue;
Expand All @@ -411,6 +427,25 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed;
}

// Analyze all other functions. All function boundary analyses are skipped.
for (func::FuncOp funcOp : remainingFuncOps) {
if (!state.getOptions().isOpAllowed(funcOp))
continue;

// Gather equivalence info for CallOps.
equivalenceAnalysis(funcOp, state, funcState);

// Analyze funcOp.
if (failed(analyzeOp(funcOp, state, statistics)))
return failure();

// TODO: We currently skip all function argument analyses for functions
// that call each other circularly. These analyses do not support recursive
// calls yet. The `BufferizableOpInterface` implementations of `func`
// dialect ops return conservative results in the absence of analysis
// information.
}

return success();
}

Expand All @@ -429,14 +464,26 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
"expected that function boundary bufferization is activated");
IRRewriter rewriter(moduleOp.getContext());

// A list of functions in the order in which they are analyzed + bufferized.
// A list of non-circular functions in the order in which they are analyzed
// and bufferized.
SmallVector<func::FuncOp> orderedFuncOps;
// A list of all other functions. I.e., functions that call each other
// recursively. For these, we analyze the function body but not the function
// boundary.
SmallVector<func::FuncOp> remainingFuncOps;

// A mapping of FuncOps to their callers.
FuncCallerMap callerMap;

if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
// Try to bufferize functions in calling order. I.e., first bufferize
// functions that do not call other functions. This allows us to infer
// accurate buffer types for function return values. Functions that call
// each other recursively are bufferized in an unspecified order at the end.
// We may use unnecessarily "complex" (in terms of layout map) buffer types.
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
remainingFuncOps, callerMap)))
return failure();
llvm::append_range(orderedFuncOps, remainingFuncOps);

// Bufferize functions.
for (func::FuncOp funcOp : orderedFuncOps) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1348,3 +1348,15 @@ func.func @private_func_aliasing(%t: tensor<?xf32>) -> f32 {
%2 = tensor.extract %1[%c0] : tensor<6xf32>
return %2 : f32
}

// -----

// CHECK-LABEL: func @recursive_function
func.func @recursive_function(%a: tensor<?xf32>, %b: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
// The analysis does not support recursive function calls and is conservative
// around them.
// CHECK: call @recursive_function
// CHECK-SAME: {__inplace_operands_attr__ = ["false", "false"]}
%0:2 = call @recursive_function(%a, %b) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>)
return %0#0, %0#1 : tensor<?xf32>, tensor<?xf32>
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,6 @@ func.func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>

// -----

// expected-error @-3 {{expected callgraph to be free of circular dependencies}}

func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> {
%0 = call @bar(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
return %0 : tensor<5xf32>
}

func.func @bar(%t: tensor<5xf32>) -> tensor<5xf32>{
%0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
return %0 : tensor<5xf32>
}

// -----

func.func @scf_for(%A : tensor<?xf32>,
%B : tensor<?xf32> {bufferization.writable = true},
%C : tensor<4xf32>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -722,3 +722,52 @@ func.func @bar(%t: tensor<5xf32>, %m: memref<5xf32>) -> memref<5xf32> {
%0 = func.call @foo(%m) : (memref<5xf32>) -> (memref<5xf32>)
return %0 : memref<5xf32>
}

// -----

// A recursive function.

// CHECK-LABEL: func.func @foo(
// CHECK-SAME: %[[arg0:.*]]: memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>> {
func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> {
// We are conservative around recursive functions. The analysis cannot handle
// them, so we have to assume the op operand of the call op bufferizes to a
// memory read and write. This causes a copy in this test case.
// CHECK: %[[copy:.*]] = memref.alloc() {alignment = 64 : i64} : memref<5xf32>
// CHECK: memref.copy %[[arg0]], %[[copy]]
// CHECK: %[[cast:.*]] = memref.cast %[[copy]] : memref<5xf32> to memref<5xf32, strided<[?], offset: ?>>
// CHECK: %[[call:.*]] = call @foo(%[[cast]])
%0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)

// CHECK: memref.load %[[arg0]]
%c0 = arith.constant 0 : index
%extr = tensor.extract %t[%c0] : tensor<5xf32>
vector.print %extr : f32

// CHECK: return %[[call]]
return %0 : tensor<5xf32>
}

// -----

// Two functions calling each other recursively.

// CHECK-LABEL: func.func @foo(
// CHECK-SAME: %[[arg0:.*]]: memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>> {
// CHECK: %[[call:.*]] = call @bar(%[[arg0]]) : (memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>>
// CHECK: return %[[call]]
// CHECK: }
func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> {
%0 = call @bar(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
return %0 : tensor<5xf32>
}

// CHECK-LABEL: func.func @bar(
// CHECK-SAME: %[[arg0:.*]]: memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>> {
// CHECK: %[[call:.*]] = call @foo(%[[arg0]]) : (memref<5xf32, strided<[?], offset: ?>>) -> memref<5xf32, strided<[?], offset: ?>>
// CHECK: return %[[call]]
// CHECK: }
func.func @bar(%t: tensor<5xf32>) -> tensor<5xf32>{
%0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
return %0 : tensor<5xf32>
}
Loading