Skip to content

Implement resource binding type prefix mismatch errors #87578

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -4366,7 +4366,7 @@ def HLSLSV_GroupIndex: HLSLAnnotationAttr {

def HLSLResourceBinding: InheritableAttr {
let Spellings = [HLSLSemantic<"register">];
let Subjects = SubjectList<[HLSLBufferObj, ExternalGlobalVar]>;
let Subjects = SubjectList<[HLSLBufferObj, ExternalGlobalVar], ErrorDiag>;
let LangOpts = [HLSL];
let Args = [StringArgument<"Slot">, StringArgument<"Space", 1>];
let Documentation = [HLSLResourceBindingDocs];
Expand Down
5 changes: 4 additions & 1 deletion clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -12154,7 +12154,10 @@ def err_hlsl_missing_semantic_annotation : Error<
def err_hlsl_init_priority_unsupported : Error<
"initializer priorities are not supported in HLSL">;

def err_hlsl_unsupported_register_type : Error<"invalid resource class specifier '%0' used; expected 'b', 's', 't', or 'u'">;
def err_hlsl_mismatching_register_resource_type_and_name: Error<"invalid register name prefix '%0' for register resource type '%1' (expected %select{'t'|'u'|'b'|'s'}2)">;
def err_hlsl_mismatching_register_builtin_type_and_name: Error<"invalid register name prefix '%0' for '%1' (expected %2)">;
def err_hlsl_unsupported_register_prefix : Error<"invalid resource class specifier '%0' used; expected 'b', 's', 't', or 'u'">;
def err_hlsl_unsupported_register_resource_type : Error<"invalid resource '%0' used">;
def err_hlsl_unsupported_register_number : Error<"register number should be an integer">;
def err_hlsl_expected_space : Error<"invalid space specifier '%0' used; expected 'space' followed by an integer, like space1">;
def err_hlsl_pointers_unsupported : Error<
Expand Down
2 changes: 1 addition & 1 deletion clang/include/clang/Sema/SemaHLSL.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace clang {
class SemaHLSL : public SemaBase {
public:
SemaHLSL(Sema &S);

HLSLResourceAttr *mergeHLSLResourceAttr(bool CBuffer);
Decl *ActOnStartBuffer(Scope *BufferScope, bool CBuffer, SourceLocation KwLoc,
IdentifierInfo *Ident, SourceLocation IdentLoc,
SourceLocation LBrace);
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/Parse/ParseDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4859,6 +4859,8 @@ void Parser::ParseStructDeclaration(

// If attributes exist after the declarator, parse them.
MaybeParseGNUAttributes(DeclaratorInfo.D);
if (getLangOpts().HLSL)
MaybeParseHLSLSemantics(DeclaratorInfo.D);

// We're done with this declarator; invoke the callback.
FieldsCallback(DeclaratorInfo);
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/Parse/ParseDeclCXX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2649,6 +2649,9 @@ bool Parser::ParseCXXMemberDeclaratorBeforeInitializer(
else
DeclaratorInfo.SetIdentifier(nullptr, Tok.getLocation());

if (getLangOpts().HLSL)
MaybeParseHLSLSemantics(DeclaratorInfo);

if (!DeclaratorInfo.isFunctionDeclarator() && TryConsumeToken(tok::colon)) {
assert(DeclaratorInfo.isPastIdentifier() &&
"don't know where identifier would go yet?");
Expand Down
26 changes: 14 additions & 12 deletions clang/lib/Sema/HLSLExternalSemaSource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -477,34 +477,36 @@ void HLSLExternalSemaSource::defineTrivialHLSLTypes() {
}

/// Set up common members and attributes for buffer types
static BuiltinTypeDeclBuilder setupBufferType(CXXRecordDecl *Decl, Sema &S,
ResourceClass RC, ResourceKind RK,
bool IsROV) {
static BuiltinTypeDeclBuilder setupBufferHandle(CXXRecordDecl *Decl, Sema &S,
ResourceClass RC) {
return BuiltinTypeDeclBuilder(Decl)
.addHandleMember()
.addDefaultHandleConstructor(S, RC)
.annotateResourceClass(RC, RK, IsROV);
.addDefaultHandleConstructor(S, RC);
}

void HLSLExternalSemaSource::defineHLSLTypesWithForwardDeclarations() {
CXXRecordDecl *Decl;
Decl = BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "RWBuffer")
.addSimpleTemplateParams({"element_type"})
.Record;
Decl =
BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "RWBuffer")
.addSimpleTemplateParams({"element_type"})
.annotateResourceClass(ResourceClass::UAV, ResourceKind::TypedBuffer,
/*IsROV=*/false)
.Record;

onCompletion(Decl, [this](CXXRecordDecl *Decl) {
setupBufferType(Decl, *SemaPtr, ResourceClass::UAV,
ResourceKind::TypedBuffer, /*IsROV=*/false)
setupBufferHandle(Decl, *SemaPtr, ResourceClass::UAV)
.addArraySubscriptOperators()
.completeDefinition();
});

Decl =
BuiltinTypeDeclBuilder(*SemaPtr, HLSLNamespace, "RasterizerOrderedBuffer")
.addSimpleTemplateParams({"element_type"})
.annotateResourceClass(ResourceClass::UAV, ResourceKind::TypedBuffer,
/*IsROV=*/true)
.Record;
onCompletion(Decl, [this](CXXRecordDecl *Decl) {
setupBufferType(Decl, *SemaPtr, ResourceClass::UAV,
ResourceKind::TypedBuffer, /*IsROV=*/true)
setupBufferHandle(Decl, *SemaPtr, ResourceClass::UAV)
.addArraySubscriptOperators()
.completeDefinition();
});
Expand Down
219 changes: 214 additions & 5 deletions clang/lib/Sema/SemaDeclAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7334,8 +7334,215 @@ static void handleHLSLShaderAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
D->addAttr(NewAttr);
}

struct register_binding_flags {
bool resource = false;
bool udt = false;
bool other = false;
bool basic = false;

bool srv = false;
bool uav = false;
bool cbv = false;
bool sampler = false;

bool contains_numeric = false;
bool default_globals = false;
};

bool isDeclaredWithinCOrTBuffer(const Decl *decl) {
if (!decl)
return false;

// Traverse up the parent contexts
const DeclContext *context = decl->getDeclContext();
while (context) {
if (isa<HLSLBufferDecl>(context)) {
return true;
}
context = context->getParent();
}

return false;
}


const HLSLResourceAttr *getHLSLResourceAttrFromVarDecl(VarDecl *SamplerUAVOrSRV) {
const Type *Ty = SamplerUAVOrSRV->getType()->getPointeeOrArrayElementType();
if (!Ty)
llvm_unreachable("Resource class must have an element type.");

if (const BuiltinType *BTy = dyn_cast<BuiltinType>(Ty)) {
/* QualType QT = SamplerUAVOrSRV->getType();
PrintingPolicy PP = S.getPrintingPolicy();
std::string typestr = QualType::getAsString(QT.split(), PP);

S.Diag(ArgLoc, diag::err_hlsl_unsupported_register_resource_type)
<< typestr;
return; */
return nullptr;
}

const CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl();
if (!TheRecordDecl)
llvm_unreachable("Resource class should have a resource type declaration.");

if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(TheRecordDecl))
TheRecordDecl = TDecl->getSpecializedTemplate()->getTemplatedDecl();
TheRecordDecl = TheRecordDecl->getCanonicalDecl();
const auto *Attr = TheRecordDecl->getAttr<HLSLResourceAttr>();
return Attr;
}

void traverseType(QualType T, register_binding_flags &r) {
if (T->isIntegralOrEnumerationType() || T->isFloatingType()) {
r.contains_numeric = true;
return;
} else if (const RecordType *RT = T->getAs<RecordType>()) {
RecordDecl *SubRD = RT->getDecl();
if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(SubRD)) {
auto TheRecordDecl = TDecl->getSpecializedTemplate()->getTemplatedDecl();
TheRecordDecl = TheRecordDecl->getCanonicalDecl();
const auto *Attr = TheRecordDecl->getAttr<HLSLResourceAttr>();
llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass();
switch (DeclResourceClass) {
case llvm::hlsl::ResourceClass::SRV: {
r.srv = true;
break;
}
case llvm::hlsl::ResourceClass::UAV: {
r.uav = true;
break;
}
case llvm::hlsl::ResourceClass::CBuffer: {
r.cbv = true;
break;
}
case llvm::hlsl::ResourceClass::Sampler: {
r.sampler = true;
break;
}
case llvm::hlsl::ResourceClass::Invalid: {
llvm_unreachable("Resource class should be valid.");
break;
}
}
}

else if (SubRD->isCompleteDefinition()) {
for (auto Field : SubRD->fields()) {
QualType T = Field->getType();
traverseType(T, r);
}
}
}
}

void setResourceClassFlagsFromRecordDecl(register_binding_flags &r,
const RecordDecl *RD) {
if (!RD)
return;

if (RD->isCompleteDefinition()) {
for (auto Field : RD->fields()) {
QualType T = Field->getType();
traverseType(T, r);
}
}
}

register_binding_flags HLSLFillRegisterBindingFlags(Sema &S, Decl *D) {
register_binding_flags r;
if (!isDeclaredWithinCOrTBuffer(D)) {
// make sure the type is a basic / numeric type
if (VarDecl *v = dyn_cast<VarDecl>(D)) {
QualType t = v->getType();
// a numeric variable will inevitably end up in $Globals buffer
if (t->isIntegralType(S.getASTContext()) || t->isFloatingType())
r.default_globals = true;
}
}
// Cbuffers and Tbuffers are HLSLBufferDecl types
HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D);
// Samplers, UAVs, and SRVs are VarDecl types
VarDecl *SamplerUAVOrSRV = dyn_cast<VarDecl>(D);

if (CBufferOrTBuffer) {
r.resource = true;
if (CBufferOrTBuffer->isCBuffer())
r.cbv = true;
else
r.srv = true;
}
else if (SamplerUAVOrSRV) {
const HLSLResourceAttr *res_attr =
getHLSLResourceAttrFromVarDecl(SamplerUAVOrSRV);
if (res_attr) {
llvm::hlsl::ResourceClass DeclResourceClass =
res_attr->getResourceClass();
r.resource = true;
switch (DeclResourceClass) {
case llvm::hlsl::ResourceClass::SRV: {
r.srv = true;
break;
}
case llvm::hlsl::ResourceClass::UAV: {
r.uav = true;
break;
}
case llvm::hlsl::ResourceClass::CBuffer: {
r.cbv = true;
break;
}
case llvm::hlsl::ResourceClass::Sampler: {
r.sampler = true;
break;
}
case llvm::hlsl::ResourceClass::Invalid: {
llvm_unreachable("Resource class should be valid.");
break;
}
}
}
else {
if (SamplerUAVOrSRV->getType()->isBuiltinType())
r.basic = true;
else if (SamplerUAVOrSRV->getType()->isAggregateType()) {
r.udt = true;
QualType VarType = SamplerUAVOrSRV->getType();
if (const RecordType *RT = VarType->getAs<RecordType>()) {
const RecordDecl *RD = RT->getDecl();
// recurse through members, set appropriate resource class flags.
setResourceClassFlagsFromRecordDecl(r, RD);
}
}
else
r.other = true;
}
}
else {
llvm_unreachable("unknown decl type");
}
return r;
}


static void DiagnoseHLSLResourceRegType(Sema &S, SourceLocation &ArgLoc,
Decl *D, StringRef &Slot) {

register_binding_flags f = HLSLFillRegisterBindingFlags(S, D);
// Samplers, UAVs, and SRVs are VarDecl types
VarDecl *SamplerUAVOrSRV = dyn_cast<VarDecl>(D);
// Cbuffers and Tbuffers are HLSLBufferDecl types
HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D);
if (!SamplerUAVOrSRV && !CBufferOrTBuffer)
return;
}

static void handleHLSLResourceBindingAttr(Sema &S, Decl *D,
const ParsedAttr &AL) {
if (S.RequireCompleteType(D->getBeginLoc(), cast<ValueDecl>(D)->getType(),
diag::err_incomplete_type))
return;
StringRef Space = "space0";
StringRef Slot = "";

Expand Down Expand Up @@ -7368,13 +7575,15 @@ static void handleHLSLResourceBindingAttr(Sema &S, Decl *D,
// Validate.
if (!Slot.empty()) {
switch (Slot[0]) {
case 't':
case 'u':
case 'b':
case 's':
case 't':
case 's':
case 'c':
case 'i':
break;
default:
S.Diag(ArgLoc, diag::err_hlsl_unsupported_register_type)
S.Diag(ArgLoc, diag::err_hlsl_unsupported_register_prefix)
<< Slot.substr(0, 1);
return;
}
Expand All @@ -7398,8 +7607,8 @@ static void handleHLSLResourceBindingAttr(Sema &S, Decl *D,
return;
}

// FIXME: check reg type match decl. Issue
// https://github.com/llvm/llvm-project/issues/57886.
DiagnoseHLSLResourceRegType(S, ArgLoc, D, Slot);

HLSLResourceBindingAttr *NewAttr =
HLSLResourceBindingAttr::Create(S.getASTContext(), Slot, Space, AL);
if (NewAttr)
Expand Down
23 changes: 22 additions & 1 deletion clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,25 @@ using namespace clang;

SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {}

HLSLResourceAttr *SemaHLSL::mergeHLSLResourceAttr(bool CBuffer) {
// cbuffer case
if (CBuffer) {
HLSLResourceAttr *attr = HLSLResourceAttr::CreateImplicit(
getASTContext(), llvm::hlsl::ResourceClass::CBuffer,
llvm::hlsl::ResourceKind::CBuffer,
/*IsROV=*/false);
return attr;
}
// tbuffer case
else {
HLSLResourceAttr *attr = HLSLResourceAttr::CreateImplicit(
getASTContext(), llvm::hlsl::ResourceClass::SRV,
llvm::hlsl::ResourceKind::TBuffer,
/*IsROV=*/false);
return attr;
}
}

Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
SourceLocation KwLoc, IdentifierInfo *Ident,
SourceLocation IdentLoc,
Expand All @@ -32,7 +51,9 @@ Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
DeclContext *LexicalParent = SemaRef.getCurLexicalContext();
HLSLBufferDecl *Result = HLSLBufferDecl::Create(
getASTContext(), LexicalParent, CBuffer, KwLoc, Ident, IdentLoc, LBrace);

HLSLResourceAttr *NewAttr = mergeHLSLResourceAttr(CBuffer);
if (NewAttr)
Result->addAttr(NewAttr);
SemaRef.PushOnScopeChains(Result, BufferScope);
SemaRef.PushDeclContext(BufferScope, Result);

Expand Down
4 changes: 2 additions & 2 deletions clang/lib/Serialization/ASTWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5014,7 +5014,7 @@ void ASTWriter::WriteSpecialDeclRecords(Sema &SemaRef) {

if (!ModularCodegenDecls.empty())
Stream.EmitRecord(MODULAR_CODEGEN_DECLS, ModularCodegenDecls);

// Write the record containing tentative definitions.
RecordData TentativeDefinitions;
AddLazyVectorEmiitedDecls(*this, SemaRef.TentativeDefinitions,
Expand Down Expand Up @@ -5135,7 +5135,7 @@ void ASTWriter::WriteSpecialDeclRecords(Sema &SemaRef) {
}
if (!UndefinedButUsed.empty())
Stream.EmitRecord(UNDEFINED_BUT_USED, UndefinedButUsed);

// Write all delete-expressions that we would like to
// analyze later in AST.
RecordData DeleteExprsToAnalyze;
Expand Down
Loading