Skip to content

Commit dcc2fae

Browse files
authored
[HLSL] Fix codegen to support classes in cbuffer (#132828)
Fixes #132309
1 parent 514f984 commit dcc2fae

File tree

2 files changed

+51
-16
lines changed

2 files changed

+51
-16
lines changed

clang/lib/CodeGen/HLSLBufferLayoutBuilder.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,11 @@ static unsigned getScalarOrVectorSizeInBytes(llvm::Type *Ty) {
5252
namespace clang {
5353
namespace CodeGen {
5454

55-
// Creates a layout type for given struct with HLSL constant buffer layout
56-
// taking into account PackOffsets, if provided.
55+
// Creates a layout type for given struct or class with HLSL constant buffer
56+
// layout taking into account PackOffsets, if provided.
5757
// Previously created layout types are cached by CGHLSLRuntime.
5858
//
59-
// The function iterates over all fields of the StructType (including base
59+
// The function iterates over all fields of the record type (including base
6060
// classes) and calls layoutField to converts each field to its corresponding
6161
// LLVM type and to calculate its HLSL constant buffer layout. Any embedded
6262
// structs (or arrays of structs) are converted to target layout types as well.
@@ -67,12 +67,11 @@ namespace CodeGen {
6767
// -1 value instead. These elements must be placed at the end of the layout
6868
// after all of the elements with specific offset.
6969
llvm::TargetExtType *HLSLBufferLayoutBuilder::createLayoutType(
70-
const RecordType *StructType,
71-
const llvm::SmallVector<int32_t> *PackOffsets) {
70+
const RecordType *RT, const llvm::SmallVector<int32_t> *PackOffsets) {
7271

7372
// check if we already have the layout type for this struct
7473
if (llvm::TargetExtType *Ty =
75-
CGM.getHLSLRuntime().getHLSLBufferLayoutType(StructType))
74+
CGM.getHLSLRuntime().getHLSLBufferLayoutType(RT))
7675
return Ty;
7776

7877
SmallVector<unsigned> Layout;
@@ -87,7 +86,7 @@ llvm::TargetExtType *HLSLBufferLayoutBuilder::createLayoutType(
8786

8887
// iterate over all fields of the record, including fields on base classes
8988
llvm::SmallVector<const RecordType *> RecordTypes;
90-
RecordTypes.push_back(StructType);
89+
RecordTypes.push_back(RT);
9190
while (RecordTypes.back()->getAsCXXRecordDecl()->getNumBases()) {
9291
CXXRecordDecl *D = RecordTypes.back()->getAsCXXRecordDecl();
9392
assert(D->getNumBases() == 1 &&
@@ -148,7 +147,7 @@ llvm::TargetExtType *HLSLBufferLayoutBuilder::createLayoutType(
148147

149148
// create the layout struct type; anonymous struct have empty name but
150149
// non-empty qualified name
151-
const CXXRecordDecl *Decl = StructType->getAsCXXRecordDecl();
150+
const CXXRecordDecl *Decl = RT->getAsCXXRecordDecl();
152151
std::string Name =
153152
Decl->getName().empty() ? "anon" : Decl->getQualifiedNameAsString();
154153
llvm::StructType *StructTy =
@@ -158,7 +157,7 @@ llvm::TargetExtType *HLSLBufferLayoutBuilder::createLayoutType(
158157
llvm::TargetExtType *NewLayoutTy = llvm::TargetExtType::get(
159158
CGM.getLLVMContext(), LayoutTypeName, {StructTy}, Layout);
160159
if (NewLayoutTy)
161-
CGM.getHLSLRuntime().addHLSLBufferLayoutType(StructType, NewLayoutTy);
160+
CGM.getHLSLRuntime().addHLSLBufferLayoutType(RT, NewLayoutTy);
162161
return NewLayoutTy;
163162
}
164163

@@ -202,9 +201,9 @@ bool HLSLBufferLayoutBuilder::layoutField(const FieldDecl *FD,
202201
}
203202
// For array of structures, create a new array with a layout type
204203
// instead of the structure type.
205-
if (Ty->isStructureType()) {
204+
if (Ty->isStructureOrClassType()) {
206205
llvm::Type *NewTy =
207-
cast<llvm::TargetExtType>(createLayoutType(Ty->getAsStructureType()));
206+
cast<llvm::TargetExtType>(createLayoutType(Ty->getAs<RecordType>()));
208207
if (!NewTy)
209208
return false;
210209
assert(isa<llvm::TargetExtType>(NewTy) && "expected target type");
@@ -220,9 +219,10 @@ bool HLSLBufferLayoutBuilder::layoutField(const FieldDecl *FD,
220219
ArrayStride = llvm::alignTo(ElemSize, CBufferRowSizeInBytes);
221220
ElemOffset = (Packoffset != -1) ? Packoffset : NextRowOffset;
222221

223-
} else if (FieldTy->isStructureType()) {
222+
} else if (FieldTy->isStructureOrClassType()) {
224223
// Create a layout type for the structure
225-
ElemLayoutTy = createLayoutType(FieldTy->getAsStructureType());
224+
ElemLayoutTy =
225+
createLayoutType(cast<RecordType>(FieldTy->getAs<RecordType>()));
226226
if (!ElemLayoutTy)
227227
return false;
228228
assert(isa<llvm::TargetExtType>(ElemLayoutTy) && "expected target type");

clang/test/CodeGenHLSL/cbuffer.hlsl

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@
1313
// CHECK: %C = type <{ i32, target("dx.Layout", %A, 8, 0) }>
1414
// CHECK: %__cblayout_D = type <{ [2 x [3 x target("dx.Layout", %B, 14, 0, 8)]] }>
1515

16+
// CHECK: %__cblayout_CBClasses = type <{ target("dx.Layout", %K, 4, 0), target("dx.Layout", %L, 8, 0, 4),
17+
// CHECK-SAME: target("dx.Layout", %M, 68, 0), [10 x target("dx.Layout", %K, 4, 0)] }>
18+
// CHECK: %K = type <{ float }>
19+
// CHECK: %L = type <{ float, float }>
20+
// CHECK: %M = type <{ [5 x target("dx.Layout", %K, 4, 0)] }>
21+
1622
// CHECK: %__cblayout_CBMix = type <{ [2 x target("dx.Layout", %Test, 8, 0, 4)], float, [3 x [2 x <2 x float>]], float,
1723
// CHECK-SAME: target("dx.Layout", %anon, 4, 0), double, target("dx.Layout", %anon.0, 8, 0), float, <1 x double>, i16 }>
1824

@@ -133,6 +139,33 @@ cbuffer CBStructs {
133139
uint16_t3 f;
134140
};
135141

142+
143+
class K {
144+
float i;
145+
};
146+
147+
class L : K {
148+
float j;
149+
};
150+
151+
class M {
152+
K array[5];
153+
};
154+
155+
cbuffer CBClasses {
156+
K k;
157+
L l;
158+
M m;
159+
K ka[10];
160+
};
161+
162+
// CHECK: @CBClasses.cb = global target("dx.CBuffer", target("dx.Layout", %__cblayout_CBClasses,
163+
// CHECK-SAME: 260, 0, 16, 32, 112))
164+
// CHECK: @k = external addrspace(2) global target("dx.Layout", %K, 4, 0), align 4
165+
// CHECK: @l = external addrspace(2) global target("dx.Layout", %L, 8, 0, 4), align 4
166+
// CHECK: @m = external addrspace(2) global target("dx.Layout", %M, 68, 0), align 4
167+
// CHECK: @ka = external addrspace(2) global [10 x target("dx.Layout", %K, 4, 0)], align 4
168+
136169
struct Test {
137170
float a, b;
138171
};
@@ -237,16 +270,16 @@ RWBuffer<float> Buf;
237270

238271
[numthreads(4,1,1)]
239272
void main() {
240-
Buf[0] = a1 + b1.z + c1[2] + a.f1.y + f1 + B1[0].x + B10.z + D1.B2;
273+
Buf[0] = a1 + b1.z + c1[2] + a.f1.y + f1 + B1[0].x + ka[2].i + B10.z + D1.B2;
241274
}
242275

243276
// CHECK: define internal void @_GLOBAL__sub_I_cbuffer.hlsl()
244277
// CHECK-NEXT: entry:
245278
// CHECK-NEXT: call void @_init_resource_CBScalars.cb()
246279
// CHECK-NEXT: call void @_init_resource_CBArrays.cb()
247280

248-
// CHECK: !hlsl.cbs = !{![[CBSCALARS:[0-9]+]], ![[CBVECTORS:[0-9]+]], ![[CBARRAYS:[0-9]+]], ![[CBSTRUCTS:[0-9]+]], ![[CBMIX:[0-9]+]],
249-
// CHECK-SAME: ![[CB_A:[0-9]+]], ![[CB_B:[0-9]+]], ![[CB_C:[0-9]+]]}
281+
// CHECK: !hlsl.cbs = !{![[CBSCALARS:[0-9]+]], ![[CBVECTORS:[0-9]+]], ![[CBARRAYS:[0-9]+]], ![[CBSTRUCTS:[0-9]+]], ![[CBCLASSES:[0-9]+]],
282+
// CHECK-SAME: ![[CBMIX:[0-9]+]], ![[CB_A:[0-9]+]], ![[CB_B:[0-9]+]], ![[CB_C:[0-9]+]]}
250283

251284
// CHECK: ![[CBSCALARS]] = !{ptr @CBScalars.cb, ptr addrspace(2) @a1, ptr addrspace(2) @a2, ptr addrspace(2) @a3, ptr addrspace(2) @a4,
252285
// CHECK-SAME: ptr addrspace(2) @a5, ptr addrspace(2) @a6, ptr addrspace(2) @a7, ptr addrspace(2) @a8}
@@ -260,6 +293,8 @@ void main() {
260293
// CHECK: ![[CBSTRUCTS]] = !{ptr @CBStructs.cb, ptr addrspace(2) @a, ptr addrspace(2) @b, ptr addrspace(2) @c, ptr addrspace(2) @array_of_A,
261294
// CHECK-SAME: ptr addrspace(2) @d, ptr addrspace(2) @e, ptr addrspace(2) @f}
262295

296+
// CHECK: ![[CBCLASSES]] = !{ptr @CBClasses.cb, ptr addrspace(2) @k, ptr addrspace(2) @l, ptr addrspace(2) @m, ptr addrspace(2) @ka}
297+
263298
// CHECK: ![[CBMIX]] = !{ptr @CBMix.cb, ptr addrspace(2) @test, ptr addrspace(2) @f1, ptr addrspace(2) @f2, ptr addrspace(2) @f3,
264299
// CHECK-SAME: ptr addrspace(2) @f4, ptr addrspace(2) @f5, ptr addrspace(2) @f6, ptr addrspace(2) @f7, ptr addrspace(2) @f8, ptr addrspace(2) @f9}
265300

0 commit comments

Comments
 (0)