Skip to content

Commit 7f4ca90

Browse files
committed
[DirectX] Implement the DXILCBufferAccess pass
This introduces a pass that walks accesses to globals in cbuffers and replaces them with accesses via the cbuffer handle itself. The logic to interpret the cbuffer metadata is kept in `lib/Frontend/HLSL` so that it can be reused by other consumers of that metadata. Fixes #124630.
1 parent 6ce0fd7 commit 7f4ca90

File tree

15 files changed

+781
-0
lines changed

15 files changed

+781
-0
lines changed
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
//===- CBuffer.h - HLSL constant buffer handling ----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
///
9+
/// \file This file contains utilities to work with constant buffers in HLSL.
10+
///
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef LLVM_FRONTEND_HLSL_CBUFFER_H
14+
#define LLVM_FRONTEND_HLSL_CBUFFER_H
15+
16+
#include "llvm/ADT/SmallVector.h"
17+
#include "llvm/IR/DataLayout.h"
18+
#include "llvm/IR/DerivedTypes.h"
19+
#include <optional>
20+
21+
namespace llvm {
22+
class Module;
23+
class GlobalVariable;
24+
class NamedMDNode;
25+
26+
namespace hlsl {
27+
28+
struct CBufferMember {
29+
CBufferMember(GlobalVariable *GV, size_t Offset) : GV(GV), Offset(Offset) {}
30+
31+
GlobalVariable *GV;
32+
size_t Offset;
33+
};
34+
35+
struct CBufferMapping {
36+
CBufferMapping(GlobalVariable *Handle) : Handle(Handle) {}
37+
38+
GlobalVariable *Handle;
39+
SmallVector<CBufferMember> Members;
40+
};
41+
42+
class CBufferMetadata {
43+
NamedMDNode *MD;
44+
SmallVector<CBufferMapping> Mappings;
45+
46+
CBufferMetadata(NamedMDNode *MD) : MD(MD) {}
47+
48+
public:
49+
static std::optional<CBufferMetadata> get(Module &M);
50+
51+
using iterator = SmallVector<CBufferMapping>::iterator;
52+
iterator begin() { return Mappings.begin(); }
53+
iterator end() { return Mappings.end(); }
54+
55+
void eraseFromModule();
56+
};
57+
58+
APInt translateCBufArrayOffset(const DataLayout &DL, APInt Offset,
59+
ArrayType *Ty);
60+
61+
} // namespace hlsl
62+
} // namespace llvm
63+
64+
#endif // LLVM_FRONTEND_HLSL_CBUFFER_H

llvm/lib/Frontend/HLSL/CBuffer.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
//===- CBuffer.cpp - HLSL constant buffer handling ------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "llvm/Frontend/HLSL/CBuffer.h"
10+
#include "llvm/IR/DerivedTypes.h"
11+
#include "llvm/IR/Metadata.h"
12+
#include "llvm/IR/Module.h"
13+
14+
using namespace llvm;
15+
using namespace llvm::hlsl;
16+
17+
static size_t getMemberOffset(GlobalVariable *Handle, size_t Index) {
18+
auto *HandleTy = cast<TargetExtType>(Handle->getValueType());
19+
assert(HandleTy->getName().ends_with(".CBuffer") && "Not a cbuffer type");
20+
assert(HandleTy->getNumTypeParameters() == 1 && "Expected layout type");
21+
22+
auto *LayoutTy = cast<TargetExtType>(HandleTy->getTypeParameter(0));
23+
assert(LayoutTy->getName().ends_with(".Layout") && "Not a layout type");
24+
25+
// Skip the "size" parameter.
26+
size_t ParamIndex = Index + 1;
27+
assert(LayoutTy->getNumIntParameters() > ParamIndex &&
28+
"Not enough parameters");
29+
30+
return LayoutTy->getIntParameter(ParamIndex);
31+
}
32+
33+
std::optional<CBufferMetadata> CBufferMetadata::get(Module &M) {
34+
NamedMDNode *CBufMD = M.getNamedMetadata("hlsl.cbs");
35+
if (!CBufMD)
36+
return std::nullopt;
37+
38+
std::optional<CBufferMetadata> Result({CBufMD});
39+
40+
for (const MDNode *MD : CBufMD->operands()) {
41+
assert(MD->getNumOperands() && "Invalid cbuffer metadata");
42+
43+
auto *Handle = cast<GlobalVariable>(
44+
cast<ValueAsMetadata>(MD->getOperand(0))->getValue());
45+
CBufferMapping &Mapping = Result->Mappings.emplace_back(Handle);
46+
47+
for (int I = 1, E = MD->getNumOperands(); I < E; ++I) {
48+
Metadata *OpMD = MD->getOperand(I);
49+
// Some members may be null if they've been optimized out.
50+
if (!OpMD)
51+
continue;
52+
auto *V = cast<GlobalVariable>(cast<ValueAsMetadata>(OpMD)->getValue());
53+
Mapping.Members.emplace_back(V, getMemberOffset(Handle, I - 1));
54+
}
55+
}
56+
57+
return Result;
58+
}
59+
60+
61+
void CBufferMetadata::eraseFromModule() {
62+
// Remove the cbs named metadata
63+
MD->eraseFromParent();
64+
}
65+
66+
APInt hlsl::translateCBufArrayOffset(const DataLayout &DL, APInt Offset,
67+
ArrayType *Ty) {
68+
int64_t TypeSize = DL.getTypeSizeInBits(Ty->getElementType()) / 8;
69+
int64_t RoundUp = alignTo(TypeSize, Align(16));
70+
return Offset.udiv(TypeSize) * RoundUp;
71+
}

llvm/lib/Frontend/HLSL/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_llvm_component_library(LLVMFrontendHLSL
2+
CBuffer.cpp
23
HLSLResource.cpp
34

45
ADDITIONAL_HEADER_DIRS

llvm/lib/Target/DirectX/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ add_llvm_target(DirectXCodeGen
2020
DirectXTargetMachine.cpp
2121
DirectXTargetTransformInfo.cpp
2222
DXContainerGlobals.cpp
23+
DXILCBufferAccess.cpp
2324
DXILDataScalarization.cpp
2425
DXILFinalizeLinkage.cpp
2526
DXILFlattenArrays.cpp
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
//===- DXILCBufferAccess.cpp - Translate CBuffer Loads --------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "DXILCBufferAccess.h"
10+
#include "DirectX.h"
11+
#include "llvm/Frontend/HLSL/CBuffer.h"
12+
#include "llvm/IR/IRBuilder.h"
13+
#include "llvm/IR/IntrinsicsDirectX.h"
14+
#include "llvm/InitializePasses.h"
15+
#include "llvm/Pass.h"
16+
#include "llvm/Transforms/Utils/Local.h"
17+
18+
#define DEBUG_TYPE "dxil-cbuffer-access"
19+
using namespace llvm;
20+
21+
namespace {
22+
/// Helper for building a `load.cbufferrow` intrinsic given a simple type.
23+
struct CBufferRowIntrin {
24+
Intrinsic::ID IID;
25+
Type *RetTy;
26+
unsigned int EltSize;
27+
unsigned int NumElts;
28+
29+
CBufferRowIntrin(const DataLayout &DL, Type *Ty) {
30+
assert(Ty == Ty->getScalarType() && "Expected scalar type");
31+
32+
switch (DL.getTypeSizeInBits(Ty)) {
33+
case 16:
34+
IID = Intrinsic::dx_resource_load_cbufferrow_8;
35+
RetTy = StructType::get(Ty, Ty, Ty, Ty, Ty, Ty, Ty, Ty);
36+
EltSize = 2;
37+
NumElts = 8;
38+
break;
39+
case 32:
40+
IID = Intrinsic::dx_resource_load_cbufferrow_4;
41+
RetTy = StructType::get(Ty, Ty, Ty, Ty);
42+
EltSize = 4;
43+
NumElts = 4;
44+
break;
45+
case 64:
46+
IID = Intrinsic::dx_resource_load_cbufferrow_2;
47+
RetTy = StructType::get(Ty, Ty);
48+
EltSize = 8;
49+
NumElts = 2;
50+
break;
51+
default:
52+
llvm_unreachable("Only 16, 32, and 64 bit types supported");
53+
}
54+
}
55+
};
56+
} // namespace
57+
58+
static size_t getOffsetForCBufferGEP(GEPOperator *GEP, GlobalVariable *Global,
59+
const DataLayout &DL) {
60+
// Since we should always have a constant offset, we should only ever have a
61+
// single GEP of indirection from the Global.
62+
assert(GEP->getPointerOperand() == Global &&
63+
"Indirect access to resource handle");
64+
65+
APInt ConstantOffset(DL.getIndexTypeSizeInBits(GEP->getType()), 0);
66+
bool Success = GEP->accumulateConstantOffset(DL, ConstantOffset);
67+
(void)Success;
68+
assert(Success && "Offsets into cbuffer globals must be constant");
69+
70+
if (auto *ATy = dyn_cast<ArrayType>(Global->getValueType()))
71+
ConstantOffset = hlsl::translateCBufArrayOffset(DL, ConstantOffset, ATy);
72+
73+
return ConstantOffset.getZExtValue();
74+
}
75+
76+
/// Replace access via cbuffer global with a load from the cbuffer handle
77+
/// itself.
78+
static void replaceAccess(LoadInst *LI, GlobalVariable *Global,
79+
GlobalVariable *HandleGV, size_t BaseOffset,
80+
SmallVectorImpl<WeakTrackingVH> &DeadInsts) {
81+
const DataLayout &DL = HandleGV->getDataLayout();
82+
83+
size_t Offset = BaseOffset;
84+
if (auto *GEP = dyn_cast<GEPOperator>(LI->getPointerOperand()))
85+
Offset += getOffsetForCBufferGEP(GEP, Global, DL);
86+
else if (LI->getPointerOperand() != Global)
87+
llvm_unreachable("Load instruction doesn't reference cbuffer global");
88+
89+
IRBuilder<> Builder(LI);
90+
auto *Handle = Builder.CreateLoad(HandleGV->getValueType(), HandleGV,
91+
HandleGV->getName());
92+
93+
Type *Ty = LI->getType();
94+
CBufferRowIntrin Intrin(DL, Ty->getScalarType());
95+
// The cbuffer consists of some number of 16-byte rows.
96+
unsigned int CurrentRow = Offset / 16;
97+
unsigned int CurrentIndex = (Offset % 16) / Intrin.EltSize;
98+
99+
auto *CBufLoad = Builder.CreateIntrinsic(
100+
Intrin.RetTy, Intrin.IID,
101+
{Handle, ConstantInt::get(Builder.getInt32Ty(), CurrentRow)}, nullptr,
102+
LI->getName());
103+
auto *Elt =
104+
Builder.CreateExtractValue(CBufLoad, {CurrentIndex++}, LI->getName());
105+
106+
Value *Result = nullptr;
107+
unsigned int Remaining =
108+
((DL.getTypeSizeInBits(Ty) / 8) / Intrin.EltSize) - 1;
109+
if (Remaining == 0) {
110+
// We only have a single element, so we're done.
111+
Result = Elt;
112+
113+
// However, if we loaded a <1 x T>, then we need to adjust the type here.
114+
if (auto *VT = dyn_cast<FixedVectorType>(LI->getType()))
115+
if (VT->getNumElements() == 1)
116+
Result = Builder.CreateInsertElement(PoisonValue::get(VT), Result,
117+
Builder.getInt32(0));
118+
} else {
119+
// Walk each element and extract it, wrapping to new rows as needed.
120+
SmallVector<Value *> Extracts{Elt};
121+
while (Remaining--) {
122+
CurrentIndex %= Intrin.NumElts;
123+
124+
if (CurrentIndex == 0)
125+
CBufLoad = Builder.CreateIntrinsic(
126+
Intrin.RetTy, Intrin.IID,
127+
{Handle, ConstantInt::get(Builder.getInt32Ty(), ++CurrentRow)},
128+
nullptr, LI->getName());
129+
130+
Extracts.push_back(Builder.CreateExtractValue(CBufLoad, {CurrentIndex++},
131+
LI->getName()));
132+
}
133+
134+
// Finally, we build up the original loaded value.
135+
Result = PoisonValue::get(Ty);
136+
for (int I = 0, E = Extracts.size(); I < E; ++I)
137+
Result =
138+
Builder.CreateInsertElement(Result, Extracts[I], Builder.getInt32(I));
139+
}
140+
141+
LI->replaceAllUsesWith(Result);
142+
DeadInsts.push_back(LI);
143+
}
144+
145+
static void replaceAccessesWithHandle(GlobalVariable *Global,
146+
GlobalVariable *HandleGV,
147+
size_t BaseOffset) {
148+
SmallVector<WeakTrackingVH> DeadInsts;
149+
150+
SmallVector<User *> ToProcess{Global->users()};
151+
while (!ToProcess.empty()) {
152+
User *Cur = ToProcess.pop_back_val();
153+
154+
// If we have a load instruction, replace the access.
155+
if (auto *LI = dyn_cast<LoadInst>(Cur)) {
156+
replaceAccess(LI, Global, HandleGV, BaseOffset, DeadInsts);
157+
continue;
158+
}
159+
160+
// Otherwise, walk users looking for a load...
161+
ToProcess.append(Cur->user_begin(), Cur->user_end());
162+
}
163+
RecursivelyDeleteTriviallyDeadInstructions(DeadInsts);
164+
}
165+
166+
static bool replaceCBufferAccesses(Module &M) {
167+
std::optional<hlsl::CBufferMetadata> CBufMD = hlsl::CBufferMetadata::get(M);
168+
if (!CBufMD)
169+
return false;
170+
171+
for (const hlsl::CBufferMapping &Mapping : *CBufMD)
172+
for (const hlsl::CBufferMember &Member : Mapping.Members) {
173+
replaceAccessesWithHandle(Member.GV, Mapping.Handle, Member.Offset);
174+
Member.GV->removeFromParent();
175+
}
176+
177+
CBufMD->eraseFromModule();
178+
return true;
179+
}
180+
181+
PreservedAnalyses DXILCBufferAccess::run(Module &M, ModuleAnalysisManager &AM) {
182+
PreservedAnalyses PA;
183+
bool Changed = replaceCBufferAccesses(M);
184+
185+
if (!Changed)
186+
return PreservedAnalyses::all();
187+
return PA;
188+
}
189+
190+
namespace {
191+
class DXILCBufferAccessLegacy : public ModulePass {
192+
public:
193+
bool runOnModule(Module &M) override {
194+
return replaceCBufferAccesses(M);
195+
}
196+
StringRef getPassName() const override { return "DXIL CBuffer Access"; }
197+
DXILCBufferAccessLegacy() : ModulePass(ID) {}
198+
199+
static char ID; // Pass identification.
200+
};
201+
char DXILCBufferAccessLegacy::ID = 0;
202+
} // end anonymous namespace
203+
204+
INITIALIZE_PASS(DXILCBufferAccessLegacy, DEBUG_TYPE, "DXIL CBuffer Access",
205+
false, false)
206+
207+
ModulePass *llvm::createDXILCBufferAccessLegacyPass() {
208+
return new DXILCBufferAccessLegacy();
209+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//===- DXILCBufferAccess.h - Translate CBuffer Loads ------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// \file Pass for replacing loads from cbuffers in the cbuffer address space to
10+
// cbuffer load intrinsics.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef LLVM_LIB_TARGET_DIRECTX_DXILCBUFFERACCESS_H
15+
#define LLVM_LIB_TARGET_DIRECTX_DXILCBUFFERACCESS_H
16+
17+
#include "llvm/IR/PassManager.h"
18+
19+
namespace llvm {
20+
21+
class DXILCBufferAccess : public PassInfoMixin<DXILCBufferAccess> {
22+
public:
23+
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
24+
};
25+
26+
} // namespace llvm
27+
28+
#endif // LLVM_LIB_TARGET_DIRECTX_DXILCBUFFERACCESS_H

0 commit comments

Comments
 (0)