Skip to content

Commit daa2a58

Browse files
authored
[TRE] Adjust function entry count when using instrumented profiles (#143987)
The entry count of a function needs to be updated after a callsite is elided by TRE: before elision, the entry count accounted for the recursive call at that callsite. After TRE, we need to remove that callsite's contribution. This patch enables this for instrumented profiling cases because, there, we know the function entry count captured entries before TRE. We cannot currently address this for sample-based (because we don't know whether this function was TRE-ed in the binary that donated samples)
1 parent 44936c8 commit daa2a58

File tree

5 files changed

+200
-11
lines changed

5 files changed

+200
-11
lines changed

llvm/include/llvm/Passes/PassBuilder.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,8 @@ class PassBuilder {
773773
IntrusiveRefCntPtr<vfs::FileSystem> FS);
774774
void addPostPGOLoopRotation(ModulePassManager &MPM, OptimizationLevel Level);
775775

776+
bool isInstrumentedPGOUse() const;
777+
776778
// Extension Point callbacks
777779
SmallVector<std::function<void(FunctionPassManager &, OptimizationLevel)>, 2>
778780
PeepholeEPCallbacks;

llvm/include/llvm/Transforms/Scalar/TailRecursionElimination.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,12 @@ namespace llvm {
5858

5959
class Function;
6060

61-
struct TailCallElimPass : PassInfoMixin<TailCallElimPass> {
61+
class TailCallElimPass : public PassInfoMixin<TailCallElimPass> {
62+
const bool UpdateFunctionEntryCount;
63+
64+
public:
65+
TailCallElimPass(bool UpdateFunctionEntryCount = true)
66+
: UpdateFunctionEntryCount(UpdateFunctionEntryCount) {}
6267
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
6368
};
6469
}

llvm/lib/Passes/PassBuilderPipelines.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,8 @@ PassBuilder::buildFunctionSimplificationPipeline(OptimizationLevel Level,
625625
!Level.isOptimizingForSize())
626626
FPM.addPass(PGOMemOPSizeOpt());
627627

628-
FPM.addPass(TailCallElimPass());
628+
FPM.addPass(TailCallElimPass(/*UpdateFunctionEntryCount=*/
629+
isInstrumentedPGOUse()));
629630
FPM.addPass(
630631
SimplifyCFGPass(SimplifyCFGOptions().convertSwitchRangeToICmp(true)));
631632

@@ -1578,7 +1579,8 @@ PassBuilder::buildModuleOptimizationPipeline(OptimizationLevel Level,
15781579
OptimizePM.addPass(DivRemPairsPass());
15791580

15801581
// Try to annotate calls that were created during optimization.
1581-
OptimizePM.addPass(TailCallElimPass());
1582+
OptimizePM.addPass(
1583+
TailCallElimPass(/*UpdateFunctionEntryCount=*/isInstrumentedPGOUse()));
15821584

15831585
// LoopSink (and other loop passes since the last simplifyCFG) might have
15841586
// resulted in single-entry-single-exit or empty blocks. Clean up the CFG.
@@ -2066,7 +2068,8 @@ PassBuilder::buildLTODefaultPipeline(OptimizationLevel Level,
20662068

20672069
// LTO provides additional opportunities for tailcall elimination due to
20682070
// link-time inlining, and visibility of nocapture attribute.
2069-
FPM.addPass(TailCallElimPass());
2071+
FPM.addPass(
2072+
TailCallElimPass(/*UpdateFunctionEntryCount=*/isInstrumentedPGOUse()));
20702073

20712074
// Run a few AA driver optimizations here and now to cleanup the code.
20722075
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM),
@@ -2347,3 +2350,8 @@ AAManager PassBuilder::buildDefaultAAPipeline() {
23472350

23482351
return AA;
23492352
}
2353+
2354+
bool PassBuilder::isInstrumentedPGOUse() const {
2355+
return (PGOOpt && PGOOpt->Action == PGOOptions::IRUse) ||
2356+
!UseCtxProfile.empty();
2357+
}

llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
#include "llvm/ADT/STLExtras.h"
5454
#include "llvm/ADT/SmallPtrSet.h"
5555
#include "llvm/ADT/Statistic.h"
56+
#include "llvm/Analysis/BlockFrequencyInfo.h"
5657
#include "llvm/Analysis/DomTreeUpdater.h"
5758
#include "llvm/Analysis/GlobalsModRef.h"
5859
#include "llvm/Analysis/InstructionSimplify.h"
@@ -75,10 +76,12 @@
7576
#include "llvm/IR/Module.h"
7677
#include "llvm/InitializePasses.h"
7778
#include "llvm/Pass.h"
79+
#include "llvm/Support/CommandLine.h"
7880
#include "llvm/Support/Debug.h"
7981
#include "llvm/Support/raw_ostream.h"
8082
#include "llvm/Transforms/Scalar.h"
8183
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
84+
#include <cmath>
8285
using namespace llvm;
8386

8487
#define DEBUG_TYPE "tailcallelim"
@@ -87,6 +90,11 @@ STATISTIC(NumEliminated, "Number of tail calls removed");
8790
STATISTIC(NumRetDuped, "Number of return duplicated");
8891
STATISTIC(NumAccumAdded, "Number of accumulators introduced");
8992

93+
static cl::opt<bool> ForceDisableBFI(
94+
"tre-disable-entrycount-recompute", cl::init(false), cl::Hidden,
95+
cl::desc("Force disabling recomputing of function entry count, on "
96+
"successful tail recursion elimination."));
97+
9098
/// Scan the specified function for alloca instructions.
9199
/// If it contains any dynamic allocas, returns false.
92100
static bool canTRE(Function &F) {
@@ -399,6 +407,9 @@ class TailRecursionEliminator {
399407
AliasAnalysis *AA;
400408
OptimizationRemarkEmitter *ORE;
401409
DomTreeUpdater &DTU;
410+
BlockFrequencyInfo *const BFI;
411+
const uint64_t OrigEntryBBFreq;
412+
const uint64_t OrigEntryCount;
402413

403414
// The below are shared state we want to have available when eliminating any
404415
// calls in the function. There values should be populated by
@@ -428,8 +439,19 @@ class TailRecursionEliminator {
428439

429440
TailRecursionEliminator(Function &F, const TargetTransformInfo *TTI,
430441
AliasAnalysis *AA, OptimizationRemarkEmitter *ORE,
431-
DomTreeUpdater &DTU)
432-
: F(F), TTI(TTI), AA(AA), ORE(ORE), DTU(DTU) {}
442+
DomTreeUpdater &DTU, BlockFrequencyInfo *BFI)
443+
: F(F), TTI(TTI), AA(AA), ORE(ORE), DTU(DTU), BFI(BFI),
444+
OrigEntryBBFreq(
445+
BFI ? BFI->getBlockFreq(&F.getEntryBlock()).getFrequency() : 0U),
446+
OrigEntryCount(F.getEntryCount() ? F.getEntryCount()->getCount() : 0) {
447+
if (BFI) {
448+
// The assert is meant as API documentation for the caller.
449+
assert((OrigEntryCount != 0 && OrigEntryBBFreq != 0) &&
450+
"If a BFI was provided, the function should have both an entry "
451+
"count that is non-zero and an entry basic block with a non-zero "
452+
"frequency.");
453+
}
454+
}
433455

434456
CallInst *findTRECandidate(BasicBlock *BB);
435457

@@ -450,7 +472,7 @@ class TailRecursionEliminator {
450472
public:
451473
static bool eliminate(Function &F, const TargetTransformInfo *TTI,
452474
AliasAnalysis *AA, OptimizationRemarkEmitter *ORE,
453-
DomTreeUpdater &DTU);
475+
DomTreeUpdater &DTU, BlockFrequencyInfo *BFI);
454476
};
455477
} // namespace
456478

@@ -735,6 +757,28 @@ bool TailRecursionEliminator::eliminateCall(CallInst *CI) {
735757
CI->eraseFromParent(); // Remove call.
736758
DTU.applyUpdates({{DominatorTree::Insert, BB, HeaderBB}});
737759
++NumEliminated;
760+
if (OrigEntryBBFreq) {
761+
assert(F.getEntryCount().has_value());
762+
// This pass is not expected to remove BBs, only add an entry BB. For that
763+
// reason, and because the BB here isn't the new entry BB, the BFI lookup is
764+
// expected to succeed.
765+
assert(&F.getEntryBlock() != BB);
766+
auto RelativeBBFreq =
767+
static_cast<double>(BFI->getBlockFreq(BB).getFrequency()) /
768+
static_cast<double>(OrigEntryBBFreq);
769+
auto ToSubtract =
770+
static_cast<uint64_t>(std::round(RelativeBBFreq * OrigEntryCount));
771+
auto OldEntryCount = F.getEntryCount()->getCount();
772+
if (OldEntryCount <= ToSubtract) {
773+
LLVM_DEBUG(
774+
errs() << "[TRE] The entrycount attributable to the recursive call, "
775+
<< ToSubtract
776+
<< ", should be strictly lower than the function entry count, "
777+
<< OldEntryCount << "\n");
778+
} else {
779+
F.setEntryCount(OldEntryCount - ToSubtract, F.getEntryCount()->getType());
780+
}
781+
}
738782
return true;
739783
}
740784

@@ -861,7 +905,8 @@ bool TailRecursionEliminator::eliminate(Function &F,
861905
const TargetTransformInfo *TTI,
862906
AliasAnalysis *AA,
863907
OptimizationRemarkEmitter *ORE,
864-
DomTreeUpdater &DTU) {
908+
DomTreeUpdater &DTU,
909+
BlockFrequencyInfo *BFI) {
865910
if (F.getFnAttribute("disable-tail-calls").getValueAsBool())
866911
return false;
867912

@@ -877,7 +922,7 @@ bool TailRecursionEliminator::eliminate(Function &F,
877922
return MadeChange;
878923

879924
// Change any tail recursive calls to loops.
880-
TailRecursionEliminator TRE(F, TTI, AA, ORE, DTU);
925+
TailRecursionEliminator TRE(F, TTI, AA, ORE, DTU, BFI);
881926

882927
for (BasicBlock &BB : F)
883928
MadeChange |= TRE.processBlock(BB);
@@ -919,7 +964,8 @@ struct TailCallElim : public FunctionPass {
919964
return TailRecursionEliminator::eliminate(
920965
F, &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F),
921966
&getAnalysis<AAResultsWrapperPass>().getAAResults(),
922-
&getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(), DTU);
967+
&getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(), DTU,
968+
/*BFI=*/nullptr);
923969
}
924970
};
925971
}
@@ -942,14 +988,22 @@ PreservedAnalyses TailCallElimPass::run(Function &F,
942988

943989
TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F);
944990
AliasAnalysis &AA = AM.getResult<AAManager>(F);
991+
// This must come first. It needs the 2 analyses, meaning, if it came after
992+
// the lines asking for the cached result, should they be nullptr (which, in
993+
// the case of the PDT, is likely), updates to the trees would be missed.
994+
auto *BFI = (!ForceDisableBFI && UpdateFunctionEntryCount &&
995+
F.getEntryCount().has_value() && F.getEntryCount()->getCount())
996+
? &AM.getResult<BlockFrequencyAnalysis>(F)
997+
: nullptr;
945998
auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
946999
auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
9471000
auto *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F);
9481001
// There is no noticable performance difference here between Lazy and Eager
9491002
// UpdateStrategy based on some test results. It is feasible to switch the
9501003
// UpdateStrategy to Lazy if we find it profitable later.
9511004
DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager);
952-
bool Changed = TailRecursionEliminator::eliminate(F, &TTI, &AA, &ORE, DTU);
1005+
bool Changed =
1006+
TailRecursionEliminator::eliminate(F, &TTI, &AA, &ORE, DTU, BFI);
9531007

9541008
if (!Changed)
9551009
return PreservedAnalyses::all();
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals
2+
; RUN: opt -passes=tailcallelim -S %s -o - | FileCheck %s --check-prefixes=CHECK,ENABLED
3+
; RUN: opt -passes=tailcallelim -tre-disable-entrycount-recompute -S %s -o - | FileCheck %s --check-prefixes=CHECK,DISABLED
4+
5+
; Test that tail call elimination correctly adjusts function entry counts
6+
; when eliminating tail recursive calls.
7+
8+
; Basic test: eliminate a tail call and adjust entry count
9+
define i32 @test_basic_entry_count_adjustment(i32 %n) !prof !0 {
10+
; CHECK-LABEL: @test_basic_entry_count_adjustment(
11+
; CHECK-NEXT: entry:
12+
; CHECK-NEXT: br label [[TAILRECURSE:%.*]]
13+
; CHECK: tailrecurse:
14+
; CHECK-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB:%.*]], [[IF_THEN:%.*]] ]
15+
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
16+
; CHECK-NEXT: br i1 [[CMP]], label [[IF_THEN]], label [[IF_ELSE:%.*]], !prof [[PROF1:![0-9]+]]
17+
; CHECK: if.then:
18+
; CHECK-NEXT: [[SUB]] = sub i32 [[N_TR]], 1
19+
; CHECK-NEXT: br label [[TAILRECURSE]]
20+
; CHECK: if.else:
21+
; CHECK-NEXT: ret i32 0
22+
;
23+
entry:
24+
%cmp = icmp sgt i32 %n, 0
25+
br i1 %cmp, label %if.then, label %if.else, !prof !1
26+
27+
if.then: ; preds = %entry
28+
%sub = sub i32 %n, 1
29+
%call = tail call i32 @test_basic_entry_count_adjustment(i32 %sub)
30+
ret i32 %call
31+
32+
if.else: ; preds = %entry
33+
ret i32 0
34+
}
35+
36+
; Test multiple tail calls in different blocks with different frequencies
37+
define i32 @test_multiple_blocks_entry_count(i32 %n, i32 %flag) !prof !2 {
38+
; CHECK-LABEL: @test_multiple_blocks_entry_count(
39+
; CHECK-NEXT: entry:
40+
; CHECK-NEXT: br label [[TAILRECURSE:%.*]]
41+
; CHECK: tailrecurse:
42+
; CHECK-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB1:%.*]], [[BLOCK1:%.*]] ], [ [[SUB2:%.*]], [[BLOCK2:%.*]] ]
43+
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
44+
; CHECK-NEXT: br i1 [[CMP]], label [[CHECK_FLAG:%.*]], label [[BASE_CASE:%.*]], !prof [[PROF3:![0-9]+]]
45+
; CHECK: check.flag:
46+
; CHECK-NEXT: [[CMP_FLAG:%.*]] = icmp eq i32 [[FLAG:%.*]], 1
47+
; CHECK-NEXT: br i1 [[CMP_FLAG]], label [[BLOCK1]], label [[BLOCK2]], !prof [[PROF4:![0-9]+]]
48+
; CHECK: block1:
49+
; CHECK-NEXT: [[SUB1]] = sub i32 [[N_TR]], 1
50+
; CHECK-NEXT: br label [[TAILRECURSE]]
51+
; CHECK: block2:
52+
; CHECK-NEXT: [[SUB2]] = sub i32 [[N_TR]], 2
53+
; CHECK-NEXT: br label [[TAILRECURSE]]
54+
; CHECK: base.case:
55+
; CHECK-NEXT: ret i32 1
56+
;
57+
entry:
58+
%cmp = icmp sgt i32 %n, 0
59+
br i1 %cmp, label %check.flag, label %base.case, !prof !3
60+
check.flag:
61+
%cmp.flag = icmp eq i32 %flag, 1
62+
br i1 %cmp.flag, label %block1, label %block2, !prof !4
63+
block1: ; preds = %check.flag
64+
%sub1 = sub i32 %n, 1
65+
%call1 = tail call i32 @test_multiple_blocks_entry_count(i32 %sub1, i32 %flag)
66+
ret i32 %call1
67+
block2: ; preds = %check.flag
68+
%sub2 = sub i32 %n, 2
69+
%call2 = tail call i32 @test_multiple_blocks_entry_count(i32 %sub2, i32 %flag)
70+
ret i32 %call2
71+
base.case: ; preds = %entry
72+
ret i32 1
73+
}
74+
75+
define i32 @test_no_entry_count(i32 %n) {
76+
; CHECK-LABEL: @test_no_entry_count(
77+
; CHECK-NEXT: entry:
78+
; CHECK-NEXT: br label [[TAILRECURSE:%.*]]
79+
; CHECK: tailrecurse:
80+
; CHECK-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB:%.*]], [[IF_THEN:%.*]] ]
81+
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
82+
; CHECK-NEXT: br i1 [[CMP]], label [[IF_THEN]], label [[IF_ELSE:%.*]]
83+
; CHECK: if.then:
84+
; CHECK-NEXT: [[SUB]] = sub i32 [[N_TR]], 1
85+
; CHECK-NEXT: br label [[TAILRECURSE]]
86+
; CHECK: if.else:
87+
; CHECK-NEXT: ret i32 0
88+
;
89+
entry:
90+
%cmp = icmp sgt i32 %n, 0
91+
br i1 %cmp, label %if.then, label %if.else
92+
93+
if.then: ; preds = %entry
94+
%sub = sub i32 %n, 1
95+
%call = tail call i32 @test_no_entry_count(i32 %sub)
96+
ret i32 %call
97+
98+
if.else: ; preds = %entry
99+
ret i32 0
100+
}
101+
102+
; Function entry count metadata
103+
!0 = !{!"function_entry_count", i64 1000}
104+
!1 = !{!"branch_weights", i32 800, i32 200}
105+
!2 = !{!"function_entry_count", i64 2000}
106+
!3 = !{!"branch_weights", i32 3, i32 1}
107+
!4 = !{!"branch_weights", i32 100, i32 900}
108+
;.
109+
; ENABLED: [[META0:![0-9]+]] = !{!"function_entry_count", i64 200}
110+
; ENABLED: [[PROF1]] = !{!"branch_weights", i32 800, i32 200}
111+
; ENABLED: [[META2:![0-9]+]] = !{!"function_entry_count", i64 500}
112+
; ENABLED: [[PROF3]] = !{!"branch_weights", i32 3, i32 1}
113+
; ENABLED: [[PROF4]] = !{!"branch_weights", i32 100, i32 900}
114+
;.
115+
; DISABLED: [[META0:![0-9]+]] = !{!"function_entry_count", i64 1000}
116+
; DISABLED: [[PROF1]] = !{!"branch_weights", i32 800, i32 200}
117+
; DISABLED: [[META2:![0-9]+]] = !{!"function_entry_count", i64 2000}
118+
; DISABLED: [[PROF3]] = !{!"branch_weights", i32 3, i32 1}
119+
; DISABLED: [[PROF4]] = !{!"branch_weights", i32 100, i32 900}
120+
;.

0 commit comments

Comments
 (0)