Skip to content

Commit 66a0b34

Browse files
committed
[Attributor] AAFunctionReachability, Handle CallBase Reachability.
This patch makes it possible to query callbase reachability (Can a callbase reach a function Fn transitively). The patch moves the reachability query handling logic to a member class, this class will have more users within the AA once we add other function reachability queries. Reviewed By: jdoerfert Differential Revision: https://reviews.llvm.org/D106402
1 parent 2cc6f7c commit 66a0b34

File tree

3 files changed

+157
-69
lines changed

3 files changed

+157
-69
lines changed

llvm/include/llvm/Transforms/IPO/Attributor.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4454,6 +4454,9 @@ struct AAFunctionReachability
44544454
/// If the function represented by this possition can reach \p Fn.
44554455
virtual bool canReach(Attributor &A, Function *Fn) const = 0;
44564456

4457+
/// Can \p CB reach \p Fn
4458+
virtual bool canReach(Attributor &A, CallBase &CB, Function *Fn) const = 0;
4459+
44574460
/// Create an abstract attribute view for the position \p IRP.
44584461
static AAFunctionReachability &createForPosition(const IRPosition &IRP,
44594462
Attributor &A);

llvm/lib/Transforms/IPO/AttributorAttributes.cpp

Lines changed: 134 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -9496,115 +9496,180 @@ struct AACallEdgesFunction : public AACallEdgesImpl {
94969496
};
94979497

94989498
struct AAFunctionReachabilityFunction : public AAFunctionReachability {
9499-
AAFunctionReachabilityFunction(const IRPosition &IRP, Attributor &A)
9500-
: AAFunctionReachability(IRP, A) {}
9499+
private:
9500+
struct QuerySet {
9501+
void markReachable(Function *Fn) {
9502+
Reachable.insert(Fn);
9503+
Unreachable.erase(Fn);
9504+
}
9505+
9506+
ChangeStatus update(Attributor &A, const AAFunctionReachability &AA,
9507+
ArrayRef<const AACallEdges *> AAEdgesList) {
9508+
ChangeStatus Change = ChangeStatus::UNCHANGED;
9509+
9510+
for (auto *AAEdges : AAEdgesList) {
9511+
if (AAEdges->hasUnknownCallee()) {
9512+
if (!CanReachUnknownCallee)
9513+
Change = ChangeStatus::CHANGED;
9514+
CanReachUnknownCallee = true;
9515+
return Change;
9516+
}
9517+
}
95019518

9502-
bool canReach(Attributor &A, Function *Fn) const override {
9503-
// Assume that we can reach any function if we can reach a call with
9504-
// unknown callee.
9505-
if (CanReachUnknownCallee)
9506-
return true;
9519+
for (Function *Fn : make_early_inc_range(Unreachable)) {
9520+
if (checkIfReachable(A, AA, AAEdgesList, Fn)) {
9521+
Change = ChangeStatus::CHANGED;
9522+
markReachable(Fn);
9523+
}
9524+
}
9525+
return Change;
9526+
}
95079527

9508-
if (ReachableQueries.count(Fn))
9509-
return true;
9528+
bool isReachable(Attributor &A, const AAFunctionReachability &AA,
9529+
ArrayRef<const AACallEdges *> AAEdgesList, Function *Fn) {
9530+
// Assume that we can reach the function.
9531+
// TODO: Be more specific with the unknown callee.
9532+
if (CanReachUnknownCallee)
9533+
return true;
9534+
9535+
if (Reachable.count(Fn))
9536+
return true;
9537+
9538+
if (Unreachable.count(Fn))
9539+
return false;
9540+
9541+
// We need to assume that this function can't reach Fn to prevent
9542+
// an infinite loop if this function is recursive.
9543+
Unreachable.insert(Fn);
9544+
9545+
bool Result = checkIfReachable(A, AA, AAEdgesList, Fn);
9546+
if (Result)
9547+
markReachable(Fn);
9548+
return Result;
9549+
}
9550+
9551+
bool checkIfReachable(Attributor &A, const AAFunctionReachability &AA,
9552+
ArrayRef<const AACallEdges *> AAEdgesList,
9553+
Function *Fn) const {
9554+
9555+
// Handle the most trivial case first.
9556+
for (auto *AAEdges : AAEdgesList) {
9557+
const SetVector<Function *> &Edges = AAEdges->getOptimisticEdges();
9558+
9559+
if (Edges.count(Fn))
9560+
return true;
9561+
}
9562+
9563+
SmallVector<const AAFunctionReachability *, 8> Deps;
9564+
for (auto &AAEdges : AAEdgesList) {
9565+
const SetVector<Function *> &Edges = AAEdges->getOptimisticEdges();
9566+
9567+
for (Function *Edge : Edges) {
9568+
// We don't need a dependency if the result is reachable.
9569+
const AAFunctionReachability &EdgeReachability =
9570+
A.getAAFor<AAFunctionReachability>(
9571+
AA, IRPosition::function(*Edge), DepClassTy::NONE);
9572+
Deps.push_back(&EdgeReachability);
9573+
9574+
if (EdgeReachability.canReach(A, Fn))
9575+
return true;
9576+
}
9577+
}
9578+
9579+
// The result is false for now, set dependencies and leave.
9580+
for (auto Dep : Deps)
9581+
A.recordDependence(AA, *Dep, DepClassTy::REQUIRED);
95109582

9511-
if (UnreachableQueries.count(Fn))
95129583
return false;
9584+
}
9585+
9586+
/// Set of functions that we know for sure is reachable.
9587+
DenseSet<Function *> Reachable;
9588+
9589+
/// Set of functions that are unreachable, but might become reachable.
9590+
DenseSet<Function *> Unreachable;
95139591

9592+
/// If we can reach a function with a call to a unknown function we assume
9593+
/// that we can reach any function.
9594+
bool CanReachUnknownCallee = false;
9595+
};
9596+
9597+
public:
9598+
AAFunctionReachabilityFunction(const IRPosition &IRP, Attributor &A)
9599+
: AAFunctionReachability(IRP, A) {}
9600+
9601+
bool canReach(Attributor &A, Function *Fn) const override {
95149602
const AACallEdges &AAEdges =
95159603
A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::REQUIRED);
95169604

9517-
const SetVector<Function *> &Edges = AAEdges.getOptimisticEdges();
9518-
bool Result = checkIfReachable(A, Edges, Fn);
9605+
// Attributor returns attributes as const, so this function has to be
9606+
// const for users of this attribute to use it without having to do
9607+
// a const_cast.
9608+
// This is a hack for us to be able to cache queries.
9609+
auto *NonConstThis = const_cast<AAFunctionReachabilityFunction *>(this);
9610+
bool Result =
9611+
NonConstThis->WholeFunction.isReachable(A, *this, {&AAEdges}, Fn);
9612+
9613+
return Result;
9614+
}
9615+
9616+
/// Can \p CB reach \p Fn
9617+
bool canReach(Attributor &A, CallBase &CB, Function *Fn) const override {
9618+
const AACallEdges &AAEdges = A.getAAFor<AACallEdges>(
9619+
*this, IRPosition::callsite_function(CB), DepClassTy::REQUIRED);
95199620

95209621
// Attributor returns attributes as const, so this function has to be
95219622
// const for users of this attribute to use it without having to do
95229623
// a const_cast.
95239624
// This is a hack for us to be able to cache queries.
95249625
auto *NonConstThis = const_cast<AAFunctionReachabilityFunction *>(this);
9626+
QuerySet &CBQuery = NonConstThis->CBQueries[&CB];
95259627

9526-
if (Result)
9527-
NonConstThis->ReachableQueries.insert(Fn);
9528-
else
9529-
NonConstThis->UnreachableQueries.insert(Fn);
9628+
bool Result = CBQuery.isReachable(A, *this, {&AAEdges}, Fn);
95309629

95319630
return Result;
95329631
}
95339632

95349633
/// See AbstractAttribute::updateImpl(...).
95359634
ChangeStatus updateImpl(Attributor &A) override {
9536-
if (CanReachUnknownCallee)
9537-
return ChangeStatus::UNCHANGED;
9538-
95399635
const AACallEdges &AAEdges =
95409636
A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::REQUIRED);
9541-
const SetVector<Function *> &Edges = AAEdges.getOptimisticEdges();
95429637
ChangeStatus Change = ChangeStatus::UNCHANGED;
95439638

9544-
if (AAEdges.hasUnknownCallee()) {
9545-
bool OldCanReachUnknown = CanReachUnknownCallee;
9546-
CanReachUnknownCallee = true;
9547-
return OldCanReachUnknown ? ChangeStatus::UNCHANGED
9548-
: ChangeStatus::CHANGED;
9549-
}
9639+
Change |= WholeFunction.update(A, *this, {&AAEdges});
95509640

9551-
// Check if any of the unreachable functions become reachable.
9552-
for (auto Current = UnreachableQueries.begin();
9553-
Current != UnreachableQueries.end();) {
9554-
if (!checkIfReachable(A, Edges, *Current)) {
9555-
Current++;
9556-
continue;
9557-
}
9558-
ReachableQueries.insert(*Current);
9559-
UnreachableQueries.erase(*Current++);
9560-
Change = ChangeStatus::CHANGED;
9641+
for (auto CBPair : CBQueries) {
9642+
const AACallEdges &AAEdges = A.getAAFor<AACallEdges>(
9643+
*this, IRPosition::callsite_function(*CBPair.first),
9644+
DepClassTy::REQUIRED);
9645+
9646+
Change |= CBPair.second.update(A, *this, {&AAEdges});
95619647
}
95629648

95639649
return Change;
95649650
}
95659651

95669652
const std::string getAsStr() const override {
9567-
size_t QueryCount = ReachableQueries.size() + UnreachableQueries.size();
9653+
size_t QueryCount =
9654+
WholeFunction.Reachable.size() + WholeFunction.Unreachable.size();
95689655

9569-
return "FunctionReachability [" + std::to_string(ReachableQueries.size()) +
9570-
"," + std::to_string(QueryCount) + "]";
9656+
return "FunctionReachability [" +
9657+
std::to_string(WholeFunction.Reachable.size()) + "," +
9658+
std::to_string(QueryCount) + "]";
95719659
}
95729660

95739661
void trackStatistics() const override {}
9574-
95759662
private:
9576-
bool canReachUnknownCallee() const override { return CanReachUnknownCallee; }
9577-
9578-
bool checkIfReachable(Attributor &A, const SetVector<Function *> &Edges,
9579-
Function *Fn) const {
9580-
if (Edges.count(Fn))
9581-
return true;
9582-
9583-
for (Function *Edge : Edges) {
9584-
// We don't need a dependency if the result is reachable.
9585-
const AAFunctionReachability &EdgeReachability =
9586-
A.getAAFor<AAFunctionReachability>(*this, IRPosition::function(*Edge),
9587-
DepClassTy::NONE);
9588-
9589-
if (EdgeReachability.canReach(A, Fn))
9590-
return true;
9591-
}
9592-
for (Function *Fn : Edges)
9593-
A.getAAFor<AAFunctionReachability>(*this, IRPosition::function(*Fn),
9594-
DepClassTy::REQUIRED);
9595-
9596-
return false;
9663+
bool canReachUnknownCallee() const override {
9664+
return WholeFunction.CanReachUnknownCallee;
95979665
}
95989666

9599-
/// Set of functions that we know for sure is reachable.
9600-
SmallPtrSet<Function *, 8> ReachableQueries;
9601-
9602-
/// Set of functions that are unreachable, but might become reachable.
9603-
SmallPtrSet<Function *, 8> UnreachableQueries;
9667+
/// Used to answer if a the whole function can reacha a specific function.
9668+
QuerySet WholeFunction;
96049669

9605-
/// If we can reach a function with a call to a unknown function we assume
9606-
/// that we can reach any function.
9607-
bool CanReachUnknownCallee = false;
9670+
/// Used to answer if a call base inside this function can reach a specific
9671+
/// function.
9672+
DenseMap<CallBase *, QuerySet> CBQueries;
96089673
};
96099674

96109675
} // namespace

llvm/unittests/Transforms/IPO/AttributorTest.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,13 @@ TEST_F(AttributorTestBase, AAReachabilityTest) {
109109
call void @func5(void ()* @func3)
110110
ret void
111111
}
112+
113+
define void @func7() {
114+
entry:
115+
call void @func2()
116+
call void @func4()
117+
ret void
118+
}
112119
)";
113120

114121
Module &M = parseModule(ModuleString);
@@ -127,22 +134,35 @@ TEST_F(AttributorTestBase, AAReachabilityTest) {
127134
Function *F3 = M.getFunction("func3");
128135
Function *F4 = M.getFunction("func4");
129136
Function *F6 = M.getFunction("func6");
137+
Function *F7 = M.getFunction("func7");
138+
139+
// call void @func2()
140+
CallBase &F7FirstCB =
141+
*static_cast<CallBase *>(F7->getEntryBlock().getFirstNonPHI());
130142

131143
const AAFunctionReachability &F1AA =
132144
A.getOrCreateAAFor<AAFunctionReachability>(IRPosition::function(*F1));
133145

134146
const AAFunctionReachability &F6AA =
135147
A.getOrCreateAAFor<AAFunctionReachability>(IRPosition::function(*F6));
136148

149+
const AAFunctionReachability &F7AA =
150+
A.getOrCreateAAFor<AAFunctionReachability>(IRPosition::function(*F7));
151+
137152
F1AA.canReach(A, F3);
138153
F1AA.canReach(A, F4);
139154
F6AA.canReach(A, F4);
155+
F7AA.canReach(A, F7FirstCB, F3);
156+
F7AA.canReach(A, F7FirstCB, F4);
140157

141158
A.run();
142159

143160
ASSERT_TRUE(F1AA.canReach(A, F3));
144161
ASSERT_FALSE(F1AA.canReach(A, F4));
145162

163+
ASSERT_TRUE(F7AA.canReach(A, F7FirstCB, F3));
164+
ASSERT_FALSE(F7AA.canReach(A, F7FirstCB, F4));
165+
146166
// Assumed to be reacahable, since F6 can reach a function with
147167
// a unknown callee.
148168
ASSERT_TRUE(F6AA.canReach(A, F4));

0 commit comments

Comments
 (0)