Skip to content

[Profiler] Simplify PGOMapping a bit #60621

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 18, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 65 additions & 76 deletions lib/SIL/IR/SILProfiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,21 +463,35 @@ class SourceMappingRegion {
};

/// An ASTWalker that maps ASTNodes to profiling counters.
///
/// TODO: We ought to be able to leverage the CounterExprs from the
/// CoverageMapping walker to recompute the correct counter information
/// for this walker.
struct PGOMapping : public ASTWalker {
/// The next counter value to assign.
unsigned NextCounter;
/// The counter indices for AST nodes.
const llvm::DenseMap<ASTNode, unsigned> &CounterMap;

/// The map of statements to counters.
/// The loaded counter data.
const llvm::InstrProfRecord &LoadedCounts;

/// The output map of statements to counters.
llvm::DenseMap<ASTNode, ProfileCounter> &LoadedCounterMap;
llvm::Expected<llvm::InstrProfRecord> &LoadedCounts;
llvm::DenseMap<ASTNode, ASTNode> &CondToParentMap;
llvm::DenseMap<ASTNode, unsigned> CounterMap;

PGOMapping(llvm::DenseMap<ASTNode, ProfileCounter> &LoadedCounterMap,
llvm::Expected<llvm::InstrProfRecord> &LoadedCounts,
PGOMapping(const llvm::DenseMap<ASTNode, unsigned> &CounterMap,
const llvm::InstrProfRecord &LoadedCounts,
llvm::DenseMap<ASTNode, ProfileCounter> &LoadedCounterMap,
llvm::DenseMap<ASTNode, ASTNode> &RegionCondToParentMap)
: NextCounter(0), LoadedCounterMap(LoadedCounterMap),
LoadedCounts(LoadedCounts), CondToParentMap(RegionCondToParentMap) {}
: CounterMap(CounterMap), LoadedCounts(LoadedCounts),
LoadedCounterMap(LoadedCounterMap),
CondToParentMap(RegionCondToParentMap) {}

/// Retrieve the counter index for a leaf node.
unsigned getCounterIndex(ASTNode Node) const {
auto result = CounterMap.find(Node);
assert(result != CounterMap.end() && "Unmapped node?");
return result->second;
}

unsigned getParentCounter() const {
if (Parent.isNull())
Expand Down Expand Up @@ -516,50 +530,53 @@ struct PGOMapping : public ASTWalker {
"region does not have an associated counter");

unsigned CounterIndexForFunc = CounterIt->second;
return LoadedCounts->Counts[CounterIndexForFunc];
return LoadedCounts.Counts[CounterIndexForFunc];
}

/// Record the execution count for a leaf node.
void setKnownExecutionCount(ASTNode Node) {
LoadedCounterMap[Node] = loadExecutionCount(Node);
}

/// Record a computed execution count for a node.
void setExecutionCount(ASTNode Node, ProfileCounter count) {
LoadedCounterMap[Node] = count;
}

bool walkToDeclPre(Decl *D) override {
if (isUnmapped(D))
return false;
if (auto *AFD = dyn_cast<AbstractFunctionDecl>(D)) {
return visitFunctionDecl(*this, AFD, [&] {
auto node = AFD->getBody();
CounterMap[node] = NextCounter++;
auto count = loadExecutionCount(node);
LoadedCounterMap[node] = count;
setKnownExecutionCount(AFD->getBody());
});
}
if (auto *TLCD = dyn_cast<TopLevelCodeDecl>(D)) {
auto node = TLCD->getBody();
CounterMap[node] = NextCounter++;
auto count = loadExecutionCount(node);
LoadedCounterMap[node] = count;
}
if (auto *TLCD = dyn_cast<TopLevelCodeDecl>(D))
setKnownExecutionCount(TLCD->getBody());

return true;
}

std::pair<bool, Stmt *> walkToStmtPre(Stmt *S) override {
unsigned parent = getParentCounter();
auto parentCount = LoadedCounts.Counts[parent];
if (auto *IS = dyn_cast<IfStmt>(S)) {
auto thenStmt = IS->getThenStmt();
CounterMap[thenStmt] = NextCounter++;
auto thenCount = loadExecutionCount(thenStmt);
LoadedCounterMap[thenStmt] = thenCount;
setExecutionCount(thenStmt, thenCount);
if (auto elseStmt = IS->getElseStmt()) {
CounterMap[elseStmt] = parent;
auto count = loadExecutionCount(elseStmt);
auto count = parentCount;
if (!parent) {
auto thenVal = thenCount.getValue();
for (auto pCount = NextCounter - 1; pCount > 0; --pCount) {
auto cCount = LoadedCounts->Counts[pCount];
for (auto pCount = getCounterIndex(thenStmt); pCount > 0; --pCount) {
auto cCount = LoadedCounts.Counts[pCount];
if (cCount > thenVal) {
count = cCount;
break;
}
}
}
LoadedCounterMap[elseStmt] = subtract(count, thenCount);
setExecutionCount(elseStmt, subtract(count, thenCount));
auto Cond = IS->getCond();
for (const auto &elt : Cond) {
if (elt.getKind() ==
Expand All @@ -568,47 +585,24 @@ struct PGOMapping : public ASTWalker {
}
}
}
} else if (auto *US = dyn_cast<GuardStmt>(S)) {
auto guardBody = US->getBody();
CounterMap[guardBody] = NextCounter++;
} else if (auto *GS = dyn_cast<GuardStmt>(S)) {
auto guardBody = GS->getBody();
auto guardCount = loadExecutionCount(guardBody);
LoadedCounterMap[guardBody] = guardCount;
CounterMap[US] = parent;
auto count = loadExecutionCount(US);
LoadedCounterMap[US] = subtract(count, guardCount);
setExecutionCount(guardBody, guardCount);
setExecutionCount(GS, subtract(parentCount, guardCount));
} else if (auto *WS = dyn_cast<WhileStmt>(S)) {
auto whileBody = WS->getBody();
CounterMap[whileBody] = NextCounter++;
auto whileCount = loadExecutionCount(whileBody);
LoadedCounterMap[whileBody] = whileCount;
CounterMap[WS] = parent;
auto count = loadExecutionCount(WS);
LoadedCounterMap[WS] = count;
setKnownExecutionCount(WS->getBody());
setExecutionCount(WS, parentCount);
} else if (auto *RWS = dyn_cast<RepeatWhileStmt>(S)) {
auto rwsBody = RWS->getBody();
CounterMap[rwsBody] = NextCounter++;
auto rwsBodyCount = loadExecutionCount(rwsBody);
LoadedCounterMap[rwsBody] = rwsBodyCount;
CounterMap[RWS] = parent;
auto count = loadExecutionCount(RWS);
LoadedCounterMap[RWS] = count;
setKnownExecutionCount(RWS->getBody());
setExecutionCount(RWS, parentCount);
} else if (auto *FES = dyn_cast<ForEachStmt>(S)) {
auto fesBody = FES->getBody();
CounterMap[fesBody] = NextCounter++;
auto fesCount = loadExecutionCount(fesBody);
LoadedCounterMap[fesBody] = fesCount;
CounterMap[FES] = parent;
auto count = loadExecutionCount(FES);
LoadedCounterMap[FES] = count;
setKnownExecutionCount(FES->getBody());
setExecutionCount(FES, parentCount);
} else if (auto *SS = dyn_cast<SwitchStmt>(S)) {
CounterMap[SS] = NextCounter++;
auto ssCount = loadExecutionCount(SS);
LoadedCounterMap[SS] = ssCount;
setKnownExecutionCount(SS);
} else if (auto *CS = dyn_cast<CaseStmt>(S)) {
auto stmt = getProfilerStmtForCase(CS);
CounterMap[stmt] = NextCounter++;
auto csCount = loadExecutionCount(stmt);
LoadedCounterMap[stmt] = csCount;
setKnownExecutionCount(getProfilerStmtForCase(CS));
}
return {true, S};
}
Expand All @@ -624,32 +618,27 @@ struct PGOMapping : public ASTWalker {

unsigned parent = getParentCounter();

if (Parent.isNull()) {
CounterMap[E] = NextCounter++;
auto eCount = loadExecutionCount(E);
LoadedCounterMap[E] = eCount;
}
if (Parent.isNull())
setKnownExecutionCount(E);

if (auto *IE = dyn_cast<IfExpr>(E)) {
auto thenExpr = IE->getThenExpr();
CounterMap[thenExpr] = NextCounter++;
auto thenCount = loadExecutionCount(thenExpr);
LoadedCounterMap[thenExpr] = thenCount;
setExecutionCount(thenExpr, thenCount);
auto elseExpr = IE->getElseExpr();
assert(elseExpr && "An if-expr must have an else subexpression");
CounterMap[elseExpr] = parent;
auto count = loadExecutionCount(elseExpr);
auto count = LoadedCounts.Counts[parent];
if (!parent) {
auto thenVal = thenCount.getValue();
for (auto pCount = NextCounter - 1; pCount > 0; --pCount) {
auto cCount = LoadedCounts->Counts[pCount];
for (auto pCount = getCounterIndex(thenExpr); pCount > 0; --pCount) {
auto cCount = LoadedCounts.Counts[pCount];
if (cCount > thenVal) {
count = cCount;
break;
}
}
}
LoadedCounterMap[elseExpr] = subtract(count, thenCount);
setExecutionCount(elseExpr, subtract(count, thenCount));
}
return {true, E};
}
Expand Down Expand Up @@ -1186,8 +1175,8 @@ void SILProfiler::assignRegionCounters() {
llvm::dbgs() << PGOFuncName << "\n";
return;
}
PGOMapping pgoMapper(RegionLoadedCounterMap, LoadedCounts,
RegionCondToParentMap);
PGOMapping pgoMapper(RegionCounterMap, LoadedCounts.get(),
RegionLoadedCounterMap, RegionCondToParentMap);
Root.walk(pgoMapper);
}
}
Expand Down