Skip to content

Commit 1a2960b

Browse files
authored
[pgo][nfc] Model Count as a std::optional in PGOUseBBInfo (#83364)
Simpler code, compared to tracking state of 2 variables and the ambiguity of "0" CountValue (is it 0 or is it invalid?)
1 parent 3be05d8 commit 1a2960b

File tree

1 file changed

+39
-47
lines changed

1 file changed

+39
-47
lines changed

llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp

Lines changed: 39 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -983,27 +983,22 @@ using DirectEdges = SmallVector<PGOUseEdge *, 2>;
983983

984984
// This class stores the auxiliary information for each BB.
985985
struct PGOUseBBInfo : public PGOBBInfo {
986-
uint64_t CountValue = 0;
987-
bool CountValid;
986+
std::optional<uint64_t> Count;
988987
int32_t UnknownCountInEdge = 0;
989988
int32_t UnknownCountOutEdge = 0;
990989
DirectEdges InEdges;
991990
DirectEdges OutEdges;
992991

993-
PGOUseBBInfo(unsigned IX) : PGOBBInfo(IX), CountValid(false) {}
992+
PGOUseBBInfo(unsigned IX) : PGOBBInfo(IX) {}
994993

995994
// Set the profile count value for this BB.
996-
void setBBInfoCount(uint64_t Value) {
997-
CountValue = Value;
998-
CountValid = true;
999-
}
995+
void setBBInfoCount(uint64_t Value) { Count = Value; }
1000996

1001997
// Return the information string of this object.
1002998
std::string infoString() const {
1003-
if (!CountValid)
999+
if (!Count)
10041000
return PGOBBInfo::infoString();
1005-
return (Twine(PGOBBInfo::infoString()) + " Count=" + Twine(CountValue))
1006-
.str();
1001+
return (Twine(PGOBBInfo::infoString()) + " Count=" + Twine(*Count)).str();
10071002
}
10081003

10091004
// Add an OutEdge and update the edge count.
@@ -1216,15 +1211,15 @@ bool PGOUseFunc::setInstrumentedCounts(
12161211

12171212
// If only one out-edge, the edge profile count should be the same as BB
12181213
// profile count.
1219-
if (SrcInfo.CountValid && SrcInfo.OutEdges.size() == 1)
1220-
setEdgeCount(E.get(), SrcInfo.CountValue);
1214+
if (SrcInfo.Count && SrcInfo.OutEdges.size() == 1)
1215+
setEdgeCount(E.get(), *SrcInfo.Count);
12211216
else {
12221217
const BasicBlock *DestBB = E->DestBB;
12231218
PGOUseBBInfo &DestInfo = getBBInfo(DestBB);
12241219
// If only one in-edge, the edge profile count should be the same as BB
12251220
// profile count.
1226-
if (DestInfo.CountValid && DestInfo.InEdges.size() == 1)
1227-
setEdgeCount(E.get(), DestInfo.CountValue);
1221+
if (DestInfo.Count && DestInfo.InEdges.size() == 1)
1222+
setEdgeCount(E.get(), *DestInfo.Count);
12281223
}
12291224
if (E->CountValid)
12301225
continue;
@@ -1481,38 +1476,36 @@ void PGOUseFunc::populateCounters() {
14811476
// For efficient traversal, it's better to start from the end as most
14821477
// of the instrumented edges are at the end.
14831478
for (auto &BB : reverse(F)) {
1484-
PGOUseBBInfo *Count = findBBInfo(&BB);
1485-
if (Count == nullptr)
1479+
PGOUseBBInfo *UseBBInfo = findBBInfo(&BB);
1480+
if (UseBBInfo == nullptr)
14861481
continue;
1487-
if (!Count->CountValid) {
1488-
if (Count->UnknownCountOutEdge == 0) {
1489-
Count->CountValue = sumEdgeCount(Count->OutEdges);
1490-
Count->CountValid = true;
1482+
if (!UseBBInfo->Count) {
1483+
if (UseBBInfo->UnknownCountOutEdge == 0) {
1484+
UseBBInfo->Count = sumEdgeCount(UseBBInfo->OutEdges);
14911485
Changes = true;
1492-
} else if (Count->UnknownCountInEdge == 0) {
1493-
Count->CountValue = sumEdgeCount(Count->InEdges);
1494-
Count->CountValid = true;
1486+
} else if (UseBBInfo->UnknownCountInEdge == 0) {
1487+
UseBBInfo->Count = sumEdgeCount(UseBBInfo->InEdges);
14951488
Changes = true;
14961489
}
14971490
}
1498-
if (Count->CountValid) {
1499-
if (Count->UnknownCountOutEdge == 1) {
1491+
if (UseBBInfo->Count) {
1492+
if (UseBBInfo->UnknownCountOutEdge == 1) {
15001493
uint64_t Total = 0;
1501-
uint64_t OutSum = sumEdgeCount(Count->OutEdges);
1494+
uint64_t OutSum = sumEdgeCount(UseBBInfo->OutEdges);
15021495
// If the one of the successor block can early terminate (no-return),
15031496
// we can end up with situation where out edge sum count is larger as
15041497
// the source BB's count is collected by a post-dominated block.
1505-
if (Count->CountValue > OutSum)
1506-
Total = Count->CountValue - OutSum;
1507-
setEdgeCount(Count->OutEdges, Total);
1498+
if (*UseBBInfo->Count > OutSum)
1499+
Total = *UseBBInfo->Count - OutSum;
1500+
setEdgeCount(UseBBInfo->OutEdges, Total);
15081501
Changes = true;
15091502
}
1510-
if (Count->UnknownCountInEdge == 1) {
1503+
if (UseBBInfo->UnknownCountInEdge == 1) {
15111504
uint64_t Total = 0;
1512-
uint64_t InSum = sumEdgeCount(Count->InEdges);
1513-
if (Count->CountValue > InSum)
1514-
Total = Count->CountValue - InSum;
1515-
setEdgeCount(Count->InEdges, Total);
1505+
uint64_t InSum = sumEdgeCount(UseBBInfo->InEdges);
1506+
if (*UseBBInfo->Count > InSum)
1507+
Total = *UseBBInfo->Count - InSum;
1508+
setEdgeCount(UseBBInfo->InEdges, Total);
15161509
Changes = true;
15171510
}
15181511
}
@@ -1527,16 +1520,16 @@ void PGOUseFunc::populateCounters() {
15271520
auto BI = findBBInfo(&BB);
15281521
if (BI == nullptr)
15291522
continue;
1530-
assert(BI->CountValid && "BB count is not valid");
1523+
assert(BI->Count && "BB count is not valid");
15311524
}
15321525
#endif
1533-
uint64_t FuncEntryCount = getBBInfo(&*F.begin()).CountValue;
1526+
uint64_t FuncEntryCount = *getBBInfo(&*F.begin()).Count;
15341527
uint64_t FuncMaxCount = FuncEntryCount;
15351528
for (auto &BB : F) {
15361529
auto BI = findBBInfo(&BB);
15371530
if (BI == nullptr)
15381531
continue;
1539-
FuncMaxCount = std::max(FuncMaxCount, BI->CountValue);
1532+
FuncMaxCount = std::max(FuncMaxCount, *BI->Count);
15401533
}
15411534

15421535
// Fix the obviously inconsistent entry count.
@@ -1566,11 +1559,11 @@ void PGOUseFunc::setBranchWeights() {
15661559
isa<CallBrInst>(TI)))
15671560
continue;
15681561

1569-
if (getBBInfo(&BB).CountValue == 0)
1562+
const PGOUseBBInfo &BBCountInfo = getBBInfo(&BB);
1563+
if (!*BBCountInfo.Count)
15701564
continue;
15711565

15721566
// We have a non-zero Branch BB.
1573-
const PGOUseBBInfo &BBCountInfo = getBBInfo(&BB);
15741567
unsigned Size = BBCountInfo.OutEdges.size();
15751568
SmallVector<uint64_t, 2> EdgeCounts(Size, 0);
15761569
uint64_t MaxCount = 0;
@@ -1622,7 +1615,7 @@ void PGOUseFunc::annotateIrrLoopHeaderWeights() {
16221615
if (BFI->isIrrLoopHeader(&BB) || isIndirectBrTarget(&BB)) {
16231616
Instruction *TI = BB.getTerminator();
16241617
const PGOUseBBInfo &BBCountInfo = getBBInfo(&BB);
1625-
setIrrLoopHeaderMetadata(M, TI, BBCountInfo.CountValue);
1618+
setIrrLoopHeaderMetadata(M, TI, *BBCountInfo.Count);
16261619
}
16271620
}
16281621
}
@@ -1649,7 +1642,7 @@ void SelectInstVisitor::annotateOneSelectInst(SelectInst &SI) {
16491642
uint64_t TotalCount = 0;
16501643
auto BI = UseFunc->findBBInfo(SI.getParent());
16511644
if (BI != nullptr)
1652-
TotalCount = BI->CountValue;
1645+
TotalCount = *BI->Count;
16531646
// False Count
16541647
SCounts[1] = (TotalCount > SCounts[0] ? TotalCount - SCounts[0] : 0);
16551648
uint64_t MaxCount = std::max(SCounts[0], SCounts[1]);
@@ -1850,7 +1843,7 @@ static void fixFuncEntryCount(PGOUseFunc &Func, LoopInfo &LI,
18501843
if (!Func.findBBInfo(&BBI))
18511844
continue;
18521845
auto BFICount = NBFI.getBlockProfileCount(&BBI);
1853-
CountValue = Func.getBBInfo(&BBI).CountValue;
1846+
CountValue = *Func.getBBInfo(&BBI).Count;
18541847
BFICountValue = *BFICount;
18551848
SumCount.add(APFloat(CountValue * 1.0), APFloat::rmNearestTiesToEven);
18561849
SumBFICount.add(APFloat(BFICountValue * 1.0), APFloat::rmNearestTiesToEven);
@@ -1866,7 +1859,7 @@ static void fixFuncEntryCount(PGOUseFunc &Func, LoopInfo &LI,
18661859
if (Scale < 1.001 && Scale > 0.999)
18671860
return;
18681861

1869-
uint64_t FuncEntryCount = Func.getBBInfo(&*F.begin()).CountValue;
1862+
uint64_t FuncEntryCount = *Func.getBBInfo(&*F.begin()).Count;
18701863
uint64_t NewEntryCount = 0.5 + FuncEntryCount * Scale;
18711864
if (NewEntryCount == 0)
18721865
NewEntryCount = 1;
@@ -1896,8 +1889,7 @@ static void verifyFuncBFI(PGOUseFunc &Func, LoopInfo &LI,
18961889
uint64_t CountValue = 0;
18971890
uint64_t BFICountValue = 0;
18981891

1899-
if (Func.getBBInfo(&BBI).CountValid)
1900-
CountValue = Func.getBBInfo(&BBI).CountValue;
1892+
CountValue = Func.getBBInfo(&BBI).Count.value_or(CountValue);
19011893

19021894
BBNum++;
19031895
if (CountValue)
@@ -2279,8 +2271,8 @@ template <> struct DOTGraphTraits<PGOUseFunc *> : DefaultDOTGraphTraits {
22792271
OS << getSimpleNodeName(Node) << ":\\l";
22802272
PGOUseBBInfo *BI = Graph->findBBInfo(Node);
22812273
OS << "Count : ";
2282-
if (BI && BI->CountValid)
2283-
OS << BI->CountValue << "\\l";
2274+
if (BI && BI->Count)
2275+
OS << *BI->Count << "\\l";
22842276
else
22852277
OS << "Unknown\\l";
22862278

0 commit comments

Comments
 (0)