Skip to content

[HLSL][SPIRV] Reapply "[HLSL][SPIRV] Add vk::constant_id attribute." #144902

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 19, 2025

Conversation

s-perron
Copy link
Contributor

@s-perron s-perron requested a review from Keenuts June 19, 2025 14:41
@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:codegen IR generation bugs: mangling, exceptions, etc. HLSL HLSL Language Support labels Jun 19, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 19, 2025

@llvm/pr-subscribers-clang-codegen

Author: Steven Perron (s-perron)

Changes
  • Reapply "[HLSL][SPIRV] Add vk::constant_id attribute." (#144812)
  • Fix memory leak.

Patch is 38.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144902.diff

15 Files Affected:

  • (modified) clang/include/clang/Basic/Attr.td (+8)
  • (modified) clang/include/clang/Basic/AttrDocs.td (+15)
  • (modified) clang/include/clang/Basic/Builtins.td (+13)
  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+4)
  • (modified) clang/include/clang/Sema/SemaHLSL.h (+4-1)
  • (modified) clang/lib/CodeGen/CGHLSLBuiltins.cpp (+75)
  • (modified) clang/lib/CodeGen/CodeGenFunction.h (+6)
  • (modified) clang/lib/Sema/SemaDecl.cpp (+13)
  • (modified) clang/lib/Sema/SemaDeclAttr.cpp (+3)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+119-1)
  • (added) clang/test/AST/HLSL/vk.spec-constant.usage.hlsl (+130)
  • (renamed) clang/test/CodeGenHLSL/vk-features/SpirvType.alignment.hlsl ()
  • (renamed) clang/test/CodeGenHLSL/vk-features/SpirvType.hlsl ()
  • (added) clang/test/CodeGenHLSL/vk-features/vk.spec-constant.hlsl (+210)
  • (added) clang/test/SemaHLSL/vk.spec-constant.error.hlsl (+37)
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index f113cd2ba2fbf..27fea7dea0a5e 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -5023,6 +5023,14 @@ def HLSLVkExtBuiltinInput : InheritableAttr {
   let Documentation = [HLSLVkExtBuiltinInputDocs];
 }
 
+def HLSLVkConstantId : InheritableAttr {
+  let Spellings = [CXX11<"vk", "constant_id">];
+  let Args = [IntArgument<"Id">];
+  let Subjects = SubjectList<[ExternalGlobalVar]>;
+  let LangOpts = [HLSL];
+  let Documentation = [VkConstantIdDocs];
+}
+
 def RandomizeLayout : InheritableAttr {
   let Spellings = [GCC<"randomize_layout">];
   let Subjects = SubjectList<[Record]>;
diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index 6051e1fc45111..43442f177ab7b 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -8252,6 +8252,21 @@ and https://microsoft.github.io/hlsl-specs/proposals/0013-wave-size-range.html
   }];
 }
 
+def VkConstantIdDocs : Documentation {
+  let Category = DocCatFunction;
+  let Content = [{
+The ``vk::constant_id`` attribute specifies the id for a SPIR-V specialization
+constant. The attribute applies to const global scalar variables. The variable must be initialized with a C++11 constexpr.
+In SPIR-V, the
+variable will be replaced with an `OpSpecConstant` with the given id.
+The syntax is:
+
+.. code-block:: text
+
+  ``[[vk::constant_id(<Id>)]] const T Name = <Init>``
+}];
+}
+
 def RootSignatureDocs : Documentation {
   let Category = DocCatFunction;
   let Content = [{
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 68cd3d790e78a..d65b3a5d2f447 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -5065,6 +5065,19 @@ def HLSLGroupMemoryBarrierWithGroupSync: LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void()";
 }
 
+class HLSLScalarTemplate
+    : Template<["bool", "char", "short", "int", "long long int",
+                "unsigned short", "unsigned int", "unsigned long long int",
+                "__fp16", "float", "double"],
+               ["_bool", "_char", "_short", "_int", "_longlong", "_ushort",
+                "_uint", "_ulonglong", "_half", "_float", "_double"]>;
+
+def HLSLGetSpirvSpecConstant : LangBuiltin<"HLSL_LANG">, HLSLScalarTemplate {
+  let Spellings = ["__builtin_get_spirv_spec_constant"];
+  let Attributes = [NoThrow, Const, Pure];
+  let Prototype = "T(unsigned int, T)";
+}
+
 // Builtins for XRay.
 def XRayCustomEvent : Builtin {
   let Spellings = ["__xray_customevent"];
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 979ff60b73b75..34b798a09c216 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12927,6 +12927,10 @@ def err_spirv_enum_not_int : Error<
 def err_spirv_enum_not_valid : Error<
    "invalid value for %select{storage class}0 argument">;
 
+def err_specialization_const
+    : Error<"variable with 'vk::constant_id' attribute must be a const "
+            "int/float/enum/bool and be initialized with a literal">;
+
 // errors of expect.with.probability
 def err_probability_not_constant_float : Error<
    "probability argument to __builtin_expect_with_probability must be constant "
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index 33c4b8d1568bf..97091792ba236 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -98,6 +98,8 @@ class SemaHLSL : public SemaBase {
   HLSLWaveSizeAttr *mergeWaveSizeAttr(Decl *D, const AttributeCommonInfo &AL,
                                       int Min, int Max, int Preferred,
                                       int SpelledArgsCount);
+  HLSLVkConstantIdAttr *
+  mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL, int Id);
   HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
                                   llvm::Triple::EnvironmentType ShaderType);
   HLSLParamModifierAttr *
@@ -135,6 +137,7 @@ class SemaHLSL : public SemaBase {
   void handleRootSignatureAttr(Decl *D, const ParsedAttr &AL);
   void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
   void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
+  void handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL);
   void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
   void handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL);
   void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL);
@@ -171,7 +174,7 @@ class SemaHLSL : public SemaBase {
   QualType getInoutParameterType(QualType Ty);
 
   bool transformInitList(const InitializedEntity &Entity, InitListExpr *Init);
-
+  bool handleInitialization(VarDecl *VDecl, Expr *&Init);
   void deduceAddressSpace(VarDecl *Decl);
 
 private:
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index ccf45c0c6ff1d..2a60a0909c93e 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -12,6 +12,7 @@
 
 #include "CGBuiltin.h"
 #include "CGHLSLRuntime.h"
+#include "CodeGenFunction.h"
 
 using namespace clang;
 using namespace CodeGen;
@@ -214,6 +215,44 @@ static Intrinsic::ID getWaveActiveMaxIntrinsic(llvm::Triple::ArchType Arch,
   }
 }
 
+// Returns the mangled name for a builtin function that the SPIR-V backend
+// will expand into a spec Constant.
+static std::string getSpecConstantFunctionName(clang::QualType SpecConstantType,
+                                               ASTContext &Context) {
+  // The parameter types for our conceptual intrinsic function.
+  QualType ClangParamTypes[] = {Context.IntTy, SpecConstantType};
+
+  // Create a temporary FunctionDecl for the builtin fuction. It won't be
+  // added to the AST.
+  FunctionProtoType::ExtProtoInfo EPI;
+  QualType FnType =
+      Context.getFunctionType(SpecConstantType, ClangParamTypes, EPI);
+  DeclarationName FuncName = &Context.Idents.get("__spirv_SpecConstant");
+  FunctionDecl *FnDeclForMangling = FunctionDecl::Create(
+      Context, Context.getTranslationUnitDecl(), SourceLocation(),
+      SourceLocation(), FuncName, FnType, /*TSI=*/nullptr, SC_Extern);
+
+  // Attach the created parameter declarations to the function declaration.
+  SmallVector<ParmVarDecl *, 2> ParamDecls;
+  for (QualType ParamType : ClangParamTypes) {
+    ParmVarDecl *PD = ParmVarDecl::Create(
+        Context, FnDeclForMangling, SourceLocation(), SourceLocation(),
+        /*IdentifierInfo*/ nullptr, ParamType, /*TSI*/ nullptr, SC_None,
+        /*DefaultArg*/ nullptr);
+    ParamDecls.push_back(PD);
+  }
+  FnDeclForMangling->setParams(ParamDecls);
+
+  // Get the mangled name.
+  std::string Name;
+  llvm::raw_string_ostream MangledNameStream(Name);
+  std::unique_ptr<MangleContext> Mangler(Context.createMangleContext());
+  Mangler->mangleName(FnDeclForMangling, MangledNameStream);
+  MangledNameStream.flush();
+
+  return Name;
+}
+
 Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
                                             const CallExpr *E,
                                             ReturnValueSlot ReturnValue) {
@@ -773,6 +812,42 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
     return EmitRuntimeCall(
         Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID));
   }
+  case Builtin::BI__builtin_get_spirv_spec_constant_bool:
+  case Builtin::BI__builtin_get_spirv_spec_constant_short:
+  case Builtin::BI__builtin_get_spirv_spec_constant_ushort:
+  case Builtin::BI__builtin_get_spirv_spec_constant_int:
+  case Builtin::BI__builtin_get_spirv_spec_constant_uint:
+  case Builtin::BI__builtin_get_spirv_spec_constant_longlong:
+  case Builtin::BI__builtin_get_spirv_spec_constant_ulonglong:
+  case Builtin::BI__builtin_get_spirv_spec_constant_half:
+  case Builtin::BI__builtin_get_spirv_spec_constant_float:
+  case Builtin::BI__builtin_get_spirv_spec_constant_double: {
+    llvm::Function *SpecConstantFn = getSpecConstantFunction(E->getType());
+    llvm::Value *SpecId = EmitScalarExpr(E->getArg(0));
+    llvm::Value *DefaultVal = EmitScalarExpr(E->getArg(1));
+    llvm::Value *Args[] = {SpecId, DefaultVal};
+    return Builder.CreateCall(SpecConstantFn, Args);
+  }
   }
   return nullptr;
 }
+
+llvm::Function *clang::CodeGen::CodeGenFunction::getSpecConstantFunction(
+    const clang::QualType &SpecConstantType) {
+
+  // Find or create the declaration for the function.
+  llvm::Module *M = &CGM.getModule();
+  std::string MangledName =
+      getSpecConstantFunctionName(SpecConstantType, getContext());
+  llvm::Function *SpecConstantFn = M->getFunction(MangledName);
+
+  if (!SpecConstantFn) {
+    llvm::Type *IntType = ConvertType(getContext().IntTy);
+    llvm::Type *RetTy = ConvertType(SpecConstantType);
+    llvm::Type *ArgTypes[] = {IntType, RetTy};
+    llvm::FunctionType *FnTy = llvm::FunctionType::get(RetTy, ArgTypes, false);
+    SpecConstantFn = llvm::Function::Create(
+        FnTy, llvm::GlobalValue::ExternalLinkage, MangledName, M);
+  }
+  return SpecConstantFn;
+}
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index a5ab9df01dba9..59f14b3e35fd0 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4850,6 +4850,12 @@ class CodeGenFunction : public CodeGenTypeCache {
   llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E,
                                    ReturnValueSlot ReturnValue);
+
+  // Returns a builtin function that the SPIR-V backend will expand into a spec
+  // constant.
+  llvm::Function *
+  getSpecConstantFunction(const clang::QualType &SpecConstantType);
+
   llvm::Value *EmitDirectXBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitSPIRVBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx,
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 1bf72e5bb7b9d..e1cccf068b5aa 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -2890,6 +2890,8 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D,
     NewAttr = S.HLSL().mergeWaveSizeAttr(D, *WS, WS->getMin(), WS->getMax(),
                                          WS->getPreferred(),
                                          WS->getSpelledArgsCount());
+  else if (const auto *CI = dyn_cast<HLSLVkConstantIdAttr>(Attr))
+    NewAttr = S.HLSL().mergeVkConstantIdAttr(D, *CI, CI->getId());
   else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr))
     NewAttr = S.HLSL().mergeShaderAttr(D, *SA, SA->getType());
   else if (isa<SuppressAttr>(Attr))
@@ -13757,6 +13759,10 @@ void Sema::AddInitializerToDecl(Decl *RealDecl, Expr *Init, bool DirectInit) {
     return;
   }
 
+  if (getLangOpts().HLSL)
+    if (!HLSL().handleInitialization(VDecl, Init))
+      return;
+
   // Get the decls type and save a reference for later, since
   // CheckInitializerTypes may change it.
   QualType DclT = VDecl->getType(), SavT = DclT;
@@ -14179,6 +14185,13 @@ void Sema::ActOnUninitializedDecl(Decl *RealDecl) {
       }
     }
 
+    // HLSL variable with the `vk::constant_id` attribute must be initialized.
+    if (!Var->isInvalidDecl() && Var->hasAttr<HLSLVkConstantIdAttr>()) {
+      Diag(Var->getLocation(), diag::err_specialization_const);
+      Var->setInvalidDecl();
+      return;
+    }
+
     if (!Var->isInvalidDecl() && RealDecl->hasAttr<LoaderUninitializedAttr>()) {
       if (Var->getStorageClass() == SC_Extern) {
         Diag(Var->getLocation(), diag::err_loader_uninitialized_extern_decl)
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 1c2fa80e782d4..eba29e609cb05 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -7590,6 +7590,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
   case ParsedAttr::AT_HLSLVkExtBuiltinInput:
     S.HLSL().handleVkExtBuiltinInputAttr(D, AL);
     break;
+  case ParsedAttr::AT_HLSLVkConstantId:
+    S.HLSL().handleVkConstantIdAttr(D, AL);
+    break;
   case ParsedAttr::AT_HLSLSV_GroupThreadID:
     S.HLSL().handleSV_GroupThreadIDAttr(D, AL);
     break;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index b55f4fd786b58..9b43ee00810b2 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -119,6 +119,40 @@ static ResourceClass getResourceClass(RegisterType RT) {
   llvm_unreachable("unexpected RegisterType value");
 }
 
+static Builtin::ID getSpecConstBuiltinId(QualType Type) {
+  const auto *BT = dyn_cast<BuiltinType>(Type);
+  if (!BT) {
+    if (!Type->isEnumeralType())
+      return Builtin::NotBuiltin;
+    return Builtin::BI__builtin_get_spirv_spec_constant_int;
+  }
+
+  switch (BT->getKind()) {
+  case BuiltinType::Bool:
+    return Builtin::BI__builtin_get_spirv_spec_constant_bool;
+  case BuiltinType::Short:
+    return Builtin::BI__builtin_get_spirv_spec_constant_short;
+  case BuiltinType::Int:
+    return Builtin::BI__builtin_get_spirv_spec_constant_int;
+  case BuiltinType::LongLong:
+    return Builtin::BI__builtin_get_spirv_spec_constant_longlong;
+  case BuiltinType::UShort:
+    return Builtin::BI__builtin_get_spirv_spec_constant_ushort;
+  case BuiltinType::UInt:
+    return Builtin::BI__builtin_get_spirv_spec_constant_uint;
+  case BuiltinType::ULongLong:
+    return Builtin::BI__builtin_get_spirv_spec_constant_ulonglong;
+  case BuiltinType::Half:
+    return Builtin::BI__builtin_get_spirv_spec_constant_half;
+  case BuiltinType::Float:
+    return Builtin::BI__builtin_get_spirv_spec_constant_float;
+  case BuiltinType::Double:
+    return Builtin::BI__builtin_get_spirv_spec_constant_double;
+  default:
+    return Builtin::NotBuiltin;
+  }
+}
+
 DeclBindingInfo *ResourceBindings::addDeclBindingInfo(const VarDecl *VD,
                                                       ResourceClass ResClass) {
   assert(getDeclBindingInfo(VD, ResClass) == nullptr &&
@@ -607,6 +641,41 @@ HLSLWaveSizeAttr *SemaHLSL::mergeWaveSizeAttr(Decl *D,
   return Result;
 }
 
+HLSLVkConstantIdAttr *
+SemaHLSL::mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL,
+                                int Id) {
+
+  auto &TargetInfo = getASTContext().getTargetInfo();
+  if (TargetInfo.getTriple().getArch() != llvm::Triple::spirv) {
+    Diag(AL.getLoc(), diag::warn_attribute_ignored) << AL;
+    return nullptr;
+  }
+
+  auto *VD = cast<VarDecl>(D);
+
+  if (getSpecConstBuiltinId(VD->getType()) == Builtin::NotBuiltin) {
+    Diag(VD->getLocation(), diag::err_specialization_const);
+    return nullptr;
+  }
+
+  if (!VD->getType().isConstQualified()) {
+    Diag(VD->getLocation(), diag::err_specialization_const);
+    return nullptr;
+  }
+
+  if (HLSLVkConstantIdAttr *CI = D->getAttr<HLSLVkConstantIdAttr>()) {
+    if (CI->getId() != Id) {
+      Diag(CI->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
+      Diag(AL.getLoc(), diag::note_conflicting_attribute);
+    }
+    return nullptr;
+  }
+
+  HLSLVkConstantIdAttr *Result =
+      ::new (getASTContext()) HLSLVkConstantIdAttr(getASTContext(), AL, Id);
+  return Result;
+}
+
 HLSLShaderAttr *
 SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
                           llvm::Triple::EnvironmentType ShaderType) {
@@ -1157,6 +1226,15 @@ void SemaHLSL::handleVkExtBuiltinInputAttr(Decl *D, const ParsedAttr &AL) {
                  HLSLVkExtBuiltinInputAttr(getASTContext(), AL, ID));
 }
 
+void SemaHLSL::handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL) {
+  uint32_t Id;
+  if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Id))
+    return;
+  HLSLVkConstantIdAttr *NewAttr = mergeVkConstantIdAttr(D, AL, Id);
+  if (NewAttr)
+    D->addAttr(NewAttr);
+}
+
 bool SemaHLSL::diagnoseInputIDType(QualType T, const ParsedAttr &AL) {
   const auto *VT = T->getAs<VectorType>();
 
@@ -3206,6 +3284,7 @@ static bool IsDefaultBufferConstantDecl(VarDecl *VD) {
   return VD->getDeclContext()->isTranslationUnit() &&
          QT.getAddressSpace() == LangAS::Default &&
          VD->getStorageClass() != SC_Static &&
+         !VD->hasAttr<HLSLVkConstantIdAttr>() &&
          !isInvalidConstantBufferLeafElementType(QT.getTypePtr());
 }
 
@@ -3273,7 +3352,8 @@ void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) {
     const Type *VarType = VD->getType().getTypePtr();
     while (VarType->isArrayType())
       VarType = VarType->getArrayElementTypeNoTypeQual();
-    if (VarType->isHLSLResourceRecord()) {
+    if (VarType->isHLSLResourceRecord() ||
+        VD->hasAttr<HLSLVkConstantIdAttr>()) {
       // Make the variable for resources static. The global externally visible
       // storage is accessed through the handle, which is a member. The variable
       // itself is not externally visible.
@@ -3696,3 +3776,41 @@ bool SemaHLSL::transformInitList(const InitializedEntity &Entity,
     Init->updateInit(Ctx, I, NewInit->getInit(I));
   return true;
 }
+
+bool SemaHLSL::handleInitialization(VarDecl *VDecl, Expr *&Init) {
+  const HLSLVkConstantIdAttr *ConstIdAttr =
+      VDecl->getAttr<HLSLVkConstantIdAttr>();
+  if (!ConstIdAttr)
+    return true;
+
+  ASTContext &Context = SemaRef.getASTContext();
+
+  APValue InitValue;
+  if (!Init->isCXX11ConstantExpr(Context, &InitValue)) {
+    Diag(VDecl->getLocation(), diag::err_specialization_const);
+    VDecl->setInvalidDecl();
+    return false;
+  }
+
+  Builtin::ID BID = getSpecConstBuiltinId(VDecl->getType());
+
+  // Argument 1: The ID from the attribute
+  int ConstantID = ConstIdAttr->getId();
+  llvm::APInt IDVal(Context.getIntWidth(Context.IntTy), ConstantID);
+  Expr *IdExpr = IntegerLiteral::Create(Context, IDVal, Context.IntTy,
+                                        ConstIdAttr->getLocation());
+
+  SmallVector<Expr *, 2> Args = {IdExpr, Init};
+  Expr *C = SemaRef.BuildBuiltinCallExpr(Init->getExprLoc(), BID, Args);
+  if (C->getType()->getCanonicalTypeUnqualified() !=
+      VDecl->getType()->getCanonicalTypeUnqualified()) {
+    C = SemaRef
+            .BuildCStyleCastExpr(SourceLocation(),
+                                 Context.getTrivialTypeSourceInfo(
+                                     Init->getType(), Init->getExprLoc()),
+                                 SourceLocation(), C)
+            .get();
+  }
+  Init = C;
+  return true;
+}
diff --git a/clang/test/AST/HLSL/vk.spec-constant.usage.hlsl b/clang/test/AST/HLSL/vk.spec-constant.usage.hlsl
new file mode 100644
index 0000000000000..c0955c1ea7b43
--- /dev/null
+++ b/clang/test/AST/HLSL/vk.spec-constant.usage.hlsl
@@ -0,0 +1,130 @@
+// RUN: %clang_cc1 -finclude-default-header -triple spirv-unknown-vulkan-compute -x hlsl -ast-dump -o - %s | FileCheck %s
+
+// CHECK: VarDecl {{.*}} bool_const 'const hlsl_private bool' static cinit
+// CHECK-NEXT: CallExpr {{.*}} 'bool'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'bool (*)(unsigned int, bool) noexcept' <FunctionToPointerDecay>
+// CHECK-NEXT: DeclRefExpr {{.*}} 'bool (unsigned int, bool) noexcept' lvalue Function {{.*}} '__builtin_get_spirv_spec_constant_bool' 'bool (unsigned int, bool) noexcept'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'unsigned int' <IntegralCast>
+// CHECK-NEXT: IntegerLiteral {{.*}} 'int' 1
+// CHECK-NEXT: CXXBoolLiteralExpr {{.*}} 'bool' true
+[[vk::constant_id(1)]]
+const bool bool_const = true;
+
+// CHECK: VarDecl {{.*}} short_const 'const hlsl_private short' static cinit
+// CHECK-NEXT: CallExpr {{.*}} 'short'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'short (*)(unsigned int, short) noexcept' <FunctionToPointerDecay>
+// CHECK-NEXT: DeclRefExpr {{.*}} 'short (unsigned int, short) noexcept' lvalue Function {{.*}} '__builtin_get_spirv_spec_constant_short' 'short (unsigned int, short) noexcept'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'unsigned int' <IntegralCast>
+// CHECK-NEXT: IntegerLiteral {{.*}} 'int' 2
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'short' <IntegralCast>
+// CHECK-NEXT: IntegerLiteral {{.*}} 'int' 4
+[[vk::constant_id(2)]]
+const short short_const = 4;
+
+// CHECK: VarDecl {{.*}} int_const 'const hlsl_private int' stat...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jun 19, 2025

@llvm/pr-subscribers-hlsl

Author: Steven Perron (s-perron)

Changes
  • Reapply "[HLSL][SPIRV] Add vk::constant_id attribute." (#144812)
  • Fix memory leak.

Patch is 38.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144902.diff

15 Files Affected:

  • (modified) clang/include/clang/Basic/Attr.td (+8)
  • (modified) clang/include/clang/Basic/AttrDocs.td (+15)
  • (modified) clang/include/clang/Basic/Builtins.td (+13)
  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+4)
  • (modified) clang/include/clang/Sema/SemaHLSL.h (+4-1)
  • (modified) clang/lib/CodeGen/CGHLSLBuiltins.cpp (+75)
  • (modified) clang/lib/CodeGen/CodeGenFunction.h (+6)
  • (modified) clang/lib/Sema/SemaDecl.cpp (+13)
  • (modified) clang/lib/Sema/SemaDeclAttr.cpp (+3)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+119-1)
  • (added) clang/test/AST/HLSL/vk.spec-constant.usage.hlsl (+130)
  • (renamed) clang/test/CodeGenHLSL/vk-features/SpirvType.alignment.hlsl ()
  • (renamed) clang/test/CodeGenHLSL/vk-features/SpirvType.hlsl ()
  • (added) clang/test/CodeGenHLSL/vk-features/vk.spec-constant.hlsl (+210)
  • (added) clang/test/SemaHLSL/vk.spec-constant.error.hlsl (+37)
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index f113cd2ba2fbf..27fea7dea0a5e 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -5023,6 +5023,14 @@ def HLSLVkExtBuiltinInput : InheritableAttr {
   let Documentation = [HLSLVkExtBuiltinInputDocs];
 }
 
+def HLSLVkConstantId : InheritableAttr {
+  let Spellings = [CXX11<"vk", "constant_id">];
+  let Args = [IntArgument<"Id">];
+  let Subjects = SubjectList<[ExternalGlobalVar]>;
+  let LangOpts = [HLSL];
+  let Documentation = [VkConstantIdDocs];
+}
+
 def RandomizeLayout : InheritableAttr {
   let Spellings = [GCC<"randomize_layout">];
   let Subjects = SubjectList<[Record]>;
diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index 6051e1fc45111..43442f177ab7b 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -8252,6 +8252,21 @@ and https://microsoft.github.io/hlsl-specs/proposals/0013-wave-size-range.html
   }];
 }
 
+def VkConstantIdDocs : Documentation {
+  let Category = DocCatFunction;
+  let Content = [{
+The ``vk::constant_id`` attribute specifies the id for a SPIR-V specialization
+constant. The attribute applies to const global scalar variables. The variable must be initialized with a C++11 constexpr.
+In SPIR-V, the
+variable will be replaced with an `OpSpecConstant` with the given id.
+The syntax is:
+
+.. code-block:: text
+
+  ``[[vk::constant_id(<Id>)]] const T Name = <Init>``
+}];
+}
+
 def RootSignatureDocs : Documentation {
   let Category = DocCatFunction;
   let Content = [{
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 68cd3d790e78a..d65b3a5d2f447 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -5065,6 +5065,19 @@ def HLSLGroupMemoryBarrierWithGroupSync: LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void()";
 }
 
+class HLSLScalarTemplate
+    : Template<["bool", "char", "short", "int", "long long int",
+                "unsigned short", "unsigned int", "unsigned long long int",
+                "__fp16", "float", "double"],
+               ["_bool", "_char", "_short", "_int", "_longlong", "_ushort",
+                "_uint", "_ulonglong", "_half", "_float", "_double"]>;
+
+def HLSLGetSpirvSpecConstant : LangBuiltin<"HLSL_LANG">, HLSLScalarTemplate {
+  let Spellings = ["__builtin_get_spirv_spec_constant"];
+  let Attributes = [NoThrow, Const, Pure];
+  let Prototype = "T(unsigned int, T)";
+}
+
 // Builtins for XRay.
 def XRayCustomEvent : Builtin {
   let Spellings = ["__xray_customevent"];
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 979ff60b73b75..34b798a09c216 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12927,6 +12927,10 @@ def err_spirv_enum_not_int : Error<
 def err_spirv_enum_not_valid : Error<
    "invalid value for %select{storage class}0 argument">;
 
+def err_specialization_const
+    : Error<"variable with 'vk::constant_id' attribute must be a const "
+            "int/float/enum/bool and be initialized with a literal">;
+
 // errors of expect.with.probability
 def err_probability_not_constant_float : Error<
    "probability argument to __builtin_expect_with_probability must be constant "
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index 33c4b8d1568bf..97091792ba236 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -98,6 +98,8 @@ class SemaHLSL : public SemaBase {
   HLSLWaveSizeAttr *mergeWaveSizeAttr(Decl *D, const AttributeCommonInfo &AL,
                                       int Min, int Max, int Preferred,
                                       int SpelledArgsCount);
+  HLSLVkConstantIdAttr *
+  mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL, int Id);
   HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
                                   llvm::Triple::EnvironmentType ShaderType);
   HLSLParamModifierAttr *
@@ -135,6 +137,7 @@ class SemaHLSL : public SemaBase {
   void handleRootSignatureAttr(Decl *D, const ParsedAttr &AL);
   void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
   void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
+  void handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL);
   void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
   void handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL);
   void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL);
@@ -171,7 +174,7 @@ class SemaHLSL : public SemaBase {
   QualType getInoutParameterType(QualType Ty);
 
   bool transformInitList(const InitializedEntity &Entity, InitListExpr *Init);
-
+  bool handleInitialization(VarDecl *VDecl, Expr *&Init);
   void deduceAddressSpace(VarDecl *Decl);
 
 private:
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index ccf45c0c6ff1d..2a60a0909c93e 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -12,6 +12,7 @@
 
 #include "CGBuiltin.h"
 #include "CGHLSLRuntime.h"
+#include "CodeGenFunction.h"
 
 using namespace clang;
 using namespace CodeGen;
@@ -214,6 +215,44 @@ static Intrinsic::ID getWaveActiveMaxIntrinsic(llvm::Triple::ArchType Arch,
   }
 }
 
+// Returns the mangled name for a builtin function that the SPIR-V backend
+// will expand into a spec Constant.
+static std::string getSpecConstantFunctionName(clang::QualType SpecConstantType,
+                                               ASTContext &Context) {
+  // The parameter types for our conceptual intrinsic function.
+  QualType ClangParamTypes[] = {Context.IntTy, SpecConstantType};
+
+  // Create a temporary FunctionDecl for the builtin fuction. It won't be
+  // added to the AST.
+  FunctionProtoType::ExtProtoInfo EPI;
+  QualType FnType =
+      Context.getFunctionType(SpecConstantType, ClangParamTypes, EPI);
+  DeclarationName FuncName = &Context.Idents.get("__spirv_SpecConstant");
+  FunctionDecl *FnDeclForMangling = FunctionDecl::Create(
+      Context, Context.getTranslationUnitDecl(), SourceLocation(),
+      SourceLocation(), FuncName, FnType, /*TSI=*/nullptr, SC_Extern);
+
+  // Attach the created parameter declarations to the function declaration.
+  SmallVector<ParmVarDecl *, 2> ParamDecls;
+  for (QualType ParamType : ClangParamTypes) {
+    ParmVarDecl *PD = ParmVarDecl::Create(
+        Context, FnDeclForMangling, SourceLocation(), SourceLocation(),
+        /*IdentifierInfo*/ nullptr, ParamType, /*TSI*/ nullptr, SC_None,
+        /*DefaultArg*/ nullptr);
+    ParamDecls.push_back(PD);
+  }
+  FnDeclForMangling->setParams(ParamDecls);
+
+  // Get the mangled name.
+  std::string Name;
+  llvm::raw_string_ostream MangledNameStream(Name);
+  std::unique_ptr<MangleContext> Mangler(Context.createMangleContext());
+  Mangler->mangleName(FnDeclForMangling, MangledNameStream);
+  MangledNameStream.flush();
+
+  return Name;
+}
+
 Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
                                             const CallExpr *E,
                                             ReturnValueSlot ReturnValue) {
@@ -773,6 +812,42 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
     return EmitRuntimeCall(
         Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID));
   }
+  case Builtin::BI__builtin_get_spirv_spec_constant_bool:
+  case Builtin::BI__builtin_get_spirv_spec_constant_short:
+  case Builtin::BI__builtin_get_spirv_spec_constant_ushort:
+  case Builtin::BI__builtin_get_spirv_spec_constant_int:
+  case Builtin::BI__builtin_get_spirv_spec_constant_uint:
+  case Builtin::BI__builtin_get_spirv_spec_constant_longlong:
+  case Builtin::BI__builtin_get_spirv_spec_constant_ulonglong:
+  case Builtin::BI__builtin_get_spirv_spec_constant_half:
+  case Builtin::BI__builtin_get_spirv_spec_constant_float:
+  case Builtin::BI__builtin_get_spirv_spec_constant_double: {
+    llvm::Function *SpecConstantFn = getSpecConstantFunction(E->getType());
+    llvm::Value *SpecId = EmitScalarExpr(E->getArg(0));
+    llvm::Value *DefaultVal = EmitScalarExpr(E->getArg(1));
+    llvm::Value *Args[] = {SpecId, DefaultVal};
+    return Builder.CreateCall(SpecConstantFn, Args);
+  }
   }
   return nullptr;
 }
+
+llvm::Function *clang::CodeGen::CodeGenFunction::getSpecConstantFunction(
+    const clang::QualType &SpecConstantType) {
+
+  // Find or create the declaration for the function.
+  llvm::Module *M = &CGM.getModule();
+  std::string MangledName =
+      getSpecConstantFunctionName(SpecConstantType, getContext());
+  llvm::Function *SpecConstantFn = M->getFunction(MangledName);
+
+  if (!SpecConstantFn) {
+    llvm::Type *IntType = ConvertType(getContext().IntTy);
+    llvm::Type *RetTy = ConvertType(SpecConstantType);
+    llvm::Type *ArgTypes[] = {IntType, RetTy};
+    llvm::FunctionType *FnTy = llvm::FunctionType::get(RetTy, ArgTypes, false);
+    SpecConstantFn = llvm::Function::Create(
+        FnTy, llvm::GlobalValue::ExternalLinkage, MangledName, M);
+  }
+  return SpecConstantFn;
+}
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index a5ab9df01dba9..59f14b3e35fd0 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4850,6 +4850,12 @@ class CodeGenFunction : public CodeGenTypeCache {
   llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E,
                                    ReturnValueSlot ReturnValue);
+
+  // Returns a builtin function that the SPIR-V backend will expand into a spec
+  // constant.
+  llvm::Function *
+  getSpecConstantFunction(const clang::QualType &SpecConstantType);
+
   llvm::Value *EmitDirectXBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitSPIRVBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx,
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 1bf72e5bb7b9d..e1cccf068b5aa 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -2890,6 +2890,8 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D,
     NewAttr = S.HLSL().mergeWaveSizeAttr(D, *WS, WS->getMin(), WS->getMax(),
                                          WS->getPreferred(),
                                          WS->getSpelledArgsCount());
+  else if (const auto *CI = dyn_cast<HLSLVkConstantIdAttr>(Attr))
+    NewAttr = S.HLSL().mergeVkConstantIdAttr(D, *CI, CI->getId());
   else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr))
     NewAttr = S.HLSL().mergeShaderAttr(D, *SA, SA->getType());
   else if (isa<SuppressAttr>(Attr))
@@ -13757,6 +13759,10 @@ void Sema::AddInitializerToDecl(Decl *RealDecl, Expr *Init, bool DirectInit) {
     return;
   }
 
+  if (getLangOpts().HLSL)
+    if (!HLSL().handleInitialization(VDecl, Init))
+      return;
+
   // Get the decls type and save a reference for later, since
   // CheckInitializerTypes may change it.
   QualType DclT = VDecl->getType(), SavT = DclT;
@@ -14179,6 +14185,13 @@ void Sema::ActOnUninitializedDecl(Decl *RealDecl) {
       }
     }
 
+    // HLSL variable with the `vk::constant_id` attribute must be initialized.
+    if (!Var->isInvalidDecl() && Var->hasAttr<HLSLVkConstantIdAttr>()) {
+      Diag(Var->getLocation(), diag::err_specialization_const);
+      Var->setInvalidDecl();
+      return;
+    }
+
     if (!Var->isInvalidDecl() && RealDecl->hasAttr<LoaderUninitializedAttr>()) {
       if (Var->getStorageClass() == SC_Extern) {
         Diag(Var->getLocation(), diag::err_loader_uninitialized_extern_decl)
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 1c2fa80e782d4..eba29e609cb05 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -7590,6 +7590,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
   case ParsedAttr::AT_HLSLVkExtBuiltinInput:
     S.HLSL().handleVkExtBuiltinInputAttr(D, AL);
     break;
+  case ParsedAttr::AT_HLSLVkConstantId:
+    S.HLSL().handleVkConstantIdAttr(D, AL);
+    break;
   case ParsedAttr::AT_HLSLSV_GroupThreadID:
     S.HLSL().handleSV_GroupThreadIDAttr(D, AL);
     break;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index b55f4fd786b58..9b43ee00810b2 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -119,6 +119,40 @@ static ResourceClass getResourceClass(RegisterType RT) {
   llvm_unreachable("unexpected RegisterType value");
 }
 
+static Builtin::ID getSpecConstBuiltinId(QualType Type) {
+  const auto *BT = dyn_cast<BuiltinType>(Type);
+  if (!BT) {
+    if (!Type->isEnumeralType())
+      return Builtin::NotBuiltin;
+    return Builtin::BI__builtin_get_spirv_spec_constant_int;
+  }
+
+  switch (BT->getKind()) {
+  case BuiltinType::Bool:
+    return Builtin::BI__builtin_get_spirv_spec_constant_bool;
+  case BuiltinType::Short:
+    return Builtin::BI__builtin_get_spirv_spec_constant_short;
+  case BuiltinType::Int:
+    return Builtin::BI__builtin_get_spirv_spec_constant_int;
+  case BuiltinType::LongLong:
+    return Builtin::BI__builtin_get_spirv_spec_constant_longlong;
+  case BuiltinType::UShort:
+    return Builtin::BI__builtin_get_spirv_spec_constant_ushort;
+  case BuiltinType::UInt:
+    return Builtin::BI__builtin_get_spirv_spec_constant_uint;
+  case BuiltinType::ULongLong:
+    return Builtin::BI__builtin_get_spirv_spec_constant_ulonglong;
+  case BuiltinType::Half:
+    return Builtin::BI__builtin_get_spirv_spec_constant_half;
+  case BuiltinType::Float:
+    return Builtin::BI__builtin_get_spirv_spec_constant_float;
+  case BuiltinType::Double:
+    return Builtin::BI__builtin_get_spirv_spec_constant_double;
+  default:
+    return Builtin::NotBuiltin;
+  }
+}
+
 DeclBindingInfo *ResourceBindings::addDeclBindingInfo(const VarDecl *VD,
                                                       ResourceClass ResClass) {
   assert(getDeclBindingInfo(VD, ResClass) == nullptr &&
@@ -607,6 +641,41 @@ HLSLWaveSizeAttr *SemaHLSL::mergeWaveSizeAttr(Decl *D,
   return Result;
 }
 
+HLSLVkConstantIdAttr *
+SemaHLSL::mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL,
+                                int Id) {
+
+  auto &TargetInfo = getASTContext().getTargetInfo();
+  if (TargetInfo.getTriple().getArch() != llvm::Triple::spirv) {
+    Diag(AL.getLoc(), diag::warn_attribute_ignored) << AL;
+    return nullptr;
+  }
+
+  auto *VD = cast<VarDecl>(D);
+
+  if (getSpecConstBuiltinId(VD->getType()) == Builtin::NotBuiltin) {
+    Diag(VD->getLocation(), diag::err_specialization_const);
+    return nullptr;
+  }
+
+  if (!VD->getType().isConstQualified()) {
+    Diag(VD->getLocation(), diag::err_specialization_const);
+    return nullptr;
+  }
+
+  if (HLSLVkConstantIdAttr *CI = D->getAttr<HLSLVkConstantIdAttr>()) {
+    if (CI->getId() != Id) {
+      Diag(CI->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
+      Diag(AL.getLoc(), diag::note_conflicting_attribute);
+    }
+    return nullptr;
+  }
+
+  HLSLVkConstantIdAttr *Result =
+      ::new (getASTContext()) HLSLVkConstantIdAttr(getASTContext(), AL, Id);
+  return Result;
+}
+
 HLSLShaderAttr *
 SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
                           llvm::Triple::EnvironmentType ShaderType) {
@@ -1157,6 +1226,15 @@ void SemaHLSL::handleVkExtBuiltinInputAttr(Decl *D, const ParsedAttr &AL) {
                  HLSLVkExtBuiltinInputAttr(getASTContext(), AL, ID));
 }
 
+void SemaHLSL::handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL) {
+  uint32_t Id;
+  if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Id))
+    return;
+  HLSLVkConstantIdAttr *NewAttr = mergeVkConstantIdAttr(D, AL, Id);
+  if (NewAttr)
+    D->addAttr(NewAttr);
+}
+
 bool SemaHLSL::diagnoseInputIDType(QualType T, const ParsedAttr &AL) {
   const auto *VT = T->getAs<VectorType>();
 
@@ -3206,6 +3284,7 @@ static bool IsDefaultBufferConstantDecl(VarDecl *VD) {
   return VD->getDeclContext()->isTranslationUnit() &&
          QT.getAddressSpace() == LangAS::Default &&
          VD->getStorageClass() != SC_Static &&
+         !VD->hasAttr<HLSLVkConstantIdAttr>() &&
          !isInvalidConstantBufferLeafElementType(QT.getTypePtr());
 }
 
@@ -3273,7 +3352,8 @@ void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) {
     const Type *VarType = VD->getType().getTypePtr();
     while (VarType->isArrayType())
       VarType = VarType->getArrayElementTypeNoTypeQual();
-    if (VarType->isHLSLResourceRecord()) {
+    if (VarType->isHLSLResourceRecord() ||
+        VD->hasAttr<HLSLVkConstantIdAttr>()) {
       // Make the variable for resources static. The global externally visible
       // storage is accessed through the handle, which is a member. The variable
       // itself is not externally visible.
@@ -3696,3 +3776,41 @@ bool SemaHLSL::transformInitList(const InitializedEntity &Entity,
     Init->updateInit(Ctx, I, NewInit->getInit(I));
   return true;
 }
+
+bool SemaHLSL::handleInitialization(VarDecl *VDecl, Expr *&Init) {
+  const HLSLVkConstantIdAttr *ConstIdAttr =
+      VDecl->getAttr<HLSLVkConstantIdAttr>();
+  if (!ConstIdAttr)
+    return true;
+
+  ASTContext &Context = SemaRef.getASTContext();
+
+  APValue InitValue;
+  if (!Init->isCXX11ConstantExpr(Context, &InitValue)) {
+    Diag(VDecl->getLocation(), diag::err_specialization_const);
+    VDecl->setInvalidDecl();
+    return false;
+  }
+
+  Builtin::ID BID = getSpecConstBuiltinId(VDecl->getType());
+
+  // Argument 1: The ID from the attribute
+  int ConstantID = ConstIdAttr->getId();
+  llvm::APInt IDVal(Context.getIntWidth(Context.IntTy), ConstantID);
+  Expr *IdExpr = IntegerLiteral::Create(Context, IDVal, Context.IntTy,
+                                        ConstIdAttr->getLocation());
+
+  SmallVector<Expr *, 2> Args = {IdExpr, Init};
+  Expr *C = SemaRef.BuildBuiltinCallExpr(Init->getExprLoc(), BID, Args);
+  if (C->getType()->getCanonicalTypeUnqualified() !=
+      VDecl->getType()->getCanonicalTypeUnqualified()) {
+    C = SemaRef
+            .BuildCStyleCastExpr(SourceLocation(),
+                                 Context.getTrivialTypeSourceInfo(
+                                     Init->getType(), Init->getExprLoc()),
+                                 SourceLocation(), C)
+            .get();
+  }
+  Init = C;
+  return true;
+}
diff --git a/clang/test/AST/HLSL/vk.spec-constant.usage.hlsl b/clang/test/AST/HLSL/vk.spec-constant.usage.hlsl
new file mode 100644
index 0000000000000..c0955c1ea7b43
--- /dev/null
+++ b/clang/test/AST/HLSL/vk.spec-constant.usage.hlsl
@@ -0,0 +1,130 @@
+// RUN: %clang_cc1 -finclude-default-header -triple spirv-unknown-vulkan-compute -x hlsl -ast-dump -o - %s | FileCheck %s
+
+// CHECK: VarDecl {{.*}} bool_const 'const hlsl_private bool' static cinit
+// CHECK-NEXT: CallExpr {{.*}} 'bool'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'bool (*)(unsigned int, bool) noexcept' <FunctionToPointerDecay>
+// CHECK-NEXT: DeclRefExpr {{.*}} 'bool (unsigned int, bool) noexcept' lvalue Function {{.*}} '__builtin_get_spirv_spec_constant_bool' 'bool (unsigned int, bool) noexcept'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'unsigned int' <IntegralCast>
+// CHECK-NEXT: IntegerLiteral {{.*}} 'int' 1
+// CHECK-NEXT: CXXBoolLiteralExpr {{.*}} 'bool' true
+[[vk::constant_id(1)]]
+const bool bool_const = true;
+
+// CHECK: VarDecl {{.*}} short_const 'const hlsl_private short' static cinit
+// CHECK-NEXT: CallExpr {{.*}} 'short'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'short (*)(unsigned int, short) noexcept' <FunctionToPointerDecay>
+// CHECK-NEXT: DeclRefExpr {{.*}} 'short (unsigned int, short) noexcept' lvalue Function {{.*}} '__builtin_get_spirv_spec_constant_short' 'short (unsigned int, short) noexcept'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'unsigned int' <IntegralCast>
+// CHECK-NEXT: IntegerLiteral {{.*}} 'int' 2
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'short' <IntegralCast>
+// CHECK-NEXT: IntegerLiteral {{.*}} 'int' 4
+[[vk::constant_id(2)]]
+const short short_const = 4;
+
+// CHECK: VarDecl {{.*}} int_const 'const hlsl_private int' stat...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jun 19, 2025

@llvm/pr-subscribers-clang

Author: Steven Perron (s-perron)

Changes
  • Reapply "[HLSL][SPIRV] Add vk::constant_id attribute." (#144812)
  • Fix memory leak.

Patch is 38.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144902.diff

15 Files Affected:

  • (modified) clang/include/clang/Basic/Attr.td (+8)
  • (modified) clang/include/clang/Basic/AttrDocs.td (+15)
  • (modified) clang/include/clang/Basic/Builtins.td (+13)
  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+4)
  • (modified) clang/include/clang/Sema/SemaHLSL.h (+4-1)
  • (modified) clang/lib/CodeGen/CGHLSLBuiltins.cpp (+75)
  • (modified) clang/lib/CodeGen/CodeGenFunction.h (+6)
  • (modified) clang/lib/Sema/SemaDecl.cpp (+13)
  • (modified) clang/lib/Sema/SemaDeclAttr.cpp (+3)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+119-1)
  • (added) clang/test/AST/HLSL/vk.spec-constant.usage.hlsl (+130)
  • (renamed) clang/test/CodeGenHLSL/vk-features/SpirvType.alignment.hlsl ()
  • (renamed) clang/test/CodeGenHLSL/vk-features/SpirvType.hlsl ()
  • (added) clang/test/CodeGenHLSL/vk-features/vk.spec-constant.hlsl (+210)
  • (added) clang/test/SemaHLSL/vk.spec-constant.error.hlsl (+37)
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index f113cd2ba2fbf..27fea7dea0a5e 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -5023,6 +5023,14 @@ def HLSLVkExtBuiltinInput : InheritableAttr {
   let Documentation = [HLSLVkExtBuiltinInputDocs];
 }
 
+def HLSLVkConstantId : InheritableAttr {
+  let Spellings = [CXX11<"vk", "constant_id">];
+  let Args = [IntArgument<"Id">];
+  let Subjects = SubjectList<[ExternalGlobalVar]>;
+  let LangOpts = [HLSL];
+  let Documentation = [VkConstantIdDocs];
+}
+
 def RandomizeLayout : InheritableAttr {
   let Spellings = [GCC<"randomize_layout">];
   let Subjects = SubjectList<[Record]>;
diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index 6051e1fc45111..43442f177ab7b 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -8252,6 +8252,21 @@ and https://microsoft.github.io/hlsl-specs/proposals/0013-wave-size-range.html
   }];
 }
 
+def VkConstantIdDocs : Documentation {
+  let Category = DocCatFunction;
+  let Content = [{
+The ``vk::constant_id`` attribute specifies the id for a SPIR-V specialization
+constant. The attribute applies to const global scalar variables. The variable must be initialized with a C++11 constexpr.
+In SPIR-V, the
+variable will be replaced with an `OpSpecConstant` with the given id.
+The syntax is:
+
+.. code-block:: text
+
+  ``[[vk::constant_id(<Id>)]] const T Name = <Init>``
+}];
+}
+
 def RootSignatureDocs : Documentation {
   let Category = DocCatFunction;
   let Content = [{
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 68cd3d790e78a..d65b3a5d2f447 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -5065,6 +5065,19 @@ def HLSLGroupMemoryBarrierWithGroupSync: LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void()";
 }
 
+class HLSLScalarTemplate
+    : Template<["bool", "char", "short", "int", "long long int",
+                "unsigned short", "unsigned int", "unsigned long long int",
+                "__fp16", "float", "double"],
+               ["_bool", "_char", "_short", "_int", "_longlong", "_ushort",
+                "_uint", "_ulonglong", "_half", "_float", "_double"]>;
+
+def HLSLGetSpirvSpecConstant : LangBuiltin<"HLSL_LANG">, HLSLScalarTemplate {
+  let Spellings = ["__builtin_get_spirv_spec_constant"];
+  let Attributes = [NoThrow, Const, Pure];
+  let Prototype = "T(unsigned int, T)";
+}
+
 // Builtins for XRay.
 def XRayCustomEvent : Builtin {
   let Spellings = ["__xray_customevent"];
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 979ff60b73b75..34b798a09c216 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12927,6 +12927,10 @@ def err_spirv_enum_not_int : Error<
 def err_spirv_enum_not_valid : Error<
    "invalid value for %select{storage class}0 argument">;
 
+def err_specialization_const
+    : Error<"variable with 'vk::constant_id' attribute must be a const "
+            "int/float/enum/bool and be initialized with a literal">;
+
 // errors of expect.with.probability
 def err_probability_not_constant_float : Error<
    "probability argument to __builtin_expect_with_probability must be constant "
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index 33c4b8d1568bf..97091792ba236 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -98,6 +98,8 @@ class SemaHLSL : public SemaBase {
   HLSLWaveSizeAttr *mergeWaveSizeAttr(Decl *D, const AttributeCommonInfo &AL,
                                       int Min, int Max, int Preferred,
                                       int SpelledArgsCount);
+  HLSLVkConstantIdAttr *
+  mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL, int Id);
   HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
                                   llvm::Triple::EnvironmentType ShaderType);
   HLSLParamModifierAttr *
@@ -135,6 +137,7 @@ class SemaHLSL : public SemaBase {
   void handleRootSignatureAttr(Decl *D, const ParsedAttr &AL);
   void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
   void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
+  void handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL);
   void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
   void handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL);
   void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL);
@@ -171,7 +174,7 @@ class SemaHLSL : public SemaBase {
   QualType getInoutParameterType(QualType Ty);
 
   bool transformInitList(const InitializedEntity &Entity, InitListExpr *Init);
-
+  bool handleInitialization(VarDecl *VDecl, Expr *&Init);
   void deduceAddressSpace(VarDecl *Decl);
 
 private:
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index ccf45c0c6ff1d..2a60a0909c93e 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -12,6 +12,7 @@
 
 #include "CGBuiltin.h"
 #include "CGHLSLRuntime.h"
+#include "CodeGenFunction.h"
 
 using namespace clang;
 using namespace CodeGen;
@@ -214,6 +215,44 @@ static Intrinsic::ID getWaveActiveMaxIntrinsic(llvm::Triple::ArchType Arch,
   }
 }
 
+// Returns the mangled name for a builtin function that the SPIR-V backend
+// will expand into a spec Constant.
+static std::string getSpecConstantFunctionName(clang::QualType SpecConstantType,
+                                               ASTContext &Context) {
+  // The parameter types for our conceptual intrinsic function.
+  QualType ClangParamTypes[] = {Context.IntTy, SpecConstantType};
+
+  // Create a temporary FunctionDecl for the builtin fuction. It won't be
+  // added to the AST.
+  FunctionProtoType::ExtProtoInfo EPI;
+  QualType FnType =
+      Context.getFunctionType(SpecConstantType, ClangParamTypes, EPI);
+  DeclarationName FuncName = &Context.Idents.get("__spirv_SpecConstant");
+  FunctionDecl *FnDeclForMangling = FunctionDecl::Create(
+      Context, Context.getTranslationUnitDecl(), SourceLocation(),
+      SourceLocation(), FuncName, FnType, /*TSI=*/nullptr, SC_Extern);
+
+  // Attach the created parameter declarations to the function declaration.
+  SmallVector<ParmVarDecl *, 2> ParamDecls;
+  for (QualType ParamType : ClangParamTypes) {
+    ParmVarDecl *PD = ParmVarDecl::Create(
+        Context, FnDeclForMangling, SourceLocation(), SourceLocation(),
+        /*IdentifierInfo*/ nullptr, ParamType, /*TSI*/ nullptr, SC_None,
+        /*DefaultArg*/ nullptr);
+    ParamDecls.push_back(PD);
+  }
+  FnDeclForMangling->setParams(ParamDecls);
+
+  // Get the mangled name.
+  std::string Name;
+  llvm::raw_string_ostream MangledNameStream(Name);
+  std::unique_ptr<MangleContext> Mangler(Context.createMangleContext());
+  Mangler->mangleName(FnDeclForMangling, MangledNameStream);
+  MangledNameStream.flush();
+
+  return Name;
+}
+
 Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
                                             const CallExpr *E,
                                             ReturnValueSlot ReturnValue) {
@@ -773,6 +812,42 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
     return EmitRuntimeCall(
         Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID));
   }
+  case Builtin::BI__builtin_get_spirv_spec_constant_bool:
+  case Builtin::BI__builtin_get_spirv_spec_constant_short:
+  case Builtin::BI__builtin_get_spirv_spec_constant_ushort:
+  case Builtin::BI__builtin_get_spirv_spec_constant_int:
+  case Builtin::BI__builtin_get_spirv_spec_constant_uint:
+  case Builtin::BI__builtin_get_spirv_spec_constant_longlong:
+  case Builtin::BI__builtin_get_spirv_spec_constant_ulonglong:
+  case Builtin::BI__builtin_get_spirv_spec_constant_half:
+  case Builtin::BI__builtin_get_spirv_spec_constant_float:
+  case Builtin::BI__builtin_get_spirv_spec_constant_double: {
+    llvm::Function *SpecConstantFn = getSpecConstantFunction(E->getType());
+    llvm::Value *SpecId = EmitScalarExpr(E->getArg(0));
+    llvm::Value *DefaultVal = EmitScalarExpr(E->getArg(1));
+    llvm::Value *Args[] = {SpecId, DefaultVal};
+    return Builder.CreateCall(SpecConstantFn, Args);
+  }
   }
   return nullptr;
 }
+
+llvm::Function *clang::CodeGen::CodeGenFunction::getSpecConstantFunction(
+    const clang::QualType &SpecConstantType) {
+
+  // Find or create the declaration for the function.
+  llvm::Module *M = &CGM.getModule();
+  std::string MangledName =
+      getSpecConstantFunctionName(SpecConstantType, getContext());
+  llvm::Function *SpecConstantFn = M->getFunction(MangledName);
+
+  if (!SpecConstantFn) {
+    llvm::Type *IntType = ConvertType(getContext().IntTy);
+    llvm::Type *RetTy = ConvertType(SpecConstantType);
+    llvm::Type *ArgTypes[] = {IntType, RetTy};
+    llvm::FunctionType *FnTy = llvm::FunctionType::get(RetTy, ArgTypes, false);
+    SpecConstantFn = llvm::Function::Create(
+        FnTy, llvm::GlobalValue::ExternalLinkage, MangledName, M);
+  }
+  return SpecConstantFn;
+}
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index a5ab9df01dba9..59f14b3e35fd0 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4850,6 +4850,12 @@ class CodeGenFunction : public CodeGenTypeCache {
   llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E,
                                    ReturnValueSlot ReturnValue);
+
+  // Returns a builtin function that the SPIR-V backend will expand into a spec
+  // constant.
+  llvm::Function *
+  getSpecConstantFunction(const clang::QualType &SpecConstantType);
+
   llvm::Value *EmitDirectXBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitSPIRVBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx,
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 1bf72e5bb7b9d..e1cccf068b5aa 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -2890,6 +2890,8 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D,
     NewAttr = S.HLSL().mergeWaveSizeAttr(D, *WS, WS->getMin(), WS->getMax(),
                                          WS->getPreferred(),
                                          WS->getSpelledArgsCount());
+  else if (const auto *CI = dyn_cast<HLSLVkConstantIdAttr>(Attr))
+    NewAttr = S.HLSL().mergeVkConstantIdAttr(D, *CI, CI->getId());
   else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr))
     NewAttr = S.HLSL().mergeShaderAttr(D, *SA, SA->getType());
   else if (isa<SuppressAttr>(Attr))
@@ -13757,6 +13759,10 @@ void Sema::AddInitializerToDecl(Decl *RealDecl, Expr *Init, bool DirectInit) {
     return;
   }
 
+  if (getLangOpts().HLSL)
+    if (!HLSL().handleInitialization(VDecl, Init))
+      return;
+
   // Get the decls type and save a reference for later, since
   // CheckInitializerTypes may change it.
   QualType DclT = VDecl->getType(), SavT = DclT;
@@ -14179,6 +14185,13 @@ void Sema::ActOnUninitializedDecl(Decl *RealDecl) {
       }
     }
 
+    // HLSL variable with the `vk::constant_id` attribute must be initialized.
+    if (!Var->isInvalidDecl() && Var->hasAttr<HLSLVkConstantIdAttr>()) {
+      Diag(Var->getLocation(), diag::err_specialization_const);
+      Var->setInvalidDecl();
+      return;
+    }
+
     if (!Var->isInvalidDecl() && RealDecl->hasAttr<LoaderUninitializedAttr>()) {
       if (Var->getStorageClass() == SC_Extern) {
         Diag(Var->getLocation(), diag::err_loader_uninitialized_extern_decl)
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 1c2fa80e782d4..eba29e609cb05 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -7590,6 +7590,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
   case ParsedAttr::AT_HLSLVkExtBuiltinInput:
     S.HLSL().handleVkExtBuiltinInputAttr(D, AL);
     break;
+  case ParsedAttr::AT_HLSLVkConstantId:
+    S.HLSL().handleVkConstantIdAttr(D, AL);
+    break;
   case ParsedAttr::AT_HLSLSV_GroupThreadID:
     S.HLSL().handleSV_GroupThreadIDAttr(D, AL);
     break;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index b55f4fd786b58..9b43ee00810b2 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -119,6 +119,40 @@ static ResourceClass getResourceClass(RegisterType RT) {
   llvm_unreachable("unexpected RegisterType value");
 }
 
+static Builtin::ID getSpecConstBuiltinId(QualType Type) {
+  const auto *BT = dyn_cast<BuiltinType>(Type);
+  if (!BT) {
+    if (!Type->isEnumeralType())
+      return Builtin::NotBuiltin;
+    return Builtin::BI__builtin_get_spirv_spec_constant_int;
+  }
+
+  switch (BT->getKind()) {
+  case BuiltinType::Bool:
+    return Builtin::BI__builtin_get_spirv_spec_constant_bool;
+  case BuiltinType::Short:
+    return Builtin::BI__builtin_get_spirv_spec_constant_short;
+  case BuiltinType::Int:
+    return Builtin::BI__builtin_get_spirv_spec_constant_int;
+  case BuiltinType::LongLong:
+    return Builtin::BI__builtin_get_spirv_spec_constant_longlong;
+  case BuiltinType::UShort:
+    return Builtin::BI__builtin_get_spirv_spec_constant_ushort;
+  case BuiltinType::UInt:
+    return Builtin::BI__builtin_get_spirv_spec_constant_uint;
+  case BuiltinType::ULongLong:
+    return Builtin::BI__builtin_get_spirv_spec_constant_ulonglong;
+  case BuiltinType::Half:
+    return Builtin::BI__builtin_get_spirv_spec_constant_half;
+  case BuiltinType::Float:
+    return Builtin::BI__builtin_get_spirv_spec_constant_float;
+  case BuiltinType::Double:
+    return Builtin::BI__builtin_get_spirv_spec_constant_double;
+  default:
+    return Builtin::NotBuiltin;
+  }
+}
+
 DeclBindingInfo *ResourceBindings::addDeclBindingInfo(const VarDecl *VD,
                                                       ResourceClass ResClass) {
   assert(getDeclBindingInfo(VD, ResClass) == nullptr &&
@@ -607,6 +641,41 @@ HLSLWaveSizeAttr *SemaHLSL::mergeWaveSizeAttr(Decl *D,
   return Result;
 }
 
+HLSLVkConstantIdAttr *
+SemaHLSL::mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL,
+                                int Id) {
+
+  auto &TargetInfo = getASTContext().getTargetInfo();
+  if (TargetInfo.getTriple().getArch() != llvm::Triple::spirv) {
+    Diag(AL.getLoc(), diag::warn_attribute_ignored) << AL;
+    return nullptr;
+  }
+
+  auto *VD = cast<VarDecl>(D);
+
+  if (getSpecConstBuiltinId(VD->getType()) == Builtin::NotBuiltin) {
+    Diag(VD->getLocation(), diag::err_specialization_const);
+    return nullptr;
+  }
+
+  if (!VD->getType().isConstQualified()) {
+    Diag(VD->getLocation(), diag::err_specialization_const);
+    return nullptr;
+  }
+
+  if (HLSLVkConstantIdAttr *CI = D->getAttr<HLSLVkConstantIdAttr>()) {
+    if (CI->getId() != Id) {
+      Diag(CI->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
+      Diag(AL.getLoc(), diag::note_conflicting_attribute);
+    }
+    return nullptr;
+  }
+
+  HLSLVkConstantIdAttr *Result =
+      ::new (getASTContext()) HLSLVkConstantIdAttr(getASTContext(), AL, Id);
+  return Result;
+}
+
 HLSLShaderAttr *
 SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
                           llvm::Triple::EnvironmentType ShaderType) {
@@ -1157,6 +1226,15 @@ void SemaHLSL::handleVkExtBuiltinInputAttr(Decl *D, const ParsedAttr &AL) {
                  HLSLVkExtBuiltinInputAttr(getASTContext(), AL, ID));
 }
 
+void SemaHLSL::handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL) {
+  uint32_t Id;
+  if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Id))
+    return;
+  HLSLVkConstantIdAttr *NewAttr = mergeVkConstantIdAttr(D, AL, Id);
+  if (NewAttr)
+    D->addAttr(NewAttr);
+}
+
 bool SemaHLSL::diagnoseInputIDType(QualType T, const ParsedAttr &AL) {
   const auto *VT = T->getAs<VectorType>();
 
@@ -3206,6 +3284,7 @@ static bool IsDefaultBufferConstantDecl(VarDecl *VD) {
   return VD->getDeclContext()->isTranslationUnit() &&
          QT.getAddressSpace() == LangAS::Default &&
          VD->getStorageClass() != SC_Static &&
+         !VD->hasAttr<HLSLVkConstantIdAttr>() &&
          !isInvalidConstantBufferLeafElementType(QT.getTypePtr());
 }
 
@@ -3273,7 +3352,8 @@ void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) {
     const Type *VarType = VD->getType().getTypePtr();
     while (VarType->isArrayType())
       VarType = VarType->getArrayElementTypeNoTypeQual();
-    if (VarType->isHLSLResourceRecord()) {
+    if (VarType->isHLSLResourceRecord() ||
+        VD->hasAttr<HLSLVkConstantIdAttr>()) {
       // Make the variable for resources static. The global externally visible
       // storage is accessed through the handle, which is a member. The variable
       // itself is not externally visible.
@@ -3696,3 +3776,41 @@ bool SemaHLSL::transformInitList(const InitializedEntity &Entity,
     Init->updateInit(Ctx, I, NewInit->getInit(I));
   return true;
 }
+
+bool SemaHLSL::handleInitialization(VarDecl *VDecl, Expr *&Init) {
+  const HLSLVkConstantIdAttr *ConstIdAttr =
+      VDecl->getAttr<HLSLVkConstantIdAttr>();
+  if (!ConstIdAttr)
+    return true;
+
+  ASTContext &Context = SemaRef.getASTContext();
+
+  APValue InitValue;
+  if (!Init->isCXX11ConstantExpr(Context, &InitValue)) {
+    Diag(VDecl->getLocation(), diag::err_specialization_const);
+    VDecl->setInvalidDecl();
+    return false;
+  }
+
+  Builtin::ID BID = getSpecConstBuiltinId(VDecl->getType());
+
+  // Argument 1: The ID from the attribute
+  int ConstantID = ConstIdAttr->getId();
+  llvm::APInt IDVal(Context.getIntWidth(Context.IntTy), ConstantID);
+  Expr *IdExpr = IntegerLiteral::Create(Context, IDVal, Context.IntTy,
+                                        ConstIdAttr->getLocation());
+
+  SmallVector<Expr *, 2> Args = {IdExpr, Init};
+  Expr *C = SemaRef.BuildBuiltinCallExpr(Init->getExprLoc(), BID, Args);
+  if (C->getType()->getCanonicalTypeUnqualified() !=
+      VDecl->getType()->getCanonicalTypeUnqualified()) {
+    C = SemaRef
+            .BuildCStyleCastExpr(SourceLocation(),
+                                 Context.getTrivialTypeSourceInfo(
+                                     Init->getType(), Init->getExprLoc()),
+                                 SourceLocation(), C)
+            .get();
+  }
+  Init = C;
+  return true;
+}
diff --git a/clang/test/AST/HLSL/vk.spec-constant.usage.hlsl b/clang/test/AST/HLSL/vk.spec-constant.usage.hlsl
new file mode 100644
index 0000000000000..c0955c1ea7b43
--- /dev/null
+++ b/clang/test/AST/HLSL/vk.spec-constant.usage.hlsl
@@ -0,0 +1,130 @@
+// RUN: %clang_cc1 -finclude-default-header -triple spirv-unknown-vulkan-compute -x hlsl -ast-dump -o - %s | FileCheck %s
+
+// CHECK: VarDecl {{.*}} bool_const 'const hlsl_private bool' static cinit
+// CHECK-NEXT: CallExpr {{.*}} 'bool'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'bool (*)(unsigned int, bool) noexcept' <FunctionToPointerDecay>
+// CHECK-NEXT: DeclRefExpr {{.*}} 'bool (unsigned int, bool) noexcept' lvalue Function {{.*}} '__builtin_get_spirv_spec_constant_bool' 'bool (unsigned int, bool) noexcept'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'unsigned int' <IntegralCast>
+// CHECK-NEXT: IntegerLiteral {{.*}} 'int' 1
+// CHECK-NEXT: CXXBoolLiteralExpr {{.*}} 'bool' true
+[[vk::constant_id(1)]]
+const bool bool_const = true;
+
+// CHECK: VarDecl {{.*}} short_const 'const hlsl_private short' static cinit
+// CHECK-NEXT: CallExpr {{.*}} 'short'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'short (*)(unsigned int, short) noexcept' <FunctionToPointerDecay>
+// CHECK-NEXT: DeclRefExpr {{.*}} 'short (unsigned int, short) noexcept' lvalue Function {{.*}} '__builtin_get_spirv_spec_constant_short' 'short (unsigned int, short) noexcept'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'unsigned int' <IntegralCast>
+// CHECK-NEXT: IntegerLiteral {{.*}} 'int' 2
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'short' <IntegralCast>
+// CHECK-NEXT: IntegerLiteral {{.*}} 'int' 4
+[[vk::constant_id(2)]]
+const short short_const = 4;
+
+// CHECK: VarDecl {{.*}} int_const 'const hlsl_private int' stat...
[truncated]

@Keenuts
Copy link
Contributor

Keenuts commented Jun 19, 2025

You might want to change the PR title to add the tags etc

@s-perron s-perron changed the title spec constant [HLSL][SPIRV] Reapply "[HLSL][SPIRV] Add vk::constant_id attribute." Jun 19, 2025
@s-perron s-perron merged commit 01d648a into llvm:main Jun 19, 2025
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:codegen IR generation bugs: mangling, exceptions, etc. clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category HLSL HLSL Language Support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants