Skip to content

Commit b7d9322

Browse files
committed
[FS-AFDO] Load pseudo probe profile on MIR
This change enables loading pseudo-probe based profile on MIR. Different from the IR profile loader, callsites are excluded from MIR profile loading since they are not assinged a FS discriminator. Using zero as the discriminator is not accurate and would undo the distribution work done by the IR loader based on pseudo probe distribution factor. We reply on block probes only for FS profile loading. Some refactoring is done to the IR profile loader so that `getProbeWeight` can be shared by both loaders. Reviewed By: wenlei Differential Revision: https://reviews.llvm.org/D148584
1 parent 345fd0c commit b7d9322

File tree

10 files changed

+545
-140
lines changed

10 files changed

+545
-140
lines changed

llvm/include/llvm/IR/PseudoProbe.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
namespace llvm {
2222

2323
class Instruction;
24+
class DILocation;
2425

2526
constexpr const char *PseudoProbeDescMetadataName = "llvm.pseudo_probe_desc";
2627

@@ -78,10 +79,22 @@ struct PseudoProbeDwarfDiscriminator {
7879
constexpr static uint8_t FullDistributionFactor = 100;
7980
};
8081

82+
class PseudoProbeDescriptor {
83+
uint64_t FunctionGUID;
84+
uint64_t FunctionHash;
85+
86+
public:
87+
PseudoProbeDescriptor(uint64_t GUID, uint64_t Hash)
88+
: FunctionGUID(GUID), FunctionHash(Hash) {}
89+
uint64_t getFunctionGUID() const { return FunctionGUID; }
90+
uint64_t getFunctionHash() const { return FunctionHash; }
91+
};
92+
8193
struct PseudoProbe {
8294
uint32_t Id;
8395
uint32_t Type;
8496
uint32_t Attr;
97+
uint32_t Discriminator;
8598
// Distribution factor that estimates the portion of the real execution count.
8699
// A saturated distribution factor stands for 1.0 or 100%. A pesudo probe has
87100
// a factor with the value ranged from 0.0 to 1.0.

llvm/include/llvm/Transforms/IPO/SampleProfileProbe.h

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,6 @@ using ProbeFactorMap = std::unordered_map<std::pair<uint64_t, uint64_t>, float,
4040
pair_hash<uint64_t, uint64_t>>;
4141
using FuncProbeFactorMap = StringMap<ProbeFactorMap>;
4242

43-
class PseudoProbeDescriptor {
44-
uint64_t FunctionGUID;
45-
uint64_t FunctionHash;
46-
47-
public:
48-
PseudoProbeDescriptor(uint64_t GUID, uint64_t Hash)
49-
: FunctionGUID(GUID), FunctionHash(Hash) {}
50-
uint64_t getFunctionGUID() const { return FunctionGUID; }
51-
uint64_t getFunctionHash() const { return FunctionHash; }
52-
};
5343

5444
// A pseudo probe verifier that can be run after each IR passes to detect the
5545
// violation of updating probe factors. In principle, the sum of distribution
@@ -78,20 +68,6 @@ class PseudoProbeVerifier {
7868
const ProbeFactorMap &ProbeFactors);
7969
};
8070

81-
// This class serves sample counts correlation for SampleProfileLoader by
82-
// analyzing pseudo probes and their function descriptors injected by
83-
// SampleProfileProber.
84-
class PseudoProbeManager {
85-
DenseMap<uint64_t, PseudoProbeDescriptor> GUIDToProbeDescMap;
86-
87-
const PseudoProbeDescriptor *getDesc(const Function &F) const;
88-
89-
public:
90-
PseudoProbeManager(const Module &M);
91-
bool moduleIsProbed(const Module &M) const;
92-
bool profileIsValid(const Function &F, const FunctionSamples &Samples) const;
93-
};
94-
9571
/// Sample profile pseudo prober.
9672
///
9773
/// Insert pseudo probes for block sampling and value sampling.

llvm/include/llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "llvm/IR/Instruction.h"
3535
#include "llvm/IR/Instructions.h"
3636
#include "llvm/IR/Module.h"
37+
#include "llvm/IR/PseudoProbe.h"
3738
#include "llvm/ProfileData/SampleProf.h"
3839
#include "llvm/ProfileData/SampleProfReader.h"
3940
#include "llvm/Support/CommandLine.h"
@@ -80,6 +81,55 @@ template <> struct IRTraits<BasicBlock> {
8081

8182
} // end namespace afdo_detail
8283

84+
// This class serves sample counts correlation for SampleProfileLoader by
85+
// analyzing pseudo probes and their function descriptors injected by
86+
// SampleProfileProber.
87+
class PseudoProbeManager {
88+
DenseMap<uint64_t, PseudoProbeDescriptor> GUIDToProbeDescMap;
89+
90+
const PseudoProbeDescriptor *getDesc(const Function &F) const {
91+
auto I = GUIDToProbeDescMap.find(
92+
Function::getGUID(FunctionSamples::getCanonicalFnName(F)));
93+
return I == GUIDToProbeDescMap.end() ? nullptr : &I->second;
94+
}
95+
96+
public:
97+
PseudoProbeManager(const Module &M) {
98+
if (NamedMDNode *FuncInfo =
99+
M.getNamedMetadata(PseudoProbeDescMetadataName)) {
100+
for (const auto *Operand : FuncInfo->operands()) {
101+
const auto *MD = cast<MDNode>(Operand);
102+
auto GUID = mdconst::dyn_extract<ConstantInt>(MD->getOperand(0))
103+
->getZExtValue();
104+
auto Hash = mdconst::dyn_extract<ConstantInt>(MD->getOperand(1))
105+
->getZExtValue();
106+
GUIDToProbeDescMap.try_emplace(GUID, PseudoProbeDescriptor(GUID, Hash));
107+
}
108+
}
109+
}
110+
111+
bool moduleIsProbed(const Module &M) const {
112+
return M.getNamedMetadata(PseudoProbeDescMetadataName);
113+
}
114+
115+
bool profileIsValid(const Function &F, const FunctionSamples &Samples) const {
116+
const auto *Desc = getDesc(F);
117+
if (!Desc) {
118+
LLVM_DEBUG(dbgs() << "Probe descriptor missing for Function "
119+
<< F.getName() << "\n");
120+
return false;
121+
}
122+
if (Desc->getFunctionHash() != Samples.getFunctionHash()) {
123+
LLVM_DEBUG(dbgs() << "Hash mismatch for Function " << F.getName()
124+
<< "\n");
125+
return false;
126+
}
127+
return true;
128+
}
129+
};
130+
131+
132+
83133
extern cl::opt<bool> SampleProfileUseProfi;
84134

85135
template <typename BT> class SampleProfileLoaderBaseImpl {
@@ -137,6 +187,7 @@ template <typename BT> class SampleProfileLoaderBaseImpl {
137187
unsigned getFunctionLoc(FunctionT &Func);
138188
virtual ErrorOr<uint64_t> getInstWeight(const InstructionT &Inst);
139189
ErrorOr<uint64_t> getInstWeightImpl(const InstructionT &Inst);
190+
virtual ErrorOr<uint64_t> getProbeWeight(const InstructionT &Inst);
140191
ErrorOr<uint64_t> getBlockWeight(const BasicBlockT *BB);
141192
mutable DenseMap<const DILocation *, const FunctionSamples *>
142193
DILocation2SampleMap;
@@ -212,6 +263,9 @@ template <typename BT> class SampleProfileLoaderBaseImpl {
212263
/// Profile reader object.
213264
std::unique_ptr<SampleProfileReader> Reader;
214265

266+
// A pseudo probe helper to correlate the imported sample counts.
267+
std::unique_ptr<PseudoProbeManager> ProbeManager;
268+
215269
/// Samples collected for the body of this function.
216270
FunctionSamples *Samples = nullptr;
217271

@@ -299,6 +353,8 @@ void SampleProfileLoaderBaseImpl<BT>::printBlockWeight(
299353
template <typename BT>
300354
ErrorOr<uint64_t>
301355
SampleProfileLoaderBaseImpl<BT>::getInstWeight(const InstructionT &Inst) {
356+
if (FunctionSamples::ProfileIsProbeBased)
357+
return getProbeWeight(Inst);
302358
return getInstWeightImpl(Inst);
303359
}
304360

@@ -346,6 +402,65 @@ SampleProfileLoaderBaseImpl<BT>::getInstWeightImpl(const InstructionT &Inst) {
346402
return R;
347403
}
348404

405+
// Here use error_code to represent: 1) The dangling probe. 2) Ignore the weight
406+
// of non-probe instruction. So if all instructions of the BB give error_code,
407+
// tell the inference algorithm to infer the BB weight.
408+
template <typename BT>
409+
ErrorOr<uint64_t>
410+
SampleProfileLoaderBaseImpl<BT>::getProbeWeight(const InstructionT &Inst) {
411+
assert(FunctionSamples::ProfileIsProbeBased &&
412+
"Profile is not pseudo probe based");
413+
std::optional<PseudoProbe> Probe = extractProbe(Inst);
414+
// Ignore the non-probe instruction. If none of the instruction in the BB is
415+
// probe, we choose to infer the BB's weight.
416+
if (!Probe)
417+
return std::error_code();
418+
419+
const FunctionSamples *FS = findFunctionSamples(Inst);
420+
// If none of the instruction has FunctionSample, we choose to return zero
421+
// value sample to indicate the BB is cold. This could happen when the
422+
// instruction is from inlinee and no profile data is found.
423+
// FIXME: This should not be affected by the source drift issue as 1) if the
424+
// newly added function is top-level inliner, it won't match the CFG checksum
425+
// in the function profile or 2) if it's the inlinee, the inlinee should have
426+
// a profile, otherwise it wouldn't be inlined. For non-probe based profile,
427+
// we can improve it by adding a switch for profile-sample-block-accurate for
428+
// block level counts in the future.
429+
if (!FS)
430+
return 0;
431+
432+
auto R = FS->findSamplesAt(Probe->Id, Probe->Discriminator);
433+
if (R) {
434+
uint64_t Samples = R.get() * Probe->Factor;
435+
bool FirstMark = CoverageTracker.markSamplesUsed(FS, Probe->Id, 0, Samples);
436+
if (FirstMark) {
437+
ORE->emit([&]() {
438+
OptRemarkAnalysisT Remark(DEBUG_TYPE, "AppliedSamples", &Inst);
439+
Remark << "Applied " << ore::NV("NumSamples", Samples);
440+
Remark << " samples from profile (ProbeId=";
441+
Remark << ore::NV("ProbeId", Probe->Id);
442+
if (Probe->Discriminator) {
443+
Remark << ".";
444+
Remark << ore::NV("Discriminator", Probe->Discriminator);
445+
}
446+
Remark << ", Factor=";
447+
Remark << ore::NV("Factor", Probe->Factor);
448+
Remark << ", OriginalSamples=";
449+
Remark << ore::NV("OriginalSamples", R.get());
450+
Remark << ")";
451+
return Remark;
452+
});
453+
}
454+
LLVM_DEBUG({dbgs() << " " << Probe->Id;
455+
if (Probe->Discriminator)
456+
dbgs() << "." << Probe->Discriminator;
457+
dbgs() << ":" << Inst << " - weight: " << R.get()
458+
<< " - factor: " << format("%0.2f", Probe->Factor) << ")\n";});
459+
return Samples;
460+
}
461+
return R;
462+
}
463+
349464
/// Compute the weight of a basic block.
350465
///
351466
/// The weight of basic block \p BB is the maximum weight of all the

llvm/lib/CodeGen/MIRSampleProfile.cpp

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,21 @@
1818
#include "llvm/CodeGen/MachineBlockFrequencyInfo.h"
1919
#include "llvm/CodeGen/MachineBranchProbabilityInfo.h"
2020
#include "llvm/CodeGen/MachineDominators.h"
21+
#include "llvm/CodeGen/MachineInstr.h"
2122
#include "llvm/CodeGen/MachineLoopInfo.h"
2223
#include "llvm/CodeGen/MachineOptimizationRemarkEmitter.h"
2324
#include "llvm/CodeGen/MachinePostDominators.h"
2425
#include "llvm/CodeGen/Passes.h"
2526
#include "llvm/IR/Function.h"
27+
#include "llvm/IR/PseudoProbe.h"
2628
#include "llvm/InitializePasses.h"
2729
#include "llvm/Support/CommandLine.h"
2830
#include "llvm/Support/Debug.h"
2931
#include "llvm/Support/VirtualFileSystem.h"
3032
#include "llvm/Support/raw_ostream.h"
3133
#include "llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h"
3234
#include "llvm/Transforms/Utils/SampleProfileLoaderBaseUtil.h"
35+
#include <optional>
3336

3437
using namespace llvm;
3538
using namespace sampleprof;
@@ -92,6 +95,22 @@ extern cl::opt<GVDAGType> ViewBlockLayoutWithBFI;
9295
// Defined in Analysis/BlockFrequencyInfo.cpp: -view-bfi-func-name=
9396
extern cl::opt<std::string> ViewBlockFreqFuncName;
9497

98+
std::optional<PseudoProbe> extractProbe(const MachineInstr &MI) {
99+
if (MI.isPseudoProbe()) {
100+
PseudoProbe Probe;
101+
Probe.Id = MI.getOperand(1).getImm();
102+
Probe.Type = MI.getOperand(2).getImm();
103+
Probe.Attr = MI.getOperand(3).getImm();
104+
Probe.Factor = 1;
105+
DILocation *DebugLoc = MI.getDebugLoc();
106+
Probe.Discriminator = DebugLoc ? DebugLoc->getDiscriminator() : 0;
107+
return Probe;
108+
}
109+
110+
// Ignore callsite probes since they do not have FS discriminators.
111+
return std::nullopt;
112+
}
113+
95114
namespace afdo_detail {
96115
template <> struct IRTraits<MachineBasicBlock> {
97116
using InstructionT = MachineInstr;
@@ -167,6 +186,8 @@ class MIRProfileLoader final
167186

168187
bool ProfileIsValid = true;
169188
ErrorOr<uint64_t> getInstWeight(const MachineInstr &MI) override {
189+
if (FunctionSamples::ProfileIsProbeBased)
190+
return getProbeWeight(MI);
170191
if (ImprovedFSDiscriminator && MI.isMetaInstruction())
171192
return std::error_code();
172193
return getInstWeightImpl(MI);
@@ -275,6 +296,14 @@ bool MIRProfileLoader::doInitialization(Module &M) {
275296
Reader->setModule(&M);
276297
ProfileIsValid = (Reader->read() == sampleprof_error::success);
277298

299+
// Load pseudo probe descriptors for probe-based function samples.
300+
if (Reader->profileIsProbeBased()) {
301+
ProbeManager = std::make_unique<PseudoProbeManager>(M);
302+
if (!ProbeManager->moduleIsProbed(M)) {
303+
return false;
304+
}
305+
}
306+
278307
return true;
279308
}
280309

@@ -285,8 +314,13 @@ bool MIRProfileLoader::runOnFunction(MachineFunction &MF) {
285314
if (!Samples || Samples->empty())
286315
return false;
287316

288-
if (getFunctionLoc(MF) == 0)
289-
return false;
317+
if (FunctionSamples::ProfileIsProbeBased) {
318+
if (!ProbeManager->profileIsValid(MF.getFunction(), *Samples))
319+
return false;
320+
} else {
321+
if (getFunctionLoc(MF) == 0)
322+
return false;
323+
}
290324

291325
DenseSet<GlobalValue::GUID> InlinedGUIDs;
292326
bool Changed = computeAndPropagateWeights(MF, InlinedGUIDs);

llvm/lib/IR/PseudoProbe.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,8 @@ using namespace llvm;
2222
namespace llvm {
2323

2424
std::optional<PseudoProbe>
25-
extractProbeFromDiscriminator(const Instruction &Inst) {
26-
assert(isa<CallBase>(&Inst) && !isa<IntrinsicInst>(&Inst) &&
27-
"Only call instructions should have pseudo probe encodes as their "
28-
"Dwarf discriminators");
29-
if (const DebugLoc &DLoc = Inst.getDebugLoc()) {
30-
const DILocation *DIL = DLoc;
25+
extractProbeFromDiscriminator(const DILocation *DIL) {
26+
if (DIL) {
3127
auto Discriminator = DIL->getDiscriminator();
3228
if (DILocation::isPseudoProbeDiscriminator(Discriminator)) {
3329
PseudoProbe Probe;
@@ -40,12 +36,23 @@ extractProbeFromDiscriminator(const Instruction &Inst) {
4036
Probe.Factor =
4137
PseudoProbeDwarfDiscriminator::extractProbeFactor(Discriminator) /
4238
(float)PseudoProbeDwarfDiscriminator::FullDistributionFactor;
39+
Probe.Discriminator = 0;
4340
return Probe;
4441
}
4542
}
4643
return std::nullopt;
4744
}
4845

46+
std::optional<PseudoProbe>
47+
extractProbeFromDiscriminator(const Instruction &Inst) {
48+
assert(isa<CallBase>(&Inst) && !isa<IntrinsicInst>(&Inst) &&
49+
"Only call instructions should have pseudo probe encodes as their "
50+
"Dwarf discriminators");
51+
if (const DebugLoc &DLoc = Inst.getDebugLoc())
52+
return extractProbeFromDiscriminator(DLoc);
53+
return std::nullopt;
54+
}
55+
4956
std::optional<PseudoProbe> extractProbe(const Instruction &Inst) {
5057
if (const auto *II = dyn_cast<PseudoProbeInst>(&Inst)) {
5158
PseudoProbe Probe;
@@ -54,6 +61,11 @@ std::optional<PseudoProbe> extractProbe(const Instruction &Inst) {
5461
Probe.Attr = II->getAttributes()->getZExtValue();
5562
Probe.Factor = II->getFactor()->getZExtValue() /
5663
(float)PseudoProbeFullDistributionFactor;
64+
Probe.Discriminator = 0;
65+
if (const DebugLoc &DLoc = Inst.getDebugLoc())
66+
Probe.Discriminator = DLoc->getDiscriminator();
67+
assert(Probe.Discriminator == 0 &&
68+
"Unexpected non-zero FS-discriminator for IR pseudo probes");
5769
return Probe;
5870
}
5971

0 commit comments

Comments
 (0)