Skip to content

[pgo][nfc] Model Count as a std::optional in PGOUseBBInfo #83364

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 39 additions & 47 deletions llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -983,27 +983,22 @@ using DirectEdges = SmallVector<PGOUseEdge *, 2>;

// This class stores the auxiliary information for each BB.
struct PGOUseBBInfo : public PGOBBInfo {
uint64_t CountValue = 0;
bool CountValid;
std::optional<uint64_t> Count;
int32_t UnknownCountInEdge = 0;
int32_t UnknownCountOutEdge = 0;
DirectEdges InEdges;
DirectEdges OutEdges;

PGOUseBBInfo(unsigned IX) : PGOBBInfo(IX), CountValid(false) {}
PGOUseBBInfo(unsigned IX) : PGOBBInfo(IX) {}

// Set the profile count value for this BB.
void setBBInfoCount(uint64_t Value) {
CountValue = Value;
CountValid = true;
}
void setBBInfoCount(uint64_t Value) { Count = Value; }

// Return the information string of this object.
std::string infoString() const {
if (!CountValid)
if (!Count)
return PGOBBInfo::infoString();
return (Twine(PGOBBInfo::infoString()) + " Count=" + Twine(CountValue))
.str();
return (Twine(PGOBBInfo::infoString()) + " Count=" + Twine(*Count)).str();
}

// Add an OutEdge and update the edge count.
Expand Down Expand Up @@ -1216,15 +1211,15 @@ bool PGOUseFunc::setInstrumentedCounts(

// If only one out-edge, the edge profile count should be the same as BB
// profile count.
if (SrcInfo.CountValid && SrcInfo.OutEdges.size() == 1)
setEdgeCount(E.get(), SrcInfo.CountValue);
if (SrcInfo.Count && SrcInfo.OutEdges.size() == 1)
setEdgeCount(E.get(), *SrcInfo.Count);
else {
const BasicBlock *DestBB = E->DestBB;
PGOUseBBInfo &DestInfo = getBBInfo(DestBB);
// If only one in-edge, the edge profile count should be the same as BB
// profile count.
if (DestInfo.CountValid && DestInfo.InEdges.size() == 1)
setEdgeCount(E.get(), DestInfo.CountValue);
if (DestInfo.Count && DestInfo.InEdges.size() == 1)
setEdgeCount(E.get(), *DestInfo.Count);
}
if (E->CountValid)
continue;
Expand Down Expand Up @@ -1481,38 +1476,36 @@ void PGOUseFunc::populateCounters() {
// For efficient traversal, it's better to start from the end as most
// of the instrumented edges are at the end.
for (auto &BB : reverse(F)) {
PGOUseBBInfo *Count = findBBInfo(&BB);
if (Count == nullptr)
PGOUseBBInfo *UseBBInfo = findBBInfo(&BB);
if (UseBBInfo == nullptr)
continue;
if (!Count->CountValid) {
if (Count->UnknownCountOutEdge == 0) {
Count->CountValue = sumEdgeCount(Count->OutEdges);
Count->CountValid = true;
if (!UseBBInfo->Count) {
if (UseBBInfo->UnknownCountOutEdge == 0) {
UseBBInfo->Count = sumEdgeCount(UseBBInfo->OutEdges);
Changes = true;
} else if (Count->UnknownCountInEdge == 0) {
Count->CountValue = sumEdgeCount(Count->InEdges);
Count->CountValid = true;
} else if (UseBBInfo->UnknownCountInEdge == 0) {
UseBBInfo->Count = sumEdgeCount(UseBBInfo->InEdges);
Changes = true;
}
}
if (Count->CountValid) {
if (Count->UnknownCountOutEdge == 1) {
if (UseBBInfo->Count) {
if (UseBBInfo->UnknownCountOutEdge == 1) {
uint64_t Total = 0;
uint64_t OutSum = sumEdgeCount(Count->OutEdges);
uint64_t OutSum = sumEdgeCount(UseBBInfo->OutEdges);
// If the one of the successor block can early terminate (no-return),
// we can end up with situation where out edge sum count is larger as
// the source BB's count is collected by a post-dominated block.
if (Count->CountValue > OutSum)
Total = Count->CountValue - OutSum;
setEdgeCount(Count->OutEdges, Total);
if (*UseBBInfo->Count > OutSum)
Total = *UseBBInfo->Count - OutSum;
setEdgeCount(UseBBInfo->OutEdges, Total);
Changes = true;
}
if (Count->UnknownCountInEdge == 1) {
if (UseBBInfo->UnknownCountInEdge == 1) {
uint64_t Total = 0;
uint64_t InSum = sumEdgeCount(Count->InEdges);
if (Count->CountValue > InSum)
Total = Count->CountValue - InSum;
setEdgeCount(Count->InEdges, Total);
uint64_t InSum = sumEdgeCount(UseBBInfo->InEdges);
if (*UseBBInfo->Count > InSum)
Total = *UseBBInfo->Count - InSum;
setEdgeCount(UseBBInfo->InEdges, Total);
Changes = true;
}
}
Expand All @@ -1527,16 +1520,16 @@ void PGOUseFunc::populateCounters() {
auto BI = findBBInfo(&BB);
if (BI == nullptr)
continue;
assert(BI->CountValid && "BB count is not valid");
assert(BI->Count && "BB count is not valid");
}
#endif
uint64_t FuncEntryCount = getBBInfo(&*F.begin()).CountValue;
uint64_t FuncEntryCount = *getBBInfo(&*F.begin()).Count;
uint64_t FuncMaxCount = FuncEntryCount;
for (auto &BB : F) {
auto BI = findBBInfo(&BB);
if (BI == nullptr)
continue;
FuncMaxCount = std::max(FuncMaxCount, BI->CountValue);
FuncMaxCount = std::max(FuncMaxCount, *BI->Count);
}

// Fix the obviously inconsistent entry count.
Expand Down Expand Up @@ -1566,11 +1559,11 @@ void PGOUseFunc::setBranchWeights() {
isa<CallBrInst>(TI)))
continue;

if (getBBInfo(&BB).CountValue == 0)
const PGOUseBBInfo &BBCountInfo = getBBInfo(&BB);
if (!*BBCountInfo.Count)
continue;

// We have a non-zero Branch BB.
const PGOUseBBInfo &BBCountInfo = getBBInfo(&BB);
unsigned Size = BBCountInfo.OutEdges.size();
SmallVector<uint64_t, 2> EdgeCounts(Size, 0);
uint64_t MaxCount = 0;
Expand Down Expand Up @@ -1622,7 +1615,7 @@ void PGOUseFunc::annotateIrrLoopHeaderWeights() {
if (BFI->isIrrLoopHeader(&BB) || isIndirectBrTarget(&BB)) {
Instruction *TI = BB.getTerminator();
const PGOUseBBInfo &BBCountInfo = getBBInfo(&BB);
setIrrLoopHeaderMetadata(M, TI, BBCountInfo.CountValue);
setIrrLoopHeaderMetadata(M, TI, *BBCountInfo.Count);
}
}
}
Expand All @@ -1649,7 +1642,7 @@ void SelectInstVisitor::annotateOneSelectInst(SelectInst &SI) {
uint64_t TotalCount = 0;
auto BI = UseFunc->findBBInfo(SI.getParent());
if (BI != nullptr)
TotalCount = BI->CountValue;
TotalCount = *BI->Count;
// False Count
SCounts[1] = (TotalCount > SCounts[0] ? TotalCount - SCounts[0] : 0);
uint64_t MaxCount = std::max(SCounts[0], SCounts[1]);
Expand Down Expand Up @@ -1850,7 +1843,7 @@ static void fixFuncEntryCount(PGOUseFunc &Func, LoopInfo &LI,
if (!Func.findBBInfo(&BBI))
continue;
auto BFICount = NBFI.getBlockProfileCount(&BBI);
CountValue = Func.getBBInfo(&BBI).CountValue;
CountValue = *Func.getBBInfo(&BBI).Count;
BFICountValue = *BFICount;
SumCount.add(APFloat(CountValue * 1.0), APFloat::rmNearestTiesToEven);
SumBFICount.add(APFloat(BFICountValue * 1.0), APFloat::rmNearestTiesToEven);
Expand All @@ -1866,7 +1859,7 @@ static void fixFuncEntryCount(PGOUseFunc &Func, LoopInfo &LI,
if (Scale < 1.001 && Scale > 0.999)
return;

uint64_t FuncEntryCount = Func.getBBInfo(&*F.begin()).CountValue;
uint64_t FuncEntryCount = *Func.getBBInfo(&*F.begin()).Count;
uint64_t NewEntryCount = 0.5 + FuncEntryCount * Scale;
if (NewEntryCount == 0)
NewEntryCount = 1;
Expand Down Expand Up @@ -1896,8 +1889,7 @@ static void verifyFuncBFI(PGOUseFunc &Func, LoopInfo &LI,
uint64_t CountValue = 0;
uint64_t BFICountValue = 0;

if (Func.getBBInfo(&BBI).CountValid)
CountValue = Func.getBBInfo(&BBI).CountValue;
CountValue = Func.getBBInfo(&BBI).Count.value_or(CountValue);

BBNum++;
if (CountValue)
Expand Down Expand Up @@ -2279,8 +2271,8 @@ template <> struct DOTGraphTraits<PGOUseFunc *> : DefaultDOTGraphTraits {
OS << getSimpleNodeName(Node) << ":\\l";
PGOUseBBInfo *BI = Graph->findBBInfo(Node);
OS << "Count : ";
if (BI && BI->CountValid)
OS << BI->CountValue << "\\l";
if (BI && BI->Count)
OS << *BI->Count << "\\l";
else
OS << "Unknown\\l";

Expand Down