@@ -7340,14 +7340,13 @@ struct register_binding_flags {
7340
7340
bool other = false ;
7341
7341
bool basic = false ;
7342
7342
7343
- bool srv;
7344
- bool uav;
7345
- bool cbv;
7346
- bool sampler;
7343
+ bool srv = false ;
7344
+ bool uav = false ;
7345
+ bool cbv = false ;
7346
+ bool sampler = false ;
7347
7347
7348
7348
bool contains_numeric = false ;
7349
7349
bool default_globals = false ;
7350
- bool is_member = false ;
7351
7350
};
7352
7351
7353
7352
bool isDeclaredWithinCOrTBuffer (const Decl *decl) {
@@ -7394,6 +7393,63 @@ const HLSLResourceAttr *getHLSLResourceAttrFromVarDecl(VarDecl *SamplerUAVOrSRV)
7394
7393
return Attr;
7395
7394
}
7396
7395
7396
+ void traverseType (QualType T, register_binding_flags &r) {
7397
+ if (T->isIntegralOrEnumerationType () || T->isFloatingType ()) {
7398
+ r.contains_numeric = true ;
7399
+ return ;
7400
+ } else if (const RecordType *RT = T->getAs <RecordType>()) {
7401
+ RecordDecl *SubRD = RT->getDecl ();
7402
+ if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(SubRD)) {
7403
+ auto TheRecordDecl = TDecl->getSpecializedTemplate ()->getTemplatedDecl ();
7404
+ TheRecordDecl = TheRecordDecl->getCanonicalDecl ();
7405
+ const auto *Attr = TheRecordDecl->getAttr <HLSLResourceAttr>();
7406
+ llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass ();
7407
+ switch (DeclResourceClass) {
7408
+ case llvm::hlsl::ResourceClass::SRV: {
7409
+ r.srv = true ;
7410
+ break ;
7411
+ }
7412
+ case llvm::hlsl::ResourceClass::UAV: {
7413
+ r.uav = true ;
7414
+ break ;
7415
+ }
7416
+ case llvm::hlsl::ResourceClass::CBuffer: {
7417
+ r.cbv = true ;
7418
+ break ;
7419
+ }
7420
+ case llvm::hlsl::ResourceClass::Sampler: {
7421
+ r.sampler = true ;
7422
+ break ;
7423
+ }
7424
+ case llvm::hlsl::ResourceClass::Invalid: {
7425
+ llvm_unreachable (" Resource class should be valid." );
7426
+ break ;
7427
+ }
7428
+ }
7429
+ }
7430
+
7431
+ else if (SubRD->isCompleteDefinition ()) {
7432
+ for (auto Field : SubRD->fields ()) {
7433
+ QualType T = Field->getType ();
7434
+ traverseType (T, r);
7435
+ }
7436
+ }
7437
+ }
7438
+ }
7439
+
7440
+ void setResourceClassFlagsFromRecordDecl (register_binding_flags &r,
7441
+ const RecordDecl *RD) {
7442
+ if (!RD)
7443
+ return ;
7444
+
7445
+ if (RD->isCompleteDefinition ()) {
7446
+ for (auto Field : RD->fields ()) {
7447
+ QualType T = Field->getType ();
7448
+ traverseType (T, r);
7449
+ }
7450
+ }
7451
+ }
7452
+
7397
7453
register_binding_flags HLSLFillRegisterBindingFlags (Sema &S, Decl *D) {
7398
7454
register_binding_flags r;
7399
7455
if (!isDeclaredWithinCOrTBuffer (D)) {
@@ -7452,7 +7508,12 @@ register_binding_flags HLSLFillRegisterBindingFlags(Sema &S, Decl *D) {
7452
7508
r.basic = true ;
7453
7509
else if (SamplerUAVOrSRV->getType ()->isAggregateType ()) {
7454
7510
r.udt = true ;
7455
- // recurse through members, set appropriate resource class flags.
7511
+ QualType VarType = SamplerUAVOrSRV->getType ();
7512
+ if (const RecordType *RT = VarType->getAs <RecordType>()) {
7513
+ const RecordDecl *RD = RT->getDecl ();
7514
+ // recurse through members, set appropriate resource class flags.
7515
+ setResourceClassFlagsFromRecordDecl (r, RD);
7516
+ }
7456
7517
}
7457
7518
else
7458
7519
r.other = true ;
@@ -7479,6 +7540,9 @@ static void DiagnoseHLSLResourceRegType(Sema &S, SourceLocation &ArgLoc,
7479
7540
7480
7541
static void handleHLSLResourceBindingAttr (Sema &S, Decl *D,
7481
7542
const ParsedAttr &AL) {
7543
+ if (S.RequireCompleteType (D->getBeginLoc (), cast<ValueDecl>(D)->getType (),
7544
+ diag::err_incomplete_type))
7545
+ return ;
7482
7546
StringRef Space = " space0" ;
7483
7547
StringRef Slot = " " ;
7484
7548
0 commit comments