@@ -78,6 +78,10 @@ class Util {
78
78
// / \param Tmpl whether the class is template instantiation or simple record
79
79
static bool isSyclType (const QualType &Ty, StringRef Name, bool Tmpl = false );
80
80
81
+ // / Checks whether given clang type is a full specialization of the SYCL
82
+ // / specialization constant class.
83
+ static bool isSyclSpecConstantType (const QualType &Ty);
84
+
81
85
// / Checks whether given clang type is declared in the given hierarchy of
82
86
// / declaration contexts.
83
87
// / \param Ty the clang type being checked
@@ -773,6 +777,14 @@ static CompoundStmt *CreateOpenCLKernelBody(Sema &S,
773
777
getExprForSpecialSYCLObj (FldType, WrapperFld,
774
778
WrapperFldCRD, Base,
775
779
InitMethodName, BodyStmts);
780
+ } else if (Util::isSyclSpecConstantType (FldType)) {
781
+ // Specialization constants are "invisible" to the
782
+ // kernel argument creation and device-side SYCL object
783
+ // materialization infrastructure in this source.
784
+ // It is OK not to really materialize them on the kernel
785
+ // side, because their only use can be via
786
+ // 'spec_const_obj.get()' method, which is translated to
787
+ // an intrinsic and 'this' is really never used.
776
788
} else {
777
789
// Field is a structure or class so change the wrapper
778
790
// object and recursively search for accessor field.
@@ -816,6 +828,8 @@ static CompoundStmt *CreateOpenCLKernelBody(Sema &S,
816
828
InitExprs.push_back (MemberInit.get ());
817
829
getExprForSpecialSYCLObj (FieldType, Field, CRD, KernelObjCloneRef,
818
830
InitMethodName, BodyStmts);
831
+ } else if (Util::isSyclSpecConstantType (FieldType)) {
832
+ // Just skip specialization constants - not part of signature.
819
833
} else if (CRD || FieldType->isScalarType ()) {
820
834
// If field has built-in or a structure/class type just initialize
821
835
// this field with corresponding kernel argument using copy
@@ -959,11 +973,13 @@ static bool buildArgTys(ASTContext &Context, CXXRecordDecl *KernelObj,
959
973
QualType FldType = WrapperFld->getType ();
960
974
if (FldType->isStructureOrClassType ()) {
961
975
if (Util::isSyclAccessorType (FldType)) {
962
- // accessor field is found - create descriptor
976
+ // Accessor field is found - create descriptor.
963
977
createSpecialSYCLObjParamDesc (WrapperFld, FldType);
978
+ } else if (Util::isSyclSpecConstantType (FldType)) {
979
+ // Don't try recursive search below.
964
980
} else {
965
- // field is some class or struct - recursively check for
966
- // accessor fields
981
+ // Field is some class or struct - recursively check for
982
+ // accessor fields.
967
983
createParamDescForWrappedAccessors (WrapperFld, FldType);
968
984
}
969
985
}
@@ -985,6 +1001,8 @@ static bool buildArgTys(ASTContext &Context, CXXRecordDecl *KernelObj,
985
1001
QualType ArgTy = Fld->getType ();
986
1002
if (Util::isSyclAccessorType (ArgTy) || Util::isSyclSamplerType (ArgTy)) {
987
1003
createSpecialSYCLObjParamDesc (Fld, ArgTy);
1004
+ } else if (Util::isSyclSpecConstantType (ArgTy)) {
1005
+ // Specialization constants are not added as arguments.
988
1006
} else if (ArgTy->isStructureOrClassType ()) {
989
1007
if (Context.getLangOpts ().SYCLStdLayoutKernelParams ) {
990
1008
if (!ArgTy->isStandardLayoutType ()) {
@@ -1127,6 +1145,21 @@ static void populateIntHeader(SYCLIntegrationHeader &H, const StringRef Name,
1127
1145
uint64_t Sz = Ctx.getTypeSizeInChars (Fld->getType ()).getQuantity ();
1128
1146
H.addParamDesc (SYCLIntegrationHeader::kind_pointer,
1129
1147
static_cast <unsigned >(Sz), static_cast <unsigned >(Offset));
1148
+ } else if (Util::isSyclSpecConstantType (ArgTy)) {
1149
+ // Add specialization constant ID to the header.
1150
+ auto *TmplSpec =
1151
+ cast<ClassTemplateSpecializationDecl>(ArgTy->getAsCXXRecordDecl ());
1152
+ const TemplateArgumentList *TemplateArgs =
1153
+ &TmplSpec->getTemplateInstantiationArgs ();
1154
+ // Get specialization constant ID type, which is the second template
1155
+ // argument.
1156
+ QualType SpecConstIDTy = TypeName::getFullyQualifiedType (
1157
+ TemplateArgs->get (1 ).getAsType (), Ctx, true )
1158
+ .getCanonicalType ();
1159
+ const std::string SpecConstName = PredefinedExpr::ComputeName (
1160
+ Ctx, PredefinedExpr::UniqueStableNameExpr, SpecConstIDTy);
1161
+ H.addSpecConstant (SpecConstName, SpecConstIDTy);
1162
+ // Spec constant lambda capture does not become a kernel argument.
1130
1163
} else if (ArgTy->isStructureOrClassType () || ArgTy->isScalarType ()) {
1131
1164
// the parameter is an object of standard layout type or scalar;
1132
1165
// the check for standard layout is done elsewhere
@@ -1658,6 +1691,13 @@ void SYCLIntegrationHeader::emitForwardClassDecls(
1658
1691
}
1659
1692
}
1660
1693
1694
+ static std::string getCPPTypeString (QualType Ty) {
1695
+ LangOptions LO;
1696
+ PrintingPolicy P (LO);
1697
+ P.SuppressTypedefs = true ;
1698
+ return eraseAnonNamespace (Ty.getAsString (P));
1699
+ }
1700
+
1661
1701
void SYCLIntegrationHeader::emit (raw_ostream &O) {
1662
1702
O << " // This is auto-generated SYCL integration header.\n " ;
1663
1703
O << " \n " ;
@@ -1666,6 +1706,33 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
1666
1706
O << " #include <CL/sycl/detail/kernel_desc.hpp>\n " ;
1667
1707
1668
1708
O << " \n " ;
1709
+
1710
+ if (SpecConsts.size () > 0 ) {
1711
+ // Remove duplicates.
1712
+ std::sort (SpecConsts.begin (), SpecConsts.end (),
1713
+ [](const SpecConstID &SC1, const SpecConstID &SC2) {
1714
+ // Sort by string IDs for stable spec consts order in the
1715
+ // header.
1716
+ return SC1.second .compare (SC2.second ) < 0 ;
1717
+ });
1718
+ SpecConstID *End =
1719
+ std::unique (SpecConsts.begin (), SpecConsts.end (),
1720
+ [](const SpecConstID &SC1, const SpecConstID &SC2) {
1721
+ // Here can do faster comparison of types.
1722
+ return SC1.first == SC2.first ;
1723
+ });
1724
+ O << " // Specialization constants IDs:\n " ;
1725
+ for (const auto &P : llvm::make_range (SpecConsts.begin (), End)) {
1726
+ std::string CPPName = getCPPTypeString (P.first );
1727
+ O << " template <> struct sycl::detail::SpecConstantInfo<" << CPPName
1728
+ << " > {\n " ;
1729
+ O << " static constexpr const char* getName() {\n " ;
1730
+ O << " return \" " << P.second << " \" ;\n " ;
1731
+ O << " }\n " ;
1732
+ O << " };\n " ;
1733
+ }
1734
+ }
1735
+
1669
1736
if (!UnnamedLambdaSupport) {
1670
1737
O << " // Forward declarations of templated kernel function types:\n " ;
1671
1738
@@ -1747,11 +1814,8 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
1747
1814
O << " ', '" << c;
1748
1815
O << " '> {\n " ;
1749
1816
} else {
1750
- LangOptions LO;
1751
- PrintingPolicy P (LO);
1752
- P.SuppressTypedefs = true ;
1753
- O << " template <> struct KernelInfo<"
1754
- << eraseAnonNamespace (K.NameType .getAsString (P)) << " > {\n " ;
1817
+ O << " template <> struct KernelInfo<" << getCPPTypeString (K.NameType )
1818
+ << " > {\n " ;
1755
1819
}
1756
1820
O << " DLL_LOCAL\n " ;
1757
1821
O << " static constexpr const char* getName() { return \" " << K.Name
@@ -1815,6 +1879,10 @@ void SYCLIntegrationHeader::endKernel() {
1815
1879
// nop for now
1816
1880
}
1817
1881
1882
+ void SYCLIntegrationHeader::addSpecConstant (StringRef IDName, QualType IDType) {
1883
+ SpecConsts.emplace_back (std::make_pair (IDType, IDName.str ()));
1884
+ }
1885
+
1818
1886
SYCLIntegrationHeader::SYCLIntegrationHeader (DiagnosticsEngine &_Diag,
1819
1887
bool _UnnamedLambdaSupport)
1820
1888
: Diag(_Diag), UnnamedLambdaSupport(_UnnamedLambdaSupport) {}
@@ -1835,6 +1903,16 @@ bool Util::isSyclStreamType(const QualType &Ty) {
1835
1903
return isSyclType (Ty, " stream" );
1836
1904
}
1837
1905
1906
+ bool Util::isSyclSpecConstantType (const QualType &Ty) {
1907
+ const StringRef &Name = " spec_constant" ;
1908
+ std::array<DeclContextDesc, 4 > Scopes = {
1909
+ Util::DeclContextDesc{clang::Decl::Kind::Namespace, " cl" },
1910
+ Util::DeclContextDesc{clang::Decl::Kind::Namespace, " sycl" },
1911
+ Util::DeclContextDesc{clang::Decl::Kind::Namespace, " experimental" },
1912
+ Util::DeclContextDesc{Decl::Kind::ClassTemplateSpecialization, Name}};
1913
+ return matchQualifiedTypeName (Ty, Scopes);
1914
+ }
1915
+
1838
1916
bool Util::isSyclType (const QualType &Ty, StringRef Name, bool Tmpl) {
1839
1917
Decl::Kind ClassDeclKind =
1840
1918
Tmpl ? Decl::Kind::ClassTemplateSpecialization : Decl::Kind::CXXRecord;
0 commit comments