Skip to content

Commit d26b43f

Browse files
Add JumpTableToSwitch pass (llvm#77709)
Add a pass to convert jump tables to switches. The new pass replaces an indirect call with a switch + direct calls if all the functions in the jump table are smaller than the provided threshold. The pass is currently disabled by default and can be enabled by -enable-jump-table-to-switch. Test plan: ninja check-all
1 parent 9308d66 commit d26b43f

File tree

13 files changed

+732
-0
lines changed

13 files changed

+732
-0
lines changed
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
//===- JumpTableToSwitch.h - ------------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef LLVM_TRANSFORMS_SCALAR_JUMP_TABLE_TO_SWITCH_H
10+
#define LLVM_TRANSFORMS_SCALAR_JUMP_TABLE_TO_SWITCH_H
11+
12+
#include "llvm/IR/PassManager.h"
13+
14+
namespace llvm {
15+
16+
class Function;
17+
18+
struct JumpTableToSwitchPass : PassInfoMixin<JumpTableToSwitchPass> {
19+
/// Run the pass over the function.
20+
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
21+
};
22+
} // end namespace llvm
23+
24+
#endif // LLVM_TRANSFORMS_SCALAR_JUMP_TABLE_TO_SWITCH_H

llvm/lib/Passes/PassBuilder.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@
201201
#include "llvm/Transforms/Scalar/InferAddressSpaces.h"
202202
#include "llvm/Transforms/Scalar/InferAlignment.h"
203203
#include "llvm/Transforms/Scalar/InstSimplifyPass.h"
204+
#include "llvm/Transforms/Scalar/JumpTableToSwitch.h"
204205
#include "llvm/Transforms/Scalar/JumpThreading.h"
205206
#include "llvm/Transforms/Scalar/LICM.h"
206207
#include "llvm/Transforms/Scalar/LoopAccessAnalysisPrinter.h"

llvm/lib/Passes/PassBuilderPipelines.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
#include "llvm/Transforms/Scalar/IndVarSimplify.h"
9292
#include "llvm/Transforms/Scalar/InferAlignment.h"
9393
#include "llvm/Transforms/Scalar/InstSimplifyPass.h"
94+
#include "llvm/Transforms/Scalar/JumpTableToSwitch.h"
9495
#include "llvm/Transforms/Scalar/JumpThreading.h"
9596
#include "llvm/Transforms/Scalar/LICM.h"
9697
#include "llvm/Transforms/Scalar/LoopDeletion.h"
@@ -237,6 +238,10 @@ static cl::opt<bool>
237238
EnableGVNSink("enable-gvn-sink",
238239
cl::desc("Enable the GVN sinking pass (default = off)"));
239240

241+
static cl::opt<bool> EnableJumpTableToSwitch(
242+
"enable-jump-table-to-switch",
243+
cl::desc("Enable JumpTableToSwitch pass (default = off)"));
244+
240245
// This option is used in simplifying testing SampleFDO optimizations for
241246
// profile loading.
242247
static cl::opt<bool>
@@ -559,6 +564,10 @@ PassBuilder::buildFunctionSimplificationPipeline(OptimizationLevel Level,
559564
FPM.addPass(JumpThreadingPass());
560565
FPM.addPass(CorrelatedValuePropagationPass());
561566

567+
// Jump table to switch conversion.
568+
if (EnableJumpTableToSwitch)
569+
FPM.addPass(JumpTableToSwitchPass());
570+
562571
FPM.addPass(
563572
SimplifyCFGPass(SimplifyCFGOptions().convertSwitchRangeToICmp(true)));
564573
FPM.addPass(InstCombinePass());

llvm/lib/Passes/PassRegistry.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,7 @@ FUNCTION_PASS("interleaved-load-combine", InterleavedLoadCombinePass(TM))
348348
FUNCTION_PASS("invalidate<all>", InvalidateAllAnalysesPass())
349349
FUNCTION_PASS("irce", IRCEPass())
350350
FUNCTION_PASS("jump-threading", JumpThreadingPass())
351+
FUNCTION_PASS("jump-table-to-switch", JumpTableToSwitchPass());
351352
FUNCTION_PASS("kcfi", KCFIPass())
352353
FUNCTION_PASS("lcssa", LCSSAPass())
353354
FUNCTION_PASS("libcalls-shrinkwrap", LibCallsShrinkWrapPass())

llvm/lib/Transforms/Scalar/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ add_llvm_component_library(LLVMScalarOpts
2525
InferAlignment.cpp
2626
InstSimplifyPass.cpp
2727
JumpThreading.cpp
28+
JumpTableToSwitch.cpp
2829
LICM.cpp
2930
LoopAccessAnalysisPrinter.cpp
3031
LoopBoundSplit.cpp
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
//===- JumpTableToSwitch.cpp ----------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "llvm/Transforms/Scalar/JumpTableToSwitch.h"
10+
#include "llvm/ADT/SmallVector.h"
11+
#include "llvm/Analysis/ConstantFolding.h"
12+
#include "llvm/Analysis/DomTreeUpdater.h"
13+
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
14+
#include "llvm/Analysis/PostDominators.h"
15+
#include "llvm/IR/IRBuilder.h"
16+
#include "llvm/Support/CommandLine.h"
17+
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
18+
19+
using namespace llvm;
20+
21+
static cl::opt<unsigned>
22+
JumpTableSizeThreshold("jump-table-to-switch-size-threshold", cl::Hidden,
23+
cl::desc("Only split jump tables with size less or "
24+
"equal than JumpTableSizeThreshold."),
25+
cl::init(10));
26+
27+
// TODO: Consider adding a cost model for profitability analysis of this
28+
// transformation. Currently we replace a jump table with a switch if all the
29+
// functions in the jump table are smaller than the provided threshold.
30+
static cl::opt<unsigned> FunctionSizeThreshold(
31+
"jump-table-to-switch-function-size-threshold", cl::Hidden,
32+
cl::desc("Only split jump tables containing functions whose sizes are less "
33+
"or equal than this threshold."),
34+
cl::init(50));
35+
36+
#define DEBUG_TYPE "jump-table-to-switch"
37+
38+
namespace {
39+
struct JumpTableTy {
40+
Value *Index;
41+
SmallVector<Function *, 10> Funcs;
42+
};
43+
} // anonymous namespace
44+
45+
static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP,
46+
PointerType *PtrTy) {
47+
Constant *Ptr = dyn_cast<Constant>(GEP->getPointerOperand());
48+
if (!Ptr)
49+
return std::nullopt;
50+
51+
GlobalVariable *GV = dyn_cast<GlobalVariable>(Ptr);
52+
if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer())
53+
return std::nullopt;
54+
55+
Function &F = *GEP->getParent()->getParent();
56+
const DataLayout &DL = F.getParent()->getDataLayout();
57+
const unsigned BitWidth =
58+
DL.getIndexSizeInBits(GEP->getPointerAddressSpace());
59+
MapVector<Value *, APInt> VariableOffsets;
60+
APInt ConstantOffset(BitWidth, 0);
61+
if (!GEP->collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset))
62+
return std::nullopt;
63+
if (VariableOffsets.size() != 1)
64+
return std::nullopt;
65+
// TODO: consider supporting more general patterns
66+
if (!ConstantOffset.isZero())
67+
return std::nullopt;
68+
APInt StrideBytes = VariableOffsets.front().second;
69+
const uint64_t JumpTableSizeBytes = DL.getTypeAllocSize(GV->getValueType());
70+
if (JumpTableSizeBytes % StrideBytes.getZExtValue() != 0)
71+
return std::nullopt;
72+
const uint64_t N = JumpTableSizeBytes / StrideBytes.getZExtValue();
73+
if (N > JumpTableSizeThreshold)
74+
return std::nullopt;
75+
76+
JumpTableTy JumpTable;
77+
JumpTable.Index = VariableOffsets.front().first;
78+
JumpTable.Funcs.reserve(N);
79+
for (uint64_t Index = 0; Index < N; ++Index) {
80+
// ConstantOffset is zero.
81+
APInt Offset = Index * StrideBytes;
82+
Constant *C =
83+
ConstantFoldLoadFromConst(GV->getInitializer(), PtrTy, Offset, DL);
84+
auto *Func = dyn_cast_or_null<Function>(C);
85+
if (!Func || Func->isDeclaration() ||
86+
Func->getInstructionCount() > FunctionSizeThreshold)
87+
return std::nullopt;
88+
JumpTable.Funcs.push_back(Func);
89+
}
90+
return JumpTable;
91+
}
92+
93+
static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
94+
DomTreeUpdater &DTU,
95+
OptimizationRemarkEmitter &ORE) {
96+
const bool IsVoid = CB->getType() == Type::getVoidTy(CB->getContext());
97+
98+
SmallVector<DominatorTree::UpdateType, 8> DTUpdates;
99+
BasicBlock *BB = CB->getParent();
100+
BasicBlock *Tail = SplitBlock(BB, CB, &DTU, nullptr, nullptr,
101+
BB->getName() + Twine(".tail"));
102+
DTUpdates.push_back({DominatorTree::Delete, BB, Tail});
103+
BB->getTerminator()->eraseFromParent();
104+
105+
Function &F = *BB->getParent();
106+
BasicBlock *BBUnreachable = BasicBlock::Create(
107+
F.getContext(), "default.switch.case.unreachable", &F, Tail);
108+
IRBuilder<> BuilderUnreachable(BBUnreachable);
109+
BuilderUnreachable.CreateUnreachable();
110+
111+
IRBuilder<> Builder(BB);
112+
SwitchInst *Switch = Builder.CreateSwitch(JT.Index, BBUnreachable);
113+
DTUpdates.push_back({DominatorTree::Insert, BB, BBUnreachable});
114+
115+
IRBuilder<> BuilderTail(CB);
116+
PHINode *PHI =
117+
IsVoid ? nullptr : BuilderTail.CreatePHI(CB->getType(), JT.Funcs.size());
118+
119+
for (auto [Index, Func] : llvm::enumerate(JT.Funcs)) {
120+
BasicBlock *B = BasicBlock::Create(Func->getContext(),
121+
"call." + Twine(Index), &F, Tail);
122+
DTUpdates.push_back({DominatorTree::Insert, BB, B});
123+
DTUpdates.push_back({DominatorTree::Insert, B, Tail});
124+
125+
CallBase *Call = cast<CallBase>(CB->clone());
126+
Call->setCalledFunction(Func);
127+
Call->insertInto(B, B->end());
128+
Switch->addCase(
129+
cast<ConstantInt>(ConstantInt::get(JT.Index->getType(), Index)), B);
130+
BranchInst::Create(Tail, B);
131+
if (PHI)
132+
PHI->addIncoming(Call, B);
133+
}
134+
DTU.applyUpdates(DTUpdates);
135+
ORE.emit([&]() {
136+
return OptimizationRemark(DEBUG_TYPE, "ReplacedJumpTableWithSwitch", CB)
137+
<< "expanded indirect call into switch";
138+
});
139+
if (PHI)
140+
CB->replaceAllUsesWith(PHI);
141+
CB->eraseFromParent();
142+
return Tail;
143+
}
144+
145+
PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
146+
FunctionAnalysisManager &AM) {
147+
OptimizationRemarkEmitter &ORE =
148+
AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
149+
DominatorTree *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
150+
PostDominatorTree *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F);
151+
DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy);
152+
bool Changed = false;
153+
for (BasicBlock &BB : make_early_inc_range(F)) {
154+
BasicBlock *CurrentBB = &BB;
155+
while (CurrentBB) {
156+
BasicBlock *SplittedOutTail = nullptr;
157+
for (Instruction &I : make_early_inc_range(*CurrentBB)) {
158+
auto *Call = dyn_cast<CallInst>(&I);
159+
if (!Call || Call->getCalledFunction() || Call->isMustTailCall())
160+
continue;
161+
auto *L = dyn_cast<LoadInst>(Call->getCalledOperand());
162+
// Skip atomic or volatile loads.
163+
if (!L || !L->isSimple())
164+
continue;
165+
auto *GEP = dyn_cast<GetElementPtrInst>(L->getPointerOperand());
166+
if (!GEP)
167+
continue;
168+
auto *PtrTy = dyn_cast<PointerType>(L->getType());
169+
assert(PtrTy && "call operand must be a pointer");
170+
std::optional<JumpTableTy> JumpTable = parseJumpTable(GEP, PtrTy);
171+
if (!JumpTable)
172+
continue;
173+
SplittedOutTail = expandToSwitch(Call, *JumpTable, DTU, ORE);
174+
Changed = true;
175+
break;
176+
}
177+
CurrentBB = SplittedOutTail ? SplittedOutTail : nullptr;
178+
}
179+
}
180+
181+
if (!Changed)
182+
return PreservedAnalyses::all();
183+
184+
PreservedAnalyses PA;
185+
if (DT)
186+
PA.preserve<DominatorTreeAnalysis>();
187+
if (PDT)
188+
PA.preserve<PostDominatorTreeAnalysis>();
189+
return PA;
190+
}

llvm/test/Other/new-pm-defaults.ll

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@
7171
; RUN: -passes='default<O3>' -S %s 2>&1 \
7272
; RUN: | FileCheck %s --check-prefixes=CHECK-O,CHECK-DEFAULT,CHECK-O3,%llvmcheckext,CHECK-EP-OPTIMIZER-LAST,CHECK-O23SZ
7373

74+
; RUN: opt -disable-verify -verify-analysis-invalidation=0 -eagerly-invalidate-analyses=0 -debug-pass-manager \
75+
; RUN: -passes='default<O3>' -enable-jump-table-to-switch -S %s 2>&1 \
76+
; RUN: | FileCheck %s --check-prefixes=CHECK-O,CHECK-DEFAULT,CHECK-O3,CHECK-JUMP-TABLE-TO-SWITCH,CHECK-O23SZ,%llvmcheckext
77+
7478
; RUN: opt -disable-verify -verify-analysis-invalidation=0 -eagerly-invalidate-analyses=0 -debug-pass-manager \
7579
; RUN: -passes='default<O3>' -enable-matrix -S %s 2>&1 \
7680
; RUN: | FileCheck %s --check-prefixes=CHECK-O,CHECK-DEFAULT,CHECK-O3,CHECK-O23SZ,%llvmcheckext,CHECK-MATRIX
@@ -151,6 +155,7 @@
151155
; CHECK-O23SZ-NEXT: Running analysis: LazyValueAnalysis
152156
; CHECK-O23SZ-NEXT: Running pass: CorrelatedValuePropagationPass
153157
; CHECK-O23SZ-NEXT: Invalidating analysis: LazyValueAnalysis
158+
; CHECK-JUMP-TABLE-TO-SWITCH-NEXT: Running pass: JumpTableToSwitchPass
154159
; CHECK-O-NEXT: Running pass: SimplifyCFGPass
155160
; CHECK-O-NEXT: Running pass: InstCombinePass
156161
; CHECK-O23SZ-NEXT: Running pass: AggressiveInstCombinePass

0 commit comments

Comments
 (0)