@@ -7334,89 +7334,147 @@ static void handleHLSLShaderAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
7334
7334
D->addAttr (NewAttr);
7335
7335
}
7336
7336
7337
- static void DiagnoseHLSLResourceRegType (Sema &S, SourceLocation &ArgLoc,
7338
- Decl *D, StringRef &Slot) {
7339
- // Samplers, UAVs, and SRVs are VarDecl types
7340
- VarDecl *SamplerUAVOrSRV = dyn_cast<VarDecl>(D) ;
7341
- // Cbuffers and Tbuffers are HLSLBufferDecl types
7342
- HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D);
7343
- if (!SamplerUAVOrSRV && !CBufferOrTBuffer)
7344
- return ;
7345
-
7346
- llvm::hlsl::ResourceClass DeclResourceClass ;
7347
- StringRef VarTy = " " ;
7348
- if (SamplerUAVOrSRV) {
7349
- const Type *Ty = SamplerUAVOrSRV-> getType ()-> getPointeeOrArrayElementType () ;
7350
- if (!Ty)
7351
- llvm_unreachable ( " Resource class must have an element type. " ) ;
7337
+ struct register_binding_flags {
7338
+ bool resource = false ;
7339
+ bool udt = false ;
7340
+ bool other = false ;
7341
+ bool basic = false ;
7342
+
7343
+ bool srv;
7344
+ bool uav ;
7345
+ bool cbv;
7346
+ bool sampler ;
7347
+
7348
+ bool contains_numeric = false ;
7349
+ bool default_globals = false ;
7350
+ bool is_member = false ;
7351
+ } ;
7352
7352
7353
- if (const BuiltinType *BTy = dyn_cast<BuiltinType>(Ty)) {
7354
- QualType QT = SamplerUAVOrSRV->getType ();
7355
- PrintingPolicy PP = S.getPrintingPolicy ();
7356
- std::string typestr = QualType::getAsString (QT.split (), PP);
7353
+ bool isDeclaredWithinCOrTBuffer (const Decl *decl) {
7354
+ if (!decl)
7355
+ return false ;
7357
7356
7358
- S.Diag (ArgLoc, diag::err_hlsl_unsupported_register_resource_type)
7359
- << typestr;
7360
- return ;
7357
+ // Traverse up the parent contexts
7358
+ const DeclContext *context = decl->getDeclContext ();
7359
+ while (context) {
7360
+ if (isa<HLSLBufferDecl>(context)) {
7361
+ return true ;
7361
7362
}
7363
+ context = context->getParent ();
7364
+ }
7362
7365
7363
- const CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl ();
7364
- if (!TheRecordDecl)
7365
- llvm_unreachable (
7366
- " Resource class should have a resource type declaration." );
7366
+ return false ;
7367
+ }
7367
7368
7368
- if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(TheRecordDecl))
7369
- TheRecordDecl = TDecl->getSpecializedTemplate ()->getTemplatedDecl ();
7370
- TheRecordDecl = TheRecordDecl->getCanonicalDecl ();
7371
- const auto *Attr = TheRecordDecl->getAttr <HLSLResourceAttr>();
7372
- if (!Attr)
7373
- llvm_unreachable (" Resource class should have a resource attribute." );
7374
7369
7375
- DeclResourceClass = Attr->getResourceClass ();
7376
- VarTy = TheRecordDecl->getName ();
7377
- } else {
7378
- HLSLResourceAttr *Attr = CBufferOrTBuffer->getAttr <HLSLResourceAttr>();
7379
- DeclResourceClass = Attr->getResourceClass ();
7380
- if (CBufferOrTBuffer->isCBuffer ()) {
7381
- VarTy = " cbuffer" ;
7382
- } else {
7383
- VarTy = " tbuffer" ;
7384
- }
7385
- }
7386
- switch (DeclResourceClass) {
7387
- case llvm::hlsl::ResourceClass::SRV: {
7388
- if (Slot[0 ] == ' t' )
7389
- return ;
7390
- break ;
7391
- }
7392
- case llvm::hlsl::ResourceClass::UAV: {
7393
- if (Slot[0 ] == ' u' )
7394
- return ;
7395
- break ;
7370
+ const HLSLResourceAttr *getHLSLResourceAttrFromVarDecl (VarDecl *SamplerUAVOrSRV) {
7371
+ const Type *Ty = SamplerUAVOrSRV->getType ()->getPointeeOrArrayElementType ();
7372
+ if (!Ty)
7373
+ llvm_unreachable (" Resource class must have an element type." );
7374
+
7375
+ if (const BuiltinType *BTy = dyn_cast<BuiltinType>(Ty)) {
7376
+ /* QualType QT = SamplerUAVOrSRV->getType();
7377
+ PrintingPolicy PP = S.getPrintingPolicy();
7378
+ std::string typestr = QualType::getAsString(QT.split(), PP);
7379
+
7380
+ S.Diag(ArgLoc, diag::err_hlsl_unsupported_register_resource_type)
7381
+ << typestr;
7382
+ return; */
7383
+ return nullptr ;
7396
7384
}
7397
- case llvm::hlsl::ResourceClass::CBuffer: {
7398
- // could be CBuffer or TBuffer
7399
- if (VarTy == " cbuffer" ) {
7400
- if (Slot[0 ] == ' b' )
7401
- return ;
7402
- } else if (VarTy == " tbuffer" ) {
7403
- if (Slot[0 ] == ' t' )
7404
- return ;
7385
+
7386
+ const CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl ();
7387
+ if (!TheRecordDecl)
7388
+ llvm_unreachable (" Resource class should have a resource type declaration." );
7389
+
7390
+ if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(TheRecordDecl))
7391
+ TheRecordDecl = TDecl->getSpecializedTemplate ()->getTemplatedDecl ();
7392
+ TheRecordDecl = TheRecordDecl->getCanonicalDecl ();
7393
+ const auto *Attr = TheRecordDecl->getAttr <HLSLResourceAttr>();
7394
+ return Attr;
7395
+ }
7396
+
7397
+ register_binding_flags HLSLFillRegisterBindingFlags (Sema &S, Decl *D) {
7398
+ register_binding_flags r;
7399
+ if (!isDeclaredWithinCOrTBuffer (D)) {
7400
+ // make sure the type is a basic / numeric type
7401
+ if (VarDecl *v = dyn_cast<VarDecl>(D)) {
7402
+ QualType t = v->getType ();
7403
+ // a numeric variable will inevitably end up in $Globals buffer
7404
+ if (t->isIntegralType (S.getASTContext ()) || t->isFloatingType ())
7405
+ r.default_globals = true ;
7405
7406
}
7406
- break ;
7407
- }
7408
- case llvm::hlsl::ResourceClass::Sampler: {
7409
- if (Slot[0 ] == ' s' )
7410
- return ;
7411
- break ;
7412
7407
}
7413
- case llvm::hlsl::ResourceClass::Invalid: {
7414
- llvm_unreachable (" Resource class should be valid." );
7415
- break ;
7408
+ // Cbuffers and Tbuffers are HLSLBufferDecl types
7409
+ HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D);
7410
+ // Samplers, UAVs, and SRVs are VarDecl types
7411
+ VarDecl *SamplerUAVOrSRV = dyn_cast<VarDecl>(D);
7412
+
7413
+ if (CBufferOrTBuffer) {
7414
+ r.resource = true ;
7415
+ if (CBufferOrTBuffer->isCBuffer ())
7416
+ r.cbv = true ;
7417
+ else
7418
+ r.srv = true ;
7419
+ }
7420
+ else if (SamplerUAVOrSRV) {
7421
+ const HLSLResourceAttr *res_attr =
7422
+ getHLSLResourceAttrFromVarDecl (SamplerUAVOrSRV);
7423
+ if (res_attr) {
7424
+ llvm::hlsl::ResourceClass DeclResourceClass =
7425
+ res_attr->getResourceClass ();
7426
+ r.resource = true ;
7427
+ switch (DeclResourceClass) {
7428
+ case llvm::hlsl::ResourceClass::SRV: {
7429
+ r.srv = true ;
7430
+ break ;
7431
+ }
7432
+ case llvm::hlsl::ResourceClass::UAV: {
7433
+ r.uav = true ;
7434
+ break ;
7435
+ }
7436
+ case llvm::hlsl::ResourceClass::CBuffer: {
7437
+ r.cbv = true ;
7438
+ break ;
7439
+ }
7440
+ case llvm::hlsl::ResourceClass::Sampler: {
7441
+ r.sampler = true ;
7442
+ break ;
7443
+ }
7444
+ case llvm::hlsl::ResourceClass::Invalid: {
7445
+ llvm_unreachable (" Resource class should be valid." );
7446
+ break ;
7447
+ }
7448
+ }
7449
+ }
7450
+ else {
7451
+ if (SamplerUAVOrSRV->getType ()->isBuiltinType ())
7452
+ r.basic = true ;
7453
+ else if (SamplerUAVOrSRV->getType ()->isAggregateType ()) {
7454
+ r.udt = true ;
7455
+ // recurse through members, set appropriate resource class flags.
7456
+ }
7457
+ else
7458
+ r.other = true ;
7459
+ }
7416
7460
}
7461
+ else {
7462
+ llvm_unreachable (" unknown decl type" );
7417
7463
}
7418
- S.Diag (ArgLoc, diag::err_hlsl_mismatching_register_resource_type_and_name)
7419
- << Slot.substr (0 , 1 ) << VarTy << (unsigned )DeclResourceClass;
7464
+ return r;
7465
+ }
7466
+
7467
+
7468
+ static void DiagnoseHLSLResourceRegType (Sema &S, SourceLocation &ArgLoc,
7469
+ Decl *D, StringRef &Slot) {
7470
+
7471
+ register_binding_flags f = HLSLFillRegisterBindingFlags (S, D);
7472
+ // Samplers, UAVs, and SRVs are VarDecl types
7473
+ VarDecl *SamplerUAVOrSRV = dyn_cast<VarDecl>(D);
7474
+ // Cbuffers and Tbuffers are HLSLBufferDecl types
7475
+ HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D);
7476
+ if (!SamplerUAVOrSRV && !CBufferOrTBuffer)
7477
+ return ;
7420
7478
}
7421
7479
7422
7480
static void handleHLSLResourceBindingAttr (Sema &S, Decl *D,
@@ -7453,10 +7511,12 @@ static void handleHLSLResourceBindingAttr(Sema &S, Decl *D,
7453
7511
// Validate.
7454
7512
if (!Slot.empty ()) {
7455
7513
switch (Slot[0 ]) {
7514
+ case ' t' :
7456
7515
case ' u' :
7457
7516
case ' b' :
7458
- case ' s' :
7459
- case ' t' :
7517
+ case ' s' :
7518
+ case ' c' :
7519
+ case ' i' :
7460
7520
break ;
7461
7521
default :
7462
7522
S.Diag (ArgLoc, diag::err_hlsl_unsupported_register_prefix)
0 commit comments