Skip to content

Commit 776cdda

Browse files
authored
[HLSL] Implement default constant buffer $Globals (#125807)
All variable declarations in the global scope that are not resources, static or empty are implicitly added to implicit constant buffer `$Globals`. They are created in `hlsl_constant` address space and collected in an implicit `HLSLBufferDecl` node that is added to the AST at the end of the translation unit. Codegen is the same as for explicit constant buffers. Fixes #123801
1 parent 9fa77c1 commit 776cdda

File tree

10 files changed

+240
-41
lines changed

10 files changed

+240
-41
lines changed

clang/include/clang/AST/Decl.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5045,6 +5045,11 @@ class HLSLBufferDecl final : public NamedDecl, public DeclContext {
50455045
// LayoutStruct - Layout struct for the buffer
50465046
CXXRecordDecl *LayoutStruct;
50475047

5048+
// For default (implicit) constant buffer, a lisf of references of global
5049+
// decls that belong to the buffer. The decls are already parented by the
5050+
// translation unit context.
5051+
SmallVector<Decl *> DefaultBufferDecls;
5052+
50485053
HLSLBufferDecl(DeclContext *DC, bool CBuffer, SourceLocation KwLoc,
50495054
IdentifierInfo *ID, SourceLocation IDLoc,
50505055
SourceLocation LBrace);
@@ -5054,6 +5059,8 @@ class HLSLBufferDecl final : public NamedDecl, public DeclContext {
50545059
bool CBuffer, SourceLocation KwLoc,
50555060
IdentifierInfo *ID, SourceLocation IDLoc,
50565061
SourceLocation LBrace);
5062+
static HLSLBufferDecl *CreateDefaultCBuffer(ASTContext &C,
5063+
DeclContext *LexicalParent);
50575064
static HLSLBufferDecl *CreateDeserialized(ASTContext &C, GlobalDeclID ID);
50585065

50595066
SourceRange getSourceRange() const override LLVM_READONLY {
@@ -5068,6 +5075,7 @@ class HLSLBufferDecl final : public NamedDecl, public DeclContext {
50685075
bool hasValidPackoffset() const { return HasValidPackoffset; }
50695076
const CXXRecordDecl *getLayoutStruct() const { return LayoutStruct; }
50705077
void addLayoutStruct(CXXRecordDecl *LS);
5078+
void addDefaultBufferDecl(Decl *D);
50715079

50725080
// Implement isa/cast/dyncast/etc.
50735081
static bool classof(const Decl *D) { return classofKind(D->getKind()); }
@@ -5079,6 +5087,28 @@ class HLSLBufferDecl final : public NamedDecl, public DeclContext {
50795087
return static_cast<HLSLBufferDecl *>(const_cast<DeclContext *>(DC));
50805088
}
50815089

5090+
// Iterator for the buffer decls. For constant buffers explicitly declared
5091+
// with `cbuffer` keyword this will the list of decls parented by this
5092+
// HLSLBufferDecl (equal to `decls()`).
5093+
// For implicit $Globals buffer this will be the list of default buffer
5094+
// declarations stored in DefaultBufferDecls plus the implicit layout
5095+
// struct (the only child of HLSLBufferDecl in this case).
5096+
//
5097+
// The iterator uses llvm::concat_iterator to concatenate the lists
5098+
// `decls()` and `DefaultBufferDecls`. For non-default buffers
5099+
// `DefaultBufferDecls` is always empty.
5100+
using buffer_decl_iterator =
5101+
llvm::concat_iterator<Decl *const, SmallVector<Decl *>::const_iterator,
5102+
decl_iterator>;
5103+
using buffer_decl_range = llvm::iterator_range<buffer_decl_iterator>;
5104+
5105+
buffer_decl_range buffer_decls() const {
5106+
return buffer_decl_range(buffer_decls_begin(), buffer_decls_end());
5107+
}
5108+
buffer_decl_iterator buffer_decls_begin() const;
5109+
buffer_decl_iterator buffer_decls_end() const;
5110+
bool buffer_decls_empty();
5111+
50825112
friend class ASTDeclReader;
50835113
friend class ASTDeclWriter;
50845114
};

clang/include/clang/Sema/SemaHLSL.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,13 @@ class SemaHLSL : public SemaBase {
105105
HLSLParamModifierAttr::Spelling Spelling);
106106
void ActOnTopLevelFunction(FunctionDecl *FD);
107107
void ActOnVariableDeclarator(VarDecl *VD);
108+
void ActOnEndOfTranslationUnit(TranslationUnitDecl *TU);
108109
void CheckEntryPoint(FunctionDecl *FD);
109110
void CheckSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param,
110111
const HLSLAnnotationAttr *AnnotationAttr);
111112
void DiagnoseAttrStageMismatch(
112113
const Attr *A, llvm::Triple::EnvironmentType Stage,
113114
std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages);
114-
void DiagnoseAvailabilityViolations(TranslationUnitDecl *TU);
115115

116116
QualType handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS,
117117
QualType LHSType, QualType RHSType,
@@ -168,11 +168,16 @@ class SemaHLSL : public SemaBase {
168168
// List of all resource bindings
169169
ResourceBindings Bindings;
170170

171+
// default constant buffer $Globals
172+
HLSLBufferDecl *DefaultCBuffer;
173+
171174
private:
172175
void collectResourceBindingsOnVarDecl(VarDecl *D);
173176
void collectResourceBindingsOnUserRecordDecl(const VarDecl *VD,
174177
const RecordType *RT);
175178
void processExplicitBindingsOnDecl(VarDecl *D);
179+
180+
void diagnoseAvailabilityViolations(TranslationUnitDecl *TU);
176181
};
177182

178183
} // namespace clang

clang/lib/AST/Decl.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
#include "llvm/ADT/SmallVector.h"
5858
#include "llvm/ADT/StringRef.h"
5959
#include "llvm/ADT/StringSwitch.h"
60+
#include "llvm/ADT/iterator_range.h"
6061
#include "llvm/Support/Casting.h"
6162
#include "llvm/Support/ErrorHandling.h"
6263
#include "llvm/Support/raw_ostream.h"
@@ -5745,6 +5746,17 @@ HLSLBufferDecl *HLSLBufferDecl::Create(ASTContext &C,
57455746
return Result;
57465747
}
57475748

5749+
HLSLBufferDecl *
5750+
HLSLBufferDecl::CreateDefaultCBuffer(ASTContext &C,
5751+
DeclContext *LexicalParent) {
5752+
DeclContext *DC = LexicalParent;
5753+
IdentifierInfo *II = &C.Idents.get("$Globals", tok::TokenKind::identifier);
5754+
HLSLBufferDecl *Result = new (C, DC) HLSLBufferDecl(
5755+
DC, true, SourceLocation(), II, SourceLocation(), SourceLocation());
5756+
Result->setImplicit(true);
5757+
return Result;
5758+
}
5759+
57485760
HLSLBufferDecl *HLSLBufferDecl::CreateDeserialized(ASTContext &C,
57495761
GlobalDeclID ID) {
57505762
return new (C, ID) HLSLBufferDecl(nullptr, false, SourceLocation(), nullptr,
@@ -5757,6 +5769,30 @@ void HLSLBufferDecl::addLayoutStruct(CXXRecordDecl *LS) {
57575769
addDecl(LS);
57585770
}
57595771

5772+
void HLSLBufferDecl::addDefaultBufferDecl(Decl *D) {
5773+
assert(isImplicit() &&
5774+
"default decls can only be added to the implicit/default constant "
5775+
"buffer $Globals");
5776+
DefaultBufferDecls.push_back(D);
5777+
}
5778+
5779+
HLSLBufferDecl::buffer_decl_iterator
5780+
HLSLBufferDecl::buffer_decls_begin() const {
5781+
return buffer_decl_iterator(llvm::iterator_range(DefaultBufferDecls.begin(),
5782+
DefaultBufferDecls.end()),
5783+
decl_range(decls_begin(), decls_end()));
5784+
}
5785+
5786+
HLSLBufferDecl::buffer_decl_iterator HLSLBufferDecl::buffer_decls_end() const {
5787+
return buffer_decl_iterator(
5788+
llvm::iterator_range(DefaultBufferDecls.end(), DefaultBufferDecls.end()),
5789+
decl_range(decls_end(), decls_end()));
5790+
}
5791+
5792+
bool HLSLBufferDecl::buffer_decls_empty() {
5793+
return DefaultBufferDecls.empty() && decls_empty();
5794+
}
5795+
57605796
//===----------------------------------------------------------------------===//
57615797
// ImportDecl Implementation
57625798
//===----------------------------------------------------------------------===//

clang/lib/CodeGen/CGHLSLRuntime.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ void CGHLSLRuntime::emitBufferGlobalsAndMetadata(const HLSLBufferDecl *BufDecl,
116116
BufGlobals.push_back(ValueAsMetadata::get(BufGV));
117117

118118
const auto *ElemIt = LayoutStruct->element_begin();
119-
for (Decl *D : BufDecl->decls()) {
119+
for (Decl *D : BufDecl->buffer_decls()) {
120120
if (isa<CXXRecordDecl, EmptyDecl>(D))
121121
// Nothing to do for this declaration.
122122
continue;

clang/lib/CodeGen/CodeGenModule.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5513,6 +5513,11 @@ void CodeGenModule::EmitGlobalVarDefinition(const VarDecl *D,
55135513
if (getLangOpts().OpenCL && ASTTy->isSamplerT())
55145514
return;
55155515

5516+
// HLSL default buffer constants will be emitted during HLSLBufferDecl codegen
5517+
if (getLangOpts().HLSL &&
5518+
D->getType().getAddressSpace() == LangAS::hlsl_constant)
5519+
return;
5520+
55165521
// If this is OpenMP device, check if it is legal to emit this global
55175522
// normally.
55185523
if (LangOpts.OpenMPIsTargetDevice && OpenMPRuntime &&

clang/lib/Sema/Sema.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,8 +1417,7 @@ void Sema::ActOnEndOfTranslationUnit() {
14171417
}
14181418

14191419
if (LangOpts.HLSL)
1420-
HLSL().DiagnoseAvailabilityViolations(
1421-
getASTContext().getTranslationUnitDecl());
1420+
HLSL().ActOnEndOfTranslationUnit(getASTContext().getTranslationUnitDecl());
14221421

14231422
// If there were errors, disable 'unused' warnings since they will mostly be
14241423
// noise. Don't warn for a use from a module: either we should warn on all

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
//===----------------------------------------------------------------------===//
1010

1111
#include "clang/Sema/SemaHLSL.h"
12+
#include "clang/AST/ASTConsumer.h"
1213
#include "clang/AST/ASTContext.h"
1314
#include "clang/AST/Attr.h"
1415
#include "clang/AST/Attrs.inc"
@@ -147,7 +148,7 @@ bool ResourceBindings::hasBindingInfoForDecl(const VarDecl *VD) const {
147148
return DeclToBindingListIndex.contains(VD);
148149
}
149150

150-
SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {}
151+
SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S), DefaultCBuffer(nullptr) {}
151152

152153
Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer,
153154
SourceLocation KwLoc, IdentifierInfo *Ident,
@@ -225,7 +226,7 @@ static void validatePackoffset(Sema &S, HLSLBufferDecl *BufDecl) {
225226
// or on none.
226227
bool HasPackOffset = false;
227228
bool HasNonPackOffset = false;
228-
for (auto *Field : BufDecl->decls()) {
229+
for (auto *Field : BufDecl->buffer_decls()) {
229230
VarDecl *Var = dyn_cast<VarDecl>(Field);
230231
if (!Var)
231232
continue;
@@ -492,7 +493,7 @@ void createHostLayoutStructForBuffer(Sema &S, HLSLBufferDecl *BufDecl) {
492493
LS->setImplicit(true);
493494
LS->startDefinition();
494495

495-
for (Decl *D : BufDecl->decls()) {
496+
for (Decl *D : BufDecl->buffer_decls()) {
496497
VarDecl *VD = dyn_cast<VarDecl>(D);
497498
if (!VD || VD->getStorageClass() == SC_Static ||
498499
VD->getType().getAddressSpace() == LangAS::hlsl_groupshared)
@@ -1928,7 +1929,19 @@ void DiagnoseHLSLAvailability::CheckDeclAvailability(NamedDecl *D,
19281929

19291930
} // namespace
19301931

1931-
void SemaHLSL::DiagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
1932+
void SemaHLSL::ActOnEndOfTranslationUnit(TranslationUnitDecl *TU) {
1933+
// process default CBuffer - create buffer layout struct and invoke codegenCGH
1934+
if (DefaultCBuffer) {
1935+
SemaRef.getCurLexicalContext()->addDecl(DefaultCBuffer);
1936+
createHostLayoutStructForBuffer(SemaRef, DefaultCBuffer);
1937+
1938+
DeclGroupRef DG(DefaultCBuffer);
1939+
SemaRef.Consumer.HandleTopLevelDecl(DG);
1940+
}
1941+
diagnoseAvailabilityViolations(TU);
1942+
}
1943+
1944+
void SemaHLSL::diagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
19321945
// Skip running the diagnostics scan if the diagnostic mode is
19331946
// strict (-fhlsl-strict-availability) and the target shader stage is known
19341947
// because all relevant diagnostics were already emitted in the
@@ -2991,6 +3004,14 @@ QualType SemaHLSL::getInoutParameterType(QualType Ty) {
29913004
return Ty;
29923005
}
29933006

3007+
static bool IsDefaultBufferConstantDecl(VarDecl *VD) {
3008+
QualType QT = VD->getType();
3009+
return VD->getDeclContext()->isTranslationUnit() &&
3010+
QT.getAddressSpace() == LangAS::Default &&
3011+
VD->getStorageClass() != SC_Static &&
3012+
!isInvalidConstantBufferLeafElementType(QT.getTypePtr());
3013+
}
3014+
29943015
void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) {
29953016
if (VD->hasGlobalStorage()) {
29963017
// make sure the declaration has a complete type
@@ -3002,7 +3023,21 @@ void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) {
30023023
return;
30033024
}
30043025

3005-
// find all resources on decl
3026+
// Global variables outside a cbuffer block that are not a resource, static,
3027+
// groupshared, or an empty array or struct belong to the default constant
3028+
// buffer $Globals
3029+
if (IsDefaultBufferConstantDecl(VD)) {
3030+
if (DefaultCBuffer == nullptr)
3031+
DefaultCBuffer = HLSLBufferDecl::CreateDefaultCBuffer(
3032+
SemaRef.getASTContext(), SemaRef.getCurLexicalContext());
3033+
// update address space to hlsl_constant
3034+
QualType NewTy = getASTContext().getAddrSpaceQualType(
3035+
VD->getType(), LangAS::hlsl_constant);
3036+
VD->setType(NewTy);
3037+
DefaultCBuffer->addDefaultBufferDecl(VD);
3038+
}
3039+
3040+
// find all resources bindings on decl
30063041
if (VD->getType()->isHLSLIntangibleType())
30073042
collectResourceBindingsOnVarDecl(VD);
30083043

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -ast-dump -o - %s | FileCheck %s
2+
3+
struct EmptyStruct {
4+
};
5+
6+
struct S {
7+
RWBuffer<float> buf;
8+
EmptyStruct es;
9+
float ea[0];
10+
float b;
11+
};
12+
13+
// CHECK: VarDecl {{.*}} used a 'hlsl_constant float'
14+
float a;
15+
16+
// CHECK: VarDecl {{.*}} b 'RWBuffer<float>':'hlsl::RWBuffer<float>'
17+
RWBuffer<float> b;
18+
19+
// CHECK: VarDecl {{.*}} c 'EmptyStruct'
20+
EmptyStruct c;
21+
22+
// CHECK: VarDecl {{.*}} d 'float[0]'
23+
float d[0];
24+
25+
// CHECK: VarDecl {{.*}} e 'RWBuffer<float>[2]'
26+
RWBuffer<float> e[2];
27+
28+
// CHECK: VarDecl {{.*}} f 'groupshared float'
29+
groupshared float f;
30+
31+
// CHECK: VarDecl {{.*}} g 'hlsl_constant float'
32+
float g;
33+
34+
// CHECK: VarDecl {{.*}} h 'hlsl_constant S'
35+
S h;
36+
37+
// CHECK: HLSLBufferDecl {{.*}} implicit cbuffer $Globals
38+
// CHECK: CXXRecordDecl {{.*}} implicit struct __cblayout_$Globals definition
39+
// CHECK: PackedAttr
40+
// CHECK-NEXT: FieldDecl {{.*}} a 'float'
41+
// CHECK-NEXT: FieldDecl {{.*}} g 'float'
42+
// CHECK-NEXT: FieldDecl {{.*}} h '__cblayout_S'
43+
44+
// CHECK: CXXRecordDecl {{.*}} implicit struct __cblayout_S definition
45+
// CHECK: PackedAttr {{.*}} Implicit
46+
// CHECK-NEXT: FieldDecl {{.*}} b 'float'
47+
48+
export float foo() {
49+
return a;
50+
}

clang/test/CodeGenHLSL/basic_types.hlsl

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,38 +6,38 @@
66
// RUN: -emit-llvm -disable-llvm-passes -o - -DNAMESPACED| FileCheck %s
77

88

9-
// CHECK: @uint16_t_Val = global i16 0, align 2
10-
// CHECK: @int16_t_Val = global i16 0, align 2
11-
// CHECK: @uint_Val = global i32 0, align 4
12-
// CHECK: @uint64_t_Val = global i64 0, align 8
13-
// CHECK: @int64_t_Val = global i64 0, align 8
14-
// CHECK: @int16_t2_Val = global <2 x i16> zeroinitializer, align 4
15-
// CHECK: @int16_t3_Val = global <3 x i16> zeroinitializer, align 8
16-
// CHECK: @int16_t4_Val = global <4 x i16> zeroinitializer, align 8
17-
// CHECK: @uint16_t2_Val = global <2 x i16> zeroinitializer, align 4
18-
// CHECK: @uint16_t3_Val = global <3 x i16> zeroinitializer, align 8
19-
// CHECK: @uint16_t4_Val = global <4 x i16> zeroinitializer, align 8
20-
// CHECK: @int2_Val = global <2 x i32> zeroinitializer, align 8
21-
// CHECK: @int3_Val = global <3 x i32> zeroinitializer, align 16
22-
// CHECK: @int4_Val = global <4 x i32> zeroinitializer, align 16
23-
// CHECK: @uint2_Val = global <2 x i32> zeroinitializer, align 8
24-
// CHECK: @uint3_Val = global <3 x i32> zeroinitializer, align 16
25-
// CHECK: @uint4_Val = global <4 x i32> zeroinitializer, align 16
26-
// CHECK: @int64_t2_Val = global <2 x i64> zeroinitializer, align 16
27-
// CHECK: @int64_t3_Val = global <3 x i64> zeroinitializer, align 32
28-
// CHECK: @int64_t4_Val = global <4 x i64> zeroinitializer, align 32
29-
// CHECK: @uint64_t2_Val = global <2 x i64> zeroinitializer, align 16
30-
// CHECK: @uint64_t3_Val = global <3 x i64> zeroinitializer, align 32
31-
// CHECK: @uint64_t4_Val = global <4 x i64> zeroinitializer, align 32
32-
// CHECK: @half2_Val = global <2 x half> zeroinitializer, align 4
33-
// CHECK: @half3_Val = global <3 x half> zeroinitializer, align 8
34-
// CHECK: @half4_Val = global <4 x half> zeroinitializer, align 8
35-
// CHECK: @float2_Val = global <2 x float> zeroinitializer, align 8
36-
// CHECK: @float3_Val = global <3 x float> zeroinitializer, align 16
37-
// CHECK: @float4_Val = global <4 x float> zeroinitializer, align 16
38-
// CHECK: @double2_Val = global <2 x double> zeroinitializer, align 16
39-
// CHECK: @double3_Val = global <3 x double> zeroinitializer, align 32
40-
// CHECK: @double4_Val = global <4 x double> zeroinitializer, align 32
9+
// CHECK: @uint16_t_Val = external addrspace(2) global i16, align 2
10+
// CHECK: @int16_t_Val = external addrspace(2) global i16, align 2
11+
// CHECK: @uint_Val = external addrspace(2) global i32, align 4
12+
// CHECK: @uint64_t_Val = external addrspace(2) global i64, align 8
13+
// CHECK: @int64_t_Val = external addrspace(2) global i64, align 8
14+
// CHECK: @int16_t2_Val = external addrspace(2) global <2 x i16>, align 4
15+
// CHECK: @int16_t3_Val = external addrspace(2) global <3 x i16>, align 8
16+
// CHECK: @int16_t4_Val = external addrspace(2) global <4 x i16>, align 8
17+
// CHECK: @uint16_t2_Val = external addrspace(2) global <2 x i16>, align 4
18+
// CHECK: @uint16_t3_Val = external addrspace(2) global <3 x i16>, align 8
19+
// CHECK: @uint16_t4_Val = external addrspace(2) global <4 x i16>, align 8
20+
// CHECK: @int2_Val = external addrspace(2) global <2 x i32>, align 8
21+
// CHECK: @int3_Val = external addrspace(2) global <3 x i32>, align 16
22+
// CHECK: @int4_Val = external addrspace(2) global <4 x i32>, align 16
23+
// CHECK: @uint2_Val = external addrspace(2) global <2 x i32>, align 8
24+
// CHECK: @uint3_Val = external addrspace(2) global <3 x i32>, align 16
25+
// CHECK: @uint4_Val = external addrspace(2) global <4 x i32>, align 16
26+
// CHECK: @int64_t2_Val = external addrspace(2) global <2 x i64>, align 16
27+
// CHECK: @int64_t3_Val = external addrspace(2) global <3 x i64>, align 32
28+
// CHECK: @int64_t4_Val = external addrspace(2) global <4 x i64>, align 32
29+
// CHECK: @uint64_t2_Val = external addrspace(2) global <2 x i64>, align 16
30+
// CHECK: @uint64_t3_Val = external addrspace(2) global <3 x i64>, align 32
31+
// CHECK: @uint64_t4_Val = external addrspace(2) global <4 x i64>, align 32
32+
// CHECK: @half2_Val = external addrspace(2) global <2 x half>, align 4
33+
// CHECK: @half3_Val = external addrspace(2) global <3 x half>, align 8
34+
// CHECK: @half4_Val = external addrspace(2) global <4 x half>, align 8
35+
// CHECK: @float2_Val = external addrspace(2) global <2 x float>, align 8
36+
// CHECK: @float3_Val = external addrspace(2) global <3 x float>, align 16
37+
// CHECK: @float4_Val = external addrspace(2) global <4 x float>, align 16
38+
// CHECK: @double2_Val = external addrspace(2) global <2 x double>, align 16
39+
// CHECK: @double3_Val = external addrspace(2) global <3 x double>, align 32
40+
// CHECK: @double4_Val = external addrspace(2) global <4 x double>, align 32
4141

4242
#ifdef NAMESPACED
4343
#define TYPE_DECL(T) hlsl::T T##_Val

0 commit comments

Comments
 (0)