Skip to content

Commit f986d6c

Browse files
author
git apple-llvm automerger
committed
Merge commit '32a9d8bddbf4' from apple/main into swift/next
2 parents 0e1a3ca + 32a9d8b commit f986d6c

File tree

6 files changed

+91
-46
lines changed

6 files changed

+91
-46
lines changed

mlir/include/mlir/Analysis/CallGraph.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class CallOpInterface;
2727
struct CallInterfaceCallable;
2828
class Operation;
2929
class Region;
30+
class SymbolTableCollection;
3031

3132
//===----------------------------------------------------------------------===//
3233
// CallGraphNode
@@ -189,8 +190,11 @@ class CallGraph {
189190
}
190191

191192
/// Resolve the callable for given callee to a node in the callgraph, or the
192-
/// external node if a valid node was not resolved.
193-
CallGraphNode *resolveCallable(CallOpInterface call) const;
193+
/// external node if a valid node was not resolved. The provided symbol table
194+
/// is used when resolving calls that reference callables via a symbol
195+
/// reference.
196+
CallGraphNode *resolveCallable(CallOpInterface call,
197+
SymbolTableCollection &symbolTable) const;
194198

195199
/// Erase the given node from the callgraph.
196200
void eraseNode(CallGraphNode *node);

mlir/include/mlir/Interfaces/CallInterfaces.td

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,19 +46,16 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
4646
}],
4747
"Operation::operand_range", "getArgOperands"
4848
>,
49-
InterfaceMethod<[{
50-
Resolve the callable operation for given callee to a
51-
CallableOpInterface, or nullptr if a valid callable was not resolved.
52-
}],
53-
"Operation *", "resolveCallable", (ins), [{
54-
// If the callable isn't a value, lookup the symbol reference.
55-
CallInterfaceCallable callable = $_op.getCallableForCallee();
56-
if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>())
57-
return SymbolTable::lookupNearestSymbolFrom($_op, symbolRef);
58-
return callable.get<Value>().getDefiningOp();
59-
}]
60-
>,
6149
];
50+
51+
let extraClassDeclaration = [{
52+
/// Resolve the callable operation for given callee to a
53+
/// CallableOpInterface, or nullptr if a valid callable was not resolved.
54+
/// `symbolTable` is an optional parameter that will allow for using a
55+
/// cached symbol table for symbol lookups instead of performing an O(N)
56+
/// scan.
57+
Operation *resolveCallable(SymbolTableCollection *symbolTable = nullptr);
58+
}];
6259
}
6360

6461
/// Interface for callable operations.

mlir/lib/Analysis/CallGraph.cpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,14 @@ void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) {
6868
/// Recursively compute the callgraph edges for the given operation. Computed
6969
/// edges are placed into the given callgraph object.
7070
static void computeCallGraph(Operation *op, CallGraph &cg,
71+
SymbolTableCollection &symbolTable,
7172
CallGraphNode *parentNode, bool resolveCalls) {
7273
if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) {
7374
// If there is no parent node, we ignore this operation. Even if this
7475
// operation was a call, there would be no callgraph node to attribute it
7576
// to.
7677
if (resolveCalls && parentNode)
77-
parentNode->addCallEdge(cg.resolveCallable(call));
78+
parentNode->addCallEdge(cg.resolveCallable(call, symbolTable));
7879
return;
7980
}
8081

@@ -88,15 +89,18 @@ static void computeCallGraph(Operation *op, CallGraph &cg,
8889

8990
for (Region &region : op->getRegions())
9091
for (Operation &nested : region.getOps())
91-
computeCallGraph(&nested, cg, parentNode, resolveCalls);
92+
computeCallGraph(&nested, cg, symbolTable, parentNode, resolveCalls);
9293
}
9394

9495
CallGraph::CallGraph(Operation *op) : externalNode(/*callableRegion=*/nullptr) {
9596
// Make two passes over the graph, one to compute the callables and one to
9697
// resolve the calls. We split these up as we may have nested callable objects
9798
// that need to be reserved before the calls.
98-
computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/false);
99-
computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/true);
99+
SymbolTableCollection symbolTable;
100+
computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr,
101+
/*resolveCalls=*/false);
102+
computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr,
103+
/*resolveCalls=*/true);
100104
}
101105

102106
/// Get or add a call graph node for the given region.
@@ -109,16 +113,17 @@ CallGraphNode *CallGraph::getOrAddNode(Region *region,
109113
node.reset(new CallGraphNode(region));
110114

111115
// Add this node to the given parent node if necessary.
112-
if (parentNode)
116+
if (parentNode) {
113117
parentNode->addChildEdge(node.get());
114-
else
118+
} else {
115119
// Otherwise, connect all callable nodes to the external node, this allows
116120
// for conservatively including all callable nodes within the graph.
117-
// FIXME(riverriddle) This isn't correct, this is only necessary for
118-
// callable nodes that *could* be called from external sources. This
119-
// requires extending the interface for callables to check if they may be
120-
// referenced externally.
121+
// FIXME This isn't correct, this is only necessary for callable nodes
122+
// that *could* be called from external sources. This requires extending
123+
// the interface for callables to check if they may be referenced
124+
// externally.
121125
externalNode.addAbstractEdge(node.get());
126+
}
122127
}
123128
return node.get();
124129
}
@@ -132,8 +137,10 @@ CallGraphNode *CallGraph::lookupNode(Region *region) const {
132137

133138
/// Resolve the callable for given callee to a node in the callgraph, or the
134139
/// external node if a valid node was not resolved.
135-
CallGraphNode *CallGraph::resolveCallable(CallOpInterface call) const {
136-
Operation *callable = call.resolveCallable();
140+
CallGraphNode *
141+
CallGraph::resolveCallable(CallOpInterface call,
142+
SymbolTableCollection &symbolTable) const {
143+
Operation *callable = call.resolveCallable(&symbolTable);
137144
if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callable))
138145
if (auto *node = lookupNode(callableOp.getCallableRegion()))
139146
return node;

mlir/lib/Interfaces/CallInterfaces.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,27 @@
1010

1111
using namespace mlir;
1212

13+
//===----------------------------------------------------------------------===//
14+
// CallOpInterface
15+
//===----------------------------------------------------------------------===//
16+
17+
/// Resolve the callable operation for given callee to a CallableOpInterface, or
18+
/// nullptr if a valid callable was not resolved. `symbolTable` is an optional
19+
/// parameter that will allow for using a cached symbol table for symbol lookups
20+
/// instead of performing an O(N) scan.
21+
Operation *
22+
CallOpInterface::resolveCallable(SymbolTableCollection *symbolTable) {
23+
CallInterfaceCallable callable = getCallableForCallee();
24+
if (auto symbolVal = callable.dyn_cast<Value>())
25+
return symbolVal.getDefiningOp();
26+
27+
// If the callable isn't a value, lookup the symbol reference.
28+
auto symbolRef = callable.get<SymbolRefAttr>();
29+
if (symbolTable)
30+
return symbolTable->lookupNearestSymbolFrom(getOperation(), symbolRef);
31+
return SymbolTable::lookupNearestSymbolFrom(getOperation(), symbolRef);
32+
}
33+
1334
//===----------------------------------------------------------------------===//
1435
// CallInterfaces
1536
//===----------------------------------------------------------------------===//

mlir/lib/Transforms/Inliner.cpp

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ using namespace mlir;
3333

3434
/// Walk all of the used symbol callgraph nodes referenced with the given op.
3535
static void walkReferencedSymbolNodes(
36-
Operation *op, CallGraph &cg,
36+
Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable,
3737
DenseMap<Attribute, CallGraphNode *> &resolvedRefs,
3838
function_ref<void(CallGraphNode *, Operation *)> callback) {
3939
auto symbolUses = SymbolTable::getSymbolUses(op);
@@ -47,8 +47,8 @@ static void walkReferencedSymbolNodes(
4747
// If this is the first instance of this reference, try to resolve a
4848
// callgraph node for it.
4949
if (refIt.second) {
50-
auto *symbolOp = SymbolTable::lookupNearestSymbolFrom(symbolTableOp,
51-
use.getSymbolRef());
50+
auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp,
51+
use.getSymbolRef());
5252
auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
5353
if (!callableOp)
5454
continue;
@@ -80,7 +80,7 @@ struct CGUseList {
8080
DenseMap<CallGraphNode *, int> innerUses;
8181
};
8282

83-
CGUseList(Operation *op, CallGraph &cg);
83+
CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable);
8484

8585
/// Drop uses of nodes referred to by the given call operation that resides
8686
/// within 'userNode'.
@@ -110,13 +110,19 @@ struct CGUseList {
110110
/// A mapping between a discardable callgraph node (that is a symbol) and the
111111
/// number of uses for this node.
112112
DenseMap<CallGraphNode *, int> discardableSymNodeUses;
113+
113114
/// A mapping between a callgraph node and the symbol callgraph nodes that it
114115
/// uses.
115116
DenseMap<CallGraphNode *, CGUser> nodeUses;
117+
118+
/// A symbol table to use when resolving call lookups.
119+
SymbolTableCollection &symbolTable;
116120
};
117121
} // end anonymous namespace
118122

119-
CGUseList::CGUseList(Operation *op, CallGraph &cg) {
123+
CGUseList::CGUseList(Operation *op, CallGraph &cg,
124+
SymbolTableCollection &symbolTable)
125+
: symbolTable(symbolTable) {
120126
/// A set of callgraph nodes that are always known to be live during inlining.
121127
DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes;
122128

@@ -135,7 +141,7 @@ CGUseList::CGUseList(Operation *op, CallGraph &cg) {
135141
}
136142
}
137143
// Otherwise, check for any referenced nodes. These will be always-live.
138-
walkReferencedSymbolNodes(&op, cg, alwaysLiveNodes,
144+
walkReferencedSymbolNodes(&op, cg, symbolTable, alwaysLiveNodes,
139145
[](CallGraphNode *, Operation *) {});
140146
}
141147
};
@@ -162,7 +168,7 @@ void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp,
162168
--discardableSymNodeUses[node];
163169
};
164170
DenseMap<Attribute, CallGraphNode *> resolvedRefs;
165-
walkReferencedSymbolNodes(callOp, cg, resolvedRefs, walkFn);
171+
walkReferencedSymbolNodes(callOp, cg, symbolTable, resolvedRefs, walkFn);
166172
}
167173

168174
void CGUseList::eraseNode(CallGraphNode *node) {
@@ -220,7 +226,7 @@ void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) {
220226
return;
221227
++discardSymIt->second;
222228
};
223-
walkReferencedSymbolNodes(parentOp, cg, resolvedRefs, walkFn);
229+
walkReferencedSymbolNodes(parentOp, cg, symbolTable, resolvedRefs, walkFn);
224230
}
225231

226232
void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) {
@@ -305,6 +311,7 @@ struct ResolvedCall {
305311
/// inside of nested callgraph nodes.
306312
static void collectCallOps(iterator_range<Region::iterator> blocks,
307313
CallGraphNode *sourceNode, CallGraph &cg,
314+
SymbolTableCollection &symbolTable,
308315
SmallVectorImpl<ResolvedCall> &calls,
309316
bool traverseNestedCGNodes) {
310317
SmallVector<std::pair<Block *, CallGraphNode *>, 8> worklist;
@@ -328,7 +335,7 @@ static void collectCallOps(iterator_range<Region::iterator> blocks,
328335
continue;
329336
}
330337

331-
CallGraphNode *targetNode = cg.resolveCallable(call);
338+
CallGraphNode *targetNode = cg.resolveCallable(call, symbolTable);
332339
if (!targetNode->isExternal())
333340
calls.emplace_back(call, sourceNode, targetNode);
334341
continue;
@@ -352,8 +359,9 @@ static void collectCallOps(iterator_range<Region::iterator> blocks,
352359
namespace {
353360
/// This class provides a specialization of the main inlining interface.
354361
struct Inliner : public InlinerInterface {
355-
Inliner(MLIRContext *context, CallGraph &cg)
356-
: InlinerInterface(context), cg(cg) {}
362+
Inliner(MLIRContext *context, CallGraph &cg,
363+
SymbolTableCollection &symbolTable)
364+
: InlinerInterface(context), cg(cg), symbolTable(symbolTable) {}
357365

358366
/// Process a set of blocks that have been inlined. This callback is invoked
359367
/// *before* inlined terminator operations have been processed.
@@ -367,7 +375,7 @@ struct Inliner : public InlinerInterface {
367375
assert(region && "expected valid parent node");
368376
}
369377

370-
collectCallOps(inlinedBlocks, node, cg, calls,
378+
collectCallOps(inlinedBlocks, node, cg, symbolTable, calls,
371379
/*traverseNestedCGNodes=*/true);
372380
}
373381

@@ -389,6 +397,9 @@ struct Inliner : public InlinerInterface {
389397

390398
/// The callgraph being operated on.
391399
CallGraph &cg;
400+
401+
/// A symbol table to use when resolving call lookups.
402+
SymbolTableCollection &symbolTable;
392403
};
393404
} // namespace
394405

@@ -427,11 +438,12 @@ static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
427438
continue;
428439

429440
// Don't collect calls if the node is already dead.
430-
if (useList.isDead(node))
441+
if (useList.isDead(node)) {
431442
deadNodes.push_back(node);
432-
else
433-
collectCallOps(*node->getCallableRegion(), node, cg, calls,
434-
/*traverseNestedCGNodes=*/false);
443+
} else {
444+
collectCallOps(*node->getCallableRegion(), node, cg, inliner.symbolTable,
445+
calls, /*traverseNestedCGNodes=*/false);
446+
}
435447
}
436448

437449
// Try to inline each of the call operations. Don't cache the end iterator
@@ -585,8 +597,9 @@ void InlinerPass::runOnOperation() {
585597
op->getCanonicalizationPatterns(canonPatterns, context);
586598

587599
// Run the inline transform in post-order over the SCCs in the callgraph.
588-
Inliner inliner(context, cg);
589-
CGUseList useList(getOperation(), cg);
600+
SymbolTableCollection symbolTable;
601+
Inliner inliner(context, cg, symbolTable);
602+
CGUseList useList(getOperation(), cg, symbolTable);
590603
runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
591604
inlineSCC(inliner, useList, scc, context, canonPatterns);
592605
});

mlir/lib/Transforms/SCCP.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,9 @@ class SCCPSolver {
304304
/// avoids re-resolving symbol references during propagation. Value based
305305
/// callables are trivial to resolve, so they can be done in-place.
306306
DenseMap<Operation *, Operation *> callToSymbolCallable;
307+
308+
/// A symbol table used for O(1) symbol lookups during simplification.
309+
SymbolTableCollection symbolTable;
307310
};
308311
} // end anonymous namespace
309312

@@ -425,7 +428,7 @@ void SCCPSolver::initializeSymbolCallables(Operation *op) {
425428
// If the use is a call, track it to avoid the need to recompute the
426429
// reference later.
427430
if (auto callOp = dyn_cast<CallOpInterface>(use.getUser())) {
428-
Operation *symCallable = callOp.resolveCallable();
431+
Operation *symCallable = callOp.resolveCallable(&symbolTable);
429432
auto callableLatticeIt = callableLatticeState.find(symCallable);
430433
if (callableLatticeIt != callableLatticeState.end()) {
431434
callToSymbolCallable.try_emplace(callOp, symCallable);
@@ -438,7 +441,7 @@ void SCCPSolver::initializeSymbolCallables(Operation *op) {
438441
continue;
439442
}
440443
// This use isn't a call, so don't we know all of the callers.
441-
auto *symbol = SymbolTable::lookupSymbolIn(op, use.getSymbolRef());
444+
auto *symbol = symbolTable.lookupSymbolIn(op, use.getSymbolRef());
442445
auto it = callableLatticeState.find(symbol);
443446
if (it != callableLatticeState.end())
444447
markAllOverdefined(it->second.getCallableArguments());

0 commit comments

Comments
 (0)