@@ -310,21 +310,19 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
310
310
// / any func::CallOp.
311
311
static LogicalResult getFuncOpsOrderedByCalls (
312
312
ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
313
- SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap) {
313
+ SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap,
314
+ SymbolTableCollection &symbolTables) {
314
315
// For each FuncOp, the set of functions called by it (i.e. the union of
315
316
// symbols of all nested func::CallOp).
316
317
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
317
318
// For each FuncOp, the number of func::CallOp it contains.
318
319
DenseMap<func::FuncOp, unsigned > numberCallOpsContainedInFuncOp;
319
320
320
- // TODO Avoid recomputing the symbol tables every time.
321
- mlir::SymbolTableCollection symbolTable;
322
-
323
321
for (func::FuncOp funcOp : moduleOp.getOps <func::FuncOp>()) {
324
322
// Collect function calls and populate the caller map.
325
323
numberCallOpsContainedInFuncOp[funcOp] = 0 ;
326
324
WalkResult res = funcOp.walk ([&](func::CallOp callOp) -> WalkResult {
327
- func::FuncOp calledFunction = getCalledFunction (callOp, symbolTable );
325
+ func::FuncOp calledFunction = getCalledFunction (callOp, symbolTables );
328
326
assert (calledFunction && " could not retrieved called func::FuncOp" );
329
327
// If the called function does not have any tensors in its signature, then
330
328
// it is not necessary to bufferize the callee before the caller.
@@ -458,7 +456,8 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
458
456
FuncCallerMap callerMap;
459
457
460
458
if (failed (getFuncOpsOrderedByCalls (moduleOp, orderedFuncOps,
461
- remainingFuncOps, callerMap)))
459
+ remainingFuncOps, callerMap,
460
+ funcState.symbolTables )))
462
461
return failure ();
463
462
464
463
// Analyze functions in order. Starting with functions that are not calling
@@ -534,7 +533,8 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
534
533
// each other recursively are bufferized in an unspecified order at the end.
535
534
// We may use unnecessarily "complex" (in terms of layout map) buffer types.
536
535
if (failed (getFuncOpsOrderedByCalls (moduleOp, orderedFuncOps,
537
- remainingFuncOps, callerMap)))
536
+ remainingFuncOps, callerMap,
537
+ state.getSymbolTables ())))
538
538
return failure ();
539
539
llvm::append_range (orderedFuncOps, remainingFuncOps);
540
540
0 commit comments