|
| 1 | +//===-- AMDGPUCloneModuleLDSPass.cpp ------------------------------*- C++ -*-=// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// |
| 9 | +// The purpose of this pass is to ensure that the combined module contains |
| 10 | +// as many LDS global variables as there are kernels that (indirectly) access |
| 11 | +// them. As LDS variables behave like C++ static variables, it is important that |
| 12 | +// each partition contains a unique copy of the variable on a per kernel basis. |
| 13 | +// This representation also prepares the combined module to eliminate |
| 14 | +// cross-module dependencies of LDS variables. |
| 15 | +// |
| 16 | +// This pass operates as follows: |
| 17 | +// 1. Firstly, traverse the call graph from each kernel to determine the number |
| 18 | +// of kernels calling each device function. |
| 19 | +// 2. For each LDS global variable GV, determine the function F that defines it. |
| 20 | +// Collect it's caller functions. Clone F and GV, and finally insert a |
| 21 | +// call/invoke instruction in each caller function. |
| 22 | +// |
| 23 | +//===----------------------------------------------------------------------===// |
| 24 | + |
| 25 | +#include "AMDGPU.h" |
| 26 | +#include "llvm/ADT/DepthFirstIterator.h" |
| 27 | +#include "llvm/ADT/Twine.h" |
| 28 | +#include "llvm/Analysis/CallGraph.h" |
| 29 | +#include "llvm/IR/InstrTypes.h" |
| 30 | +#include "llvm/IR/Instructions.h" |
| 31 | +#include "llvm/Passes/PassBuilder.h" |
| 32 | +#include "llvm/Support/ScopedPrinter.h" |
| 33 | +#include "llvm/Transforms/Utils/Cloning.h" |
| 34 | + |
| 35 | +using namespace llvm; |
| 36 | + |
| 37 | +#define DEBUG_TYPE "amdgpu-clone-module-lds" |
| 38 | + |
| 39 | +static cl::opt<unsigned int> MaxCountForClonedFunctions( |
| 40 | + "clone-lds-functions-max-count", cl::init(16), cl::Hidden, |
| 41 | + cl::desc("Specify a limit to the number of clones of a function")); |
| 42 | + |
| 43 | +/// Return the function that defines \p GV |
| 44 | +/// \param GV The global variable in question |
| 45 | +/// \return The function defining \p GV |
| 46 | +static Function *getFunctionDefiningGV(GlobalVariable &GV) { |
| 47 | + SmallVector<User *> Worklist(GV.users()); |
| 48 | + while (!Worklist.empty()) { |
| 49 | + User *U = Worklist.pop_back_val(); |
| 50 | + if (auto *Inst = dyn_cast<Instruction>(U)) |
| 51 | + return Inst->getFunction(); |
| 52 | + if (auto *Op = dyn_cast<Operator>(U)) |
| 53 | + append_range(Worklist, Op->users()); |
| 54 | + } |
| 55 | + return nullptr; |
| 56 | +}; |
| 57 | + |
| 58 | +PreservedAnalyses AMDGPUCloneModuleLDSPass::run(Module &M, |
| 59 | + ModuleAnalysisManager &AM) { |
| 60 | + if (MaxCountForClonedFunctions.getValue() == 1) |
| 61 | + return PreservedAnalyses::all(); |
| 62 | + |
| 63 | + bool Changed = false; |
| 64 | + auto &CG = AM.getResult<CallGraphAnalysis>(M); |
| 65 | + |
| 66 | + // For each function in the call graph, determine the number |
| 67 | + // of ancestor-caller kernels. |
| 68 | + DenseMap<Function *, unsigned int> KernelRefsToFuncs; |
| 69 | + for (auto &Fn : M) { |
| 70 | + if (Fn.getCallingConv() != CallingConv::AMDGPU_KERNEL) |
| 71 | + continue; |
| 72 | + for (auto I = df_begin(&CG), E = df_end(&CG); I != E; ++I) |
| 73 | + if (auto *F = I->getFunction()) |
| 74 | + KernelRefsToFuncs[F]++; |
| 75 | + } |
| 76 | + |
| 77 | + DenseMap<GlobalVariable *, Function *> GVToFnMap; |
| 78 | + for (auto &GV : M.globals()) { |
| 79 | + if (GVToFnMap.contains(&GV) || |
| 80 | + GV.getAddressSpace() != AMDGPUAS::LOCAL_ADDRESS || |
| 81 | + !GV.hasInitializer()) |
| 82 | + continue; |
| 83 | + |
| 84 | + auto *OldF = getFunctionDefiningGV(GV); |
| 85 | + GVToFnMap.insert({&GV, OldF}); |
| 86 | + LLVM_DEBUG(dbgs() << "Found LDS " << GV.getName() << " used in function " |
| 87 | + << OldF->getName() << '\n'); |
| 88 | + |
| 89 | + // Collect all call instructions to OldF |
| 90 | + SmallVector<Instruction *> InstsCallingOldF; |
| 91 | + for (auto &I : OldF->uses()) |
| 92 | + if (auto *CI = dyn_cast<CallBase>(I.getUser())) |
| 93 | + InstsCallingOldF.push_back(CI); |
| 94 | + |
| 95 | + // Create as many clones of the function containing LDS global as |
| 96 | + // there are kernels calling the function (including the function |
| 97 | + // already defining the LDS global). Respectively, clone the |
| 98 | + // LDS global and the call instructions to the function. |
| 99 | + LLVM_DEBUG(dbgs() << "\tFunction is referenced by " |
| 100 | + << KernelRefsToFuncs[OldF] << " kernels.\n"); |
| 101 | + for (unsigned int ID = 0; |
| 102 | + ID + 1 < std::min(KernelRefsToFuncs[OldF], |
| 103 | + MaxCountForClonedFunctions.getValue()); |
| 104 | + ++ID) { |
| 105 | + // Clone LDS global variable |
| 106 | + auto *NewGV = new GlobalVariable( |
| 107 | + M, GV.getValueType(), GV.isConstant(), GlobalValue::InternalLinkage, |
| 108 | + PoisonValue::get(GV.getValueType()), |
| 109 | + GV.getName() + ".clone." + Twine(ID), &GV, |
| 110 | + GlobalValue::NotThreadLocal, AMDGPUAS::LOCAL_ADDRESS, false); |
| 111 | + NewGV->copyAttributesFrom(&GV); |
| 112 | + NewGV->copyMetadata(&GV, 0); |
| 113 | + NewGV->setComdat(GV.getComdat()); |
| 114 | + LLVM_DEBUG(dbgs() << "Inserting LDS clone with name " << NewGV->getName() |
| 115 | + << '\n'); |
| 116 | + |
| 117 | + // Clone function |
| 118 | + ValueToValueMapTy VMap; |
| 119 | + VMap[&GV] = NewGV; |
| 120 | + auto *NewF = CloneFunction(OldF, VMap); |
| 121 | + NewF->setName(OldF->getName() + ".clone." + Twine(ID)); |
| 122 | + LLVM_DEBUG(dbgs() << "Inserting function clone with name " |
| 123 | + << NewF->getName() << '\n'); |
| 124 | + |
| 125 | + |
| 126 | + // Create a new CallInst to call the cloned function |
| 127 | + for (auto *Inst : InstsCallingOldF) { |
| 128 | + Instruction *I = Inst->clone(); |
| 129 | + I->setName(Inst->getName() + ".clone." + Twine(ID)); |
| 130 | + if (auto *CI = dyn_cast<CallBase>(I)) |
| 131 | + CI->setCalledOperand(NewF); |
| 132 | + I->insertAfter(Inst); |
| 133 | + LLVM_DEBUG(dbgs() << "Inserting inst: " << *I << '\n'); |
| 134 | + } |
| 135 | + Changed = true; |
| 136 | + } |
| 137 | + } |
| 138 | + return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); |
| 139 | +} |
0 commit comments