Skip to content

Commit 3de88fe

Browse files
authored
[DirectX] Implement the DXILCBufferAccess pass (#134571)
This introduces a pass that walks accesses to globals in cbuffers and replaces them with accesses via the cbuffer handle itself. The logic to interpret the cbuffer metadata is kept in `lib/Frontend/HLSL` so that it can be reused by other consumers of that metadata. Fixes #124630.
1 parent 2a02404 commit 3de88fe

File tree

16 files changed

+822
-0
lines changed

16 files changed

+822
-0
lines changed
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
//===- CBuffer.h - HLSL constant buffer handling ----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
///
9+
/// \file This file contains utilities to work with constant buffers in HLSL.
10+
///
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef LLVM_FRONTEND_HLSL_CBUFFER_H
14+
#define LLVM_FRONTEND_HLSL_CBUFFER_H
15+
16+
#include "llvm/ADT/SmallVector.h"
17+
#include "llvm/IR/DataLayout.h"
18+
#include "llvm/IR/DerivedTypes.h"
19+
#include <optional>
20+
21+
namespace llvm {
22+
class Module;
23+
class GlobalVariable;
24+
class NamedMDNode;
25+
26+
namespace hlsl {
27+
28+
struct CBufferMember {
29+
GlobalVariable *GV;
30+
size_t Offset;
31+
32+
CBufferMember(GlobalVariable *GV, size_t Offset) : GV(GV), Offset(Offset) {}
33+
};
34+
35+
struct CBufferMapping {
36+
GlobalVariable *Handle;
37+
SmallVector<CBufferMember> Members;
38+
39+
CBufferMapping(GlobalVariable *Handle) : Handle(Handle) {}
40+
};
41+
42+
class CBufferMetadata {
43+
NamedMDNode *MD;
44+
SmallVector<CBufferMapping> Mappings;
45+
46+
CBufferMetadata(NamedMDNode *MD) : MD(MD) {}
47+
48+
public:
49+
static std::optional<CBufferMetadata> get(Module &M);
50+
51+
using iterator = SmallVector<CBufferMapping>::iterator;
52+
iterator begin() { return Mappings.begin(); }
53+
iterator end() { return Mappings.end(); }
54+
55+
void eraseFromModule();
56+
};
57+
58+
APInt translateCBufArrayOffset(const DataLayout &DL, APInt Offset,
59+
ArrayType *Ty);
60+
61+
} // namespace hlsl
62+
} // namespace llvm
63+
64+
#endif // LLVM_FRONTEND_HLSL_CBUFFER_H

llvm/lib/Frontend/HLSL/CBuffer.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
//===- CBuffer.cpp - HLSL constant buffer handling ------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "llvm/Frontend/HLSL/CBuffer.h"
10+
#include "llvm/Frontend/HLSL/HLSLResource.h"
11+
#include "llvm/IR/DerivedTypes.h"
12+
#include "llvm/IR/Metadata.h"
13+
#include "llvm/IR/Module.h"
14+
15+
using namespace llvm;
16+
using namespace llvm::hlsl;
17+
18+
static size_t getMemberOffset(GlobalVariable *Handle, size_t Index) {
19+
auto *HandleTy = cast<TargetExtType>(Handle->getValueType());
20+
assert(HandleTy->getName().ends_with(".CBuffer") && "Not a cbuffer type");
21+
assert(HandleTy->getNumTypeParameters() == 1 && "Expected layout type");
22+
23+
auto *LayoutTy = cast<TargetExtType>(HandleTy->getTypeParameter(0));
24+
assert(LayoutTy->getName().ends_with(".Layout") && "Not a layout type");
25+
26+
// Skip the "size" parameter.
27+
size_t ParamIndex = Index + 1;
28+
assert(LayoutTy->getNumIntParameters() > ParamIndex &&
29+
"Not enough parameters");
30+
31+
return LayoutTy->getIntParameter(ParamIndex);
32+
}
33+
34+
std::optional<CBufferMetadata> CBufferMetadata::get(Module &M) {
35+
NamedMDNode *CBufMD = M.getNamedMetadata("hlsl.cbs");
36+
if (!CBufMD)
37+
return std::nullopt;
38+
39+
std::optional<CBufferMetadata> Result({CBufMD});
40+
41+
for (const MDNode *MD : CBufMD->operands()) {
42+
assert(MD->getNumOperands() && "Invalid cbuffer metadata");
43+
44+
auto *Handle = cast<GlobalVariable>(
45+
cast<ValueAsMetadata>(MD->getOperand(0))->getValue());
46+
CBufferMapping &Mapping = Result->Mappings.emplace_back(Handle);
47+
48+
for (int I = 1, E = MD->getNumOperands(); I < E; ++I) {
49+
Metadata *OpMD = MD->getOperand(I);
50+
// Some members may be null if they've been optimized out.
51+
if (!OpMD)
52+
continue;
53+
auto *V = cast<GlobalVariable>(cast<ValueAsMetadata>(OpMD)->getValue());
54+
Mapping.Members.emplace_back(V, getMemberOffset(Handle, I - 1));
55+
}
56+
}
57+
58+
return Result;
59+
}
60+
61+
void CBufferMetadata::eraseFromModule() {
62+
// Remove the cbs named metadata
63+
MD->eraseFromParent();
64+
}
65+
66+
APInt hlsl::translateCBufArrayOffset(const DataLayout &DL, APInt Offset,
67+
ArrayType *Ty) {
68+
int64_t TypeSize = DL.getTypeSizeInBits(Ty->getElementType()) / 8;
69+
int64_t RoundUp = alignTo(TypeSize, Align(CBufferRowSizeInBytes));
70+
return Offset.udiv(TypeSize) * RoundUp;
71+
}

llvm/lib/Frontend/HLSL/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_llvm_component_library(LLVMFrontendHLSL
2+
CBuffer.cpp
23
HLSLResource.cpp
34

45
ADDITIONAL_HEADER_DIRS

llvm/lib/Target/DirectX/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ add_llvm_target(DirectXCodeGen
2020
DirectXTargetMachine.cpp
2121
DirectXTargetTransformInfo.cpp
2222
DXContainerGlobals.cpp
23+
DXILCBufferAccess.cpp
2324
DXILDataScalarization.cpp
2425
DXILFinalizeLinkage.cpp
2526
DXILFlattenArrays.cpp
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
//===- DXILCBufferAccess.cpp - Translate CBuffer Loads --------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "DXILCBufferAccess.h"
10+
#include "DirectX.h"
11+
#include "llvm/Frontend/HLSL/CBuffer.h"
12+
#include "llvm/Frontend/HLSL/HLSLResource.h"
13+
#include "llvm/IR/IRBuilder.h"
14+
#include "llvm/IR/IntrinsicsDirectX.h"
15+
#include "llvm/InitializePasses.h"
16+
#include "llvm/Pass.h"
17+
#include "llvm/Transforms/Utils/Local.h"
18+
19+
#define DEBUG_TYPE "dxil-cbuffer-access"
20+
using namespace llvm;
21+
22+
namespace {
23+
/// Helper for building a `load.cbufferrow` intrinsic given a simple type.
24+
struct CBufferRowIntrin {
25+
Intrinsic::ID IID;
26+
Type *RetTy;
27+
unsigned int EltSize;
28+
unsigned int NumElts;
29+
30+
CBufferRowIntrin(const DataLayout &DL, Type *Ty) {
31+
assert(Ty == Ty->getScalarType() && "Expected scalar type");
32+
33+
switch (DL.getTypeSizeInBits(Ty)) {
34+
case 16:
35+
IID = Intrinsic::dx_resource_load_cbufferrow_8;
36+
RetTy = StructType::get(Ty, Ty, Ty, Ty, Ty, Ty, Ty, Ty);
37+
EltSize = 2;
38+
NumElts = 8;
39+
break;
40+
case 32:
41+
IID = Intrinsic::dx_resource_load_cbufferrow_4;
42+
RetTy = StructType::get(Ty, Ty, Ty, Ty);
43+
EltSize = 4;
44+
NumElts = 4;
45+
break;
46+
case 64:
47+
IID = Intrinsic::dx_resource_load_cbufferrow_2;
48+
RetTy = StructType::get(Ty, Ty);
49+
EltSize = 8;
50+
NumElts = 2;
51+
break;
52+
default:
53+
llvm_unreachable("Only 16, 32, and 64 bit types supported");
54+
}
55+
}
56+
};
57+
} // namespace
58+
59+
static size_t getOffsetForCBufferGEP(GEPOperator *GEP, GlobalVariable *Global,
60+
const DataLayout &DL) {
61+
// Since we should always have a constant offset, we should only ever have a
62+
// single GEP of indirection from the Global.
63+
assert(GEP->getPointerOperand() == Global &&
64+
"Indirect access to resource handle");
65+
66+
APInt ConstantOffset(DL.getIndexTypeSizeInBits(GEP->getType()), 0);
67+
bool Success = GEP->accumulateConstantOffset(DL, ConstantOffset);
68+
(void)Success;
69+
assert(Success && "Offsets into cbuffer globals must be constant");
70+
71+
if (auto *ATy = dyn_cast<ArrayType>(Global->getValueType()))
72+
ConstantOffset = hlsl::translateCBufArrayOffset(DL, ConstantOffset, ATy);
73+
74+
return ConstantOffset.getZExtValue();
75+
}
76+
77+
/// Replace access via cbuffer global with a load from the cbuffer handle
78+
/// itself.
79+
static void replaceAccess(LoadInst *LI, GlobalVariable *Global,
80+
GlobalVariable *HandleGV, size_t BaseOffset,
81+
SmallVectorImpl<WeakTrackingVH> &DeadInsts) {
82+
const DataLayout &DL = HandleGV->getDataLayout();
83+
84+
size_t Offset = BaseOffset;
85+
if (auto *GEP = dyn_cast<GEPOperator>(LI->getPointerOperand()))
86+
Offset += getOffsetForCBufferGEP(GEP, Global, DL);
87+
else if (LI->getPointerOperand() != Global)
88+
llvm_unreachable("Load instruction doesn't reference cbuffer global");
89+
90+
IRBuilder<> Builder(LI);
91+
auto *Handle = Builder.CreateLoad(HandleGV->getValueType(), HandleGV,
92+
HandleGV->getName());
93+
94+
Type *Ty = LI->getType();
95+
CBufferRowIntrin Intrin(DL, Ty->getScalarType());
96+
// The cbuffer consists of some number of 16-byte rows.
97+
unsigned int CurrentRow = Offset / hlsl::CBufferRowSizeInBytes;
98+
unsigned int CurrentIndex =
99+
(Offset % hlsl::CBufferRowSizeInBytes) / Intrin.EltSize;
100+
101+
auto *CBufLoad = Builder.CreateIntrinsic(
102+
Intrin.RetTy, Intrin.IID,
103+
{Handle, ConstantInt::get(Builder.getInt32Ty(), CurrentRow)}, nullptr,
104+
LI->getName());
105+
auto *Elt =
106+
Builder.CreateExtractValue(CBufLoad, {CurrentIndex++}, LI->getName());
107+
108+
Value *Result = nullptr;
109+
unsigned int Remaining =
110+
((DL.getTypeSizeInBits(Ty) / 8) / Intrin.EltSize) - 1;
111+
if (Remaining == 0) {
112+
// We only have a single element, so we're done.
113+
Result = Elt;
114+
115+
// However, if we loaded a <1 x T>, then we need to adjust the type here.
116+
if (auto *VT = dyn_cast<FixedVectorType>(LI->getType())) {
117+
assert(VT->getNumElements() == 1 && "Can't have multiple elements here");
118+
Result = Builder.CreateInsertElement(PoisonValue::get(VT), Result,
119+
Builder.getInt32(0));
120+
}
121+
} else {
122+
// Walk each element and extract it, wrapping to new rows as needed.
123+
SmallVector<Value *> Extracts{Elt};
124+
while (Remaining--) {
125+
CurrentIndex %= Intrin.NumElts;
126+
127+
if (CurrentIndex == 0)
128+
CBufLoad = Builder.CreateIntrinsic(
129+
Intrin.RetTy, Intrin.IID,
130+
{Handle, ConstantInt::get(Builder.getInt32Ty(), ++CurrentRow)},
131+
nullptr, LI->getName());
132+
133+
Extracts.push_back(Builder.CreateExtractValue(CBufLoad, {CurrentIndex++},
134+
LI->getName()));
135+
}
136+
137+
// Finally, we build up the original loaded value.
138+
Result = PoisonValue::get(Ty);
139+
for (int I = 0, E = Extracts.size(); I < E; ++I)
140+
Result =
141+
Builder.CreateInsertElement(Result, Extracts[I], Builder.getInt32(I));
142+
}
143+
144+
LI->replaceAllUsesWith(Result);
145+
DeadInsts.push_back(LI);
146+
}
147+
148+
static void replaceAccessesWithHandle(GlobalVariable *Global,
149+
GlobalVariable *HandleGV,
150+
size_t BaseOffset) {
151+
SmallVector<WeakTrackingVH> DeadInsts;
152+
153+
SmallVector<User *> ToProcess{Global->users()};
154+
while (!ToProcess.empty()) {
155+
User *Cur = ToProcess.pop_back_val();
156+
157+
// If we have a load instruction, replace the access.
158+
if (auto *LI = dyn_cast<LoadInst>(Cur)) {
159+
replaceAccess(LI, Global, HandleGV, BaseOffset, DeadInsts);
160+
continue;
161+
}
162+
163+
// Otherwise, walk users looking for a load...
164+
ToProcess.append(Cur->user_begin(), Cur->user_end());
165+
}
166+
RecursivelyDeleteTriviallyDeadInstructions(DeadInsts);
167+
}
168+
169+
static bool replaceCBufferAccesses(Module &M) {
170+
std::optional<hlsl::CBufferMetadata> CBufMD = hlsl::CBufferMetadata::get(M);
171+
if (!CBufMD)
172+
return false;
173+
174+
for (const hlsl::CBufferMapping &Mapping : *CBufMD)
175+
for (const hlsl::CBufferMember &Member : Mapping.Members) {
176+
replaceAccessesWithHandle(Member.GV, Mapping.Handle, Member.Offset);
177+
Member.GV->removeFromParent();
178+
}
179+
180+
CBufMD->eraseFromModule();
181+
return true;
182+
}
183+
184+
PreservedAnalyses DXILCBufferAccess::run(Module &M, ModuleAnalysisManager &AM) {
185+
PreservedAnalyses PA;
186+
bool Changed = replaceCBufferAccesses(M);
187+
188+
if (!Changed)
189+
return PreservedAnalyses::all();
190+
return PA;
191+
}
192+
193+
namespace {
194+
class DXILCBufferAccessLegacy : public ModulePass {
195+
public:
196+
bool runOnModule(Module &M) override { return replaceCBufferAccesses(M); }
197+
StringRef getPassName() const override { return "DXIL CBuffer Access"; }
198+
DXILCBufferAccessLegacy() : ModulePass(ID) {}
199+
200+
static char ID; // Pass identification.
201+
};
202+
char DXILCBufferAccessLegacy::ID = 0;
203+
} // end anonymous namespace
204+
205+
INITIALIZE_PASS(DXILCBufferAccessLegacy, DEBUG_TYPE, "DXIL CBuffer Access",
206+
false, false)
207+
208+
ModulePass *llvm::createDXILCBufferAccessLegacyPass() {
209+
return new DXILCBufferAccessLegacy();
210+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//===- DXILCBufferAccess.h - Translate CBuffer Loads ------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// \file Pass for replacing loads from cbuffers in the cbuffer address space to
10+
// cbuffer load intrinsics.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef LLVM_LIB_TARGET_DIRECTX_DXILCBUFFERACCESS_H
15+
#define LLVM_LIB_TARGET_DIRECTX_DXILCBUFFERACCESS_H
16+
17+
#include "llvm/IR/PassManager.h"
18+
19+
namespace llvm {
20+
21+
class DXILCBufferAccess : public PassInfoMixin<DXILCBufferAccess> {
22+
public:
23+
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
24+
};
25+
26+
} // namespace llvm
27+
28+
#endif // LLVM_LIB_TARGET_DIRECTX_DXILCBUFFERACCESS_H

0 commit comments

Comments
 (0)