Skip to content

[HLSL] Implement default constant buffer $Globals (2nd attempt) #128589

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions clang/include/clang/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -5045,15 +5045,26 @@ class HLSLBufferDecl final : public NamedDecl, public DeclContext {
// LayoutStruct - Layout struct for the buffer
CXXRecordDecl *LayoutStruct;

// For default (implicit) constant buffer, an array of references of global
// decls that belong to the buffer. The decls are already parented by the
// translation unit context. The array is allocated by the ASTContext
// allocator in HLSLBufferDecl::CreateDefaultCBuffer.
ArrayRef<Decl *> DefaultBufferDecls;

HLSLBufferDecl(DeclContext *DC, bool CBuffer, SourceLocation KwLoc,
IdentifierInfo *ID, SourceLocation IDLoc,
SourceLocation LBrace);

void setDefaultBufferDecls(ArrayRef<Decl *> Decls);

public:
static HLSLBufferDecl *Create(ASTContext &C, DeclContext *LexicalParent,
bool CBuffer, SourceLocation KwLoc,
IdentifierInfo *ID, SourceLocation IDLoc,
SourceLocation LBrace);
static HLSLBufferDecl *
CreateDefaultCBuffer(ASTContext &C, DeclContext *LexicalParent,
ArrayRef<Decl *> DefaultCBufferDecls);
static HLSLBufferDecl *CreateDeserialized(ASTContext &C, GlobalDeclID ID);

SourceRange getSourceRange() const override LLVM_READONLY {
Expand All @@ -5079,6 +5090,28 @@ class HLSLBufferDecl final : public NamedDecl, public DeclContext {
return static_cast<HLSLBufferDecl *>(const_cast<DeclContext *>(DC));
}

// Iterator for the buffer decls. For constant buffers explicitly declared
// with `cbuffer` keyword this will the list of decls parented by this
// HLSLBufferDecl (equal to `decls()`).
// For implicit $Globals buffer this will be the list of default buffer
// declarations stored in DefaultBufferDecls plus the implicit layout
// struct (the only child of HLSLBufferDecl in this case).
//
// The iterator uses llvm::concat_iterator to concatenate the lists
// `decls()` and `DefaultBufferDecls`. For non-default buffers
// `DefaultBufferDecls` is always empty.
using buffer_decl_iterator =
llvm::concat_iterator<Decl *const, SmallVector<Decl *>::const_iterator,
decl_iterator>;
using buffer_decl_range = llvm::iterator_range<buffer_decl_iterator>;

buffer_decl_range buffer_decls() const {
return buffer_decl_range(buffer_decls_begin(), buffer_decls_end());
}
buffer_decl_iterator buffer_decls_begin() const;
buffer_decl_iterator buffer_decls_end() const;
bool buffer_decls_empty();

friend class ASTDeclReader;
friend class ASTDeclWriter;
};
Expand Down
8 changes: 7 additions & 1 deletion clang/include/clang/Sema/SemaHLSL.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,13 @@ class SemaHLSL : public SemaBase {
HLSLParamModifierAttr::Spelling Spelling);
void ActOnTopLevelFunction(FunctionDecl *FD);
void ActOnVariableDeclarator(VarDecl *VD);
void ActOnEndOfTranslationUnit(TranslationUnitDecl *TU);
void CheckEntryPoint(FunctionDecl *FD);
void CheckSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
const HLSLAnnotationAttr *AnnotationAttr);
void DiagnoseAttrStageMismatch(
const Attr *A, llvm::Triple::EnvironmentType Stage,
std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages);
void DiagnoseAvailabilityViolations(TranslationUnitDecl *TU);

QualType handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS,
QualType LHSType, QualType RHSType,
Expand Down Expand Up @@ -168,11 +168,17 @@ class SemaHLSL : public SemaBase {
// List of all resource bindings
ResourceBindings Bindings;

// Global declaration collected for the $Globals default constant
// buffer which will be created at the end of the translation unit.
llvm::SmallVector<Decl *> DefaultCBufferDecls;

private:
void collectResourceBindingsOnVarDecl(VarDecl *D);
void collectResourceBindingsOnUserRecordDecl(const VarDecl *VD,
const RecordType *RT);
void processExplicitBindingsOnDecl(VarDecl *D);

void diagnoseAvailabilityViolations(TranslationUnitDecl *TU);
};

} // namespace clang
Expand Down
43 changes: 43 additions & 0 deletions clang/lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
Expand Down Expand Up @@ -5745,6 +5746,18 @@ HLSLBufferDecl *HLSLBufferDecl::Create(ASTContext &C,
return Result;
}

HLSLBufferDecl *
HLSLBufferDecl::CreateDefaultCBuffer(ASTContext &C, DeclContext *LexicalParent,
ArrayRef<Decl *> DefaultCBufferDecls) {
DeclContext *DC = LexicalParent;
IdentifierInfo *II = &C.Idents.get("$Globals", tok::TokenKind::identifier);
HLSLBufferDecl *Result = new (C, DC) HLSLBufferDecl(
DC, true, SourceLocation(), II, SourceLocation(), SourceLocation());
Result->setImplicit(true);
Result->setDefaultBufferDecls(DefaultCBufferDecls);
return Result;
}

HLSLBufferDecl *HLSLBufferDecl::CreateDeserialized(ASTContext &C,
GlobalDeclID ID) {
return new (C, ID) HLSLBufferDecl(nullptr, false, SourceLocation(), nullptr,
Expand All @@ -5757,6 +5770,36 @@ void HLSLBufferDecl::addLayoutStruct(CXXRecordDecl *LS) {
addDecl(LS);
}

void HLSLBufferDecl::setDefaultBufferDecls(ArrayRef<Decl *> Decls) {
assert(!Decls.empty());
assert(DefaultBufferDecls.empty() && "default decls are already set");
assert(isImplicit() &&
"default decls can only be added to the implicit/default constant "
"buffer $Globals");

// allocate array for default decls with ASTContext allocator
Decl **DeclsArray = new (getASTContext()) Decl *[Decls.size()];
std::copy(Decls.begin(), Decls.end(), DeclsArray);
DefaultBufferDecls = ArrayRef<Decl *>(DeclsArray, Decls.size());
}

HLSLBufferDecl::buffer_decl_iterator
HLSLBufferDecl::buffer_decls_begin() const {
return buffer_decl_iterator(llvm::iterator_range(DefaultBufferDecls.begin(),
DefaultBufferDecls.end()),
decl_range(decls_begin(), decls_end()));
}

HLSLBufferDecl::buffer_decl_iterator HLSLBufferDecl::buffer_decls_end() const {
return buffer_decl_iterator(
llvm::iterator_range(DefaultBufferDecls.end(), DefaultBufferDecls.end()),
decl_range(decls_end(), decls_end()));
}

bool HLSLBufferDecl::buffer_decls_empty() {
return DefaultBufferDecls.empty() && decls_empty();
}

//===----------------------------------------------------------------------===//
// ImportDecl Implementation
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/CodeGen/CGHLSLRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ void CGHLSLRuntime::emitBufferGlobalsAndMetadata(const HLSLBufferDecl *BufDecl,
BufGlobals.push_back(ValueAsMetadata::get(BufGV));

const auto *ElemIt = LayoutStruct->element_begin();
for (Decl *D : BufDecl->decls()) {
for (Decl *D : BufDecl->buffer_decls()) {
if (isa<CXXRecordDecl, EmptyDecl>(D))
// Nothing to do for this declaration.
continue;
Expand Down
5 changes: 5 additions & 0 deletions clang/lib/CodeGen/CodeGenModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5513,6 +5513,11 @@ void CodeGenModule::EmitGlobalVarDefinition(const VarDecl *D,
if (getLangOpts().OpenCL && ASTTy->isSamplerT())
return;

// HLSL default buffer constants will be emitted during HLSLBufferDecl codegen
if (getLangOpts().HLSL &&
D->getType().getAddressSpace() == LangAS::hlsl_constant)
return;

// If this is OpenMP device, check if it is legal to emit this global
// normally.
if (LangOpts.OpenMPIsTargetDevice && OpenMPRuntime &&
Expand Down
3 changes: 1 addition & 2 deletions clang/lib/Sema/Sema.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1417,8 +1417,7 @@ void Sema::ActOnEndOfTranslationUnit() {
}

if (LangOpts.HLSL)
HLSL().DiagnoseAvailabilityViolations(
getASTContext().getTranslationUnitDecl());
HLSL().ActOnEndOfTranslationUnit(getASTContext().getTranslationUnitDecl());

// If there were errors, disable 'unused' warnings since they will mostly be
// noise. Don't warn for a use from a module: either we should warn on all
Expand Down
43 changes: 39 additions & 4 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
//===----------------------------------------------------------------------===//

#include "clang/Sema/SemaHLSL.h"
#include "clang/AST/ASTConsumer.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/Attr.h"
#include "clang/AST/Attrs.inc"
Expand Down Expand Up @@ -225,7 +226,7 @@ static void validatePackoffset(Sema &S, HLSLBufferDecl *BufDecl) {
// or on none.
bool HasPackOffset = false;
bool HasNonPackOffset = false;
for (auto *Field : BufDecl->decls()) {
for (auto *Field : BufDecl->buffer_decls()) {
VarDecl *Var = dyn_cast<VarDecl>(Field);
if (!Var)
continue;
Expand Down Expand Up @@ -492,7 +493,7 @@ void createHostLayoutStructForBuffer(Sema &S, HLSLBufferDecl *BufDecl) {
LS->setImplicit(true);
LS->startDefinition();

for (Decl *D : BufDecl->decls()) {
for (Decl *D : BufDecl->buffer_decls()) {
VarDecl *VD = dyn_cast<VarDecl>(D);
if (!VD || VD->getStorageClass() == SC_Static ||
VD->getType().getAddressSpace() == LangAS::hlsl_groupshared)
Expand Down Expand Up @@ -1928,7 +1929,22 @@ void DiagnoseHLSLAvailability::CheckDeclAvailability(NamedDecl *D,

} // namespace

void SemaHLSL::DiagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
void SemaHLSL::ActOnEndOfTranslationUnit(TranslationUnitDecl *TU) {
// process default CBuffer - create buffer layout struct and invoke codegenCGH
if (!DefaultCBufferDecls.empty()) {
HLSLBufferDecl *DefaultCBuffer = HLSLBufferDecl::CreateDefaultCBuffer(
SemaRef.getASTContext(), SemaRef.getCurLexicalContext(),
DefaultCBufferDecls);
SemaRef.getCurLexicalContext()->addDecl(DefaultCBuffer);
createHostLayoutStructForBuffer(SemaRef, DefaultCBuffer);

DeclGroupRef DG(DefaultCBuffer);
SemaRef.Consumer.HandleTopLevelDecl(DG);
}
diagnoseAvailabilityViolations(TU);
}

void SemaHLSL::diagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
// Skip running the diagnostics scan if the diagnostic mode is
// strict (-fhlsl-strict-availability) and the target shader stage is known
// because all relevant diagnostics were already emitted in the
Expand Down Expand Up @@ -2991,6 +3007,14 @@ QualType SemaHLSL::getInoutParameterType(QualType Ty) {
return Ty;
}

static bool IsDefaultBufferConstantDecl(VarDecl *VD) {
QualType QT = VD->getType();
return VD->getDeclContext()->isTranslationUnit() &&
QT.getAddressSpace() == LangAS::Default &&
VD->getStorageClass() != SC_Static &&
!isInvalidConstantBufferLeafElementType(QT.getTypePtr());
}

void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) {
if (VD->hasGlobalStorage()) {
// make sure the declaration has a complete type
Expand All @@ -3002,7 +3026,18 @@ void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) {
return;
}

// find all resources on decl
// Global variables outside a cbuffer block that are not a resource, static,
// groupshared, or an empty array or struct belong to the default constant
// buffer $Globals (to be created at the end of the translation unit).
if (IsDefaultBufferConstantDecl(VD)) {
// update address space to hlsl_constant
QualType NewTy = getASTContext().getAddrSpaceQualType(
VD->getType(), LangAS::hlsl_constant);
VD->setType(NewTy);
DefaultCBufferDecls.push_back(VD);
}

// find all resources bindings on decl
if (VD->getType()->isHLSLIntangibleType())
collectResourceBindingsOnVarDecl(VD);

Expand Down
50 changes: 50 additions & 0 deletions clang/test/AST/HLSL/default_cbuffer.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -ast-dump -o - %s | FileCheck %s

struct EmptyStruct {
};

struct S {
RWBuffer<float> buf;
EmptyStruct es;
float ea[0];
float b;
};

// CHECK: VarDecl {{.*}} used a 'hlsl_constant float'
float a;

// CHECK: VarDecl {{.*}} b 'RWBuffer<float>':'hlsl::RWBuffer<float>'
RWBuffer<float> b;

// CHECK: VarDecl {{.*}} c 'EmptyStruct'
EmptyStruct c;

// CHECK: VarDecl {{.*}} d 'float[0]'
float d[0];

// CHECK: VarDecl {{.*}} e 'RWBuffer<float>[2]'
RWBuffer<float> e[2];

// CHECK: VarDecl {{.*}} f 'groupshared float'
groupshared float f;

// CHECK: VarDecl {{.*}} g 'hlsl_constant float'
float g;

// CHECK: VarDecl {{.*}} h 'hlsl_constant S'
S h;

// CHECK: HLSLBufferDecl {{.*}} implicit cbuffer $Globals
// CHECK: CXXRecordDecl {{.*}} implicit struct __cblayout_$Globals definition
// CHECK: PackedAttr
// CHECK-NEXT: FieldDecl {{.*}} a 'float'
// CHECK-NEXT: FieldDecl {{.*}} g 'float'
// CHECK-NEXT: FieldDecl {{.*}} h '__cblayout_S'

// CHECK: CXXRecordDecl {{.*}} implicit struct __cblayout_S definition
// CHECK: PackedAttr {{.*}} Implicit
// CHECK-NEXT: FieldDecl {{.*}} b 'float'

export float foo() {
return a;
}
64 changes: 32 additions & 32 deletions clang/test/CodeGenHLSL/basic_types.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,38 @@
// RUN: -emit-llvm -disable-llvm-passes -o - -DNAMESPACED| FileCheck %s


// CHECK: @uint16_t_Val = global i16 0, align 2
// CHECK: @int16_t_Val = global i16 0, align 2
// CHECK: @uint_Val = global i32 0, align 4
// CHECK: @uint64_t_Val = global i64 0, align 8
// CHECK: @int64_t_Val = global i64 0, align 8
// CHECK: @int16_t2_Val = global <2 x i16> zeroinitializer, align 4
// CHECK: @int16_t3_Val = global <3 x i16> zeroinitializer, align 8
// CHECK: @int16_t4_Val = global <4 x i16> zeroinitializer, align 8
// CHECK: @uint16_t2_Val = global <2 x i16> zeroinitializer, align 4
// CHECK: @uint16_t3_Val = global <3 x i16> zeroinitializer, align 8
// CHECK: @uint16_t4_Val = global <4 x i16> zeroinitializer, align 8
// CHECK: @int2_Val = global <2 x i32> zeroinitializer, align 8
// CHECK: @int3_Val = global <3 x i32> zeroinitializer, align 16
// CHECK: @int4_Val = global <4 x i32> zeroinitializer, align 16
// CHECK: @uint2_Val = global <2 x i32> zeroinitializer, align 8
// CHECK: @uint3_Val = global <3 x i32> zeroinitializer, align 16
// CHECK: @uint4_Val = global <4 x i32> zeroinitializer, align 16
// CHECK: @int64_t2_Val = global <2 x i64> zeroinitializer, align 16
// CHECK: @int64_t3_Val = global <3 x i64> zeroinitializer, align 32
// CHECK: @int64_t4_Val = global <4 x i64> zeroinitializer, align 32
// CHECK: @uint64_t2_Val = global <2 x i64> zeroinitializer, align 16
// CHECK: @uint64_t3_Val = global <3 x i64> zeroinitializer, align 32
// CHECK: @uint64_t4_Val = global <4 x i64> zeroinitializer, align 32
// CHECK: @half2_Val = global <2 x half> zeroinitializer, align 4
// CHECK: @half3_Val = global <3 x half> zeroinitializer, align 8
// CHECK: @half4_Val = global <4 x half> zeroinitializer, align 8
// CHECK: @float2_Val = global <2 x float> zeroinitializer, align 8
// CHECK: @float3_Val = global <3 x float> zeroinitializer, align 16
// CHECK: @float4_Val = global <4 x float> zeroinitializer, align 16
// CHECK: @double2_Val = global <2 x double> zeroinitializer, align 16
// CHECK: @double3_Val = global <3 x double> zeroinitializer, align 32
// CHECK: @double4_Val = global <4 x double> zeroinitializer, align 32
// CHECK: @uint16_t_Val = external addrspace(2) global i16, align 2
// CHECK: @int16_t_Val = external addrspace(2) global i16, align 2
// CHECK: @uint_Val = external addrspace(2) global i32, align 4
// CHECK: @uint64_t_Val = external addrspace(2) global i64, align 8
// CHECK: @int64_t_Val = external addrspace(2) global i64, align 8
// CHECK: @int16_t2_Val = external addrspace(2) global <2 x i16>, align 4
// CHECK: @int16_t3_Val = external addrspace(2) global <3 x i16>, align 8
// CHECK: @int16_t4_Val = external addrspace(2) global <4 x i16>, align 8
// CHECK: @uint16_t2_Val = external addrspace(2) global <2 x i16>, align 4
// CHECK: @uint16_t3_Val = external addrspace(2) global <3 x i16>, align 8
// CHECK: @uint16_t4_Val = external addrspace(2) global <4 x i16>, align 8
// CHECK: @int2_Val = external addrspace(2) global <2 x i32>, align 8
// CHECK: @int3_Val = external addrspace(2) global <3 x i32>, align 16
// CHECK: @int4_Val = external addrspace(2) global <4 x i32>, align 16
// CHECK: @uint2_Val = external addrspace(2) global <2 x i32>, align 8
// CHECK: @uint3_Val = external addrspace(2) global <3 x i32>, align 16
// CHECK: @uint4_Val = external addrspace(2) global <4 x i32>, align 16
// CHECK: @int64_t2_Val = external addrspace(2) global <2 x i64>, align 16
// CHECK: @int64_t3_Val = external addrspace(2) global <3 x i64>, align 32
// CHECK: @int64_t4_Val = external addrspace(2) global <4 x i64>, align 32
// CHECK: @uint64_t2_Val = external addrspace(2) global <2 x i64>, align 16
// CHECK: @uint64_t3_Val = external addrspace(2) global <3 x i64>, align 32
// CHECK: @uint64_t4_Val = external addrspace(2) global <4 x i64>, align 32
// CHECK: @half2_Val = external addrspace(2) global <2 x half>, align 4
// CHECK: @half3_Val = external addrspace(2) global <3 x half>, align 8
// CHECK: @half4_Val = external addrspace(2) global <4 x half>, align 8
// CHECK: @float2_Val = external addrspace(2) global <2 x float>, align 8
// CHECK: @float3_Val = external addrspace(2) global <3 x float>, align 16
// CHECK: @float4_Val = external addrspace(2) global <4 x float>, align 16
// CHECK: @double2_Val = external addrspace(2) global <2 x double>, align 16
// CHECK: @double3_Val = external addrspace(2) global <3 x double>, align 32
// CHECK: @double4_Val = external addrspace(2) global <4 x double>, align 32

#ifdef NAMESPACED
#define TYPE_DECL(T) hlsl::T T##_Val
Expand Down
Loading