-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[HLSL] Fix codegen to support classes in cbuffer
#132828
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-clang-codegen @llvm/pr-subscribers-hlsl Author: Helena Kotas (hekota) ChangesFixes #132309 Full diff: https://github.com/llvm/llvm-project/pull/132828.diff 2 Files Affected:
diff --git a/clang/lib/CodeGen/HLSLBufferLayoutBuilder.cpp b/clang/lib/CodeGen/HLSLBufferLayoutBuilder.cpp
index e0f5b0f59ef40..b546b6dd574ff 100644
--- a/clang/lib/CodeGen/HLSLBufferLayoutBuilder.cpp
+++ b/clang/lib/CodeGen/HLSLBufferLayoutBuilder.cpp
@@ -52,11 +52,11 @@ static unsigned getScalarOrVectorSizeInBytes(llvm::Type *Ty) {
namespace clang {
namespace CodeGen {
-// Creates a layout type for given struct with HLSL constant buffer layout
-// taking into account PackOffsets, if provided.
+// Creates a layout type for given struct or class with HLSL constant buffer
+// layout taking into account PackOffsets, if provided.
// Previously created layout types are cached by CGHLSLRuntime.
//
-// The function iterates over all fields of the StructType (including base
+// The function iterates over all fields of the record type (including base
// classes) and calls layoutField to converts each field to its corresponding
// LLVM type and to calculate its HLSL constant buffer layout. Any embedded
// structs (or arrays of structs) are converted to target layout types as well.
@@ -67,12 +67,11 @@ namespace CodeGen {
// -1 value instead. These elements must be placed at the end of the layout
// after all of the elements with specific offset.
llvm::TargetExtType *HLSLBufferLayoutBuilder::createLayoutType(
- const RecordType *StructType,
- const llvm::SmallVector<int32_t> *PackOffsets) {
+ const RecordType *RT, const llvm::SmallVector<int32_t> *PackOffsets) {
// check if we already have the layout type for this struct
if (llvm::TargetExtType *Ty =
- CGM.getHLSLRuntime().getHLSLBufferLayoutType(StructType))
+ CGM.getHLSLRuntime().getHLSLBufferLayoutType(RT))
return Ty;
SmallVector<unsigned> Layout;
@@ -87,7 +86,7 @@ llvm::TargetExtType *HLSLBufferLayoutBuilder::createLayoutType(
// iterate over all fields of the record, including fields on base classes
llvm::SmallVector<const RecordType *> RecordTypes;
- RecordTypes.push_back(StructType);
+ RecordTypes.push_back(RT);
while (RecordTypes.back()->getAsCXXRecordDecl()->getNumBases()) {
CXXRecordDecl *D = RecordTypes.back()->getAsCXXRecordDecl();
assert(D->getNumBases() == 1 &&
@@ -148,7 +147,7 @@ llvm::TargetExtType *HLSLBufferLayoutBuilder::createLayoutType(
// create the layout struct type; anonymous struct have empty name but
// non-empty qualified name
- const CXXRecordDecl *Decl = StructType->getAsCXXRecordDecl();
+ const CXXRecordDecl *Decl = RT->getAsCXXRecordDecl();
std::string Name =
Decl->getName().empty() ? "anon" : Decl->getQualifiedNameAsString();
llvm::StructType *StructTy =
@@ -158,7 +157,7 @@ llvm::TargetExtType *HLSLBufferLayoutBuilder::createLayoutType(
llvm::TargetExtType *NewLayoutTy = llvm::TargetExtType::get(
CGM.getLLVMContext(), LayoutTypeName, {StructTy}, Layout);
if (NewLayoutTy)
- CGM.getHLSLRuntime().addHLSLBufferLayoutType(StructType, NewLayoutTy);
+ CGM.getHLSLRuntime().addHLSLBufferLayoutType(RT, NewLayoutTy);
return NewLayoutTy;
}
@@ -202,9 +201,9 @@ bool HLSLBufferLayoutBuilder::layoutField(const FieldDecl *FD,
}
// For array of structures, create a new array with a layout type
// instead of the structure type.
- if (Ty->isStructureType()) {
+ if (Ty->isStructureOrClassType()) {
llvm::Type *NewTy =
- cast<llvm::TargetExtType>(createLayoutType(Ty->getAsStructureType()));
+ cast<llvm::TargetExtType>(createLayoutType(Ty->getAs<RecordType>()));
if (!NewTy)
return false;
assert(isa<llvm::TargetExtType>(NewTy) && "expected target type");
@@ -220,9 +219,10 @@ bool HLSLBufferLayoutBuilder::layoutField(const FieldDecl *FD,
ArrayStride = llvm::alignTo(ElemSize, CBufferRowSizeInBytes);
ElemOffset = (Packoffset != -1) ? Packoffset : NextRowOffset;
- } else if (FieldTy->isStructureType()) {
+ } else if (FieldTy->isStructureOrClassType()) {
// Create a layout type for the structure
- ElemLayoutTy = createLayoutType(FieldTy->getAsStructureType());
+ ElemLayoutTy =
+ createLayoutType(cast<RecordType>(FieldTy->getAs<RecordType>()));
if (!ElemLayoutTy)
return false;
assert(isa<llvm::TargetExtType>(ElemLayoutTy) && "expected target type");
diff --git a/clang/test/CodeGenHLSL/cbuffer.hlsl b/clang/test/CodeGenHLSL/cbuffer.hlsl
index 98948ea6811e3..db06cea808b62 100644
--- a/clang/test/CodeGenHLSL/cbuffer.hlsl
+++ b/clang/test/CodeGenHLSL/cbuffer.hlsl
@@ -13,6 +13,12 @@
// CHECK: %C = type <{ i32, target("dx.Layout", %A, 8, 0) }>
// CHECK: %__cblayout_D = type <{ [2 x [3 x target("dx.Layout", %B, 14, 0, 8)]] }>
+// CHECK: %__cblayout_CBClasses = type <{ target("dx.Layout", %K, 4, 0), target("dx.Layout", %L, 8, 0, 4),
+// CHECK-SAME: target("dx.Layout", %M, 68, 0), [10 x target("dx.Layout", %K, 4, 0)] }>
+// CHECK: %K = type <{ float }>
+// CHECK: %L = type <{ float, float }>
+// CHECK: %M = type <{ [5 x target("dx.Layout", %K, 4, 0)] }>
+
// CHECK: %__cblayout_CBMix = type <{ [2 x target("dx.Layout", %Test, 8, 0, 4)], float, [3 x [2 x <2 x float>]], float,
// CHECK-SAME: target("dx.Layout", %anon, 4, 0), double, target("dx.Layout", %anon.0, 8, 0), float, <1 x double>, i16 }>
@@ -133,6 +139,33 @@ cbuffer CBStructs {
uint16_t3 f;
};
+
+class K {
+ float i;
+};
+
+class L : K {
+ float j;
+};
+
+class M {
+ K array[5];
+};
+
+cbuffer CBClasses {
+ K k;
+ L l;
+ M m;
+ K ka[10];
+};
+
+// CHECK: @CBClasses.cb = global target("dx.CBuffer", target("dx.Layout", %__cblayout_CBClasses,
+// CHECK-SAME: 260, 0, 16, 32, 112))
+// CHECK: @k = external addrspace(2) global target("dx.Layout", %K, 4, 0), align 4
+// CHECK: @l = external addrspace(2) global target("dx.Layout", %L, 8, 0, 4), align 4
+// CHECK: @m = external addrspace(2) global target("dx.Layout", %M, 68, 0), align 4
+// CHECK: @ka = external addrspace(2) global [10 x target("dx.Layout", %K, 4, 0)], align 4
+
struct Test {
float a, b;
};
@@ -237,7 +270,7 @@ RWBuffer<float> Buf;
[numthreads(4,1,1)]
void main() {
- Buf[0] = a1 + b1.z + c1[2] + a.f1.y + f1 + B1[0].x + B10.z + D1.B2;
+ Buf[0] = a1 + b1.z + c1[2] + a.f1.y + f1 + B1[0].x + ka[2].i + B10.z + D1.B2;
}
// CHECK: define internal void @_GLOBAL__sub_I_cbuffer.hlsl()
@@ -245,8 +278,8 @@ void main() {
// CHECK-NEXT: call void @_init_resource_CBScalars.cb()
// CHECK-NEXT: call void @_init_resource_CBArrays.cb()
-// CHECK: !hlsl.cbs = !{![[CBSCALARS:[0-9]+]], ![[CBVECTORS:[0-9]+]], ![[CBARRAYS:[0-9]+]], ![[CBSTRUCTS:[0-9]+]], ![[CBMIX:[0-9]+]],
-// CHECK-SAME: ![[CB_A:[0-9]+]], ![[CB_B:[0-9]+]], ![[CB_C:[0-9]+]]}
+// CHECK: !hlsl.cbs = !{![[CBSCALARS:[0-9]+]], ![[CBVECTORS:[0-9]+]], ![[CBARRAYS:[0-9]+]], ![[CBSTRUCTS:[0-9]+]], ![[CBCLASSES:[0-9]+]],
+// CHECK-SAME: ![[CBMIX:[0-9]+]], ![[CB_A:[0-9]+]], ![[CB_B:[0-9]+]], ![[CB_C:[0-9]+]]}
// CHECK: ![[CBSCALARS]] = !{ptr @CBScalars.cb, ptr addrspace(2) @a1, ptr addrspace(2) @a2, ptr addrspace(2) @a3, ptr addrspace(2) @a4,
// CHECK-SAME: ptr addrspace(2) @a5, ptr addrspace(2) @a6, ptr addrspace(2) @a7, ptr addrspace(2) @a8}
@@ -260,6 +293,8 @@ void main() {
// CHECK: ![[CBSTRUCTS]] = !{ptr @CBStructs.cb, ptr addrspace(2) @a, ptr addrspace(2) @b, ptr addrspace(2) @c, ptr addrspace(2) @array_of_A,
// CHECK-SAME: ptr addrspace(2) @d, ptr addrspace(2) @e, ptr addrspace(2) @f}
+// CHECK: ![[CBCLASSES]] = !{ptr @CBClasses.cb, ptr addrspace(2) @k, ptr addrspace(2) @l, ptr addrspace(2) @m, ptr addrspace(2) @ka}
+
// CHECK: ![[CBMIX]] = !{ptr @CBMix.cb, ptr addrspace(2) @test, ptr addrspace(2) @f1, ptr addrspace(2) @f2, ptr addrspace(2) @f3,
// CHECK-SAME: ptr addrspace(2) @f4, ptr addrspace(2) @f5, ptr addrspace(2) @f6, ptr addrspace(2) @f7, ptr addrspace(2) @f8, ptr addrspace(2) @f9}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but probably need someone with more area expertise than me to approve it.
Fixes #132309