Skip to content

Commit 6888ea8

Browse files
committed
write code for flag set step, except for recursive udt step
1 parent 8baae27 commit 6888ea8

File tree

1 file changed

+134
-74
lines changed

1 file changed

+134
-74
lines changed

clang/lib/Sema/SemaDeclAttr.cpp

Lines changed: 134 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -7334,89 +7334,147 @@ static void handleHLSLShaderAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
73347334
D->addAttr(NewAttr);
73357335
}
73367336

7337-
static void DiagnoseHLSLResourceRegType(Sema &S, SourceLocation &ArgLoc,
7338-
Decl *D, StringRef &Slot) {
7339-
// Samplers, UAVs, and SRVs are VarDecl types
7340-
VarDecl *SamplerUAVOrSRV = dyn_cast<VarDecl>(D);
7341-
// Cbuffers and Tbuffers are HLSLBufferDecl types
7342-
HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D);
7343-
if (!SamplerUAVOrSRV && !CBufferOrTBuffer)
7344-
return;
7345-
7346-
llvm::hlsl::ResourceClass DeclResourceClass;
7347-
StringRef VarTy = "";
7348-
if (SamplerUAVOrSRV) {
7349-
const Type *Ty = SamplerUAVOrSRV->getType()->getPointeeOrArrayElementType();
7350-
if (!Ty)
7351-
llvm_unreachable("Resource class must have an element type.");
7337+
struct register_binding_flags {
7338+
bool resource = false;
7339+
bool udt = false;
7340+
bool other = false;
7341+
bool basic = false;
7342+
7343+
bool srv;
7344+
bool uav;
7345+
bool cbv;
7346+
bool sampler;
7347+
7348+
bool contains_numeric = false;
7349+
bool default_globals = false;
7350+
bool is_member = false;
7351+
};
73527352

7353-
if (const BuiltinType *BTy = dyn_cast<BuiltinType>(Ty)) {
7354-
QualType QT = SamplerUAVOrSRV->getType();
7355-
PrintingPolicy PP = S.getPrintingPolicy();
7356-
std::string typestr = QualType::getAsString(QT.split(), PP);
7353+
bool isDeclaredWithinCOrTBuffer(const Decl *decl) {
7354+
if (!decl)
7355+
return false;
73577356

7358-
S.Diag(ArgLoc, diag::err_hlsl_unsupported_register_resource_type)
7359-
<< typestr;
7360-
return;
7357+
// Traverse up the parent contexts
7358+
const DeclContext *context = decl->getDeclContext();
7359+
while (context) {
7360+
if (isa<HLSLBufferDecl>(context)) {
7361+
return true;
73617362
}
7363+
context = context->getParent();
7364+
}
73627365

7363-
const CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl();
7364-
if (!TheRecordDecl)
7365-
llvm_unreachable(
7366-
"Resource class should have a resource type declaration.");
7366+
return false;
7367+
}
73677368

7368-
if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(TheRecordDecl))
7369-
TheRecordDecl = TDecl->getSpecializedTemplate()->getTemplatedDecl();
7370-
TheRecordDecl = TheRecordDecl->getCanonicalDecl();
7371-
const auto *Attr = TheRecordDecl->getAttr<HLSLResourceAttr>();
7372-
if (!Attr)
7373-
llvm_unreachable("Resource class should have a resource attribute.");
73747369

7375-
DeclResourceClass = Attr->getResourceClass();
7376-
VarTy = TheRecordDecl->getName();
7377-
} else {
7378-
HLSLResourceAttr *Attr = CBufferOrTBuffer->getAttr<HLSLResourceAttr>();
7379-
DeclResourceClass = Attr->getResourceClass();
7380-
if (CBufferOrTBuffer->isCBuffer()) {
7381-
VarTy = "cbuffer";
7382-
} else {
7383-
VarTy = "tbuffer";
7384-
}
7385-
}
7386-
switch (DeclResourceClass) {
7387-
case llvm::hlsl::ResourceClass::SRV: {
7388-
if (Slot[0] == 't')
7389-
return;
7390-
break;
7391-
}
7392-
case llvm::hlsl::ResourceClass::UAV: {
7393-
if (Slot[0] == 'u')
7394-
return;
7395-
break;
7370+
const HLSLResourceAttr *getHLSLResourceAttrFromVarDecl(VarDecl *SamplerUAVOrSRV) {
7371+
const Type *Ty = SamplerUAVOrSRV->getType()->getPointeeOrArrayElementType();
7372+
if (!Ty)
7373+
llvm_unreachable("Resource class must have an element type.");
7374+
7375+
if (const BuiltinType *BTy = dyn_cast<BuiltinType>(Ty)) {
7376+
/* QualType QT = SamplerUAVOrSRV->getType();
7377+
PrintingPolicy PP = S.getPrintingPolicy();
7378+
std::string typestr = QualType::getAsString(QT.split(), PP);
7379+
7380+
S.Diag(ArgLoc, diag::err_hlsl_unsupported_register_resource_type)
7381+
<< typestr;
7382+
return; */
7383+
return nullptr;
73967384
}
7397-
case llvm::hlsl::ResourceClass::CBuffer: {
7398-
// could be CBuffer or TBuffer
7399-
if (VarTy == "cbuffer") {
7400-
if (Slot[0] == 'b')
7401-
return;
7402-
} else if (VarTy == "tbuffer") {
7403-
if (Slot[0] == 't')
7404-
return;
7385+
7386+
const CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl();
7387+
if (!TheRecordDecl)
7388+
llvm_unreachable("Resource class should have a resource type declaration.");
7389+
7390+
if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(TheRecordDecl))
7391+
TheRecordDecl = TDecl->getSpecializedTemplate()->getTemplatedDecl();
7392+
TheRecordDecl = TheRecordDecl->getCanonicalDecl();
7393+
const auto *Attr = TheRecordDecl->getAttr<HLSLResourceAttr>();
7394+
return Attr;
7395+
}
7396+
7397+
register_binding_flags HLSLFillRegisterBindingFlags(Sema &S, Decl *D) {
7398+
register_binding_flags r;
7399+
if (!isDeclaredWithinCOrTBuffer(D)) {
7400+
// make sure the type is a basic / numeric type
7401+
if (VarDecl *v = dyn_cast<VarDecl>(D)) {
7402+
QualType t = v->getType();
7403+
// a numeric variable will inevitably end up in $Globals buffer
7404+
if (t->isIntegralType(S.getASTContext()) || t->isFloatingType())
7405+
r.default_globals = true;
74057406
}
7406-
break;
7407-
}
7408-
case llvm::hlsl::ResourceClass::Sampler: {
7409-
if (Slot[0] == 's')
7410-
return;
7411-
break;
74127407
}
7413-
case llvm::hlsl::ResourceClass::Invalid: {
7414-
llvm_unreachable("Resource class should be valid.");
7415-
break;
7408+
// Cbuffers and Tbuffers are HLSLBufferDecl types
7409+
HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D);
7410+
// Samplers, UAVs, and SRVs are VarDecl types
7411+
VarDecl *SamplerUAVOrSRV = dyn_cast<VarDecl>(D);
7412+
7413+
if (CBufferOrTBuffer) {
7414+
r.resource = true;
7415+
if (CBufferOrTBuffer->isCBuffer())
7416+
r.cbv = true;
7417+
else
7418+
r.srv = true;
7419+
}
7420+
else if (SamplerUAVOrSRV) {
7421+
const HLSLResourceAttr *res_attr =
7422+
getHLSLResourceAttrFromVarDecl(SamplerUAVOrSRV);
7423+
if (res_attr) {
7424+
llvm::hlsl::ResourceClass DeclResourceClass =
7425+
res_attr->getResourceClass();
7426+
r.resource = true;
7427+
switch (DeclResourceClass) {
7428+
case llvm::hlsl::ResourceClass::SRV: {
7429+
r.srv = true;
7430+
break;
7431+
}
7432+
case llvm::hlsl::ResourceClass::UAV: {
7433+
r.uav = true;
7434+
break;
7435+
}
7436+
case llvm::hlsl::ResourceClass::CBuffer: {
7437+
r.cbv = true;
7438+
break;
7439+
}
7440+
case llvm::hlsl::ResourceClass::Sampler: {
7441+
r.sampler = true;
7442+
break;
7443+
}
7444+
case llvm::hlsl::ResourceClass::Invalid: {
7445+
llvm_unreachable("Resource class should be valid.");
7446+
break;
7447+
}
7448+
}
7449+
}
7450+
else {
7451+
if (SamplerUAVOrSRV->getType()->isBuiltinType())
7452+
r.basic = true;
7453+
else if (SamplerUAVOrSRV->getType()->isAggregateType()) {
7454+
r.udt = true;
7455+
// recurse through members, set appropriate resource class flags.
7456+
}
7457+
else
7458+
r.other = true;
7459+
}
74167460
}
7461+
else {
7462+
llvm_unreachable("unknown decl type");
74177463
}
7418-
S.Diag(ArgLoc, diag::err_hlsl_mismatching_register_resource_type_and_name)
7419-
<< Slot.substr(0, 1) << VarTy << (unsigned)DeclResourceClass;
7464+
return r;
7465+
}
7466+
7467+
7468+
static void DiagnoseHLSLResourceRegType(Sema &S, SourceLocation &ArgLoc,
7469+
Decl *D, StringRef &Slot) {
7470+
7471+
register_binding_flags f = HLSLFillRegisterBindingFlags(S, D);
7472+
// Samplers, UAVs, and SRVs are VarDecl types
7473+
VarDecl *SamplerUAVOrSRV = dyn_cast<VarDecl>(D);
7474+
// Cbuffers and Tbuffers are HLSLBufferDecl types
7475+
HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D);
7476+
if (!SamplerUAVOrSRV && !CBufferOrTBuffer)
7477+
return;
74207478
}
74217479

74227480
static void handleHLSLResourceBindingAttr(Sema &S, Decl *D,
@@ -7453,10 +7511,12 @@ static void handleHLSLResourceBindingAttr(Sema &S, Decl *D,
74537511
// Validate.
74547512
if (!Slot.empty()) {
74557513
switch (Slot[0]) {
7514+
case 't':
74567515
case 'u':
74577516
case 'b':
7458-
case 's':
7459-
case 't':
7517+
case 's':
7518+
case 'c':
7519+
case 'i':
74607520
break;
74617521
default:
74627522
S.Diag(ArgLoc, diag::err_hlsl_unsupported_register_prefix)

0 commit comments

Comments
 (0)