|
| 1 | +//===- JumpTableToSwitch.cpp ----------------------------------------------===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | + |
| 9 | +#include "llvm/Transforms/Scalar/JumpTableToSwitch.h" |
| 10 | +#include "llvm/ADT/DenseMap.h" |
| 11 | +#include "llvm/ADT/SmallSet.h" |
| 12 | +#include "llvm/Analysis/DomTreeUpdater.h" |
| 13 | +#include "llvm/Analysis/PostDominators.h" |
| 14 | +#include "llvm/Analysis/TargetLibraryInfo.h" |
| 15 | +#include "llvm/Analysis/TargetTransformInfo.h" |
| 16 | +#include "llvm/Analysis/ValueTracking.h" |
| 17 | +#include "llvm/IR/IRBuilder.h" |
| 18 | +#include "llvm/IR/IntrinsicInst.h" |
| 19 | +#include "llvm/IR/PatternMatch.h" |
| 20 | +#include "llvm/Support/CommandLine.h" |
| 21 | +#include "llvm/Support/Debug.h" |
| 22 | +#include "llvm/Transforms/Utils/BasicBlockUtils.h" |
| 23 | +#include "llvm/Transforms/Utils/Cloning.h" |
| 24 | +#include "llvm/Transforms/Utils/Local.h" |
| 25 | + |
| 26 | +using namespace llvm; |
| 27 | +using namespace PatternMatch; |
| 28 | + |
| 29 | +static cl::opt<unsigned> |
| 30 | + JumpTableSizeThreshold("jump-table-to-switch-size-threshold", cl::Hidden, |
| 31 | + cl::desc("Only split jump tables with size less or " |
| 32 | + "equal than JumpTableSizeThreshold."), |
| 33 | + cl::init(10)); |
| 34 | + |
| 35 | +#define DEBUG_TYPE "jump-table-to-switch" |
| 36 | + |
| 37 | +static Constant *getElementWithGivenTypeAtOffset(Constant *C, const Type *Ty, |
| 38 | + uint64_t Offset, Module &M) { |
| 39 | + if (Offset == 0 && C->getType() == Ty) |
| 40 | + return C; |
| 41 | + if (auto *CS = dyn_cast<ConstantStruct>(C)) { |
| 42 | + const DataLayout &DL = M.getDataLayout(); |
| 43 | + const StructLayout *SL = DL.getStructLayout(CS->getType()); |
| 44 | + if (Offset >= SL->getSizeInBytes()) |
| 45 | + return nullptr; |
| 46 | + const unsigned Op = SL->getElementContainingOffset(Offset); |
| 47 | + const uint64_t AdjustedOffset = Offset - SL->getElementOffset(Op); |
| 48 | + Constant *Element = cast<Constant>(CS->getOperand(Op)); |
| 49 | + return getElementWithGivenTypeAtOffset(Element, Ty, AdjustedOffset, M); |
| 50 | + } |
| 51 | + // TODO: add support for arrays. |
| 52 | + return nullptr; |
| 53 | +} |
| 54 | + |
| 55 | +namespace { |
| 56 | +struct JumpTableTy { |
| 57 | + Value *Index; |
| 58 | + SmallVector<Function *, 5> Funcs; |
| 59 | +}; |
| 60 | +} // anonymous namespace |
| 61 | + |
| 62 | +static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP) { |
| 63 | + if (!GEP || !GEP->isInBounds()) |
| 64 | + return std::nullopt; |
| 65 | + ArrayType *ArrayTy = dyn_cast<ArrayType>(GEP->getSourceElementType()); |
| 66 | + if (!ArrayTy || ArrayTy->getArrayNumElements() > JumpTableSizeThreshold) |
| 67 | + return std::nullopt; |
| 68 | + Constant *Ptr = dyn_cast<Constant>(GEP->getPointerOperand()); |
| 69 | + if (!Ptr) |
| 70 | + return std::nullopt; |
| 71 | + |
| 72 | + Function &F = *GEP->getParent()->getParent(); |
| 73 | + const DataLayout &DL = F.getParent()->getDataLayout(); |
| 74 | + const unsigned BitWidth = |
| 75 | + DL.getIndexSizeInBits(GEP->getPointerAddressSpace()); |
| 76 | + MapVector<Value *, APInt> VariableOffsets; |
| 77 | + APInt ConstantOffset(BitWidth, 0); |
| 78 | + if (!GEP->collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset)) |
| 79 | + return std::nullopt; |
| 80 | + if (VariableOffsets.empty() || VariableOffsets.size() > 1) |
| 81 | + return std::nullopt; |
| 82 | + |
| 83 | + Module &M = *F.getParent(); |
| 84 | + unsigned Offset = ConstantOffset.getZExtValue(); |
| 85 | + // TODO: support more general patterns |
| 86 | + // (see also TODO in getElementWithGivenTypeAtOffset). |
| 87 | + if (Offset != 0) |
| 88 | + return std::nullopt; |
| 89 | + if (!Ptr->getNumOperands()) |
| 90 | + return std::nullopt; |
| 91 | + Constant *ConstArray = getElementWithGivenTypeAtOffset( |
| 92 | + cast<Constant>(Ptr->getOperand(0)), ArrayTy, Offset, M); |
| 93 | + if (!ConstArray) |
| 94 | + return std::nullopt; |
| 95 | + |
| 96 | + JumpTableTy JumpTable; |
| 97 | + JumpTable.Index = VariableOffsets.front().first; |
| 98 | + |
| 99 | + const uint64_t N = ArrayTy->getArrayNumElements(); |
| 100 | + JumpTable.Funcs.assign(N, nullptr); |
| 101 | + for (uint64_t Index = 0; Index < N; ++Index) { |
| 102 | + auto *Func = |
| 103 | + dyn_cast_or_null<Function>(ConstArray->getAggregateElement(Index)); |
| 104 | + if (!Func || Func->isDeclaration()) |
| 105 | + return std::nullopt; |
| 106 | + JumpTable.Funcs[Index] = Func; |
| 107 | + } |
| 108 | + return JumpTable; |
| 109 | +} |
| 110 | + |
| 111 | +static BasicBlock *split(CallBase *CB, const JumpTableTy &JT, |
| 112 | + DomTreeUpdater *DTU) { |
| 113 | + const bool IsVoid = CB->getType() == Type::getVoidTy(CB->getContext()); |
| 114 | + |
| 115 | + SmallVector<DominatorTree::UpdateType, 8> DTUpdates; |
| 116 | + BasicBlock *BB = CB->getParent(); |
| 117 | + BasicBlock *Tail = |
| 118 | + SplitBlock(BB, CB, DTU, nullptr, nullptr, BB->getName() + Twine(".tail")); |
| 119 | + DTUpdates.push_back({DominatorTree::Delete, BB, Tail}); |
| 120 | + BB->getTerminator()->eraseFromParent(); |
| 121 | + |
| 122 | + Function &F = *BB->getParent(); |
| 123 | + BasicBlock *BBUnreachable = BasicBlock::Create( |
| 124 | + F.getContext(), "default.switch.case.unreachable", &F, Tail); |
| 125 | + IRBuilder<> BuilderUnreachable(BBUnreachable); |
| 126 | + BuilderUnreachable.CreateUnreachable(); |
| 127 | + |
| 128 | + IRBuilder<> Builder(BB); |
| 129 | + SwitchInst *Switch = Builder.CreateSwitch(JT.Index, BBUnreachable); |
| 130 | + DTUpdates.push_back({DominatorTree::Insert, BB, BBUnreachable}); |
| 131 | + |
| 132 | + IRBuilder<> BuilderTail(CB); |
| 133 | + PHINode *PHI = |
| 134 | + IsVoid ? nullptr : BuilderTail.CreatePHI(CB->getType(), JT.Funcs.size()); |
| 135 | + |
| 136 | + for (auto [Index, Func] : llvm::enumerate(JT.Funcs)) { |
| 137 | + BasicBlock *B = BasicBlock::Create(Func->getContext(), |
| 138 | + "call." + Twine(Index), &F, Tail); |
| 139 | + DTUpdates.push_back({DominatorTree::Insert, BB, B}); |
| 140 | + DTUpdates.push_back({DominatorTree::Insert, B, Tail}); |
| 141 | + |
| 142 | + CallBase *Call = cast<CallBase>(CB->clone()); |
| 143 | + Call->setCalledFunction(Func); |
| 144 | + Call->insertInto(B, B->end()); |
| 145 | + Switch->addCase( |
| 146 | + cast<ConstantInt>(ConstantInt::get(JT.Index->getType(), Index)), B); |
| 147 | + BranchInst::Create(Tail, B); |
| 148 | + if (PHI) |
| 149 | + PHI->addIncoming(Call, B); |
| 150 | + } |
| 151 | + if (DTU) |
| 152 | + DTU->applyUpdates(DTUpdates); |
| 153 | + if (PHI) |
| 154 | + CB->replaceAllUsesWith(PHI); |
| 155 | + CB->eraseFromParent(); |
| 156 | + return Tail; |
| 157 | +} |
| 158 | + |
| 159 | +PreservedAnalyses JumpTableToSwitchPass::run(Function &F, |
| 160 | + FunctionAnalysisManager &AM) { |
| 161 | + DominatorTree *DT = AM.getCachedResult<DominatorTreeAnalysis>(F); |
| 162 | + PostDominatorTree *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F); |
| 163 | + std::unique_ptr<DomTreeUpdater> DTU; |
| 164 | + bool Changed = false; |
| 165 | + for (BasicBlock &BB : make_early_inc_range(F)) { |
| 166 | + BasicBlock *CurrentBB = &BB; |
| 167 | + while (CurrentBB) { |
| 168 | + BasicBlock *SplittedOutTail = nullptr; |
| 169 | + for (Instruction &I : make_early_inc_range(*CurrentBB)) { |
| 170 | + CallBase *CB = dyn_cast<CallBase>(&I); |
| 171 | + if (!CB || isa<IntrinsicInst>(CB) || CB->getCalledFunction() || |
| 172 | + isa<InvokeInst>(CB) || CB->isMustTailCall()) |
| 173 | + continue; |
| 174 | + |
| 175 | + Value *V; |
| 176 | + if (!match(CB->getCalledOperand(), m_Load(m_Value(V)))) |
| 177 | + continue; |
| 178 | + auto *GEP = dyn_cast<GetElementPtrInst>(V); |
| 179 | + if (!GEP) |
| 180 | + continue; |
| 181 | + |
| 182 | + std::optional<JumpTableTy> JumpTable = parseJumpTable(GEP); |
| 183 | + if (!JumpTable) |
| 184 | + continue; |
| 185 | + if ((DT || PDT) && !DTU) |
| 186 | + DTU = std::make_unique<DomTreeUpdater>( |
| 187 | + DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy); |
| 188 | + SplittedOutTail = split(CB, *JumpTable, DTU.get()); |
| 189 | + Changed = true; |
| 190 | + break; |
| 191 | + } |
| 192 | + CurrentBB = SplittedOutTail ? SplittedOutTail : nullptr; |
| 193 | + } |
| 194 | + } |
| 195 | + |
| 196 | + if (!Changed) |
| 197 | + return PreservedAnalyses::all(); |
| 198 | + |
| 199 | + PreservedAnalyses PA; |
| 200 | + if (DT) |
| 201 | + PA.preserve<DominatorTreeAnalysis>(); |
| 202 | + if (PDT) |
| 203 | + PA.preserve<PostDominatorTreeAnalysis>(); |
| 204 | + return PA; |
| 205 | +} |
0 commit comments