|
| 1 | +//===- DXILCBufferAccess.cpp - Translate CBuffer Loads --------------------===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | + |
| 9 | +#include "DXILCBufferAccess.h" |
| 10 | +#include "DirectX.h" |
| 11 | +#include "llvm/Frontend/HLSL/CBuffer.h" |
| 12 | +#include "llvm/Frontend/HLSL/HLSLResource.h" |
| 13 | +#include "llvm/IR/IRBuilder.h" |
| 14 | +#include "llvm/IR/IntrinsicsDirectX.h" |
| 15 | +#include "llvm/InitializePasses.h" |
| 16 | +#include "llvm/Pass.h" |
| 17 | +#include "llvm/Transforms/Utils/Local.h" |
| 18 | + |
| 19 | +#define DEBUG_TYPE "dxil-cbuffer-access" |
| 20 | +using namespace llvm; |
| 21 | + |
| 22 | +namespace { |
| 23 | +/// Helper for building a `load.cbufferrow` intrinsic given a simple type. |
| 24 | +struct CBufferRowIntrin { |
| 25 | + Intrinsic::ID IID; |
| 26 | + Type *RetTy; |
| 27 | + unsigned int EltSize; |
| 28 | + unsigned int NumElts; |
| 29 | + |
| 30 | + CBufferRowIntrin(const DataLayout &DL, Type *Ty) { |
| 31 | + assert(Ty == Ty->getScalarType() && "Expected scalar type"); |
| 32 | + |
| 33 | + switch (DL.getTypeSizeInBits(Ty)) { |
| 34 | + case 16: |
| 35 | + IID = Intrinsic::dx_resource_load_cbufferrow_8; |
| 36 | + RetTy = StructType::get(Ty, Ty, Ty, Ty, Ty, Ty, Ty, Ty); |
| 37 | + EltSize = 2; |
| 38 | + NumElts = 8; |
| 39 | + break; |
| 40 | + case 32: |
| 41 | + IID = Intrinsic::dx_resource_load_cbufferrow_4; |
| 42 | + RetTy = StructType::get(Ty, Ty, Ty, Ty); |
| 43 | + EltSize = 4; |
| 44 | + NumElts = 4; |
| 45 | + break; |
| 46 | + case 64: |
| 47 | + IID = Intrinsic::dx_resource_load_cbufferrow_2; |
| 48 | + RetTy = StructType::get(Ty, Ty); |
| 49 | + EltSize = 8; |
| 50 | + NumElts = 2; |
| 51 | + break; |
| 52 | + default: |
| 53 | + llvm_unreachable("Only 16, 32, and 64 bit types supported"); |
| 54 | + } |
| 55 | + } |
| 56 | +}; |
| 57 | +} // namespace |
| 58 | + |
| 59 | +static size_t getOffsetForCBufferGEP(GEPOperator *GEP, GlobalVariable *Global, |
| 60 | + const DataLayout &DL) { |
| 61 | + // Since we should always have a constant offset, we should only ever have a |
| 62 | + // single GEP of indirection from the Global. |
| 63 | + assert(GEP->getPointerOperand() == Global && |
| 64 | + "Indirect access to resource handle"); |
| 65 | + |
| 66 | + APInt ConstantOffset(DL.getIndexTypeSizeInBits(GEP->getType()), 0); |
| 67 | + bool Success = GEP->accumulateConstantOffset(DL, ConstantOffset); |
| 68 | + (void)Success; |
| 69 | + assert(Success && "Offsets into cbuffer globals must be constant"); |
| 70 | + |
| 71 | + if (auto *ATy = dyn_cast<ArrayType>(Global->getValueType())) |
| 72 | + ConstantOffset = hlsl::translateCBufArrayOffset(DL, ConstantOffset, ATy); |
| 73 | + |
| 74 | + return ConstantOffset.getZExtValue(); |
| 75 | +} |
| 76 | + |
| 77 | +/// Replace access via cbuffer global with a load from the cbuffer handle |
| 78 | +/// itself. |
| 79 | +static void replaceAccess(LoadInst *LI, GlobalVariable *Global, |
| 80 | + GlobalVariable *HandleGV, size_t BaseOffset, |
| 81 | + SmallVectorImpl<WeakTrackingVH> &DeadInsts) { |
| 82 | + const DataLayout &DL = HandleGV->getDataLayout(); |
| 83 | + |
| 84 | + size_t Offset = BaseOffset; |
| 85 | + if (auto *GEP = dyn_cast<GEPOperator>(LI->getPointerOperand())) |
| 86 | + Offset += getOffsetForCBufferGEP(GEP, Global, DL); |
| 87 | + else if (LI->getPointerOperand() != Global) |
| 88 | + llvm_unreachable("Load instruction doesn't reference cbuffer global"); |
| 89 | + |
| 90 | + IRBuilder<> Builder(LI); |
| 91 | + auto *Handle = Builder.CreateLoad(HandleGV->getValueType(), HandleGV, |
| 92 | + HandleGV->getName()); |
| 93 | + |
| 94 | + Type *Ty = LI->getType(); |
| 95 | + CBufferRowIntrin Intrin(DL, Ty->getScalarType()); |
| 96 | + // The cbuffer consists of some number of 16-byte rows. |
| 97 | + unsigned int CurrentRow = Offset / hlsl::CBufferRowSizeInBytes; |
| 98 | + unsigned int CurrentIndex = |
| 99 | + (Offset % hlsl::CBufferRowSizeInBytes) / Intrin.EltSize; |
| 100 | + |
| 101 | + auto *CBufLoad = Builder.CreateIntrinsic( |
| 102 | + Intrin.RetTy, Intrin.IID, |
| 103 | + {Handle, ConstantInt::get(Builder.getInt32Ty(), CurrentRow)}, nullptr, |
| 104 | + LI->getName()); |
| 105 | + auto *Elt = |
| 106 | + Builder.CreateExtractValue(CBufLoad, {CurrentIndex++}, LI->getName()); |
| 107 | + |
| 108 | + Value *Result = nullptr; |
| 109 | + unsigned int Remaining = |
| 110 | + ((DL.getTypeSizeInBits(Ty) / 8) / Intrin.EltSize) - 1; |
| 111 | + if (Remaining == 0) { |
| 112 | + // We only have a single element, so we're done. |
| 113 | + Result = Elt; |
| 114 | + |
| 115 | + // However, if we loaded a <1 x T>, then we need to adjust the type here. |
| 116 | + if (auto *VT = dyn_cast<FixedVectorType>(LI->getType())) { |
| 117 | + assert(VT->getNumElements() == 1 && "Can't have multiple elements here"); |
| 118 | + Result = Builder.CreateInsertElement(PoisonValue::get(VT), Result, |
| 119 | + Builder.getInt32(0)); |
| 120 | + } |
| 121 | + } else { |
| 122 | + // Walk each element and extract it, wrapping to new rows as needed. |
| 123 | + SmallVector<Value *> Extracts{Elt}; |
| 124 | + while (Remaining--) { |
| 125 | + CurrentIndex %= Intrin.NumElts; |
| 126 | + |
| 127 | + if (CurrentIndex == 0) |
| 128 | + CBufLoad = Builder.CreateIntrinsic( |
| 129 | + Intrin.RetTy, Intrin.IID, |
| 130 | + {Handle, ConstantInt::get(Builder.getInt32Ty(), ++CurrentRow)}, |
| 131 | + nullptr, LI->getName()); |
| 132 | + |
| 133 | + Extracts.push_back(Builder.CreateExtractValue(CBufLoad, {CurrentIndex++}, |
| 134 | + LI->getName())); |
| 135 | + } |
| 136 | + |
| 137 | + // Finally, we build up the original loaded value. |
| 138 | + Result = PoisonValue::get(Ty); |
| 139 | + for (int I = 0, E = Extracts.size(); I < E; ++I) |
| 140 | + Result = |
| 141 | + Builder.CreateInsertElement(Result, Extracts[I], Builder.getInt32(I)); |
| 142 | + } |
| 143 | + |
| 144 | + LI->replaceAllUsesWith(Result); |
| 145 | + DeadInsts.push_back(LI); |
| 146 | +} |
| 147 | + |
| 148 | +static void replaceAccessesWithHandle(GlobalVariable *Global, |
| 149 | + GlobalVariable *HandleGV, |
| 150 | + size_t BaseOffset) { |
| 151 | + SmallVector<WeakTrackingVH> DeadInsts; |
| 152 | + |
| 153 | + SmallVector<User *> ToProcess{Global->users()}; |
| 154 | + while (!ToProcess.empty()) { |
| 155 | + User *Cur = ToProcess.pop_back_val(); |
| 156 | + |
| 157 | + // If we have a load instruction, replace the access. |
| 158 | + if (auto *LI = dyn_cast<LoadInst>(Cur)) { |
| 159 | + replaceAccess(LI, Global, HandleGV, BaseOffset, DeadInsts); |
| 160 | + continue; |
| 161 | + } |
| 162 | + |
| 163 | + // Otherwise, walk users looking for a load... |
| 164 | + ToProcess.append(Cur->user_begin(), Cur->user_end()); |
| 165 | + } |
| 166 | + RecursivelyDeleteTriviallyDeadInstructions(DeadInsts); |
| 167 | +} |
| 168 | + |
| 169 | +static bool replaceCBufferAccesses(Module &M) { |
| 170 | + std::optional<hlsl::CBufferMetadata> CBufMD = hlsl::CBufferMetadata::get(M); |
| 171 | + if (!CBufMD) |
| 172 | + return false; |
| 173 | + |
| 174 | + for (const hlsl::CBufferMapping &Mapping : *CBufMD) |
| 175 | + for (const hlsl::CBufferMember &Member : Mapping.Members) { |
| 176 | + replaceAccessesWithHandle(Member.GV, Mapping.Handle, Member.Offset); |
| 177 | + Member.GV->removeFromParent(); |
| 178 | + } |
| 179 | + |
| 180 | + CBufMD->eraseFromModule(); |
| 181 | + return true; |
| 182 | +} |
| 183 | + |
| 184 | +PreservedAnalyses DXILCBufferAccess::run(Module &M, ModuleAnalysisManager &AM) { |
| 185 | + PreservedAnalyses PA; |
| 186 | + bool Changed = replaceCBufferAccesses(M); |
| 187 | + |
| 188 | + if (!Changed) |
| 189 | + return PreservedAnalyses::all(); |
| 190 | + return PA; |
| 191 | +} |
| 192 | + |
| 193 | +namespace { |
| 194 | +class DXILCBufferAccessLegacy : public ModulePass { |
| 195 | +public: |
| 196 | + bool runOnModule(Module &M) override { return replaceCBufferAccesses(M); } |
| 197 | + StringRef getPassName() const override { return "DXIL CBuffer Access"; } |
| 198 | + DXILCBufferAccessLegacy() : ModulePass(ID) {} |
| 199 | + |
| 200 | + static char ID; // Pass identification. |
| 201 | +}; |
| 202 | +char DXILCBufferAccessLegacy::ID = 0; |
| 203 | +} // end anonymous namespace |
| 204 | + |
| 205 | +INITIALIZE_PASS(DXILCBufferAccessLegacy, DEBUG_TYPE, "DXIL CBuffer Access", |
| 206 | + false, false) |
| 207 | + |
| 208 | +ModulePass *llvm::createDXILCBufferAccessLegacyPass() { |
| 209 | + return new DXILCBufferAccessLegacy(); |
| 210 | +} |
0 commit comments