Skip to content

[TRE] Adjust function entry count when using instrumented profiles #143987

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions llvm/include/llvm/Passes/PassBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,8 @@ class PassBuilder {
IntrusiveRefCntPtr<vfs::FileSystem> FS);
void addPostPGOLoopRotation(ModulePassManager &MPM, OptimizationLevel Level);

bool isInstrumentedPGOUse() const;

// Extension Point callbacks
SmallVector<std::function<void(FunctionPassManager &, OptimizationLevel)>, 2>
PeepholeEPCallbacks;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,12 @@ namespace llvm {

class Function;

struct TailCallElimPass : PassInfoMixin<TailCallElimPass> {
class TailCallElimPass : public PassInfoMixin<TailCallElimPass> {
const bool UpdateFunctionEntryCount;

public:
TailCallElimPass(bool UpdateFunctionEntryCount = true)
: UpdateFunctionEntryCount(UpdateFunctionEntryCount) {}
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
};
}
Expand Down
14 changes: 11 additions & 3 deletions llvm/lib/Passes/PassBuilderPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,8 @@ PassBuilder::buildFunctionSimplificationPipeline(OptimizationLevel Level,
!Level.isOptimizingForSize())
FPM.addPass(PGOMemOPSizeOpt());

FPM.addPass(TailCallElimPass());
FPM.addPass(TailCallElimPass(/*UpdateFunctionEntryCount=*/
isInstrumentedPGOUse()));
FPM.addPass(
SimplifyCFGPass(SimplifyCFGOptions().convertSwitchRangeToICmp(true)));

Expand Down Expand Up @@ -1581,7 +1582,8 @@ PassBuilder::buildModuleOptimizationPipeline(OptimizationLevel Level,
OptimizePM.addPass(DivRemPairsPass());

// Try to annotate calls that were created during optimization.
OptimizePM.addPass(TailCallElimPass());
OptimizePM.addPass(
TailCallElimPass(/*UpdateFunctionEntryCount=*/isInstrumentedPGOUse()));

// LoopSink (and other loop passes since the last simplifyCFG) might have
// resulted in single-entry-single-exit or empty blocks. Clean up the CFG.
Expand Down Expand Up @@ -2069,7 +2071,8 @@ PassBuilder::buildLTODefaultPipeline(OptimizationLevel Level,

// LTO provides additional opportunities for tailcall elimination due to
// link-time inlining, and visibility of nocapture attribute.
FPM.addPass(TailCallElimPass());
FPM.addPass(
TailCallElimPass(/*UpdateFunctionEntryCount=*/isInstrumentedPGOUse()));

// Run a few AA driver optimizations here and now to cleanup the code.
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM),
Expand Down Expand Up @@ -2350,3 +2353,8 @@ AAManager PassBuilder::buildDefaultAAPipeline() {

return AA;
}

bool PassBuilder::isInstrumentedPGOUse() const {
return (PGOOpt && PGOOpt->Action == PGOOptions::IRUse) ||
!UseCtxProfile.empty();
}
68 changes: 61 additions & 7 deletions llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/BlockFrequencyInfo.h"
#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/InstructionSimplify.h"
Expand All @@ -75,10 +76,12 @@
#include "llvm/IR/Module.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include <cmath>
using namespace llvm;

#define DEBUG_TYPE "tailcallelim"
Expand All @@ -87,6 +90,11 @@ STATISTIC(NumEliminated, "Number of tail calls removed");
STATISTIC(NumRetDuped, "Number of return duplicated");
STATISTIC(NumAccumAdded, "Number of accumulators introduced");

static cl::opt<bool> ForceDisableBFI(
"tre-disable-entrycount-recompute", cl::init(false), cl::Hidden,
cl::desc("Force disabling recomputing of function entry count, on "
"successful tail recursion elimination."));

/// Scan the specified function for alloca instructions.
/// If it contains any dynamic allocas, returns false.
static bool canTRE(Function &F) {
Expand Down Expand Up @@ -399,6 +407,9 @@ class TailRecursionEliminator {
AliasAnalysis *AA;
OptimizationRemarkEmitter *ORE;
DomTreeUpdater &DTU;
BlockFrequencyInfo *const BFI;
const uint64_t OrigEntryBBFreq;
const uint64_t OrigEntryCount;

// The below are shared state we want to have available when eliminating any
// calls in the function. There values should be populated by
Expand Down Expand Up @@ -428,8 +439,19 @@ class TailRecursionEliminator {

TailRecursionEliminator(Function &F, const TargetTransformInfo *TTI,
AliasAnalysis *AA, OptimizationRemarkEmitter *ORE,
DomTreeUpdater &DTU)
: F(F), TTI(TTI), AA(AA), ORE(ORE), DTU(DTU) {}
DomTreeUpdater &DTU, BlockFrequencyInfo *BFI)
: F(F), TTI(TTI), AA(AA), ORE(ORE), DTU(DTU), BFI(BFI),
OrigEntryBBFreq(
BFI ? BFI->getBlockFreq(&F.getEntryBlock()).getFrequency() : 0U),
OrigEntryCount(F.getEntryCount() ? F.getEntryCount()->getCount() : 0) {
if (BFI) {
// The assert is meant as API documentation for the caller.
assert((OrigEntryCount != 0 && OrigEntryBBFreq != 0) &&
"If a BFI was provided, the function should have both an entry "
"count that is non-zero and an entry basic block with a non-zero "
"frequency.");
}
}

CallInst *findTRECandidate(BasicBlock *BB);

Expand All @@ -450,7 +472,7 @@ class TailRecursionEliminator {
public:
static bool eliminate(Function &F, const TargetTransformInfo *TTI,
AliasAnalysis *AA, OptimizationRemarkEmitter *ORE,
DomTreeUpdater &DTU);
DomTreeUpdater &DTU, BlockFrequencyInfo *BFI);
};
} // namespace

Expand Down Expand Up @@ -735,6 +757,28 @@ bool TailRecursionEliminator::eliminateCall(CallInst *CI) {
CI->eraseFromParent(); // Remove call.
DTU.applyUpdates({{DominatorTree::Insert, BB, HeaderBB}});
++NumEliminated;
if (OrigEntryBBFreq) {
assert(F.getEntryCount().has_value());
// This pass is not expected to remove BBs, only add an entry BB. For that
// reason, and because the BB here isn't the new entry BB, the BFI lookup is
// expected to succeed.
assert(&F.getEntryBlock() != BB);
auto RelativeBBFreq =
static_cast<double>(BFI->getBlockFreq(BB).getFrequency()) /
static_cast<double>(OrigEntryBBFreq);
auto ToSubtract =
static_cast<uint64_t>(std::round(RelativeBBFreq * OrigEntryCount));
auto OldEntryCount = F.getEntryCount()->getCount();
if (OldEntryCount <= ToSubtract) {
LLVM_DEBUG(
errs() << "[TRE] The entrycount attributable to the recursive call, "
<< ToSubtract
<< ", should be strictly lower than the function entry count, "
<< OldEntryCount << "\n");
} else {
F.setEntryCount(OldEntryCount - ToSubtract, F.getEntryCount()->getType());
}
}
return true;
}

Expand Down Expand Up @@ -861,7 +905,8 @@ bool TailRecursionEliminator::eliminate(Function &F,
const TargetTransformInfo *TTI,
AliasAnalysis *AA,
OptimizationRemarkEmitter *ORE,
DomTreeUpdater &DTU) {
DomTreeUpdater &DTU,
BlockFrequencyInfo *BFI) {
if (F.getFnAttribute("disable-tail-calls").getValueAsBool())
return false;

Expand All @@ -877,7 +922,7 @@ bool TailRecursionEliminator::eliminate(Function &F,
return MadeChange;

// Change any tail recursive calls to loops.
TailRecursionEliminator TRE(F, TTI, AA, ORE, DTU);
TailRecursionEliminator TRE(F, TTI, AA, ORE, DTU, BFI);

for (BasicBlock &BB : F)
MadeChange |= TRE.processBlock(BB);
Expand Down Expand Up @@ -919,7 +964,8 @@ struct TailCallElim : public FunctionPass {
return TailRecursionEliminator::eliminate(
F, &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F),
&getAnalysis<AAResultsWrapperPass>().getAAResults(),
&getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(), DTU);
&getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(), DTU,
/*BFI=*/nullptr);
}
};
}
Expand All @@ -942,14 +988,22 @@ PreservedAnalyses TailCallElimPass::run(Function &F,

TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F);
AliasAnalysis &AA = AM.getResult<AAManager>(F);
// This must come first. It needs the 2 analyses, meaning, if it came after
// the lines asking for the cached result, should they be nullptr (which, in
// the case of the PDT, is likely), updates to the trees would be missed.
auto *BFI = (!ForceDisableBFI && UpdateFunctionEntryCount &&
F.getEntryCount().has_value() && F.getEntryCount()->getCount())
? &AM.getResult<BlockFrequencyAnalysis>(F)
: nullptr;
auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
auto *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F);
// There is no noticable performance difference here between Lazy and Eager
// UpdateStrategy based on some test results. It is feasible to switch the
// UpdateStrategy to Lazy if we find it profitable later.
DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager);
bool Changed = TailRecursionEliminator::eliminate(F, &TTI, &AA, &ORE, DTU);
bool Changed =
TailRecursionEliminator::eliminate(F, &TTI, &AA, &ORE, DTU, BFI);

if (!Changed)
return PreservedAnalyses::all();
Expand Down
120 changes: 120 additions & 0 deletions llvm/test/Transforms/TailCallElim/entry-count-adjustment.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals
; RUN: opt -passes=tailcallelim -S %s -o - | FileCheck %s --check-prefixes=CHECK,ENABLED
; RUN: opt -passes=tailcallelim -tre-disable-entrycount-recompute -S %s -o - | FileCheck %s --check-prefixes=CHECK,DISABLED

; Test that tail call elimination correctly adjusts function entry counts
; when eliminating tail recursive calls.

; Basic test: eliminate a tail call and adjust entry count
define i32 @test_basic_entry_count_adjustment(i32 %n) !prof !0 {
; CHECK-LABEL: @test_basic_entry_count_adjustment(
; CHECK-NEXT: entry:
; CHECK-NEXT: br label [[TAILRECURSE:%.*]]
; CHECK: tailrecurse:
; CHECK-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB:%.*]], [[IF_THEN:%.*]] ]
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
; CHECK-NEXT: br i1 [[CMP]], label [[IF_THEN]], label [[IF_ELSE:%.*]], !prof [[PROF1:![0-9]+]]
; CHECK: if.then:
; CHECK-NEXT: [[SUB]] = sub i32 [[N_TR]], 1
; CHECK-NEXT: br label [[TAILRECURSE]]
; CHECK: if.else:
; CHECK-NEXT: ret i32 0
;
entry:
%cmp = icmp sgt i32 %n, 0
br i1 %cmp, label %if.then, label %if.else, !prof !1

if.then: ; preds = %entry
%sub = sub i32 %n, 1
%call = tail call i32 @test_basic_entry_count_adjustment(i32 %sub)
ret i32 %call

if.else: ; preds = %entry
ret i32 0
}

; Test multiple tail calls in different blocks with different frequencies
define i32 @test_multiple_blocks_entry_count(i32 %n, i32 %flag) !prof !2 {
; CHECK-LABEL: @test_multiple_blocks_entry_count(
; CHECK-NEXT: entry:
; CHECK-NEXT: br label [[TAILRECURSE:%.*]]
; CHECK: tailrecurse:
; CHECK-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB1:%.*]], [[BLOCK1:%.*]] ], [ [[SUB2:%.*]], [[BLOCK2:%.*]] ]
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
; CHECK-NEXT: br i1 [[CMP]], label [[CHECK_FLAG:%.*]], label [[BASE_CASE:%.*]], !prof [[PROF3:![0-9]+]]
; CHECK: check.flag:
; CHECK-NEXT: [[CMP_FLAG:%.*]] = icmp eq i32 [[FLAG:%.*]], 1
; CHECK-NEXT: br i1 [[CMP_FLAG]], label [[BLOCK1]], label [[BLOCK2]], !prof [[PROF4:![0-9]+]]
; CHECK: block1:
; CHECK-NEXT: [[SUB1]] = sub i32 [[N_TR]], 1
; CHECK-NEXT: br label [[TAILRECURSE]]
; CHECK: block2:
; CHECK-NEXT: [[SUB2]] = sub i32 [[N_TR]], 2
; CHECK-NEXT: br label [[TAILRECURSE]]
; CHECK: base.case:
; CHECK-NEXT: ret i32 1
;
entry:
%cmp = icmp sgt i32 %n, 0
br i1 %cmp, label %check.flag, label %base.case, !prof !3
check.flag:
%cmp.flag = icmp eq i32 %flag, 1
br i1 %cmp.flag, label %block1, label %block2, !prof !4
block1: ; preds = %check.flag
%sub1 = sub i32 %n, 1
%call1 = tail call i32 @test_multiple_blocks_entry_count(i32 %sub1, i32 %flag)
ret i32 %call1
block2: ; preds = %check.flag
%sub2 = sub i32 %n, 2
%call2 = tail call i32 @test_multiple_blocks_entry_count(i32 %sub2, i32 %flag)
ret i32 %call2
base.case: ; preds = %entry
ret i32 1
}

define i32 @test_no_entry_count(i32 %n) {
; CHECK-LABEL: @test_no_entry_count(
; CHECK-NEXT: entry:
; CHECK-NEXT: br label [[TAILRECURSE:%.*]]
; CHECK: tailrecurse:
; CHECK-NEXT: [[N_TR:%.*]] = phi i32 [ [[N:%.*]], [[ENTRY:%.*]] ], [ [[SUB:%.*]], [[IF_THEN:%.*]] ]
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N_TR]], 0
; CHECK-NEXT: br i1 [[CMP]], label [[IF_THEN]], label [[IF_ELSE:%.*]]
; CHECK: if.then:
; CHECK-NEXT: [[SUB]] = sub i32 [[N_TR]], 1
; CHECK-NEXT: br label [[TAILRECURSE]]
; CHECK: if.else:
; CHECK-NEXT: ret i32 0
;
entry:
%cmp = icmp sgt i32 %n, 0
br i1 %cmp, label %if.then, label %if.else

if.then: ; preds = %entry
%sub = sub i32 %n, 1
%call = tail call i32 @test_no_entry_count(i32 %sub)
ret i32 %call

if.else: ; preds = %entry
ret i32 0
}

; Function entry count metadata
!0 = !{!"function_entry_count", i64 1000}
!1 = !{!"branch_weights", i32 800, i32 200}
!2 = !{!"function_entry_count", i64 2000}
!3 = !{!"branch_weights", i32 3, i32 1}
!4 = !{!"branch_weights", i32 100, i32 900}
;.
; ENABLED: [[META0:![0-9]+]] = !{!"function_entry_count", i64 200}
; ENABLED: [[PROF1]] = !{!"branch_weights", i32 800, i32 200}
; ENABLED: [[META2:![0-9]+]] = !{!"function_entry_count", i64 500}
; ENABLED: [[PROF3]] = !{!"branch_weights", i32 3, i32 1}
; ENABLED: [[PROF4]] = !{!"branch_weights", i32 100, i32 900}
;.
; DISABLED: [[META0:![0-9]+]] = !{!"function_entry_count", i64 1000}
; DISABLED: [[PROF1]] = !{!"branch_weights", i32 800, i32 200}
; DISABLED: [[META2:![0-9]+]] = !{!"function_entry_count", i64 2000}
; DISABLED: [[PROF3]] = !{!"branch_weights", i32 3, i32 1}
; DISABLED: [[PROF4]] = !{!"branch_weights", i32 100, i32 900}
;.
Loading