Skip to content

Commit 5c159b9

Browse files
committed
[mlir] Add a utility method on CallOpInterface for resolving the callable.
Summary: This is the most common operation performed on a CallOpInterface. This just moves the existing functionality from the CallGraph so that other users can access it. Differential Revision: https://reviews.llvm.org/D74250
1 parent d4fbf83 commit 5c159b9

File tree

5 files changed

+25
-31
lines changed

5 files changed

+25
-31
lines changed

mlir/include/mlir/Analysis/CallGraph.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "llvm/ADT/SetVector.h"
2424

2525
namespace mlir {
26+
class CallOpInterface;
2627
struct CallInterfaceCallable;
2728
class Operation;
2829
class Region;
@@ -188,11 +189,8 @@ class CallGraph {
188189
}
189190

190191
/// Resolve the callable for given callee to a node in the callgraph, or the
191-
/// external node if a valid node was not resolved. 'from' provides an anchor
192-
/// for symbol table lookups, and is only required if the callable is a symbol
193-
/// reference.
194-
CallGraphNode *resolveCallable(CallInterfaceCallable callable,
195-
Operation *from = nullptr) const;
192+
/// external node if a valid node was not resolved.
193+
CallGraphNode *resolveCallable(CallOpInterface call) const;
196194

197195
/// An iterator over the nodes of the graph.
198196
using iterator = NodeIterator;

mlir/include/mlir/Analysis/CallInterfaces.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,10 @@
1414
#ifndef MLIR_ANALYSIS_CALLINTERFACES_H
1515
#define MLIR_ANALYSIS_CALLINTERFACES_H
1616

17-
#include "mlir/IR/OpDefinition.h"
17+
#include "mlir/IR/SymbolTable.h"
1818
#include "llvm/ADT/PointerUnion.h"
1919

2020
namespace mlir {
21-
2221
/// A callable is either a symbol, or an SSA value, that is referenced by a
2322
/// call-like operation. This represents the destination of the call.
2423
struct CallInterfaceCallable : public PointerUnion<SymbolRefAttr, Value> {

mlir/include/mlir/Analysis/CallInterfaces.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,18 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
4444
}],
4545
"Operation::operand_range", "getArgOperands"
4646
>,
47+
InterfaceMethod<[{
48+
Resolve the callable operation for given callee to a
49+
CallableOpInterface, or nullptr if a valid callable was not resolved.
50+
}],
51+
"Operation *", "resolveCallable", (ins), [{
52+
// If the callable isn't a value, lookup the symbol reference.
53+
CallInterfaceCallable callable = op.getCallableForCallee();
54+
if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>())
55+
return SymbolTable::lookupNearestSymbolFrom(op, symbolRef);
56+
return callable.get<Value>().getDefiningOp();
57+
}]
58+
>,
4759
];
4860
}
4961

mlir/lib/Analysis/CallGraph.cpp

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,8 @@ static void computeCallGraph(Operation *op, CallGraph &cg,
7979
// If there is no parent node, we ignore this operation. Even if this
8080
// operation was a call, there would be no callgraph node to attribute it
8181
// to.
82-
if (!resolveCalls || !parentNode)
83-
return;
84-
parentNode->addCallEdge(
85-
cg.resolveCallable(call.getCallableForCallee(), op));
82+
if (resolveCalls && parentNode)
83+
parentNode->addCallEdge(cg.resolveCallable(call));
8684
return;
8785
}
8886

@@ -141,23 +139,11 @@ CallGraphNode *CallGraph::lookupNode(Region *region) const {
141139

142140
/// Resolve the callable for given callee to a node in the callgraph, or the
143141
/// external node if a valid node was not resolved.
144-
CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable,
145-
Operation *from) const {
146-
// Get the callee operation from the callable.
147-
Operation *callee;
148-
if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>())
149-
callee = SymbolTable::lookupNearestSymbolFrom(from, symbolRef);
150-
else
151-
callee = callable.get<Value>().getDefiningOp();
152-
153-
// If the callee is non-null and is a valid callable object, try to get the
154-
// called region from it.
155-
if (callee && callee->getNumRegions()) {
156-
if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callee)) {
157-
if (auto *node = lookupNode(callableOp.getCallableRegion()))
158-
return node;
159-
}
160-
}
142+
CallGraphNode *CallGraph::resolveCallable(CallOpInterface call) const {
143+
Operation *callable = call.resolveCallable();
144+
if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callable))
145+
if (auto *node = lookupNode(callableOp.getCallableRegion()))
146+
return node;
161147

162148
// If we don't have a valid direct region, this is an external call.
163149
return getExternalNode();

mlir/lib/Transforms/Inliner.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,14 @@ static void collectCallOps(iterator_range<Region::iterator> blocks,
8686
while (!worklist.empty()) {
8787
for (Operation &op : *worklist.pop_back_val()) {
8888
if (auto call = dyn_cast<CallOpInterface>(op)) {
89-
CallInterfaceCallable callable = call.getCallableForCallee();
90-
9189
// TODO(riverriddle) Support inlining nested call references.
90+
CallInterfaceCallable callable = call.getCallableForCallee();
9291
if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) {
9392
if (!symRef.isa<FlatSymbolRefAttr>())
9493
continue;
9594
}
9695

97-
CallGraphNode *node = cg.resolveCallable(callable, &op);
96+
CallGraphNode *node = cg.resolveCallable(call);
9897
if (!node->isExternal())
9998
calls.emplace_back(call, node);
10099
continue;

0 commit comments

Comments
 (0)