@@ -285,14 +285,17 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
285
285
}
286
286
287
287
// / Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
288
- // / callee-caller order (i.e. callees without callers first).
288
+ // / callee-caller order (i.e., callees without callers first). Store all
289
+ // / remaining functions (i.e., the ones that call each other recursively) in
290
+ // / `remainingFuncOps`.
291
+ // /
289
292
// / Store the map of FuncOp to all its callers in `callerMap`.
290
- // / Return `failure()` if a cycle of calls is detected or if we are unable to
291
- // / retrieve the called FuncOp from any func::CallOp.
292
- static LogicalResult
293
- getFuncOpsOrderedByCalls (ModuleOp moduleOp,
294
- SmallVectorImpl<func::FuncOp> &orderedFuncOps,
295
- FuncCallerMap &callerMap) {
293
+ // /
294
+ // / Return `failure()` if we are unable to retrieve the called FuncOp from
295
+ // / any func::CallOp.
296
+ static LogicalResult getFuncOpsOrderedByCalls (
297
+ ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
298
+ SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap) {
296
299
// For each FuncOp, the set of functions called by it (i.e. the union of
297
300
// symbols of all nested func::CallOp).
298
301
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
@@ -326,19 +329,25 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
326
329
});
327
330
if (res.wasInterrupted ())
328
331
return failure ();
332
+
329
333
// Iteratively remove function operations that do not call any of the
330
- // functions remaining in the callCounter map and add them to the worklist .
334
+ // functions remaining in the callCounter map and add them to ordered list .
331
335
while (!numberCallOpsContainedInFuncOp.empty ()) {
332
336
auto it = llvm::find_if (numberCallOpsContainedInFuncOp,
333
337
[](auto entry) { return entry.getSecond () == 0 ; });
334
338
if (it == numberCallOpsContainedInFuncOp.end ())
335
- return moduleOp.emitOpError (
336
- " expected callgraph to be free of circular dependencies." );
339
+ break ;
337
340
orderedFuncOps.push_back (it->getFirst ());
338
341
for (auto callee : calledBy[it->getFirst ()])
339
342
numberCallOpsContainedInFuncOp[callee]--;
340
343
numberCallOpsContainedInFuncOp.erase (it);
341
344
}
345
+
346
+ // Put all other functions in the list of remaining functions. These are
347
+ // functions that call each other circularly.
348
+ for (auto it : numberCallOpsContainedInFuncOp)
349
+ remainingFuncOps.push_back (it.first );
350
+
342
351
return success ();
343
352
}
344
353
@@ -378,16 +387,23 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
378
387
" expected that function boundary bufferization is activated" );
379
388
FuncAnalysisState &funcState = getOrCreateFuncAnalysisState (state);
380
389
381
- // A list of functions in the order in which they are analyzed + bufferized.
390
+ // A list of non-circular functions in the order in which they are analyzed
391
+ // and bufferized.
382
392
SmallVector<func::FuncOp> orderedFuncOps;
393
+ // A list of all other functions. I.e., functions that call each other
394
+ // recursively. For these, we analyze the function body but not the function
395
+ // boundary.
396
+ SmallVector<func::FuncOp> remainingFuncOps;
383
397
384
398
// A mapping of FuncOps to their callers.
385
399
FuncCallerMap callerMap;
386
400
387
- if (failed (getFuncOpsOrderedByCalls (moduleOp, orderedFuncOps, callerMap)))
401
+ if (failed (getFuncOpsOrderedByCalls (moduleOp, orderedFuncOps,
402
+ remainingFuncOps, callerMap)))
388
403
return failure ();
389
404
390
- // Analyze ops.
405
+ // Analyze functions in order. Starting with functions that are not calling
406
+ // any other functions.
391
407
for (func::FuncOp funcOp : orderedFuncOps) {
392
408
if (!state.getOptions ().isOpAllowed (funcOp))
393
409
continue ;
@@ -411,6 +427,25 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
411
427
funcState.analyzedFuncOps [funcOp] = FuncOpAnalysisState::Analyzed;
412
428
}
413
429
430
+ // Analyze all other functions. All function boundary analyses are skipped.
431
+ for (func::FuncOp funcOp : remainingFuncOps) {
432
+ if (!state.getOptions ().isOpAllowed (funcOp))
433
+ continue ;
434
+
435
+ // Gather equivalence info for CallOps.
436
+ equivalenceAnalysis (funcOp, state, funcState);
437
+
438
+ // Analyze funcOp.
439
+ if (failed (analyzeOp (funcOp, state, statistics)))
440
+ return failure ();
441
+
442
+ // TODO: We currently skip all function argument analyses for functions
443
+ // that call each other circularly. These analyses do not support recursive
444
+ // calls yet. The `BufferizableOpInterface` implementations of `func`
445
+ // dialect ops return conservative results in the absence of analysis
446
+ // information.
447
+ }
448
+
414
449
return success ();
415
450
}
416
451
@@ -429,14 +464,26 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
429
464
" expected that function boundary bufferization is activated" );
430
465
IRRewriter rewriter (moduleOp.getContext ());
431
466
432
- // A list of functions in the order in which they are analyzed + bufferized.
467
+ // A list of non-circular functions in the order in which they are analyzed
468
+ // and bufferized.
433
469
SmallVector<func::FuncOp> orderedFuncOps;
470
+ // A list of all other functions. I.e., functions that call each other
471
+ // recursively. For these, we analyze the function body but not the function
472
+ // boundary.
473
+ SmallVector<func::FuncOp> remainingFuncOps;
434
474
435
475
// A mapping of FuncOps to their callers.
436
476
FuncCallerMap callerMap;
437
477
438
- if (failed (getFuncOpsOrderedByCalls (moduleOp, orderedFuncOps, callerMap)))
478
+ // Try to bufferize functions in calling order. I.e., first bufferize
479
+ // functions that do not call other functions. This allows us to infer
480
+ // accurate buffer types for function return values. Functions that call
481
+ // each other recursively are bufferized in an unspecified order at the end.
482
+ // We may use unnecessarily "complex" (in terms of layout map) buffer types.
483
+ if (failed (getFuncOpsOrderedByCalls (moduleOp, orderedFuncOps,
484
+ remainingFuncOps, callerMap)))
439
485
return failure ();
486
+ llvm::append_range (orderedFuncOps, remainingFuncOps);
440
487
441
488
// Bufferize functions.
442
489
for (func::FuncOp funcOp : orderedFuncOps) {
0 commit comments