Skip to content

Commit e11ff65

Browse files
committed
[ctx_prof] Add Inlining support
1 parent aaed557 commit e11ff65

File tree

9 files changed

+363
-5
lines changed

9 files changed

+363
-5
lines changed

llvm/include/llvm/Analysis/CtxProfAnalysis.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,16 @@ class PGOContextualProfile {
6363
return getDefinedFunctionGUID(F) != 0;
6464
}
6565

66+
uint32_t getNrCounters(const Function &F) const {
67+
assert(isFunctionKnown(F));
68+
return FuncInfo.find(getDefinedFunctionGUID(F))->second.NextCounterIndex;
69+
}
70+
71+
uint32_t getNrCallsites(const Function &F) const {
72+
assert(isFunctionKnown(F));
73+
return FuncInfo.find(getDefinedFunctionGUID(F))->second.NextCallsiteIndex;
74+
}
75+
6676
uint32_t allocateNextCounterIndex(const Function &F) {
6777
assert(isFunctionKnown(F));
6878
return FuncInfo.find(getDefinedFunctionGUID(F))->second.NextCounterIndex++;
@@ -113,9 +123,7 @@ class CtxProfAnalysisPrinterPass
113123
: public PassInfoMixin<CtxProfAnalysisPrinterPass> {
114124
public:
115125
enum class PrintMode { Everything, JSON };
116-
explicit CtxProfAnalysisPrinterPass(raw_ostream &OS,
117-
PrintMode Mode = PrintMode::Everything)
118-
: OS(OS), Mode(Mode) {}
126+
explicit CtxProfAnalysisPrinterPass(raw_ostream &OS);
119127

120128
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
121129
static bool isRequired() { return true; }

llvm/include/llvm/IR/IntrinsicInst.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1516,6 +1516,8 @@ class InstrProfInstBase : public IntrinsicInst {
15161516
return const_cast<Value *>(getArgOperand(0))->stripPointerCasts();
15171517
}
15181518

1519+
void setNameValue(Value *V) { setArgOperand(0, V); }
1520+
15191521
// The hash of the CFG for the instrumented function.
15201522
ConstantInt *getHash() const {
15211523
return cast<ConstantInt>(const_cast<Value *>(getArgOperand(1)));

llvm/include/llvm/ProfileData/PGOCtxProfReader.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@ class PGOCtxProfContext final {
7474
Iter->second.emplace(Other.guid(), std::move(Other));
7575
}
7676

77+
void ingestAllContexts(uint32_t CSId, CallTargetMapTy &&Other) {
78+
auto [_, Inserted] = callsites().try_emplace(CSId, std::move(Other));
79+
assert(Inserted);
80+
}
81+
7782
void resizeCounters(uint32_t Size) { Counters.resize(Size); }
7883

7984
bool hasCallsite(uint32_t I) const {

llvm/include/llvm/Transforms/Utils/Cloning.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "llvm/ADT/SmallVector.h"
2121
#include "llvm/ADT/Twine.h"
2222
#include "llvm/Analysis/AssumptionCache.h"
23+
#include "llvm/Analysis/CtxProfAnalysis.h"
2324
#include "llvm/Analysis/InlineCost.h"
2425
#include "llvm/IR/BasicBlock.h"
2526
#include "llvm/IR/ValueHandle.h"
@@ -270,6 +271,17 @@ InlineResult InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
270271
bool InsertLifetime = true,
271272
Function *ForwardVarArgsTo = nullptr);
272273

274+
/// Same as above, but it will update the contextual profile. If the contextual
275+
/// profile is invalid (i.e. not loaded because it is not present), it defaults
276+
/// to the behavior of the non-contextual profile updating variant above. This
277+
/// makes it easy to drop-in replace uses of the non-contextual overload.
278+
InlineResult InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
279+
CtxProfAnalysis::Result &CtxProf,
280+
bool MergeAttributes = false,
281+
AAResults *CalleeAAR = nullptr,
282+
bool InsertLifetime = true,
283+
Function *ForwardVarArgsTo = nullptr);
284+
273285
/// Clones a loop \p OrigLoop. Returns the loop and the blocks in \p
274286
/// Blocks.
275287
///

llvm/lib/Analysis/CtxProfAnalysis.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,15 @@ cl::opt<std::string>
2929
UseCtxProfile("use-ctx-profile", cl::init(""), cl::Hidden,
3030
cl::desc("Use the specified contextual profile file"));
3131

32+
static cl::opt<CtxProfAnalysisPrinterPass::PrintMode> PrintLevel(
33+
"ctx-profile-printer-level",
34+
cl::init(CtxProfAnalysisPrinterPass::PrintMode::Everything), cl::Hidden,
35+
cl::values(clEnumValN(CtxProfAnalysisPrinterPass::PrintMode::Everything,
36+
"everything", "print everything - most verbose"),
37+
clEnumValN(CtxProfAnalysisPrinterPass::PrintMode::JSON, "json",
38+
"just the json representation of the profile")),
39+
cl::desc("Verbosity level of the contextual profile printer pass."));
40+
3241
namespace llvm {
3342
namespace json {
3443
Value toJSON(const PGOCtxProfContext &P) {
@@ -150,7 +159,6 @@ PGOContextualProfile CtxProfAnalysis::run(Module &M,
150159
// If we made it this far, the Result is valid - which we mark by setting
151160
// .Profiles.
152161
// Trim first the roots that aren't in this module.
153-
DenseSet<GlobalValue::GUID> ProfiledGUIDs;
154162
for (auto &[RootGuid, _] : llvm::make_early_inc_range(*MaybeCtx))
155163
if (!Result.FuncInfo.contains(RootGuid))
156164
MaybeCtx->erase(RootGuid);
@@ -165,6 +173,10 @@ PGOContextualProfile::getDefinedFunctionGUID(const Function &F) const {
165173
return 0;
166174
}
167175

176+
CtxProfAnalysisPrinterPass::CtxProfAnalysisPrinterPass(raw_ostream &OS)
177+
: OS(OS),
178+
Mode(PrintLevel.getNumOccurrences() > 0 ? PrintLevel : PrintMode::JSON) {}
179+
168180
PreservedAnalyses CtxProfAnalysisPrinterPass::run(Module &M,
169181
ModuleAnalysisManager &MAM) {
170182
CtxProfAnalysis::Result &C = MAM.getResult<CtxProfAnalysis>(M);

llvm/lib/Transforms/IPO/ModuleInliner.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "llvm/Analysis/AliasAnalysis.h"
2121
#include "llvm/Analysis/AssumptionCache.h"
2222
#include "llvm/Analysis/BlockFrequencyInfo.h"
23+
#include "llvm/Analysis/CtxProfAnalysis.h"
2324
#include "llvm/Analysis/InlineAdvisor.h"
2425
#include "llvm/Analysis/InlineCost.h"
2526
#include "llvm/Analysis/InlineOrder.h"
@@ -113,6 +114,8 @@ PreservedAnalyses ModuleInlinerPass::run(Module &M,
113114
return PreservedAnalyses::all();
114115
}
115116

117+
auto &CtxProf = MAM.getResult<CtxProfAnalysis>(M);
118+
116119
bool Changed = false;
117120

118121
ProfileSummaryInfo *PSI = MAM.getCachedResult<ProfileSummaryAnalysis>(M);
@@ -213,7 +216,7 @@ PreservedAnalyses ModuleInlinerPass::run(Module &M,
213216
&FAM.getResult<BlockFrequencyAnalysis>(Callee));
214217

215218
InlineResult IR =
216-
InlineFunction(*CB, IFI, /*MergeAttributes=*/true,
219+
InlineFunction(*CB, IFI, CtxProf, /*MergeAttributes=*/true,
217220
&FAM.getResult<AAManager>(*CB->getCaller()));
218221
if (!IR.isSuccess()) {
219222
Advice->recordUnsuccessfulInlining(IR);

llvm/lib/Transforms/Utils/InlineFunction.cpp

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "llvm/Analysis/BlockFrequencyInfo.h"
2424
#include "llvm/Analysis/CallGraph.h"
2525
#include "llvm/Analysis/CaptureTracking.h"
26+
#include "llvm/Analysis/CtxProfAnalysis.h"
2627
#include "llvm/Analysis/IndirectCallVisitor.h"
2728
#include "llvm/Analysis/InstructionSimplify.h"
2829
#include "llvm/Analysis/MemoryProfileInfo.h"
@@ -46,6 +47,7 @@
4647
#include "llvm/IR/Dominators.h"
4748
#include "llvm/IR/EHPersonalities.h"
4849
#include "llvm/IR/Function.h"
50+
#include "llvm/IR/GlobalVariable.h"
4951
#include "llvm/IR/IRBuilder.h"
5052
#include "llvm/IR/InlineAsm.h"
5153
#include "llvm/IR/InstrTypes.h"
@@ -71,6 +73,7 @@
7173
#include <algorithm>
7274
#include <cassert>
7375
#include <cstdint>
76+
#include <deque>
7477
#include <iterator>
7578
#include <limits>
7679
#include <optional>
@@ -2116,6 +2119,203 @@ inlineRetainOrClaimRVCalls(CallBase &CB, objcarc::ARCInstKind RVCallKind,
21162119
}
21172120
}
21182121

2122+
// In contextual profiling, when an inline succeeds, we want to remap the
2123+
// indices of the callee in the index space of the caller. We can't just leave
2124+
// them as-is because the same callee may appear in other places in this caller
2125+
// (other callsites), and its (callee's) counters and sub-contextual profile
2126+
// tree would be potentially different.
2127+
// Not all BBs of the callee may survive the opportunistic DCE InlineFunction
2128+
// does (same goes for callsites in the callee).
2129+
// We will return a pair of vectors, one for basic block IDs and one for
2130+
// callsites. For such a vector V, V[Idx] will be -1 if the callee
2131+
// instrumentation with index Idx did not survive inlining, and a new value
2132+
// otherwise.
2133+
// This function will update the instrumentation intrinsics accordingly,
2134+
// mapping indices as described above. We also replace the "name" operand
2135+
// because we use it to distinguish between "own" instrumentation and "from
2136+
// callee" instrumentation when performing the traversal of the CFG of the
2137+
// caller. We traverse depth-first from the callsite's BB and up to the point we
2138+
// hit owned BBs.
2139+
// The return values will be then used to update the contextual
2140+
// profile. Note: we only update the "name" and "index" operands in the
2141+
// instrumentation intrinsics, we leave the hash and total nr of indices as-is,
2142+
// it's not worth updating those.
2143+
static const std::pair<std::vector<int64_t>, std::vector<int64_t>>
2144+
remapIndices(Function &Caller, BasicBlock *StartBB,
2145+
CtxProfAnalysis::Result &CtxProf, uint32_t CalleeCounters,
2146+
uint32_t CalleeCallsites) {
2147+
// We'll allocate a new ID to imported callsite counters and callsites. We're
2148+
// using -1 to indicate a counter we delete. Most likely the entry, for
2149+
// example, will be deleted - we don't want 2 IDs in the same BB, and the
2150+
// entry would have been cloned in the callsite's old BB.
2151+
std::vector<int64_t> CalleeCounterMap;
2152+
std::vector<int64_t> CalleeCallsiteMap;
2153+
CalleeCounterMap.resize(CalleeCounters, -1);
2154+
CalleeCallsiteMap.resize(CalleeCallsites, -1);
2155+
2156+
auto RewriteInstrIfNeeded = [&](InstrProfIncrementInst &Ins) -> bool {
2157+
if (Ins.getNameValue() == &Caller)
2158+
return false;
2159+
const auto OldID = static_cast<uint32_t>(Ins.getIndex()->getZExtValue());
2160+
if (CalleeCounterMap[OldID] == -1)
2161+
CalleeCounterMap[OldID] = CtxProf.allocateNextCounterIndex(Caller);
2162+
const auto NewID = static_cast<uint32_t>(CalleeCounterMap[OldID]);
2163+
2164+
Ins.setNameValue(&Caller);
2165+
Ins.setIndex(NewID);
2166+
return true;
2167+
};
2168+
2169+
auto RewriteCallsiteInsIfNeeded = [&](InstrProfCallsite &Ins) -> bool {
2170+
if (Ins.getNameValue() == &Caller)
2171+
return false;
2172+
const auto OldID = static_cast<uint32_t>(Ins.getIndex()->getZExtValue());
2173+
if (CalleeCallsiteMap[OldID] == -1)
2174+
CalleeCallsiteMap[OldID] = CtxProf.allocateNextCallsiteIndex(Caller);
2175+
const auto NewID = static_cast<uint32_t>(CalleeCallsiteMap[OldID]);
2176+
2177+
Ins.setNameValue(&Caller);
2178+
Ins.setIndex(NewID);
2179+
return true;
2180+
};
2181+
2182+
std::deque<BasicBlock *> Worklist;
2183+
DenseSet<const BasicBlock *> Seen;
2184+
// We will traverse the BBs starting from the callsite BB. The callsite BB
2185+
// will have at least a BB ID - maybe its own, and in any case the one coming
2186+
// from the cloned function's entry BB. The other BBs we'll start seeing from
2187+
// there on may or may not have BB IDs. BBs with IDs belonging to our caller
2188+
// are definitely not coming from the imported function and form a boundary
2189+
// past which we don't need to traverse anymore. BBs may have no
2190+
// instrumentation (because we originally inserted instrumentation as per
2191+
// MST), in which case we'll traverse past them. An invariant we'll keep is
2192+
// that a BB will have at most 1 BB ID. For example, in the callsite BB, we
2193+
// will delete the callee BB's instrumentation. This doesn't result in
2194+
// information loss: the entry BB of the caller will have the same count as
2195+
// the callsite's BB. At the end of this traversal, all the callee's
2196+
// instrumentation would be mapped into the caller's instrumentation index
2197+
// space. Some of the callee's counters may be deleted (as mentioned, this
2198+
// should result in no loss of information).
2199+
Worklist.push_back(StartBB);
2200+
while (!Worklist.empty()) {
2201+
auto *BB = Worklist.front();
2202+
Worklist.pop_front();
2203+
bool Changed = false;
2204+
auto *BBID = CtxProfAnalysis::getBBInstrumentation(*BB);
2205+
if (BBID) {
2206+
Changed |= RewriteInstrIfNeeded(*BBID);
2207+
// this may be the entryblock from the inlined callee, coming into a BB
2208+
// that didn't have instrumentation because of MST decisions. Let's make
2209+
// sure it's placed accordingly. This is a noop elsewhere.
2210+
BBID->moveBefore(&*BB->getFirstInsertionPt());
2211+
}
2212+
for (auto &I : llvm::make_early_inc_range(*BB)) {
2213+
if (auto *Inc = dyn_cast<InstrProfIncrementInst>(&I)) {
2214+
if (Inc != BBID) {
2215+
Inc->eraseFromParent();
2216+
Changed = true;
2217+
}
2218+
} else if (auto *CS = dyn_cast<InstrProfCallsite>(&I)) {
2219+
Changed |= RewriteCallsiteInsIfNeeded(*CS);
2220+
}
2221+
}
2222+
if (!BBID || Changed)
2223+
for (auto *Succ : successors(BB))
2224+
if (Seen.insert(Succ).second)
2225+
Worklist.push_back(Succ);
2226+
}
2227+
return {std::move(CalleeCounterMap), std::move(CalleeCallsiteMap)};
2228+
}
2229+
2230+
llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI,
2231+
CtxProfAnalysis::Result &CtxProf,
2232+
bool MergeAttributes,
2233+
AAResults *CalleeAAR,
2234+
bool InsertLifetime,
2235+
Function *ForwardVarArgsTo) {
2236+
if (!CtxProf)
2237+
return InlineFunction(CB, IFI, MergeAttributes, CalleeAAR, InsertLifetime,
2238+
ForwardVarArgsTo);
2239+
2240+
auto &Caller = *CB.getCaller();
2241+
auto &Callee = *CB.getCalledFunction();
2242+
auto *StartBB = CB.getParent();
2243+
2244+
// Get some preliminary data about the callsite before it might get inlined.
2245+
// Inlining shouldn't delete the callee, but it's cleaner (and low-cost) to
2246+
// get this data upfront and rely less on InlineFunction's behavior.
2247+
const auto CalleeGUID = AssignGUIDPass::getGUID(Callee);
2248+
auto *CallsiteIDIns = CtxProfAnalysis::getCallsiteInstrumentation(CB);
2249+
const auto CallsiteID =
2250+
static_cast<uint32_t>(CallsiteIDIns->getIndex()->getZExtValue());
2251+
2252+
const auto NrCalleeCounters = CtxProf.getNrCounters(Callee);
2253+
const auto NrCalleeCallsites = CtxProf.getNrCallsites(Callee);
2254+
2255+
auto Ret = InlineFunction(CB, IFI, MergeAttributes, CalleeAAR, InsertLifetime,
2256+
ForwardVarArgsTo);
2257+
if (!Ret.isSuccess())
2258+
return Ret;
2259+
2260+
// Inlining succeeded, we don't need the instrumentation of the inlined
2261+
// callsite.
2262+
CallsiteIDIns->eraseFromParent();
2263+
2264+
// Assinging Maps and then capturing references into it in the lambda because
2265+
// captured structured bindings are a C++20 extension. We do also need a
2266+
// capture here, though.
2267+
const auto IndicesMaps = remapIndices(Caller, StartBB, CtxProf,
2268+
NrCalleeCounters, NrCalleeCallsites);
2269+
const uint32_t NewCountersSize = CtxProf.getNrCounters(Caller);
2270+
2271+
auto Updater = [&](PGOCtxProfContext &Ctx) {
2272+
assert(Ctx.guid() == AssignGUIDPass::getGUID(Caller));
2273+
const auto &[CalleeCounterMap, CalleeCallsiteMap] = IndicesMaps;
2274+
assert(
2275+
(Ctx.counters().size() +
2276+
llvm::count_if(CalleeCounterMap, [](auto V) { return V != -1; }) ==
2277+
NewCountersSize) &&
2278+
"The caller's counters size should have grown by the number of new "
2279+
"distinct counters inherited from the inlined callee.");
2280+
Ctx.resizeCounters(NewCountersSize);
2281+
// If the callsite wasn't exercised in this context, the value of the
2282+
// counters coming from it is 0 - which it is right now, after resizing them
2283+
// - and so we're done.
2284+
auto CSIt = Ctx.callsites().find(CallsiteID);
2285+
if (CSIt == Ctx.callsites().end())
2286+
return;
2287+
auto CalleeCtxIt = CSIt->second.find(CalleeGUID);
2288+
// The callsite was exercised, but not with this callee (so presumably this
2289+
// is an indirect callsite). Again, we're done here.
2290+
if (CalleeCtxIt == CSIt->second.end())
2291+
return;
2292+
2293+
// Let's pull in the counter values and the subcontexts coming from the
2294+
// inlined callee.
2295+
auto &CalleeCtx = CalleeCtxIt->second;
2296+
assert(CalleeCtx.guid() == CalleeGUID);
2297+
2298+
for (auto I = 0U; I < CalleeCtx.counters().size(); ++I) {
2299+
const int64_t NewIndex = CalleeCounterMap[I];
2300+
if (NewIndex >= 0)
2301+
Ctx.counters()[NewIndex] = CalleeCtx.counters()[I];
2302+
}
2303+
for (auto &[I, OtherSet] : CalleeCtx.callsites()) {
2304+
const int64_t NewCSIdx = CalleeCallsiteMap[I];
2305+
if (NewCSIdx >= 0)
2306+
Ctx.ingestAllContexts(NewCSIdx, std::move(OtherSet));
2307+
}
2308+
// We know the traversal is preorder, so it wouldn't have yet looked at the
2309+
// sub-contexts of this context that it's currently visiting. Meaning, the
2310+
// erase below invalidates no iterators.
2311+
auto Deleted = Ctx.callsites().erase(CallsiteID);
2312+
assert(Deleted);
2313+
(void)Deleted;
2314+
};
2315+
CtxProf.update(Updater, &Caller);
2316+
return Ret;
2317+
}
2318+
21192319
/// This function inlines the called function into the basic block of the
21202320
/// caller. This returns false if it is not possible to inline this call.
21212321
/// The program is still in a well defined state if this occurs though.

0 commit comments

Comments
 (0)