Skip to content

[pgo][nfc] Model Count as a std::optional in PGOUseBBInfo #83364

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 29, 2024

Conversation

mtrofin
Copy link
Member

@mtrofin mtrofin commented Feb 29, 2024

Simpler code, compared to tracking state of 2 variables and the ambiguity of "0" CountValue (is it 0 or is it invalid?)

Simpler code, compared to tracking state of 2 variables and the
ambiguity of "0" CountValue (is it 0 or is it invalid?)
@llvmbot llvmbot added PGO Profile Guided Optimizations llvm:transforms labels Feb 29, 2024
@mtrofin
Copy link
Member Author

mtrofin commented Feb 29, 2024

I'll do the same for PGOUseEdge next.

@llvmbot
Copy link
Member

llvmbot commented Feb 29, 2024

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-pgo

Author: Mircea Trofin (mtrofin)

Changes

Simpler code, compared to tracking state of 2 variables and the ambiguity of "0" CountValue (is it 0 or is it invalid?)


Full diff: https://github.com/llvm/llvm-project/pull/83364.diff

1 Files Affected:

  • (modified) llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp (+39-47)
diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index c20fc942eaf0d5..0c042e73ba0836 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -983,27 +983,22 @@ using DirectEdges = SmallVector<PGOUseEdge *, 2>;
 
 // This class stores the auxiliary information for each BB.
 struct PGOUseBBInfo : public PGOBBInfo {
-  uint64_t CountValue = 0;
-  bool CountValid;
+  std::optional<uint64_t> Count;
   int32_t UnknownCountInEdge = 0;
   int32_t UnknownCountOutEdge = 0;
   DirectEdges InEdges;
   DirectEdges OutEdges;
 
-  PGOUseBBInfo(unsigned IX) : PGOBBInfo(IX), CountValid(false) {}
+  PGOUseBBInfo(unsigned IX) : PGOBBInfo(IX) {}
 
   // Set the profile count value for this BB.
-  void setBBInfoCount(uint64_t Value) {
-    CountValue = Value;
-    CountValid = true;
-  }
+  void setBBInfoCount(uint64_t Value) { Count = Value; }
 
   // Return the information string of this object.
   std::string infoString() const {
-    if (!CountValid)
+    if (!Count)
       return PGOBBInfo::infoString();
-    return (Twine(PGOBBInfo::infoString()) + "  Count=" + Twine(CountValue))
-        .str();
+    return (Twine(PGOBBInfo::infoString()) + "  Count=" + Twine(*Count)).str();
   }
 
   // Add an OutEdge and update the edge count.
@@ -1216,15 +1211,15 @@ bool PGOUseFunc::setInstrumentedCounts(
 
     // If only one out-edge, the edge profile count should be the same as BB
     // profile count.
-    if (SrcInfo.CountValid && SrcInfo.OutEdges.size() == 1)
-      setEdgeCount(E.get(), SrcInfo.CountValue);
+    if (SrcInfo.Count && SrcInfo.OutEdges.size() == 1)
+      setEdgeCount(E.get(), *SrcInfo.Count);
     else {
       const BasicBlock *DestBB = E->DestBB;
       PGOUseBBInfo &DestInfo = getBBInfo(DestBB);
       // If only one in-edge, the edge profile count should be the same as BB
       // profile count.
-      if (DestInfo.CountValid && DestInfo.InEdges.size() == 1)
-        setEdgeCount(E.get(), DestInfo.CountValue);
+      if (DestInfo.Count && DestInfo.InEdges.size() == 1)
+        setEdgeCount(E.get(), *DestInfo.Count);
     }
     if (E->CountValid)
       continue;
@@ -1481,38 +1476,36 @@ void PGOUseFunc::populateCounters() {
     // For efficient traversal, it's better to start from the end as most
     // of the instrumented edges are at the end.
     for (auto &BB : reverse(F)) {
-      PGOUseBBInfo *Count = findBBInfo(&BB);
-      if (Count == nullptr)
+      PGOUseBBInfo *UseBBInfo = findBBInfo(&BB);
+      if (UseBBInfo == nullptr)
         continue;
-      if (!Count->CountValid) {
-        if (Count->UnknownCountOutEdge == 0) {
-          Count->CountValue = sumEdgeCount(Count->OutEdges);
-          Count->CountValid = true;
+      if (!UseBBInfo->Count) {
+        if (UseBBInfo->UnknownCountOutEdge == 0) {
+          UseBBInfo->Count = sumEdgeCount(UseBBInfo->OutEdges);
           Changes = true;
-        } else if (Count->UnknownCountInEdge == 0) {
-          Count->CountValue = sumEdgeCount(Count->InEdges);
-          Count->CountValid = true;
+        } else if (UseBBInfo->UnknownCountInEdge == 0) {
+          UseBBInfo->Count = sumEdgeCount(UseBBInfo->InEdges);
           Changes = true;
         }
       }
-      if (Count->CountValid) {
-        if (Count->UnknownCountOutEdge == 1) {
+      if (UseBBInfo->Count) {
+        if (UseBBInfo->UnknownCountOutEdge == 1) {
           uint64_t Total = 0;
-          uint64_t OutSum = sumEdgeCount(Count->OutEdges);
+          uint64_t OutSum = sumEdgeCount(UseBBInfo->OutEdges);
           // If the one of the successor block can early terminate (no-return),
           // we can end up with situation where out edge sum count is larger as
           // the source BB's count is collected by a post-dominated block.
-          if (Count->CountValue > OutSum)
-            Total = Count->CountValue - OutSum;
-          setEdgeCount(Count->OutEdges, Total);
+          if (*UseBBInfo->Count > OutSum)
+            Total = *UseBBInfo->Count - OutSum;
+          setEdgeCount(UseBBInfo->OutEdges, Total);
           Changes = true;
         }
-        if (Count->UnknownCountInEdge == 1) {
+        if (UseBBInfo->UnknownCountInEdge == 1) {
           uint64_t Total = 0;
-          uint64_t InSum = sumEdgeCount(Count->InEdges);
-          if (Count->CountValue > InSum)
-            Total = Count->CountValue - InSum;
-          setEdgeCount(Count->InEdges, Total);
+          uint64_t InSum = sumEdgeCount(UseBBInfo->InEdges);
+          if (*UseBBInfo->Count > InSum)
+            Total = *UseBBInfo->Count - InSum;
+          setEdgeCount(UseBBInfo->InEdges, Total);
           Changes = true;
         }
       }
@@ -1527,16 +1520,16 @@ void PGOUseFunc::populateCounters() {
     auto BI = findBBInfo(&BB);
     if (BI == nullptr)
       continue;
-    assert(BI->CountValid && "BB count is not valid");
+    assert(BI->Count && "BB count is not valid");
   }
 #endif
-  uint64_t FuncEntryCount = getBBInfo(&*F.begin()).CountValue;
+  uint64_t FuncEntryCount = *getBBInfo(&*F.begin()).Count;
   uint64_t FuncMaxCount = FuncEntryCount;
   for (auto &BB : F) {
     auto BI = findBBInfo(&BB);
     if (BI == nullptr)
       continue;
-    FuncMaxCount = std::max(FuncMaxCount, BI->CountValue);
+    FuncMaxCount = std::max(FuncMaxCount, *BI->Count);
   }
 
   // Fix the obviously inconsistent entry count.
@@ -1566,11 +1559,11 @@ void PGOUseFunc::setBranchWeights() {
           isa<CallBrInst>(TI)))
       continue;
 
-    if (getBBInfo(&BB).CountValue == 0)
+    const PGOUseBBInfo &BBCountInfo = getBBInfo(&BB);
+    if (!*BBCountInfo.Count)
       continue;
 
     // We have a non-zero Branch BB.
-    const PGOUseBBInfo &BBCountInfo = getBBInfo(&BB);
     unsigned Size = BBCountInfo.OutEdges.size();
     SmallVector<uint64_t, 2> EdgeCounts(Size, 0);
     uint64_t MaxCount = 0;
@@ -1622,7 +1615,7 @@ void PGOUseFunc::annotateIrrLoopHeaderWeights() {
     if (BFI->isIrrLoopHeader(&BB) || isIndirectBrTarget(&BB)) {
       Instruction *TI = BB.getTerminator();
       const PGOUseBBInfo &BBCountInfo = getBBInfo(&BB);
-      setIrrLoopHeaderMetadata(M, TI, BBCountInfo.CountValue);
+      setIrrLoopHeaderMetadata(M, TI, *BBCountInfo.Count);
     }
   }
 }
@@ -1649,7 +1642,7 @@ void SelectInstVisitor::annotateOneSelectInst(SelectInst &SI) {
   uint64_t TotalCount = 0;
   auto BI = UseFunc->findBBInfo(SI.getParent());
   if (BI != nullptr)
-    TotalCount = BI->CountValue;
+    TotalCount = *BI->Count;
   // False Count
   SCounts[1] = (TotalCount > SCounts[0] ? TotalCount - SCounts[0] : 0);
   uint64_t MaxCount = std::max(SCounts[0], SCounts[1]);
@@ -1850,7 +1843,7 @@ static void fixFuncEntryCount(PGOUseFunc &Func, LoopInfo &LI,
     if (!Func.findBBInfo(&BBI))
       continue;
     auto BFICount = NBFI.getBlockProfileCount(&BBI);
-    CountValue = Func.getBBInfo(&BBI).CountValue;
+    CountValue = *Func.getBBInfo(&BBI).Count;
     BFICountValue = *BFICount;
     SumCount.add(APFloat(CountValue * 1.0), APFloat::rmNearestTiesToEven);
     SumBFICount.add(APFloat(BFICountValue * 1.0), APFloat::rmNearestTiesToEven);
@@ -1866,7 +1859,7 @@ static void fixFuncEntryCount(PGOUseFunc &Func, LoopInfo &LI,
   if (Scale < 1.001 && Scale > 0.999)
     return;
 
-  uint64_t FuncEntryCount = Func.getBBInfo(&*F.begin()).CountValue;
+  uint64_t FuncEntryCount = *Func.getBBInfo(&*F.begin()).Count;
   uint64_t NewEntryCount = 0.5 + FuncEntryCount * Scale;
   if (NewEntryCount == 0)
     NewEntryCount = 1;
@@ -1896,8 +1889,7 @@ static void verifyFuncBFI(PGOUseFunc &Func, LoopInfo &LI,
     uint64_t CountValue = 0;
     uint64_t BFICountValue = 0;
 
-    if (Func.getBBInfo(&BBI).CountValid)
-      CountValue = Func.getBBInfo(&BBI).CountValue;
+    CountValue = Func.getBBInfo(&BBI).Count.value_or(CountValue);
 
     BBNum++;
     if (CountValue)
@@ -2279,8 +2271,8 @@ template <> struct DOTGraphTraits<PGOUseFunc *> : DefaultDOTGraphTraits {
     OS << getSimpleNodeName(Node) << ":\\l";
     PGOUseBBInfo *BI = Graph->findBBInfo(Node);
     OS << "Count : ";
-    if (BI && BI->CountValid)
-      OS << BI->CountValue << "\\l";
+    if (BI && BI->Count)
+      OS << *BI->Count << "\\l";
     else
       OS << "Unknown\\l";
 

Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for improving this. LGTM!

@mtrofin mtrofin merged commit 1a2960b into llvm:main Feb 29, 2024
@mtrofin mtrofin deleted the pgouseinfo branch February 29, 2024 23:12
mtrofin added a commit to mtrofin/llvm-project that referenced this pull request Feb 29, 2024
mtrofin added a commit to mtrofin/llvm-project that referenced this pull request Mar 1, 2024
…#83364)

Simpler code, compared to tracking state of 2 variables and the ambiguity of "0" CountValue (is it 0 or is it invalid?)
mtrofin added a commit that referenced this pull request Mar 1, 2024
mtrofin added a commit to mtrofin/llvm-project that referenced this pull request Mar 1, 2024
…#83364)

Simpler code, compared to tracking state of 2 variables and the ambiguity of "0" CountValue (is it 0 or is it invalid?)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:transforms PGO Profile Guided Optimizations
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants