14
14
#include " clang/AST/QualTypeNames.h"
15
15
#include " clang/AST/RecordLayout.h"
16
16
#include " clang/AST/RecursiveASTVisitor.h"
17
+ #include " clang/AST/TemplateArgumentVisitor.h"
18
+ #include " clang/AST/TypeVisitor.h"
17
19
#include " clang/Analysis/CallGraph.h"
18
20
#include " clang/Basic/Attributes.h"
19
21
#include " clang/Basic/Builtins.h"
@@ -2473,9 +2475,111 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
2473
2475
2474
2476
} // namespace
2475
2477
2478
+ class SYCLKernelNameTypeVisitor
2479
+ : public TypeVisitor<SYCLKernelNameTypeVisitor>,
2480
+ public ConstTemplateArgumentVisitor<SYCLKernelNameTypeVisitor> {
2481
+ Sema &S;
2482
+ SourceLocation KernelInvocationFuncLoc;
2483
+ using InnerTypeVisitor = TypeVisitor<SYCLKernelNameTypeVisitor>;
2484
+ using InnerTAVisitor =
2485
+ ConstTemplateArgumentVisitor<SYCLKernelNameTypeVisitor>;
2486
+
2487
+ public:
2488
+ SYCLKernelNameTypeVisitor (Sema &S, SourceLocation KernelInvocationFuncLoc)
2489
+ : S(S), KernelInvocationFuncLoc(KernelInvocationFuncLoc) {}
2490
+
2491
+ void Visit (QualType T) {
2492
+ if (T.isNull ())
2493
+ return ;
2494
+ const CXXRecordDecl *RD = T->getAsCXXRecordDecl ();
2495
+ if (!RD)
2496
+ return ;
2497
+ // If KernelNameType has template args visit each template arg via
2498
+ // ConstTemplateArgumentVisitor
2499
+ if (const auto *TSD = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
2500
+ const TemplateArgumentList &Args = TSD->getTemplateArgs ();
2501
+ for (unsigned I = 0 ; I < Args.size (); I++) {
2502
+ Visit (Args[I]);
2503
+ }
2504
+ } else {
2505
+ InnerTypeVisitor::Visit (T.getTypePtr ());
2506
+ }
2507
+ }
2508
+
2509
+ void Visit (const TemplateArgument &TA) {
2510
+ if (TA.isNull ())
2511
+ return ;
2512
+ InnerTAVisitor::Visit (TA);
2513
+ }
2514
+
2515
+ void VisitEnumType (const EnumType *T) {
2516
+ const EnumDecl *ED = T->getDecl ();
2517
+ if (!ED->isScoped () && !ED->isFixed ()) {
2518
+ S.Diag (KernelInvocationFuncLoc, diag::err_sycl_kernel_incorrectly_named)
2519
+ << /* Unscoped enum requires fixed underlying type */ 2 ;
2520
+ S.Diag (ED->getSourceRange ().getBegin (), diag::note_entity_declared_at)
2521
+ << ED;
2522
+ }
2523
+ }
2524
+
2525
+ void VisitRecordType (const RecordType *T) {
2526
+ return VisitTagDecl (T->getDecl ());
2527
+ }
2528
+
2529
+ void VisitTagDecl (const TagDecl *Tag) {
2530
+ bool UnnamedLambdaEnabled =
2531
+ S.getASTContext ().getLangOpts ().SYCLUnnamedLambda ;
2532
+ if (!Tag->getDeclContext ()->isTranslationUnit () &&
2533
+ !isa<NamespaceDecl>(Tag->getDeclContext ()) && !UnnamedLambdaEnabled) {
2534
+ const bool KernelNameIsMissing = Tag->getName ().empty ();
2535
+ if (KernelNameIsMissing) {
2536
+ S.Diag (KernelInvocationFuncLoc, diag::err_sycl_kernel_incorrectly_named)
2537
+ << /* kernel name is missing */ 0 ;
2538
+ } else {
2539
+ if (Tag->isCompleteDefinition ())
2540
+ S.Diag (KernelInvocationFuncLoc,
2541
+ diag::err_sycl_kernel_incorrectly_named)
2542
+ << /* kernel name is not globally-visible */ 1 ;
2543
+ else
2544
+ S.Diag (KernelInvocationFuncLoc, diag::warn_sycl_implicit_decl);
2545
+
2546
+ S.Diag (Tag->getSourceRange ().getBegin (), diag::note_previous_decl)
2547
+ << Tag->getName ();
2548
+ }
2549
+ }
2550
+ }
2551
+
2552
+ void VisitTypeTemplateArgument (const TemplateArgument &TA) {
2553
+ QualType T = TA.getAsType ();
2554
+ if (const auto *ET = T->getAs <EnumType>())
2555
+ VisitEnumType (ET);
2556
+ else
2557
+ Visit (T);
2558
+ }
2559
+
2560
+ void VisitIntegralTemplateArgument (const TemplateArgument &TA) {
2561
+ QualType T = TA.getIntegralType ();
2562
+ if (const EnumType *ET = T->getAs <EnumType>())
2563
+ VisitEnumType (ET);
2564
+ }
2565
+
2566
+ void VisitTemplateTemplateArgument (const TemplateArgument &TA) {
2567
+ TemplateDecl *TD = TA.getAsTemplate ().getAsTemplateDecl ();
2568
+ TemplateParameterList *TemplateParams = TD->getTemplateParameters ();
2569
+ for (NamedDecl *P : *TemplateParams) {
2570
+ if (NonTypeTemplateParmDecl *TemplateParam =
2571
+ dyn_cast<NonTypeTemplateParmDecl>(P))
2572
+ if (const EnumType *ET = TemplateParam->getType ()->getAs <EnumType>())
2573
+ VisitEnumType (ET);
2574
+ }
2575
+ }
2576
+ };
2577
+
2476
2578
void Sema::CheckSYCLKernelCall (FunctionDecl *KernelFunc, SourceRange CallLoc,
2477
2579
ArrayRef<const Expr *> Args) {
2478
2580
const CXXRecordDecl *KernelObj = getKernelObjectType (KernelFunc);
2581
+ QualType KernelNameType =
2582
+ calculateKernelNameType (getASTContext (), KernelFunc);
2479
2583
if (!KernelObj) {
2480
2584
Diag (Args[0 ]->getExprLoc (), diag::err_sycl_kernel_not_function_object);
2481
2585
KernelFunc->setInvalidDecl ();
@@ -2511,6 +2615,10 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc, SourceRange CallLoc,
2511
2615
return ;
2512
2616
2513
2617
KernelObjVisitor Visitor{*this };
2618
+ SYCLKernelNameTypeVisitor KernelTypeVisitor (*this , Args[0 ]->getExprLoc ());
2619
+ // Emit diagnostics for SYCL device kernels only
2620
+ if (LangOpts.SYCLIsDevice )
2621
+ KernelTypeVisitor.Visit (KernelNameType);
2514
2622
DiagnosingSYCLKernel = true ;
2515
2623
Visitor.VisitRecordBases (KernelObj, FieldChecker, UnionChecker,
2516
2624
ArgsSizeChecker);
@@ -2856,18 +2964,6 @@ static void emitWithoutAnonNamespaces(llvm::raw_ostream &OS, StringRef Source) {
2856
2964
OS << Source;
2857
2965
}
2858
2966
2859
- static bool checkEnumTemplateParameter (const EnumDecl *ED,
2860
- DiagnosticsEngine &Diag,
2861
- SourceLocation KernelLocation) {
2862
- if (!ED->isScoped () && !ED->isFixed ()) {
2863
- Diag.Report (KernelLocation, diag::err_sycl_kernel_incorrectly_named) << 2 ;
2864
- Diag.Report (ED->getSourceRange ().getBegin (), diag::note_entity_declared_at)
2865
- << ED;
2866
- return true ;
2867
- }
2868
- return false ;
2869
- }
2870
-
2871
2967
// Emits a forward declaration
2872
2968
void SYCLIntegrationHeader::emitFwdDecl (raw_ostream &O, const Decl *D,
2873
2969
SourceLocation KernelLocation) {
@@ -2880,32 +2976,6 @@ void SYCLIntegrationHeader::emitFwdDecl(raw_ostream &O, const Decl *D,
2880
2976
auto *NS = dyn_cast_or_null<NamespaceDecl>(DC);
2881
2977
2882
2978
if (!NS) {
2883
- if (!DC->isTranslationUnit ()) {
2884
- const TagDecl *TD = isa<ClassTemplateDecl>(D)
2885
- ? cast<ClassTemplateDecl>(D)->getTemplatedDecl ()
2886
- : dyn_cast<TagDecl>(D);
2887
-
2888
- if (TD && !UnnamedLambdaSupport) {
2889
- // defined class constituting the kernel name is not globally
2890
- // accessible - contradicts the spec
2891
- const bool KernelNameIsMissing = TD->getName ().empty ();
2892
- if (KernelNameIsMissing) {
2893
- Diag.Report (KernelLocation, diag::err_sycl_kernel_incorrectly_named)
2894
- << /* kernel name is missing */ 0 ;
2895
- // Don't emit note if kernel name was completely omitted
2896
- } else {
2897
- if (TD->isCompleteDefinition ())
2898
- Diag.Report (KernelLocation,
2899
- diag::err_sycl_kernel_incorrectly_named)
2900
- << /* kernel name is not globally-visible */ 1 ;
2901
- else
2902
- Diag.Report (KernelLocation, diag::warn_sycl_implicit_decl);
2903
- Diag.Report (D->getSourceRange ().getBegin (),
2904
- diag::note_previous_decl)
2905
- << TD->getName ();
2906
- }
2907
- }
2908
- }
2909
2979
break ;
2910
2980
}
2911
2981
@@ -3025,7 +3095,6 @@ void SYCLIntegrationHeader::emitForwardClassDecls(
3025
3095
// Handle Kernel Name Type templated using enum type and value.
3026
3096
if (const auto *ET = T->getAs <EnumType>()) {
3027
3097
const EnumDecl *ED = ET->getDecl ();
3028
- if (!checkEnumTemplateParameter (ED, Diag, KernelLocation))
3029
3098
emitFwdDecl (O, ED, KernelLocation);
3030
3099
} else if (Arg.getKind () == TemplateArgument::ArgKind::Type)
3031
3100
emitForwardClassDecls (O, T, KernelLocation, Printed);
@@ -3085,7 +3154,6 @@ void SYCLIntegrationHeader::emitForwardClassDecls(
3085
3154
QualType T = TemplateParam->getType ();
3086
3155
if (const auto *ET = T->getAs <EnumType>()) {
3087
3156
const EnumDecl *ED = ET->getDecl ();
3088
- if (!checkEnumTemplateParameter (ED, Diag, KernelLocation))
3089
3157
emitFwdDecl (O, ED, KernelLocation);
3090
3158
}
3091
3159
}
0 commit comments