Skip to content

Commit 51cf14a

Browse files
committed
[TailRecursionElim] Adjust function entry count
1 parent 64155a3 commit 51cf14a

File tree

5 files changed

+246
-11
lines changed

5 files changed

+246
-11
lines changed

llvm/include/llvm/Passes/PassBuilder.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,8 @@ class PassBuilder {
773773
IntrusiveRefCntPtr<vfs::FileSystem> FS);
774774
void addPostPGOLoopRotation(ModulePassManager &MPM, OptimizationLevel Level);
775775

776+
bool isInstrumentedPGOUse() const;
777+
776778
// Extension Point callbacks
777779
SmallVector<std::function<void(FunctionPassManager &, OptimizationLevel)>, 2>
778780
PeepholeEPCallbacks;

llvm/include/llvm/Transforms/Scalar/TailRecursionElimination.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,12 @@ namespace llvm {
5858

5959
class Function;
6060

61-
struct TailCallElimPass : PassInfoMixin<TailCallElimPass> {
61+
class TailCallElimPass : public PassInfoMixin<TailCallElimPass> {
62+
const bool UpdateFunctionEntryCount;
63+
64+
public:
65+
TailCallElimPass(bool UpdateFunctionEntryCount = true)
66+
: UpdateFunctionEntryCount(UpdateFunctionEntryCount) {}
6267
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
6368
};
6469
}

llvm/lib/Passes/PassBuilderPipelines.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,8 @@ PassBuilder::buildFunctionSimplificationPipeline(OptimizationLevel Level,
628628
!Level.isOptimizingForSize())
629629
FPM.addPass(PGOMemOPSizeOpt());
630630

631-
FPM.addPass(TailCallElimPass());
631+
FPM.addPass(TailCallElimPass(/*UpdateFunctionEntryCount=*/
632+
isInstrumentedPGOUse()));
632633
FPM.addPass(
633634
SimplifyCFGPass(SimplifyCFGOptions().convertSwitchRangeToICmp(true)));
634635

@@ -1581,7 +1582,8 @@ PassBuilder::buildModuleOptimizationPipeline(OptimizationLevel Level,
15811582
OptimizePM.addPass(DivRemPairsPass());
15821583

15831584
// Try to annotate calls that were created during optimization.
1584-
OptimizePM.addPass(TailCallElimPass());
1585+
OptimizePM.addPass(
1586+
TailCallElimPass(/*UpdateFunctionEntryCount=*/isInstrumentedPGOUse()));
15851587

15861588
// LoopSink (and other loop passes since the last simplifyCFG) might have
15871589
// resulted in single-entry-single-exit or empty blocks. Clean up the CFG.
@@ -2069,7 +2071,8 @@ PassBuilder::buildLTODefaultPipeline(OptimizationLevel Level,
20692071

20702072
// LTO provides additional opportunities for tailcall elimination due to
20712073
// link-time inlining, and visibility of nocapture attribute.
2072-
FPM.addPass(TailCallElimPass());
2074+
FPM.addPass(
2075+
TailCallElimPass(/*UpdateFunctionEntryCount=*/isInstrumentedPGOUse()));
20732076

20742077
// Run a few AA driver optimizations here and now to cleanup the code.
20752078
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM),
@@ -2350,3 +2353,8 @@ AAManager PassBuilder::buildDefaultAAPipeline() {
23502353

23512354
return AA;
23522355
}
2356+
2357+
bool PassBuilder::isInstrumentedPGOUse() const {
2358+
return (PGOOpt && PGOOpt->Action == PGOOptions::IRUse) ||
2359+
!UseCtxProfile.empty();
2360+
}

llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
#include "llvm/ADT/STLExtras.h"
5454
#include "llvm/ADT/SmallPtrSet.h"
5555
#include "llvm/ADT/Statistic.h"
56+
#include "llvm/Analysis/BlockFrequencyInfo.h"
5657
#include "llvm/Analysis/DomTreeUpdater.h"
5758
#include "llvm/Analysis/GlobalsModRef.h"
5859
#include "llvm/Analysis/InstructionSimplify.h"
@@ -75,10 +76,12 @@
7576
#include "llvm/IR/Module.h"
7677
#include "llvm/InitializePasses.h"
7778
#include "llvm/Pass.h"
79+
#include "llvm/Support/CommandLine.h"
7880
#include "llvm/Support/Debug.h"
7981
#include "llvm/Support/raw_ostream.h"
8082
#include "llvm/Transforms/Scalar.h"
8183
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
84+
#include <cmath>
8285
using namespace llvm;
8386

8487
#define DEBUG_TYPE "tailcallelim"
@@ -87,6 +90,11 @@ STATISTIC(NumEliminated, "Number of tail calls removed");
8790
STATISTIC(NumRetDuped, "Number of return duplicated");
8891
STATISTIC(NumAccumAdded, "Number of accumulators introduced");
8992

93+
static cl::opt<bool> ForceDisableBFI(
94+
"tre-disable-entrycount-recompute", cl::init(false), cl::Hidden,
95+
cl::desc("Force disabling recomputing of function entry count, on "
96+
"successful tail recursion elimination."));
97+
9098
/// Scan the specified function for alloca instructions.
9199
/// If it contains any dynamic allocas, returns false.
92100
static bool canTRE(Function &F) {
@@ -399,6 +407,8 @@ class TailRecursionEliminator {
399407
AliasAnalysis *AA;
400408
OptimizationRemarkEmitter *ORE;
401409
DomTreeUpdater &DTU;
410+
BlockFrequencyInfo *const BFI;
411+
const uint64_t OrigEntryBBFreq;
402412

403413
// The below are shared state we want to have available when eliminating any
404414
// calls in the function. There values should be populated by
@@ -428,8 +438,19 @@ class TailRecursionEliminator {
428438

429439
TailRecursionEliminator(Function &F, const TargetTransformInfo *TTI,
430440
AliasAnalysis *AA, OptimizationRemarkEmitter *ORE,
431-
DomTreeUpdater &DTU)
432-
: F(F), TTI(TTI), AA(AA), ORE(ORE), DTU(DTU) {}
441+
DomTreeUpdater &DTU, BlockFrequencyInfo *BFI)
442+
: F(F), TTI(TTI), AA(AA), ORE(ORE), DTU(DTU), BFI(BFI),
443+
OrigEntryBBFreq(
444+
BFI ? BFI->getBlockFreq(&F.getEntryBlock()).getFrequency() : 0U) {
445+
if (BFI) {
446+
auto EC = F.getEntryCount();
447+
(void)EC;
448+
assert((EC.has_value() && EC->getCount() != 0 && OrigEntryBBFreq) &&
449+
"If a BFI was provided, the function should have both an entry "
450+
"count that is non-zero and an entry basic block with a non-zero "
451+
"frequency.");
452+
}
453+
}
433454

434455
CallInst *findTRECandidate(BasicBlock *BB);
435456

@@ -450,7 +471,7 @@ class TailRecursionEliminator {
450471
public:
451472
static bool eliminate(Function &F, const TargetTransformInfo *TTI,
452473
AliasAnalysis *AA, OptimizationRemarkEmitter *ORE,
453-
DomTreeUpdater &DTU);
474+
DomTreeUpdater &DTU, BlockFrequencyInfo *BFI);
454475
};
455476
} // namespace
456477

@@ -735,6 +756,28 @@ bool TailRecursionEliminator::eliminateCall(CallInst *CI) {
735756
CI->eraseFromParent(); // Remove call.
736757
DTU.applyUpdates({{DominatorTree::Insert, BB, HeaderBB}});
737758
++NumEliminated;
759+
if (OrigEntryBBFreq) {
760+
assert(F.getEntryCount().has_value());
761+
// This pass is not expected to remove BBs, only add an entry BB. For that
762+
// reason, and because the BB here isn't the new entry BB, the BFI lookup is
763+
// expected to succeed.
764+
assert(&F.getEntryBlock() != BB);
765+
auto RelativeBBFreq =
766+
static_cast<double>(BFI->getBlockFreq(BB).getFrequency()) /
767+
static_cast<double>(OrigEntryBBFreq);
768+
auto OldEntryCount = F.getEntryCount()->getCount();
769+
auto ToSubtract =
770+
static_cast<uint64_t>(std::round(RelativeBBFreq * OldEntryCount));
771+
if (OldEntryCount <= ToSubtract) {
772+
LLVM_DEBUG(
773+
errs() << "[TRE] The entrycount attributable to the recursive call, "
774+
<< ToSubtract
775+
<< ", should be strictly lower than the original function "
776+
"entry count, "
777+
<< OldEntryCount << "\n");
778+
}
779+
F.setEntryCount(OldEntryCount - ToSubtract, F.getEntryCount()->getType());
780+
}
738781
return true;
739782
}
740783

@@ -861,7 +904,8 @@ bool TailRecursionEliminator::eliminate(Function &F,
861904
const TargetTransformInfo *TTI,
862905
AliasAnalysis *AA,
863906
OptimizationRemarkEmitter *ORE,
864-
DomTreeUpdater &DTU) {
907+
DomTreeUpdater &DTU,
908+
BlockFrequencyInfo *BFI) {
865909
if (F.getFnAttribute("disable-tail-calls").getValueAsBool())
866910
return false;
867911

@@ -877,7 +921,7 @@ bool TailRecursionEliminator::eliminate(Function &F,
877921
return MadeChange;
878922

879923
// Change any tail recursive calls to loops.
880-
TailRecursionEliminator TRE(F, TTI, AA, ORE, DTU);
924+
TailRecursionEliminator TRE(F, TTI, AA, ORE, DTU, BFI);
881925

882926
for (BasicBlock &BB : F)
883927
MadeChange |= TRE.processBlock(BB);
@@ -919,7 +963,8 @@ struct TailCallElim : public FunctionPass {
919963
return TailRecursionEliminator::eliminate(
920964
F, &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F),
921965
&getAnalysis<AAResultsWrapperPass>().getAAResults(),
922-
&getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(), DTU);
966+
&getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(), DTU,
967+
nullptr);
923968
}
924969
};
925970
}
@@ -942,14 +987,22 @@ PreservedAnalyses TailCallElimPass::run(Function &F,
942987

943988
TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F);
944989
AliasAnalysis &AA = AM.getResult<AAManager>(F);
990+
// This must come first. It needs the 2 analyses, meaning, if it came after
991+
// the lines asking for the cached result, should they be nullptr (which, in
992+
// the case of the PDT, is likely), updates to the trees would be missed.
993+
auto *BFI = (!ForceDisableBFI && UpdateFunctionEntryCount &&
994+
F.getEntryCount().has_value() && F.getEntryCount()->getCount())
995+
? &AM.getResult<BlockFrequencyAnalysis>(F)
996+
: nullptr;
945997
auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
946998
auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
947999
auto *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F);
9481000
// There is no noticable performance difference here between Lazy and Eager
9491001
// UpdateStrategy based on some test results. It is feasible to switch the
9501002
// UpdateStrategy to Lazy if we find it profitable later.
9511003
DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager);
952-
bool Changed = TailRecursionEliminator::eliminate(F, &TTI, &AA, &ORE, DTU);
1004+
bool Changed =
1005+
TailRecursionEliminator::eliminate(F, &TTI, &AA, &ORE, DTU, BFI);
9531006

9541007
if (!Changed)
9551008
return PreservedAnalyses::all();
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals
2+
; RUN: opt -passes=tailcallelim -S %s -o - | FileCheck %s
3+
; RUN: opt -passes=tailcallelim -tre-disable-entrycount-recompute -S %s -o - | FileCheck %s --check-prefix=DISABLED
4+
5+
; Test that tail call elimination correctly adjusts function entry counts
6+
; when eliminating tail recursive calls.
7+
8+
; Basic test: eliminate a tail call and adjust entry count
9+
define i32 @test_basic_entry_count_adjustment(i32 %n) !prof !0 {
10+
; CHECK-LABEL: @test_basic_entry_count_adjustment(
11+
; CHECK-NEXT: entry:
12+
; CHECK-NEXT: br label [[TAILRECURSE:%.*]]
13+
; CHECK: tailrecurse:
14+
; CHECK-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB:%.*]], [[IF_THEN:%.*]] ]
15+
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
16+
; CHECK-NEXT: br i1 [[CMP]], label [[IF_THEN]], label [[IF_ELSE:%.*]], !prof [[PROF1:![0-9]+]]
17+
; CHECK: if.then:
18+
; CHECK-NEXT: [[SUB]] = sub i32 [[N_TR]], 1
19+
; CHECK-NEXT: br label [[TAILRECURSE]]
20+
; CHECK: if.else:
21+
; CHECK-NEXT: ret i32 0
22+
;
23+
; DISABLED-LABEL: @test_basic_entry_count_adjustment(
24+
; DISABLED-NEXT: entry:
25+
; DISABLED-NEXT: br label [[TAILRECURSE:%.*]]
26+
; DISABLED: tailrecurse:
27+
; DISABLED-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB:%.*]], [[IF_THEN:%.*]] ]
28+
; DISABLED-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
29+
; DISABLED-NEXT: br i1 [[CMP]], label [[IF_THEN]], label [[IF_ELSE:%.*]], !prof [[PROF1:![0-9]+]]
30+
; DISABLED: if.then:
31+
; DISABLED-NEXT: [[SUB]] = sub i32 [[N_TR]], 1
32+
; DISABLED-NEXT: br label [[TAILRECURSE]]
33+
; DISABLED: if.else:
34+
; DISABLED-NEXT: ret i32 0
35+
;
36+
entry:
37+
%cmp = icmp sgt i32 %n, 0
38+
br i1 %cmp, label %if.then, label %if.else, !prof !1
39+
40+
if.then: ; preds = %entry
41+
%sub = sub i32 %n, 1
42+
%call = tail call i32 @test_basic_entry_count_adjustment(i32 %sub)
43+
ret i32 %call
44+
45+
if.else: ; preds = %entry
46+
ret i32 0
47+
}
48+
49+
; Test multiple tail calls in different blocks with different frequencies
50+
define i32 @test_multiple_blocks_entry_count(i32 %n, i32 %flag) !prof !2 {
51+
; CHECK-LABEL: @test_multiple_blocks_entry_count(
52+
; CHECK-NEXT: entry:
53+
; CHECK-NEXT: br label [[TAILRECURSE:%.*]]
54+
; CHECK: tailrecurse:
55+
; CHECK-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB1:%.*]], [[BLOCK1:%.*]] ], [ [[SUB2:%.*]], [[BLOCK2:%.*]] ]
56+
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
57+
; CHECK-NEXT: br i1 [[CMP]], label [[CHECK_FLAG:%.*]], label [[BASE_CASE:%.*]]
58+
; CHECK: check.flag:
59+
; CHECK-NEXT: [[CMP_FLAG:%.*]] = icmp eq i32 [[FLAG:%.*]], 1
60+
; CHECK-NEXT: br i1 [[CMP_FLAG]], label [[BLOCK1]], label [[BLOCK2]], !prof [[PROF3:![0-9]+]]
61+
; CHECK: block1:
62+
; CHECK-NEXT: [[SUB1]] = sub i32 [[N_TR]], 1
63+
; CHECK-NEXT: br label [[TAILRECURSE]]
64+
; CHECK: block2:
65+
; CHECK-NEXT: [[SUB2]] = sub i32 [[N_TR]], 2
66+
; CHECK-NEXT: br label [[TAILRECURSE]]
67+
; CHECK: base.case:
68+
; CHECK-NEXT: ret i32 1
69+
;
70+
; DISABLED-LABEL: @test_multiple_blocks_entry_count(
71+
; DISABLED-NEXT: entry:
72+
; DISABLED-NEXT: br label [[TAILRECURSE:%.*]]
73+
; DISABLED: tailrecurse:
74+
; DISABLED-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB1:%.*]], [[BLOCK1:%.*]] ], [ [[SUB2:%.*]], [[BLOCK2:%.*]] ]
75+
; DISABLED-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
76+
; DISABLED-NEXT: br i1 [[CMP]], label [[CHECK_FLAG:%.*]], label [[BASE_CASE:%.*]]
77+
; DISABLED: check.flag:
78+
; DISABLED-NEXT: [[CMP_FLAG:%.*]] = icmp eq i32 [[FLAG:%.*]], 1
79+
; DISABLED-NEXT: br i1 [[CMP_FLAG]], label [[BLOCK1]], label [[BLOCK2]], !prof [[PROF3:![0-9]+]]
80+
; DISABLED: block1:
81+
; DISABLED-NEXT: [[SUB1]] = sub i32 [[N_TR]], 1
82+
; DISABLED-NEXT: br label [[TAILRECURSE]]
83+
; DISABLED: block2:
84+
; DISABLED-NEXT: [[SUB2]] = sub i32 [[N_TR]], 2
85+
; DISABLED-NEXT: br label [[TAILRECURSE]]
86+
; DISABLED: base.case:
87+
; DISABLED-NEXT: ret i32 1
88+
;
89+
entry:
90+
%cmp = icmp sgt i32 %n, 0
91+
br i1 %cmp, label %check.flag, label %base.case
92+
93+
check.flag:
94+
%cmp.flag = icmp eq i32 %flag, 1
95+
br i1 %cmp.flag, label %block1, label %block2, !prof !3
96+
97+
block1: ; preds = %check.flag
98+
%sub1 = sub i32 %n, 1
99+
%call1 = tail call i32 @test_multiple_blocks_entry_count(i32 %sub1, i32 %flag)
100+
ret i32 %call1
101+
102+
block2: ; preds = %check.flag
103+
%sub2 = sub i32 %n, 2
104+
%call2 = tail call i32 @test_multiple_blocks_entry_count(i32 %sub2, i32 %flag)
105+
ret i32 %call2
106+
107+
base.case: ; preds = %entry
108+
ret i32 1
109+
}
110+
111+
; Test function without entry count (should not crash)
112+
define i32 @test_no_entry_count(i32 %n) {
113+
; CHECK-LABEL: @test_no_entry_count(
114+
; CHECK-NEXT: entry:
115+
; CHECK-NEXT: br label [[TAILRECURSE:%.*]]
116+
; CHECK: tailrecurse:
117+
; CHECK-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB:%.*]], [[IF_THEN:%.*]] ]
118+
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
119+
; CHECK-NEXT: br i1 [[CMP]], label [[IF_THEN]], label [[IF_ELSE:%.*]]
120+
; CHECK: if.then:
121+
; CHECK-NEXT: [[SUB]] = sub i32 [[N_TR]], 1
122+
; CHECK-NEXT: br label [[TAILRECURSE]]
123+
; CHECK: if.else:
124+
; CHECK-NEXT: ret i32 0
125+
;
126+
; DISABLED-LABEL: @test_no_entry_count(
127+
; DISABLED-NEXT: entry:
128+
; DISABLED-NEXT: br label [[TAILRECURSE:%.*]]
129+
; DISABLED: tailrecurse:
130+
; DISABLED-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB:%.*]], [[IF_THEN:%.*]] ]
131+
; DISABLED-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
132+
; DISABLED-NEXT: br i1 [[CMP]], label [[IF_THEN]], label [[IF_ELSE:%.*]]
133+
; DISABLED: if.then:
134+
; DISABLED-NEXT: [[SUB]] = sub i32 [[N_TR]], 1
135+
; DISABLED-NEXT: br label [[TAILRECURSE]]
136+
; DISABLED: if.else:
137+
; DISABLED-NEXT: ret i32 0
138+
;
139+
entry:
140+
%cmp = icmp sgt i32 %n, 0
141+
br i1 %cmp, label %if.then, label %if.else
142+
143+
if.then: ; preds = %entry
144+
%sub = sub i32 %n, 1
145+
%call = tail call i32 @test_no_entry_count(i32 %sub)
146+
ret i32 %call
147+
148+
if.else: ; preds = %entry
149+
ret i32 0
150+
}
151+
152+
; Function entry count metadata
153+
!0 = !{!"function_entry_count", i64 1000}
154+
!1 = !{!"branch_weights", i32 800, i32 200}
155+
!2 = !{!"function_entry_count", i64 2000}
156+
!3 = !{!"branch_weights", i32 100, i32 500}
157+
;.
158+
; CHECK: [[META0:![0-9]+]] = !{!"function_entry_count", i64 200}
159+
; CHECK: [[PROF1]] = !{!"branch_weights", i32 800, i32 200}
160+
; CHECK: [[META2:![0-9]+]] = !{!"function_entry_count", i64 859}
161+
; CHECK: [[PROF3]] = !{!"branch_weights", i32 100, i32 500}
162+
;.
163+
; DISABLED: [[META0:![0-9]+]] = !{!"function_entry_count", i64 1000}
164+
; DISABLED: [[PROF1]] = !{!"branch_weights", i32 800, i32 200}
165+
; DISABLED: [[META2:![0-9]+]] = !{!"function_entry_count", i64 2000}
166+
; DISABLED: [[PROF3]] = !{!"branch_weights", i32 100, i32 500}
167+
;.

0 commit comments

Comments
 (0)