|
34 | 34 | #include "llvm/IR/Instruction.h"
|
35 | 35 | #include "llvm/IR/Instructions.h"
|
36 | 36 | #include "llvm/IR/Module.h"
|
| 37 | +#include "llvm/IR/PseudoProbe.h" |
37 | 38 | #include "llvm/ProfileData/SampleProf.h"
|
38 | 39 | #include "llvm/ProfileData/SampleProfReader.h"
|
39 | 40 | #include "llvm/Support/CommandLine.h"
|
@@ -80,6 +81,55 @@ template <> struct IRTraits<BasicBlock> {
|
80 | 81 |
|
81 | 82 | } // end namespace afdo_detail
|
82 | 83 |
|
| 84 | +// This class serves sample counts correlation for SampleProfileLoader by |
| 85 | +// analyzing pseudo probes and their function descriptors injected by |
| 86 | +// SampleProfileProber. |
| 87 | +class PseudoProbeManager { |
| 88 | + DenseMap<uint64_t, PseudoProbeDescriptor> GUIDToProbeDescMap; |
| 89 | + |
| 90 | + const PseudoProbeDescriptor *getDesc(const Function &F) const { |
| 91 | + auto I = GUIDToProbeDescMap.find( |
| 92 | + Function::getGUID(FunctionSamples::getCanonicalFnName(F))); |
| 93 | + return I == GUIDToProbeDescMap.end() ? nullptr : &I->second; |
| 94 | + } |
| 95 | + |
| 96 | +public: |
| 97 | + PseudoProbeManager(const Module &M) { |
| 98 | + if (NamedMDNode *FuncInfo = |
| 99 | + M.getNamedMetadata(PseudoProbeDescMetadataName)) { |
| 100 | + for (const auto *Operand : FuncInfo->operands()) { |
| 101 | + const auto *MD = cast<MDNode>(Operand); |
| 102 | + auto GUID = mdconst::dyn_extract<ConstantInt>(MD->getOperand(0)) |
| 103 | + ->getZExtValue(); |
| 104 | + auto Hash = mdconst::dyn_extract<ConstantInt>(MD->getOperand(1)) |
| 105 | + ->getZExtValue(); |
| 106 | + GUIDToProbeDescMap.try_emplace(GUID, PseudoProbeDescriptor(GUID, Hash)); |
| 107 | + } |
| 108 | + } |
| 109 | + } |
| 110 | + |
| 111 | + bool moduleIsProbed(const Module &M) const { |
| 112 | + return M.getNamedMetadata(PseudoProbeDescMetadataName); |
| 113 | + } |
| 114 | + |
| 115 | + bool profileIsValid(const Function &F, const FunctionSamples &Samples) const { |
| 116 | + const auto *Desc = getDesc(F); |
| 117 | + if (!Desc) { |
| 118 | + LLVM_DEBUG(dbgs() << "Probe descriptor missing for Function " |
| 119 | + << F.getName() << "\n"); |
| 120 | + return false; |
| 121 | + } |
| 122 | + if (Desc->getFunctionHash() != Samples.getFunctionHash()) { |
| 123 | + LLVM_DEBUG(dbgs() << "Hash mismatch for Function " << F.getName() |
| 124 | + << "\n"); |
| 125 | + return false; |
| 126 | + } |
| 127 | + return true; |
| 128 | + } |
| 129 | +}; |
| 130 | + |
| 131 | + |
| 132 | + |
83 | 133 | extern cl::opt<bool> SampleProfileUseProfi;
|
84 | 134 |
|
85 | 135 | template <typename BT> class SampleProfileLoaderBaseImpl {
|
@@ -137,6 +187,7 @@ template <typename BT> class SampleProfileLoaderBaseImpl {
|
137 | 187 | unsigned getFunctionLoc(FunctionT &Func);
|
138 | 188 | virtual ErrorOr<uint64_t> getInstWeight(const InstructionT &Inst);
|
139 | 189 | ErrorOr<uint64_t> getInstWeightImpl(const InstructionT &Inst);
|
| 190 | + virtual ErrorOr<uint64_t> getProbeWeight(const InstructionT &Inst); |
140 | 191 | ErrorOr<uint64_t> getBlockWeight(const BasicBlockT *BB);
|
141 | 192 | mutable DenseMap<const DILocation *, const FunctionSamples *>
|
142 | 193 | DILocation2SampleMap;
|
@@ -212,6 +263,9 @@ template <typename BT> class SampleProfileLoaderBaseImpl {
|
212 | 263 | /// Profile reader object.
|
213 | 264 | std::unique_ptr<SampleProfileReader> Reader;
|
214 | 265 |
|
| 266 | + // A pseudo probe helper to correlate the imported sample counts. |
| 267 | + std::unique_ptr<PseudoProbeManager> ProbeManager; |
| 268 | + |
215 | 269 | /// Samples collected for the body of this function.
|
216 | 270 | FunctionSamples *Samples = nullptr;
|
217 | 271 |
|
@@ -299,6 +353,8 @@ void SampleProfileLoaderBaseImpl<BT>::printBlockWeight(
|
299 | 353 | template <typename BT>
|
300 | 354 | ErrorOr<uint64_t>
|
301 | 355 | SampleProfileLoaderBaseImpl<BT>::getInstWeight(const InstructionT &Inst) {
|
| 356 | + if (FunctionSamples::ProfileIsProbeBased) |
| 357 | + return getProbeWeight(Inst); |
302 | 358 | return getInstWeightImpl(Inst);
|
303 | 359 | }
|
304 | 360 |
|
@@ -346,6 +402,65 @@ SampleProfileLoaderBaseImpl<BT>::getInstWeightImpl(const InstructionT &Inst) {
|
346 | 402 | return R;
|
347 | 403 | }
|
348 | 404 |
|
| 405 | +// Here use error_code to represent: 1) The dangling probe. 2) Ignore the weight |
| 406 | +// of non-probe instruction. So if all instructions of the BB give error_code, |
| 407 | +// tell the inference algorithm to infer the BB weight. |
| 408 | +template <typename BT> |
| 409 | +ErrorOr<uint64_t> |
| 410 | +SampleProfileLoaderBaseImpl<BT>::getProbeWeight(const InstructionT &Inst) { |
| 411 | + assert(FunctionSamples::ProfileIsProbeBased && |
| 412 | + "Profile is not pseudo probe based"); |
| 413 | + std::optional<PseudoProbe> Probe = extractProbe(Inst); |
| 414 | + // Ignore the non-probe instruction. If none of the instruction in the BB is |
| 415 | + // probe, we choose to infer the BB's weight. |
| 416 | + if (!Probe) |
| 417 | + return std::error_code(); |
| 418 | + |
| 419 | + const FunctionSamples *FS = findFunctionSamples(Inst); |
| 420 | + // If none of the instruction has FunctionSample, we choose to return zero |
| 421 | + // value sample to indicate the BB is cold. This could happen when the |
| 422 | + // instruction is from inlinee and no profile data is found. |
| 423 | + // FIXME: This should not be affected by the source drift issue as 1) if the |
| 424 | + // newly added function is top-level inliner, it won't match the CFG checksum |
| 425 | + // in the function profile or 2) if it's the inlinee, the inlinee should have |
| 426 | + // a profile, otherwise it wouldn't be inlined. For non-probe based profile, |
| 427 | + // we can improve it by adding a switch for profile-sample-block-accurate for |
| 428 | + // block level counts in the future. |
| 429 | + if (!FS) |
| 430 | + return 0; |
| 431 | + |
| 432 | + auto R = FS->findSamplesAt(Probe->Id, Probe->Discriminator); |
| 433 | + if (R) { |
| 434 | + uint64_t Samples = R.get() * Probe->Factor; |
| 435 | + bool FirstMark = CoverageTracker.markSamplesUsed(FS, Probe->Id, 0, Samples); |
| 436 | + if (FirstMark) { |
| 437 | + ORE->emit([&]() { |
| 438 | + OptRemarkAnalysisT Remark(DEBUG_TYPE, "AppliedSamples", &Inst); |
| 439 | + Remark << "Applied " << ore::NV("NumSamples", Samples); |
| 440 | + Remark << " samples from profile (ProbeId="; |
| 441 | + Remark << ore::NV("ProbeId", Probe->Id); |
| 442 | + if (Probe->Discriminator) { |
| 443 | + Remark << "."; |
| 444 | + Remark << ore::NV("Discriminator", Probe->Discriminator); |
| 445 | + } |
| 446 | + Remark << ", Factor="; |
| 447 | + Remark << ore::NV("Factor", Probe->Factor); |
| 448 | + Remark << ", OriginalSamples="; |
| 449 | + Remark << ore::NV("OriginalSamples", R.get()); |
| 450 | + Remark << ")"; |
| 451 | + return Remark; |
| 452 | + }); |
| 453 | + } |
| 454 | + LLVM_DEBUG({dbgs() << " " << Probe->Id; |
| 455 | + if (Probe->Discriminator) |
| 456 | + dbgs() << "." << Probe->Discriminator; |
| 457 | + dbgs() << ":" << Inst << " - weight: " << R.get() |
| 458 | + << " - factor: " << format("%0.2f", Probe->Factor) << ")\n";}); |
| 459 | + return Samples; |
| 460 | + } |
| 461 | + return R; |
| 462 | +} |
| 463 | + |
349 | 464 | /// Compute the weight of a basic block.
|
350 | 465 | ///
|
351 | 466 | /// The weight of basic block \p BB is the maximum weight of all the
|
|
0 commit comments