@@ -33,7 +33,7 @@ using namespace mlir;
33
33
34
34
// / Walk all of the used symbol callgraph nodes referenced with the given op.
35
35
static void walkReferencedSymbolNodes (
36
- Operation *op, CallGraph &cg,
36
+ Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable,
37
37
DenseMap<Attribute, CallGraphNode *> &resolvedRefs,
38
38
function_ref<void (CallGraphNode *, Operation *)> callback) {
39
39
auto symbolUses = SymbolTable::getSymbolUses (op);
@@ -47,8 +47,8 @@ static void walkReferencedSymbolNodes(
47
47
// If this is the first instance of this reference, try to resolve a
48
48
// callgraph node for it.
49
49
if (refIt.second ) {
50
- auto *symbolOp = SymbolTable:: lookupNearestSymbolFrom (symbolTableOp,
51
- use.getSymbolRef ());
50
+ auto *symbolOp = symbolTable. lookupNearestSymbolFrom (symbolTableOp,
51
+ use.getSymbolRef ());
52
52
auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
53
53
if (!callableOp)
54
54
continue ;
@@ -80,7 +80,7 @@ struct CGUseList {
80
80
DenseMap<CallGraphNode *, int > innerUses;
81
81
};
82
82
83
- CGUseList (Operation *op, CallGraph &cg);
83
+ CGUseList (Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable );
84
84
85
85
// / Drop uses of nodes referred to by the given call operation that resides
86
86
// / within 'userNode'.
@@ -110,13 +110,19 @@ struct CGUseList {
110
110
// / A mapping between a discardable callgraph node (that is a symbol) and the
111
111
// / number of uses for this node.
112
112
DenseMap<CallGraphNode *, int > discardableSymNodeUses;
113
+
113
114
// / A mapping between a callgraph node and the symbol callgraph nodes that it
114
115
// / uses.
115
116
DenseMap<CallGraphNode *, CGUser> nodeUses;
117
+
118
+ // / A symbol table to use when resolving call lookups.
119
+ SymbolTableCollection &symbolTable;
116
120
};
117
121
} // end anonymous namespace
118
122
119
- CGUseList::CGUseList (Operation *op, CallGraph &cg) {
123
+ CGUseList::CGUseList (Operation *op, CallGraph &cg,
124
+ SymbolTableCollection &symbolTable)
125
+ : symbolTable(symbolTable) {
120
126
// / A set of callgraph nodes that are always known to be live during inlining.
121
127
DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes;
122
128
@@ -135,7 +141,7 @@ CGUseList::CGUseList(Operation *op, CallGraph &cg) {
135
141
}
136
142
}
137
143
// Otherwise, check for any referenced nodes. These will be always-live.
138
- walkReferencedSymbolNodes (&op, cg, alwaysLiveNodes,
144
+ walkReferencedSymbolNodes (&op, cg, symbolTable, alwaysLiveNodes,
139
145
[](CallGraphNode *, Operation *) {});
140
146
}
141
147
};
@@ -162,7 +168,7 @@ void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp,
162
168
--discardableSymNodeUses[node];
163
169
};
164
170
DenseMap<Attribute, CallGraphNode *> resolvedRefs;
165
- walkReferencedSymbolNodes (callOp, cg, resolvedRefs, walkFn);
171
+ walkReferencedSymbolNodes (callOp, cg, symbolTable, resolvedRefs, walkFn);
166
172
}
167
173
168
174
void CGUseList::eraseNode (CallGraphNode *node) {
@@ -220,7 +226,7 @@ void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) {
220
226
return ;
221
227
++discardSymIt->second ;
222
228
};
223
- walkReferencedSymbolNodes (parentOp, cg, resolvedRefs, walkFn);
229
+ walkReferencedSymbolNodes (parentOp, cg, symbolTable, resolvedRefs, walkFn);
224
230
}
225
231
226
232
void CGUseList::mergeUsesAfterInlining (CallGraphNode *lhs, CallGraphNode *rhs) {
@@ -305,6 +311,7 @@ struct ResolvedCall {
305
311
// / inside of nested callgraph nodes.
306
312
static void collectCallOps (iterator_range<Region::iterator> blocks,
307
313
CallGraphNode *sourceNode, CallGraph &cg,
314
+ SymbolTableCollection &symbolTable,
308
315
SmallVectorImpl<ResolvedCall> &calls,
309
316
bool traverseNestedCGNodes) {
310
317
SmallVector<std::pair<Block *, CallGraphNode *>, 8 > worklist;
@@ -328,7 +335,7 @@ static void collectCallOps(iterator_range<Region::iterator> blocks,
328
335
continue ;
329
336
}
330
337
331
- CallGraphNode *targetNode = cg.resolveCallable (call);
338
+ CallGraphNode *targetNode = cg.resolveCallable (call, symbolTable );
332
339
if (!targetNode->isExternal ())
333
340
calls.emplace_back (call, sourceNode, targetNode);
334
341
continue ;
@@ -352,8 +359,9 @@ static void collectCallOps(iterator_range<Region::iterator> blocks,
352
359
namespace {
353
360
// / This class provides a specialization of the main inlining interface.
354
361
struct Inliner : public InlinerInterface {
355
- Inliner (MLIRContext *context, CallGraph &cg)
356
- : InlinerInterface(context), cg(cg) {}
362
+ Inliner (MLIRContext *context, CallGraph &cg,
363
+ SymbolTableCollection &symbolTable)
364
+ : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {}
357
365
358
366
// / Process a set of blocks that have been inlined. This callback is invoked
359
367
// / *before* inlined terminator operations have been processed.
@@ -367,7 +375,7 @@ struct Inliner : public InlinerInterface {
367
375
assert (region && " expected valid parent node" );
368
376
}
369
377
370
- collectCallOps (inlinedBlocks, node, cg, calls,
378
+ collectCallOps (inlinedBlocks, node, cg, symbolTable, calls,
371
379
/* traverseNestedCGNodes=*/ true );
372
380
}
373
381
@@ -389,6 +397,9 @@ struct Inliner : public InlinerInterface {
389
397
390
398
// / The callgraph being operated on.
391
399
CallGraph &cg;
400
+
401
+ // / A symbol table to use when resolving call lookups.
402
+ SymbolTableCollection &symbolTable;
392
403
};
393
404
} // namespace
394
405
@@ -427,11 +438,12 @@ static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
427
438
continue ;
428
439
429
440
// Don't collect calls if the node is already dead.
430
- if (useList.isDead (node))
441
+ if (useList.isDead (node)) {
431
442
deadNodes.push_back (node);
432
- else
433
- collectCallOps (*node->getCallableRegion (), node, cg, calls,
434
- /* traverseNestedCGNodes=*/ false );
443
+ } else {
444
+ collectCallOps (*node->getCallableRegion (), node, cg, inliner.symbolTable ,
445
+ calls, /* traverseNestedCGNodes=*/ false );
446
+ }
435
447
}
436
448
437
449
// Try to inline each of the call operations. Don't cache the end iterator
@@ -585,8 +597,9 @@ void InlinerPass::runOnOperation() {
585
597
op->getCanonicalizationPatterns (canonPatterns, context);
586
598
587
599
// Run the inline transform in post-order over the SCCs in the callgraph.
588
- Inliner inliner (context, cg);
589
- CGUseList useList (getOperation (), cg);
600
+ SymbolTableCollection symbolTable;
601
+ Inliner inliner (context, cg, symbolTable);
602
+ CGUseList useList (getOperation (), cg, symbolTable);
590
603
runTransformOnCGSCCs (cg, [&](CallGraphSCC &scc) {
591
604
inlineSCC (inliner, useList, scc, context, canonPatterns);
592
605
});
0 commit comments