@@ -3344,58 +3344,6 @@ static const char *paramKind2Str(KernelParamKind K) {
3344
3344
#undef CASE
3345
3345
}
3346
3346
3347
- // Emits a forward declaration
3348
- void SYCLIntegrationHeader::emitFwdDecl (raw_ostream &O, const Decl *D,
3349
- SourceLocation KernelLocation) {
3350
- // wrap the declaration into namespaces if needed
3351
- unsigned NamespaceCnt = 0 ;
3352
- std::string NSStr = " " ;
3353
- const DeclContext *DC = D->getDeclContext ();
3354
-
3355
- while (DC) {
3356
- auto *NS = dyn_cast_or_null<NamespaceDecl>(DC);
3357
-
3358
- if (!NS) {
3359
- break ;
3360
- }
3361
-
3362
- ++NamespaceCnt;
3363
- const StringRef NSInlinePrefix = NS->isInline () ? " inline " : " " ;
3364
- NSStr.insert (
3365
- 0 , Twine (NSInlinePrefix + " namespace " + NS->getName () + " { " ).str ());
3366
- DC = NS->getDeclContext ();
3367
- }
3368
- O << NSStr;
3369
- if (NamespaceCnt > 0 )
3370
- O << " \n " ;
3371
- // print declaration into a string:
3372
- PrintingPolicy P (D->getASTContext ().getLangOpts ());
3373
- P.adjustForCPlusPlusFwdDecl ();
3374
- P.SuppressTypedefs = true ;
3375
- P.SuppressUnwrittenScope = true ;
3376
- std::string S;
3377
- llvm::raw_string_ostream SO (S);
3378
- D->print (SO, P);
3379
- O << SO.str ();
3380
-
3381
- if (const auto *ED = dyn_cast<EnumDecl>(D)) {
3382
- QualType T = ED->getIntegerType ();
3383
- // Backup since getIntegerType() returns null for enum forward
3384
- // declaration with no fixed underlying type
3385
- if (T.isNull ())
3386
- T = ED->getPromotionType ();
3387
- O << " : " << T.getAsString ();
3388
- }
3389
-
3390
- O << " ;\n " ;
3391
-
3392
- // print closing braces for namespaces if needed
3393
- for (unsigned I = 0 ; I < NamespaceCnt; ++I)
3394
- O << " }" ;
3395
- if (NamespaceCnt > 0 )
3396
- O << " \n " ;
3397
- }
3398
-
3399
3347
// Emits forward declarations of classes and template classes on which
3400
3348
// declaration of given type depends.
3401
3349
// For example, consider SimpleVadd
@@ -3432,126 +3380,176 @@ void SYCLIntegrationHeader::emitFwdDecl(raw_ostream &O, const Decl *D,
3432
3380
// template <typename T> class MyTmplClass;
3433
3381
// template <typename T1, unsigned int N, typename ...T2> class SimpleVadd;
3434
3382
//
3435
- void SYCLIntegrationHeader::emitForwardClassDecls (
3436
- raw_ostream &O, QualType T, SourceLocation KernelLocation,
3437
- llvm::SmallPtrSetImpl<const void *> &Printed) {
3383
+ class SYCLFwdDeclEmitter
3384
+ : public TypeVisitor<SYCLFwdDeclEmitter>,
3385
+ public ConstTemplateArgumentVisitor<SYCLFwdDeclEmitter> {
3386
+ using InnerTypeVisitor = TypeVisitor<SYCLFwdDeclEmitter>;
3387
+ using InnerTemplArgVisitor = ConstTemplateArgumentVisitor<SYCLFwdDeclEmitter>;
3388
+ raw_ostream &OS;
3389
+ llvm::SmallPtrSet<const NamedDecl *, 4 > Printed;
3390
+ PrintingPolicy Policy;
3438
3391
3439
- // peel off the pointer types and get the class/struct type:
3440
- for (; T->isPointerType (); T = T->getPointeeType ())
3441
- ;
3442
- const CXXRecordDecl *RD = T->getAsCXXRecordDecl ();
3392
+ void printForwardDecl (NamedDecl *D) {
3393
+ // wrap the declaration into namespaces if needed
3394
+ unsigned NamespaceCnt = 0 ;
3395
+ std::string NSStr = " " ;
3396
+ const DeclContext *DC = D->getDeclContext ();
3443
3397
3444
- if (!RD) {
3398
+ while (DC) {
3399
+ const auto *NS = dyn_cast_or_null<NamespaceDecl>(DC);
3445
3400
3446
- return ;
3401
+ if (!NS)
3402
+ break ;
3403
+
3404
+ ++NamespaceCnt;
3405
+ const StringRef NSInlinePrefix = NS->isInline () ? " inline " : " " ;
3406
+ NSStr.insert (
3407
+ 0 ,
3408
+ Twine (NSInlinePrefix + " namespace " + NS->getName () + " { " ).str ());
3409
+ DC = NS->getDeclContext ();
3410
+ }
3411
+ OS << NSStr;
3412
+ if (NamespaceCnt > 0 )
3413
+ OS << " \n " ;
3414
+
3415
+ D->print (OS, Policy);
3416
+
3417
+ if (const auto *ED = dyn_cast<EnumDecl>(D)) {
3418
+ QualType T = ED->getIntegerType ();
3419
+ // Backup since getIntegerType() returns null for enum forward
3420
+ // declaration with no fixed underlying type
3421
+ if (T.isNull ())
3422
+ T = ED->getPromotionType ();
3423
+ OS << " : " << T.getAsString ();
3424
+ }
3425
+
3426
+ OS << " ;\n " ;
3427
+
3428
+ // print closing braces for namespaces if needed
3429
+ for (unsigned I = 0 ; I < NamespaceCnt; ++I)
3430
+ OS << " }" ;
3431
+ if (NamespaceCnt > 0 )
3432
+ OS << " \n " ;
3447
3433
}
3448
3434
3449
- // see if this is a template specialization ...
3450
- if (const auto *TSD = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
3451
- // ... yes, it is template specialization:
3452
- // - first, recurse into template parameters and emit needed forward
3453
- // declarations
3454
- const TemplateArgumentList &Args = TSD->getTemplateArgs ();
3435
+ // Checks if we've already printed forward declaration and prints it if not.
3436
+ void checkAndEmitForwardDecl (NamedDecl *D) {
3437
+ if (Printed.insert (D).second )
3438
+ printForwardDecl (D);
3439
+ }
3455
3440
3456
- for (unsigned I = 0 ; I < Args.size (); I++) {
3457
- const TemplateArgument &Arg = Args[I];
3441
+ void VisitTemplateArgs (ArrayRef<TemplateArgument> Args) {
3442
+ for (size_t I = 0 , E = Args.size (); I < E; ++I)
3443
+ Visit (Args[I]);
3444
+ }
3458
3445
3459
- switch (Arg.getKind ()) {
3460
- case TemplateArgument::ArgKind::Type:
3461
- case TemplateArgument::ArgKind::Integral: {
3462
- QualType T = (Arg.getKind () == TemplateArgument::ArgKind::Type)
3463
- ? Arg.getAsType ()
3464
- : Arg.getIntegralType ();
3465
-
3466
- // Handle Kernel Name Type templated using enum type and value.
3467
- if (const auto *ET = T->getAs <EnumType>()) {
3468
- const EnumDecl *ED = ET->getDecl ();
3469
- emitFwdDecl (O, ED, KernelLocation);
3470
- } else if (Arg.getKind () == TemplateArgument::ArgKind::Type)
3471
- emitForwardClassDecls (O, T, KernelLocation, Printed);
3472
- break ;
3473
- }
3474
- case TemplateArgument::ArgKind::Pack: {
3475
- ArrayRef<TemplateArgument> Pack = Arg.getPackAsArray ();
3446
+ public:
3447
+ SYCLFwdDeclEmitter (raw_ostream &OS, LangOptions LO) : OS(OS), Policy(LO) {
3448
+ Policy.adjustForCPlusPlusFwdDecl ();
3449
+ Policy.SuppressTypedefs = true ;
3450
+ Policy.SuppressUnwrittenScope = true ;
3451
+ }
3476
3452
3477
- for (const auto &T : Pack) {
3478
- if (T.getKind () == TemplateArgument::ArgKind::Type) {
3479
- emitForwardClassDecls (O, T.getAsType (), KernelLocation, Printed);
3480
- }
3481
- }
3482
- break ;
3483
- }
3484
- case TemplateArgument::ArgKind::Template: {
3485
- // recursion is not required, since the maximum possible nesting level
3486
- // equals two for template argument
3487
- //
3488
- // for example:
3489
- // template <typename T> class Bar;
3490
- // template <template <typename> class> class Baz;
3491
- // template <template <template <typename> class> class T>
3492
- // class Foo;
3493
- //
3494
- // The Baz is a template class. The Baz<Bar> is a class. The class Foo
3495
- // should be specialized with template class, not a class. The correct
3496
- // specialization of template class Foo is Foo<Baz>. The incorrect
3497
- // specialization of template class Foo is Foo<Baz<Bar>>. In this case
3498
- // template class Foo specialized by class Baz<Bar>, not a template
3499
- // class template <template <typename> class> class T as it should.
3500
- TemplateDecl *TD = Arg.getAsTemplate ().getAsTemplateDecl ();
3501
- TemplateParameterList *TemplateParams = TD->getTemplateParameters ();
3502
- for (NamedDecl *P : *TemplateParams) {
3503
- // If template template paramter type has an enum value template
3504
- // parameter, forward declaration of enum type is required. Only enum
3505
- // values (not types) need to be handled. For example, consider the
3506
- // following kernel name type:
3507
- //
3508
- // template <typename EnumTypeOut, template <EnumValueIn EnumValue,
3509
- // typename TypeIn> class T> class Foo;
3510
- //
3511
- // The correct specialization for Foo (with enum type) is:
3512
- // Foo<EnumTypeOut, Baz>, where Baz is a template class.
3513
- //
3514
- // Therefore the forward class declarations generated in the
3515
- // integration header are:
3516
- // template <EnumValueIn EnumValue, typename TypeIn> class Baz;
3517
- // template <typename EnumTypeOut, template <EnumValueIn EnumValue,
3518
- // typename EnumTypeIn> class T> class Foo;
3519
- //
3520
- // This requires the following enum forward declarations:
3521
- // enum class EnumTypeOut : int; (Used to template Foo)
3522
- // enum class EnumValueIn : int; (Used to template Baz)
3523
- if (NonTypeTemplateParmDecl *TemplateParam =
3524
- dyn_cast<NonTypeTemplateParmDecl>(P)) {
3525
- QualType T = TemplateParam->getType ();
3526
- if (const auto *ET = T->getAs <EnumType>()) {
3527
- const EnumDecl *ED = ET->getDecl ();
3528
- emitFwdDecl (O, ED, KernelLocation);
3529
- }
3530
- }
3531
- }
3532
- if (Printed.insert (TD).second ) {
3533
- emitFwdDecl (O, TD, KernelLocation);
3534
- }
3535
- break ;
3536
- }
3537
- default :
3538
- break ; // nop
3539
- }
3453
+ void Visit (QualType T) {
3454
+ if (T.isNull ())
3455
+ return ;
3456
+ InnerTypeVisitor::Visit (T.getTypePtr ());
3457
+ }
3458
+
3459
+ void Visit (const TemplateArgument &TA) {
3460
+ if (TA.isNull ())
3461
+ return ;
3462
+ InnerTemplArgVisitor::Visit (TA);
3463
+ }
3464
+
3465
+ void VisitPointerType (const PointerType *T) {
3466
+ // Peel off the pointer types.
3467
+ QualType PT = T->getPointeeType ();
3468
+ while (PT->isPointerType ())
3469
+ PT = PT->getPointeeType ();
3470
+ Visit (PT);
3471
+ }
3472
+
3473
+ void VisitTagType (const TagType *T) {
3474
+ TagDecl *TD = T->getDecl ();
3475
+ if (const auto *TSD = dyn_cast<ClassTemplateSpecializationDecl>(TD)) {
3476
+ // - first, recurse into template parameters and emit needed forward
3477
+ // declarations
3478
+ ArrayRef<TemplateArgument> Args = TSD->getTemplateArgs ().asArray ();
3479
+ VisitTemplateArgs (Args);
3480
+ // - second, emit forward declaration for the template class being
3481
+ // specialized
3482
+ ClassTemplateDecl *CTD = TSD->getSpecializedTemplate ();
3483
+ assert (CTD && " template declaration must be available" );
3484
+
3485
+ checkAndEmitForwardDecl (CTD);
3486
+ return ;
3540
3487
}
3541
- // - second, emit forward declaration for the template class being
3542
- // specialized
3543
- ClassTemplateDecl *CTD = TSD->getSpecializedTemplate ();
3544
- assert (CTD && " template declaration must be available" );
3488
+ checkAndEmitForwardDecl (TD);
3489
+ }
3490
+
3491
+ void VisitTypeTemplateArgument (const TemplateArgument &TA) {
3492
+ QualType T = TA.getAsType ();
3493
+ Visit (T);
3494
+ }
3545
3495
3546
- if (Printed.insert (CTD).second ) {
3547
- emitFwdDecl (O, CTD, KernelLocation);
3496
+ void VisitIntegralTemplateArgument (const TemplateArgument &TA) {
3497
+ QualType T = TA.getIntegralType ();
3498
+ if (const EnumType *ET = T->getAs <EnumType>())
3499
+ VisitTagType (ET);
3500
+ }
3501
+
3502
+ void VisitTemplateTemplateArgument (const TemplateArgument &TA) {
3503
+ // recursion is not required, since the maximum possible nesting level
3504
+ // equals two for template argument
3505
+ //
3506
+ // for example:
3507
+ // template <typename T> class Bar;
3508
+ // template <template <typename> class> class Baz;
3509
+ // template <template <template <typename> class> class T>
3510
+ // class Foo;
3511
+ //
3512
+ // The Baz is a template class. The Baz<Bar> is a class. The class Foo
3513
+ // should be specialized with template class, not a class. The correct
3514
+ // specialization of template class Foo is Foo<Baz>. The incorrect
3515
+ // specialization of template class Foo is Foo<Baz<Bar>>. In this case
3516
+ // template class Foo specialized by class Baz<Bar>, not a template
3517
+ // class template <template <typename> class> class T as it should.
3518
+ TemplateDecl *TD = TA.getAsTemplate ().getAsTemplateDecl ();
3519
+ TemplateParameterList *TemplateParams = TD->getTemplateParameters ();
3520
+ for (NamedDecl *P : *TemplateParams) {
3521
+ // If template template parameter type has an enum value template
3522
+ // parameter, forward declaration of enum type is required. Only enum
3523
+ // values (not types) need to be handled. For example, consider the
3524
+ // following kernel name type:
3525
+ //
3526
+ // template <typename EnumTypeOut, template <EnumValueIn EnumValue,
3527
+ // typename TypeIn> class T> class Foo;
3528
+ //
3529
+ // The correct specialization for Foo (with enum type) is:
3530
+ // Foo<EnumTypeOut, Baz>, where Baz is a template class.
3531
+ //
3532
+ // Therefore the forward class declarations generated in the
3533
+ // integration header are:
3534
+ // template <EnumValueIn EnumValue, typename TypeIn> class Baz;
3535
+ // template <typename EnumTypeOut, template <EnumValueIn EnumValue,
3536
+ // typename EnumTypeIn> class T> class Foo;
3537
+ //
3538
+ // This requires the following enum forward declarations:
3539
+ // enum class EnumTypeOut : int; (Used to template Foo)
3540
+ // enum class EnumValueIn : int; (Used to template Baz)
3541
+ if (NonTypeTemplateParmDecl *TemplateParam =
3542
+ dyn_cast<NonTypeTemplateParmDecl>(P))
3543
+ if (const EnumType *ET = TemplateParam->getType ()->getAs <EnumType>())
3544
+ VisitTagType (ET);
3548
3545
}
3549
- } else if (Printed.insert (RD).second ) {
3550
- // emit forward declarations for "leaf" classes in the template parameter
3551
- // tree;
3552
- emitFwdDecl (O, RD, KernelLocation);
3546
+ checkAndEmitForwardDecl (TD);
3553
3547
}
3554
- }
3548
+
3549
+ void VisitPackTemplateArgument (const TemplateArgument &TA) {
3550
+ VisitTemplateArgs (TA.getPackAsArray ());
3551
+ }
3552
+ };
3555
3553
3556
3554
class SYCLKernelNameTypePrinter
3557
3555
: public TypeVisitor<SYCLKernelNameTypePrinter>,
@@ -3709,10 +3707,9 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
3709
3707
if (!UnnamedLambdaSupport) {
3710
3708
O << " // Forward declarations of templated kernel function types:\n " ;
3711
3709
3712
- llvm::SmallPtrSet<const void *, 4 > Printed;
3713
- for (const KernelDesc &K : KernelDescs) {
3714
- emitForwardClassDecls (O, K.NameType , K.KernelLocation , Printed);
3715
- }
3710
+ SYCLFwdDeclEmitter FwdDeclEmitter (O, S.getLangOpts ());
3711
+ for (const KernelDesc &K : KernelDescs)
3712
+ FwdDeclEmitter.Visit (K.NameType );
3716
3713
}
3717
3714
O << " \n " ;
3718
3715
0 commit comments