Skip to content

Commit 01d648a

Browse files
authored
[HLSL][SPIRV] Reapply "[HLSL][SPIRV] Add vk::constant_id attribute." (#144902)
- **Reapply "[HLSL][SPIRV] Add vk::constant_id attribute." (#144812)** - **Fix memory leak.**
1 parent f4db142 commit 01d648a

File tree

15 files changed

+637
-2
lines changed

15 files changed

+637
-2
lines changed

clang/include/clang/Basic/Attr.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5023,6 +5023,14 @@ def HLSLVkExtBuiltinInput : InheritableAttr {
50235023
let Documentation = [HLSLVkExtBuiltinInputDocs];
50245024
}
50255025

5026+
def HLSLVkConstantId : InheritableAttr {
5027+
let Spellings = [CXX11<"vk", "constant_id">];
5028+
let Args = [IntArgument<"Id">];
5029+
let Subjects = SubjectList<[ExternalGlobalVar]>;
5030+
let LangOpts = [HLSL];
5031+
let Documentation = [VkConstantIdDocs];
5032+
}
5033+
50265034
def RandomizeLayout : InheritableAttr {
50275035
let Spellings = [GCC<"randomize_layout">];
50285036
let Subjects = SubjectList<[Record]>;

clang/include/clang/Basic/AttrDocs.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8252,6 +8252,21 @@ and https://microsoft.github.io/hlsl-specs/proposals/0013-wave-size-range.html
82528252
}];
82538253
}
82548254

8255+
def VkConstantIdDocs : Documentation {
8256+
let Category = DocCatFunction;
8257+
let Content = [{
8258+
The ``vk::constant_id`` attribute specifies the id for a SPIR-V specialization
8259+
constant. The attribute applies to const global scalar variables. The variable must be initialized with a C++11 constexpr.
8260+
In SPIR-V, the
8261+
variable will be replaced with an `OpSpecConstant` with the given id.
8262+
The syntax is:
8263+
8264+
.. code-block:: text
8265+
8266+
``[[vk::constant_id(<Id>)]] const T Name = <Init>``
8267+
}];
8268+
}
8269+
82558270
def RootSignatureDocs : Documentation {
82568271
let Category = DocCatFunction;
82578272
let Content = [{

clang/include/clang/Basic/Builtins.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5065,6 +5065,19 @@ def HLSLGroupMemoryBarrierWithGroupSync: LangBuiltin<"HLSL_LANG"> {
50655065
let Prototype = "void()";
50665066
}
50675067

5068+
class HLSLScalarTemplate
5069+
: Template<["bool", "char", "short", "int", "long long int",
5070+
"unsigned short", "unsigned int", "unsigned long long int",
5071+
"__fp16", "float", "double"],
5072+
["_bool", "_char", "_short", "_int", "_longlong", "_ushort",
5073+
"_uint", "_ulonglong", "_half", "_float", "_double"]>;
5074+
5075+
def HLSLGetSpirvSpecConstant : LangBuiltin<"HLSL_LANG">, HLSLScalarTemplate {
5076+
let Spellings = ["__builtin_get_spirv_spec_constant"];
5077+
let Attributes = [NoThrow, Const, Pure];
5078+
let Prototype = "T(unsigned int, T)";
5079+
}
5080+
50685081
// Builtins for XRay.
50695082
def XRayCustomEvent : Builtin {
50705083
let Spellings = ["__xray_customevent"];

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12927,6 +12927,10 @@ def err_spirv_enum_not_int : Error<
1292712927
def err_spirv_enum_not_valid : Error<
1292812928
"invalid value for %select{storage class}0 argument">;
1292912929

12930+
def err_specialization_const
12931+
: Error<"variable with 'vk::constant_id' attribute must be a const "
12932+
"int/float/enum/bool and be initialized with a literal">;
12933+
1293012934
// errors of expect.with.probability
1293112935
def err_probability_not_constant_float : Error<
1293212936
"probability argument to __builtin_expect_with_probability must be constant "

clang/include/clang/Sema/SemaHLSL.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ class SemaHLSL : public SemaBase {
9898
HLSLWaveSizeAttr *mergeWaveSizeAttr(Decl *D, const AttributeCommonInfo &AL,
9999
int Min, int Max, int Preferred,
100100
int SpelledArgsCount);
101+
HLSLVkConstantIdAttr *
102+
mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL, int Id);
101103
HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
102104
llvm::Triple::EnvironmentType ShaderType);
103105
HLSLParamModifierAttr *
@@ -135,6 +137,7 @@ class SemaHLSL : public SemaBase {
135137
void handleRootSignatureAttr(Decl *D, const ParsedAttr &AL);
136138
void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
137139
void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
140+
void handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL);
138141
void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
139142
void handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL);
140143
void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL);
@@ -171,7 +174,7 @@ class SemaHLSL : public SemaBase {
171174
QualType getInoutParameterType(QualType Ty);
172175

173176
bool transformInitList(const InitializedEntity &Entity, InitListExpr *Init);
174-
177+
bool handleInitialization(VarDecl *VDecl, Expr *&Init);
175178
void deduceAddressSpace(VarDecl *Decl);
176179

177180
private:

clang/lib/CodeGen/CGHLSLBuiltins.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "CGBuiltin.h"
1414
#include "CGHLSLRuntime.h"
15+
#include "CodeGenFunction.h"
1516

1617
using namespace clang;
1718
using namespace CodeGen;
@@ -214,6 +215,44 @@ static Intrinsic::ID getWaveActiveMaxIntrinsic(llvm::Triple::ArchType Arch,
214215
}
215216
}
216217

218+
// Returns the mangled name for a builtin function that the SPIR-V backend
219+
// will expand into a spec Constant.
220+
static std::string getSpecConstantFunctionName(clang::QualType SpecConstantType,
221+
ASTContext &Context) {
222+
// The parameter types for our conceptual intrinsic function.
223+
QualType ClangParamTypes[] = {Context.IntTy, SpecConstantType};
224+
225+
// Create a temporary FunctionDecl for the builtin fuction. It won't be
226+
// added to the AST.
227+
FunctionProtoType::ExtProtoInfo EPI;
228+
QualType FnType =
229+
Context.getFunctionType(SpecConstantType, ClangParamTypes, EPI);
230+
DeclarationName FuncName = &Context.Idents.get("__spirv_SpecConstant");
231+
FunctionDecl *FnDeclForMangling = FunctionDecl::Create(
232+
Context, Context.getTranslationUnitDecl(), SourceLocation(),
233+
SourceLocation(), FuncName, FnType, /*TSI=*/nullptr, SC_Extern);
234+
235+
// Attach the created parameter declarations to the function declaration.
236+
SmallVector<ParmVarDecl *, 2> ParamDecls;
237+
for (QualType ParamType : ClangParamTypes) {
238+
ParmVarDecl *PD = ParmVarDecl::Create(
239+
Context, FnDeclForMangling, SourceLocation(), SourceLocation(),
240+
/*IdentifierInfo*/ nullptr, ParamType, /*TSI*/ nullptr, SC_None,
241+
/*DefaultArg*/ nullptr);
242+
ParamDecls.push_back(PD);
243+
}
244+
FnDeclForMangling->setParams(ParamDecls);
245+
246+
// Get the mangled name.
247+
std::string Name;
248+
llvm::raw_string_ostream MangledNameStream(Name);
249+
std::unique_ptr<MangleContext> Mangler(Context.createMangleContext());
250+
Mangler->mangleName(FnDeclForMangling, MangledNameStream);
251+
MangledNameStream.flush();
252+
253+
return Name;
254+
}
255+
217256
Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
218257
const CallExpr *E,
219258
ReturnValueSlot ReturnValue) {
@@ -773,6 +812,42 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
773812
return EmitRuntimeCall(
774813
Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID));
775814
}
815+
case Builtin::BI__builtin_get_spirv_spec_constant_bool:
816+
case Builtin::BI__builtin_get_spirv_spec_constant_short:
817+
case Builtin::BI__builtin_get_spirv_spec_constant_ushort:
818+
case Builtin::BI__builtin_get_spirv_spec_constant_int:
819+
case Builtin::BI__builtin_get_spirv_spec_constant_uint:
820+
case Builtin::BI__builtin_get_spirv_spec_constant_longlong:
821+
case Builtin::BI__builtin_get_spirv_spec_constant_ulonglong:
822+
case Builtin::BI__builtin_get_spirv_spec_constant_half:
823+
case Builtin::BI__builtin_get_spirv_spec_constant_float:
824+
case Builtin::BI__builtin_get_spirv_spec_constant_double: {
825+
llvm::Function *SpecConstantFn = getSpecConstantFunction(E->getType());
826+
llvm::Value *SpecId = EmitScalarExpr(E->getArg(0));
827+
llvm::Value *DefaultVal = EmitScalarExpr(E->getArg(1));
828+
llvm::Value *Args[] = {SpecId, DefaultVal};
829+
return Builder.CreateCall(SpecConstantFn, Args);
830+
}
776831
}
777832
return nullptr;
778833
}
834+
835+
llvm::Function *clang::CodeGen::CodeGenFunction::getSpecConstantFunction(
836+
const clang::QualType &SpecConstantType) {
837+
838+
// Find or create the declaration for the function.
839+
llvm::Module *M = &CGM.getModule();
840+
std::string MangledName =
841+
getSpecConstantFunctionName(SpecConstantType, getContext());
842+
llvm::Function *SpecConstantFn = M->getFunction(MangledName);
843+
844+
if (!SpecConstantFn) {
845+
llvm::Type *IntType = ConvertType(getContext().IntTy);
846+
llvm::Type *RetTy = ConvertType(SpecConstantType);
847+
llvm::Type *ArgTypes[] = {IntType, RetTy};
848+
llvm::FunctionType *FnTy = llvm::FunctionType::get(RetTy, ArgTypes, false);
849+
SpecConstantFn = llvm::Function::Create(
850+
FnTy, llvm::GlobalValue::ExternalLinkage, MangledName, M);
851+
}
852+
return SpecConstantFn;
853+
}

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4850,6 +4850,12 @@ class CodeGenFunction : public CodeGenTypeCache {
48504850
llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
48514851
llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E,
48524852
ReturnValueSlot ReturnValue);
4853+
4854+
// Returns a builtin function that the SPIR-V backend will expand into a spec
4855+
// constant.
4856+
llvm::Function *
4857+
getSpecConstantFunction(const clang::QualType &SpecConstantType);
4858+
48534859
llvm::Value *EmitDirectXBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
48544860
llvm::Value *EmitSPIRVBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
48554861
llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx,

clang/lib/Sema/SemaDecl.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2890,6 +2890,8 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D,
28902890
NewAttr = S.HLSL().mergeWaveSizeAttr(D, *WS, WS->getMin(), WS->getMax(),
28912891
WS->getPreferred(),
28922892
WS->getSpelledArgsCount());
2893+
else if (const auto *CI = dyn_cast<HLSLVkConstantIdAttr>(Attr))
2894+
NewAttr = S.HLSL().mergeVkConstantIdAttr(D, *CI, CI->getId());
28932895
else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr))
28942896
NewAttr = S.HLSL().mergeShaderAttr(D, *SA, SA->getType());
28952897
else if (isa<SuppressAttr>(Attr))
@@ -13757,6 +13759,10 @@ void Sema::AddInitializerToDecl(Decl *RealDecl, Expr *Init, bool DirectInit) {
1375713759
return;
1375813760
}
1375913761

13762+
if (getLangOpts().HLSL)
13763+
if (!HLSL().handleInitialization(VDecl, Init))
13764+
return;
13765+
1376013766
// Get the decls type and save a reference for later, since
1376113767
// CheckInitializerTypes may change it.
1376213768
QualType DclT = VDecl->getType(), SavT = DclT;
@@ -14179,6 +14185,13 @@ void Sema::ActOnUninitializedDecl(Decl *RealDecl) {
1417914185
}
1418014186
}
1418114187

14188+
// HLSL variable with the `vk::constant_id` attribute must be initialized.
14189+
if (!Var->isInvalidDecl() && Var->hasAttr<HLSLVkConstantIdAttr>()) {
14190+
Diag(Var->getLocation(), diag::err_specialization_const);
14191+
Var->setInvalidDecl();
14192+
return;
14193+
}
14194+
1418214195
if (!Var->isInvalidDecl() && RealDecl->hasAttr<LoaderUninitializedAttr>()) {
1418314196
if (Var->getStorageClass() == SC_Extern) {
1418414197
Diag(Var->getLocation(), diag::err_loader_uninitialized_extern_decl)

clang/lib/Sema/SemaDeclAttr.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7590,6 +7590,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
75907590
case ParsedAttr::AT_HLSLVkExtBuiltinInput:
75917591
S.HLSL().handleVkExtBuiltinInputAttr(D, AL);
75927592
break;
7593+
case ParsedAttr::AT_HLSLVkConstantId:
7594+
S.HLSL().handleVkConstantIdAttr(D, AL);
7595+
break;
75937596
case ParsedAttr::AT_HLSLSV_GroupThreadID:
75947597
S.HLSL().handleSV_GroupThreadIDAttr(D, AL);
75957598
break;

0 commit comments

Comments
 (0)