Skip to content

Commit f9226d2

Browse files
[SYCL] Handle KernelName templated using type with enum template argument (#1780)
Add support to handle enums when KernelNameType is templated using a type which is in turn templated using enum. Signed-off-by: Elizabeth Andrews <[email protected]>
1 parent 12d14e8 commit f9226d2

File tree

3 files changed

+54
-17
lines changed

3 files changed

+54
-17
lines changed

clang/include/clang/Sema/Sema.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,8 @@ class SYCLIntegrationHeader {
319319
};
320320

321321
public:
322-
SYCLIntegrationHeader(DiagnosticsEngine &Diag, bool UnnamedLambdaSupport);
322+
SYCLIntegrationHeader(DiagnosticsEngine &Diag, bool UnnamedLambdaSupport,
323+
Sema &S);
323324

324325
/// Emits contents of the header into given stream.
325326
void emit(raw_ostream &Out);
@@ -424,6 +425,8 @@ class SYCLIntegrationHeader {
424425

425426
/// Whether header is generated with unnamed lambda support
426427
bool UnnamedLambdaSupport;
428+
429+
Sema &S;
427430
};
428431

429432
/// Keeps track of expected type during expression parsing. The type is tied to
@@ -12584,7 +12587,7 @@ class Sema final {
1258412587
SYCLIntegrationHeader &getSyclIntegrationHeader() {
1258512588
if (SyclIntHeader == nullptr)
1258612589
SyclIntHeader = std::make_unique<SYCLIntegrationHeader>(
12587-
getDiagnostics(), getLangOpts().SYCLUnnamedLambda);
12590+
getDiagnostics(), getLangOpts().SYCLUnnamedLambda, *this);
1258812591
return *SyclIntHeader.get();
1258912592
}
1259012593

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1866,6 +1866,9 @@ static void printArguments(ASTContext &Ctx, raw_ostream &ArgOS,
18661866
ArrayRef<TemplateArgument> Args,
18671867
const PrintingPolicy &P);
18681868

1869+
static std::string getKernelNameTypeString(QualType T, ASTContext &Ctx,
1870+
const PrintingPolicy &TypePolicy);
1871+
18691872
static void printArgument(ASTContext &Ctx, raw_ostream &ArgOS,
18701873
TemplateArgument Arg, const PrintingPolicy &P) {
18711874
switch (Arg.getKind()) {
@@ -1891,8 +1894,7 @@ static void printArgument(ASTContext &Ctx, raw_ostream &ArgOS,
18911894
TypePolicy.SuppressTypedefs = true;
18921895
TypePolicy.SuppressTagKeyword = true;
18931896
QualType T = Arg.getAsType();
1894-
QualType FullyQualifiedType = TypeName::getFullyQualifiedType(T, Ctx, true);
1895-
ArgOS << FullyQualifiedType.getAsString(TypePolicy);
1897+
ArgOS << getKernelNameTypeString(T, Ctx, TypePolicy);
18961898
break;
18971899
}
18981900
default:
@@ -1925,36 +1927,36 @@ static void printTemplateArguments(ASTContext &Ctx, raw_ostream &ArgOS,
19251927
ArgOS << ">";
19261928
}
19271929

1928-
static std::string getKernelNameTypeString(QualType T) {
1930+
static std::string getKernelNameTypeString(QualType T, ASTContext &Ctx,
1931+
const PrintingPolicy &TypePolicy) {
1932+
1933+
QualType FullyQualifiedType = TypeName::getFullyQualifiedType(T, Ctx, true);
19291934

19301935
const CXXRecordDecl *RD = T->getAsCXXRecordDecl();
19311936

19321937
if (!RD)
1933-
return getCPPTypeString(T);
1938+
return eraseAnonNamespace(FullyQualifiedType.getAsString(TypePolicy));
19341939

19351940
// If kernel name type is a template specialization with enum type
19361941
// template parameters, enumerators in name type string should be
19371942
// replaced with their underlying value since the enum definition
19381943
// is not visible in integration header.
19391944
if (const auto *TSD = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
1940-
LangOptions LO;
1941-
PrintingPolicy P(LO);
1942-
P.SuppressTypedefs = true;
19431945
SmallString<64> Buf;
19441946
llvm::raw_svector_ostream ArgOS(Buf);
19451947

19461948
// Print template class name
1947-
TSD->printQualifiedName(ArgOS, P, /*WithGlobalNsPrefix*/ true);
1949+
TSD->printQualifiedName(ArgOS, TypePolicy, /*WithGlobalNsPrefix*/ true);
19481950

19491951
// Print template arguments substituting enumerators
19501952
ASTContext &Ctx = RD->getASTContext();
19511953
const TemplateArgumentList &Args = TSD->getTemplateArgs();
1952-
printTemplateArguments(Ctx, ArgOS, Args.asArray(), P);
1954+
printTemplateArguments(Ctx, ArgOS, Args.asArray(), TypePolicy);
19531955

19541956
return eraseAnonNamespace(ArgOS.str().str());
19551957
}
19561958

1957-
return getCPPTypeString(T);
1959+
return eraseAnonNamespace(FullyQualifiedType.getAsString(TypePolicy));
19581960
}
19591961

19601962
void SYCLIntegrationHeader::emit(raw_ostream &O) {
@@ -2073,9 +2075,11 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
20732075
O << "', '" << c;
20742076
O << "'> {\n";
20752077
} else {
2076-
2078+
LangOptions LO;
2079+
PrintingPolicy P(LO);
2080+
P.SuppressTypedefs = true;
20772081
O << "template <> struct KernelInfo<"
2078-
<< getKernelNameTypeString(K.NameType) << "> {\n";
2082+
<< getKernelNameTypeString(K.NameType, S.getASTContext(), P) << "> {\n";
20792083
}
20802084
O << " DLL_LOCAL\n";
20812085
O << " static constexpr const char* getName() { return \"" << K.Name
@@ -2144,8 +2148,9 @@ void SYCLIntegrationHeader::addSpecConstant(StringRef IDName, QualType IDType) {
21442148
}
21452149

21462150
SYCLIntegrationHeader::SYCLIntegrationHeader(DiagnosticsEngine &_Diag,
2147-
bool _UnnamedLambdaSupport)
2148-
: Diag(_Diag), UnnamedLambdaSupport(_UnnamedLambdaSupport) {}
2151+
bool _UnnamedLambdaSupport,
2152+
Sema &_S)
2153+
: Diag(_Diag), UnnamedLambdaSupport(_UnnamedLambdaSupport), S(_S) {}
21492154

21502155
// -----------------------------------------------------------------------------
21512156
// Utility class methods

clang/test/CodeGenSYCL/kernelname-enum.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,21 @@ class dummy_functor_7 {
7979
void operator()() {}
8080
};
8181

82+
namespace type_argument_template_enum {
83+
enum class E : int {
84+
A,
85+
B,
86+
C
87+
};
88+
}
89+
90+
template <typename T>
91+
class T1 {};
92+
template <type_argument_template_enum::E EnumValue>
93+
class T2 {};
94+
template <typename EnumType>
95+
class T3 {};
96+
8297
int main() {
8398

8499
dummy_functor_1<no_namespace_int::val_1> f1;
@@ -124,6 +139,14 @@ int main() {
124139
cgh.single_task(f8);
125140
});
126141

142+
q.submit([&](cl::sycl::handler &cgh) {
143+
cgh.single_task<T1<T2<type_argument_template_enum::E::A>>>([=]() {});
144+
});
145+
146+
q.submit([&](cl::sycl::handler &cgh) {
147+
cgh.single_task<T1<T3<type_argument_template_enum::E>>>([=]() {});
148+
});
149+
127150
return 0;
128151
}
129152

@@ -145,7 +168,11 @@ int main() {
145168
// CHECK: enum unscoped_enum : int;
146169
// CHECK: template <unscoped_enum EnumType> class dummy_functor_6;
147170
// CHECK: template <typename EnumType> class dummy_functor_7;
148-
171+
// CHECK: namespace type_argument_template_enum {
172+
// CHECK-NEXT: enum class E : int;
173+
// CHECK-NEXT: }
174+
// CHECK: template <type_argument_template_enum::E EnumValue> class T2;
175+
// CHECK: template <typename T> class T1;
149176
// CHECK: Specializations of KernelInfo for kernel function types:
150177
// CHECK: template <> struct KernelInfo<::dummy_functor_1<(no_namespace_int)0>>
151178
// CHECK: template <> struct KernelInfo<::dummy_functor_2<(no_namespace_short)1>>
@@ -155,3 +182,5 @@ int main() {
155182
// CHECK: template <> struct KernelInfo<::dummy_functor_6<(unscoped_enum)0>>
156183
// CHECK: template <> struct KernelInfo<::dummy_functor_7<::no_namespace_int>>
157184
// CHECK: template <> struct KernelInfo<::dummy_functor_7<::internal::namespace_short>>
185+
// CHECK: template <> struct KernelInfo<::T1<::T2<(type_argument_template_enum::E)0>>>
186+
// CHECK: template <> struct KernelInfo<::T1<::T3<::type_argument_template_enum::E>>>

0 commit comments

Comments
 (0)