Skip to content

Commit 5e22095

Browse files
committed
[TailRecursionElim] Adjust function entry count
1 parent 64155a3 commit 5e22095

File tree

5 files changed

+196
-11
lines changed

5 files changed

+196
-11
lines changed

llvm/include/llvm/Passes/PassBuilder.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,8 @@ class PassBuilder {
773773
IntrusiveRefCntPtr<vfs::FileSystem> FS);
774774
void addPostPGOLoopRotation(ModulePassManager &MPM, OptimizationLevel Level);
775775

776+
bool isInstrumentedPGOUse() const;
777+
776778
// Extension Point callbacks
777779
SmallVector<std::function<void(FunctionPassManager &, OptimizationLevel)>, 2>
778780
PeepholeEPCallbacks;

llvm/include/llvm/Transforms/Scalar/TailRecursionElimination.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,12 @@ namespace llvm {
5858

5959
class Function;
6060

61-
struct TailCallElimPass : PassInfoMixin<TailCallElimPass> {
61+
class TailCallElimPass : public PassInfoMixin<TailCallElimPass> {
62+
const bool UpdateFunctionEntryCount;
63+
64+
public:
65+
TailCallElimPass(bool UpdateFunctionEntryCount = true)
66+
: UpdateFunctionEntryCount(UpdateFunctionEntryCount) {}
6267
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
6368
};
6469
}

llvm/lib/Passes/PassBuilderPipelines.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,8 @@ PassBuilder::buildFunctionSimplificationPipeline(OptimizationLevel Level,
628628
!Level.isOptimizingForSize())
629629
FPM.addPass(PGOMemOPSizeOpt());
630630

631-
FPM.addPass(TailCallElimPass());
631+
FPM.addPass(TailCallElimPass(/*UpdateFunctionEntryCount=*/
632+
isInstrumentedPGOUse()));
632633
FPM.addPass(
633634
SimplifyCFGPass(SimplifyCFGOptions().convertSwitchRangeToICmp(true)));
634635

@@ -1581,7 +1582,8 @@ PassBuilder::buildModuleOptimizationPipeline(OptimizationLevel Level,
15811582
OptimizePM.addPass(DivRemPairsPass());
15821583

15831584
// Try to annotate calls that were created during optimization.
1584-
OptimizePM.addPass(TailCallElimPass());
1585+
OptimizePM.addPass(
1586+
TailCallElimPass(/*UpdateFunctionEntryCount=*/isInstrumentedPGOUse()));
15851587

15861588
// LoopSink (and other loop passes since the last simplifyCFG) might have
15871589
// resulted in single-entry-single-exit or empty blocks. Clean up the CFG.
@@ -2069,7 +2071,8 @@ PassBuilder::buildLTODefaultPipeline(OptimizationLevel Level,
20692071

20702072
// LTO provides additional opportunities for tailcall elimination due to
20712073
// link-time inlining, and visibility of nocapture attribute.
2072-
FPM.addPass(TailCallElimPass());
2074+
FPM.addPass(
2075+
TailCallElimPass(/*UpdateFunctionEntryCount=*/isInstrumentedPGOUse()));
20732076

20742077
// Run a few AA driver optimizations here and now to cleanup the code.
20752078
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM),
@@ -2350,3 +2353,8 @@ AAManager PassBuilder::buildDefaultAAPipeline() {
23502353

23512354
return AA;
23522355
}
2356+
2357+
bool PassBuilder::isInstrumentedPGOUse() const {
2358+
return (PGOOpt && PGOOpt->Action == PGOOptions::IRUse) ||
2359+
!UseCtxProfile.empty();
2360+
}

llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
#include "llvm/ADT/STLExtras.h"
5454
#include "llvm/ADT/SmallPtrSet.h"
5555
#include "llvm/ADT/Statistic.h"
56+
#include "llvm/Analysis/BlockFrequencyInfo.h"
5657
#include "llvm/Analysis/DomTreeUpdater.h"
5758
#include "llvm/Analysis/GlobalsModRef.h"
5859
#include "llvm/Analysis/InstructionSimplify.h"
@@ -75,10 +76,12 @@
7576
#include "llvm/IR/Module.h"
7677
#include "llvm/InitializePasses.h"
7778
#include "llvm/Pass.h"
79+
#include "llvm/Support/CommandLine.h"
7880
#include "llvm/Support/Debug.h"
7981
#include "llvm/Support/raw_ostream.h"
8082
#include "llvm/Transforms/Scalar.h"
8183
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
84+
#include <cmath>
8285
using namespace llvm;
8386

8487
#define DEBUG_TYPE "tailcallelim"
@@ -87,6 +90,11 @@ STATISTIC(NumEliminated, "Number of tail calls removed");
8790
STATISTIC(NumRetDuped, "Number of return duplicated");
8891
STATISTIC(NumAccumAdded, "Number of accumulators introduced");
8992

93+
static cl::opt<bool> ForceDisableBFI(
94+
"tre-disable-entrycount-recompute", cl::init(false), cl::Hidden,
95+
cl::desc("Force disabling recomputing of function entry count, on "
96+
"successful tail recursion elimination."));
97+
9098
/// Scan the specified function for alloca instructions.
9199
/// If it contains any dynamic allocas, returns false.
92100
static bool canTRE(Function &F) {
@@ -399,6 +407,8 @@ class TailRecursionEliminator {
399407
AliasAnalysis *AA;
400408
OptimizationRemarkEmitter *ORE;
401409
DomTreeUpdater &DTU;
410+
BlockFrequencyInfo *const BFI;
411+
const uint64_t OrigEntryBBFreq;
402412

403413
// The below are shared state we want to have available when eliminating any
404414
// calls in the function. There values should be populated by
@@ -428,8 +438,20 @@ class TailRecursionEliminator {
428438

429439
TailRecursionEliminator(Function &F, const TargetTransformInfo *TTI,
430440
AliasAnalysis *AA, OptimizationRemarkEmitter *ORE,
431-
DomTreeUpdater &DTU)
432-
: F(F), TTI(TTI), AA(AA), ORE(ORE), DTU(DTU) {}
441+
DomTreeUpdater &DTU, BlockFrequencyInfo *BFI)
442+
: F(F), TTI(TTI), AA(AA), ORE(ORE), DTU(DTU), BFI(BFI),
443+
OrigEntryBBFreq(
444+
BFI ? BFI->getBlockFreq(&F.getEntryBlock()).getFrequency() : 0U) {
445+
if (BFI) {
446+
auto EC = F.getEntryCount();
447+
(void)EC;
448+
assert(
449+
(EC.has_value() && EC->getCount() != 0 && OrigEntryBBFreq) &&
450+
"If the function has an entry count, its entry basic block should "
451+
"have a non-zero frequency. Pass a nullptr BFI if the function has "
452+
"no entry count");
453+
}
454+
}
433455

434456
CallInst *findTRECandidate(BasicBlock *BB);
435457

@@ -450,7 +472,7 @@ class TailRecursionEliminator {
450472
public:
451473
static bool eliminate(Function &F, const TargetTransformInfo *TTI,
452474
AliasAnalysis *AA, OptimizationRemarkEmitter *ORE,
453-
DomTreeUpdater &DTU);
475+
DomTreeUpdater &DTU, BlockFrequencyInfo *BFI);
454476
};
455477
} // namespace
456478

@@ -735,6 +757,28 @@ bool TailRecursionEliminator::eliminateCall(CallInst *CI) {
735757
CI->eraseFromParent(); // Remove call.
736758
DTU.applyUpdates({{DominatorTree::Insert, BB, HeaderBB}});
737759
++NumEliminated;
760+
if (OrigEntryBBFreq) {
761+
assert(F.getEntryCount().has_value());
762+
// This pass is not expected to remove BBs, only add an entry BB. For that
763+
// reason, and because the BB here isn't the new entry BB, the BFI lookup is
764+
// expected to succeed.
765+
assert(&F.getEntryBlock() != BB);
766+
auto RelativeBBFreq =
767+
static_cast<double>(BFI->getBlockFreq(BB).getFrequency()) /
768+
static_cast<double>(OrigEntryBBFreq);
769+
auto OldEntryCount = F.getEntryCount()->getCount();
770+
auto ToSubtract =
771+
static_cast<uint64_t>(std::round(RelativeBBFreq * OldEntryCount));
772+
if (OldEntryCount <= ToSubtract) {
773+
LLVM_DEBUG(
774+
errs() << "[TRE] The entrycount attributable to the recursive call, "
775+
<< ToSubtract
776+
<< ", should be strictly lower than the original function "
777+
"entry count, "
778+
<< OldEntryCount << "\n");
779+
}
780+
F.setEntryCount(OldEntryCount - ToSubtract, F.getEntryCount()->getType());
781+
}
738782
return true;
739783
}
740784

@@ -861,7 +905,8 @@ bool TailRecursionEliminator::eliminate(Function &F,
861905
const TargetTransformInfo *TTI,
862906
AliasAnalysis *AA,
863907
OptimizationRemarkEmitter *ORE,
864-
DomTreeUpdater &DTU) {
908+
DomTreeUpdater &DTU,
909+
BlockFrequencyInfo *BFI) {
865910
if (F.getFnAttribute("disable-tail-calls").getValueAsBool())
866911
return false;
867912

@@ -877,7 +922,7 @@ bool TailRecursionEliminator::eliminate(Function &F,
877922
return MadeChange;
878923

879924
// Change any tail recursive calls to loops.
880-
TailRecursionEliminator TRE(F, TTI, AA, ORE, DTU);
925+
TailRecursionEliminator TRE(F, TTI, AA, ORE, DTU, BFI);
881926

882927
for (BasicBlock &BB : F)
883928
MadeChange |= TRE.processBlock(BB);
@@ -919,7 +964,8 @@ struct TailCallElim : public FunctionPass {
919964
return TailRecursionEliminator::eliminate(
920965
F, &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F),
921966
&getAnalysis<AAResultsWrapperPass>().getAAResults(),
922-
&getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(), DTU);
967+
&getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(), DTU,
968+
nullptr);
923969
}
924970
};
925971
}
@@ -942,14 +988,22 @@ PreservedAnalyses TailCallElimPass::run(Function &F,
942988

943989
TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F);
944990
AliasAnalysis &AA = AM.getResult<AAManager>(F);
991+
// This must come first. It needs the 2 analyses, meaning, if it came after
992+
// the lines asking for the cached result, should they be nullptr (which, in
993+
// the case of the PDT, is likely), updates to the trees would be missed.
994+
auto *BFI = (!ForceDisableBFI && UpdateFunctionEntryCount &&
995+
F.getEntryCount().has_value() && F.getEntryCount()->getCount())
996+
? &AM.getResult<BlockFrequencyAnalysis>(F)
997+
: nullptr;
945998
auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
946999
auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
9471000
auto *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F);
9481001
// There is no noticable performance difference here between Lazy and Eager
9491002
// UpdateStrategy based on some test results. It is feasible to switch the
9501003
// UpdateStrategy to Lazy if we find it profitable later.
9511004
DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager);
952-
bool Changed = TailRecursionEliminator::eliminate(F, &TTI, &AA, &ORE, DTU);
1005+
bool Changed =
1006+
TailRecursionEliminator::eliminate(F, &TTI, &AA, &ORE, DTU, BFI);
9531007

9541008
if (!Changed)
9551009
return PreservedAnalyses::all();
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals
2+
; RUN: opt -passes=tailcallelim -S %s -o - | FileCheck %s
3+
4+
; Test that tail call elimination correctly adjusts function entry counts
5+
; when eliminating tail recursive calls.
6+
7+
; Basic test: eliminate a tail call and adjust entry count
8+
define i32 @test_basic_entry_count_adjustment(i32 %n) !prof !0 {
9+
; CHECK-LABEL: @test_basic_entry_count_adjustment(
10+
; CHECK-NEXT: entry:
11+
; CHECK-NEXT: br label [[TAILRECURSE:%.*]]
12+
; CHECK: tailrecurse:
13+
; CHECK-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB:%.*]], [[IF_THEN:%.*]] ]
14+
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
15+
; CHECK-NEXT: br i1 [[CMP]], label [[IF_THEN]], label [[IF_ELSE:%.*]], !prof [[PROF1:![0-9]+]]
16+
; CHECK: if.then:
17+
; CHECK-NEXT: [[SUB]] = sub i32 [[N_TR]], 1
18+
; CHECK-NEXT: br label [[TAILRECURSE]]
19+
; CHECK: if.else:
20+
; CHECK-NEXT: ret i32 0
21+
;
22+
entry:
23+
%cmp = icmp sgt i32 %n, 0
24+
br i1 %cmp, label %if.then, label %if.else, !prof !1
25+
26+
if.then: ; preds = %entry
27+
%sub = sub i32 %n, 1
28+
%call = tail call i32 @test_basic_entry_count_adjustment(i32 %sub)
29+
ret i32 %call
30+
31+
if.else: ; preds = %entry
32+
ret i32 0
33+
}
34+
35+
; Test multiple tail calls in different blocks with different frequencies
36+
define i32 @test_multiple_blocks_entry_count(i32 %n, i32 %flag) !prof !2 {
37+
; CHECK-LABEL: @test_multiple_blocks_entry_count(
38+
; CHECK-NEXT: entry:
39+
; CHECK-NEXT: br label [[TAILRECURSE:%.*]]
40+
; CHECK: tailrecurse:
41+
; CHECK-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB1:%.*]], [[BLOCK1:%.*]] ], [ [[SUB2:%.*]], [[BLOCK2:%.*]] ]
42+
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
43+
; CHECK-NEXT: br i1 [[CMP]], label [[CHECK_FLAG:%.*]], label [[BASE_CASE:%.*]]
44+
; CHECK: check.flag:
45+
; CHECK-NEXT: [[CMP_FLAG:%.*]] = icmp eq i32 [[FLAG:%.*]], 1
46+
; CHECK-NEXT: br i1 [[CMP_FLAG]], label [[BLOCK1]], label [[BLOCK2]], !prof [[PROF3:![0-9]+]]
47+
; CHECK: block1:
48+
; CHECK-NEXT: [[SUB1]] = sub i32 [[N_TR]], 1
49+
; CHECK-NEXT: br label [[TAILRECURSE]]
50+
; CHECK: block2:
51+
; CHECK-NEXT: [[SUB2]] = sub i32 [[N_TR]], 2
52+
; CHECK-NEXT: br label [[TAILRECURSE]]
53+
; CHECK: base.case:
54+
; CHECK-NEXT: ret i32 1
55+
;
56+
entry:
57+
%cmp = icmp sgt i32 %n, 0
58+
br i1 %cmp, label %check.flag, label %base.case
59+
60+
check.flag:
61+
%cmp.flag = icmp eq i32 %flag, 1
62+
br i1 %cmp.flag, label %block1, label %block2, !prof !3
63+
64+
block1: ; preds = %check.flag
65+
%sub1 = sub i32 %n, 1
66+
%call1 = tail call i32 @test_multiple_blocks_entry_count(i32 %sub1, i32 %flag)
67+
ret i32 %call1
68+
69+
block2: ; preds = %check.flag
70+
%sub2 = sub i32 %n, 2
71+
%call2 = tail call i32 @test_multiple_blocks_entry_count(i32 %sub2, i32 %flag)
72+
ret i32 %call2
73+
74+
base.case: ; preds = %entry
75+
ret i32 1
76+
}
77+
78+
; Test function without entry count (should not crash)
79+
define i32 @test_no_entry_count(i32 %n) {
80+
; CHECK-LABEL: @test_no_entry_count(
81+
; CHECK-NEXT: entry:
82+
; CHECK-NEXT: br label [[TAILRECURSE:%.*]]
83+
; CHECK: tailrecurse:
84+
; CHECK-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB:%.*]], [[IF_THEN:%.*]] ]
85+
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
86+
; CHECK-NEXT: br i1 [[CMP]], label [[IF_THEN]], label [[IF_ELSE:%.*]]
87+
; CHECK: if.then:
88+
; CHECK-NEXT: [[SUB]] = sub i32 [[N_TR]], 1
89+
; CHECK-NEXT: br label [[TAILRECURSE]]
90+
; CHECK: if.else:
91+
; CHECK-NEXT: ret i32 0
92+
;
93+
entry:
94+
%cmp = icmp sgt i32 %n, 0
95+
br i1 %cmp, label %if.then, label %if.else
96+
97+
if.then: ; preds = %entry
98+
%sub = sub i32 %n, 1
99+
%call = tail call i32 @test_no_entry_count(i32 %sub)
100+
ret i32 %call
101+
102+
if.else: ; preds = %entry
103+
ret i32 0
104+
}
105+
106+
; Function entry count metadata
107+
!0 = !{!"function_entry_count", i64 1000}
108+
!1 = !{!"branch_weights", i32 800, i32 200}
109+
!2 = !{!"function_entry_count", i64 2000}
110+
!3 = !{!"branch_weights", i32 100, i32 500}
111+
;.
112+
; CHECK: [[META0:![0-9]+]] = !{!"function_entry_count", i64 200}
113+
; CHECK: [[PROF1]] = !{!"branch_weights", i32 800, i32 200}
114+
; CHECK: [[META2:![0-9]+]] = !{!"function_entry_count", i64 859}
115+
; CHECK: [[PROF3]] = !{!"branch_weights", i32 100, i32 500}
116+
;.

0 commit comments

Comments
 (0)