-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[pgo][nfc] Model Count
as a std::optional
in PGOUseBBInfo
#83364
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Simpler code, compared to tracking state of 2 variables and the ambiguity of "0" CountValue (is it 0 or is it invalid?)
I'll do the same for |
@llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-pgo Author: Mircea Trofin (mtrofin) ChangesSimpler code, compared to tracking state of 2 variables and the ambiguity of "0" CountValue (is it 0 or is it invalid?) Full diff: https://github.com/llvm/llvm-project/pull/83364.diff 1 Files Affected:
diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index c20fc942eaf0d5..0c042e73ba0836 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -983,27 +983,22 @@ using DirectEdges = SmallVector<PGOUseEdge *, 2>;
// This class stores the auxiliary information for each BB.
struct PGOUseBBInfo : public PGOBBInfo {
- uint64_t CountValue = 0;
- bool CountValid;
+ std::optional<uint64_t> Count;
int32_t UnknownCountInEdge = 0;
int32_t UnknownCountOutEdge = 0;
DirectEdges InEdges;
DirectEdges OutEdges;
- PGOUseBBInfo(unsigned IX) : PGOBBInfo(IX), CountValid(false) {}
+ PGOUseBBInfo(unsigned IX) : PGOBBInfo(IX) {}
// Set the profile count value for this BB.
- void setBBInfoCount(uint64_t Value) {
- CountValue = Value;
- CountValid = true;
- }
+ void setBBInfoCount(uint64_t Value) { Count = Value; }
// Return the information string of this object.
std::string infoString() const {
- if (!CountValid)
+ if (!Count)
return PGOBBInfo::infoString();
- return (Twine(PGOBBInfo::infoString()) + " Count=" + Twine(CountValue))
- .str();
+ return (Twine(PGOBBInfo::infoString()) + " Count=" + Twine(*Count)).str();
}
// Add an OutEdge and update the edge count.
@@ -1216,15 +1211,15 @@ bool PGOUseFunc::setInstrumentedCounts(
// If only one out-edge, the edge profile count should be the same as BB
// profile count.
- if (SrcInfo.CountValid && SrcInfo.OutEdges.size() == 1)
- setEdgeCount(E.get(), SrcInfo.CountValue);
+ if (SrcInfo.Count && SrcInfo.OutEdges.size() == 1)
+ setEdgeCount(E.get(), *SrcInfo.Count);
else {
const BasicBlock *DestBB = E->DestBB;
PGOUseBBInfo &DestInfo = getBBInfo(DestBB);
// If only one in-edge, the edge profile count should be the same as BB
// profile count.
- if (DestInfo.CountValid && DestInfo.InEdges.size() == 1)
- setEdgeCount(E.get(), DestInfo.CountValue);
+ if (DestInfo.Count && DestInfo.InEdges.size() == 1)
+ setEdgeCount(E.get(), *DestInfo.Count);
}
if (E->CountValid)
continue;
@@ -1481,38 +1476,36 @@ void PGOUseFunc::populateCounters() {
// For efficient traversal, it's better to start from the end as most
// of the instrumented edges are at the end.
for (auto &BB : reverse(F)) {
- PGOUseBBInfo *Count = findBBInfo(&BB);
- if (Count == nullptr)
+ PGOUseBBInfo *UseBBInfo = findBBInfo(&BB);
+ if (UseBBInfo == nullptr)
continue;
- if (!Count->CountValid) {
- if (Count->UnknownCountOutEdge == 0) {
- Count->CountValue = sumEdgeCount(Count->OutEdges);
- Count->CountValid = true;
+ if (!UseBBInfo->Count) {
+ if (UseBBInfo->UnknownCountOutEdge == 0) {
+ UseBBInfo->Count = sumEdgeCount(UseBBInfo->OutEdges);
Changes = true;
- } else if (Count->UnknownCountInEdge == 0) {
- Count->CountValue = sumEdgeCount(Count->InEdges);
- Count->CountValid = true;
+ } else if (UseBBInfo->UnknownCountInEdge == 0) {
+ UseBBInfo->Count = sumEdgeCount(UseBBInfo->InEdges);
Changes = true;
}
}
- if (Count->CountValid) {
- if (Count->UnknownCountOutEdge == 1) {
+ if (UseBBInfo->Count) {
+ if (UseBBInfo->UnknownCountOutEdge == 1) {
uint64_t Total = 0;
- uint64_t OutSum = sumEdgeCount(Count->OutEdges);
+ uint64_t OutSum = sumEdgeCount(UseBBInfo->OutEdges);
// If the one of the successor block can early terminate (no-return),
// we can end up with situation where out edge sum count is larger as
// the source BB's count is collected by a post-dominated block.
- if (Count->CountValue > OutSum)
- Total = Count->CountValue - OutSum;
- setEdgeCount(Count->OutEdges, Total);
+ if (*UseBBInfo->Count > OutSum)
+ Total = *UseBBInfo->Count - OutSum;
+ setEdgeCount(UseBBInfo->OutEdges, Total);
Changes = true;
}
- if (Count->UnknownCountInEdge == 1) {
+ if (UseBBInfo->UnknownCountInEdge == 1) {
uint64_t Total = 0;
- uint64_t InSum = sumEdgeCount(Count->InEdges);
- if (Count->CountValue > InSum)
- Total = Count->CountValue - InSum;
- setEdgeCount(Count->InEdges, Total);
+ uint64_t InSum = sumEdgeCount(UseBBInfo->InEdges);
+ if (*UseBBInfo->Count > InSum)
+ Total = *UseBBInfo->Count - InSum;
+ setEdgeCount(UseBBInfo->InEdges, Total);
Changes = true;
}
}
@@ -1527,16 +1520,16 @@ void PGOUseFunc::populateCounters() {
auto BI = findBBInfo(&BB);
if (BI == nullptr)
continue;
- assert(BI->CountValid && "BB count is not valid");
+ assert(BI->Count && "BB count is not valid");
}
#endif
- uint64_t FuncEntryCount = getBBInfo(&*F.begin()).CountValue;
+ uint64_t FuncEntryCount = *getBBInfo(&*F.begin()).Count;
uint64_t FuncMaxCount = FuncEntryCount;
for (auto &BB : F) {
auto BI = findBBInfo(&BB);
if (BI == nullptr)
continue;
- FuncMaxCount = std::max(FuncMaxCount, BI->CountValue);
+ FuncMaxCount = std::max(FuncMaxCount, *BI->Count);
}
// Fix the obviously inconsistent entry count.
@@ -1566,11 +1559,11 @@ void PGOUseFunc::setBranchWeights() {
isa<CallBrInst>(TI)))
continue;
- if (getBBInfo(&BB).CountValue == 0)
+ const PGOUseBBInfo &BBCountInfo = getBBInfo(&BB);
+ if (!*BBCountInfo.Count)
continue;
// We have a non-zero Branch BB.
- const PGOUseBBInfo &BBCountInfo = getBBInfo(&BB);
unsigned Size = BBCountInfo.OutEdges.size();
SmallVector<uint64_t, 2> EdgeCounts(Size, 0);
uint64_t MaxCount = 0;
@@ -1622,7 +1615,7 @@ void PGOUseFunc::annotateIrrLoopHeaderWeights() {
if (BFI->isIrrLoopHeader(&BB) || isIndirectBrTarget(&BB)) {
Instruction *TI = BB.getTerminator();
const PGOUseBBInfo &BBCountInfo = getBBInfo(&BB);
- setIrrLoopHeaderMetadata(M, TI, BBCountInfo.CountValue);
+ setIrrLoopHeaderMetadata(M, TI, *BBCountInfo.Count);
}
}
}
@@ -1649,7 +1642,7 @@ void SelectInstVisitor::annotateOneSelectInst(SelectInst &SI) {
uint64_t TotalCount = 0;
auto BI = UseFunc->findBBInfo(SI.getParent());
if (BI != nullptr)
- TotalCount = BI->CountValue;
+ TotalCount = *BI->Count;
// False Count
SCounts[1] = (TotalCount > SCounts[0] ? TotalCount - SCounts[0] : 0);
uint64_t MaxCount = std::max(SCounts[0], SCounts[1]);
@@ -1850,7 +1843,7 @@ static void fixFuncEntryCount(PGOUseFunc &Func, LoopInfo &LI,
if (!Func.findBBInfo(&BBI))
continue;
auto BFICount = NBFI.getBlockProfileCount(&BBI);
- CountValue = Func.getBBInfo(&BBI).CountValue;
+ CountValue = *Func.getBBInfo(&BBI).Count;
BFICountValue = *BFICount;
SumCount.add(APFloat(CountValue * 1.0), APFloat::rmNearestTiesToEven);
SumBFICount.add(APFloat(BFICountValue * 1.0), APFloat::rmNearestTiesToEven);
@@ -1866,7 +1859,7 @@ static void fixFuncEntryCount(PGOUseFunc &Func, LoopInfo &LI,
if (Scale < 1.001 && Scale > 0.999)
return;
- uint64_t FuncEntryCount = Func.getBBInfo(&*F.begin()).CountValue;
+ uint64_t FuncEntryCount = *Func.getBBInfo(&*F.begin()).Count;
uint64_t NewEntryCount = 0.5 + FuncEntryCount * Scale;
if (NewEntryCount == 0)
NewEntryCount = 1;
@@ -1896,8 +1889,7 @@ static void verifyFuncBFI(PGOUseFunc &Func, LoopInfo &LI,
uint64_t CountValue = 0;
uint64_t BFICountValue = 0;
- if (Func.getBBInfo(&BBI).CountValid)
- CountValue = Func.getBBInfo(&BBI).CountValue;
+ CountValue = Func.getBBInfo(&BBI).Count.value_or(CountValue);
BBNum++;
if (CountValue)
@@ -2279,8 +2271,8 @@ template <> struct DOTGraphTraits<PGOUseFunc *> : DefaultDOTGraphTraits {
OS << getSimpleNodeName(Node) << ":\\l";
PGOUseBBInfo *BI = Graph->findBBInfo(Node);
OS << "Count : ";
- if (BI && BI->CountValid)
- OS << BI->CountValue << "\\l";
+ if (BI && BI->Count)
+ OS << *BI->Count << "\\l";
else
OS << "Unknown\\l";
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for improving this. LGTM!
…#83364) Simpler code, compared to tracking state of 2 variables and the ambiguity of "0" CountValue (is it 0 or is it invalid?)
…#83364) Simpler code, compared to tracking state of 2 variables and the ambiguity of "0" CountValue (is it 0 or is it invalid?)
Simpler code, compared to tracking state of 2 variables and the ambiguity of "0" CountValue (is it 0 or is it invalid?)