Skip to content

Commit aa58b7b

Browse files
committed
[CSSPGO][llvm-profgen] Reimplement computeSummaryAndThreshold using context trie
Follow-up patch to https://reviews.llvm.org/D125246, support `computeSummaryAndThreshold` based on context trie. Reviewed By: hoy, wenlei Differential Revision: https://reviews.llvm.org/D127026
1 parent eba5749 commit aa58b7b

File tree

5 files changed

+128
-45
lines changed

5 files changed

+128
-45
lines changed

llvm/include/llvm/Transforms/IPO/SampleContextTracker.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ class SampleContextTracker {
142142
ContextTrieNode &getRootContext();
143143
void promoteMergeContextSamplesTree(const Instruction &Inst,
144144
StringRef CalleeName);
145+
146+
// Create a merged conext-less profile map.
147+
void createContextLessProfileMap(SampleProfileMap &ContextLessProfiles);
145148
// Dump the internal context profile trie.
146149
void dump();
147150

@@ -158,7 +161,6 @@ class SampleContextTracker {
158161
promoteMergeContextSamplesTree(ContextTrieNode &FromNode,
159162
ContextTrieNode &ToNodeParent,
160163
uint32_t ContextFramesToRemove);
161-
162164
// Map from function name to context profiles (excluding base profile)
163165
StringMap<ContextSamplesTy> FuncToCtxtProfiles;
164166

llvm/lib/Transforms/IPO/SampleContextTracker.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,4 +595,24 @@ ContextTrieNode &SampleContextTracker::promoteMergeContextSamplesTree(
595595

596596
return *ToNode;
597597
}
598+
599+
void SampleContextTracker::createContextLessProfileMap(
600+
SampleProfileMap &ContextLessProfiles) {
601+
std::queue<ContextTrieNode *> NodeQueue;
602+
NodeQueue.push(&RootContext);
603+
604+
while (!NodeQueue.empty()) {
605+
ContextTrieNode *Node = NodeQueue.front();
606+
FunctionSamples *FProfile = Node->getFunctionSamples();
607+
NodeQueue.pop();
608+
609+
if (FProfile) {
610+
// Profile's context can be empty, use ContextNode's func name.
611+
ContextLessProfiles[Node->getFuncName()].merge(*FProfile);
612+
}
613+
614+
for (auto &It : Node->getAllChildContext())
615+
NodeQueue.push(&It.second);
616+
}
617+
}
598618
} // namespace llvm

llvm/tools/llvm-profgen/ProfileGenerator.cpp

Lines changed: 85 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ static cl::opt<bool> UpdateTotalSamples(
9191
llvm::cl::Optional);
9292

9393
extern cl::opt<int> ProfileSummaryCutoffHot;
94+
extern cl::opt<bool> UseContextLessSummary;
9495

9596
static cl::opt<bool> GenCSNestedProfile(
9697
"gen-cs-nested-profile", cl::Hidden, cl::init(true),
@@ -128,14 +129,13 @@ ProfileGeneratorBase::create(ProfiledBinary *Binary,
128129
}
129130

130131
std::unique_ptr<ProfileGeneratorBase>
131-
ProfileGeneratorBase::create(ProfiledBinary *Binary,
132-
const SampleProfileMap &&Profiles,
132+
ProfileGeneratorBase::create(ProfiledBinary *Binary, SampleProfileMap &Profiles,
133133
bool ProfileIsCS) {
134134
std::unique_ptr<ProfileGeneratorBase> Generator;
135135
if (ProfileIsCS) {
136136
if (Binary->useFSDiscriminator())
137137
exitWithError("FS discriminator is not supported in CS profile.");
138-
Generator.reset(new CSProfileGenerator(Binary, std::move(Profiles)));
138+
Generator.reset(new CSProfileGenerator(Binary, Profiles));
139139
} else {
140140
Generator.reset(new ProfileGenerator(Binary, std::move(Profiles)));
141141
}
@@ -403,43 +403,73 @@ void ProfileGeneratorBase::updateFunctionSamples() {
403403

404404
void ProfileGeneratorBase::collectProfiledFunctions() {
405405
std::unordered_set<const BinaryFunction *> ProfiledFunctions;
406-
if (SampleCounters) {
407-
// Go through all the stacks, ranges and branches in sample counters, use
408-
// the start of the range to look up the function it belongs and record the
409-
// function.
410-
for (const auto &CI : *SampleCounters) {
411-
if (const auto *CtxKey = dyn_cast<AddrBasedCtxKey>(CI.first.getPtr())) {
412-
for (auto Addr : CtxKey->Context) {
413-
if (FuncRange *FRange = Binary->findFuncRangeForOffset(
414-
Binary->virtualAddrToOffset(Addr)))
415-
ProfiledFunctions.insert(FRange->Func);
416-
}
417-
}
406+
if (collectFunctionsFromRawProfile(ProfiledFunctions))
407+
Binary->setProfiledFunctions(ProfiledFunctions);
408+
else if (collectFunctionsFromLLVMProfile(ProfiledFunctions))
409+
Binary->setProfiledFunctions(ProfiledFunctions);
410+
else
411+
llvm_unreachable("Unsupported input profile");
412+
}
418413

419-
for (auto Item : CI.second.RangeCounter) {
420-
uint64_t StartOffset = Item.first.first;
421-
if (FuncRange *FRange = Binary->findFuncRangeForOffset(StartOffset))
414+
bool ProfileGeneratorBase::collectFunctionsFromRawProfile(
415+
std::unordered_set<const BinaryFunction *> &ProfiledFunctions) {
416+
if (!SampleCounters)
417+
return false;
418+
// Go through all the stacks, ranges and branches in sample counters, use
419+
// the start of the range to look up the function it belongs and record the
420+
// function.
421+
for (const auto &CI : *SampleCounters) {
422+
if (const auto *CtxKey = dyn_cast<AddrBasedCtxKey>(CI.first.getPtr())) {
423+
for (auto Addr : CtxKey->Context) {
424+
if (FuncRange *FRange = Binary->findFuncRangeForOffset(
425+
Binary->virtualAddrToOffset(Addr)))
422426
ProfiledFunctions.insert(FRange->Func);
423427
}
428+
}
424429

425-
for (auto Item : CI.second.BranchCounter) {
426-
uint64_t SourceOffset = Item.first.first;
427-
uint64_t TargetOffset = Item.first.first;
428-
if (FuncRange *FRange = Binary->findFuncRangeForOffset(SourceOffset))
429-
ProfiledFunctions.insert(FRange->Func);
430-
if (FuncRange *FRange = Binary->findFuncRangeForOffset(TargetOffset))
431-
ProfiledFunctions.insert(FRange->Func);
432-
}
430+
for (auto Item : CI.second.RangeCounter) {
431+
uint64_t StartOffset = Item.first.first;
432+
if (FuncRange *FRange = Binary->findFuncRangeForOffset(StartOffset))
433+
ProfiledFunctions.insert(FRange->Func);
433434
}
434-
} else {
435-
// This is for the case the input is a llvm sample profile.
436-
for (const auto &FS : ProfileMap) {
437-
if (auto *Func = Binary->getBinaryFunction(FS.first.getName()))
438-
ProfiledFunctions.insert(Func);
435+
436+
for (auto Item : CI.second.BranchCounter) {
437+
uint64_t SourceOffset = Item.first.first;
438+
uint64_t TargetOffset = Item.first.first;
439+
if (FuncRange *FRange = Binary->findFuncRangeForOffset(SourceOffset))
440+
ProfiledFunctions.insert(FRange->Func);
441+
if (FuncRange *FRange = Binary->findFuncRangeForOffset(TargetOffset))
442+
ProfiledFunctions.insert(FRange->Func);
439443
}
440444
}
445+
return true;
446+
}
447+
448+
bool ProfileGenerator::collectFunctionsFromLLVMProfile(
449+
std::unordered_set<const BinaryFunction *> &ProfiledFunctions) {
450+
for (const auto &FS : ProfileMap) {
451+
if (auto *Func = Binary->getBinaryFunction(FS.first.getName()))
452+
ProfiledFunctions.insert(Func);
453+
}
454+
return true;
455+
}
441456

442-
Binary->setProfiledFunctions(ProfiledFunctions);
457+
bool CSProfileGenerator::collectFunctionsFromLLVMProfile(
458+
std::unordered_set<const BinaryFunction *> &ProfiledFunctions) {
459+
std::queue<ContextTrieNode *> NodeQueue;
460+
NodeQueue.push(&getRootContext());
461+
while (!NodeQueue.empty()) {
462+
ContextTrieNode *Node = NodeQueue.front();
463+
NodeQueue.pop();
464+
465+
if (!Node->getFuncName().empty())
466+
if (auto *Func = Binary->getBinaryFunction(Node->getFuncName()))
467+
ProfiledFunctions.insert(Func);
468+
469+
for (auto &It : Node->getAllChildContext())
470+
NodeQueue.push(&It.second);
471+
}
472+
return true;
443473
}
444474

445475
FunctionSamples &
@@ -471,7 +501,7 @@ void ProfileGenerator::generateProfile() {
471501
}
472502

473503
void ProfileGenerator::postProcessProfiles() {
474-
computeSummaryAndThreshold();
504+
computeSummaryAndThreshold(ProfileMap);
475505
trimColdProfiles(ProfileMap, ColdCountThreshold);
476506
calculateAndShowDensity(ProfileMap);
477507
}
@@ -965,13 +995,12 @@ void CSProfileGenerator::convertToProfileMap() {
965995
}
966996

967997
void CSProfileGenerator::postProcessProfiles() {
968-
if (SampleCounters)
969-
convertToProfileMap();
970-
971998
// Compute hot/cold threshold based on profile. This will be used for cold
972999
// context profile merging/trimming.
9731000
computeSummaryAndThreshold();
9741001

1002+
convertToProfileMap();
1003+
9751004
// Run global pre-inliner to adjust/merge context profile based on estimated
9761005
// inline decisions.
9771006
if (EnableCSPreInliner) {
@@ -1003,15 +1032,33 @@ void CSProfileGenerator::postProcessProfiles() {
10031032
}
10041033
}
10051034

1006-
void ProfileGeneratorBase::computeSummaryAndThreshold() {
1035+
void ProfileGeneratorBase::computeSummaryAndThreshold(
1036+
SampleProfileMap &Profiles) {
10071037
SampleProfileSummaryBuilder Builder(ProfileSummaryBuilder::DefaultCutoffs);
1008-
Summary = Builder.computeSummaryForProfiles(ProfileMap);
1038+
Summary = Builder.computeSummaryForProfiles(Profiles);
10091039
HotCountThreshold = ProfileSummaryBuilder::getHotCountThreshold(
10101040
(Summary->getDetailedSummary()));
10111041
ColdCountThreshold = ProfileSummaryBuilder::getColdCountThreshold(
10121042
(Summary->getDetailedSummary()));
10131043
}
10141044

1045+
void CSProfileGenerator::computeSummaryAndThreshold() {
1046+
// Always merge and use context-less profile map to compute summary.
1047+
SampleProfileMap ContextLessProfiles;
1048+
ContextTracker.createContextLessProfileMap(ContextLessProfiles);
1049+
1050+
// Set the flag below to avoid merging the profile again in
1051+
// computeSummaryAndThreshold
1052+
FunctionSamples::ProfileIsCS = false;
1053+
assert(
1054+
(!UseContextLessSummary.getNumOccurrences() || UseContextLessSummary) &&
1055+
"Don't set --profile-summary-contextless to false for profile "
1056+
"generation");
1057+
ProfileGeneratorBase::computeSummaryAndThreshold(ContextLessProfiles);
1058+
// Recover the old value.
1059+
FunctionSamples::ProfileIsCS = true;
1060+
}
1061+
10151062
void ProfileGeneratorBase::extractProbesFromRange(
10161063
const RangeSample &RangeCounter, ProbeCounterMap &ProbeCounter,
10171064
bool FindDisjointRanges) {

llvm/tools/llvm-profgen/ProfileGenerator.h

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ using ProbeCounterMap =
3232
class ProfileGeneratorBase {
3333

3434
public:
35+
ProfileGeneratorBase(ProfiledBinary *Binary) : Binary(Binary){};
3536
ProfileGeneratorBase(ProfiledBinary *Binary,
3637
const ContextSampleCounterMap *Counters)
3738
: Binary(Binary), SampleCounters(Counters){};
@@ -44,7 +45,7 @@ class ProfileGeneratorBase {
4445
create(ProfiledBinary *Binary, const ContextSampleCounterMap *Counters,
4546
bool profileIsCS);
4647
static std::unique_ptr<ProfileGeneratorBase>
47-
create(ProfiledBinary *Binary, const SampleProfileMap &&ProfileMap,
48+
create(ProfiledBinary *Binary, SampleProfileMap &ProfileMap,
4849
bool profileIsCS);
4950
virtual void generateProfile() = 0;
5051
void write();
@@ -109,7 +110,7 @@ class ProfileGeneratorBase {
109110

110111
StringRef getCalleeNameForOffset(uint64_t TargetOffset);
111112

112-
void computeSummaryAndThreshold();
113+
void computeSummaryAndThreshold(SampleProfileMap &ProfileMap);
113114

114115
void calculateAndShowDensity(const SampleProfileMap &Profiles);
115116

@@ -120,6 +121,13 @@ class ProfileGeneratorBase {
120121

121122
void collectProfiledFunctions();
122123

124+
bool collectFunctionsFromRawProfile(
125+
std::unordered_set<const BinaryFunction *> &ProfiledFunctions);
126+
127+
// Collect profiled Functions for llvm sample profile input.
128+
virtual bool collectFunctionsFromLLVMProfile(
129+
std::unordered_set<const BinaryFunction *> &ProfiledFunctions) = 0;
130+
123131
// Thresholds from profile summary to answer isHotCount/isColdCount queries.
124132
uint64_t HotCountThreshold;
125133

@@ -166,15 +174,17 @@ class ProfileGenerator : public ProfileGeneratorBase {
166174
void postProcessProfiles();
167175
void trimColdProfiles(const SampleProfileMap &Profiles,
168176
uint64_t ColdCntThreshold);
177+
bool collectFunctionsFromLLVMProfile(
178+
std::unordered_set<const BinaryFunction *> &ProfiledFunctions) override;
169179
};
170180

171181
class CSProfileGenerator : public ProfileGeneratorBase {
172182
public:
173183
CSProfileGenerator(ProfiledBinary *Binary,
174184
const ContextSampleCounterMap *Counters)
175185
: ProfileGeneratorBase(Binary, Counters){};
176-
CSProfileGenerator(ProfiledBinary *Binary, const SampleProfileMap &&Profiles)
177-
: ProfileGeneratorBase(Binary, std::move(Profiles)){};
186+
CSProfileGenerator(ProfiledBinary *Binary, SampleProfileMap &Profiles)
187+
: ProfileGeneratorBase(Binary), ContextTracker(Profiles, nullptr){};
178188
void generateProfile() override;
179189

180190
// Trim the context stack at a given depth.
@@ -343,6 +353,11 @@ class CSProfileGenerator : public ProfileGeneratorBase {
343353

344354
void convertToProfileMap();
345355

356+
void computeSummaryAndThreshold();
357+
358+
bool collectFunctionsFromLLVMProfile(
359+
std::unordered_set<const BinaryFunction *> &ProfiledFunctions) override;
360+
346361
ContextTrieNode &getRootContext() { return ContextTracker.getRootContext(); };
347362

348363
// The container for holding the FunctionSamples used by context trie.

llvm/tools/llvm-profgen/llvm-profgen.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,7 @@ int main(int argc, const char *argv[]) {
164164
std::move(ReaderOrErr.get());
165165
Reader->read();
166166
std::unique_ptr<ProfileGeneratorBase> Generator =
167-
ProfileGeneratorBase::create(Binary.get(),
168-
std::move(Reader->getProfiles()),
167+
ProfileGeneratorBase::create(Binary.get(), Reader->getProfiles(),
169168
Reader->profileIsCS());
170169
Generator->generateProfile();
171170
Generator->write();

0 commit comments

Comments
 (0)