Skip to content

[HLSL] Adjust resource binding diagnostic flags code #106657

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 64 additions & 112 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,9 @@ struct RegisterBindingFlags {

bool ContainsNumeric = false;
bool DefaultGlobals = false;

// used only when Resource == true
std::optional<llvm::dxil::ResourceClass> ResourceClass;
};

static bool isDeclaredWithinCOrTBuffer(const Decl *TheDecl) {
Expand Down Expand Up @@ -545,65 +548,38 @@ static const T *getSpecifiedHLSLAttrFromVarDecl(VarDecl *VD) {
return getSpecifiedHLSLAttrFromRecordDecl<T>(TheRecordDecl);
}

static void updateFlagsFromType(QualType TheQualTy,
RegisterBindingFlags &Flags);

static void updateResourceClassFlagsFromRecordDecl(RegisterBindingFlags &Flags,
const RecordDecl *RD) {
if (!RD)
return;

if (RD->isCompleteDefinition()) {
for (auto Field : RD->fields()) {
QualType T = Field->getType();
updateFlagsFromType(T, Flags);
static void updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags,
const RecordType *RT) {
llvm::SmallVector<const Type *> TypesToScan;
TypesToScan.emplace_back(RT);

while (!TypesToScan.empty()) {
const Type *T = TypesToScan.pop_back_val();
while (T->isArrayType())
T = T->getArrayElementTypeNoTypeQual();
if (T->isIntegralOrEnumerationType() || T->isFloatingType()) {
Flags.ContainsNumeric = true;
continue;
}
}
}

static void updateFlagsFromType(QualType TheQualTy,
RegisterBindingFlags &Flags) {
// if the member's type is a numeric type, set the ContainsNumeric flag
if (TheQualTy->isIntegralOrEnumerationType() || TheQualTy->isFloatingType()) {
Flags.ContainsNumeric = true;
return;
}

const clang::Type *TheBaseType = TheQualTy.getTypePtr();
while (TheBaseType->isArrayType())
TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
// otherwise, if the member's base type is not a record type, return
const RecordType *TheRecordTy = TheBaseType->getAs<RecordType>();
if (!TheRecordTy)
return;

RecordDecl *SubRecordDecl = TheRecordTy->getDecl();
const HLSLResourceClassAttr *Attr =
getSpecifiedHLSLAttrFromRecordDecl<HLSLResourceClassAttr>(SubRecordDecl);
// find the attr if it's on the member, or on any of the member's fields
if (Attr) {
llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass();
updateResourceClassFlagsFromDeclResourceClass(Flags, DeclResourceClass);
}
const RecordType *RT = T->getAs<RecordType>();
if (!RT)
continue;

// otherwise, dig deeper and recurse into the member
else {
updateResourceClassFlagsFromRecordDecl(Flags, SubRecordDecl);
const RecordDecl *RD = RT->getDecl();
for (FieldDecl *FD : RD->fields()) {
if (HLSLResourceClassAttr *RCAttr =
FD->getAttr<HLSLResourceClassAttr>()) {
updateResourceClassFlagsFromDeclResourceClass(
Flags, RCAttr->getResourceClass());
continue;
}
TypesToScan.emplace_back(FD->getType().getTypePtr());
}
}
}

static RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S,
Decl *TheDecl) {

// Cbuffers and Tbuffers are HLSLBufferDecl types
HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl);
// Samplers, UAVs, and SRVs are VarDecl types
VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl);

assert(((TheVarDecl && !CBufferOrTBuffer) ||
(!TheVarDecl && CBufferOrTBuffer)) &&
"either TheVarDecl or CBufferOrTBuffer should be set");

RegisterBindingFlags Flags;

// check if the decl type is groupshared
Expand All @@ -612,58 +588,60 @@ static RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S,
return Flags;
}

if (!isDeclaredWithinCOrTBuffer(TheDecl)) {
// make sure the type is a basic / numeric type
if (TheVarDecl) {
QualType TheQualTy = TheVarDecl->getType();
// a numeric variable or an array of numeric variables
// will inevitably end up in $Globals buffer
const clang::Type *TheBaseType = TheQualTy.getTypePtr();
while (TheBaseType->isArrayType())
TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
if (TheBaseType->isIntegralType(S.getASTContext()) ||
TheBaseType->isFloatingType())
Flags.DefaultGlobals = true;
}
}

if (CBufferOrTBuffer) {
// Cbuffers and Tbuffers are HLSLBufferDecl types
if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl)) {
Flags.Resource = true;
if (CBufferOrTBuffer->isCBuffer())
Flags.CBV = true;
else
Flags.SRV = true;
} else if (TheVarDecl) {
Flags.ResourceClass = CBufferOrTBuffer->isCBuffer()
? llvm::dxil::ResourceClass::CBuffer
: llvm::dxil::ResourceClass::SRV;
}
// Samplers, UAVs, and SRVs are VarDecl types
else if (VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl)) {
const HLSLResourceClassAttr *resClassAttr =
getSpecifiedHLSLAttrFromVarDecl<HLSLResourceClassAttr>(TheVarDecl);

if (resClassAttr) {
llvm::hlsl::ResourceClass DeclResourceClass =
resClassAttr->getResourceClass();
Flags.Resource = true;
updateResourceClassFlagsFromDeclResourceClass(Flags, DeclResourceClass);
Flags.ResourceClass = resClassAttr->getResourceClass();
} else {
const clang::Type *TheBaseType = TheVarDecl->getType().getTypePtr();
while (TheBaseType->isArrayType())
TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
if (TheBaseType->isArithmeticType())

if (TheBaseType->isArithmeticType()) {
Flags.Basic = true;
else if (TheBaseType->isRecordType()) {
if (!isDeclaredWithinCOrTBuffer(TheDecl) &&
(TheBaseType->isIntegralType(S.getASTContext()) ||
TheBaseType->isFloatingType()))
Flags.DefaultGlobals = true;
} else if (TheBaseType->isRecordType()) {
Flags.UDT = true;
const RecordType *TheRecordTy = TheBaseType->getAs<RecordType>();
assert(TheRecordTy && "The Qual Type should be Record Type");
const RecordDecl *TheRecordDecl = TheRecordTy->getDecl();
// recurse through members, set appropriate resource class flags.
updateResourceClassFlagsFromRecordDecl(Flags, TheRecordDecl);
updateResourceClassFlagsFromRecordType(Flags, TheRecordTy);
} else
Flags.Other = true;
}
} else {
llvm_unreachable("expected be VarDecl or HLSLBufferDecl");
}
return Flags;
}

enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid };

static RegisterType getRegisterType(llvm::dxil::ResourceClass RC) {
switch (RC) {
case llvm::dxil::ResourceClass::SRV:
return RegisterType::SRV;
case llvm::dxil::ResourceClass::UAV:
return RegisterType::UAV;
case llvm::dxil::ResourceClass::CBuffer:
return RegisterType::CBuffer;
case llvm::dxil::ResourceClass::Sampler:
return RegisterType::Sampler;
}
llvm_unreachable("unexpected ResourceClass value");
}

static RegisterType getRegisterType(StringRef Slot) {
switch (Slot[0]) {
case 't':
Expand Down Expand Up @@ -754,34 +732,8 @@ static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
// next, if resource is set, make sure the register type in the register
// annotation is compatible with the variable's resource type.
if (Flags.Resource) {
const HLSLResourceClassAttr *resClassAttr = nullptr;
if (CBufferOrTBuffer) {
resClassAttr = CBufferOrTBuffer->getAttr<HLSLResourceClassAttr>();
} else if (TheVarDecl) {
resClassAttr =
getSpecifiedHLSLAttrFromVarDecl<HLSLResourceClassAttr>(TheVarDecl);
}

assert(resClassAttr &&
"any decl that set the resource flag on analysis should "
"have a resource class attribute attached.");
const llvm::hlsl::ResourceClass DeclResourceClass =
resClassAttr->getResourceClass();

// confirm that the register type is bound to its expected resource class
static RegisterType ExpectedRegisterTypesForResourceClass[] = {
RegisterType::SRV,
RegisterType::UAV,
RegisterType::CBuffer,
RegisterType::Sampler,
};
assert((size_t)DeclResourceClass <
std::size(ExpectedRegisterTypesForResourceClass) &&
"DeclResourceClass has unexpected value");

RegisterType ExpectedRegisterType =
ExpectedRegisterTypesForResourceClass[(int)DeclResourceClass];
if (regType != ExpectedRegisterType) {
RegisterType expRegType = getRegisterType(Flags.ResourceClass.value());
if (regType != expRegType) {
S.Diag(TheDecl->getLocation(), diag::err_hlsl_binding_type_mismatch)
<< regTypeNum;
}
Expand Down Expand Up @@ -823,7 +775,7 @@ static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
}

void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) {
if (dyn_cast<VarDecl>(TheDecl)) {
if (isa<VarDecl>(TheDecl)) {
if (SemaRef.RequireCompleteType(TheDecl->getBeginLoc(),
cast<ValueDecl>(TheDecl)->getType(),
diag::err_incomplete_type))
Expand Down
7 changes: 7 additions & 0 deletions clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,10 @@ struct Eg14{
};
// expected-warning@+1{{binding type 't' only applies to types containing SRV resources}}
Eg14 e14 : register(t9);

struct Eg15 {
float f[4];
};
// expected no error
Eg15 e15 : register(c0);

Loading