@@ -3322,21 +3322,6 @@ static const char *paramKind2Str(KernelParamKind K) {
3322
3322
#undef CASE
3323
3323
}
3324
3324
3325
- // Removes all "(anonymous namespace)::" substrings from given string, and emits
3326
- // it.
3327
- static void emitWithoutAnonNamespaces (llvm::raw_ostream &OS, StringRef Source) {
3328
- const char S1[] = " (anonymous namespace)::" ;
3329
-
3330
- size_t Pos;
3331
-
3332
- while ((Pos = Source.find (S1)) != StringRef::npos) {
3333
- OS << Source.take_front (Pos);
3334
- Source = Source.drop_front (Pos + sizeof (S1) - 1 );
3335
- }
3336
-
3337
- OS << Source;
3338
- }
3339
-
3340
3325
// Emits a forward declaration
3341
3326
void SYCLIntegrationHeader::emitFwdDecl (raw_ostream &O, const Decl *D,
3342
3327
SourceLocation KernelLocation) {
@@ -3555,139 +3540,133 @@ void SYCLIntegrationHeader::emitForwardClassDecls(
3555
3540
}
3556
3541
}
3557
3542
3558
- static void emitCPPTypeString (raw_ostream &OS, QualType Ty) {
3559
- LangOptions LO;
3560
- PrintingPolicy P (LO);
3561
- P.SuppressTypedefs = true ;
3562
- emitWithoutAnonNamespaces (OS, Ty.getAsString (P));
3563
- }
3564
-
3565
- static void printArguments (ASTContext &Ctx, raw_ostream &ArgOS,
3566
- ArrayRef<TemplateArgument> Args,
3567
- const PrintingPolicy &P);
3543
+ class SYCLKernelNameTypePrinter
3544
+ : public TypeVisitor<SYCLKernelNameTypePrinter>,
3545
+ public ConstTemplateArgumentVisitor<SYCLKernelNameTypePrinter> {
3546
+ using InnerTypeVisitor = TypeVisitor<SYCLKernelNameTypePrinter>;
3547
+ using InnerTemplArgVisitor =
3548
+ ConstTemplateArgumentVisitor<SYCLKernelNameTypePrinter>;
3549
+ raw_ostream &OS;
3550
+ PrintingPolicy &Policy;
3551
+
3552
+ void printTemplateArgs (ArrayRef<TemplateArgument> Args) {
3553
+ for (size_t I = 0 , E = Args.size (); I < E; ++I) {
3554
+ const TemplateArgument &Arg = Args[I];
3555
+ // If argument is an empty pack argument, skip printing comma and
3556
+ // argument.
3557
+ if (Arg.getKind () == TemplateArgument::ArgKind::Pack && !Arg.pack_size ())
3558
+ continue ;
3568
3559
3569
- static void emitKernelNameType (QualType T, ASTContext &Ctx, raw_ostream &OS,
3570
- const PrintingPolicy &TypePolicy) ;
3560
+ if (I)
3561
+ OS << " , " ;
3571
3562
3572
- static void printArgument (ASTContext &Ctx, raw_ostream &ArgOS,
3573
- TemplateArgument Arg, const PrintingPolicy &P) {
3574
- switch (Arg.getKind ()) {
3575
- case TemplateArgument::ArgKind::Pack: {
3576
- printArguments (Ctx, ArgOS, Arg.getPackAsArray (), P);
3577
- break ;
3578
- }
3579
- case TemplateArgument::ArgKind::Integral: {
3580
- QualType T = Arg.getIntegralType ();
3581
- const EnumType *ET = T->getAs <EnumType>();
3582
-
3583
- if (ET) {
3584
- const llvm::APSInt &Val = Arg.getAsIntegral ();
3585
- ArgOS << " static_cast<"
3586
- << ET->getDecl ()->getQualifiedNameAsString (
3587
- /* WithGlobalNsPrefix*/ true )
3588
- << " >"
3589
- << " (" << Val << " )" ;
3590
- } else {
3591
- Arg.print (P, ArgOS);
3563
+ Visit (Arg);
3592
3564
}
3593
- break ;
3594
3565
}
3595
- case TemplateArgument::ArgKind::Type: {
3596
- LangOptions LO;
3597
- PrintingPolicy TypePolicy (LO);
3598
- TypePolicy.SuppressTypedefs = true ;
3599
- TypePolicy.SuppressTagKeyword = true ;
3600
- QualType T = Arg.getAsType ();
3601
3566
3602
- emitKernelNameType (T, Ctx, ArgOS, TypePolicy);
3603
- break ;
3604
- }
3605
- case TemplateArgument::ArgKind::Template: {
3606
- TemplateDecl *TD = Arg.getAsTemplate ().getAsTemplateDecl ();
3607
- ArgOS << TD->getQualifiedNameAsString ();
3608
- break ;
3609
- }
3610
- default :
3611
- Arg.print (P, ArgOS);
3567
+ void VisitQualifiers (Qualifiers Quals) {
3568
+ Quals.print (OS, Policy, /* appendSpaceIfNotEmpty*/ true );
3612
3569
}
3613
- }
3614
3570
3615
- static void printArguments (ASTContext &Ctx, raw_ostream &ArgOS,
3616
- ArrayRef<TemplateArgument> Args,
3617
- const PrintingPolicy &P) {
3618
- for (unsigned I = 0 ; I < Args.size (); I++) {
3619
- const TemplateArgument &Arg = Args[I];
3571
+ public:
3572
+ SYCLKernelNameTypePrinter (raw_ostream &OS, PrintingPolicy &Policy)
3573
+ : OS(OS), Policy(Policy) {}
3620
3574
3621
- // If argument is an empty pack argument, skip printing comma and argument.
3622
- if (Arg. getKind () == TemplateArgument::ArgKind::Pack && !Arg. pack_size ())
3623
- continue ;
3575
+ void Visit (QualType T) {
3576
+ if (T. isNull ())
3577
+ return ;
3624
3578
3625
- if (I != 0 )
3626
- ArgOS << " , " ;
3579
+ QualType CT = T. getCanonicalType ();
3580
+ VisitQualifiers (CT. getQualifiers ()) ;
3627
3581
3628
- printArgument (Ctx, ArgOS, Arg, P );
3582
+ InnerTypeVisitor::Visit (CT. getTypePtr () );
3629
3583
}
3630
- }
3631
3584
3632
- static void printTemplateArguments (ASTContext &Ctx, raw_ostream &ArgOS,
3633
- ArrayRef<TemplateArgument> Args,
3634
- const PrintingPolicy &P) {
3635
- ArgOS << " <" ;
3636
- printArguments (Ctx, ArgOS, Args, P);
3637
- ArgOS << " >" ;
3638
- }
3585
+ void VisitType (const Type *T) {
3586
+ OS << QualType::getAsString (T, Qualifiers (), Policy);
3587
+ }
3639
3588
3640
- static void emitRecordType (raw_ostream &OS, QualType T, const CXXRecordDecl *RD,
3641
- const PrintingPolicy &TypePolicy) {
3642
- SmallString<64 > Buf;
3643
- llvm::raw_svector_ostream RecOS (Buf);
3644
- T.getCanonicalType ().getQualifiers ().print (RecOS, TypePolicy,
3645
- /* appendSpaceIfNotEmpty*/ true );
3646
- if (const auto *TSD = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
3589
+ void Visit (const TemplateArgument &TA) {
3590
+ if (TA.isNull ())
3591
+ return ;
3592
+ InnerTemplArgVisitor::Visit (TA);
3593
+ }
3594
+
3595
+ void VisitTagType (const TagType *T) {
3596
+ TagDecl *RD = T->getDecl ();
3597
+ if (const auto *TSD = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
3647
3598
3648
- // Print template class name
3649
- TSD->printQualifiedName (RecOS, TypePolicy , /* WithGlobalNsPrefix*/ true );
3599
+ // Print template class name
3600
+ TSD->printQualifiedName (OS, Policy , /* WithGlobalNsPrefix*/ true );
3650
3601
3651
- // Print template arguments substituting enumerators
3652
- ASTContext &Ctx = RD-> getASTContext () ;
3653
- const TemplateArgumentList &Args = TSD-> getTemplateArgs ( );
3654
- printTemplateArguments (Ctx, RecOS, Args. asArray (), TypePolicy) ;
3602
+ ArrayRef<TemplateArgument> Args = TSD-> getTemplateArgs (). asArray ();
3603
+ OS << " < " ;
3604
+ printTemplateArgs (Args );
3605
+ OS << " > " ;
3655
3606
3656
- emitWithoutAnonNamespaces (OS, RecOS.str ());
3657
- return ;
3607
+ return ;
3608
+ }
3609
+ // TODO: Next part of code results in printing of "class" keyword before
3610
+ // class name in case if kernel name doesn't belong to some namespace. It
3611
+ // seems if we don't print it, the integration header still represents valid
3612
+ // c++ code. Probably we don't need to print it at all.
3613
+ if (RD->getDeclContext ()->isFunctionOrMethod ()) {
3614
+ OS << QualType::getAsString (T, Qualifiers (), Policy);
3615
+ return ;
3616
+ }
3617
+
3618
+ const NamespaceDecl *NS = dyn_cast<NamespaceDecl>(RD->getDeclContext ());
3619
+ RD->printQualifiedName (OS, Policy, !(NS && NS->isAnonymousNamespace ()));
3658
3620
}
3659
- if (RD-> getDeclContext ()-> isFunctionOrMethod ()) {
3660
- emitWithoutAnonNamespaces (OS, T. getCanonicalType (). getAsString (TypePolicy));
3661
- return ;
3621
+
3622
+ void VisitTemplateArgument ( const TemplateArgument &TA) {
3623
+ TA. print (Policy, OS) ;
3662
3624
}
3663
3625
3664
- const NamespaceDecl *NS = dyn_cast<NamespaceDecl>(RD->getDeclContext ());
3665
- RD->printQualifiedName (RecOS, TypePolicy,
3666
- !(NS && NS->isAnonymousNamespace ()));
3667
- emitWithoutAnonNamespaces (OS, RecOS.str ());
3668
- }
3626
+ void VisitTypeTemplateArgument (const TemplateArgument &TA) {
3627
+ Policy.SuppressTagKeyword = true ;
3628
+ QualType T = TA.getAsType ();
3629
+ Visit (T);
3630
+ Policy.SuppressTagKeyword = false ;
3631
+ }
3669
3632
3670
- static void emitKernelNameType (QualType T, ASTContext &Ctx, raw_ostream &OS,
3671
- const PrintingPolicy &TypePolicy) {
3672
- if (T->isRecordType ()) {
3673
- emitRecordType (OS, T, T->getAsCXXRecordDecl (), TypePolicy);
3674
- return ;
3633
+ void VisitIntegralTemplateArgument (const TemplateArgument &TA) {
3634
+ QualType T = TA.getIntegralType ();
3635
+ if (const EnumType *ET = T->getAs <EnumType>()) {
3636
+ const llvm::APSInt &Val = TA.getAsIntegral ();
3637
+ OS << " static_cast<" ;
3638
+ ET->getDecl ()->printQualifiedName (OS, Policy,
3639
+ /* WithGlobalNsPrefix*/ true );
3640
+ OS << " >(" << Val << " )" ;
3641
+ } else {
3642
+ TA.print (Policy, OS);
3643
+ }
3675
3644
}
3676
3645
3677
- if (T->isEnumeralType ())
3678
- OS << " ::" ;
3679
- emitWithoutAnonNamespaces (OS, T.getCanonicalType ().getAsString (TypePolicy));
3680
- }
3646
+ void VisitTemplateTemplateArgument (const TemplateArgument &TA) {
3647
+ TemplateDecl *TD = TA.getAsTemplate ().getAsTemplateDecl ();
3648
+ TD->printQualifiedName (OS, Policy);
3649
+ }
3650
+
3651
+ void VisitPackTemplateArgument (const TemplateArgument &TA) {
3652
+ printTemplateArgs (TA.getPackAsArray ());
3653
+ }
3654
+ };
3681
3655
3682
3656
void SYCLIntegrationHeader::emit (raw_ostream &O) {
3683
3657
O << " // This is auto-generated SYCL integration header.\n " ;
3684
3658
O << " \n " ;
3685
3659
3686
- O << " #include <CL/sycl/detail/defines .hpp>\n " ;
3660
+ O << " #include <CL/sycl/detail/defines_elementary .hpp>\n " ;
3687
3661
O << " #include <CL/sycl/detail/kernel_desc.hpp>\n " ;
3688
3662
3689
3663
O << " \n " ;
3690
3664
3665
+ LangOptions LO;
3666
+ PrintingPolicy Policy (LO);
3667
+ Policy.SuppressTypedefs = true ;
3668
+ Policy.SuppressUnwrittenScope = true ;
3669
+
3691
3670
if (SpecConsts.size () > 0 ) {
3692
3671
// Remove duplicates.
3693
3672
std::sort (SpecConsts.begin (), SpecConsts.end (),
@@ -3705,7 +3684,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
3705
3684
O << " // Specialization constants IDs:\n " ;
3706
3685
for (const auto &P : llvm::make_range (SpecConsts.begin (), End)) {
3707
3686
O << " template <> struct sycl::detail::SpecConstantInfo<" ;
3708
- emitCPPTypeString (O, P.first );
3687
+ O << P.first . getAsString (Policy );
3709
3688
O << " > {\n " ;
3710
3689
O << " static constexpr const char* getName() {\n " ;
3711
3690
O << " return \" " << P.second << " \" ;\n " ;
@@ -3773,19 +3752,17 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
3773
3752
O << " ', '" << c;
3774
3753
O << " '> {\n " ;
3775
3754
} else {
3776
- LangOptions LO;
3777
- PrintingPolicy P (LO);
3778
- P.SuppressTypedefs = true ;
3779
3755
O << " template <> struct KernelInfo<" ;
3780
- emitKernelNameType (K.NameType , S.getASTContext (), O, P);
3756
+ SYCLKernelNameTypePrinter Printer (O, Policy);
3757
+ Printer.Visit (K.NameType );
3781
3758
O << " > {\n " ;
3782
3759
}
3783
- O << " DLL_LOCAL \n " ;
3760
+ O << " __SYCL_DLL_LOCAL \n " ;
3784
3761
O << " static constexpr const char* getName() { return \" " << K.Name
3785
3762
<< " \" ; }\n " ;
3786
- O << " DLL_LOCAL \n " ;
3763
+ O << " __SYCL_DLL_LOCAL \n " ;
3787
3764
O << " static constexpr unsigned getNumParams() { return " << N << " ; }\n " ;
3788
- O << " DLL_LOCAL \n " ;
3765
+ O << " __SYCL_DLL_LOCAL \n " ;
3789
3766
O << " static constexpr const kernel_param_desc_t& " ;
3790
3767
O << " getParamDesc(unsigned i) {\n " ;
3791
3768
O << " return kernel_signatures[i+" << CurStart << " ];\n " ;
0 commit comments