Skip to content

Commit 67060cb

Browse files
committed
[nfc] Improve testability of PGOInstrumentationGen
1 parent 2adc012 commit 67060cb

File tree

6 files changed

+83
-52
lines changed

6 files changed

+83
-52
lines changed

llvm/include/llvm/Transforms/Instrumentation/PGOInstrumentation.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,18 @@ class PGOInstrumentationGenCreateVar
5353
bool ProfileSampling;
5454
};
5555

56+
enum class PGOInstrumentationType { Invalid = 0, FDO, CSFDO, CTXPROF };
5657
/// The instrumentation (profile-instr-gen) pass for IR based PGO.
5758
class PGOInstrumentationGen : public PassInfoMixin<PGOInstrumentationGen> {
5859
public:
59-
PGOInstrumentationGen(bool IsCS = false) : IsCS(IsCS) {}
60+
PGOInstrumentationGen(
61+
PGOInstrumentationType InstrumentationType = PGOInstrumentationType ::FDO)
62+
: InstrumentationType(InstrumentationType) {}
6063
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
6164

6265
private:
6366
// If this is a context sensitive instrumentation.
64-
bool IsCS;
67+
const PGOInstrumentationType InstrumentationType;
6568
};
6669

6770
/// The profile annotation (profile-instr-use) pass for IR based PGO.

llvm/lib/Passes/PassBuilderPipelines.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -844,7 +844,8 @@ void PassBuilder::addPGOInstrPasses(ModulePassManager &MPM,
844844
}
845845

846846
// Perform PGO instrumentation.
847-
MPM.addPass(PGOInstrumentationGen(IsCS));
847+
MPM.addPass(PGOInstrumentationGen(IsCS ? PGOInstrumentationType::CSFDO
848+
: PGOInstrumentationType::FDO));
848849

849850
addPostPGOLoopRotation(MPM, Level);
850851
// Add the profile lowering pass.
@@ -879,7 +880,8 @@ void PassBuilder::addPGOInstrPassesForO0(
879880
}
880881

881882
// Perform PGO instrumentation.
882-
MPM.addPass(PGOInstrumentationGen(IsCS));
883+
MPM.addPass(PGOInstrumentationGen(IsCS ? PGOInstrumentationType::CSFDO
884+
: PGOInstrumentationType::FDO));
883885
// Add the profile lowering pass.
884886
InstrProfOptions Options;
885887
if (!ProfileFile.empty())
@@ -1193,7 +1195,7 @@ PassBuilder::buildModuleSimplificationPipeline(OptimizationLevel Level,
11931195
PGOOpt->ProfileFile, PGOOpt->ProfileRemappingFile,
11941196
PGOOpt->FS);
11951197
} else if (IsCtxProfGen || IsCtxProfUse) {
1196-
MPM.addPass(PGOInstrumentationGen(false));
1198+
MPM.addPass(PGOInstrumentationGen(PGOInstrumentationType::CTXPROF));
11971199
// In pre-link, we just want the instrumented IR. We use the contextual
11981200
// profile in the post-thinlink phase.
11991201
// The instrumentation will be removed in post-thinlink after IPO.

llvm/lib/Passes/PassRegistry.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ MODULE_PASS("constmerge", ConstantMergePass())
5656
MODULE_PASS("coro-cleanup", CoroCleanupPass())
5757
MODULE_PASS("coro-early", CoroEarlyPass())
5858
MODULE_PASS("cross-dso-cfi", CrossDSOCFIPass())
59+
MODULE_PASS("ctx-instr-gen",
60+
PGOInstrumentationGen(PGOInstrumentationType::CTXPROF))
5961
MODULE_PASS("deadargelim", DeadArgumentEliminationPass())
6062
MODULE_PASS("debugify", NewPMDebugifyPass())
6163
MODULE_PASS("dfsan", DataFlowSanitizerPass())

llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp

Lines changed: 68 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@
110110
#include "llvm/Transforms/Instrumentation.h"
111111
#include "llvm/Transforms/Instrumentation/BlockCoverageInference.h"
112112
#include "llvm/Transforms/Instrumentation/CFGMST.h"
113-
#include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h"
114113
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
115114
#include "llvm/Transforms/Utils/MisExpect.h"
116115
#include "llvm/Transforms/Utils/ModuleUtils.h"
@@ -321,7 +320,6 @@ static cl::opt<unsigned> PGOFunctionCriticalEdgeThreshold(
321320
" greater than this threshold."));
322321

323322
extern cl::opt<unsigned> MaxNumVTableAnnotations;
324-
extern cl::opt<std::string> UseCtxProfile;
325323

326324
namespace llvm {
327325
// Command line option to turn on CFG dot dump after profile annotation.
@@ -339,21 +337,43 @@ extern cl::opt<bool> EnableVTableProfileUse;
339337
extern cl::opt<InstrProfCorrelator::ProfCorrelatorKind> ProfileCorrelate;
340338
} // namespace llvm
341339

342-
bool shouldInstrumentForCtxProf() {
343-
return PGOCtxProfLoweringPass::isCtxIRPGOInstrEnabled() ||
344-
!UseCtxProfile.empty();
345-
}
346-
bool shouldInstrumentEntryBB() {
347-
return PGOInstrumentEntry || shouldInstrumentForCtxProf();
348-
}
340+
namespace {
341+
class FunctionInstrumenter final {
342+
Module &M;
343+
Function &F;
344+
TargetLibraryInfo &TLI;
345+
std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers;
346+
BranchProbabilityInfo *const BPI;
347+
BlockFrequencyInfo *const BFI;
349348

350-
// FIXME(mtrofin): re-enable this for ctx profiling, for non-indirect calls. Ctx
351-
// profiling implicitly captures indirect call cases, but not other values.
352-
// Supporting other values is relatively straight-forward - just another counter
353-
// range within the context.
354-
bool isValueProfilingDisabled() {
355-
return DisableValueProfiling || shouldInstrumentForCtxProf();
356-
}
349+
const PGOInstrumentationType InstrumentationType;
350+
351+
// FIXME(mtrofin): re-enable this for ctx profiling, for non-indirect calls.
352+
// Ctx profiling implicitly captures indirect call cases, but not other
353+
// values. Supporting other values is relatively straight-forward - just
354+
// another counter range within the context.
355+
bool isValueProfilingDisabled() const {
356+
return DisableValueProfiling ||
357+
InstrumentationType == PGOInstrumentationType::CTXPROF;
358+
}
359+
360+
bool shouldInstrumentEntryBB() const {
361+
return PGOInstrumentEntry ||
362+
InstrumentationType == PGOInstrumentationType::CTXPROF;
363+
}
364+
365+
public:
366+
FunctionInstrumenter(
367+
Module &M, Function &F, TargetLibraryInfo &TLI,
368+
std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers,
369+
BranchProbabilityInfo *BPI = nullptr, BlockFrequencyInfo *BFI = nullptr,
370+
PGOInstrumentationType InstrumentationType = PGOInstrumentationType::FDO)
371+
: M(M), F(F), TLI(TLI), ComdatMembers(ComdatMembers), BPI(BPI), BFI(BFI),
372+
InstrumentationType(InstrumentationType) {}
373+
374+
void instrument();
375+
};
376+
} // namespace
357377

358378
// Return a string describing the branch condition that can be
359379
// used in static branch probability heuristics:
@@ -395,13 +415,16 @@ static const char *ValueProfKindDescr[] = {
395415

396416
// Create a COMDAT variable INSTR_PROF_RAW_VERSION_VAR to make the runtime
397417
// aware this is an ir_level profile so it can set the version flag.
398-
static GlobalVariable *createIRLevelProfileFlagVar(Module &M, bool IsCS) {
418+
static GlobalVariable *
419+
createIRLevelProfileFlagVar(Module &M,
420+
PGOInstrumentationType InstrumentationType) {
399421
const StringRef VarName(INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR));
400422
Type *IntTy64 = Type::getInt64Ty(M.getContext());
401423
uint64_t ProfileVersion = (INSTR_PROF_RAW_VERSION | VARIANT_MASK_IR_PROF);
402-
if (IsCS)
424+
if (InstrumentationType == PGOInstrumentationType::CSFDO)
403425
ProfileVersion |= VARIANT_MASK_CSIR_PROF;
404-
if (shouldInstrumentEntryBB())
426+
if (PGOInstrumentEntry ||
427+
InstrumentationType == PGOInstrumentationType::CTXPROF)
405428
ProfileVersion |= VARIANT_MASK_INSTR_ENTRY;
406429
if (DebugInfoCorrelate || ProfileCorrelate == InstrProfCorrelator::DEBUG_INFO)
407430
ProfileVersion |= VARIANT_MASK_DBG_CORRELATE;
@@ -871,31 +894,28 @@ populateEHOperandBundle(VPCandidateInfo &Cand,
871894

872895
// Visit all edge and instrument the edges not in MST, and do value profiling.
873896
// Critical edges will be split.
874-
static void instrumentOneFunc(
875-
Function &F, Module *M, TargetLibraryInfo &TLI, BranchProbabilityInfo *BPI,
876-
BlockFrequencyInfo *BFI,
877-
std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers,
878-
bool IsCS) {
897+
void FunctionInstrumenter::instrument() {
879898
if (!PGOBlockCoverage) {
880899
// Split indirectbr critical edges here before computing the MST rather than
881900
// later in getInstrBB() to avoid invalidating it.
882901
SplitIndirectBrCriticalEdges(F, /*IgnoreBlocksWithoutPHI=*/false, BPI, BFI);
883902
}
884903

885904
FuncPGOInstrumentation<PGOEdge, PGOBBInfo> FuncInfo(
886-
F, TLI, ComdatMembers, true, BPI, BFI, IsCS, shouldInstrumentEntryBB(),
887-
PGOBlockCoverage);
905+
F, TLI, ComdatMembers, true, BPI, BFI,
906+
InstrumentationType == PGOInstrumentationType::CSFDO,
907+
shouldInstrumentEntryBB(), PGOBlockCoverage);
888908

889909
auto Name = FuncInfo.FuncNameVar;
890-
auto CFGHash = ConstantInt::get(Type::getInt64Ty(M->getContext()),
891-
FuncInfo.FunctionHash);
910+
auto CFGHash =
911+
ConstantInt::get(Type::getInt64Ty(M.getContext()), FuncInfo.FunctionHash);
892912
if (PGOFunctionEntryCoverage) {
893913
auto &EntryBB = F.getEntryBlock();
894914
IRBuilder<> Builder(&EntryBB, EntryBB.getFirstInsertionPt());
895915
// llvm.instrprof.cover(i8* <name>, i64 <hash>, i32 <num-counters>,
896916
// i32 <index>)
897917
Builder.CreateCall(
898-
Intrinsic::getDeclaration(M, Intrinsic::instrprof_cover),
918+
Intrinsic::getDeclaration(&M, Intrinsic::instrprof_cover),
899919
{Name, CFGHash, Builder.getInt32(1), Builder.getInt32(0)});
900920
return;
901921
}
@@ -905,9 +925,9 @@ static void instrumentOneFunc(
905925
unsigned NumCounters =
906926
InstrumentBBs.size() + FuncInfo.SIVisitor.getNumOfSelectInsts();
907927

908-
if (shouldInstrumentForCtxProf()) {
928+
if (InstrumentationType == PGOInstrumentationType::CTXPROF) {
909929
auto *CSIntrinsic =
910-
Intrinsic::getDeclaration(M, Intrinsic::instrprof_callsite);
930+
Intrinsic::getDeclaration(&M, Intrinsic::instrprof_callsite);
911931
// We want to count the instrumentable callsites, then instrument them. This
912932
// is because the llvm.instrprof.callsite intrinsic has an argument (like
913933
// the other instrprof intrinsics) capturing the total number of
@@ -950,7 +970,7 @@ static void instrumentOneFunc(
950970
// llvm.instrprof.timestamp(i8* <name>, i64 <hash>, i32 <num-counters>,
951971
// i32 <index>)
952972
Builder.CreateCall(
953-
Intrinsic::getDeclaration(M, Intrinsic::instrprof_timestamp),
973+
Intrinsic::getDeclaration(&M, Intrinsic::instrprof_timestamp),
954974
{Name, CFGHash, Builder.getInt32(NumCounters), Builder.getInt32(I)});
955975
I += PGOBlockCoverage ? 8 : 1;
956976
}
@@ -962,9 +982,9 @@ static void instrumentOneFunc(
962982
// llvm.instrprof.increment(i8* <name>, i64 <hash>, i32 <num-counters>,
963983
// i32 <index>)
964984
Builder.CreateCall(
965-
Intrinsic::getDeclaration(M, PGOBlockCoverage
966-
? Intrinsic::instrprof_cover
967-
: Intrinsic::instrprof_increment),
985+
Intrinsic::getDeclaration(&M, PGOBlockCoverage
986+
? Intrinsic::instrprof_cover
987+
: Intrinsic::instrprof_increment),
968988
{Name, CFGHash, Builder.getInt32(NumCounters), Builder.getInt32(I++)});
969989
}
970990

@@ -1011,7 +1031,7 @@ static void instrumentOneFunc(
10111031
SmallVector<OperandBundleDef, 1> OpBundles;
10121032
populateEHOperandBundle(Cand, BlockColors, OpBundles);
10131033
Builder.CreateCall(
1014-
Intrinsic::getDeclaration(M, Intrinsic::instrprof_value_profile),
1034+
Intrinsic::getDeclaration(&M, Intrinsic::instrprof_value_profile),
10151035
{FuncInfo.FuncNameVar, Builder.getInt64(FuncInfo.FunctionHash),
10161036
ToProfile, Builder.getInt32(Kind), Builder.getInt32(SiteIndex++)},
10171037
OpBundles);
@@ -1746,7 +1766,7 @@ static uint32_t getMaxNumAnnotations(InstrProfValueKind ValueProfKind) {
17461766

17471767
// Traverse all valuesites and annotate the instructions for all value kind.
17481768
void PGOUseFunc::annotateValueSites() {
1749-
if (isValueProfilingDisabled())
1769+
if (DisableValueProfiling)
17501770
return;
17511771

17521772
// Create the PGOFuncName meta data.
@@ -1861,11 +1881,12 @@ static bool skipPGOGen(const Function &F) {
18611881
static bool InstrumentAllFunctions(
18621882
Module &M, function_ref<TargetLibraryInfo &(Function &)> LookupTLI,
18631883
function_ref<BranchProbabilityInfo *(Function &)> LookupBPI,
1864-
function_ref<BlockFrequencyInfo *(Function &)> LookupBFI, bool IsCS) {
1884+
function_ref<BlockFrequencyInfo *(Function &)> LookupBFI,
1885+
PGOInstrumentationType InstrumentationType) {
18651886
// For the context-sensitve instrumentation, we should have a separated pass
18661887
// (before LTO/ThinLTO linking) to create these variables.
1867-
if (!IsCS && !shouldInstrumentForCtxProf())
1868-
createIRLevelProfileFlagVar(M, /*IsCS=*/false);
1888+
if (InstrumentationType == PGOInstrumentationType::FDO)
1889+
createIRLevelProfileFlagVar(M, InstrumentationType);
18691890

18701891
Triple TT(M.getTargetTriple());
18711892
LLVMContext &Ctx = M.getContext();
@@ -1884,7 +1905,9 @@ static bool InstrumentAllFunctions(
18841905
auto &TLI = LookupTLI(F);
18851906
auto *BPI = LookupBPI(F);
18861907
auto *BFI = LookupBFI(F);
1887-
instrumentOneFunc(F, &M, TLI, BPI, BFI, ComdatMembers, IsCS);
1908+
FunctionInstrumenter FI(M, F, TLI, ComdatMembers, BPI, BFI,
1909+
InstrumentationType);
1910+
FI.instrument();
18881911
}
18891912
return true;
18901913
}
@@ -1894,7 +1917,8 @@ PGOInstrumentationGenCreateVar::run(Module &M, ModuleAnalysisManager &MAM) {
18941917
createProfileFileNameVar(M, CSInstrName);
18951918
// The variable in a comdat may be discarded by LTO. Ensure the declaration
18961919
// will be retained.
1897-
appendToCompilerUsed(M, createIRLevelProfileFlagVar(M, /*IsCS=*/true));
1920+
appendToCompilerUsed(
1921+
M, createIRLevelProfileFlagVar(M, PGOInstrumentationType::CSFDO));
18981922
if (ProfileSampling)
18991923
createProfileSamplingVar(M);
19001924
PreservedAnalyses PA;
@@ -1916,7 +1940,8 @@ PreservedAnalyses PGOInstrumentationGen::run(Module &M,
19161940
return &FAM.getResult<BlockFrequencyAnalysis>(F);
19171941
};
19181942

1919-
if (!InstrumentAllFunctions(M, LookupTLI, LookupBPI, LookupBFI, IsCS))
1943+
if (!InstrumentAllFunctions(M, LookupTLI, LookupBPI, LookupBFI,
1944+
InstrumentationType))
19201945
return PreservedAnalyses::all();
19211946

19221947
return PreservedAnalyses::none();
@@ -2115,7 +2140,6 @@ static bool annotateAllFunctions(
21152140
bool InstrumentFuncEntry = PGOReader->instrEntryBBEnabled();
21162141
if (PGOInstrumentEntry.getNumOccurrences() > 0)
21172142
InstrumentFuncEntry = PGOInstrumentEntry;
2118-
InstrumentFuncEntry |= shouldInstrumentForCtxProf();
21192143

21202144
bool HasSingleByteCoverage = PGOReader->hasSingleByteCoverage();
21212145
for (auto &F : M) {

llvm/test/Transforms/PGOProfile/ctx-instrumentation-invalid-roots.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
; RUN: not opt -passes=pgo-instr-gen,ctx-instr-lower -profile-context-root=good \
1+
; RUN: not opt -passes=ctx-instr-gen,ctx-instr-lower -profile-context-root=good \
22
; RUN: -profile-context-root=bad \
33
; RUN: -S < %s 2>&1 | FileCheck %s
44

llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals all --version 4
2-
; RUN: opt -passes=pgo-instr-gen -profile-context-root=an_entrypoint \
2+
; RUN: opt -passes=ctx-instr-gen -profile-context-root=an_entrypoint \
33
; RUN: -S < %s | FileCheck --check-prefix=INSTRUMENT %s
4-
; RUN: opt -passes=pgo-instr-gen,assign-guid,ctx-instr-lower -profile-context-root=an_entrypoint \
4+
; RUN: opt -passes=ctx-instr-gen,assign-guid,ctx-instr-lower -profile-context-root=an_entrypoint \
55
; RUN: -profile-context-root=another_entrypoint_no_callees \
66
; RUN: -S < %s | FileCheck --check-prefix=LOWERING %s
77

0 commit comments

Comments
 (0)