Skip to content

[HLSL] Fix codegen to support classes in cbuffer #132828

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions clang/lib/CodeGen/HLSLBufferLayoutBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ static unsigned getScalarOrVectorSizeInBytes(llvm::Type *Ty) {
namespace clang {
namespace CodeGen {

// Creates a layout type for given struct with HLSL constant buffer layout
// taking into account PackOffsets, if provided.
// Creates a layout type for given struct or class with HLSL constant buffer
// layout taking into account PackOffsets, if provided.
// Previously created layout types are cached by CGHLSLRuntime.
//
// The function iterates over all fields of the StructType (including base
// The function iterates over all fields of the record type (including base
// classes) and calls layoutField to converts each field to its corresponding
// LLVM type and to calculate its HLSL constant buffer layout. Any embedded
// structs (or arrays of structs) are converted to target layout types as well.
Expand All @@ -67,12 +67,11 @@ namespace CodeGen {
// -1 value instead. These elements must be placed at the end of the layout
// after all of the elements with specific offset.
llvm::TargetExtType *HLSLBufferLayoutBuilder::createLayoutType(
const RecordType *StructType,
const llvm::SmallVector<int32_t> *PackOffsets) {
const RecordType *RT, const llvm::SmallVector<int32_t> *PackOffsets) {

// check if we already have the layout type for this struct
if (llvm::TargetExtType *Ty =
CGM.getHLSLRuntime().getHLSLBufferLayoutType(StructType))
CGM.getHLSLRuntime().getHLSLBufferLayoutType(RT))
return Ty;

SmallVector<unsigned> Layout;
Expand All @@ -87,7 +86,7 @@ llvm::TargetExtType *HLSLBufferLayoutBuilder::createLayoutType(

// iterate over all fields of the record, including fields on base classes
llvm::SmallVector<const RecordType *> RecordTypes;
RecordTypes.push_back(StructType);
RecordTypes.push_back(RT);
while (RecordTypes.back()->getAsCXXRecordDecl()->getNumBases()) {
CXXRecordDecl *D = RecordTypes.back()->getAsCXXRecordDecl();
assert(D->getNumBases() == 1 &&
Expand Down Expand Up @@ -148,7 +147,7 @@ llvm::TargetExtType *HLSLBufferLayoutBuilder::createLayoutType(

// create the layout struct type; anonymous struct have empty name but
// non-empty qualified name
const CXXRecordDecl *Decl = StructType->getAsCXXRecordDecl();
const CXXRecordDecl *Decl = RT->getAsCXXRecordDecl();
std::string Name =
Decl->getName().empty() ? "anon" : Decl->getQualifiedNameAsString();
llvm::StructType *StructTy =
Expand All @@ -158,7 +157,7 @@ llvm::TargetExtType *HLSLBufferLayoutBuilder::createLayoutType(
llvm::TargetExtType *NewLayoutTy = llvm::TargetExtType::get(
CGM.getLLVMContext(), LayoutTypeName, {StructTy}, Layout);
if (NewLayoutTy)
CGM.getHLSLRuntime().addHLSLBufferLayoutType(StructType, NewLayoutTy);
CGM.getHLSLRuntime().addHLSLBufferLayoutType(RT, NewLayoutTy);
return NewLayoutTy;
}

Expand Down Expand Up @@ -202,9 +201,9 @@ bool HLSLBufferLayoutBuilder::layoutField(const FieldDecl *FD,
}
// For array of structures, create a new array with a layout type
// instead of the structure type.
if (Ty->isStructureType()) {
if (Ty->isStructureOrClassType()) {
llvm::Type *NewTy =
cast<llvm::TargetExtType>(createLayoutType(Ty->getAsStructureType()));
cast<llvm::TargetExtType>(createLayoutType(Ty->getAs<RecordType>()));
if (!NewTy)
return false;
assert(isa<llvm::TargetExtType>(NewTy) && "expected target type");
Expand All @@ -220,9 +219,10 @@ bool HLSLBufferLayoutBuilder::layoutField(const FieldDecl *FD,
ArrayStride = llvm::alignTo(ElemSize, CBufferRowSizeInBytes);
ElemOffset = (Packoffset != -1) ? Packoffset : NextRowOffset;

} else if (FieldTy->isStructureType()) {
} else if (FieldTy->isStructureOrClassType()) {
// Create a layout type for the structure
ElemLayoutTy = createLayoutType(FieldTy->getAsStructureType());
ElemLayoutTy =
createLayoutType(cast<RecordType>(FieldTy->getAs<RecordType>()));
if (!ElemLayoutTy)
return false;
assert(isa<llvm::TargetExtType>(ElemLayoutTy) && "expected target type");
Expand Down
41 changes: 38 additions & 3 deletions clang/test/CodeGenHLSL/cbuffer.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
// CHECK: %C = type <{ i32, target("dx.Layout", %A, 8, 0) }>
// CHECK: %__cblayout_D = type <{ [2 x [3 x target("dx.Layout", %B, 14, 0, 8)]] }>

// CHECK: %__cblayout_CBClasses = type <{ target("dx.Layout", %K, 4, 0), target("dx.Layout", %L, 8, 0, 4),
// CHECK-SAME: target("dx.Layout", %M, 68, 0), [10 x target("dx.Layout", %K, 4, 0)] }>
// CHECK: %K = type <{ float }>
// CHECK: %L = type <{ float, float }>
// CHECK: %M = type <{ [5 x target("dx.Layout", %K, 4, 0)] }>

// CHECK: %__cblayout_CBMix = type <{ [2 x target("dx.Layout", %Test, 8, 0, 4)], float, [3 x [2 x <2 x float>]], float,
// CHECK-SAME: target("dx.Layout", %anon, 4, 0), double, target("dx.Layout", %anon.0, 8, 0), float, <1 x double>, i16 }>

Expand Down Expand Up @@ -133,6 +139,33 @@ cbuffer CBStructs {
uint16_t3 f;
};


class K {
float i;
};

class L : K {
float j;
};

class M {
K array[5];
};

cbuffer CBClasses {
K k;
L l;
M m;
K ka[10];
};

// CHECK: @CBClasses.cb = global target("dx.CBuffer", target("dx.Layout", %__cblayout_CBClasses,
// CHECK-SAME: 260, 0, 16, 32, 112))
// CHECK: @k = external addrspace(2) global target("dx.Layout", %K, 4, 0), align 4
// CHECK: @l = external addrspace(2) global target("dx.Layout", %L, 8, 0, 4), align 4
// CHECK: @m = external addrspace(2) global target("dx.Layout", %M, 68, 0), align 4
// CHECK: @ka = external addrspace(2) global [10 x target("dx.Layout", %K, 4, 0)], align 4

struct Test {
float a, b;
};
Expand Down Expand Up @@ -237,16 +270,16 @@ RWBuffer<float> Buf;

[numthreads(4,1,1)]
void main() {
Buf[0] = a1 + b1.z + c1[2] + a.f1.y + f1 + B1[0].x + B10.z + D1.B2;
Buf[0] = a1 + b1.z + c1[2] + a.f1.y + f1 + B1[0].x + ka[2].i + B10.z + D1.B2;
}

// CHECK: define internal void @_GLOBAL__sub_I_cbuffer.hlsl()
// CHECK-NEXT: entry:
// CHECK-NEXT: call void @_init_resource_CBScalars.cb()
// CHECK-NEXT: call void @_init_resource_CBArrays.cb()

// CHECK: !hlsl.cbs = !{![[CBSCALARS:[0-9]+]], ![[CBVECTORS:[0-9]+]], ![[CBARRAYS:[0-9]+]], ![[CBSTRUCTS:[0-9]+]], ![[CBMIX:[0-9]+]],
// CHECK-SAME: ![[CB_A:[0-9]+]], ![[CB_B:[0-9]+]], ![[CB_C:[0-9]+]]}
// CHECK: !hlsl.cbs = !{![[CBSCALARS:[0-9]+]], ![[CBVECTORS:[0-9]+]], ![[CBARRAYS:[0-9]+]], ![[CBSTRUCTS:[0-9]+]], ![[CBCLASSES:[0-9]+]],
// CHECK-SAME: ![[CBMIX:[0-9]+]], ![[CB_A:[0-9]+]], ![[CB_B:[0-9]+]], ![[CB_C:[0-9]+]]}

// CHECK: ![[CBSCALARS]] = !{ptr @CBScalars.cb, ptr addrspace(2) @a1, ptr addrspace(2) @a2, ptr addrspace(2) @a3, ptr addrspace(2) @a4,
// CHECK-SAME: ptr addrspace(2) @a5, ptr addrspace(2) @a6, ptr addrspace(2) @a7, ptr addrspace(2) @a8}
Expand All @@ -260,6 +293,8 @@ void main() {
// CHECK: ![[CBSTRUCTS]] = !{ptr @CBStructs.cb, ptr addrspace(2) @a, ptr addrspace(2) @b, ptr addrspace(2) @c, ptr addrspace(2) @array_of_A,
// CHECK-SAME: ptr addrspace(2) @d, ptr addrspace(2) @e, ptr addrspace(2) @f}

// CHECK: ![[CBCLASSES]] = !{ptr @CBClasses.cb, ptr addrspace(2) @k, ptr addrspace(2) @l, ptr addrspace(2) @m, ptr addrspace(2) @ka}

// CHECK: ![[CBMIX]] = !{ptr @CBMix.cb, ptr addrspace(2) @test, ptr addrspace(2) @f1, ptr addrspace(2) @f2, ptr addrspace(2) @f3,
// CHECK-SAME: ptr addrspace(2) @f4, ptr addrspace(2) @f5, ptr addrspace(2) @f6, ptr addrspace(2) @f7, ptr addrspace(2) @f8, ptr addrspace(2) @f9}

Expand Down