@@ -285,14 +285,17 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
285
285
}
286
286
287
287
// / Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
288
- // / callee-caller order (i.e. callees without callers first).
288
+ // / callee-caller order (i.e., callees without callers first). Store all
289
+ // / remaining functions (i.e., the ones that call each other recursively) in
290
+ // / `remainingFuncOps`.
291
+ // /
289
292
// / Store the map of FuncOp to all its callers in `callerMap`.
290
- // / Return `failure()` if a cycle of calls is detected or if we are unable to
291
- // / retrieve the called FuncOp from any func::CallOp.
292
- static LogicalResult
293
- getFuncOpsOrderedByCalls (ModuleOp moduleOp,
294
- SmallVectorImpl<func::FuncOp> &orderedFuncOps,
295
- FuncCallerMap &callerMap) {
293
+ // /
294
+ // / Return `failure()` if we are unable to retrieve the called FuncOp from
295
+ // / any func::CallOp.
296
+ static LogicalResult getFuncOpsOrderedByCalls (
297
+ ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
298
+ SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap) {
296
299
// For each FuncOp, the set of functions called by it (i.e. the union of
297
300
// symbols of all nested func::CallOp).
298
301
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
@@ -326,19 +329,25 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
326
329
});
327
330
if (res.wasInterrupted ())
328
331
return failure ();
332
+
329
333
// Iteratively remove function operations that do not call any of the
330
- // functions remaining in the callCounter map and add them to the worklist .
334
+ // functions remaining in the callCounter map and add them to ordered list .
331
335
while (!numberCallOpsContainedInFuncOp.empty ()) {
332
336
auto it = llvm::find_if (numberCallOpsContainedInFuncOp,
333
337
[](auto entry) { return entry.getSecond () == 0 ; });
334
338
if (it == numberCallOpsContainedInFuncOp.end ())
335
- return moduleOp.emitOpError (
336
- " expected callgraph to be free of circular dependencies." );
339
+ break ;
337
340
orderedFuncOps.push_back (it->getFirst ());
338
341
for (auto callee : calledBy[it->getFirst ()])
339
342
numberCallOpsContainedInFuncOp[callee]--;
340
343
numberCallOpsContainedInFuncOp.erase (it);
341
344
}
345
+
346
+ // Put all other functions in the list of remaining functions. These are
347
+ // functions that call each each circularly.
348
+ for (auto it : numberCallOpsContainedInFuncOp)
349
+ remainingFuncOps.push_back (it.first );
350
+
342
351
return success ();
343
352
}
344
353
@@ -379,15 +388,17 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
379
388
FuncAnalysisState &funcState = getOrCreateFuncAnalysisState (state);
380
389
381
390
// A list of functions in the order in which they are analyzed + bufferized.
382
- SmallVector<func::FuncOp> orderedFuncOps;
391
+ SmallVector<func::FuncOp> orderedFuncOps, remainingFuncOps ;
383
392
384
393
// A mapping of FuncOps to their callers.
385
394
FuncCallerMap callerMap;
386
395
387
- if (failed (getFuncOpsOrderedByCalls (moduleOp, orderedFuncOps, callerMap)))
396
+ if (failed (getFuncOpsOrderedByCalls (moduleOp, orderedFuncOps,
397
+ remainingFuncOps, callerMap)))
388
398
return failure ();
389
399
390
- // Analyze ops.
400
+ // Analyze ops in order. Starting with functions that are not calling any
401
+ // other functions.
391
402
for (func::FuncOp funcOp : orderedFuncOps) {
392
403
if (!state.getOptions ().isOpAllowed (funcOp))
393
404
continue ;
@@ -411,6 +422,25 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
411
422
funcState.analyzedFuncOps [funcOp] = FuncOpAnalysisState::Analyzed;
412
423
}
413
424
425
+ // Analyze all other ops.
426
+ for (func::FuncOp funcOp : remainingFuncOps) {
427
+ if (!state.getOptions ().isOpAllowed (funcOp))
428
+ continue ;
429
+
430
+ // Gather equivalence info for CallOps.
431
+ equivalenceAnalysis (funcOp, state, funcState);
432
+
433
+ // Analyze funcOp.
434
+ if (failed (analyzeOp (funcOp, state, statistics)))
435
+ return failure ();
436
+
437
+ // TODO: We currently skip all function argument analyses for functions
438
+ // that call each other circularly. These analyses do not support recursive
439
+ // calls yet. The `BufferizableOpInterface` implementations of `func`
440
+ // dialect ops return conservative results in the absence of analysis
441
+ // information.
442
+ }
443
+
414
444
return success ();
415
445
}
416
446
@@ -430,13 +460,20 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
430
460
IRRewriter rewriter (moduleOp.getContext ());
431
461
432
462
// A list of functions in the order in which they are analyzed + bufferized.
433
- SmallVector<func::FuncOp> orderedFuncOps;
463
+ SmallVector<func::FuncOp> orderedFuncOps, remainingFuncOps ;
434
464
435
465
// A mapping of FuncOps to their callers.
436
466
FuncCallerMap callerMap;
437
467
438
- if (failed (getFuncOpsOrderedByCalls (moduleOp, orderedFuncOps, callerMap)))
468
+ // Try to bufferize functions in calling order. I.e., first bufferize
469
+ // functions that do not call other functions. This allows us to infer
470
+ // accurate buffer types for function return values. Functions that call
471
+ // each other recursively are bufferized in an unspecified order at the end.
472
+ // We may use unnecessarily "complex" (in terms of layout map) buffer types.
473
+ if (failed (getFuncOpsOrderedByCalls (moduleOp, orderedFuncOps,
474
+ remainingFuncOps, callerMap)))
439
475
return failure ();
476
+ llvm::append_range (orderedFuncOps, remainingFuncOps);
440
477
441
478
// Bufferize functions.
442
479
for (func::FuncOp funcOp : orderedFuncOps) {
0 commit comments