Skip to content

Commit f3ad57d

Browse files
committed
Implement memcpy support in DXIL CBufferAccess pass
1 parent f1988f4 commit f3ad57d

File tree

2 files changed

+416
-82
lines changed

2 files changed

+416
-82
lines changed

llvm/lib/Target/DirectX/DXILCBufferAccess.cpp

Lines changed: 212 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "llvm/InitializePasses.h"
1818
#include "llvm/Pass.h"
1919
#include "llvm/Support/ErrorHandling.h"
20+
#include "llvm/Support/FormatVariadic.h"
2021
#include "llvm/Transforms/Utils/Local.h"
2122

2223
#define DEBUG_TYPE "dxil-cbuffer-access"
@@ -57,109 +58,237 @@ struct CBufferRowIntrin {
5758
}
5859
}
5960
};
60-
} // namespace
6161

62-
static size_t getOffsetForCBufferGEP(GEPOperator *GEP, GlobalVariable *Global,
63-
const DataLayout &DL) {
64-
// Since we should always have a constant offset, we should only ever have a
65-
// single GEP of indirection from the Global.
66-
assert(GEP->getPointerOperand() == Global &&
67-
"Indirect access to resource handle");
62+
// Helper for creating CBuffer handles and loading data from them
63+
struct CBufferResource {
64+
GlobalVariable *GVHandle;
65+
GlobalVariable *Member;
66+
size_t MemberOffset;
6867

69-
APInt ConstantOffset(DL.getIndexTypeSizeInBits(GEP->getType()), 0);
70-
bool Success = GEP->accumulateConstantOffset(DL, ConstantOffset);
71-
(void)Success;
72-
assert(Success && "Offsets into cbuffer globals must be constant");
68+
LoadInst *Handle;
7369

74-
if (auto *ATy = dyn_cast<ArrayType>(Global->getValueType()))
75-
ConstantOffset = hlsl::translateCBufArrayOffset(DL, ConstantOffset, ATy);
70+
CBufferResource(GlobalVariable *GVHandle, GlobalVariable *Member,
71+
size_t MemberOffset)
72+
: GVHandle(GVHandle), Member(Member), MemberOffset(MemberOffset) {}
7673

77-
return ConstantOffset.getZExtValue();
78-
}
74+
const DataLayout &getDataLayout() { return GVHandle->getDataLayout(); }
75+
Type *getValueType() { return Member->getValueType(); }
76+
iterator_range<ConstantDataSequential::user_iterator> users() {
77+
return Member->users();
78+
}
7979

80-
/// Replace access via cbuffer global with a load from the cbuffer handle
81-
/// itself.
82-
static void replaceAccess(LoadInst *LI, GlobalVariable *Global,
83-
GlobalVariable *HandleGV, size_t BaseOffset,
84-
SmallVectorImpl<WeakTrackingVH> &DeadInsts) {
85-
const DataLayout &DL = HandleGV->getDataLayout();
80+
/// Get the byte offset of a Pointer-typed Value * `Val` relative to Member.
81+
/// `Val` can either be Member itself, or a GEP of a constant offset from
82+
/// Member
83+
size_t getOffsetForCBufferGEP(Value *Val) {
84+
assert(isa<PointerType>(Val->getType()) &&
85+
"Expected a pointer-typed value");
86+
87+
if (Val == Member)
88+
return 0;
89+
90+
if (auto *GEP = dyn_cast<GEPOperator>(Val)) {
91+
// Since we should always have a constant offset, we should only ever have
92+
// a single GEP of indirection from the Global.
93+
assert(GEP->getPointerOperand() == Member &&
94+
"Indirect access to resource handle");
95+
96+
const DataLayout &DL = getDataLayout();
97+
APInt ConstantOffset(DL.getIndexTypeSizeInBits(GEP->getType()), 0);
98+
bool Success = GEP->accumulateConstantOffset(DL, ConstantOffset);
99+
(void)Success;
100+
assert(Success && "Offsets into cbuffer globals must be constant");
101+
102+
if (auto *ATy = dyn_cast<ArrayType>(Member->getValueType()))
103+
ConstantOffset =
104+
hlsl::translateCBufArrayOffset(DL, ConstantOffset, ATy);
105+
106+
return ConstantOffset.getZExtValue();
107+
}
86108

87-
size_t Offset = BaseOffset;
88-
if (auto *GEP = dyn_cast<GEPOperator>(LI->getPointerOperand()))
89-
Offset += getOffsetForCBufferGEP(GEP, Global, DL);
90-
else if (LI->getPointerOperand() != Global)
91-
llvm_unreachable("Load instruction doesn't reference cbuffer global");
109+
llvm_unreachable("Invalid value passed to getOffsetFromPtr; it must be a "
110+
"GlobalVariable or GEP");
111+
}
92112

93-
IRBuilder<> Builder(LI);
94-
auto *Handle = Builder.CreateLoad(HandleGV->getValueType(), HandleGV,
95-
HandleGV->getName());
96-
97-
Type *Ty = LI->getType();
98-
CBufferRowIntrin Intrin(DL, Ty->getScalarType());
99-
// The cbuffer consists of some number of 16-byte rows.
100-
unsigned int CurrentRow = Offset / hlsl::CBufferRowSizeInBytes;
101-
unsigned int CurrentIndex =
102-
(Offset % hlsl::CBufferRowSizeInBytes) / Intrin.EltSize;
103-
104-
auto *CBufLoad = Builder.CreateIntrinsic(
105-
Intrin.RetTy, Intrin.IID,
106-
{Handle, ConstantInt::get(Builder.getInt32Ty(), CurrentRow)}, nullptr,
107-
LI->getName());
108-
auto *Elt =
109-
Builder.CreateExtractValue(CBufLoad, {CurrentIndex++}, LI->getName());
110-
111-
Value *Result = nullptr;
112-
unsigned int Remaining =
113-
((DL.getTypeSizeInBits(Ty) / 8) / Intrin.EltSize) - 1;
114-
if (Remaining == 0) {
115-
// We only have a single element, so we're done.
116-
Result = Elt;
117-
118-
// However, if we loaded a <1 x T>, then we need to adjust the type here.
119-
if (auto *VT = dyn_cast<FixedVectorType>(LI->getType())) {
120-
assert(VT->getNumElements() == 1 && "Can't have multiple elements here");
121-
Result = Builder.CreateInsertElement(PoisonValue::get(VT), Result,
122-
Builder.getInt32(0));
123-
}
124-
} else {
125-
// Walk each element and extract it, wrapping to new rows as needed.
126-
SmallVector<Value *> Extracts{Elt};
127-
while (Remaining--) {
128-
CurrentIndex %= Intrin.NumElts;
129-
130-
if (CurrentIndex == 0)
131-
CBufLoad = Builder.CreateIntrinsic(
132-
Intrin.RetTy, Intrin.IID,
133-
{Handle, ConstantInt::get(Builder.getInt32Ty(), ++CurrentRow)},
134-
nullptr, LI->getName());
135-
136-
Extracts.push_back(Builder.CreateExtractValue(CBufLoad, {CurrentIndex++},
137-
LI->getName()));
113+
/// Create a handle for this cbuffer resource using the IRBuilder `Builder`
114+
/// and sets the handle as the current one to use for subsequent calls to
115+
/// `loadValue`
116+
void createAndSetCurrentHandle(IRBuilder<> &Builder) {
117+
Handle = Builder.CreateLoad(GVHandle->getValueType(), GVHandle,
118+
GVHandle->getName());
119+
}
120+
121+
/// Load a value of type `Ty` at offset `Offset` using the handle from the
122+
/// last call to `createAndSetCurrentHandle`
123+
Value *loadValue(IRBuilder<> &Builder, Type *Ty, size_t Offset,
124+
const Twine &Name = "") {
125+
assert(Handle &&
126+
"Expected a handle for this cbuffer global resource to be created "
127+
"before loading a value from it");
128+
const DataLayout &DL = getDataLayout();
129+
130+
size_t TargetOffset = MemberOffset + Offset;
131+
CBufferRowIntrin Intrin(DL, Ty->getScalarType());
132+
// The cbuffer consists of some number of 16-byte rows.
133+
unsigned int CurrentRow = TargetOffset / hlsl::CBufferRowSizeInBytes;
134+
unsigned int CurrentIndex =
135+
(TargetOffset % hlsl::CBufferRowSizeInBytes) / Intrin.EltSize;
136+
137+
auto *CBufLoad = Builder.CreateIntrinsic(
138+
Intrin.RetTy, Intrin.IID,
139+
{Handle, ConstantInt::get(Builder.getInt32Ty(), CurrentRow)}, nullptr,
140+
Name + ".load");
141+
auto *Elt = Builder.CreateExtractValue(CBufLoad, {CurrentIndex++},
142+
Name + ".extract");
143+
144+
Value *Result = nullptr;
145+
unsigned int Remaining =
146+
((DL.getTypeSizeInBits(Ty) / 8) / Intrin.EltSize) - 1;
147+
if (Remaining == 0) {
148+
// We only have a single element, so we're done.
149+
Result = Elt;
150+
151+
// However, if we loaded a <1 x T>, then we need to adjust the type here.
152+
if (auto *VT = dyn_cast<FixedVectorType>(Ty)) {
153+
assert(VT->getNumElements() == 1 &&
154+
"Can't have multiple elements here");
155+
Result = Builder.CreateInsertElement(PoisonValue::get(VT), Result,
156+
Builder.getInt32(0), Name);
157+
}
158+
} else {
159+
// Walk each element and extract it, wrapping to new rows as needed.
160+
SmallVector<Value *> Extracts{Elt};
161+
while (Remaining--) {
162+
CurrentIndex %= Intrin.NumElts;
163+
164+
if (CurrentIndex == 0)
165+
CBufLoad = Builder.CreateIntrinsic(
166+
Intrin.RetTy, Intrin.IID,
167+
{Handle, ConstantInt::get(Builder.getInt32Ty(), ++CurrentRow)},
168+
nullptr, Name + ".load");
169+
170+
Extracts.push_back(Builder.CreateExtractValue(
171+
CBufLoad, {CurrentIndex++}, Name + ".extract"));
172+
}
173+
174+
// Finally, we build up the original loaded value.
175+
Result = PoisonValue::get(Ty);
176+
for (int I = 0, E = Extracts.size(); I < E; ++I)
177+
Result = Builder.CreateInsertElement(Result, Extracts[I],
178+
Builder.getInt32(I),
179+
Name + formatv(".upto{}", I));
138180
}
139181

140-
// Finally, we build up the original loaded value.
141-
Result = PoisonValue::get(Ty);
142-
for (int I = 0, E = Extracts.size(); I < E; ++I)
143-
Result =
144-
Builder.CreateInsertElement(Result, Extracts[I], Builder.getInt32(I));
182+
return Result;
145183
}
184+
};
146185

186+
} // namespace
187+
188+
/// Replace load via cbuffer global with a load from the cbuffer handle itself.
189+
static void replaceLoad(LoadInst *LI, CBufferResource &CBR,
190+
SmallVectorImpl<WeakTrackingVH> &DeadInsts) {
191+
size_t Offset = CBR.getOffsetForCBufferGEP(LI->getPointerOperand());
192+
IRBuilder<> Builder(LI);
193+
CBR.createAndSetCurrentHandle(Builder);
194+
Value *Result = CBR.loadValue(Builder, LI->getType(), Offset, LI->getName());
147195
LI->replaceAllUsesWith(Result);
148196
DeadInsts.push_back(LI);
149197
}
150198

151-
static void replaceAccessesWithHandle(GlobalVariable *Global,
152-
GlobalVariable *HandleGV,
153-
size_t BaseOffset) {
199+
/// Replace memcpy from a cbuffer global with a memcpy from the cbuffer handle
200+
/// itself. Assumes the cbuffer global is an array, and the length of bytes to
201+
/// copy is divisible by array element allocation size.
202+
/// The memcpy source must also be a direct cbuffer global reference, not a GEP.
203+
static void replaceMemCpy(MemCpyInst *MCI, CBufferResource &CBR,
204+
SmallVectorImpl<WeakTrackingVH> &DeadInsts) {
205+
206+
ArrayType *ArrTy = dyn_cast<ArrayType>(CBR.getValueType());
207+
assert(ArrTy && "MemCpy lowering is only supported for array types");
208+
209+
// This assumption vastly simplifies the implementation
210+
if (MCI->getSource() != CBR.Member)
211+
reportFatalUsageError(
212+
"Expected MemCpy source to be a cbuffer global variable");
213+
214+
const std::string Name = ("memcpy." + MCI->getDest()->getName() + "." +
215+
MCI->getSource()->getName())
216+
.str();
217+
218+
ConstantInt *Length = dyn_cast<ConstantInt>(MCI->getLength());
219+
uint64_t ByteLength = Length->getZExtValue();
220+
221+
// If length to copy is zero, no memcpy is needed
222+
if (ByteLength == 0) {
223+
DeadInsts.push_back(MCI);
224+
return;
225+
}
226+
227+
const DataLayout &DL = CBR.getDataLayout();
228+
229+
Type *ElemTy = ArrTy->getElementType();
230+
size_t ElemSize = DL.getTypeAllocSize(ElemTy);
231+
assert(ByteLength % ElemSize == 0 &&
232+
"Length of bytes to MemCpy must be divisible by allocation size of "
233+
"source/destination array elements");
234+
size_t ElemsToCpy = ByteLength / ElemSize;
235+
236+
IRBuilder<> Builder(MCI);
237+
CBR.createAndSetCurrentHandle(Builder);
238+
239+
auto CopyElemsImpl = [&Builder, &MCI, &Name, &CBR,
240+
&DL](const auto &Self, ArrayType *ArrTy,
241+
size_t ArrOffset, size_t N) -> void {
242+
Type *ElemTy = ArrTy->getElementType();
243+
size_t ElemTySize = DL.getTypeAllocSize(ElemTy);
244+
for (unsigned I = 0; I < N; ++I) {
245+
size_t Offset = ArrOffset + I * ElemTySize;
246+
247+
// Recursively copy nested arrays
248+
if (ArrayType *ElemArrTy = dyn_cast<ArrayType>(ElemTy)) {
249+
Self(Self, ElemArrTy, Offset, ElemArrTy->getNumElements());
250+
continue;
251+
}
252+
253+
// Load CBuffer value and store it in Dest
254+
APInt CBufArrayOffset(
255+
DL.getIndexTypeSizeInBits(MCI->getSource()->getType()), Offset);
256+
CBufArrayOffset =
257+
hlsl::translateCBufArrayOffset(DL, CBufArrayOffset, ArrTy);
258+
Value *CBufferVal =
259+
CBR.loadValue(Builder, ElemTy, CBufArrayOffset.getZExtValue(), Name);
260+
Value *GEP =
261+
Builder.CreateInBoundsGEP(Builder.getInt8Ty(), MCI->getDest(),
262+
{Builder.getInt32(Offset)}, Name + ".dest");
263+
Builder.CreateStore(CBufferVal, GEP, MCI->isVolatile());
264+
}
265+
};
266+
auto CopyElems = [&CopyElemsImpl](ArrayType *ArrTy, size_t N) -> void {
267+
CopyElemsImpl(CopyElemsImpl, ArrTy, 0, N);
268+
};
269+
270+
CopyElems(ArrTy, ElemsToCpy);
271+
272+
MCI->eraseFromParent();
273+
}
274+
275+
static void replaceAccessesWithHandle(CBufferResource &CBR) {
154276
SmallVector<WeakTrackingVH> DeadInsts;
155277

156-
SmallVector<User *> ToProcess{Global->users()};
278+
SmallVector<User *> ToProcess{CBR.users()};
157279
while (!ToProcess.empty()) {
158280
User *Cur = ToProcess.pop_back_val();
159281

160282
// If we have a load instruction, replace the access.
161283
if (auto *LI = dyn_cast<LoadInst>(Cur)) {
162-
replaceAccess(LI, Global, HandleGV, BaseOffset, DeadInsts);
284+
replaceLoad(LI, CBR, DeadInsts);
285+
continue;
286+
}
287+
288+
// If we have a memcpy instruction, replace it with multiple accesses and
289+
// subsequent stores to the destination
290+
if (auto *MCI = dyn_cast<MemCpyInst>(Cur)) {
291+
replaceMemCpy(MCI, CBR, DeadInsts);
163292
continue;
164293
}
165294

@@ -181,7 +310,8 @@ static bool replaceCBufferAccesses(Module &M) {
181310

182311
for (const hlsl::CBufferMapping &Mapping : *CBufMD)
183312
for (const hlsl::CBufferMember &Member : Mapping.Members) {
184-
replaceAccessesWithHandle(Member.GV, Mapping.Handle, Member.Offset);
313+
CBufferResource CBR(Mapping.Handle, Member.GV, Member.Offset);
314+
replaceAccessesWithHandle(CBR);
185315
Member.GV->removeFromParent();
186316
}
187317

0 commit comments

Comments
 (0)