Skip to content

Commit e7020a1

Browse files
[SYCL] Add support for kernel name types templated using enums. (#1675)
Modify printing policy and add test for enums in anonymous namespace Added diagnostic to handle unscoped enums with no fixed underlying type. Also fixed handling of template argument type - enum type. The prior patch handled only enum values, not the type itself. Signed-off-by: Elizabeth Andrews <[email protected]>
1 parent 3192ee7 commit e7020a1

File tree

9 files changed

+394
-26
lines changed

9 files changed

+394
-26
lines changed

clang/include/clang/AST/Decl.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,15 +283,16 @@ class NamedDecl : public Decl {
283283
/// Creating this name is expensive, so it should be called only when
284284
/// performance doesn't matter.
285285
void printQualifiedName(raw_ostream &OS) const;
286-
void printQualifiedName(raw_ostream &OS, const PrintingPolicy &Policy) const;
286+
void printQualifiedName(raw_ostream &OS, const PrintingPolicy &Policy,
287+
bool WithGlobalNsPrefix = false) const;
287288

288289
/// Print only the nested name specifier part of a fully-qualified name,
289290
/// including the '::' at the end. E.g.
290291
/// when `printQualifiedName(D)` prints "A::B::i",
291292
/// this function prints "A::B::".
292293
void printNestedNameSpecifier(raw_ostream &OS) const;
293-
void printNestedNameSpecifier(raw_ostream &OS,
294-
const PrintingPolicy &Policy) const;
294+
void printNestedNameSpecifier(raw_ostream &OS, const PrintingPolicy &Policy,
295+
bool WithGlobalNsPrefix = false) const;
295296

296297
// FIXME: Remove string version.
297298
std::string getQualifiedNameAsString() const;

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10789,7 +10789,9 @@ def err_builtin_matrix_disabled: Error<
1078910789
// SYCL-specific diagnostics
1079010790
def err_sycl_kernel_incorrectly_named : Error<
1079110791
"kernel %select{name is missing"
10792-
"|needs to have a globally-visible name}0">;
10792+
"|needs to have a globally-visible name"
10793+
"|name is invalid. Unscoped enum requires fixed underlying type"
10794+
"}0">;
1079310795
def err_sycl_restrict : Error<
1079410796
"SYCL kernel cannot "
1079510797
"%select{use a non-const global variable"

clang/lib/AST/Decl.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1538,14 +1538,14 @@ void NamedDecl::printQualifiedName(raw_ostream &OS) const {
15381538
printQualifiedName(OS, getASTContext().getPrintingPolicy());
15391539
}
15401540

1541-
void NamedDecl::printQualifiedName(raw_ostream &OS,
1542-
const PrintingPolicy &P) const {
1541+
void NamedDecl::printQualifiedName(raw_ostream &OS, const PrintingPolicy &P,
1542+
bool WithGlobalNsPrefix) const {
15431543
if (getDeclContext()->isFunctionOrMethod()) {
15441544
// We do not print '(anonymous)' for function parameters without name.
15451545
printName(OS);
15461546
return;
15471547
}
1548-
printNestedNameSpecifier(OS, P);
1548+
printNestedNameSpecifier(OS, P, WithGlobalNsPrefix);
15491549
if (getDeclName())
15501550
OS << *this;
15511551
else {
@@ -1566,7 +1566,8 @@ void NamedDecl::printNestedNameSpecifier(raw_ostream &OS) const {
15661566
}
15671567

15681568
void NamedDecl::printNestedNameSpecifier(raw_ostream &OS,
1569-
const PrintingPolicy &P) const {
1569+
const PrintingPolicy &P,
1570+
bool WithGlobalNsPrefix) const {
15701571
const DeclContext *Ctx = getDeclContext();
15711572

15721573
// For ObjC methods and properties, look through categories and use the
@@ -1593,6 +1594,9 @@ void NamedDecl::printNestedNameSpecifier(raw_ostream &OS,
15931594
Ctx = Ctx->getParent();
15941595
}
15951596

1597+
if (WithGlobalNsPrefix)
1598+
OS << "::";
1599+
15961600
for (const DeclContext *DC : llvm::reverse(Contexts)) {
15971601
if (const auto *Spec = dyn_cast<ClassTemplateSpecializationDecl>(DC)) {
15981602
OS << Spec->getName();
@@ -1605,8 +1609,7 @@ void NamedDecl::printNestedNameSpecifier(raw_ostream &OS,
16051609
if (ND->isAnonymousNamespace()) {
16061610
OS << (P.MSVCFormatting ? "`anonymous namespace\'"
16071611
: "(anonymous namespace)");
1608-
}
1609-
else
1612+
} else
16101613
OS << *ND;
16111614
} else if (const auto *RD = dyn_cast<RecordDecl>(DC)) {
16121615
if (!RD->getIdentifier())

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 132 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1638,6 +1638,18 @@ static std::string eraseAnonNamespace(std::string S) {
16381638
return S;
16391639
}
16401640

1641+
static bool checkEnumTemplateParameter(const EnumDecl *ED,
1642+
DiagnosticsEngine &Diag,
1643+
SourceLocation KernelLocation) {
1644+
if (!ED->isScoped() && !ED->isFixed()) {
1645+
Diag.Report(KernelLocation, diag::err_sycl_kernel_incorrectly_named) << 2;
1646+
Diag.Report(ED->getSourceRange().getBegin(), diag::note_entity_declared_at)
1647+
<< ED;
1648+
return true;
1649+
}
1650+
return false;
1651+
}
1652+
16411653
// Emits a forward declaration
16421654
void SYCLIntegrationHeader::emitFwdDecl(raw_ostream &O, const Decl *D,
16431655
SourceLocation KernelLocation) {
@@ -1691,10 +1703,22 @@ void SYCLIntegrationHeader::emitFwdDecl(raw_ostream &O, const Decl *D,
16911703
PrintingPolicy P(D->getASTContext().getLangOpts());
16921704
P.adjustForCPlusPlusFwdDecl();
16931705
P.SuppressTypedefs = true;
1706+
P.SuppressUnwrittenScope = true;
16941707
std::string S;
16951708
llvm::raw_string_ostream SO(S);
16961709
D->print(SO, P);
1697-
O << SO.str() << ";\n";
1710+
O << SO.str();
1711+
1712+
if (const auto *ED = dyn_cast<EnumDecl>(D)) {
1713+
QualType T = ED->getIntegerType();
1714+
// Backup since getIntegerType() returns null for enum forward
1715+
// declaration with no fixed underlying type
1716+
if (T.isNull())
1717+
T = ED->getPromotionType();
1718+
O << " : " << T.getAsString();
1719+
}
1720+
1721+
O << ";\n";
16981722

16991723
// print closing braces for namespaces if needed
17001724
for (unsigned I = 0; I < NamespaceCnt; ++I)
@@ -1763,8 +1787,20 @@ void SYCLIntegrationHeader::emitForwardClassDecls(
17631787

17641788
switch (Arg.getKind()) {
17651789
case TemplateArgument::ArgKind::Type:
1766-
emitForwardClassDecls(O, Arg.getAsType(), KernelLocation, Printed);
1790+
case TemplateArgument::ArgKind::Integral: {
1791+
QualType T = (Arg.getKind() == TemplateArgument::ArgKind::Type)
1792+
? Arg.getAsType()
1793+
: Arg.getIntegralType();
1794+
1795+
// Handle Kernel Name Type templated using enum type and value.
1796+
if (const auto *ET = T->getAs<EnumType>()) {
1797+
const EnumDecl *ED = ET->getDecl();
1798+
if (!checkEnumTemplateParameter(ED, Diag, KernelLocation))
1799+
emitFwdDecl(O, ED, KernelLocation);
1800+
} else if (Arg.getKind() == TemplateArgument::ArgKind::Type)
1801+
emitForwardClassDecls(O, T, KernelLocation, Printed);
17671802
break;
1803+
}
17681804
case TemplateArgument::ArgKind::Pack: {
17691805
ArrayRef<TemplateArgument> Pack = Arg.getPackAsArray();
17701806

@@ -1823,6 +1859,97 @@ static std::string getCPPTypeString(QualType Ty) {
18231859
return eraseAnonNamespace(Ty.getAsString(P));
18241860
}
18251861

1862+
static void printArguments(ASTContext &Ctx, raw_ostream &ArgOS,
1863+
ArrayRef<TemplateArgument> Args,
1864+
const PrintingPolicy &P);
1865+
1866+
static void printArgument(ASTContext &Ctx, raw_ostream &ArgOS,
1867+
TemplateArgument Arg, const PrintingPolicy &P) {
1868+
switch (Arg.getKind()) {
1869+
case TemplateArgument::ArgKind::Pack: {
1870+
printArguments(Ctx, ArgOS, Arg.getPackAsArray(), P);
1871+
break;
1872+
}
1873+
case TemplateArgument::ArgKind::Integral: {
1874+
QualType T = Arg.getIntegralType();
1875+
const EnumType *ET = T->getAs<EnumType>();
1876+
1877+
if (ET) {
1878+
const llvm::APSInt &Val = Arg.getAsIntegral();
1879+
ArgOS << "(" << ET->getDecl()->getQualifiedNameAsString() << ")" << Val;
1880+
} else {
1881+
Arg.print(P, ArgOS);
1882+
}
1883+
break;
1884+
}
1885+
case TemplateArgument::ArgKind::Type: {
1886+
LangOptions LO;
1887+
PrintingPolicy TypePolicy(LO);
1888+
TypePolicy.SuppressTypedefs = true;
1889+
TypePolicy.SuppressTagKeyword = true;
1890+
QualType T = Arg.getAsType();
1891+
QualType FullyQualifiedType = TypeName::getFullyQualifiedType(T, Ctx, true);
1892+
ArgOS << FullyQualifiedType.getAsString(TypePolicy);
1893+
break;
1894+
}
1895+
default:
1896+
Arg.print(P, ArgOS);
1897+
}
1898+
}
1899+
1900+
static void printArguments(ASTContext &Ctx, raw_ostream &ArgOS,
1901+
ArrayRef<TemplateArgument> Args,
1902+
const PrintingPolicy &P) {
1903+
for (unsigned I = 0; I < Args.size(); I++) {
1904+
const TemplateArgument &Arg = Args[I];
1905+
1906+
if (I != 0)
1907+
ArgOS << ", ";
1908+
1909+
printArgument(Ctx, ArgOS, Arg, P);
1910+
}
1911+
}
1912+
1913+
static void printTemplateArguments(ASTContext &Ctx, raw_ostream &ArgOS,
1914+
ArrayRef<TemplateArgument> Args,
1915+
const PrintingPolicy &P) {
1916+
ArgOS << "<";
1917+
printArguments(Ctx, ArgOS, Args, P);
1918+
ArgOS << ">";
1919+
}
1920+
1921+
static std::string getKernelNameTypeString(QualType T) {
1922+
1923+
const CXXRecordDecl *RD = T->getAsCXXRecordDecl();
1924+
1925+
if (!RD)
1926+
return getCPPTypeString(T);
1927+
1928+
// If kernel name type is a template specialization with enum type
1929+
// template parameters, enumerators in name type string should be
1930+
// replaced with their underlying value since the enum definition
1931+
// is not visible in integration header.
1932+
if (const auto *TSD = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
1933+
LangOptions LO;
1934+
PrintingPolicy P(LO);
1935+
P.SuppressTypedefs = true;
1936+
SmallString<64> Buf;
1937+
llvm::raw_svector_ostream ArgOS(Buf);
1938+
1939+
// Print template class name
1940+
TSD->printQualifiedName(ArgOS, P, /*WithGlobalNsPrefix*/ true);
1941+
1942+
// Print template arguments substituting enumerators
1943+
ASTContext &Ctx = RD->getASTContext();
1944+
const TemplateArgumentList &Args = TSD->getTemplateArgs();
1945+
printTemplateArguments(Ctx, ArgOS, Args.asArray(), P);
1946+
1947+
return eraseAnonNamespace(ArgOS.str().str());
1948+
}
1949+
1950+
return getCPPTypeString(T);
1951+
}
1952+
18261953
void SYCLIntegrationHeader::emit(raw_ostream &O) {
18271954
O << "// This is auto-generated SYCL integration header.\n";
18281955
O << "\n";
@@ -1939,8 +2066,9 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
19392066
O << "', '" << c;
19402067
O << "'> {\n";
19412068
} else {
1942-
O << "template <> struct KernelInfo<" << getCPPTypeString(K.NameType)
1943-
<< "> {\n";
2069+
2070+
O << "template <> struct KernelInfo<"
2071+
<< getKernelNameTypeString(K.NameType) << "> {\n";
19442072
}
19452073
O << " DLL_LOCAL\n";
19462074
O << " static constexpr const char* getName() { return \"" << K.Name

clang/test/CodeGenSYCL/int_header1.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
// CHECK:template <> struct KernelInfo<class KernelName> {
55
// CHECK:template <> struct KernelInfo<::nm1::nm2::KernelName0> {
66
// CHECK:template <> struct KernelInfo<::nm1::KernelName1> {
7-
// CHECK:template <> struct KernelInfo<::nm1::KernelName3< ::nm1::nm2::KernelName0>> {
8-
// CHECK:template <> struct KernelInfo<::nm1::KernelName3< ::nm1::KernelName1>> {
9-
// CHECK:template <> struct KernelInfo<::nm1::KernelName4< ::nm1::nm2::KernelName0>> {
10-
// CHECK:template <> struct KernelInfo<::nm1::KernelName4< ::nm1::KernelName1>> {
7+
// CHECK:template <> struct KernelInfo<::nm1::KernelName3<::nm1::nm2::KernelName0>> {
8+
// CHECK:template <> struct KernelInfo<::nm1::KernelName3<::nm1::KernelName1>> {
9+
// CHECK:template <> struct KernelInfo<::nm1::KernelName4<::nm1::nm2::KernelName0>> {
10+
// CHECK:template <> struct KernelInfo<::nm1::KernelName4<::nm1::KernelName1>> {
1111
// CHECK:template <> struct KernelInfo<::nm1::KernelName3<KernelName5>> {
1212
// CHECK:template <> struct KernelInfo<::nm1::KernelName4<KernelName7>> {
13-
// CHECK:template <> struct KernelInfo<::nm1::KernelName8< ::nm1::nm2::C>> {
14-
// CHECK:template <> struct KernelInfo<class TmplClassInAnonNS<class ClassInAnonNS>> {
13+
// CHECK:template <> struct KernelInfo<::nm1::KernelName8<::nm1::nm2::C>> {
14+
// CHECK:template <> struct KernelInfo<::TmplClassInAnonNS<ClassInAnonNS>> {
1515

1616
// This test checks if the SYCL device compiler is able to generate correct
1717
// integration header when the kernel name class is expressed in different

clang/test/CodeGenSYCL/integration_header.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@
5757
//
5858
// CHECK: template <> struct KernelInfo<class first_kernel> {
5959
// CHECK: template <> struct KernelInfo<::second_namespace::second_kernel<char>> {
60-
// CHECK: template <> struct KernelInfo<::third_kernel<1, int, ::point<X> >> {
61-
// CHECK: template <> struct KernelInfo<::fourth_kernel< ::template_arg_ns::namespaced_arg<1> >> {
60+
// CHECK: template <> struct KernelInfo<::third_kernel<1, int, ::point<X>>> {
61+
// CHECK: template <> struct KernelInfo<::fourth_kernel<::template_arg_ns::namespaced_arg<1>>> {
6262

6363
#include "sycl.hpp"
6464

clang/test/CodeGenSYCL/kernel_name_with_typedefs.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,16 +117,16 @@ int main() {
117117
single_task<kernel_name2<const space::long_t, space::a_t>>(f);
118118
// CHECK: template <> struct KernelInfo<::kernel_name2<const volatile long, const ::space::B>> {
119119
single_task<kernel_name2<volatile space::clong_t, const space::b_t>>(f);
120-
// CHECK: template <> struct KernelInfo<::kernel_name2< ::A, long>> {
120+
// CHECK: template <> struct KernelInfo<::kernel_name2<::A, long>> {
121121
single_task<kernel_name2<space::a_t, space::long_t>>(f);
122-
// CHECK: template <> struct KernelInfo<::kernel_name2< ::space::B, int>> {
122+
// CHECK: template <> struct KernelInfo<::kernel_name2<::space::B, int>> {
123123
single_task<kernel_name2<space::b_t, int_t>>(f);
124124
// full template specialization
125125
// CHECK: template <> struct KernelInfo<::kernel_name2<int, const unsigned int>> {
126126
single_task<kernel_name2<int_t, const uint_t>>(f);
127-
// CHECK: template <> struct KernelInfo<::kernel_name2<const long, volatile const unsigned long>> {
127+
// CHECK: template <> struct KernelInfo<::kernel_name2<const long, const volatile unsigned long>> {
128128
single_task<kernel_name2<space::clong_t, volatile space::culong_t>>(f);
129-
// CHECK: template <> struct KernelInfo<::kernel_name2< ::A, volatile ::space::B>> {
129+
// CHECK: template <> struct KernelInfo<::kernel_name2<::A, volatile ::space::B>> {
130130
single_task<kernel_name2<space::a_t, volatile space::b_t>>(f);
131131
// CHECK: template <> struct KernelInfo<::kernel_name3<1>> {
132132
single_task<kernel_name3<1>>(f);

0 commit comments

Comments
 (0)