@@ -830,6 +830,13 @@ class KernelObjVisitor {
830
830
else if (ElementTy->isStructureOrClassType ())
831
831
VisitRecord (Owner, ArrayField, ElementTy->getAsCXXRecordDecl (),
832
832
handlers...);
833
+ else if (ElementTy->isUnionType ())
834
+ // TODO: This check is still necessary I think?! Array seems to handle
835
+ // this differently (see above) for structs I think.
836
+ // if (KF_FOR_EACH(handleUnionType, Field, FieldTy)) {
837
+ VisitUnion (Owner, ArrayField, ElementTy->getAsCXXRecordDecl (),
838
+ handlers...);
839
+ // }
833
840
else if (ElementTy->isArrayType ())
834
841
VisitArrayElements (ArrayField, ElementTy, handlers...);
835
842
else if (ElementTy->isScalarType ())
@@ -857,6 +864,41 @@ class KernelObjVisitor {
857
864
void VisitRecord (const CXXRecordDecl *Owner, ParentTy &Parent,
858
865
const CXXRecordDecl *Wrapper, Handlers &... handlers);
859
866
867
+ // Base case, only calls these when filtered.
868
+ template <typename ... FilteredHandlers, typename ParentTy>
869
+ void VisitUnion (const CXXRecordDecl *Owner, ParentTy &Parent,
870
+ const CXXRecordDecl *Wrapper,
871
+ FilteredHandlers &... handlers) {
872
+ (void )std::initializer_list<int >{
873
+ (handlers.enterUnion (Owner, Parent), 0 )...};
874
+ VisitRecordHelper (Wrapper, Wrapper->fields (), handlers...);
875
+ (void )std::initializer_list<int >{
876
+ (handlers.leaveUnion (Owner, Parent), 0 )...};
877
+ }
878
+
879
+
880
+ template <typename ... FilteredHandlers, typename ParentTy,
881
+ typename CurHandler, typename ... Handlers>
882
+ std::enable_if_t <!CurHandler::VisitUnionBody>
883
+ VisitUnion (const CXXRecordDecl *Owner, ParentTy &Parent,
884
+ const CXXRecordDecl *Wrapper,
885
+ FilteredHandlers &... filtered_handlers,
886
+ CurHandler &cur_handler, Handlers &... handlers) {
887
+ VisitUnion<FilteredHandlers...>(
888
+ Owner, Parent, Wrapper, filtered_handlers..., handlers...);
889
+ }
890
+
891
+ template <typename ... FilteredHandlers, typename ParentTy,
892
+ typename CurHandler, typename ... Handlers>
893
+ std::enable_if_t <CurHandler::VisitUnionBody>
894
+ VisitUnion (const CXXRecordDecl *Owner, ParentTy &Parent,
895
+ const CXXRecordDecl *Wrapper,
896
+ FilteredHandlers &... filtered_handlers,
897
+ CurHandler &cur_handler, Handlers &... handlers) {
898
+ VisitUnion<FilteredHandlers..., CurHandler>(
899
+ Owner, Parent, Wrapper, filtered_handlers..., cur_handler, handlers...);
900
+ }
901
+
860
902
template <typename ... Handlers>
861
903
void VisitRecordHelper (const CXXRecordDecl *Owner,
862
904
clang::CXXRecordDecl::base_class_const_range Range,
@@ -942,6 +984,11 @@ class KernelObjVisitor {
942
984
CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl ();
943
985
VisitRecord (Owner, Field, RD, handlers...);
944
986
}
987
+ } else if (FieldTy->isUnionType ()) {
988
+ if (KF_FOR_EACH (handleUnionType, Field, FieldTy)) {
989
+ CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl ();
990
+ VisitUnion (Owner, Field, RD, handlers...);
991
+ }
945
992
} else if (FieldTy->isReferenceType ())
946
993
KF_FOR_EACH (handleReferenceType, Field, FieldTy);
947
994
else if (FieldTy->isPointerType ())
@@ -1005,6 +1052,7 @@ class SyclKernelFieldHandler {
1005
1052
}
1006
1053
virtual bool handleSyclHalfType (FieldDecl *, QualType) { return true ; }
1007
1054
virtual bool handleStructType (FieldDecl *, QualType) { return true ; }
1055
+ virtual bool handleUnionType (FieldDecl *, QualType) { return true ; }
1008
1056
virtual bool handleReferenceType (FieldDecl *, QualType) { return true ; }
1009
1057
virtual bool handlePointerType (FieldDecl *, QualType) { return true ; }
1010
1058
virtual bool handleArrayType (FieldDecl *, QualType) { return true ; }
@@ -1024,6 +1072,8 @@ class SyclKernelFieldHandler {
1024
1072
virtual bool leaveStruct (const CXXRecordDecl *, const CXXBaseSpecifier &) {
1025
1073
return true ;
1026
1074
}
1075
+ virtual bool enterUnion (const CXXRecordDecl *, FieldDecl *) { return true ; }
1076
+ virtual bool leaveUnion (const CXXRecordDecl *, FieldDecl *) { return true ; }
1027
1077
1028
1078
// The following are used for stepping through array elements.
1029
1079
@@ -1201,6 +1251,65 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler {
1201
1251
}
1202
1252
};
1203
1253
1254
+ // A type to check the validity of passing union with accessor/sampler/stream
1255
+ // member as a kernel argument types.
1256
+ class SyclKernelUnionBodyChecker : public SyclKernelFieldHandler {
1257
+ static constexpr const bool VisitUnionBody = true ;
1258
+ int UnionCount = 0 ;
1259
+ bool IsInvalid = false ;
1260
+ DiagnosticsEngine &Diag;
1261
+
1262
+ public:
1263
+ SyclKernelUnionBodyChecker (Sema &S)
1264
+ : SyclKernelFieldHandler(S), Diag(S.getASTContext().getDiagnostics()) {}
1265
+ bool isValid () { return !IsInvalid; }
1266
+
1267
+ bool enterUnion (const CXXRecordDecl *RD, FieldDecl *FD) {
1268
+ ++UnionCount;
1269
+ return true ;
1270
+ }
1271
+
1272
+ bool leaveUnion (const CXXRecordDecl *RD, FieldDecl *FD) {
1273
+ --UnionCount;
1274
+ return true ;
1275
+ }
1276
+
1277
+ bool handlePointerType (FieldDecl *FD, QualType FieldTy) final {
1278
+ if (UnionCount) {
1279
+ IsInvalid = true ;
1280
+ Diag.Report (FD->getLocation (), diag::err_bad_kernel_param_type)
1281
+ << FieldTy;
1282
+ }
1283
+ return isValid ();
1284
+ }
1285
+
1286
+ bool handleSyclAccessorType (FieldDecl *FD, QualType FieldTy) final {
1287
+ if (UnionCount) {
1288
+ IsInvalid = true ;
1289
+ Diag.Report (FD->getLocation (), diag::err_bad_kernel_param_type)
1290
+ << FieldTy;
1291
+ }
1292
+ return isValid ();
1293
+ }
1294
+
1295
+ bool handleSyclSamplerType (FieldDecl *FD, QualType FieldTy) final {
1296
+ if (UnionCount) {
1297
+ IsInvalid = true ;
1298
+ Diag.Report (FD->getLocation (), diag::err_bad_kernel_param_type)
1299
+ << FieldTy;
1300
+ }
1301
+ return isValid ();
1302
+ }
1303
+ bool handleSyclStreamType (FieldDecl *FD, QualType FieldTy) final {
1304
+ if (UnionCount) {
1305
+ IsInvalid = true ;
1306
+ Diag.Report (FD->getLocation (), diag::err_bad_kernel_param_type)
1307
+ << FieldTy;
1308
+ }
1309
+ return isValid ();
1310
+ }
1311
+ };
1312
+
1204
1313
// A type to Create and own the FunctionDecl for the kernel.
1205
1314
class SyclKernelDeclCreator : public SyclKernelFieldHandler {
1206
1315
FunctionDecl *KernelDecl;
@@ -1416,6 +1525,10 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
1416
1525
return true ;
1417
1526
}
1418
1527
1528
+ bool handleUnionType (FieldDecl *FD, QualType FieldTy) final {
1529
+ return handleScalarType (FD, FieldTy);
1530
+ }
1531
+
1419
1532
bool handleSyclHalfType (FieldDecl *FD, QualType FieldTy) final {
1420
1533
addParam (FD, FieldTy);
1421
1534
return true ;
@@ -1751,6 +1864,10 @@ class SyclKernelBodyCreator : public SyclKernelFieldHandler {
1751
1864
return true ;
1752
1865
}
1753
1866
1867
+ bool handleUnionType (FieldDecl *FD, QualType FieldTy) final {
1868
+ return handleScalarType (FD, FieldTy);
1869
+ }
1870
+
1754
1871
bool enterStruct (const CXXRecordDecl *RD, const CXXBaseSpecifier &BS) final {
1755
1872
CXXCastPath BasePath;
1756
1873
QualType DerivedTy (RD->getTypeForDecl (), 0 );
@@ -1955,6 +2072,10 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
1955
2072
return true ;
1956
2073
}
1957
2074
2075
+ bool handleUnionType (FieldDecl *FD, QualType FieldTy) final {
2076
+ return handleScalarType (FD, FieldTy);
2077
+ }
2078
+
1958
2079
bool handleSyclStreamType (FieldDecl *FD, QualType FieldTy) final {
1959
2080
addParam (FD, FieldTy, SYCLIntegrationHeader::kind_std_layout);
1960
2081
return true ;
0 commit comments