@@ -831,6 +831,9 @@ class KernelObjVisitor {
831
831
else if (ElementTy->isStructureOrClassType ())
832
832
VisitRecord (Owner, ArrayField, ElementTy->getAsCXXRecordDecl (),
833
833
handlers...);
834
+ else if (ElementTy->isUnionType ())
835
+ VisitUnion (Owner, ArrayField, ElementTy->getAsCXXRecordDecl (),
836
+ handlers...);
834
837
else if (ElementTy->isArrayType ())
835
838
VisitArrayElements (ArrayField, ElementTy, handlers...);
836
839
else if (ElementTy->isScalarType ())
@@ -858,6 +861,65 @@ class KernelObjVisitor {
858
861
void VisitRecord (const CXXRecordDecl *Owner, ParentTy &Parent,
859
862
const CXXRecordDecl *Wrapper, Handlers &... handlers);
860
863
864
+ // Base case, only calls these when filtered.
865
+ template <typename ... FilteredHandlers, typename ParentTy>
866
+ std::enable_if_t <(sizeof ...(FilteredHandlers) > 0 )>
867
+ VisitUnion (FilteredHandlers &... handlers, const CXXRecordDecl *Owner,
868
+ ParentTy &Parent, const CXXRecordDecl *Wrapper) {
869
+ (void )std::initializer_list<int >{
870
+ (handlers.enterUnion (Owner, Parent), 0 )...};
871
+ VisitRecordHelper (Wrapper, Wrapper->fields (), handlers...);
872
+ (void )std::initializer_list<int >{
873
+ (handlers.leaveUnion (Owner, Parent), 0 )...};
874
+ }
875
+
876
+ // Handle empty base case.
877
+ template <typename ParentTy>
878
+ void VisitUnion (const CXXRecordDecl *Owner, ParentTy &Parent,
879
+ const CXXRecordDecl *Wrapper) {}
880
+
881
+ template <typename ... FilteredHandlers, typename ParentTy,
882
+ typename CurHandler, typename ... Handlers>
883
+ std::enable_if_t <!CurHandler::VisitUnionBody &&
884
+ (sizeof ...(FilteredHandlers) > 0 )>
885
+ VisitUnion (FilteredHandlers &... filtered_handlers,
886
+ const CXXRecordDecl *Owner, ParentTy &Parent,
887
+ const CXXRecordDecl *Wrapper, CurHandler &cur_handler,
888
+ Handlers &... handlers) {
889
+ VisitUnion<FilteredHandlers...>(filtered_handlers..., Owner, Parent,
890
+ Wrapper, handlers...);
891
+ }
892
+
893
+ template <typename ... FilteredHandlers, typename ParentTy,
894
+ typename CurHandler, typename ... Handlers>
895
+ std::enable_if_t <CurHandler::VisitUnionBody &&
896
+ (sizeof ...(FilteredHandlers) > 0 )>
897
+ VisitUnion (FilteredHandlers &... filtered_handlers,
898
+ const CXXRecordDecl *Owner, ParentTy &Parent,
899
+ const CXXRecordDecl *Wrapper, CurHandler &cur_handler,
900
+ Handlers &... handlers) {
901
+ VisitUnion<FilteredHandlers..., CurHandler>(
902
+ filtered_handlers..., cur_handler, Owner, Parent, Wrapper, handlers...);
903
+ }
904
+
905
+ // Add overloads without having filtered-handlers
906
+ // to handle leading-empty argument packs.
907
+ template <typename ParentTy, typename CurHandler, typename ... Handlers>
908
+ std::enable_if_t <!CurHandler::VisitUnionBody>
909
+ VisitUnion (const CXXRecordDecl *Owner, ParentTy &Parent,
910
+ const CXXRecordDecl *Wrapper, CurHandler &cur_handler,
911
+ Handlers &... handlers) {
912
+ VisitUnion (Owner, Parent, Wrapper, handlers...);
913
+ }
914
+
915
+ template <typename ParentTy, typename CurHandler, typename ... Handlers>
916
+ std::enable_if_t <CurHandler::VisitUnionBody>
917
+ VisitUnion (const CXXRecordDecl *Owner, ParentTy &Parent,
918
+ const CXXRecordDecl *Wrapper, CurHandler &cur_handler,
919
+ Handlers &... handlers) {
920
+ VisitUnion<CurHandler>(cur_handler, Owner, Parent, Wrapper, handlers...);
921
+ }
922
+
861
923
template <typename ... Handlers>
862
924
void VisitRecordHelper (const CXXRecordDecl *Owner,
863
925
clang::CXXRecordDecl::base_class_const_range Range,
@@ -943,6 +1005,11 @@ class KernelObjVisitor {
943
1005
CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl ();
944
1006
VisitRecord (Owner, Field, RD, handlers...);
945
1007
}
1008
+ } else if (FieldTy->isUnionType ()) {
1009
+ if (KF_FOR_EACH (handleUnionType, Field, FieldTy)) {
1010
+ CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl ();
1011
+ VisitUnion (Owner, Field, RD, handlers...);
1012
+ }
946
1013
} else if (FieldTy->isReferenceType ())
947
1014
KF_FOR_EACH (handleReferenceType, Field, FieldTy);
948
1015
else if (FieldTy->isPointerType ())
@@ -982,6 +1049,7 @@ class SyclKernelFieldHandler {
982
1049
SyclKernelFieldHandler (Sema &S) : SemaRef(S) {}
983
1050
984
1051
public:
1052
+ static constexpr const bool VisitUnionBody = false ;
985
1053
// Mark these virtual so that we can use override in the implementer classes,
986
1054
// despite virtual dispatch never being used.
987
1055
@@ -1006,6 +1074,7 @@ class SyclKernelFieldHandler {
1006
1074
}
1007
1075
virtual bool handleSyclHalfType (FieldDecl *, QualType) { return true ; }
1008
1076
virtual bool handleStructType (FieldDecl *, QualType) { return true ; }
1077
+ virtual bool handleUnionType (FieldDecl *, QualType) { return true ; }
1009
1078
virtual bool handleReferenceType (FieldDecl *, QualType) { return true ; }
1010
1079
virtual bool handlePointerType (FieldDecl *, QualType) { return true ; }
1011
1080
virtual bool handleArrayType (FieldDecl *, QualType) { return true ; }
@@ -1025,6 +1094,8 @@ class SyclKernelFieldHandler {
1025
1094
virtual bool leaveStruct (const CXXRecordDecl *, const CXXBaseSpecifier &) {
1026
1095
return true ;
1027
1096
}
1097
+ virtual bool enterUnion (const CXXRecordDecl *, FieldDecl *) { return true ; }
1098
+ virtual bool leaveUnion (const CXXRecordDecl *, FieldDecl *) { return true ; }
1028
1099
1029
1100
// The following are used for stepping through array elements.
1030
1101
@@ -1047,7 +1118,6 @@ class SyclKernelFieldHandler {
1047
1118
class SyclKernelFieldChecker : public SyclKernelFieldHandler {
1048
1119
bool IsInvalid = false ;
1049
1120
DiagnosticsEngine &Diag;
1050
-
1051
1121
// Check whether the object should be disallowed from being copied to kernel.
1052
1122
// Return true if not copyable, false if copyable.
1053
1123
bool checkNotCopyableToKernel (const FieldDecl *FD, const QualType &FieldTy) {
@@ -1202,6 +1272,65 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler {
1202
1272
}
1203
1273
};
1204
1274
1275
+ // A type to check the validity of accessing accessor/sampler/stream
1276
+ // types as kernel parameters inside union.
1277
+ class SyclKernelUnionChecker : public SyclKernelFieldHandler {
1278
+ int UnionCount = 0 ;
1279
+ bool IsInvalid = false ;
1280
+ DiagnosticsEngine &Diag;
1281
+
1282
+ public:
1283
+ SyclKernelUnionChecker (Sema &S)
1284
+ : SyclKernelFieldHandler(S), Diag(S.getASTContext().getDiagnostics()) {}
1285
+ bool isValid () { return !IsInvalid; }
1286
+ static constexpr const bool VisitUnionBody = true ;
1287
+
1288
+ bool checkType (SourceLocation Loc, QualType Ty) {
1289
+ if (UnionCount) {
1290
+ IsInvalid = true ;
1291
+ Diag.Report (Loc, diag::err_bad_union_kernel_param_members) << Ty;
1292
+ }
1293
+ return isValid ();
1294
+ }
1295
+
1296
+ bool enterUnion (const CXXRecordDecl *RD, FieldDecl *FD) {
1297
+ ++UnionCount;
1298
+ return true ;
1299
+ }
1300
+
1301
+ bool leaveUnion (const CXXRecordDecl *RD, FieldDecl *FD) {
1302
+ --UnionCount;
1303
+ return true ;
1304
+ }
1305
+
1306
+ bool handleSyclAccessorType (FieldDecl *FD, QualType FieldTy) final {
1307
+ return checkType (FD->getLocation (), FieldTy);
1308
+ }
1309
+
1310
+ bool handleSyclAccessorType (const CXXBaseSpecifier &BS,
1311
+ QualType FieldTy) final {
1312
+ return checkType (BS.getBeginLoc (), FieldTy);
1313
+ }
1314
+
1315
+ bool handleSyclSamplerType (FieldDecl *FD, QualType FieldTy) final {
1316
+ return checkType (FD->getLocation (), FieldTy);
1317
+ }
1318
+
1319
+ bool handleSyclSamplerType (const CXXBaseSpecifier &BS,
1320
+ QualType FieldTy) final {
1321
+ return checkType (BS.getBeginLoc (), FieldTy);
1322
+ }
1323
+
1324
+ bool handleSyclStreamType (FieldDecl *FD, QualType FieldTy) final {
1325
+ return checkType (FD->getLocation (), FieldTy);
1326
+ }
1327
+
1328
+ bool handleSyclStreamType (const CXXBaseSpecifier &BS,
1329
+ QualType FieldTy) final {
1330
+ return checkType (BS.getBeginLoc (), FieldTy);
1331
+ }
1332
+ };
1333
+
1205
1334
// A type to Create and own the FunctionDecl for the kernel.
1206
1335
class SyclKernelDeclCreator : public SyclKernelFieldHandler {
1207
1336
FunctionDecl *KernelDecl;
@@ -1453,6 +1582,10 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
1453
1582
return true ;
1454
1583
}
1455
1584
1585
+ bool handleUnionType (FieldDecl *FD, QualType FieldTy) final {
1586
+ return handleScalarType (FD, FieldTy);
1587
+ }
1588
+
1456
1589
bool handleSyclHalfType (FieldDecl *FD, QualType FieldTy) final {
1457
1590
addParam (FD, FieldTy);
1458
1591
return true ;
@@ -1805,6 +1938,10 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
1805
1938
return true ;
1806
1939
}
1807
1940
1941
+ bool handleUnionType (FieldDecl *FD, QualType FieldTy) final {
1942
+ return handleScalarType (FD, FieldTy);
1943
+ }
1944
+
1808
1945
bool enterStruct (const CXXRecordDecl *RD, const CXXBaseSpecifier &BS) final {
1809
1946
CXXCastPath BasePath;
1810
1947
QualType DerivedTy (RD->getTypeForDecl (), 0 );
@@ -2012,6 +2149,10 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
2012
2149
return true ;
2013
2150
}
2014
2151
2152
+ bool handleUnionType (FieldDecl *FD, QualType FieldTy) final {
2153
+ return handleScalarType (FD, FieldTy);
2154
+ }
2155
+
2015
2156
bool handleSyclStreamType (FieldDecl *FD, QualType FieldTy) final {
2016
2157
addParam (FD, FieldTy, SYCLIntegrationHeader::kind_std_layout);
2017
2158
return true ;
@@ -2105,14 +2246,14 @@ void Sema::CheckSYCLKernelCall(FunctionDecl *KernelFunc, SourceRange CallLoc,
2105
2246
}
2106
2247
}
2107
2248
2108
- SyclKernelFieldChecker Checker (*this );
2109
-
2249
+ SyclKernelFieldChecker FieldChecker (*this );
2250
+ SyclKernelUnionChecker UnionChecker (* this );
2110
2251
KernelObjVisitor Visitor{*this };
2111
2252
DiagnosingSYCLKernel = true ;
2112
- Visitor.VisitRecordBases (KernelObj, Checker );
2113
- Visitor.VisitRecordFields (KernelObj, Checker );
2253
+ Visitor.VisitRecordBases (KernelObj, FieldChecker, UnionChecker );
2254
+ Visitor.VisitRecordFields (KernelObj, FieldChecker, UnionChecker );
2114
2255
DiagnosingSYCLKernel = false ;
2115
- if (!Checker .isValid ())
2256
+ if (!FieldChecker. isValid () || !UnionChecker .isValid ())
2116
2257
KernelFunc->setInvalidDecl ();
2117
2258
}
2118
2259
0 commit comments