Skip to content

Commit a6362f8

Browse files
committed
[TailRecursionElim] Adjust function entry count
1 parent d659046 commit a6362f8

File tree

2 files changed

+160
-7
lines changed

2 files changed

+160
-7
lines changed

llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
#include "llvm/ADT/STLExtras.h"
5454
#include "llvm/ADT/SmallPtrSet.h"
5555
#include "llvm/ADT/Statistic.h"
56+
#include "llvm/Analysis/BlockFrequencyInfo.h"
5657
#include "llvm/Analysis/DomTreeUpdater.h"
5758
#include "llvm/Analysis/GlobalsModRef.h"
5859
#include "llvm/Analysis/InstructionSimplify.h"
@@ -409,6 +410,8 @@ class TailRecursionEliminator {
409410
AliasAnalysis *AA;
410411
OptimizationRemarkEmitter *ORE;
411412
DomTreeUpdater &DTU;
413+
const uint64_t OrigEntryBBFreq;
414+
DenseMap<const BasicBlock *, uint64_t> OriginalBBFreqs;
412415

413416
// The below are shared state we want to have available when eliminating any
414417
// calls in the function. There values should be populated by
@@ -438,8 +441,18 @@ class TailRecursionEliminator {
438441

439442
TailRecursionEliminator(Function &F, const TargetTransformInfo *TTI,
440443
AliasAnalysis *AA, OptimizationRemarkEmitter *ORE,
441-
DomTreeUpdater &DTU)
442-
: F(F), TTI(TTI), AA(AA), ORE(ORE), DTU(DTU) {}
444+
DomTreeUpdater &DTU, BlockFrequencyInfo *BFI)
445+
: F(F), TTI(TTI), AA(AA), ORE(ORE), DTU(DTU),
446+
OrigEntryBBFreq(
447+
BFI ? BFI->getBlockFreq(&F.getEntryBlock()).getFrequency() : 0U) {
448+
assert(((BFI != nullptr) == (OrigEntryBBFreq != 0)) &&
449+
"If the function has an entry count, its entry basic block should "
450+
"have a non-zero frequency. Pass a nullptr BFI if the function has "
451+
"no entry count");
452+
if (BFI)
453+
for (const auto &BB : F)
454+
OriginalBBFreqs.insert({&BB, BFI->getBlockFreq(&BB).getFrequency()});
455+
}
443456

444457
CallInst *findTRECandidate(BasicBlock *BB);
445458

@@ -460,7 +473,7 @@ class TailRecursionEliminator {
460473
public:
461474
static bool eliminate(Function &F, const TargetTransformInfo *TTI,
462475
AliasAnalysis *AA, OptimizationRemarkEmitter *ORE,
463-
DomTreeUpdater &DTU);
476+
DomTreeUpdater &DTU, BlockFrequencyInfo *BFI);
464477
};
465478
} // namespace
466479

@@ -746,6 +759,17 @@ bool TailRecursionEliminator::eliminateCall(CallInst *CI) {
746759
CI->eraseFromParent(); // Remove call.
747760
DTU.applyUpdates({{DominatorTree::Insert, BB, HeaderBB}});
748761
++NumEliminated;
762+
if (auto EC = F.getEntryCount()) {
763+
assert(OrigEntryBBFreq);
764+
auto It = OriginalBBFreqs.find(BB);
765+
assert(It != OriginalBBFreqs.end());
766+
auto RelativeBBFreq =
767+
static_cast<double>(It->second) / static_cast<double>(OrigEntryBBFreq);
768+
auto OldEntryCount = EC.value().getCount();
769+
auto ToSubtract = static_cast<uint64_t>(RelativeBBFreq * OldEntryCount);
770+
assert(OldEntryCount > ToSubtract);
771+
F.setEntryCount(OldEntryCount - ToSubtract, EC->getType());
772+
}
749773
return true;
750774
}
751775

@@ -872,7 +896,8 @@ bool TailRecursionEliminator::eliminate(Function &F,
872896
const TargetTransformInfo *TTI,
873897
AliasAnalysis *AA,
874898
OptimizationRemarkEmitter *ORE,
875-
DomTreeUpdater &DTU) {
899+
DomTreeUpdater &DTU,
900+
BlockFrequencyInfo *BFI) {
876901
if (F.getFnAttribute("disable-tail-calls").getValueAsBool())
877902
return false;
878903

@@ -888,7 +913,7 @@ bool TailRecursionEliminator::eliminate(Function &F,
888913
return MadeChange;
889914

890915
// Change any tail recursive calls to loops.
891-
TailRecursionEliminator TRE(F, TTI, AA, ORE, DTU);
916+
TailRecursionEliminator TRE(F, TTI, AA, ORE, DTU, BFI);
892917

893918
for (BasicBlock &BB : F)
894919
MadeChange |= TRE.processBlock(BB);
@@ -909,6 +934,7 @@ struct TailCallElim : public FunctionPass {
909934
AU.addRequired<TargetTransformInfoWrapperPass>();
910935
AU.addRequired<AAResultsWrapperPass>();
911936
AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
937+
AU.addRequired<BlockFrequencyInfoWrapperPass>();
912938
AU.addPreserved<GlobalsAAWrapperPass>();
913939
AU.addPreserved<DominatorTreeWrapperPass>();
914940
AU.addPreserved<PostDominatorTreeWrapperPass>();
@@ -918,6 +944,9 @@ struct TailCallElim : public FunctionPass {
918944
if (skipFunction(F))
919945
return false;
920946

947+
auto *BFI = F.getEntryCount().has_value()
948+
? &getAnalysis<BlockFrequencyInfoWrapperPass>().getBFI()
949+
: nullptr;
921950
auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
922951
auto *DT = DTWP ? &DTWP->getDomTree() : nullptr;
923952
auto *PDTWP = getAnalysisIfAvailable<PostDominatorTreeWrapperPass>();
@@ -930,7 +959,8 @@ struct TailCallElim : public FunctionPass {
930959
return TailRecursionEliminator::eliminate(
931960
F, &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F),
932961
&getAnalysis<AAResultsWrapperPass>().getAAResults(),
933-
&getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(), DTU);
962+
&getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(), DTU,
963+
BFI);
934964
}
935965
};
936966
}
@@ -953,14 +983,21 @@ PreservedAnalyses TailCallElimPass::run(Function &F,
953983

954984
TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F);
955985
AliasAnalysis &AA = AM.getResult<AAManager>(F);
986+
// This must come first. It needs the 2 analyses, meaning, if it came after
987+
// the lines asking for the cached result, should they be nullptr (which, in
988+
// the case of the PDT, is likely), updates to the trees would be missed.
989+
auto *BFI = F.getEntryCount().has_value()
990+
? &AM.getResult<BlockFrequencyAnalysis>(F)
991+
: nullptr;
956992
auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
957993
auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
958994
auto *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F);
959995
// There is no noticable performance difference here between Lazy and Eager
960996
// UpdateStrategy based on some test results. It is feasible to switch the
961997
// UpdateStrategy to Lazy if we find it profitable later.
962998
DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager);
963-
bool Changed = TailRecursionEliminator::eliminate(F, &TTI, &AA, &ORE, DTU);
999+
bool Changed =
1000+
TailRecursionEliminator::eliminate(F, &TTI, &AA, &ORE, DTU, BFI);
9641001

9651002
if (!Changed)
9661003
return PreservedAnalyses::all();
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals
2+
; RUN: opt -passes=tailcallelim -S %s -o - | FileCheck %s
3+
4+
; Test that tail call elimination correctly adjusts function entry counts
5+
; when eliminating tail recursive calls.
6+
7+
; Basic test: eliminate a tail call and adjust entry count
8+
define i32 @test_basic_entry_count_adjustment(i32 %n) !prof !0 {
9+
; CHECK-LABEL: @test_basic_entry_count_adjustment(
10+
; CHECK-NEXT: entry:
11+
; CHECK-NEXT: br label [[TAILRECURSE:%.*]]
12+
; CHECK: tailrecurse:
13+
; CHECK-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB:%.*]], [[IF_THEN:%.*]] ]
14+
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
15+
; CHECK-NEXT: br i1 [[CMP]], label [[IF_THEN]], label [[IF_ELSE:%.*]], !prof [[PROF1:![0-9]+]]
16+
; CHECK: if.then:
17+
; CHECK-NEXT: [[SUB]] = sub i32 [[N_TR]], 1
18+
; CHECK-NEXT: br label [[TAILRECURSE]]
19+
; CHECK: if.else:
20+
; CHECK-NEXT: ret i32 0
21+
;
22+
entry:
23+
%cmp = icmp sgt i32 %n, 0
24+
br i1 %cmp, label %if.then, label %if.else, !prof !1
25+
26+
if.then: ; preds = %entry
27+
%sub = sub i32 %n, 1
28+
%call = tail call i32 @test_basic_entry_count_adjustment(i32 %sub)
29+
ret i32 %call
30+
31+
if.else: ; preds = %entry
32+
ret i32 0
33+
}
34+
35+
; Test multiple tail calls in different blocks with different frequencies
36+
define i32 @test_multiple_blocks_entry_count(i32 %n, i32 %flag) !prof !2 {
37+
; CHECK-LABEL: @test_multiple_blocks_entry_count(
38+
; CHECK-NEXT: entry:
39+
; CHECK-NEXT: br label [[TAILRECURSE:%.*]]
40+
; CHECK: tailrecurse:
41+
; CHECK-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB1:%.*]], [[BLOCK1:%.*]] ], [ [[SUB2:%.*]], [[BLOCK2:%.*]] ]
42+
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
43+
; CHECK-NEXT: br i1 [[CMP]], label [[CHECK_FLAG:%.*]], label [[BASE_CASE:%.*]]
44+
; CHECK: check.flag:
45+
; CHECK-NEXT: [[CMP_FLAG:%.*]] = icmp eq i32 [[FLAG:%.*]], 1
46+
; CHECK-NEXT: br i1 [[CMP_FLAG]], label [[BLOCK1]], label [[BLOCK2]], !prof [[PROF3:![0-9]+]]
47+
; CHECK: block1:
48+
; CHECK-NEXT: [[SUB1]] = sub i32 [[N_TR]], 1
49+
; CHECK-NEXT: br label [[TAILRECURSE]]
50+
; CHECK: block2:
51+
; CHECK-NEXT: [[SUB2]] = sub i32 [[N_TR]], 2
52+
; CHECK-NEXT: br label [[TAILRECURSE]]
53+
; CHECK: base.case:
54+
; CHECK-NEXT: ret i32 1
55+
;
56+
entry:
57+
%cmp = icmp sgt i32 %n, 0
58+
br i1 %cmp, label %check.flag, label %base.case
59+
60+
check.flag:
61+
%cmp.flag = icmp eq i32 %flag, 1
62+
br i1 %cmp.flag, label %block1, label %block2, !prof !3
63+
64+
block1: ; preds = %check.flag
65+
%sub1 = sub i32 %n, 1
66+
%call1 = tail call i32 @test_multiple_blocks_entry_count(i32 %sub1, i32 %flag)
67+
ret i32 %call1
68+
69+
block2: ; preds = %check.flag
70+
%sub2 = sub i32 %n, 2
71+
%call2 = tail call i32 @test_multiple_blocks_entry_count(i32 %sub2, i32 %flag)
72+
ret i32 %call2
73+
74+
base.case: ; preds = %entry
75+
ret i32 1
76+
}
77+
78+
; Test function without entry count (should not crash)
79+
define i32 @test_no_entry_count(i32 %n) {
80+
; CHECK-LABEL: @test_no_entry_count(
81+
; CHECK-NEXT: entry:
82+
; CHECK-NEXT: br label [[TAILRECURSE:%.*]]
83+
; CHECK: tailrecurse:
84+
; CHECK-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB:%.*]], [[IF_THEN:%.*]] ]
85+
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
86+
; CHECK-NEXT: br i1 [[CMP]], label [[IF_THEN]], label [[IF_ELSE:%.*]]
87+
; CHECK: if.then:
88+
; CHECK-NEXT: [[SUB]] = sub i32 [[N_TR]], 1
89+
; CHECK-NEXT: br label [[TAILRECURSE]]
90+
; CHECK: if.else:
91+
; CHECK-NEXT: ret i32 0
92+
;
93+
entry:
94+
%cmp = icmp sgt i32 %n, 0
95+
br i1 %cmp, label %if.then, label %if.else
96+
97+
if.then: ; preds = %entry
98+
%sub = sub i32 %n, 1
99+
%call = tail call i32 @test_no_entry_count(i32 %sub)
100+
ret i32 %call
101+
102+
if.else: ; preds = %entry
103+
ret i32 0
104+
}
105+
106+
; Function entry count metadata
107+
!0 = !{!"function_entry_count", i64 1000}
108+
!1 = !{!"branch_weights", i32 800, i32 200}
109+
!2 = !{!"function_entry_count", i64 2000}
110+
!3 = !{!"branch_weights", i32 100, i32 500}
111+
;.
112+
; CHECK: [[META0:![0-9]+]] = !{!"function_entry_count", i64 201}
113+
; CHECK: [[PROF1]] = !{!"branch_weights", i32 800, i32 200}
114+
; CHECK: [[META2:![0-9]+]] = !{!"function_entry_count", i64 859}
115+
; CHECK: [[PROF3]] = !{!"branch_weights", i32 100, i32 500}
116+
;.

0 commit comments

Comments
 (0)