Skip to content

Commit 9289604

Browse files
authored
[MLIR] Use cached symbol tables in getFuncOpsOrderedByCalls (#141967)
Address TODO regarding the recomputation of symbol tables. The signature of the `getFuncOpsOrderedByCalls` function is modified to receive the collection of cached symbol tables.
1 parent 3374263 commit 9289604

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -310,21 +310,19 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
310310
/// any func::CallOp.
311311
static LogicalResult getFuncOpsOrderedByCalls(
312312
ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
313-
SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap) {
313+
SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap,
314+
SymbolTableCollection &symbolTables) {
314315
// For each FuncOp, the set of functions called by it (i.e. the union of
315316
// symbols of all nested func::CallOp).
316317
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
317318
// For each FuncOp, the number of func::CallOp it contains.
318319
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
319320

320-
// TODO Avoid recomputing the symbol tables every time.
321-
mlir::SymbolTableCollection symbolTable;
322-
323321
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
324322
// Collect function calls and populate the caller map.
325323
numberCallOpsContainedInFuncOp[funcOp] = 0;
326324
WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
327-
func::FuncOp calledFunction = getCalledFunction(callOp, symbolTable);
325+
func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables);
328326
assert(calledFunction && "could not retrieved called func::FuncOp");
329327
// If the called function does not have any tensors in its signature, then
330328
// it is not necessary to bufferize the callee before the caller.
@@ -458,7 +456,8 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
458456
FuncCallerMap callerMap;
459457

460458
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
461-
remainingFuncOps, callerMap)))
459+
remainingFuncOps, callerMap,
460+
funcState.symbolTables)))
462461
return failure();
463462

464463
// Analyze functions in order. Starting with functions that are not calling
@@ -534,7 +533,8 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
534533
// each other recursively are bufferized in an unspecified order at the end.
535534
// We may use unnecessarily "complex" (in terms of layout map) buffer types.
536535
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
537-
remainingFuncOps, callerMap)))
536+
remainingFuncOps, callerMap,
537+
state.getSymbolTables())))
538538
return failure();
539539
llvm::append_range(orderedFuncOps, remainingFuncOps);
540540

0 commit comments

Comments
 (0)