@@ -54,6 +54,8 @@ enum KernelInvocationKind {
54
54
const static std::string InitMethodName = " __init" ;
55
55
const static std::string FinalizeMethodName = " __finalize" ;
56
56
57
+ namespace {
58
+
57
59
// / Various utilities.
58
60
class Util {
59
61
public:
@@ -91,6 +93,8 @@ class Util {
91
93
ArrayRef<Util::DeclContextDesc> Scopes);
92
94
};
93
95
96
+ } // anonymous namespace
97
+
94
98
// This information is from Section 4.13 of the SYCL spec
95
99
// https://www.khronos.org/registry/SYCL/specs/sycl-1.2.1.pdf
96
100
// This function returns false if the math lib function
@@ -206,19 +210,9 @@ static bool isZeroSizedArray(QualType Ty) {
206
210
return false ;
207
211
}
208
212
209
- static Sema::DeviceDiagBuilder
210
- emitDeferredDiagnosticAndNote (Sema &S, SourceRange Loc, unsigned DiagID,
211
- SourceRange UsedAtLoc) {
212
- Sema::DeviceDiagBuilder builder =
213
- S.SYCLDiagIfDeviceCode (Loc.getBegin (), DiagID);
214
- if (UsedAtLoc.isValid ())
215
- S.SYCLDiagIfDeviceCode (UsedAtLoc.getBegin (), diag::note_sycl_used_here);
216
- return builder;
217
- }
218
-
219
- static void checkSYCLVarType (Sema &S, QualType Ty, SourceRange Loc,
220
- llvm::DenseSet<QualType> Visited,
221
- SourceRange UsedAtLoc = SourceRange()) {
213
+ static void checkSYCLType (Sema &S, QualType Ty, SourceRange Loc,
214
+ llvm::DenseSet<QualType> Visited,
215
+ SourceRange UsedAtLoc = SourceRange()) {
222
216
// Not all variable types are supported inside SYCL kernels,
223
217
// for example the quad type __float128 will cause errors in the
224
218
// SPIR-V translation phase.
@@ -229,16 +223,21 @@ static void checkSYCLVarType(Sema &S, QualType Ty, SourceRange Loc,
229
223
// different location than the variable declaration and we need to
230
224
// inform the user of both, e.g. struct member usage vs declaration.
231
225
226
+ bool Emitting = false ;
227
+
232
228
// --- check types ---
233
229
234
230
// zero length arrays
235
- if (isZeroSizedArray (Ty))
236
- emitDeferredDiagnosticAndNote (S, Loc, diag::err_typecheck_zero_array_size,
237
- UsedAtLoc);
231
+ if (isZeroSizedArray (Ty)) {
232
+ S.SYCLDiagIfDeviceCode (Loc.getBegin (), diag::err_typecheck_zero_array_size);
233
+ Emitting = true ;
234
+ }
238
235
239
236
// variable length arrays
240
- if (Ty->isVariableArrayType ())
241
- emitDeferredDiagnosticAndNote (S, Loc, diag::err_vla_unsupported, UsedAtLoc);
237
+ if (Ty->isVariableArrayType ()) {
238
+ S.SYCLDiagIfDeviceCode (Loc.getBegin (), diag::err_vla_unsupported);
239
+ Emitting = true ;
240
+ }
242
241
243
242
// Sub-reference array or pointer, then proceed with that type.
244
243
while (Ty->isAnyPointerType () || Ty->isArrayType ())
@@ -249,9 +248,14 @@ static void checkSYCLVarType(Sema &S, QualType Ty, SourceRange Loc,
249
248
Ty->isSpecificBuiltinType (BuiltinType::UInt128) ||
250
249
Ty->isSpecificBuiltinType (BuiltinType::LongDouble) ||
251
250
(Ty->isSpecificBuiltinType (BuiltinType::Float128) &&
252
- !S.Context .getTargetInfo ().hasFloat128Type ()))
253
- emitDeferredDiagnosticAndNote (S, Loc, diag::err_type_unsupported, UsedAtLoc )
251
+ !S.Context .getTargetInfo ().hasFloat128Type ())) {
252
+ S. SYCLDiagIfDeviceCode ( Loc. getBegin () , diag::err_type_unsupported)
254
253
<< Ty.getUnqualifiedType ().getCanonicalType ();
254
+ Emitting = true ;
255
+ }
256
+
257
+ if (Emitting && UsedAtLoc.isValid ())
258
+ S.SYCLDiagIfDeviceCode (UsedAtLoc.getBegin (), diag::note_used_here);
255
259
256
260
// --- now recurse ---
257
261
// Pointers complicate recursion. Add this type to Visited.
@@ -260,16 +264,15 @@ static void checkSYCLVarType(Sema &S, QualType Ty, SourceRange Loc,
260
264
return ;
261
265
262
266
if (const auto *ATy = dyn_cast<AttributedType>(Ty))
263
- return checkSYCLVarType (S, ATy->getModifiedType (), Loc, Visited);
267
+ return checkSYCLType (S, ATy->getModifiedType (), Loc, Visited);
264
268
265
269
if (const auto *RD = Ty->getAsRecordDecl ()) {
266
270
for (const auto &Field : RD->fields ())
267
- checkSYCLVarType (S, Field->getType (), Field->getSourceRange (), Visited,
268
- Loc);
271
+ checkSYCLType (S, Field->getType (), Field->getSourceRange (), Visited, Loc);
269
272
} else if (const auto *FPTy = dyn_cast<FunctionProtoType>(Ty)) {
270
273
for (const auto &ParamTy : FPTy->param_types ())
271
- checkSYCLVarType (S, ParamTy, Loc, Visited);
272
- checkSYCLVarType (S, FPTy->getReturnType (), Loc, Visited);
274
+ checkSYCLType (S, ParamTy, Loc, Visited);
275
+ checkSYCLType (S, FPTy->getReturnType (), Loc, Visited);
273
276
}
274
277
}
275
278
@@ -280,7 +283,7 @@ void Sema::checkSYCLDeviceVarDecl(VarDecl *Var) {
280
283
SourceRange Loc = Var->getLocation ();
281
284
llvm::DenseSet<QualType> Visited;
282
285
283
- checkSYCLVarType (*this , Ty, Loc, Visited);
286
+ checkSYCLType (*this , Ty, Loc, Visited);
284
287
}
285
288
286
289
class MarkDeviceFunction : public RecursiveASTVisitor <MarkDeviceFunction> {
@@ -801,6 +804,22 @@ class SyclKernelFieldChecker
801
804
bool IsInvalid = false ;
802
805
DiagnosticsEngine &Diag;
803
806
807
+ void checkAccessorType (QualType Ty, SourceRange Loc) {
808
+ assert (Util::isSyclAccessorType (Ty) &&
809
+ " Should only be called on SYCL accessor types." );
810
+
811
+ const RecordDecl *RecD = Ty->getAsRecordDecl ();
812
+ if (const ClassTemplateSpecializationDecl *CTSD =
813
+ dyn_cast<ClassTemplateSpecializationDecl>(RecD)) {
814
+ const TemplateArgumentList &TAL = CTSD->getTemplateArgs ();
815
+ TemplateArgument TA = TAL.get (0 );
816
+ const QualType TemplateArgTy = TA.getAsType ();
817
+
818
+ llvm::DenseSet<QualType> Visited;
819
+ checkSYCLType (SemaRef, TemplateArgTy, Loc, Visited);
820
+ }
821
+ }
822
+
804
823
public:
805
824
SyclKernelFieldChecker (Sema &S)
806
825
: SyclKernelFieldHandler(S), Diag(S.getASTContext().getDiagnostics()) {}
@@ -832,6 +851,15 @@ class SyclKernelFieldChecker
832
851
}
833
852
}
834
853
854
+ void handleSyclAccessorType (const CXXBaseSpecifier &BS,
855
+ QualType FieldTy) final {
856
+ checkAccessorType (FieldTy, BS.getBeginLoc ());
857
+ }
858
+
859
+ void handleSyclAccessorType (FieldDecl *FD, QualType FieldTy) final {
860
+ checkAccessorType (FieldTy, FD->getLocation ());
861
+ }
862
+
835
863
// We should be able to handle this, so we made it part of the visitor, but
836
864
// this is 'to be implemented'.
837
865
void handleArrayType (FieldDecl *FD, QualType FieldTy) final {
@@ -1454,7 +1482,6 @@ void Sema::ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc,
1454
1482
}
1455
1483
1456
1484
void Sema::MarkDevice (void ) {
1457
- // Let's mark all called functions with SYCL Device attribute.
1458
1485
// Create the call graph so we can detect recursion and check the validity
1459
1486
// of new operator overrides. Add the kernel function itself in case
1460
1487
// it is recursive.
@@ -1540,7 +1567,9 @@ Sema::DeviceDiagBuilder Sema::SYCLDiagIfDeviceCode(SourceLocation Loc,
1540
1567
" Should only be called during SYCL compilation" );
1541
1568
FunctionDecl *FD = dyn_cast<FunctionDecl>(getCurLexicalContext ());
1542
1569
DeviceDiagBuilder::Kind DiagKind = [this , FD] {
1543
- if (ConstructingOpenCLKernel || !FD)
1570
+ if (ConstructingOpenCLKernel)
1571
+ return DeviceDiagBuilder::K_ImmediateWithCallStack;
1572
+ if (!FD)
1544
1573
return DeviceDiagBuilder::K_Nop;
1545
1574
if (getEmissionStatus (FD) == Sema::FunctionEmissionStatus::Emitted)
1546
1575
return DeviceDiagBuilder::K_ImmediateWithCallStack;
@@ -1863,6 +1892,9 @@ static void printArguments(ASTContext &Ctx, raw_ostream &ArgOS,
1863
1892
ArrayRef<TemplateArgument> Args,
1864
1893
const PrintingPolicy &P);
1865
1894
1895
+ static std::string getKernelNameTypeString (QualType T, ASTContext &Ctx,
1896
+ const PrintingPolicy &TypePolicy);
1897
+
1866
1898
static void printArgument (ASTContext &Ctx, raw_ostream &ArgOS,
1867
1899
TemplateArgument Arg, const PrintingPolicy &P) {
1868
1900
switch (Arg.getKind ()) {
@@ -1888,8 +1920,7 @@ static void printArgument(ASTContext &Ctx, raw_ostream &ArgOS,
1888
1920
TypePolicy.SuppressTypedefs = true ;
1889
1921
TypePolicy.SuppressTagKeyword = true ;
1890
1922
QualType T = Arg.getAsType ();
1891
- QualType FullyQualifiedType = TypeName::getFullyQualifiedType (T, Ctx, true );
1892
- ArgOS << FullyQualifiedType.getAsString (TypePolicy);
1923
+ ArgOS << getKernelNameTypeString (T, Ctx, TypePolicy);
1893
1924
break ;
1894
1925
}
1895
1926
default :
@@ -1903,6 +1934,10 @@ static void printArguments(ASTContext &Ctx, raw_ostream &ArgOS,
1903
1934
for (unsigned I = 0 ; I < Args.size (); I++) {
1904
1935
const TemplateArgument &Arg = Args[I];
1905
1936
1937
+ // If argument is an empty pack argument, skip printing comma and argument.
1938
+ if (Arg.getKind () == TemplateArgument::ArgKind::Pack && !Arg.pack_size ())
1939
+ continue ;
1940
+
1906
1941
if (I != 0 )
1907
1942
ArgOS << " , " ;
1908
1943
@@ -1918,36 +1953,36 @@ static void printTemplateArguments(ASTContext &Ctx, raw_ostream &ArgOS,
1918
1953
ArgOS << " >" ;
1919
1954
}
1920
1955
1921
- static std::string getKernelNameTypeString (QualType T) {
1956
+ static std::string getKernelNameTypeString (QualType T, ASTContext &Ctx,
1957
+ const PrintingPolicy &TypePolicy) {
1958
+
1959
+ QualType FullyQualifiedType = TypeName::getFullyQualifiedType (T, Ctx, true );
1922
1960
1923
1961
const CXXRecordDecl *RD = T->getAsCXXRecordDecl ();
1924
1962
1925
1963
if (!RD)
1926
- return getCPPTypeString (T );
1964
+ return eraseAnonNamespace (FullyQualifiedType. getAsString (TypePolicy) );
1927
1965
1928
1966
// If kernel name type is a template specialization with enum type
1929
1967
// template parameters, enumerators in name type string should be
1930
1968
// replaced with their underlying value since the enum definition
1931
1969
// is not visible in integration header.
1932
1970
if (const auto *TSD = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
1933
- LangOptions LO;
1934
- PrintingPolicy P (LO);
1935
- P.SuppressTypedefs = true ;
1936
1971
SmallString<64 > Buf;
1937
1972
llvm::raw_svector_ostream ArgOS (Buf);
1938
1973
1939
1974
// Print template class name
1940
- TSD->printQualifiedName (ArgOS, P , /* WithGlobalNsPrefix*/ true );
1975
+ TSD->printQualifiedName (ArgOS, TypePolicy , /* WithGlobalNsPrefix*/ true );
1941
1976
1942
1977
// Print template arguments substituting enumerators
1943
1978
ASTContext &Ctx = RD->getASTContext ();
1944
1979
const TemplateArgumentList &Args = TSD->getTemplateArgs ();
1945
- printTemplateArguments (Ctx, ArgOS, Args.asArray (), P );
1980
+ printTemplateArguments (Ctx, ArgOS, Args.asArray (), TypePolicy );
1946
1981
1947
1982
return eraseAnonNamespace (ArgOS.str ().str ());
1948
1983
}
1949
1984
1950
- return getCPPTypeString (T );
1985
+ return eraseAnonNamespace (FullyQualifiedType. getAsString (TypePolicy) );
1951
1986
}
1952
1987
1953
1988
void SYCLIntegrationHeader::emit (raw_ostream &O) {
@@ -2066,9 +2101,11 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
2066
2101
O << " ', '" << c;
2067
2102
O << " '> {\n " ;
2068
2103
} else {
2069
-
2104
+ LangOptions LO;
2105
+ PrintingPolicy P (LO);
2106
+ P.SuppressTypedefs = true ;
2070
2107
O << " template <> struct KernelInfo<"
2071
- << getKernelNameTypeString (K.NameType ) << " > {\n " ;
2108
+ << getKernelNameTypeString (K.NameType , S. getASTContext (), P ) << " > {\n " ;
2072
2109
}
2073
2110
O << " DLL_LOCAL\n " ;
2074
2111
O << " static constexpr const char* getName() { return \" " << K.Name
@@ -2137,8 +2174,9 @@ void SYCLIntegrationHeader::addSpecConstant(StringRef IDName, QualType IDType) {
2137
2174
}
2138
2175
2139
2176
SYCLIntegrationHeader::SYCLIntegrationHeader (DiagnosticsEngine &_Diag,
2140
- bool _UnnamedLambdaSupport)
2141
- : Diag(_Diag), UnnamedLambdaSupport(_UnnamedLambdaSupport) {}
2177
+ bool _UnnamedLambdaSupport,
2178
+ Sema &_S)
2179
+ : Diag(_Diag), UnnamedLambdaSupport(_UnnamedLambdaSupport), S(_S) {}
2142
2180
2143
2181
// -----------------------------------------------------------------------------
2144
2182
// Utility class methods
0 commit comments