Skip to content

Commit 6e232d5

Browse files
authored
[NFC][SYCL] Use visitor to emit forward declarations (#2670)
1 parent 43f2d4b commit 6e232d5

File tree

2 files changed

+161
-184
lines changed

2 files changed

+161
-184
lines changed

clang/include/clang/Sema/Sema.h

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -388,26 +388,6 @@ class SYCLIntegrationHeader {
388388
: nullptr;
389389
}
390390

391-
/// Emits a forward declaration for given declaration.
392-
void emitFwdDecl(raw_ostream &O, const Decl *D,
393-
SourceLocation KernelLocation);
394-
395-
/// Emits forward declarations of classes and template classes on which
396-
/// declaration of given type depends. See example in the comments for the
397-
/// implementation.
398-
/// \param O
399-
/// stream to emit to
400-
/// \param T
401-
/// type to emit forward declarations for
402-
/// \param KernelLocation
403-
/// source location of the SYCL kernel function, used to emit nicer
404-
/// diagnostic messages if kernel name is missing
405-
/// \param Emitted
406-
/// a set of declarations forward declrations has been emitted for already
407-
void emitForwardClassDecls(raw_ostream &O, QualType T,
408-
SourceLocation KernelLocation,
409-
llvm::SmallPtrSetImpl<const void *> &Emitted);
410-
411391
private:
412392
/// Keeps invocation descriptors for each kernel invocation started by
413393
/// SYCLIntegrationHeader::startKernel

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 161 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -3344,58 +3344,6 @@ static const char *paramKind2Str(KernelParamKind K) {
33443344
#undef CASE
33453345
}
33463346

3347-
// Emits a forward declaration
3348-
void SYCLIntegrationHeader::emitFwdDecl(raw_ostream &O, const Decl *D,
3349-
SourceLocation KernelLocation) {
3350-
// wrap the declaration into namespaces if needed
3351-
unsigned NamespaceCnt = 0;
3352-
std::string NSStr = "";
3353-
const DeclContext *DC = D->getDeclContext();
3354-
3355-
while (DC) {
3356-
auto *NS = dyn_cast_or_null<NamespaceDecl>(DC);
3357-
3358-
if (!NS) {
3359-
break;
3360-
}
3361-
3362-
++NamespaceCnt;
3363-
const StringRef NSInlinePrefix = NS->isInline() ? "inline " : "";
3364-
NSStr.insert(
3365-
0, Twine(NSInlinePrefix + "namespace " + NS->getName() + " { ").str());
3366-
DC = NS->getDeclContext();
3367-
}
3368-
O << NSStr;
3369-
if (NamespaceCnt > 0)
3370-
O << "\n";
3371-
// print declaration into a string:
3372-
PrintingPolicy P(D->getASTContext().getLangOpts());
3373-
P.adjustForCPlusPlusFwdDecl();
3374-
P.SuppressTypedefs = true;
3375-
P.SuppressUnwrittenScope = true;
3376-
std::string S;
3377-
llvm::raw_string_ostream SO(S);
3378-
D->print(SO, P);
3379-
O << SO.str();
3380-
3381-
if (const auto *ED = dyn_cast<EnumDecl>(D)) {
3382-
QualType T = ED->getIntegerType();
3383-
// Backup since getIntegerType() returns null for enum forward
3384-
// declaration with no fixed underlying type
3385-
if (T.isNull())
3386-
T = ED->getPromotionType();
3387-
O << " : " << T.getAsString();
3388-
}
3389-
3390-
O << ";\n";
3391-
3392-
// print closing braces for namespaces if needed
3393-
for (unsigned I = 0; I < NamespaceCnt; ++I)
3394-
O << "}";
3395-
if (NamespaceCnt > 0)
3396-
O << "\n";
3397-
}
3398-
33993347
// Emits forward declarations of classes and template classes on which
34003348
// declaration of given type depends.
34013349
// For example, consider SimpleVadd
@@ -3432,126 +3380,176 @@ void SYCLIntegrationHeader::emitFwdDecl(raw_ostream &O, const Decl *D,
34323380
// template <typename T> class MyTmplClass;
34333381
// template <typename T1, unsigned int N, typename ...T2> class SimpleVadd;
34343382
//
3435-
void SYCLIntegrationHeader::emitForwardClassDecls(
3436-
raw_ostream &O, QualType T, SourceLocation KernelLocation,
3437-
llvm::SmallPtrSetImpl<const void *> &Printed) {
3383+
class SYCLFwdDeclEmitter
3384+
: public TypeVisitor<SYCLFwdDeclEmitter>,
3385+
public ConstTemplateArgumentVisitor<SYCLFwdDeclEmitter> {
3386+
using InnerTypeVisitor = TypeVisitor<SYCLFwdDeclEmitter>;
3387+
using InnerTemplArgVisitor = ConstTemplateArgumentVisitor<SYCLFwdDeclEmitter>;
3388+
raw_ostream &OS;
3389+
llvm::SmallPtrSet<const NamedDecl *, 4> Printed;
3390+
PrintingPolicy Policy;
34383391

3439-
// peel off the pointer types and get the class/struct type:
3440-
for (; T->isPointerType(); T = T->getPointeeType())
3441-
;
3442-
const CXXRecordDecl *RD = T->getAsCXXRecordDecl();
3392+
void printForwardDecl(NamedDecl *D) {
3393+
// wrap the declaration into namespaces if needed
3394+
unsigned NamespaceCnt = 0;
3395+
std::string NSStr = "";
3396+
const DeclContext *DC = D->getDeclContext();
34433397

3444-
if (!RD) {
3398+
while (DC) {
3399+
const auto *NS = dyn_cast_or_null<NamespaceDecl>(DC);
34453400

3446-
return;
3401+
if (!NS)
3402+
break;
3403+
3404+
++NamespaceCnt;
3405+
const StringRef NSInlinePrefix = NS->isInline() ? "inline " : "";
3406+
NSStr.insert(
3407+
0,
3408+
Twine(NSInlinePrefix + "namespace " + NS->getName() + " { ").str());
3409+
DC = NS->getDeclContext();
3410+
}
3411+
OS << NSStr;
3412+
if (NamespaceCnt > 0)
3413+
OS << "\n";
3414+
3415+
D->print(OS, Policy);
3416+
3417+
if (const auto *ED = dyn_cast<EnumDecl>(D)) {
3418+
QualType T = ED->getIntegerType();
3419+
// Backup since getIntegerType() returns null for enum forward
3420+
// declaration with no fixed underlying type
3421+
if (T.isNull())
3422+
T = ED->getPromotionType();
3423+
OS << " : " << T.getAsString();
3424+
}
3425+
3426+
OS << ";\n";
3427+
3428+
// print closing braces for namespaces if needed
3429+
for (unsigned I = 0; I < NamespaceCnt; ++I)
3430+
OS << "}";
3431+
if (NamespaceCnt > 0)
3432+
OS << "\n";
34473433
}
34483434

3449-
// see if this is a template specialization ...
3450-
if (const auto *TSD = dyn_cast<ClassTemplateSpecializationDecl>(RD)) {
3451-
// ... yes, it is template specialization:
3452-
// - first, recurse into template parameters and emit needed forward
3453-
// declarations
3454-
const TemplateArgumentList &Args = TSD->getTemplateArgs();
3435+
// Checks if we've already printed forward declaration and prints it if not.
3436+
void checkAndEmitForwardDecl(NamedDecl *D) {
3437+
if (Printed.insert(D).second)
3438+
printForwardDecl(D);
3439+
}
34553440

3456-
for (unsigned I = 0; I < Args.size(); I++) {
3457-
const TemplateArgument &Arg = Args[I];
3441+
void VisitTemplateArgs(ArrayRef<TemplateArgument> Args) {
3442+
for (size_t I = 0, E = Args.size(); I < E; ++I)
3443+
Visit(Args[I]);
3444+
}
34583445

3459-
switch (Arg.getKind()) {
3460-
case TemplateArgument::ArgKind::Type:
3461-
case TemplateArgument::ArgKind::Integral: {
3462-
QualType T = (Arg.getKind() == TemplateArgument::ArgKind::Type)
3463-
? Arg.getAsType()
3464-
: Arg.getIntegralType();
3465-
3466-
// Handle Kernel Name Type templated using enum type and value.
3467-
if (const auto *ET = T->getAs<EnumType>()) {
3468-
const EnumDecl *ED = ET->getDecl();
3469-
emitFwdDecl(O, ED, KernelLocation);
3470-
} else if (Arg.getKind() == TemplateArgument::ArgKind::Type)
3471-
emitForwardClassDecls(O, T, KernelLocation, Printed);
3472-
break;
3473-
}
3474-
case TemplateArgument::ArgKind::Pack: {
3475-
ArrayRef<TemplateArgument> Pack = Arg.getPackAsArray();
3446+
public:
3447+
SYCLFwdDeclEmitter(raw_ostream &OS, LangOptions LO) : OS(OS), Policy(LO) {
3448+
Policy.adjustForCPlusPlusFwdDecl();
3449+
Policy.SuppressTypedefs = true;
3450+
Policy.SuppressUnwrittenScope = true;
3451+
}
34763452

3477-
for (const auto &T : Pack) {
3478-
if (T.getKind() == TemplateArgument::ArgKind::Type) {
3479-
emitForwardClassDecls(O, T.getAsType(), KernelLocation, Printed);
3480-
}
3481-
}
3482-
break;
3483-
}
3484-
case TemplateArgument::ArgKind::Template: {
3485-
// recursion is not required, since the maximum possible nesting level
3486-
// equals two for template argument
3487-
//
3488-
// for example:
3489-
// template <typename T> class Bar;
3490-
// template <template <typename> class> class Baz;
3491-
// template <template <template <typename> class> class T>
3492-
// class Foo;
3493-
//
3494-
// The Baz is a template class. The Baz<Bar> is a class. The class Foo
3495-
// should be specialized with template class, not a class. The correct
3496-
// specialization of template class Foo is Foo<Baz>. The incorrect
3497-
// specialization of template class Foo is Foo<Baz<Bar>>. In this case
3498-
// template class Foo specialized by class Baz<Bar>, not a template
3499-
// class template <template <typename> class> class T as it should.
3500-
TemplateDecl *TD = Arg.getAsTemplate().getAsTemplateDecl();
3501-
TemplateParameterList *TemplateParams = TD->getTemplateParameters();
3502-
for (NamedDecl *P : *TemplateParams) {
3503-
// If template template paramter type has an enum value template
3504-
// parameter, forward declaration of enum type is required. Only enum
3505-
// values (not types) need to be handled. For example, consider the
3506-
// following kernel name type:
3507-
//
3508-
// template <typename EnumTypeOut, template <EnumValueIn EnumValue,
3509-
// typename TypeIn> class T> class Foo;
3510-
//
3511-
// The correct specialization for Foo (with enum type) is:
3512-
// Foo<EnumTypeOut, Baz>, where Baz is a template class.
3513-
//
3514-
// Therefore the forward class declarations generated in the
3515-
// integration header are:
3516-
// template <EnumValueIn EnumValue, typename TypeIn> class Baz;
3517-
// template <typename EnumTypeOut, template <EnumValueIn EnumValue,
3518-
// typename EnumTypeIn> class T> class Foo;
3519-
//
3520-
// This requires the following enum forward declarations:
3521-
// enum class EnumTypeOut : int; (Used to template Foo)
3522-
// enum class EnumValueIn : int; (Used to template Baz)
3523-
if (NonTypeTemplateParmDecl *TemplateParam =
3524-
dyn_cast<NonTypeTemplateParmDecl>(P)) {
3525-
QualType T = TemplateParam->getType();
3526-
if (const auto *ET = T->getAs<EnumType>()) {
3527-
const EnumDecl *ED = ET->getDecl();
3528-
emitFwdDecl(O, ED, KernelLocation);
3529-
}
3530-
}
3531-
}
3532-
if (Printed.insert(TD).second) {
3533-
emitFwdDecl(O, TD, KernelLocation);
3534-
}
3535-
break;
3536-
}
3537-
default:
3538-
break; // nop
3539-
}
3453+
void Visit(QualType T) {
3454+
if (T.isNull())
3455+
return;
3456+
InnerTypeVisitor::Visit(T.getTypePtr());
3457+
}
3458+
3459+
void Visit(const TemplateArgument &TA) {
3460+
if (TA.isNull())
3461+
return;
3462+
InnerTemplArgVisitor::Visit(TA);
3463+
}
3464+
3465+
void VisitPointerType(const PointerType *T) {
3466+
// Peel off the pointer types.
3467+
QualType PT = T->getPointeeType();
3468+
while (PT->isPointerType())
3469+
PT = PT->getPointeeType();
3470+
Visit(PT);
3471+
}
3472+
3473+
void VisitTagType(const TagType *T) {
3474+
TagDecl *TD = T->getDecl();
3475+
if (const auto *TSD = dyn_cast<ClassTemplateSpecializationDecl>(TD)) {
3476+
// - first, recurse into template parameters and emit needed forward
3477+
// declarations
3478+
ArrayRef<TemplateArgument> Args = TSD->getTemplateArgs().asArray();
3479+
VisitTemplateArgs(Args);
3480+
// - second, emit forward declaration for the template class being
3481+
// specialized
3482+
ClassTemplateDecl *CTD = TSD->getSpecializedTemplate();
3483+
assert(CTD && "template declaration must be available");
3484+
3485+
checkAndEmitForwardDecl(CTD);
3486+
return;
35403487
}
3541-
// - second, emit forward declaration for the template class being
3542-
// specialized
3543-
ClassTemplateDecl *CTD = TSD->getSpecializedTemplate();
3544-
assert(CTD && "template declaration must be available");
3488+
checkAndEmitForwardDecl(TD);
3489+
}
3490+
3491+
void VisitTypeTemplateArgument(const TemplateArgument &TA) {
3492+
QualType T = TA.getAsType();
3493+
Visit(T);
3494+
}
35453495

3546-
if (Printed.insert(CTD).second) {
3547-
emitFwdDecl(O, CTD, KernelLocation);
3496+
void VisitIntegralTemplateArgument(const TemplateArgument &TA) {
3497+
QualType T = TA.getIntegralType();
3498+
if (const EnumType *ET = T->getAs<EnumType>())
3499+
VisitTagType(ET);
3500+
}
3501+
3502+
void VisitTemplateTemplateArgument(const TemplateArgument &TA) {
3503+
// recursion is not required, since the maximum possible nesting level
3504+
// equals two for template argument
3505+
//
3506+
// for example:
3507+
// template <typename T> class Bar;
3508+
// template <template <typename> class> class Baz;
3509+
// template <template <template <typename> class> class T>
3510+
// class Foo;
3511+
//
3512+
// The Baz is a template class. The Baz<Bar> is a class. The class Foo
3513+
// should be specialized with template class, not a class. The correct
3514+
// specialization of template class Foo is Foo<Baz>. The incorrect
3515+
// specialization of template class Foo is Foo<Baz<Bar>>. In this case
3516+
// template class Foo specialized by class Baz<Bar>, not a template
3517+
// class template <template <typename> class> class T as it should.
3518+
TemplateDecl *TD = TA.getAsTemplate().getAsTemplateDecl();
3519+
TemplateParameterList *TemplateParams = TD->getTemplateParameters();
3520+
for (NamedDecl *P : *TemplateParams) {
3521+
// If template template parameter type has an enum value template
3522+
// parameter, forward declaration of enum type is required. Only enum
3523+
// values (not types) need to be handled. For example, consider the
3524+
// following kernel name type:
3525+
//
3526+
// template <typename EnumTypeOut, template <EnumValueIn EnumValue,
3527+
// typename TypeIn> class T> class Foo;
3528+
//
3529+
// The correct specialization for Foo (with enum type) is:
3530+
// Foo<EnumTypeOut, Baz>, where Baz is a template class.
3531+
//
3532+
// Therefore the forward class declarations generated in the
3533+
// integration header are:
3534+
// template <EnumValueIn EnumValue, typename TypeIn> class Baz;
3535+
// template <typename EnumTypeOut, template <EnumValueIn EnumValue,
3536+
// typename EnumTypeIn> class T> class Foo;
3537+
//
3538+
// This requires the following enum forward declarations:
3539+
// enum class EnumTypeOut : int; (Used to template Foo)
3540+
// enum class EnumValueIn : int; (Used to template Baz)
3541+
if (NonTypeTemplateParmDecl *TemplateParam =
3542+
dyn_cast<NonTypeTemplateParmDecl>(P))
3543+
if (const EnumType *ET = TemplateParam->getType()->getAs<EnumType>())
3544+
VisitTagType(ET);
35483545
}
3549-
} else if (Printed.insert(RD).second) {
3550-
// emit forward declarations for "leaf" classes in the template parameter
3551-
// tree;
3552-
emitFwdDecl(O, RD, KernelLocation);
3546+
checkAndEmitForwardDecl(TD);
35533547
}
3554-
}
3548+
3549+
void VisitPackTemplateArgument(const TemplateArgument &TA) {
3550+
VisitTemplateArgs(TA.getPackAsArray());
3551+
}
3552+
};
35553553

35563554
class SYCLKernelNameTypePrinter
35573555
: public TypeVisitor<SYCLKernelNameTypePrinter>,
@@ -3709,10 +3707,9 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
37093707
if (!UnnamedLambdaSupport) {
37103708
O << "// Forward declarations of templated kernel function types:\n";
37113709

3712-
llvm::SmallPtrSet<const void *, 4> Printed;
3713-
for (const KernelDesc &K : KernelDescs) {
3714-
emitForwardClassDecls(O, K.NameType, K.KernelLocation, Printed);
3715-
}
3710+
SYCLFwdDeclEmitter FwdDeclEmitter(O, S.getLangOpts());
3711+
for (const KernelDesc &K : KernelDescs)
3712+
FwdDeclEmitter.Visit(K.NameType);
37163713
}
37173714
O << "\n";
37183715

0 commit comments

Comments
 (0)