|
| 1 | +//===-- AMDGPUCloneModuleLDSPass.cpp ------------------------------*- C++ -*-=// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// |
| 9 | +// The purpose of this pass is to ensure that the combined module contains |
| 10 | +// as many LDS global variables as there are kernels that (indirectly) access |
| 11 | +// them. As LDS variables behave like C++ static variables, it is important that |
| 12 | +// each partition contains a unique copy of the variable on a per kernel basis. |
| 13 | +// This representation also prepares the combined module to eliminate |
| 14 | +// cross-module dependencies of LDS variables. |
| 15 | +// |
| 16 | +// This pass operates as follows: |
| 17 | +// 1. Firstly, traverse the call graph from each kernel to determine the number |
| 18 | +// of kernels calling each device function. |
| 19 | +// 2. For each LDS global variable GV, determine the function F that defines it. |
| 20 | +// Collect it's caller functions. Clone F and GV, and finally insert a |
| 21 | +// call/invoke instruction in each caller function. |
| 22 | +// |
| 23 | +//===----------------------------------------------------------------------===// |
| 24 | + |
| 25 | +#include "AMDGPU.h" |
| 26 | +#include "llvm/ADT/DepthFirstIterator.h" |
| 27 | +#include "llvm/Analysis/CallGraph.h" |
| 28 | +#include "llvm/Passes/PassBuilder.h" |
| 29 | +#include "llvm/Support/ScopedPrinter.h" |
| 30 | +#include "llvm/Transforms/Utils/Cloning.h" |
| 31 | + |
| 32 | +using namespace llvm; |
| 33 | + |
| 34 | +#define DEBUG_TYPE "amdgpu-clone-module-lds" |
| 35 | + |
| 36 | +static cl::opt<unsigned int> MaxCountForClonedFunctions( |
| 37 | + "clone-lds-functions-max-count", cl::init(16), cl::Hidden, |
| 38 | + cl::desc("Specify a limit to the number of clones of a function")); |
| 39 | + |
| 40 | +/// Return the function that defines \p GV |
| 41 | +/// \param GV The global variable in question |
| 42 | +/// \return The function defining \p GV |
| 43 | +static Function *getFunctionDefiningGV(GlobalVariable &GV) { |
| 44 | + SmallVector<User *> Worklist(GV.users()); |
| 45 | + while (!Worklist.empty()) { |
| 46 | + User *U = Worklist.pop_back_val(); |
| 47 | + if (auto *Inst = dyn_cast<Instruction>(U)) |
| 48 | + return Inst->getFunction(); |
| 49 | + if (auto *Op = dyn_cast<Operator>(U)) |
| 50 | + append_range(Worklist, Op->users()); |
| 51 | + } |
| 52 | + llvm_unreachable("GV must be used in a function."); |
| 53 | +}; |
| 54 | + |
| 55 | +/// Replace all references to \p OldGV in \p NewF with \p NewGV |
| 56 | +/// \param OldGV The global variable to be replaced |
| 57 | +/// \param NewGV The global variable taking the place of \p OldGV |
| 58 | +/// \param NewF The function in which the replacement occurs |
| 59 | +static void replaceUsesOfWith(GlobalVariable *OldGV, GlobalVariable *NewGV, |
| 60 | + Function *NewF) { |
| 61 | + // ReplaceOperatorUses takes in an instruction Inst, which is assumed to |
| 62 | + // contain OldGV as an operator, inserts an instruction correponding the |
| 63 | + // OldGV-operand and update Inst accordingly. |
| 64 | + auto ReplaceOperatorUses = [&OldGV, &NewGV](Instruction *Inst) { |
| 65 | + Inst->replaceUsesOfWith(OldGV, NewGV); |
| 66 | + SmallVector<Value *, 8> Worklist(Inst->operands()); |
| 67 | + while (!Worklist.empty()) { |
| 68 | + auto *V = Worklist.pop_back_val(); |
| 69 | + if (auto *I = dyn_cast<AddrSpaceCastOperator>(V)) { |
| 70 | + auto *Cast = new AddrSpaceCastInst(NewGV, I->getType(), "", Inst); |
| 71 | + Inst->replaceUsesOfWith(I, Cast); |
| 72 | + } else if (auto *I = dyn_cast<GEPOperator>(V)) { |
| 73 | + SmallVector<Value *, 8> Indices(I->indices()); |
| 74 | + auto *GEP = GetElementPtrInst::Create(NewGV->getValueType(), NewGV, |
| 75 | + Indices, "", Inst); |
| 76 | + Inst->replaceUsesOfWith(I, GEP); |
| 77 | + } |
| 78 | + } |
| 79 | + }; |
| 80 | + |
| 81 | + // Collect all user instructions of OldGV using a Worklist algorithm. |
| 82 | + // If a user is an operator, collect the instruction wrapping |
| 83 | + // the operator. |
| 84 | + SmallVector<Instruction *, 8> InstsToReplace; |
| 85 | + SmallVector<User *, 8> UsersWorklist(OldGV->users()); |
| 86 | + while (!UsersWorklist.empty()) { |
| 87 | + auto *U = UsersWorklist.pop_back_val(); |
| 88 | + if (auto *Inst = dyn_cast<Instruction>(U)) { |
| 89 | + InstsToReplace.push_back(Inst); |
| 90 | + } else if (auto *Op = dyn_cast<Operator>(U)) { |
| 91 | + append_range(UsersWorklist, Op->users()); |
| 92 | + } |
| 93 | + } |
| 94 | + |
| 95 | + // Replace all occurences of OldGV in NewF |
| 96 | + DenseSet<Instruction *> ReplacedInsts; |
| 97 | + while (!InstsToReplace.empty()) { |
| 98 | + auto *Inst = InstsToReplace.pop_back_val(); |
| 99 | + if (Inst->getFunction() != NewF || ReplacedInsts.contains(Inst)) |
| 100 | + continue; |
| 101 | + ReplaceOperatorUses(Inst); |
| 102 | + ReplacedInsts.insert(Inst); |
| 103 | + } |
| 104 | +}; |
| 105 | + |
| 106 | +PreservedAnalyses AMDGPUCloneModuleLDSPass::run(Module &M, |
| 107 | + ModuleAnalysisManager &AM) { |
| 108 | + if (MaxCountForClonedFunctions.getValue() == 1) |
| 109 | + return PreservedAnalyses::all(); |
| 110 | + |
| 111 | + bool Changed = false; |
| 112 | + auto &CG = AM.getResult<CallGraphAnalysis>(M); |
| 113 | + |
| 114 | + // For each function in the call graph, determine the number |
| 115 | + // of ancestor-caller kernels. |
| 116 | + DenseMap<Function *, unsigned int> KernelRefsToFuncs; |
| 117 | + for (auto &Fn : M) { |
| 118 | + if (Fn.getCallingConv() != CallingConv::AMDGPU_KERNEL) |
| 119 | + continue; |
| 120 | + for (auto I = df_begin(&CG), E = df_end(&CG); I != E; ++I) |
| 121 | + if (auto *F = I->getFunction()) |
| 122 | + KernelRefsToFuncs[F]++; |
| 123 | + } |
| 124 | + |
| 125 | + DenseMap<GlobalVariable *, Function *> GVToFnMap; |
| 126 | + LLVMContext &Ctx = M.getContext(); |
| 127 | + IRBuilder<> IRB(Ctx); |
| 128 | + for (auto &GV : M.globals()) { |
| 129 | + if (GVToFnMap.contains(&GV) || |
| 130 | + GV.getType()->getPointerAddressSpace() != AMDGPUAS::LOCAL_ADDRESS || |
| 131 | + !GV.hasInitializer()) |
| 132 | + continue; |
| 133 | + |
| 134 | + auto *OldF = getFunctionDefiningGV(GV); |
| 135 | + GVToFnMap.insert({&GV, OldF}); |
| 136 | + LLVM_DEBUG(dbgs() << "Found LDS " << GV.getName() << " used in function " |
| 137 | + << OldF->getName() << '\n'); |
| 138 | + |
| 139 | + // Collect all caller functions of OldF. Each of them must call it's |
| 140 | + // corresponding clone of OldF. |
| 141 | + SmallVector<std::pair<Instruction *, SmallVector<Value *>>> |
| 142 | + InstsCallingOldF; |
| 143 | + for (auto &I : OldF->uses()) { |
| 144 | + User *U = I.getUser(); |
| 145 | + SmallVector<Value *> Args; |
| 146 | + if (auto *CI = dyn_cast<CallInst>(U)) { |
| 147 | + append_range(Args, CI->args()); |
| 148 | + InstsCallingOldF.push_back({CI, Args}); |
| 149 | + } else if (auto *II = dyn_cast<InvokeInst>(U)) { |
| 150 | + append_range(Args, II->args()); |
| 151 | + InstsCallingOldF.push_back({II, Args}); |
| 152 | + } |
| 153 | + } |
| 154 | + |
| 155 | + // Create as many clones of the function containing LDS global as |
| 156 | + // there are kernels calling the function (including the function |
| 157 | + // already defining the LDS global). Respectively, clone the |
| 158 | + // LDS global and the call instructions to the function. |
| 159 | + LLVM_DEBUG(dbgs() << "\tFunction is referenced by " |
| 160 | + << KernelRefsToFuncs[OldF] << " kernels.\n"); |
| 161 | + for (unsigned int ID = 0; |
| 162 | + ID + 1 < std::min(KernelRefsToFuncs[OldF], |
| 163 | + MaxCountForClonedFunctions.getValue()); |
| 164 | + ++ID) { |
| 165 | + // Clone function |
| 166 | + ValueToValueMapTy VMap; |
| 167 | + auto *NewF = CloneFunction(OldF, VMap); |
| 168 | + NewF->setName(OldF->getName() + ".clone." + to_string(ID)); |
| 169 | + LLVM_DEBUG(dbgs() << "Inserting function clone with name " |
| 170 | + << NewF->getName() << '\n'); |
| 171 | + |
| 172 | + // Clone LDS global variable |
| 173 | + auto *NewGV = new GlobalVariable( |
| 174 | + M, GV.getValueType(), GV.isConstant(), GlobalValue::InternalLinkage, |
| 175 | + UndefValue::get(GV.getValueType()), |
| 176 | + GV.getName() + ".clone." + to_string(ID), &GV, |
| 177 | + GlobalValue::NotThreadLocal, AMDGPUAS::LOCAL_ADDRESS, false); |
| 178 | + NewGV->copyAttributesFrom(&GV); |
| 179 | + NewGV->copyMetadata(&GV, 0); |
| 180 | + NewGV->setComdat(GV.getComdat()); |
| 181 | + replaceUsesOfWith(&GV, NewGV, NewF); |
| 182 | + LLVM_DEBUG(dbgs() << "Inserting LDS clone with name " << NewGV->getName() |
| 183 | + << "\n"); |
| 184 | + |
| 185 | + // Create a new CallInst to call the cloned function |
| 186 | + for (auto [Inst, Args] : InstsCallingOldF) { |
| 187 | + IRB.SetInsertPoint(Inst); |
| 188 | + Instruction *I; |
| 189 | + if (isa<CallInst>(Inst)) |
| 190 | + I = IRB.CreateCall(NewF, Args, |
| 191 | + Inst->getName() + ".clone." + to_string(ID)); |
| 192 | + else if (auto *II = dyn_cast<InvokeInst>(Inst)) |
| 193 | + I = IRB.CreateInvoke(NewF, II->getNormalDest(), II->getUnwindDest(), |
| 194 | + Args, II->getName() + ".clone" + to_string(ID)); |
| 195 | + LLVM_DEBUG(dbgs() << "Inserting inst: " << *I << '\n'); |
| 196 | + } |
| 197 | + Changed = true; |
| 198 | + } |
| 199 | + } |
| 200 | + return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); |
| 201 | +} |
0 commit comments