Skip to content

Commit cfb6758

Browse files
Add a pass to convert jump tables to switches
1 parent 7e54ae2 commit cfb6758

16 files changed

+449
-0
lines changed
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
//===- JumpTableToSwitch.h - ------------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef LLVM_TRANSFORMS_SCALAR_JUMP_TABLE_TO_SWITCH_H
10+
#define LLVM_TRANSFORMS_SCALAR_JUMP_TABLE_TO_SWITCH_H
11+
12+
#include "llvm/IR/PassManager.h"
13+
14+
namespace llvm {
15+
16+
class Function;
17+
18+
struct JumpTableToSwitchPass : PassInfoMixin<JumpTableToSwitchPass> {
19+
/// Run the pass over the function.
20+
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
21+
};
22+
} // end namespace llvm
23+
24+
#endif // LLVM_TRANSFORMS_SCALAR_JUMP_TABLE_TO_SWITCH_H

llvm/lib/Passes/PassBuilder.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@
192192
#include "llvm/Transforms/Scalar/InferAddressSpaces.h"
193193
#include "llvm/Transforms/Scalar/InferAlignment.h"
194194
#include "llvm/Transforms/Scalar/InstSimplifyPass.h"
195+
#include "llvm/Transforms/Scalar/JumpTableToSwitch.h"
195196
#include "llvm/Transforms/Scalar/JumpThreading.h"
196197
#include "llvm/Transforms/Scalar/LICM.h"
197198
#include "llvm/Transforms/Scalar/LoopAccessAnalysisPrinter.h"

llvm/lib/Passes/PassBuilderPipelines.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
#include "llvm/Transforms/Scalar/IndVarSimplify.h"
9292
#include "llvm/Transforms/Scalar/InferAlignment.h"
9393
#include "llvm/Transforms/Scalar/InstSimplifyPass.h"
94+
#include "llvm/Transforms/Scalar/JumpTableToSwitch.h"
9495
#include "llvm/Transforms/Scalar/JumpThreading.h"
9596
#include "llvm/Transforms/Scalar/LICM.h"
9697
#include "llvm/Transforms/Scalar/LoopDeletion.h"
@@ -557,6 +558,7 @@ PassBuilder::buildFunctionSimplificationPipeline(OptimizationLevel Level,
557558

558559
// Optimize based on known information about branches, and cleanup afterward.
559560
FPM.addPass(JumpThreadingPass());
561+
FPM.addPass(JumpTableToSwitchPass());
560562
FPM.addPass(CorrelatedValuePropagationPass());
561563

562564
FPM.addPass(
@@ -695,6 +697,7 @@ PassBuilder::buildFunctionSimplificationPipeline(OptimizationLevel Level,
695697
FPM.addPass(DFAJumpThreadingPass());
696698

697699
FPM.addPass(JumpThreadingPass());
700+
698701
FPM.addPass(CorrelatedValuePropagationPass());
699702

700703
// Finally, do an expensive DCE pass to catch all the dead code exposed by
@@ -1926,6 +1929,7 @@ PassBuilder::buildLTODefaultPipeline(OptimizationLevel Level,
19261929

19271930
invokePeepholeEPCallbacks(MainFPM, Level);
19281931
MainFPM.addPass(JumpThreadingPass());
1932+
19291933
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(MainFPM),
19301934
PTO.EagerlyInvalidateAnalyses));
19311935

llvm/lib/Passes/PassRegistry.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ FUNCTION_PASS("interleaved-load-combine", InterleavedLoadCombinePass(TM))
330330
FUNCTION_PASS("invalidate<all>", InvalidateAllAnalysesPass())
331331
FUNCTION_PASS("irce", IRCEPass())
332332
FUNCTION_PASS("jump-threading", JumpThreadingPass())
333+
FUNCTION_PASS("jump-table-to-switch", JumpTableToSwitchPass());
333334
FUNCTION_PASS("kcfi", KCFIPass())
334335
FUNCTION_PASS("lcssa", LCSSAPass())
335336
FUNCTION_PASS("libcalls-shrinkwrap", LibCallsShrinkWrapPass())

llvm/lib/Transforms/Scalar/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ add_llvm_component_library(LLVMScalarOpts
2525
InferAlignment.cpp
2626
InstSimplifyPass.cpp
2727
JumpThreading.cpp
28+
JumpTableToSwitch.cpp
2829
LICM.cpp
2930
LoopAccessAnalysisPrinter.cpp
3031
LoopBoundSplit.cpp
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
//===- JumpTableToSwitch.cpp ----------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "llvm/Transforms/Scalar/JumpTableToSwitch.h"
10+
#include "llvm/ADT/DenseMap.h"
11+
#include "llvm/ADT/SmallSet.h"
12+
#include "llvm/Analysis/DomTreeUpdater.h"
13+
#include "llvm/Analysis/PostDominators.h"
14+
#include "llvm/Analysis/TargetLibraryInfo.h"
15+
#include "llvm/Analysis/TargetTransformInfo.h"
16+
#include "llvm/Analysis/ValueTracking.h"
17+
#include "llvm/IR/IRBuilder.h"
18+
#include "llvm/IR/IntrinsicInst.h"
19+
#include "llvm/IR/PatternMatch.h"
20+
#include "llvm/Support/CommandLine.h"
21+
#include "llvm/Support/Debug.h"
22+
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
23+
#include "llvm/Transforms/Utils/Cloning.h"
24+
#include "llvm/Transforms/Utils/Local.h"
25+
26+
using namespace llvm;
27+
using namespace PatternMatch;
28+
29+
static cl::opt<unsigned>
30+
JumpTableSizeThreshold("jump-table-to-switch-size-threshold", cl::Hidden,
31+
cl::desc("Only split jump tables with size less or "
32+
"equal than JumpTableSizeThreshold."),
33+
cl::init(10));
34+
35+
#define DEBUG_TYPE "jump-table-to-switch"
36+
37+
static Constant *getElementWithGivenTypeAtOffset(Constant *C, const Type *Ty,
38+
uint64_t Offset, Module &M) {
39+
if (Offset == 0 && C->getType() == Ty)
40+
return C;
41+
if (auto *CS = dyn_cast<ConstantStruct>(C)) {
42+
const DataLayout &DL = M.getDataLayout();
43+
const StructLayout *SL = DL.getStructLayout(CS->getType());
44+
if (Offset >= SL->getSizeInBytes())
45+
return nullptr;
46+
const unsigned Op = SL->getElementContainingOffset(Offset);
47+
const uint64_t AdjustedOffset = Offset - SL->getElementOffset(Op);
48+
Constant *Element = cast<Constant>(CS->getOperand(Op));
49+
return getElementWithGivenTypeAtOffset(Element, Ty, AdjustedOffset, M);
50+
}
51+
// TODO: add support for arrays.
52+
return nullptr;
53+
}
54+
55+
namespace {
56+
struct JumpTableTy {
57+
Value *Index;
58+
SmallVector<Function *, 5> Funcs;
59+
};
60+
} // anonymous namespace
61+
62+
static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP) {
63+
if (!GEP || !GEP->isInBounds())
64+
return std::nullopt;
65+
ArrayType *ArrayTy = dyn_cast<ArrayType>(GEP->getSourceElementType());
66+
if (!ArrayTy || ArrayTy->getArrayNumElements() > JumpTableSizeThreshold)
67+
return std::nullopt;
68+
Constant *Ptr = dyn_cast<Constant>(GEP->getPointerOperand());
69+
if (!Ptr)
70+
return std::nullopt;
71+
72+
Function &F = *GEP->getParent()->getParent();
73+
const DataLayout &DL = F.getParent()->getDataLayout();
74+
const unsigned BitWidth =
75+
DL.getIndexSizeInBits(GEP->getPointerAddressSpace());
76+
MapVector<Value *, APInt> VariableOffsets;
77+
APInt ConstantOffset(BitWidth, 0);
78+
if (!GEP->collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset))
79+
return std::nullopt;
80+
if (VariableOffsets.empty() || VariableOffsets.size() > 1)
81+
return std::nullopt;
82+
83+
Module &M = *F.getParent();
84+
unsigned Offset = ConstantOffset.getZExtValue();
85+
// TODO: support more general patterns
86+
// (see also TODO in getElementWithGivenTypeAtOffset).
87+
if (Offset != 0)
88+
return std::nullopt;
89+
if (!Ptr->getNumOperands())
90+
return std::nullopt;
91+
Constant *ConstArray = getElementWithGivenTypeAtOffset(
92+
cast<Constant>(Ptr->getOperand(0)), ArrayTy, Offset, M);
93+
if (!ConstArray)
94+
return std::nullopt;
95+
96+
JumpTableTy JumpTable;
97+
JumpTable.Index = VariableOffsets.front().first;
98+
99+
const uint64_t N = ArrayTy->getArrayNumElements();
100+
JumpTable.Funcs.assign(N, nullptr);
101+
for (uint64_t Index = 0; Index < N; ++Index) {
102+
auto *Func =
103+
dyn_cast_or_null<Function>(ConstArray->getAggregateElement(Index));
104+
if (!Func || Func->isDeclaration())
105+
return std::nullopt;
106+
JumpTable.Funcs[Index] = Func;
107+
}
108+
return JumpTable;
109+
}
110+
111+
static BasicBlock *split(CallBase *CB, const JumpTableTy &JT,
112+
DomTreeUpdater *DTU) {
113+
const bool IsVoid = CB->getType() == Type::getVoidTy(CB->getContext());
114+
115+
SmallVector<DominatorTree::UpdateType, 8> DTUpdates;
116+
BasicBlock *BB = CB->getParent();
117+
BasicBlock *Tail =
118+
SplitBlock(BB, CB, DTU, nullptr, nullptr, BB->getName() + Twine(".tail"));
119+
DTUpdates.push_back({DominatorTree::Delete, BB, Tail});
120+
BB->getTerminator()->eraseFromParent();
121+
122+
Function &F = *BB->getParent();
123+
BasicBlock *BBUnreachable = BasicBlock::Create(
124+
F.getContext(), "default.switch.case.unreachable", &F, Tail);
125+
IRBuilder<> BuilderUnreachable(BBUnreachable);
126+
BuilderUnreachable.CreateUnreachable();
127+
128+
IRBuilder<> Builder(BB);
129+
SwitchInst *Switch = Builder.CreateSwitch(JT.Index, BBUnreachable);
130+
DTUpdates.push_back({DominatorTree::Insert, BB, BBUnreachable});
131+
132+
IRBuilder<> BuilderTail(CB);
133+
PHINode *PHI =
134+
IsVoid ? nullptr : BuilderTail.CreatePHI(CB->getType(), JT.Funcs.size());
135+
136+
for (auto [Index, Func] : llvm::enumerate(JT.Funcs)) {
137+
BasicBlock *B = BasicBlock::Create(Func->getContext(),
138+
"call." + Twine(Index), &F, Tail);
139+
DTUpdates.push_back({DominatorTree::Insert, BB, B});
140+
DTUpdates.push_back({DominatorTree::Insert, B, Tail});
141+
142+
CallBase *Call = cast<CallBase>(CB->clone());
143+
Call->setCalledFunction(Func);
144+
Call->insertInto(B, B->end());
145+
Switch->addCase(
146+
cast<ConstantInt>(ConstantInt::get(JT.Index->getType(), Index)), B);
147+
BranchInst::Create(Tail, B);
148+
if (PHI)
149+
PHI->addIncoming(Call, B);
150+
}
151+
if (DTU)
152+
DTU->applyUpdates(DTUpdates);
153+
if (PHI)
154+
CB->replaceAllUsesWith(PHI);
155+
CB->eraseFromParent();
156+
return Tail;
157+
}
158+
159+
PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
160+
FunctionAnalysisManager &AM) {
161+
DominatorTree *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
162+
PostDominatorTree *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F);
163+
std::unique_ptr<DomTreeUpdater> DTU;
164+
bool Changed = false;
165+
for (BasicBlock &BB : make_early_inc_range(F)) {
166+
BasicBlock *CurrentBB = &BB;
167+
while (CurrentBB) {
168+
BasicBlock *SplittedOutTail = nullptr;
169+
for (Instruction &I : make_early_inc_range(*CurrentBB)) {
170+
CallBase *CB = dyn_cast<CallBase>(&I);
171+
if (!CB || isa<IntrinsicInst>(CB) || CB->getCalledFunction() ||
172+
isa<InvokeInst>(CB) || CB->isMustTailCall())
173+
continue;
174+
175+
Value *V;
176+
if (!match(CB->getCalledOperand(), m_Load(m_Value(V))))
177+
continue;
178+
auto *GEP = dyn_cast<GetElementPtrInst>(V);
179+
if (!GEP)
180+
continue;
181+
182+
std::optional<JumpTableTy> JumpTable = parseJumpTable(GEP);
183+
if (!JumpTable)
184+
continue;
185+
if ((DT || PDT) && !DTU)
186+
DTU = std::make_unique<DomTreeUpdater>(
187+
DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy);
188+
SplittedOutTail = split(CB, *JumpTable, DTU.get());
189+
Changed = true;
190+
break;
191+
}
192+
CurrentBB = SplittedOutTail ? SplittedOutTail : nullptr;
193+
}
194+
}
195+
196+
if (!Changed)
197+
return PreservedAnalyses::all();
198+
199+
PreservedAnalyses PA;
200+
if (DT)
201+
PA.preserve<DominatorTreeAnalysis>();
202+
if (PDT)
203+
PA.preserve<PostDominatorTreeAnalysis>();
204+
return PA;
205+
}

llvm/test/Other/new-pm-defaults.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@
149149
; CHECK-O23SZ-NEXT: Running pass: SpeculativeExecutionPass
150150
; CHECK-O23SZ-NEXT: Running pass: JumpThreadingPass
151151
; CHECK-O23SZ-NEXT: Running analysis: LazyValueAnalysis
152+
; CHECK-O23SZ-NEXT: Running pass: JumpTableToSwitchPass
152153
; CHECK-O23SZ-NEXT: Running pass: CorrelatedValuePropagationPass
153154
; CHECK-O23SZ-NEXT: Invalidating analysis: LazyValueAnalysis
154155
; CHECK-O-NEXT: Running pass: SimplifyCFGPass

llvm/test/Other/new-pm-thinlto-postlink-defaults.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
; CHECK-O23SZ-NEXT: Running pass: SpeculativeExecutionPass
8989
; CHECK-O23SZ-NEXT: Running pass: JumpThreadingPass
9090
; CHECK-O23SZ-NEXT: Running analysis: LazyValueAnalysis
91+
; CHECK-O23SZ-NEXT: Running pass: JumpTableToSwitchPass
9192
; CHECK-O23SZ-NEXT: Running pass: CorrelatedValuePropagationPass
9293
; CHECK-O23SZ-NEXT: Invalidating analysis: LazyValueAnalysis
9394
; CHECK-O-NEXT: Running pass: SimplifyCFGPass

llvm/test/Other/new-pm-thinlto-postlink-pgo-defaults.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
; CHECK-O23SZ-NEXT: Running pass: SpeculativeExecutionPass
7777
; CHECK-O23SZ-NEXT: Running pass: JumpThreadingPass
7878
; CHECK-O23SZ-NEXT: Running analysis: LazyValueAnalysis
79+
; CHECK-O23SZ-NEXT: Running pass: JumpTableToSwitchPass
7980
; CHECK-O23SZ-NEXT: Running pass: CorrelatedValuePropagationPass
8081
; CHECK-O23SZ-NEXT: Invalidating analysis: LazyValueAnalysis
8182
; CHECK-O-NEXT: Running pass: SimplifyCFGPass

llvm/test/Other/new-pm-thinlto-postlink-samplepgo-defaults.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
; CHECK-O23SZ-NEXT: Running pass: SpeculativeExecutionPass
8585
; CHECK-O23SZ-NEXT: Running pass: JumpThreadingPass
8686
; CHECK-O23SZ-NEXT: Running analysis: LazyValueAnalysis
87+
; CHECK-O23SZ-NEXT: Running pass: JumpTableToSwitchPass
8788
; CHECK-O23SZ-NEXT: Running pass: CorrelatedValuePropagationPass
8889
; CHECK-O23SZ-NEXT: Invalidating analysis: LazyValueAnalysis
8990
; CHECK-O-NEXT: Running pass: SimplifyCFGPass

llvm/test/Other/new-pm-thinlto-prelink-defaults.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@
119119
; CHECK-O23SZ-NEXT: Running pass: SpeculativeExecutionPass
120120
; CHECK-O23SZ-NEXT: Running pass: JumpThreadingPass
121121
; CHECK-O23SZ-NEXT: Running analysis: LazyValueAnalysis
122+
; CHECK-O23SZ-NEXT: Running pass: JumpTableToSwitchPass
122123
; CHECK-O23SZ-NEXT: Running pass: CorrelatedValuePropagationPass
123124
; CHECK-O23SZ-NEXT: Invalidating analysis: LazyValueAnalysis
124125
; CHECK-O-NEXT: Running pass: SimplifyCFGPass

llvm/test/Other/new-pm-thinlto-prelink-pgo-defaults.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
; CHECK-O23SZ-NEXT: Running pass: SpeculativeExecutionPass
117117
; CHECK-O23SZ-NEXT: Running pass: JumpThreadingPass
118118
; CHECK-O23SZ-NEXT: Running analysis: LazyValueAnalysis
119+
; CHECK-O23SZ-NEXT: Running pass: JumpTableToSwitchPass
119120
; CHECK-O23SZ-NEXT: Running pass: CorrelatedValuePropagationPass
120121
; CHECK-O23SZ-NEXT: Invalidating analysis: LazyValueAnalysis
121122
; CHECK-O-NEXT: Running pass: SimplifyCFGPass

llvm/test/Other/new-pm-thinlto-prelink-samplepgo-defaults.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
; CHECK-O23SZ-NEXT: Running pass: SpeculativeExecutionPass
8989
; CHECK-O23SZ-NEXT: Running pass: JumpThreadingPass
9090
; CHECK-O23SZ-NEXT: Running analysis: LazyValueAnalysis
91+
; CHECK-O23SZ-NEXT: Running pass: JumpTableToSwitchPass
9192
; CHECK-O23SZ-NEXT: Running pass: CorrelatedValuePropagationPass
9293
; CHECK-O23SZ-NEXT: Invalidating analysis: LazyValueAnalysis
9394
; CHECK-O-NEXT: Running pass: SimplifyCFGPass

0 commit comments

Comments
 (0)